330 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			330 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| 
								 | 
							
								import types
							 | 
						||
| 
								 | 
							
								from typing import Optional, Tuple
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								import mmengine
							 | 
						||
| 
								 | 
							
								import torch
							 | 
						||
| 
								 | 
							
								import torch.nn as nn
							 | 
						||
| 
								 | 
							
								from mmengine.device import get_device
							 | 
						||
| 
								 | 
							
								from transformers import AutoModelForCausalLM, AutoTokenizer
							 | 
						||
| 
								 | 
							
								from transformers.generation import GenerationConfig
							 | 
						||
| 
								 | 
							
								from transformers.modeling_outputs import BaseModelOutputWithPast
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								from opencompass.registry import MM_MODELS
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								from .generation_utils import decode_tokens, make_context
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								@MM_MODELS.register_module('qwen-vl-base')
							 | 
						||
| 
								 | 
							
								class QwenVLBase(nn.Module):
							 | 
						||
| 
								 | 
							
								    """Inference code of Qwen-VL.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    We load the Qwen model via Huggingface.
							 | 
						||
| 
								 | 
							
								    Args:
							 | 
						||
| 
								 | 
							
								        pretrained_path (str): Path to Qwen checkpoint or repo id.
							 | 
						||
| 
								 | 
							
								        prompt_constructor (dict): The config of prompt constructor.
							 | 
						||
| 
								 | 
							
								        post_processor (dict): The config of post processor.
							 | 
						||
| 
								 | 
							
								        is_caption_task (bool): Whether the task is caption task.
							 | 
						||
| 
								 | 
							
								            Defaults to False.
							 | 
						||
| 
								 | 
							
								        commit_id (str): Use given version of Qwen-VL.
							 | 
						||
| 
								 | 
							
								            Warning: the latest version may have some conflicts.
							 | 
						||
| 
								 | 
							
								            Recommend to use the given default version.
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def __init__(
							 | 
						||
| 
								 | 
							
								            self,
							 | 
						||
| 
								 | 
							
								            pretrained_path: str,
							 | 
						||
| 
								 | 
							
								            prompt_constructor: dict = None,
							 | 
						||
| 
								 | 
							
								            post_processor: dict = None,
							 | 
						||
| 
								 | 
							
								            is_caption_task: bool = False,
							 | 
						||
| 
								 | 
							
								            commit_id: str = '548275c8b99de56dec203c0e793be18e030f2f4c'
							 | 
						||
| 
								 | 
							
								    ) -> None:
							 | 
						||
| 
								 | 
							
								        super().__init__()
							 | 
						||
| 
								 | 
							
								        self.tokenizer = AutoTokenizer.from_pretrained(pretrained_path,
							 | 
						||
| 
								 | 
							
								                                                       trust_remote_code=True,
							 | 
						||
| 
								 | 
							
								                                                       revision=commit_id)
							 | 
						||
| 
								 | 
							
								        self.model = AutoModelForCausalLM.from_pretrained(
							 | 
						||
| 
								 | 
							
								            pretrained_path,
							 | 
						||
| 
								 | 
							
								            device_map=get_device(),
							 | 
						||
| 
								 | 
							
								            trust_remote_code=True,
							 | 
						||
| 
								 | 
							
								            revision=commit_id)
							 | 
						||
| 
								 | 
							
								        self.model.generation_config = GenerationConfig.from_pretrained(
							 | 
						||
| 
								 | 
							
								            pretrained_path, trust_remote_code=True, revision=commit_id)
							 | 
						||
| 
								 | 
							
								        if prompt_constructor is not None:
							 | 
						||
| 
								 | 
							
								            self.prompt_constructor = mmengine.registry.build_from_cfg(
							 | 
						||
| 
								 | 
							
								                prompt_constructor, MM_MODELS)
							 | 
						||
| 
								 | 
							
								        if post_processor is not None:
							 | 
						||
| 
								 | 
							
								            self.post_processor = mmengine.registry.build_from_cfg(
							 | 
						||
| 
								 | 
							
								                post_processor, MM_MODELS)
							 | 
						||
| 
								 | 
							
								        else:
							 | 
						||
| 
								 | 
							
								            self.post_processor = None
							 | 
						||
| 
								 | 
							
								        self.is_caption_task = is_caption_task
							 | 
						||
| 
								 | 
							
								        self.model.transformer.forward = types.MethodType(
							 | 
						||
| 
								 | 
							
								            forward_hack, self.model.transformer)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def _build_embeds(self, images, input_ids):
							 | 
						||
| 
								 | 
							
								        # encode image
							 | 
						||
| 
								 | 
							
								        images = self.model.transformer.visual(images)
							 | 
						||
| 
								 | 
							
								        # compute image position
							 | 
						||
| 
								 | 
							
								        bos_pos = torch.where(input_ids == self.model.transformer.config.
							 | 
						||
| 
								 | 
							
								                              visual['image_start_id'])
							 | 
						||
| 
								 | 
							
								        eos_pos = torch.where(
							 | 
						||
| 
								 | 
							
								            input_ids ==
							 | 
						||
| 
								 | 
							
								            self.model.transformer.config.visual['image_start_id'] + 1)
							 | 
						||
| 
								 | 
							
								        assert (bos_pos[0] == eos_pos[0]).all()
							 | 
						||
| 
								 | 
							
								        img_pos = torch.stack((bos_pos[0], bos_pos[1], eos_pos[1]), dim=1)
							 | 
						||
| 
								 | 
							
								        # embed words
							 | 
						||
| 
								 | 
							
								        inputs_embeds = self.model.transformer.wte(input_ids)
							 | 
						||
| 
								 | 
							
								        # embed image tokens
							 | 
						||
| 
								 | 
							
								        for idx, (i, a, b) in enumerate(img_pos):
							 | 
						||
| 
								 | 
							
								            inputs_embeds[i][a + 1:b] = images[idx]
							 | 
						||
| 
								 | 
							
								        return inputs_embeds
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def generate(self, batch):
							 | 
						||
| 
								 | 
							
								        images = batch.pop('inputs')
							 | 
						||
| 
								 | 
							
								        images = torch.stack(images, dim=0)
							 | 
						||
| 
								 | 
							
								        format_input = self.prompt_constructor(batch)
							 | 
						||
| 
								 | 
							
								        query = self.tokenizer.from_list_format(format_input)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        inputs = self.tokenizer(query, return_tensors='pt')
							 | 
						||
| 
								 | 
							
								        inputs = inputs.to(get_device())
							 | 
						||
| 
								 | 
							
								        input_ids, token_type_ids, attention_mask = inputs[
							 | 
						||
| 
								 | 
							
								            'input_ids'], inputs['token_type_ids'], inputs['attention_mask']
							 | 
						||
| 
								 | 
							
								        inputs_embeds = self._build_embeds(images, input_ids)
							 | 
						||
| 
								 | 
							
								        pred = self.model.generate(input_ids=input_ids,
							 | 
						||
| 
								 | 
							
								                                   inputs_embeds=inputs_embeds,
							 | 
						||
| 
								 | 
							
								                                   attention_mask=attention_mask,
							 | 
						||
| 
								 | 
							
								                                   token_type_ids=token_type_ids)
							 | 
						||
| 
								 | 
							
								        response = self.post_processor(pred.cpu()[0])
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        data_sample = batch['data_samples'][0]
							 | 
						||
| 
								 | 
							
								        if self.is_caption_task:
							 | 
						||
| 
								 | 
							
								            data_sample.pred_caption = response
							 | 
						||
| 
								 | 
							
								        else:
							 | 
						||
| 
								 | 
							
								            data_sample.pred_answer = response
							 | 
						||
| 
								 | 
							
								        return data_sample
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def forward(self, batch):
							 | 
						||
| 
								 | 
							
								        return self.generate(batch)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								@MM_MODELS.register_module('qwen-vl-chat')
							 | 
						||
| 
								 | 
							
								class QwenVLChat(QwenVLBase):
							 | 
						||
| 
								 | 
							
								    """Inference code of Qwen-VL-Chat.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    We load the Qwen model via Huggingface.
							 | 
						||
| 
								 | 
							
								    Args:
							 | 
						||
| 
								 | 
							
								        pretrained_path (str): Path to Qwen checkpoint or repo id.
							 | 
						||
| 
								 | 
							
								        prompt_constructor (dict): The config of prompt constructor.
							 | 
						||
| 
								 | 
							
								        post_processor (dict): The config of post processor.
							 | 
						||
| 
								 | 
							
								        is_caption_task (bool): Whether the task is caption task.
							 | 
						||
| 
								 | 
							
								            Defaults to False.
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def __init__(self,
							 | 
						||
| 
								 | 
							
								                 pretrained_path: str,
							 | 
						||
| 
								 | 
							
								                 prompt_constructor: dict = None,
							 | 
						||
| 
								 | 
							
								                 post_processor: dict = None,
							 | 
						||
| 
								 | 
							
								                 is_caption_task: bool = False) -> None:
							 | 
						||
| 
								 | 
							
								        super().__init__(pretrained_path, prompt_constructor, post_processor,
							 | 
						||
| 
								 | 
							
								                         is_caption_task)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def generate(self, batch):
							 | 
						||
| 
								 | 
							
								        images = batch.pop('inputs')
							 | 
						||
| 
								 | 
							
								        images = torch.stack(images, dim=0)
							 | 
						||
| 
								 | 
							
								        format_input = self.prompt_constructor(batch)
							 | 
						||
| 
								 | 
							
								        query = self.tokenizer.from_list_format(format_input)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        raw_text, context_tokens = make_context(
							 | 
						||
| 
								 | 
							
								            self.tokenizer,
							 | 
						||
| 
								 | 
							
								            query,
							 | 
						||
| 
								 | 
							
								            system='You are a helpful assistant.',
							 | 
						||
| 
								 | 
							
								            chat_format=self.model.generation_config.chat_format,
							 | 
						||
| 
								 | 
							
								        )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        input_ids = torch.tensor([context_tokens]).to(get_device())
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        inputs_embeds = self._build_embeds(images, input_ids)
							 | 
						||
| 
								 | 
							
								        pred = self.model.generate(input_ids=input_ids,
							 | 
						||
| 
								 | 
							
								                                   inputs_embeds=inputs_embeds)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        response = decode_tokens(
							 | 
						||
| 
								 | 
							
								            pred[0],
							 | 
						||
| 
								 | 
							
								            self.tokenizer,
							 | 
						||
| 
								 | 
							
								            raw_text_len=len(raw_text),
							 | 
						||
| 
								 | 
							
								            context_length=len(context_tokens),
							 | 
						||
| 
								 | 
							
								            chat_format=self.model.generation_config.chat_format,
							 | 
						||
| 
								 | 
							
								            verbose=False,
							 | 
						||
| 
								 | 
							
								            errors='replace')
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        if self.post_processor:
							 | 
						||
| 
								 | 
							
								            response = self.post_processor(response)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        data_sample = batch['data_samples'][0]
							 | 
						||
| 
								 | 
							
								        if self.is_caption_task:
							 | 
						||
| 
								 | 
							
								            data_sample.pred_caption = response
							 | 
						||
| 
								 | 
							
								        else:
							 | 
						||
| 
								 | 
							
								            data_sample.pred_answer = response
							 | 
						||
| 
								 | 
							
								        return data_sample
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def forward_hack(self,
							 | 
						||
| 
								 | 
							
								                 input_ids: Optional[torch.LongTensor] = None,
							 | 
						||
| 
								 | 
							
								                 past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
							 | 
						||
| 
								 | 
							
								                 attention_mask: Optional[torch.FloatTensor] = None,
							 | 
						||
| 
								 | 
							
								                 token_type_ids: Optional[torch.LongTensor] = None,
							 | 
						||
| 
								 | 
							
								                 position_ids: Optional[torch.LongTensor] = None,
							 | 
						||
| 
								 | 
							
								                 head_mask: Optional[torch.FloatTensor] = None,
							 | 
						||
| 
								 | 
							
								                 inputs_embeds: Optional[torch.FloatTensor] = None,
							 | 
						||
| 
								 | 
							
								                 encoder_hidden_states: Optional[torch.Tensor] = None,
							 | 
						||
| 
								 | 
							
								                 encoder_attention_mask: Optional[torch.FloatTensor] = None,
							 | 
						||
| 
								 | 
							
								                 use_cache: Optional[bool] = None,
							 | 
						||
| 
								 | 
							
								                 output_attentions: Optional[bool] = None,
							 | 
						||
| 
								 | 
							
								                 output_hidden_states: Optional[bool] = None,
							 | 
						||
| 
								 | 
							
								                 return_dict: Optional[bool] = None):
							 | 
						||
| 
								 | 
							
								    if past_key_values is None and input_ids is not None and torch.any(
							 | 
						||
| 
								 | 
							
								            input_ids == self.config.visual['image_start_id']):
							 | 
						||
| 
								 | 
							
								        bos_pos = torch.where(
							 | 
						||
| 
								 | 
							
								            input_ids == self.config.visual['image_start_id'])
							 | 
						||
| 
								 | 
							
								        eos_pos = torch.where(
							 | 
						||
| 
								 | 
							
								            input_ids == self.config.visual['image_start_id'] + 1)
							 | 
						||
| 
								 | 
							
								        assert (bos_pos[0] == eos_pos[0]).all()
							 | 
						||
| 
								 | 
							
								        img_pos = torch.stack((bos_pos[0], bos_pos[1], eos_pos[1]), dim=1)
							 | 
						||
| 
								 | 
							
								        images = []
							 | 
						||
| 
								 | 
							
								        for i, a, b in img_pos:
							 | 
						||
| 
								 | 
							
								            image = input_ids[i][a + 1:b - 1].tolist()
							 | 
						||
| 
								 | 
							
								            image = image[:image.index(self.config.visual['image_start_id'] +
							 | 
						||
| 
								 | 
							
								                                       2)]
							 | 
						||
| 
								 | 
							
								            images.append(bytes(image).decode('utf-8'))
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        images = self.visual.encode(images)
							 | 
						||
| 
								 | 
							
								        assert images.shape[0] == len(images)
							 | 
						||
| 
								 | 
							
								    else:
							 | 
						||
| 
								 | 
							
								        images = None
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    output_attentions = (output_attentions if output_attentions is not None
							 | 
						||
| 
								 | 
							
								                         else self.config.output_attentions)
							 | 
						||
| 
								 | 
							
								    output_hidden_states = (output_hidden_states if output_hidden_states
							 | 
						||
| 
								 | 
							
								                            is not None else self.config.output_hidden_states)
							 | 
						||
| 
								 | 
							
								    use_cache = use_cache if use_cache is not None else self.config.use_cache
							 | 
						||
| 
								 | 
							
								    return_dict = (return_dict
							 | 
						||
| 
								 | 
							
								                   if return_dict is not None else self.config.use_return_dict)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    if input_ids is not None and inputs_embeds is not None:
							 | 
						||
| 
								 | 
							
								        raise ValueError(
							 | 
						||
| 
								 | 
							
								            'You cannot specify both input_ids and inputs_embeds at the same time'  # noqa
							 | 
						||
| 
								 | 
							
								        )
							 | 
						||
| 
								 | 
							
								    elif input_ids is not None:
							 | 
						||
| 
								 | 
							
								        input_shape = input_ids.size()
							 | 
						||
| 
								 | 
							
								        input_ids = input_ids.view(-1, input_shape[-1])
							 | 
						||
| 
								 | 
							
								        batch_size = input_ids.shape[0]
							 | 
						||
| 
								 | 
							
								    elif inputs_embeds is not None:
							 | 
						||
| 
								 | 
							
								        input_shape = inputs_embeds.size()[:-1]
							 | 
						||
| 
								 | 
							
								        batch_size = inputs_embeds.shape[0]
							 | 
						||
| 
								 | 
							
								    else:
							 | 
						||
| 
								 | 
							
								        raise ValueError(
							 | 
						||
| 
								 | 
							
								            'You have to specify either input_ids or inputs_embeds')
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    device = input_ids.device if input_ids is not None else inputs_embeds.device  # noqa
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    if token_type_ids is not None:
							 | 
						||
| 
								 | 
							
								        token_type_ids = token_type_ids.view(-1, input_shape[-1])
							 | 
						||
| 
								 | 
							
								    if position_ids is not None:
							 | 
						||
| 
								 | 
							
								        position_ids = position_ids.view(-1, input_shape[-1])
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    if past_key_values is None:
							 | 
						||
| 
								 | 
							
								        past_length = 0
							 | 
						||
| 
								 | 
							
								        past_key_values = tuple([None] * len(self.h))
							 | 
						||
| 
								 | 
							
								    else:
							 | 
						||
| 
								 | 
							
								        past_length = past_key_values[0][0].size(-2)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    if position_ids is None:
							 | 
						||
| 
								 | 
							
								        position_ids = torch.arange(
							 | 
						||
| 
								 | 
							
								            past_length,
							 | 
						||
| 
								 | 
							
								            input_shape[-1] + past_length,
							 | 
						||
| 
								 | 
							
								            dtype=torch.long,
							 | 
						||
| 
								 | 
							
								            device=device,
							 | 
						||
| 
								 | 
							
								        )
							 | 
						||
| 
								 | 
							
								        position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    encoder_attention_mask = None
							 | 
						||
| 
								 | 
							
								    head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    if inputs_embeds is None:
							 | 
						||
| 
								 | 
							
								        inputs_embeds = self.wte(input_ids)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    if batch_size <= 0:
							 | 
						||
| 
								 | 
							
								        raise ValueError('batch_size has to be defined and > 0')
							 | 
						||
| 
								 | 
							
								    attention_mask = self._prepare_decoder_attention_mask(
							 | 
						||
| 
								 | 
							
								        attention_mask, input_shape, inputs_embeds, past_length)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    hidden_states = inputs_embeds
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    hidden_states = self.drop(hidden_states)
							 | 
						||
| 
								 | 
							
								    if images is not None:
							 | 
						||
| 
								 | 
							
								        for idx, (i, a, b) in enumerate(img_pos):
							 | 
						||
| 
								 | 
							
								            hidden_states[i][a + 1:b] = images[idx]
							 | 
						||
| 
								 | 
							
								    output_shape = input_shape + (hidden_states.size(-1), )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    presents = () if use_cache else None
							 | 
						||
| 
								 | 
							
								    all_self_attentions = () if output_attentions else None
							 | 
						||
| 
								 | 
							
								    all_hidden_states = () if output_hidden_states else None
							 | 
						||
| 
								 | 
							
								    for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        if output_hidden_states:
							 | 
						||
| 
								 | 
							
								            all_hidden_states = all_hidden_states + (hidden_states, )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        if self.gradient_checkpointing and self.training:
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								            def create_custom_forward(module):
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								                def custom_forward(*inputs):
							 | 
						||
| 
								 | 
							
								                    # None for past_key_value
							 | 
						||
| 
								 | 
							
								                    return module(*inputs, use_cache, output_attentions)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								                return custom_forward
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								            outputs = torch.utils.checkpoint.checkpoint(
							 | 
						||
| 
								 | 
							
								                create_custom_forward(block),
							 | 
						||
| 
								 | 
							
								                hidden_states,
							 | 
						||
| 
								 | 
							
								                None,
							 | 
						||
| 
								 | 
							
								                attention_mask,
							 | 
						||
| 
								 | 
							
								                head_mask[i],
							 | 
						||
| 
								 | 
							
								                encoder_hidden_states,
							 | 
						||
| 
								 | 
							
								                encoder_attention_mask,
							 | 
						||
| 
								 | 
							
								            )
							 | 
						||
| 
								 | 
							
								        else:
							 | 
						||
| 
								 | 
							
								            outputs = block(
							 | 
						||
| 
								 | 
							
								                hidden_states,
							 | 
						||
| 
								 | 
							
								                layer_past=layer_past,
							 | 
						||
| 
								 | 
							
								                attention_mask=attention_mask,
							 | 
						||
| 
								 | 
							
								                head_mask=head_mask[i],
							 | 
						||
| 
								 | 
							
								                encoder_hidden_states=encoder_hidden_states,
							 | 
						||
| 
								 | 
							
								                encoder_attention_mask=encoder_attention_mask,
							 | 
						||
| 
								 | 
							
								                use_cache=use_cache,
							 | 
						||
| 
								 | 
							
								                output_attentions=output_attentions,
							 | 
						||
| 
								 | 
							
								            )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        hidden_states = outputs[0]
							 | 
						||
| 
								 | 
							
								        if use_cache is True:
							 | 
						||
| 
								 | 
							
								            presents = presents + (outputs[2 if output_attentions else 1], )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        if output_attentions:
							 | 
						||
| 
								 | 
							
								            all_self_attentions = all_self_attentions + (outputs[1], )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    hidden_states = self.ln_f(hidden_states)
							 | 
						||
| 
								 | 
							
								    hidden_states = hidden_states.view(output_shape)
							 | 
						||
| 
								 | 
							
								    # Add last hidden state
							 | 
						||
| 
								 | 
							
								    if output_hidden_states:
							 | 
						||
| 
								 | 
							
								        all_hidden_states = all_hidden_states + (hidden_states, )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    if not return_dict:
							 | 
						||
| 
								 | 
							
								        return tuple(v for v in [hidden_states, presents, all_hidden_states]
							 | 
						||
| 
								 | 
							
								                     if v is not None)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    return BaseModelOutputWithPast(
							 | 
						||
| 
								 | 
							
								        last_hidden_state=hidden_states,
							 | 
						||
| 
								 | 
							
								        past_key_values=presents,
							 | 
						||
| 
								 | 
							
								        hidden_states=all_hidden_states,
							 | 
						||
| 
								 | 
							
								        attentions=all_self_attentions,
							 | 
						||
| 
								 | 
							
								    )
							 |