57 lines
		
	
	
		
			1.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			57 lines
		
	
	
		
			1.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| 
								 | 
							
								import os
							 | 
						||
| 
								 | 
							
								import re
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								import timm.models.hub as timm_hub
							 | 
						||
| 
								 | 
							
								import torch
							 | 
						||
| 
								 | 
							
								import torch.distributed as dist
							 | 
						||
| 
								 | 
							
								from mmengine.dist import is_distributed, is_main_process
							 | 
						||
| 
								 | 
							
								from transformers import StoppingCriteria
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class StoppingCriteriaSub(StoppingCriteria):
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def __init__(self, stops=[], encounters=1):
							 | 
						||
| 
								 | 
							
								        super().__init__()
							 | 
						||
| 
								 | 
							
								        self.stops = stops
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
							 | 
						||
| 
								 | 
							
								        for stop in self.stops:
							 | 
						||
| 
								 | 
							
								            if torch.all((stop == input_ids[0][-len(stop):])).item():
							 | 
						||
| 
								 | 
							
								                return True
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        return False
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def download_cached_file(url, check_hash=True, progress=False):
							 | 
						||
| 
								 | 
							
								    """Download a file from a URL and cache it locally.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    If the file already exists, it is not downloaded again. If distributed,
							 | 
						||
| 
								 | 
							
								    only the main process downloads the file, and the other processes wait for
							 | 
						||
| 
								 | 
							
								    the file to be downloaded.
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def get_cached_file_path():
							 | 
						||
| 
								 | 
							
								        # a hack to sync the file path across processes
							 | 
						||
| 
								 | 
							
								        parts = torch.hub.urlparse(url)
							 | 
						||
| 
								 | 
							
								        filename = os.path.basename(parts.path)
							 | 
						||
| 
								 | 
							
								        cached_file = os.path.join(timm_hub.get_cache_dir(), filename)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        return cached_file
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    if is_main_process():
							 | 
						||
| 
								 | 
							
								        timm_hub.download_cached_file(url, check_hash, progress)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    if is_distributed():
							 | 
						||
| 
								 | 
							
								        dist.barrier()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    return get_cached_file_path()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def is_url(input_url):
							 | 
						||
| 
								 | 
							
								    """Check if an input string is a url.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    look for http(s):// and ignoring the case
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    is_url = re.match(r'^(?:http)s?://', input_url, re.IGNORECASE) is not None
							 | 
						||
| 
								 | 
							
								    return is_url
							 |