update readme
This commit is contained in:
parent
5447cef50a
commit
dee166dad3
37
README.md
37
README.md
@ -144,15 +144,23 @@ Compared to `jina-reranker-v2-base-multilingual`, `jina-reranker-m0` significant
|
||||
pip install transformers >= 4.47.3
|
||||
```
|
||||
|
||||
If you run it on a GPU that support FlashAttention-2. By 2024.9.12, it supports Ampere, Ada, or Hopper GPUs (e.g., A100, RTX 3090, RTX 4090, H100),
|
||||
|
||||
```bash
|
||||
pip install flash-attn --no-build-isolation
|
||||
```
|
||||
|
||||
And then use the following code snippet to load the model:
|
||||
|
||||
```python
|
||||
from transformers import AutoModel
|
||||
|
||||
# comment out the flash_attention_2 line if you don't have a compatible GPU
|
||||
model = AutoModel.from_pretrained(
|
||||
'jinaai/jina-reranker-m0',
|
||||
torch_dtype="auto",
|
||||
trust_remote_code=True,
|
||||
attn_implementation="flash_attention_2"
|
||||
)
|
||||
|
||||
model.to('cuda') # or 'cpu' if no GPU is available
|
||||
@ -178,7 +186,7 @@ Compared to `jina-reranker-v2-base-multilingual`, `jina-reranker-m0` significant
|
||||
image_pairs = [[query, doc] for doc in documents]
|
||||
|
||||
scores = model.compute_score(image_pairs, max_length=2048, doc_type="image")
|
||||
# [0.8576154708862305, 0.9356858730316162, 0.8496521711349487, 0.8664582967758179]
|
||||
# [0.49375027418136597, 0.7889736890792847, 0.47813892364501953, 0.5210812091827393]
|
||||
```
|
||||
|
||||
**B. Textual Documents Reranking**
|
||||
@ -201,7 +209,7 @@ Compared to `jina-reranker-v2-base-multilingual`, `jina-reranker-m0` significant
|
||||
The scores will be a list of floats, where each float represents the relevance score of the corresponding document to the query. Higher scores indicate higher relevance.
|
||||
For instance the returning scores in this case will be:
|
||||
```bash
|
||||
[0.9127850532531738, 0.8384682536125183, 0.8870794177055359, 0.842738926410675]
|
||||
[0.6839263439178467, 0.4432148039340973, 0.5904013514518738, 0.45481112599372864]
|
||||
```
|
||||
|
||||
**C. Image Querying for Textual Documents**
|
||||
@ -218,10 +226,29 @@ Compared to `jina-reranker-v2-base-multilingual`, `jina-reranker-m0` significant
|
||||
"Die wichtigsten Beiträge unserer Arbeit sind zweifach: Erstens führen wir eine neuartige dreistufige Datensynthese-Pipeline namens Draft-Refine-Critique ein, die durch iterative Verfeinerung hochwertige Trainingsdaten generiert; und zweitens schlagen wir eine umfassende Trainingsstrategie vor, die kontinuierliches Vortraining zur Längenerweiterung, überwachtes Feintuning mit spezialisierten Kontrollpunkten, direkte Präferenzoptimierung (DPO) und iteratives Self-Play-Tuning kombiniert. Um die weitere Forschung und Anwendung der strukturierten Inhaltsextraktion zu erleichtern, ist das Modell auf Hugging Face öffentlich verfügbar.",
|
||||
]
|
||||
# reverse the order of the query and document
|
||||
image_pairs = [[doc, query] for doc in documents]
|
||||
scores = model.compute_score(image_pairs, max_length=2048, doc_type="text")
|
||||
image_pairs = [[query, doc] for doc in documents]
|
||||
scores = model.compute_score(image_pairs, max_length=2048, query_type="image", doc_type="text")
|
||||
|
||||
# [0.9048659801483154, 0.8266222476959229, 0.8326289653778076, 0.9075747132301331]
|
||||
# [0.98099285364151, 0.7701883316040039, 0.5637142062187195, 0.9308615922927856]
|
||||
```
|
||||
|
||||
**D. Image Querying for Image Documents**
|
||||
|
||||
The model also supports querying image documents with an image query. You can use the following code snippet:
|
||||
|
||||
```python
|
||||
query = "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/paper-11.png"
|
||||
|
||||
documents = [
|
||||
"https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/handelsblatt-preview.png",
|
||||
"https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/paper-11.png",
|
||||
"https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/wired-preview.png",
|
||||
"https://jina.ai/blog-banner/using-deepseek-r1-reasoning-model-in-deepsearch.webp"
|
||||
]
|
||||
|
||||
image_pairs = [[query, doc] for doc in documents]
|
||||
scores = model.compute_score(image_pairs, max_length=2048, doc_type="image", query_type='image')
|
||||
# [0.6275860667228699, 0.9922324419021606, 0.8090347051620483, 0.7941296100616455]
|
||||
```
|
||||
|
||||
# Model Performance
|
||||
|
||||
26
modeling.py
26
modeling.py
@ -10,7 +10,7 @@ from transformers.image_utils import load_image
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
LOGIT_SCALE = 0.68
|
||||
LOGIT_BIAS = 2.65 # logit bias for sigmoid normalization
|
||||
|
||||
def load_images(images, lazy_load: bool = True):
|
||||
# Disable PIL DecompositionBomb threshold for reading large images.
|
||||
@ -123,7 +123,7 @@ class JinaVLForRanking(Qwen2VLForConditionalGeneration):
|
||||
pairs: Union[List[Tuple[str, str]], Tuple[str, str]],
|
||||
batch_size: int = 8,
|
||||
max_length: int = 10240,
|
||||
max_query_length: int = 1024,
|
||||
max_query_length: int = 512,
|
||||
max_doc_length: Optional[int] = None,
|
||||
query_type: str = 'text',
|
||||
doc_type: str = 'text',
|
||||
@ -183,10 +183,24 @@ class JinaVLForRanking(Qwen2VLForConditionalGeneration):
|
||||
batch_inputs.append(formatting_prompts_func(q, d, query_type=query_type, doc_type=doc_type))
|
||||
|
||||
batch_images = None
|
||||
# if doc_type == 'image':
|
||||
# batch_images = load_images([d for (q, d) in mini_batch])
|
||||
# elif query_type == 'image':
|
||||
# batch_images = load_images([q for (q, d) in mini_batch])
|
||||
|
||||
doc_images = []
|
||||
query_images = []
|
||||
if doc_type == 'image':
|
||||
batch_images = load_images([d for (q, d) in mini_batch])
|
||||
elif query_type == 'image':
|
||||
batch_images = load_images([q for (q, d) in mini_batch])
|
||||
doc_images = load_images([d for (q, d) in mini_batch])
|
||||
if query_type == 'image':
|
||||
query_images = load_images([q for (q, d) in mini_batch])
|
||||
|
||||
if len(doc_images) == len(query_images) and len(doc_images) > 0:
|
||||
batch_images = [[d, q] for q, d in zip(query_images, doc_images)]
|
||||
elif len(doc_images) > 0:
|
||||
batch_images = doc_images
|
||||
elif len(query_images) > 0:
|
||||
batch_images = query_images
|
||||
|
||||
batch = self._processor(
|
||||
text=batch_inputs,
|
||||
@ -219,7 +233,7 @@ class JinaVLForRanking(Qwen2VLForConditionalGeneration):
|
||||
scores = self.forward(**batch).view(-1).cpu().float().numpy()
|
||||
|
||||
# normalize scores to [0, 1] with sigmoid with a scale
|
||||
scores = 1.0 / (1.0 + np.exp(-scores * LOGIT_SCALE))
|
||||
scores = 1.0 / (1.0 + np.exp(-(scores - LOGIT_BIAS)))
|
||||
|
||||
all_scores.extend(scores.tolist())
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user