diff --git a/README.md b/README.md index cb5af56..f387968 100644 --- a/README.md +++ b/README.md @@ -144,15 +144,23 @@ Compared to `jina-reranker-v2-base-multilingual`, `jina-reranker-m0` significant pip install transformers >= 4.47.3 ``` + If you run it on a GPU that support FlashAttention-2. By 2024.9.12, it supports Ampere, Ada, or Hopper GPUs (e.g., A100, RTX 3090, RTX 4090, H100), + + ```bash + pip install flash-attn --no-build-isolation + ``` + And then use the following code snippet to load the model: ```python from transformers import AutoModel + # comment out the flash_attention_2 line if you don't have a compatible GPU model = AutoModel.from_pretrained( 'jinaai/jina-reranker-m0', torch_dtype="auto", trust_remote_code=True, + attn_implementation="flash_attention_2" ) model.to('cuda') # or 'cpu' if no GPU is available @@ -178,7 +186,7 @@ Compared to `jina-reranker-v2-base-multilingual`, `jina-reranker-m0` significant image_pairs = [[query, doc] for doc in documents] scores = model.compute_score(image_pairs, max_length=2048, doc_type="image") - # [0.8576154708862305, 0.9356858730316162, 0.8496521711349487, 0.8664582967758179] + # [0.49375027418136597, 0.7889736890792847, 0.47813892364501953, 0.5210812091827393] ``` **B. Textual Documents Reranking** @@ -201,7 +209,7 @@ Compared to `jina-reranker-v2-base-multilingual`, `jina-reranker-m0` significant The scores will be a list of floats, where each float represents the relevance score of the corresponding document to the query. Higher scores indicate higher relevance. For instance the returning scores in this case will be: ```bash - [0.9127850532531738, 0.8384682536125183, 0.8870794177055359, 0.842738926410675] + [0.6839263439178467, 0.4432148039340973, 0.5904013514518738, 0.45481112599372864] ``` **C. Image Querying for Textual Documents** @@ -218,10 +226,29 @@ Compared to `jina-reranker-v2-base-multilingual`, `jina-reranker-m0` significant "Die wichtigsten Beiträge unserer Arbeit sind zweifach: Erstens führen wir eine neuartige dreistufige Datensynthese-Pipeline namens Draft-Refine-Critique ein, die durch iterative Verfeinerung hochwertige Trainingsdaten generiert; und zweitens schlagen wir eine umfassende Trainingsstrategie vor, die kontinuierliches Vortraining zur Längenerweiterung, überwachtes Feintuning mit spezialisierten Kontrollpunkten, direkte Präferenzoptimierung (DPO) und iteratives Self-Play-Tuning kombiniert. Um die weitere Forschung und Anwendung der strukturierten Inhaltsextraktion zu erleichtern, ist das Modell auf Hugging Face öffentlich verfügbar.", ] # reverse the order of the query and document - image_pairs = [[doc, query] for doc in documents] - scores = model.compute_score(image_pairs, max_length=2048, doc_type="text") + image_pairs = [[query, doc] for doc in documents] + scores = model.compute_score(image_pairs, max_length=2048, query_type="image", doc_type="text") - # [0.9048659801483154, 0.8266222476959229, 0.8326289653778076, 0.9075747132301331] + # [0.98099285364151, 0.7701883316040039, 0.5637142062187195, 0.9308615922927856] + ``` + + **D. Image Querying for Image Documents** + + The model also supports querying image documents with an image query. You can use the following code snippet: + + ```python + query = "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/paper-11.png" + + documents = [ + "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/handelsblatt-preview.png", + "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/paper-11.png", + "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/wired-preview.png", + "https://jina.ai/blog-banner/using-deepseek-r1-reasoning-model-in-deepsearch.webp" + ] + + image_pairs = [[query, doc] for doc in documents] + scores = model.compute_score(image_pairs, max_length=2048, doc_type="image", query_type='image') + # [0.6275860667228699, 0.9922324419021606, 0.8090347051620483, 0.7941296100616455] ``` # Model Performance diff --git a/modeling.py b/modeling.py index 0214f34..8dd72d6 100644 --- a/modeling.py +++ b/modeling.py @@ -10,7 +10,7 @@ from transformers.image_utils import load_image logger = logging.getLogger(__name__) -LOGIT_SCALE = 0.68 +LOGIT_BIAS = 2.65 # logit bias for sigmoid normalization def load_images(images, lazy_load: bool = True): # Disable PIL DecompositionBomb threshold for reading large images. @@ -123,7 +123,7 @@ class JinaVLForRanking(Qwen2VLForConditionalGeneration): pairs: Union[List[Tuple[str, str]], Tuple[str, str]], batch_size: int = 8, max_length: int = 10240, - max_query_length: int = 1024, + max_query_length: int = 512, max_doc_length: Optional[int] = None, query_type: str = 'text', doc_type: str = 'text', @@ -183,10 +183,24 @@ class JinaVLForRanking(Qwen2VLForConditionalGeneration): batch_inputs.append(formatting_prompts_func(q, d, query_type=query_type, doc_type=doc_type)) batch_images = None + # if doc_type == 'image': + # batch_images = load_images([d for (q, d) in mini_batch]) + # elif query_type == 'image': + # batch_images = load_images([q for (q, d) in mini_batch]) + + doc_images = [] + query_images = [] if doc_type == 'image': - batch_images = load_images([d for (q, d) in mini_batch]) - elif query_type == 'image': - batch_images = load_images([q for (q, d) in mini_batch]) + doc_images = load_images([d for (q, d) in mini_batch]) + if query_type == 'image': + query_images = load_images([q for (q, d) in mini_batch]) + + if len(doc_images) == len(query_images) and len(doc_images) > 0: + batch_images = [[d, q] for q, d in zip(query_images, doc_images)] + elif len(doc_images) > 0: + batch_images = doc_images + elif len(query_images) > 0: + batch_images = query_images batch = self._processor( text=batch_inputs, @@ -219,7 +233,7 @@ class JinaVLForRanking(Qwen2VLForConditionalGeneration): scores = self.forward(**batch).view(-1).cpu().float().numpy() # normalize scores to [0, 1] with sigmoid with a scale - scores = 1.0 / (1.0 + np.exp(-scores * LOGIT_SCALE)) + scores = 1.0 / (1.0 + np.exp(-(scores - LOGIT_BIAS))) all_scores.extend(scores.tolist())