80 lines
		
	
	
		
			3.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			80 lines
		
	
	
		
			3.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| 
								 | 
							
								import importlib
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								import mmengine
							 | 
						||
| 
								 | 
							
								import torch
							 | 
						||
| 
								 | 
							
								import torch.nn as nn
							 | 
						||
| 
								 | 
							
								from mmengine.device import get_device
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								from opencompass.registry import MM_MODELS
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								@MM_MODELS.register_module('otter-9b')
							 | 
						||
| 
								 | 
							
								class Otter(nn.Module):
							 | 
						||
| 
								 | 
							
								    """Inference code of OTTER.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    Model details:
							 | 
						||
| 
								 | 
							
								        OTTER: a multi-modal model based on OpenFlamingo
							 | 
						||
| 
								 | 
							
								        (open-sourced version of DeepMind's Flamingo)
							 | 
						||
| 
								 | 
							
								        https://github.com/Luodian/Otter
							 | 
						||
| 
								 | 
							
								    Args:
							 | 
						||
| 
								 | 
							
								        model_path (str): The path of OTTER model
							 | 
						||
| 
								 | 
							
								        in Huggingface model hub format.
							 | 
						||
| 
								 | 
							
								        load_bit (str): The bit of OTTER model, can be "fp32" or "bf16".
							 | 
						||
| 
								 | 
							
								        mode (str): The mode of inference. Defaults to 'generation'.
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def __init__(self,
							 | 
						||
| 
								 | 
							
								                 model_path,
							 | 
						||
| 
								 | 
							
								                 load_bit,
							 | 
						||
| 
								 | 
							
								                 prompt_constructor,
							 | 
						||
| 
								 | 
							
								                 post_processor,
							 | 
						||
| 
								 | 
							
								                 mode='generation') -> None:
							 | 
						||
| 
								 | 
							
								        super().__init__()
							 | 
						||
| 
								 | 
							
								        torch_dtype = torch.bfloat16 if load_bit == 'bf16' else torch.float32
							 | 
						||
| 
								 | 
							
								        otter_ai = importlib.import_module('otter_ai')
							 | 
						||
| 
								 | 
							
								        self.model = otter_ai.OtterForConditionalGeneration.from_pretrained(
							 | 
						||
| 
								 | 
							
								            model_path, torch_dtype=torch_dtype, device_map=get_device())
							 | 
						||
| 
								 | 
							
								        self.tokenizer = self.model.text_tokenizer
							 | 
						||
| 
								 | 
							
								        self.tokenizer.padding_side = 'left'
							 | 
						||
| 
								 | 
							
								        self.model_dtype = next(self.model.parameters()).dtype
							 | 
						||
| 
								 | 
							
								        self.prompt_constructor = mmengine.registry.build_from_cfg(
							 | 
						||
| 
								 | 
							
								            prompt_constructor, MM_MODELS)
							 | 
						||
| 
								 | 
							
								        if post_processor is not None:
							 | 
						||
| 
								 | 
							
								            self.post_processor = mmengine.registry.build_from_cfg(
							 | 
						||
| 
								 | 
							
								                post_processor, MM_MODELS)
							 | 
						||
| 
								 | 
							
								        self.mode = mode
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def forward(self, batch):
							 | 
						||
| 
								 | 
							
								        if self.mode == 'generation':
							 | 
						||
| 
								 | 
							
								            return self.generate(batch)
							 | 
						||
| 
								 | 
							
								        elif self.mode == 'loss':
							 | 
						||
| 
								 | 
							
								            return self.loss(batch)
							 | 
						||
| 
								 | 
							
								        else:
							 | 
						||
| 
								 | 
							
								            raise RuntimeError(f'Invalid mode "{self.mode}".')
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def generate(self, batch):
							 | 
						||
| 
								 | 
							
								        inputs = self.prompt_constructor(batch)
							 | 
						||
| 
								 | 
							
								        image = inputs['image']
							 | 
						||
| 
								 | 
							
								        prompt = inputs['prompt']
							 | 
						||
| 
								 | 
							
								        data_samples = inputs['data_samples']
							 | 
						||
| 
								 | 
							
								        vision_x = image.unsqueeze(1).unsqueeze(0).to(dtype=self.model_dtype)
							 | 
						||
| 
								 | 
							
								        lang_x = self.model.text_tokenizer([prompt], return_tensors='pt')
							 | 
						||
| 
								 | 
							
								        bad_words_id = self.model.text_tokenizer(['User:', 'GPT:']).input_ids
							 | 
						||
| 
								 | 
							
								        generated_text = self.model.generate(
							 | 
						||
| 
								 | 
							
								            vision_x=vision_x.to(self.model.device),
							 | 
						||
| 
								 | 
							
								            lang_x=lang_x['input_ids'].to(self.model.device),
							 | 
						||
| 
								 | 
							
								            attention_mask=lang_x['attention_mask'].to(self.model.device),
							 | 
						||
| 
								 | 
							
								            do_sample=False,
							 | 
						||
| 
								 | 
							
								            max_new_tokens=512,
							 | 
						||
| 
								 | 
							
								            num_beams=3,
							 | 
						||
| 
								 | 
							
								            bad_words_ids=bad_words_id,
							 | 
						||
| 
								 | 
							
								            no_repeat_ngram_size=3,
							 | 
						||
| 
								 | 
							
								        )
							 | 
						||
| 
								 | 
							
								        for i, data_sample in enumerate(data_samples):
							 | 
						||
| 
								 | 
							
								            output_text = self.post_processor(generated_text[i],
							 | 
						||
| 
								 | 
							
								                                              self.model.text_tokenizer)
							 | 
						||
| 
								 | 
							
								            data_sample.pred_answer = output_text
							 | 
						||
| 
								 | 
							
								            data_samples[i] = data_sample
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        return data_samples
							 |