59 lines
		
	
	
		
			1.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			59 lines
		
	
	
		
			1.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| 
								 | 
							
								from typing import List
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								from mmpretrain.structures import DataSample
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class MplugOwlMMBenchPromptConstructor:
							 | 
						||
| 
								 | 
							
								    """Prompt constructor for MplugOwl on MMBench.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    Args:
							 | 
						||
| 
								 | 
							
								        image_prompt (str): Image prompt. Defaults to `''`.
							 | 
						||
| 
								 | 
							
								        reply_prompt (str): Reply prompt. Defaults to `''`.
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def __init__(self, image_prompt: str = '', reply_prompt: str = '') -> None:
							 | 
						||
| 
								 | 
							
								        self.image_prompt = image_prompt
							 | 
						||
| 
								 | 
							
								        self.reply_prompt = reply_prompt
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def __call__(self, inputs: dict) -> dict:
							 | 
						||
| 
								 | 
							
								        """Construct prompt.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        Args:
							 | 
						||
| 
								 | 
							
								            inputs (dict): Input data containing image and data_samples.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        Returns:
							 | 
						||
| 
								 | 
							
								            dict: A dict containing prompt, images and data_samples.
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        data_samples = inputs['data_samples']
							 | 
						||
| 
								 | 
							
								        prompt = self._process(data_samples)
							 | 
						||
| 
								 | 
							
								        inputs.update({'prompt': prompt})
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        return inputs
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def _process(self, data_samples: List[DataSample]) -> str:
							 | 
						||
| 
								 | 
							
								        """Process data sample to prompt.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        Args:
							 | 
						||
| 
								 | 
							
								            data_samples (List[DataSample]): A list of data_samples.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        Returns:
							 | 
						||
| 
								 | 
							
								            str: Prompt.
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        question = [
							 | 
						||
| 
								 | 
							
								            data_sample.get('question') for data_sample in data_samples
							 | 
						||
| 
								 | 
							
								        ]
							 | 
						||
| 
								 | 
							
								        options = [data_sample.get('options') for data_sample in data_samples]
							 | 
						||
| 
								 | 
							
								        if data_samples[0].get('context') is not None:
							 | 
						||
| 
								 | 
							
								            context = [
							 | 
						||
| 
								 | 
							
								                data_sample.get('context') for data_sample in data_samples
							 | 
						||
| 
								 | 
							
								            ]
							 | 
						||
| 
								 | 
							
								        else:
							 | 
						||
| 
								 | 
							
								            context = [''] * len(data_samples)
							 | 
						||
| 
								 | 
							
								        prompts = []
							 | 
						||
| 
								 | 
							
								        for cur_context, cur_question, cur_options in zip(
							 | 
						||
| 
								 | 
							
								                context, question, options):
							 | 
						||
| 
								 | 
							
								            prompts.append(cur_context + ' ' + cur_question + ' ' +
							 | 
						||
| 
								 | 
							
								                           cur_options)  # noqa
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        return prompts
							 |