dataset-opencompass/opencompass/openicl/icl_inferencer/icl_attack_inferencer.py

211 lines
8.6 KiB
Python
Raw Normal View History

2025-07-18 07:25:44 +00:00
"""Direct Generation Inferencer."""
import os
import os.path as osp
from typing import List, Optional
import mmengine
import torch
from tqdm import tqdm
from opencompass.models.base import BaseModel
from opencompass.registry import (ICL_EVALUATORS, ICL_INFERENCERS,
TEXT_POSTPROCESSORS)
from ..icl_prompt_template import PromptTemplate
from ..icl_retriever import BaseRetriever
from ..utils.logging import get_logger
from .icl_base_inferencer import BaseInferencer, GenInferencerOutputHandler
logger = get_logger(__name__)
@ICL_INFERENCERS.register_module()
class AttackInferencer(BaseInferencer):
"""Generation Inferencer class to directly evaluate by generation.
Attributes:
model (:obj:`BaseModelWrapper`, optional): The module to inference.
max_out_len (:obj:`int`, optional): Maximum number of tokenized words
of the output.
adv_key (:obj:`str`): Prompt key in template to be attacked.
metric_key (:obj:`str`): Metric key to be returned and compared.
Defaults to `accuracy`.
max_seq_len (:obj:`int`, optional): Maximum number of tokenized words
allowed by the LM.
batch_size (:obj:`int`, optional): Batch size for the
:obj:`DataLoader`.
output_json_filepath (:obj:`str`, optional): File path for output
`JSON` file.
output_json_filename (:obj:`str`, optional): File name for output
`JSON` file.
gen_field_replace_token (:obj:`str`, optional): Used to replace the
generation field token when generating prompts.
save_every (:obj:`int`, optional): Save intermediate results every
`save_every` epochs.
generation_kwargs (:obj:`Dict`, optional): Parameters for the
:obj:`model.generate()` method.
"""
def __init__(
self,
model: BaseModel,
max_out_len: int,
adv_key: str,
metric_key: str = 'accuracy',
max_seq_len: Optional[int] = None,
batch_size: Optional[int] = 1,
gen_field_replace_token: Optional[str] = '',
output_json_filepath: Optional[str] = './icl_inference_output',
output_json_filename: Optional[str] = 'predictions',
save_every: Optional[int] = None,
fix_id_list: Optional[List[int]] = None,
dataset_cfg: Optional[List[int]] = None,
**kwargs) -> None:
super().__init__(
model=model,
max_seq_len=max_seq_len,
batch_size=batch_size,
output_json_filename=output_json_filename,
output_json_filepath=output_json_filepath,
**kwargs,
)
self.adv_key = adv_key
self.metric_key = metric_key
self.dataset_cfg = dataset_cfg
self.eval_cfg = dataset_cfg['eval_cfg']
self.output_column = dataset_cfg['reader_cfg']['output_column']
self.gen_field_replace_token = gen_field_replace_token
self.max_out_len = max_out_len
self.fix_id_list = fix_id_list
if self.model.is_api and save_every is None:
save_every = 1
self.save_every = save_every
def predict(self, adv_prompt) -> List:
# 1. Preparation for output logs
output_handler = GenInferencerOutputHandler()
# if output_json_filepath is None:
output_json_filepath = self.output_json_filepath
# if output_json_filename is None:
output_json_filename = self.output_json_filename
# 2. Get results of retrieval process
if 'Fix' in self.retriever.__class__.__name__:
ice_idx_list = self.retriever.retrieve(self.fix_id_list)
else:
ice_idx_list = self.retriever.retrieve()
# 3. Generate prompts for testing input
prompt_list, label_list = self.get_generation_prompt_list_from_retriever_indices( # noqa
ice_idx_list, {self.adv_key: adv_prompt},
self.retriever,
self.gen_field_replace_token,
max_seq_len=self.max_seq_len,
ice_template=self.ice_template,
prompt_template=self.prompt_template)
# Create tmp json file for saving intermediate results and future
# resuming
index = 0
tmp_json_filepath = os.path.join(output_json_filepath,
'tmp_' + output_json_filename)
if osp.exists(tmp_json_filepath):
# TODO: move resume to output handler
tmp_result_dict = mmengine.load(tmp_json_filepath)
output_handler.results_dict = tmp_result_dict
index = len(tmp_result_dict)
# 4. Wrap prompts with Dataloader
dataloader = self.get_dataloader(prompt_list[index:], self.batch_size)
# 5. Inference for prompts in each batch
logger.info('Starting inference process...')
for entry in tqdm(dataloader, disable=not self.is_main_process):
# 5-1. Inference with local model
with torch.no_grad():
parsed_entries = self.model.parse_template(entry, mode='gen')
results = self.model.generate_from_template(
entry, max_out_len=self.max_out_len)
generated = results
# 5-3. Save current output
for prompt, prediction in zip(parsed_entries, generated):
output_handler.save_results(prompt, prediction, index)
index = index + 1
# 5-4. Save intermediate results
if (self.save_every is not None and index % self.save_every == 0
and self.is_main_process):
output_handler.write_to_json(output_json_filepath,
'tmp_' + output_json_filename)
# 6. Output
if self.is_main_process:
os.makedirs(output_json_filepath, exist_ok=True)
output_handler.write_to_json(output_json_filepath,
output_json_filename)
if osp.exists(tmp_json_filepath):
os.remove(tmp_json_filepath)
pred_strs = [
sample['prediction']
for sample in output_handler.results_dict.values()
]
if 'pred_postprocessor' in self.eval_cfg:
kwargs = self.eval_cfg['pred_postprocessor'].copy()
proc = TEXT_POSTPROCESSORS.get(kwargs.pop('type'))
pred_strs = [proc(s, **kwargs) for s in pred_strs]
icl_evaluator = ICL_EVALUATORS.build(self.eval_cfg['evaluator'])
result = icl_evaluator.score(predictions=pred_strs,
references=label_list)
score = result.get(self.metric_key)
# try to shrink score to range 0-1
return score / 100 if score > 1 else score
def get_generation_prompt_list_from_retriever_indices(
self,
ice_idx_list: List[List[int]],
extra_prompt: dict,
retriever: BaseRetriever,
gen_field_replace_token: str,
max_seq_len: Optional[int] = None,
ice_template: Optional[PromptTemplate] = None,
prompt_template: Optional[PromptTemplate] = None):
prompt_list = []
label_list = []
for idx, ice_idx in enumerate(ice_idx_list):
ice = retriever.generate_ice(ice_idx, ice_template=ice_template)
prompt = retriever.generate_prompt_for_adv_generate_task(
idx,
ice,
extra_prompt,
gen_field_replace_token=gen_field_replace_token,
ice_template=ice_template,
prompt_template=prompt_template)
label = retriever.test_ds[idx][self.output_column]
label_list.append(label)
if max_seq_len is not None:
prompt_token_num = self.model.get_token_len_from_template(
prompt, mode='gen')
while len(ice_idx) > 0 and prompt_token_num > max_seq_len:
ice_idx = ice_idx[:-1]
ice = retriever.generate_ice(ice_idx,
ice_template=ice_template)
prompt = retriever.generate_prompt_for_adv_generate_task(
idx,
ice,
extra_prompt,
gen_field_replace_token=gen_field_replace_token,
ice_template=ice_template,
prompt_template=prompt_template)
prompt_token_num = self.model.get_token_len_from_template(
prompt, mode='gen')
prompt_list.append(prompt)
return prompt_list, label_list