290 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			290 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| 
								 | 
							
								import os
							 | 
						||
| 
								 | 
							
								import sys
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								import mmengine
							 | 
						||
| 
								 | 
							
								import torch
							 | 
						||
| 
								 | 
							
								import torch.nn as nn
							 | 
						||
| 
								 | 
							
								from mmengine.device import get_device
							 | 
						||
| 
								 | 
							
								from transformers import StoppingCriteriaList
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								from opencompass.registry import MM_MODELS
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								from .utils import StoppingCriteriaSub
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class LayerNorm(nn.LayerNorm):
							 | 
						||
| 
								 | 
							
								    """Subclass torch's LayerNorm to handle fp16."""
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def forward(self, x: torch.Tensor):
							 | 
						||
| 
								 | 
							
								        orig_type = x.dtype
							 | 
						||
| 
								 | 
							
								        ret = super().forward(x.type(torch.float32))
							 | 
						||
| 
								 | 
							
								        return ret.type(orig_type)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def load_package():
							 | 
						||
| 
								 | 
							
								    """Load required packages from MiniGPT-4."""
							 | 
						||
| 
								 | 
							
								    current_file_path = os.path.abspath(__file__)
							 | 
						||
| 
								 | 
							
								    current_folder_path = os.path.dirname(current_file_path)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    sys.path.append(os.path.join(current_folder_path, 'MiniGPT-4'))  # noqa
							 | 
						||
| 
								 | 
							
								    from minigpt4.models.mini_gpt4 import MiniGPT4
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    sys.path.pop(-1)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    return MiniGPT4
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								MiniGPT4 = load_package()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								@MM_MODELS.register_module('minigpt-4')
							 | 
						||
| 
								 | 
							
								class MiniGPT4Inferencer(MiniGPT4):
							 | 
						||
| 
								 | 
							
								    """Inference code of MiniGPT-4.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    Args:
							 | 
						||
| 
								 | 
							
								        llama_model (str): The path of vicuna path.
							 | 
						||
| 
								 | 
							
								        prompt_constructor (dict): The config of prompt constructor.
							 | 
						||
| 
								 | 
							
								        post_processor (dict): The config of post processor.
							 | 
						||
| 
								 | 
							
								        do_sample (bool): Whether use sampling. Defaults to False.
							 | 
						||
| 
								 | 
							
								        max_length (int): The max length of output. Defaults to 30.
							 | 
						||
| 
								 | 
							
								        img_size (int): The size of image. Defaults to 224.
							 | 
						||
| 
								 | 
							
								        low_resource (bool): Whether loaded in low precision.
							 | 
						||
| 
								 | 
							
								            Defaults to False.
							 | 
						||
| 
								 | 
							
								        is_caption_task (bool): Whether the task is caption task.
							 | 
						||
| 
								 | 
							
								            Defaults to False.
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def __init__(self,
							 | 
						||
| 
								 | 
							
								                 llama_model: str,
							 | 
						||
| 
								 | 
							
								                 prompt_constructor: dict,
							 | 
						||
| 
								 | 
							
								                 post_processor: dict,
							 | 
						||
| 
								 | 
							
								                 do_sample: bool = False,
							 | 
						||
| 
								 | 
							
								                 max_length: int = 30,
							 | 
						||
| 
								 | 
							
								                 img_size: int = 224,
							 | 
						||
| 
								 | 
							
								                 low_resource: bool = False,
							 | 
						||
| 
								 | 
							
								                 is_caption_task: bool = False,
							 | 
						||
| 
								 | 
							
								                 mode: str = 'generation',
							 | 
						||
| 
								 | 
							
								                 n_segments: int = 1) -> None:
							 | 
						||
| 
								 | 
							
								        super().__init__(llama_model=llama_model,
							 | 
						||
| 
								 | 
							
								                         low_resource=low_resource,
							 | 
						||
| 
								 | 
							
								                         img_size=img_size)
							 | 
						||
| 
								 | 
							
								        self.mode = mode
							 | 
						||
| 
								 | 
							
								        self.n_segments = n_segments
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        cur_device = get_device()
							 | 
						||
| 
								 | 
							
								        stop_words_ids = [
							 | 
						||
| 
								 | 
							
								            torch.tensor([835]).to(cur_device),
							 | 
						||
| 
								 | 
							
								            torch.tensor([2277, 29937]).to(cur_device),
							 | 
						||
| 
								 | 
							
								        ]
							 | 
						||
| 
								 | 
							
								        self.stopping_criteria = StoppingCriteriaList(
							 | 
						||
| 
								 | 
							
								            [StoppingCriteriaSub(stops=stop_words_ids)])
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        self.prompt_constructor = mmengine.registry.build_from_cfg(
							 | 
						||
| 
								 | 
							
								            prompt_constructor, MM_MODELS)
							 | 
						||
| 
								 | 
							
								        if post_processor is not None:
							 | 
						||
| 
								 | 
							
								            self.post_processor = mmengine.registry.build_from_cfg(
							 | 
						||
| 
								 | 
							
								                post_processor, MM_MODELS)
							 | 
						||
| 
								 | 
							
								        self.do_sample = do_sample
							 | 
						||
| 
								 | 
							
								        self.max_length = max_length
							 | 
						||
| 
								 | 
							
								        self.is_caption_task = is_caption_task
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def forward(self, batch):
							 | 
						||
| 
								 | 
							
								        if self.mode == 'generation':
							 | 
						||
| 
								 | 
							
								            return self.generate(batch)
							 | 
						||
| 
								 | 
							
								        elif self.mode == 'loss':
							 | 
						||
| 
								 | 
							
								            return self.loss(batch)
							 | 
						||
| 
								 | 
							
								        else:
							 | 
						||
| 
								 | 
							
								            raise RuntimeError(f'Invalid mode "{self.mode}".')
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def encode_img(self, image):
							 | 
						||
| 
								 | 
							
								        device = image.device
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        with self.maybe_autocast():
							 | 
						||
| 
								 | 
							
								            if image.dim() == 5:
							 | 
						||
| 
								 | 
							
								                inputs_llama, atts_llama = [], []
							 | 
						||
| 
								 | 
							
								                for j in range(image.size(2)):
							 | 
						||
| 
								 | 
							
								                    this_frame = image[:, :, j, :, :]
							 | 
						||
| 
								 | 
							
								                    frame_embeds = self.ln_vision(
							 | 
						||
| 
								 | 
							
								                        self.visual_encoder(this_frame))
							 | 
						||
| 
								 | 
							
								                    frame_atts = torch.ones(frame_embeds.size()[:-1],
							 | 
						||
| 
								 | 
							
								                                            dtype=torch.long).to(image.device)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								                    query_tokens = self.query_tokens.expand(
							 | 
						||
| 
								 | 
							
								                        frame_embeds.shape[0], -1, -1)
							 | 
						||
| 
								 | 
							
								                    frame_query_output = self.Qformer.bert(
							 | 
						||
| 
								 | 
							
								                        query_embeds=query_tokens,
							 | 
						||
| 
								 | 
							
								                        encoder_hidden_states=frame_embeds,
							 | 
						||
| 
								 | 
							
								                        encoder_attention_mask=frame_atts,
							 | 
						||
| 
								 | 
							
								                        return_dict=True,
							 | 
						||
| 
								 | 
							
								                    )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								                    frame_inputs_llama = self.llama_proj(
							 | 
						||
| 
								 | 
							
								                        frame_query_output.last_hidden_state[:, :query_tokens.
							 | 
						||
| 
								 | 
							
								                                                             size(1), :])
							 | 
						||
| 
								 | 
							
								                    frame_atts_llama = torch.ones(
							 | 
						||
| 
								 | 
							
								                        frame_inputs_llama.size()[:-1],
							 | 
						||
| 
								 | 
							
								                        dtype=torch.long).to(image.device)
							 | 
						||
| 
								 | 
							
								                    inputs_llama.append(frame_inputs_llama)
							 | 
						||
| 
								 | 
							
								                    atts_llama.append(frame_atts_llama)
							 | 
						||
| 
								 | 
							
								                inputs_llama = torch.cat(inputs_llama, dim=1)
							 | 
						||
| 
								 | 
							
								                atts_llama = torch.cat(atts_llama, dim=1)
							 | 
						||
| 
								 | 
							
								            else:
							 | 
						||
| 
								 | 
							
								                image_embeds = self.ln_vision(
							 | 
						||
| 
								 | 
							
								                    self.visual_encoder(image)).to(device)
							 | 
						||
| 
								 | 
							
								                image_atts = torch.ones(image_embeds.size()[:-1],
							 | 
						||
| 
								 | 
							
								                                        dtype=torch.long).to(device)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								                query_tokens = self.query_tokens.expand(
							 | 
						||
| 
								 | 
							
								                    image_embeds.shape[0], -1, -1)
							 | 
						||
| 
								 | 
							
								                query_output = self.Qformer.bert(
							 | 
						||
| 
								 | 
							
								                    query_embeds=query_tokens,
							 | 
						||
| 
								 | 
							
								                    encoder_hidden_states=image_embeds,
							 | 
						||
| 
								 | 
							
								                    encoder_attention_mask=image_atts,
							 | 
						||
| 
								 | 
							
								                    return_dict=True,
							 | 
						||
| 
								 | 
							
								                )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								                inputs_llama = self.llama_proj(query_output.last_hidden_state)
							 | 
						||
| 
								 | 
							
								                atts_llama = torch.ones(inputs_llama.size()[:-1],
							 | 
						||
| 
								 | 
							
								                                        dtype=torch.long).to(image.device)
							 | 
						||
| 
								 | 
							
								        return inputs_llama, atts_llama
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def pack_inputs(self, batch):
							 | 
						||
| 
								 | 
							
								        images = [image.unsqueeze(0) for image in batch['inputs']]
							 | 
						||
| 
								 | 
							
								        data_samples = [data_sample for data_sample in batch['data_samples']]
							 | 
						||
| 
								 | 
							
								        images = torch.cat(images, dim=0).to(get_device())
							 | 
						||
| 
								 | 
							
								        inputs = {'image': images, 'data_samples': data_samples}
							 | 
						||
| 
								 | 
							
								        return inputs
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def generate(self, batch):
							 | 
						||
| 
								 | 
							
								        inputs = self.pack_inputs(batch)
							 | 
						||
| 
								 | 
							
								        inputs = self.prompt_constructor(inputs)
							 | 
						||
| 
								 | 
							
								        image = inputs['image']
							 | 
						||
| 
								 | 
							
								        prompt = inputs['prompt']
							 | 
						||
| 
								 | 
							
								        data_samples = inputs['data_samples']
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        # The main process of generation
							 | 
						||
| 
								 | 
							
								        img_embeds, _ = self.encode_img(image)
							 | 
						||
| 
								 | 
							
								        prompt_segs = prompt.split('<ImageHere>')
							 | 
						||
| 
								 | 
							
								        prompt_seg_tokens = [
							 | 
						||
| 
								 | 
							
								            self.llama_tokenizer(seg,
							 | 
						||
| 
								 | 
							
								                                 return_tensors='pt',
							 | 
						||
| 
								 | 
							
								                                 add_special_tokens=i == 0).
							 | 
						||
| 
								 | 
							
								            to(self.llama_model.model.embed_tokens.weight.device).input_ids
							 | 
						||
| 
								 | 
							
								            for i, seg in enumerate(prompt_segs)
							 | 
						||
| 
								 | 
							
								        ]
							 | 
						||
| 
								 | 
							
								        prompt_seg_embs = [
							 | 
						||
| 
								 | 
							
								            self.llama_model.model.embed_tokens(seg)
							 | 
						||
| 
								 | 
							
								            for seg in prompt_seg_tokens
							 | 
						||
| 
								 | 
							
								        ]
							 | 
						||
| 
								 | 
							
								        prompt_seg_embs = [prompt_seg_embs[0], img_embeds, prompt_seg_embs[1]]
							 | 
						||
| 
								 | 
							
								        prompt_embs = torch.cat(prompt_seg_embs, dim=1)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        # generate output
							 | 
						||
| 
								 | 
							
								        outputs = self.llama_model.generate(
							 | 
						||
| 
								 | 
							
								            inputs_embeds=prompt_embs,
							 | 
						||
| 
								 | 
							
								            max_length=self.max_length,
							 | 
						||
| 
								 | 
							
								            num_beams=5,
							 | 
						||
| 
								 | 
							
								            do_sample=self.do_sample,
							 | 
						||
| 
								 | 
							
								            min_length=1,
							 | 
						||
| 
								 | 
							
								            top_p=0.9,
							 | 
						||
| 
								 | 
							
								            repetition_penalty=1.0,
							 | 
						||
| 
								 | 
							
								            length_penalty=-1.0,
							 | 
						||
| 
								 | 
							
								            temperature=1.0,
							 | 
						||
| 
								 | 
							
								            stopping_criteria=self.stopping_criteria,
							 | 
						||
| 
								 | 
							
								            num_return_sequences=1)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        for i, data_sample in enumerate(data_samples):
							 | 
						||
| 
								 | 
							
								            output_token = outputs[i]
							 | 
						||
| 
								 | 
							
								            output_text = self.post_processor(output_token,
							 | 
						||
| 
								 | 
							
								                                              self.llama_tokenizer)
							 | 
						||
| 
								 | 
							
								            if self.is_caption_task:
							 | 
						||
| 
								 | 
							
								                data_sample.pred_caption = output_text
							 | 
						||
| 
								 | 
							
								            else:
							 | 
						||
| 
								 | 
							
								                data_sample.pred_answer = output_text
							 | 
						||
| 
								 | 
							
								            data_samples[i] = data_sample
							 | 
						||
| 
								 | 
							
								        return data_samples
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def loss(self, batch):
							 | 
						||
| 
								 | 
							
								        inputs = self.pack_inputs(batch)
							 | 
						||
| 
								 | 
							
								        inputs = self.prompt_constructor(inputs)
							 | 
						||
| 
								 | 
							
								        image = inputs['image']
							 | 
						||
| 
								 | 
							
								        batch_size = image.size(0)
							 | 
						||
| 
								 | 
							
								        prompt = inputs['prompt']
							 | 
						||
| 
								 | 
							
								        data_samples = inputs['data_samples']
							 | 
						||
| 
								 | 
							
								        choices = data_samples[0].choices
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        with torch.no_grad():
							 | 
						||
| 
								 | 
							
								            img_embeds, atts_img = self.encode_img(image)
							 | 
						||
| 
								 | 
							
								            img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img,
							 | 
						||
| 
								 | 
							
								                                                    prompt)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								            self.llama_tokenizer.padding_side = 'right'
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								            n_cands = len(choices)
							 | 
						||
| 
								 | 
							
								            losses = []
							 | 
						||
| 
								 | 
							
								            for n in range(self.n_segments):
							 | 
						||
| 
								 | 
							
								                seg_len = n_cands // self.n_segments
							 | 
						||
| 
								 | 
							
								                if n == (self.n_segments - 1):
							 | 
						||
| 
								 | 
							
								                    seg_len = n_cands - seg_len * (self.n_segments - 1)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								                to_regress_tokens = self.llama_tokenizer(
							 | 
						||
| 
								 | 
							
								                    choices,
							 | 
						||
| 
								 | 
							
								                    return_tensors='pt',
							 | 
						||
| 
								 | 
							
								                    padding='longest',
							 | 
						||
| 
								 | 
							
								                    truncation=True,
							 | 
						||
| 
								 | 
							
								                    max_length=self.max_txt_len,
							 | 
						||
| 
								 | 
							
								                    add_special_tokens=False).to(image.device)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								                targets = to_regress_tokens.input_ids.masked_fill(
							 | 
						||
| 
								 | 
							
								                    to_regress_tokens.input_ids ==
							 | 
						||
| 
								 | 
							
								                    self.llama_tokenizer.pad_token_id, -100)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								                empty_targets = (
							 | 
						||
| 
								 | 
							
								                    torch.ones([atts_img.shape[0], atts_img.shape[1] + 1],
							 | 
						||
| 
								 | 
							
								                               dtype=torch.long).to(image.device).fill_(
							 | 
						||
| 
								 | 
							
								                                   -100)  # plus one for bos
							 | 
						||
| 
								 | 
							
								                )
							 | 
						||
| 
								 | 
							
								                empty_targets = empty_targets.repeat_interleave(seg_len, dim=0)
							 | 
						||
| 
								 | 
							
								                targets = torch.cat([empty_targets, targets], dim=1)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								                bos = torch.ones([batch_size, 1],
							 | 
						||
| 
								 | 
							
								                                 dtype=to_regress_tokens.input_ids.dtype,
							 | 
						||
| 
								 | 
							
								                                 device=to_regress_tokens.input_ids.device
							 | 
						||
| 
								 | 
							
								                                 ) * self.llama_tokenizer.bos_token_id
							 | 
						||
| 
								 | 
							
								                bos_embeds = self.llama_model.model.embed_tokens(bos)
							 | 
						||
| 
								 | 
							
								                bos_embeds = bos_embeds.repeat_interleave(seg_len, dim=0)
							 | 
						||
| 
								 | 
							
								                img_embeds = img_embeds.repeat_interleave(seg_len, dim=0)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								                atts_bos = atts_img[:, :1]
							 | 
						||
| 
								 | 
							
								                atts_bos = atts_bos.repeat_interleave(seg_len, dim=0)
							 | 
						||
| 
								 | 
							
								                atts_img = atts_img.repeat_interleave(seg_len, dim=0)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								                to_regress_embeds = self.llama_model.model.embed_tokens(
							 | 
						||
| 
								 | 
							
								                    to_regress_tokens.input_ids)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								                inputs_embeds = torch.cat(
							 | 
						||
| 
								 | 
							
								                    [bos_embeds, img_embeds, to_regress_embeds], dim=1)
							 | 
						||
| 
								 | 
							
								                attention_mask = torch.cat(
							 | 
						||
| 
								 | 
							
								                    [atts_bos, atts_img, to_regress_tokens.attention_mask],
							 | 
						||
| 
								 | 
							
								                    dim=1)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								                with self.maybe_autocast():
							 | 
						||
| 
								 | 
							
								                    outputs = self.llama_model(
							 | 
						||
| 
								 | 
							
								                        inputs_embeds=inputs_embeds,
							 | 
						||
| 
								 | 
							
								                        attention_mask=attention_mask,
							 | 
						||
| 
								 | 
							
								                        return_dict=True,
							 | 
						||
| 
								 | 
							
								                        labels=targets,
							 | 
						||
| 
								 | 
							
								                        reduction='none',
							 | 
						||
| 
								 | 
							
								                    )
							 | 
						||
| 
								 | 
							
								                loss = outputs.loss
							 | 
						||
| 
								 | 
							
								                loss = loss.view(targets.size(0), -1).sum(1)
							 | 
						||
| 
								 | 
							
								                loss = loss.reshape(batch_size, seg_len)
							 | 
						||
| 
								 | 
							
								                losses.append(loss)
							 | 
						||
| 
								 | 
							
								            # losses of 4 choices
							 | 
						||
| 
								 | 
							
								            losses = torch.cat(losses, dim=-1)[0]
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        for i, data_sample in enumerate(data_samples):
							 | 
						||
| 
								 | 
							
								            data_sample.losses = losses
							 | 
						||
| 
								 | 
							
								            data_samples[i] = data_sample
							 | 
						||
| 
								 | 
							
								        return data_samples
							 |