157 lines
		
	
	
		
			5.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			157 lines
		
	
	
		
			5.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| 
								 | 
							
								import importlib
							 | 
						||
| 
								 | 
							
								import os
							 | 
						||
| 
								 | 
							
								import sys
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								import mmengine
							 | 
						||
| 
								 | 
							
								import torch
							 | 
						||
| 
								 | 
							
								import torch.nn as nn
							 | 
						||
| 
								 | 
							
								from mmengine.device import get_device
							 | 
						||
| 
								 | 
							
								from transformers import StoppingCriteria
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								from opencompass.registry import MM_MODELS
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								IMAGE_TOKEN_INDEX = -200
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def load_package():
							 | 
						||
| 
								 | 
							
								    """Load required packages from LLaVA."""
							 | 
						||
| 
								 | 
							
								    current_file_path = os.path.abspath(__file__)
							 | 
						||
| 
								 | 
							
								    current_folder_path = os.path.dirname(current_file_path)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    sys.path.append(os.path.join(current_folder_path, 'LLaVA'))  # noqa
							 | 
						||
| 
								 | 
							
								    return
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class KeywordsStoppingCriteria(StoppingCriteria):
							 | 
						||
| 
								 | 
							
								    """Keyword stopping criteria implemented for llava."""
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def __init__(self, keywords, tokenizer, input_ids):
							 | 
						||
| 
								 | 
							
								        self.keywords = keywords
							 | 
						||
| 
								 | 
							
								        self.tokenizer = tokenizer
							 | 
						||
| 
								 | 
							
								        self.start_len = None
							 | 
						||
| 
								 | 
							
								        self.input_ids = input_ids
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor,
							 | 
						||
| 
								 | 
							
								                 **kwargs) -> bool:
							 | 
						||
| 
								 | 
							
								        if self.start_len is None:
							 | 
						||
| 
								 | 
							
								            self.start_len = self.input_ids.shape[1]
							 | 
						||
| 
								 | 
							
								        else:
							 | 
						||
| 
								 | 
							
								            outputs = self.tokenizer.batch_decode(output_ids[:,
							 | 
						||
| 
								 | 
							
								                                                             self.start_len:],
							 | 
						||
| 
								 | 
							
								                                                  skip_special_tokens=True)[0]
							 | 
						||
| 
								 | 
							
								            for keyword in self.keywords:
							 | 
						||
| 
								 | 
							
								                if keyword in outputs:
							 | 
						||
| 
								 | 
							
								                    return True
							 | 
						||
| 
								 | 
							
								        return False
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								@MM_MODELS.register_module('llava')
							 | 
						||
| 
								 | 
							
								class LLaVA(nn.Module):
							 | 
						||
| 
								 | 
							
								    """Inference code of LLaVA. Need to clone LLaVA official repo first. Please
							 | 
						||
| 
								 | 
							
								    check out the README in config.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    Args:
							 | 
						||
| 
								 | 
							
								        model_path (str): The path of llava checkpoint.
							 | 
						||
| 
								 | 
							
								        prompt_constructor (dict): The config of prompt constructor.
							 | 
						||
| 
								 | 
							
								        post_processor (dict): The config of post processor.
							 | 
						||
| 
								 | 
							
								        is_caption_task (bool): Whether the task is caption task.
							 | 
						||
| 
								 | 
							
								            Defaults to False.
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def __init__(
							 | 
						||
| 
								 | 
							
								        self,
							 | 
						||
| 
								 | 
							
								        model_path: str,
							 | 
						||
| 
								 | 
							
								        prompt_constructor: dict,
							 | 
						||
| 
								 | 
							
								        post_processor: dict,
							 | 
						||
| 
								 | 
							
								        is_caption_task: bool = False,
							 | 
						||
| 
								 | 
							
								    ) -> None:
							 | 
						||
| 
								 | 
							
								        super().__init__()
							 | 
						||
| 
								 | 
							
								        self.dtype = torch.float16
							 | 
						||
| 
								 | 
							
								        self.is_caption_task = is_caption_task
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        # load LLaVA modules
							 | 
						||
| 
								 | 
							
								        load_package()
							 | 
						||
| 
								 | 
							
								        mm_utils = importlib.import_module('llava.mm_utils')
							 | 
						||
| 
								 | 
							
								        builder = importlib.import_module('llava.model.builder')
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        # load pretrained LLaVA
							 | 
						||
| 
								 | 
							
								        # Note: When encounters with device related errors,
							 | 
						||
| 
								 | 
							
								        # try setting `low_cpu_mem_usage` in `load_pretrained_model` as False
							 | 
						||
| 
								 | 
							
								        model_name = mm_utils.get_model_name_from_path(model_path)
							 | 
						||
| 
								 | 
							
								        tokenizer, model, _, _ = builder.load_pretrained_model(
							 | 
						||
| 
								 | 
							
								            model_path, None, model_name)
							 | 
						||
| 
								 | 
							
								        vision_tower = model.get_vision_tower()
							 | 
						||
| 
								 | 
							
								        vision_tower.to(device=get_device(), dtype=self.dtype)
							 | 
						||
| 
								 | 
							
								        model.to(device=get_device(), dtype=self.dtype)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        # load prompt constructor and post processor
							 | 
						||
| 
								 | 
							
								        if 'v1' in model_path.lower():
							 | 
						||
| 
								 | 
							
								            conv_mode = 'llava_v1'
							 | 
						||
| 
								 | 
							
								        elif 'mpt' in model_path.lower():
							 | 
						||
| 
								 | 
							
								            conv_mode = 'mpt_multimodal'
							 | 
						||
| 
								 | 
							
								        else:
							 | 
						||
| 
								 | 
							
								            conv_mode = 'multimodal'
							 | 
						||
| 
								 | 
							
								        mm_use_im_start_end = getattr(model.config, 'mm_use_im_start_end',
							 | 
						||
| 
								 | 
							
								                                      False)
							 | 
						||
| 
								 | 
							
								        prompt_constructor.update({
							 | 
						||
| 
								 | 
							
								            'conv_mode': conv_mode,
							 | 
						||
| 
								 | 
							
								            'mm_use_im_start_end': mm_use_im_start_end
							 | 
						||
| 
								 | 
							
								        })
							 | 
						||
| 
								 | 
							
								        self.prompt_constructor = mmengine.registry.build_from_cfg(
							 | 
						||
| 
								 | 
							
								            prompt_constructor, MM_MODELS)
							 | 
						||
| 
								 | 
							
								        self.post_processor = mmengine.registry.build_from_cfg(
							 | 
						||
| 
								 | 
							
								            post_processor, MM_MODELS)
							 | 
						||
| 
								 | 
							
								        self.model = model
							 | 
						||
| 
								 | 
							
								        self.tokenizer = tokenizer
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def generate(self, batch):
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        prompt, stop_str = self.prompt_constructor(batch)
							 | 
						||
| 
								 | 
							
								        keywords = [stop_str]
							 | 
						||
| 
								 | 
							
								        data_sample = batch['data_samples'][0]
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        image = batch['inputs'][0].unsqueeze(0)
							 | 
						||
| 
								 | 
							
								        if image is not None:
							 | 
						||
| 
								 | 
							
								            images = image.to(get_device())
							 | 
						||
| 
								 | 
							
								        else:
							 | 
						||
| 
								 | 
							
								            images = None
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        mm_utils = importlib.import_module('llava.mm_utils')
							 | 
						||
| 
								 | 
							
								        input_ids = mm_utils.tokenizer_image_token(
							 | 
						||
| 
								 | 
							
								            prompt, self.tokenizer, IMAGE_TOKEN_INDEX,
							 | 
						||
| 
								 | 
							
								            return_tensors='pt').unsqueeze(0).to(get_device())
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer,
							 | 
						||
| 
								 | 
							
								                                                     input_ids)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        with torch.inference_mode():
							 | 
						||
| 
								 | 
							
								            output_ids = self.model.generate(
							 | 
						||
| 
								 | 
							
								                input_ids,
							 | 
						||
| 
								 | 
							
								                images=images.half(),
							 | 
						||
| 
								 | 
							
								                do_sample=True,
							 | 
						||
| 
								 | 
							
								                temperature=0.2,
							 | 
						||
| 
								 | 
							
								                max_new_tokens=1024,
							 | 
						||
| 
								 | 
							
								                stopping_criteria=[stopping_criteria],
							 | 
						||
| 
								 | 
							
								            )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        input_token_len = input_ids.shape[1]
							 | 
						||
| 
								 | 
							
								        n_diff_input_output = (input_ids !=
							 | 
						||
| 
								 | 
							
								                               output_ids[:, :input_token_len]).sum().item()
							 | 
						||
| 
								 | 
							
								        if n_diff_input_output > 0:
							 | 
						||
| 
								 | 
							
								            print(
							 | 
						||
| 
								 | 
							
								                f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids'  # noqa
							 | 
						||
| 
								 | 
							
								            )
							 | 
						||
| 
								 | 
							
								        outputs = self.tokenizer.batch_decode(output_ids[:, input_token_len:],
							 | 
						||
| 
								 | 
							
								                                              skip_special_tokens=True)[0]
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        output_text = self.post_processor(outputs, stop_str)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        if self.is_caption_task:
							 | 
						||
| 
								 | 
							
								            data_sample.pred_caption = output_text
							 | 
						||
| 
								 | 
							
								        else:
							 | 
						||
| 
								 | 
							
								            data_sample.pred_answer = output_text
							 | 
						||
| 
								 | 
							
								        return data_sample
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def forward(self, batch):
							 | 
						||
| 
								 | 
							
								        return self.generate(batch)
							 |