update readme
This commit is contained in:
parent
5447cef50a
commit
dee166dad3
37
README.md
37
README.md
@ -144,15 +144,23 @@ Compared to `jina-reranker-v2-base-multilingual`, `jina-reranker-m0` significant
|
|||||||
pip install transformers >= 4.47.3
|
pip install transformers >= 4.47.3
|
||||||
```
|
```
|
||||||
|
|
||||||
|
If you run it on a GPU that support FlashAttention-2. By 2024.9.12, it supports Ampere, Ada, or Hopper GPUs (e.g., A100, RTX 3090, RTX 4090, H100),
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install flash-attn --no-build-isolation
|
||||||
|
```
|
||||||
|
|
||||||
And then use the following code snippet to load the model:
|
And then use the following code snippet to load the model:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from transformers import AutoModel
|
from transformers import AutoModel
|
||||||
|
|
||||||
|
# comment out the flash_attention_2 line if you don't have a compatible GPU
|
||||||
model = AutoModel.from_pretrained(
|
model = AutoModel.from_pretrained(
|
||||||
'jinaai/jina-reranker-m0',
|
'jinaai/jina-reranker-m0',
|
||||||
torch_dtype="auto",
|
torch_dtype="auto",
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
|
attn_implementation="flash_attention_2"
|
||||||
)
|
)
|
||||||
|
|
||||||
model.to('cuda') # or 'cpu' if no GPU is available
|
model.to('cuda') # or 'cpu' if no GPU is available
|
||||||
@ -178,7 +186,7 @@ Compared to `jina-reranker-v2-base-multilingual`, `jina-reranker-m0` significant
|
|||||||
image_pairs = [[query, doc] for doc in documents]
|
image_pairs = [[query, doc] for doc in documents]
|
||||||
|
|
||||||
scores = model.compute_score(image_pairs, max_length=2048, doc_type="image")
|
scores = model.compute_score(image_pairs, max_length=2048, doc_type="image")
|
||||||
# [0.8576154708862305, 0.9356858730316162, 0.8496521711349487, 0.8664582967758179]
|
# [0.49375027418136597, 0.7889736890792847, 0.47813892364501953, 0.5210812091827393]
|
||||||
```
|
```
|
||||||
|
|
||||||
**B. Textual Documents Reranking**
|
**B. Textual Documents Reranking**
|
||||||
@ -201,7 +209,7 @@ Compared to `jina-reranker-v2-base-multilingual`, `jina-reranker-m0` significant
|
|||||||
The scores will be a list of floats, where each float represents the relevance score of the corresponding document to the query. Higher scores indicate higher relevance.
|
The scores will be a list of floats, where each float represents the relevance score of the corresponding document to the query. Higher scores indicate higher relevance.
|
||||||
For instance the returning scores in this case will be:
|
For instance the returning scores in this case will be:
|
||||||
```bash
|
```bash
|
||||||
[0.9127850532531738, 0.8384682536125183, 0.8870794177055359, 0.842738926410675]
|
[0.6839263439178467, 0.4432148039340973, 0.5904013514518738, 0.45481112599372864]
|
||||||
```
|
```
|
||||||
|
|
||||||
**C. Image Querying for Textual Documents**
|
**C. Image Querying for Textual Documents**
|
||||||
@ -218,10 +226,29 @@ Compared to `jina-reranker-v2-base-multilingual`, `jina-reranker-m0` significant
|
|||||||
"Die wichtigsten Beiträge unserer Arbeit sind zweifach: Erstens führen wir eine neuartige dreistufige Datensynthese-Pipeline namens Draft-Refine-Critique ein, die durch iterative Verfeinerung hochwertige Trainingsdaten generiert; und zweitens schlagen wir eine umfassende Trainingsstrategie vor, die kontinuierliches Vortraining zur Längenerweiterung, überwachtes Feintuning mit spezialisierten Kontrollpunkten, direkte Präferenzoptimierung (DPO) und iteratives Self-Play-Tuning kombiniert. Um die weitere Forschung und Anwendung der strukturierten Inhaltsextraktion zu erleichtern, ist das Modell auf Hugging Face öffentlich verfügbar.",
|
"Die wichtigsten Beiträge unserer Arbeit sind zweifach: Erstens führen wir eine neuartige dreistufige Datensynthese-Pipeline namens Draft-Refine-Critique ein, die durch iterative Verfeinerung hochwertige Trainingsdaten generiert; und zweitens schlagen wir eine umfassende Trainingsstrategie vor, die kontinuierliches Vortraining zur Längenerweiterung, überwachtes Feintuning mit spezialisierten Kontrollpunkten, direkte Präferenzoptimierung (DPO) und iteratives Self-Play-Tuning kombiniert. Um die weitere Forschung und Anwendung der strukturierten Inhaltsextraktion zu erleichtern, ist das Modell auf Hugging Face öffentlich verfügbar.",
|
||||||
]
|
]
|
||||||
# reverse the order of the query and document
|
# reverse the order of the query and document
|
||||||
image_pairs = [[doc, query] for doc in documents]
|
image_pairs = [[query, doc] for doc in documents]
|
||||||
scores = model.compute_score(image_pairs, max_length=2048, doc_type="text")
|
scores = model.compute_score(image_pairs, max_length=2048, query_type="image", doc_type="text")
|
||||||
|
|
||||||
# [0.9048659801483154, 0.8266222476959229, 0.8326289653778076, 0.9075747132301331]
|
# [0.98099285364151, 0.7701883316040039, 0.5637142062187195, 0.9308615922927856]
|
||||||
|
```
|
||||||
|
|
||||||
|
**D. Image Querying for Image Documents**
|
||||||
|
|
||||||
|
The model also supports querying image documents with an image query. You can use the following code snippet:
|
||||||
|
|
||||||
|
```python
|
||||||
|
query = "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/paper-11.png"
|
||||||
|
|
||||||
|
documents = [
|
||||||
|
"https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/handelsblatt-preview.png",
|
||||||
|
"https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/paper-11.png",
|
||||||
|
"https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/wired-preview.png",
|
||||||
|
"https://jina.ai/blog-banner/using-deepseek-r1-reasoning-model-in-deepsearch.webp"
|
||||||
|
]
|
||||||
|
|
||||||
|
image_pairs = [[query, doc] for doc in documents]
|
||||||
|
scores = model.compute_score(image_pairs, max_length=2048, doc_type="image", query_type='image')
|
||||||
|
# [0.6275860667228699, 0.9922324419021606, 0.8090347051620483, 0.7941296100616455]
|
||||||
```
|
```
|
||||||
|
|
||||||
# Model Performance
|
# Model Performance
|
||||||
|
|||||||
26
modeling.py
26
modeling.py
@ -10,7 +10,7 @@ from transformers.image_utils import load_image
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
LOGIT_SCALE = 0.68
|
LOGIT_BIAS = 2.65 # logit bias for sigmoid normalization
|
||||||
|
|
||||||
def load_images(images, lazy_load: bool = True):
|
def load_images(images, lazy_load: bool = True):
|
||||||
# Disable PIL DecompositionBomb threshold for reading large images.
|
# Disable PIL DecompositionBomb threshold for reading large images.
|
||||||
@ -123,7 +123,7 @@ class JinaVLForRanking(Qwen2VLForConditionalGeneration):
|
|||||||
pairs: Union[List[Tuple[str, str]], Tuple[str, str]],
|
pairs: Union[List[Tuple[str, str]], Tuple[str, str]],
|
||||||
batch_size: int = 8,
|
batch_size: int = 8,
|
||||||
max_length: int = 10240,
|
max_length: int = 10240,
|
||||||
max_query_length: int = 1024,
|
max_query_length: int = 512,
|
||||||
max_doc_length: Optional[int] = None,
|
max_doc_length: Optional[int] = None,
|
||||||
query_type: str = 'text',
|
query_type: str = 'text',
|
||||||
doc_type: str = 'text',
|
doc_type: str = 'text',
|
||||||
@ -183,10 +183,24 @@ class JinaVLForRanking(Qwen2VLForConditionalGeneration):
|
|||||||
batch_inputs.append(formatting_prompts_func(q, d, query_type=query_type, doc_type=doc_type))
|
batch_inputs.append(formatting_prompts_func(q, d, query_type=query_type, doc_type=doc_type))
|
||||||
|
|
||||||
batch_images = None
|
batch_images = None
|
||||||
|
# if doc_type == 'image':
|
||||||
|
# batch_images = load_images([d for (q, d) in mini_batch])
|
||||||
|
# elif query_type == 'image':
|
||||||
|
# batch_images = load_images([q for (q, d) in mini_batch])
|
||||||
|
|
||||||
|
doc_images = []
|
||||||
|
query_images = []
|
||||||
if doc_type == 'image':
|
if doc_type == 'image':
|
||||||
batch_images = load_images([d for (q, d) in mini_batch])
|
doc_images = load_images([d for (q, d) in mini_batch])
|
||||||
elif query_type == 'image':
|
if query_type == 'image':
|
||||||
batch_images = load_images([q for (q, d) in mini_batch])
|
query_images = load_images([q for (q, d) in mini_batch])
|
||||||
|
|
||||||
|
if len(doc_images) == len(query_images) and len(doc_images) > 0:
|
||||||
|
batch_images = [[d, q] for q, d in zip(query_images, doc_images)]
|
||||||
|
elif len(doc_images) > 0:
|
||||||
|
batch_images = doc_images
|
||||||
|
elif len(query_images) > 0:
|
||||||
|
batch_images = query_images
|
||||||
|
|
||||||
batch = self._processor(
|
batch = self._processor(
|
||||||
text=batch_inputs,
|
text=batch_inputs,
|
||||||
@ -219,7 +233,7 @@ class JinaVLForRanking(Qwen2VLForConditionalGeneration):
|
|||||||
scores = self.forward(**batch).view(-1).cpu().float().numpy()
|
scores = self.forward(**batch).view(-1).cpu().float().numpy()
|
||||||
|
|
||||||
# normalize scores to [0, 1] with sigmoid with a scale
|
# normalize scores to [0, 1] with sigmoid with a scale
|
||||||
scores = 1.0 / (1.0 + np.exp(-scores * LOGIT_SCALE))
|
scores = 1.0 / (1.0 + np.exp(-(scores - LOGIT_BIAS)))
|
||||||
|
|
||||||
all_scores.extend(scores.tolist())
|
all_scores.extend(scores.tolist())
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user