-
Notifications
You must be signed in to change notification settings - Fork 2.3k
feat: Add support for SparseEncoder
and sparse embedding models in Sentence Transformers
#9588
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
feat: Add support for SparseEncoder
and sparse embedding models in Sentence Transformers
#9588
Conversation
Hello and thanks for this idea! I think it's a big topic and will probably require some work. Some high-level notes:
|
Hey, ping me when you need another review. In the meantime, feel free to:
💙 |
Hey @anakin87, sure. Thank you, I'll ping you when this PR is ready for review. I probably manage to finish it this week, if there will be no urgent tasks at work |
@anakin87 Hey, I think it's finished. Ran tests locally, passed. However, could you please help me with the formatting? Something strange on my side, because format tests found many errors, though when I used And, just in case, if you need a code snippet to check if new sparse models work:
|
I'll take a look in the next few days. @Ryzhtus please ping me if I forget to do that. |
Pull Request Test Coverage Report for Build 17290431721Details
💛 - Coveralls |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I fixed format and left a few comments.
Please also adjust types. You can run mypy locally with hatch run test:types
.
(Reminder to myself: if we add integration tests, follow the process for slow/unstable)
def get_embedding_backend( # pylint: disable=too-many-positional-arguments | ||
model: str, | ||
device: Optional[str] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def get_embedding_backend( # pylint: disable=too-many-positional-arguments | |
model: str, | |
device: Optional[str] = None, | |
def get_embedding_backend( | |
*, | |
model: str, | |
device: Optional[str] = None, |
could you explore using only keyword args? This would probably imply updating some other code.
Class to manage Sparse embeddings from Sentence Transformers. | ||
""" | ||
|
||
def __init__( # pylint: disable=too-many-positional-arguments |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's use keyword args if possible
def embed(self, data: list[str], **kwargs) -> list[SparseEmbedding]: | ||
embeddings = self.model.encode(data, **kwargs).coalesce() | ||
|
||
rows, columns = embeddings.indices() | ||
values = embeddings.values() | ||
batch_size = embeddings.size(0) | ||
|
||
sparse_embeddings: list[SparseEmbedding] = [] | ||
for embedding in range(batch_size): | ||
mask = rows == embedding | ||
embedding_columns = columns[mask].tolist() | ||
embedding_values = values[mask].tolist() | ||
sparse_embeddings.append(SparseEmbedding(indices=embedding_columns, values=embedding_values)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Please add some high-level comments to explain this code
- We need to test it
@@ -50,10 +52,51 @@ def get_embedding_backend( # pylint: disable=too-many-positional-arguments | |||
config_kwargs=config_kwargs, | |||
backend=backend, | |||
) | |||
|
|||
_SentenceTransformersEmbeddingBackendFactory._instances[embedding_backend_id] = embedding_backend | |||
return embedding_backend | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would prefer to put the new classes and logic in a new module: sentence_transformers_sparse_backend.py
for doc, emb in zip(documents, embeddings): | ||
doc.embedding = emb |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for doc, emb in zip(documents, embeddings): | |
doc.embedding = emb | |
for doc, emb in zip(documents, embeddings): | |
doc.sparse_embedding = emb |
Let's put the sparse embedding in the corresponding field and update docstrings as needed.
if self.tokenizer_kwargs and self.tokenizer_kwargs.get("model_max_length"): | ||
self.embedding_backend.model.max_seq_length = self.tokenizer_kwargs["model_max_length"] | ||
|
||
@component.output_types(embedding=list[float]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@component.output_types(embedding=list[float]) | |
@component.output_types(sparse_embedding=SparseEmbedding) |
show_progress_bar=self.progress_bar, | ||
**(self.encode_kwargs if self.encode_kwargs else {}), | ||
)[0] | ||
return {"embedding": embedding} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return {"embedding": embedding} | |
return {"sparse_embedding": embedding} |
@@ -5,9 +5,11 @@ | |||
from unittest.mock import patch | |||
|
|||
import pytest | |||
import torch | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's create a new module: test_sentence_transformers_sparse_embedding_backend.py
tokenizer_kwargs=None, | ||
config_kwargs=None, | ||
backend="torch", | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's add single integration test
tokenizer_kwargs=None, | ||
config_kwargs=None, | ||
backend="torch", | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's add single integration test
Related Issues
SentenceTransformers introduced support for sparse embedding models via the SparseEncoder class in v5.0.0. I thought it would be cool to support these in Haystack as well, since sparse models were previously only available through the FastEmbed integration (e.g. FastembedSparseTextEmbedder)
Proposed Changes:
Introduced two new embedder classes and also a class to manage these embedding classes:
SentenceTransformersSparseTextEmbedder
SentenceTransformersSparseDocumentEmbedder
SentenceTransformersSparseEncoderEmbeddingBackend
How did you test it?
I added unit tests for both embedders
Notes for the reviewer
Some tests are currently failing — I’d appreciate your support in resolving them.
And we’ll likely need to add documentation as well.
Checklist
fix:
,feat:
,build:
,chore:
,ci:
,docs:
,style:
,refactor:
,perf:
,test:
and added!
in case the PR includes breaking changes.