Source code for langchain.retrievers.multi_vector
from enum import Enum
from typing import Dict, List, Optional
from langchain_core.callbacks import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain_core.documents import Document
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.retrievers import BaseRetriever
from langchain_core.stores import BaseStore, ByteStore
from langchain_core.vectorstores import VectorStore
from langchain.storage._lc_store import create_kv_docstore
[docs]
class SearchType(str, Enum):
"""Enumerator of the types of search to perform."""
similarity = "similarity"
"""Similarity search."""
similarity_score_threshold = "similarity_score_threshold"
"""Similarity search with a score threshold."""
mmr = "mmr"
"""Maximal Marginal Relevance reranking of similarity search."""
[docs]
class MultiVectorRetriever(BaseRetriever):
"""Retrieve from a set of multiple embeddings for the same document."""
vectorstore: VectorStore
"""The underlying vectorstore to use to store small chunks
and their embedding vectors"""
byte_store: Optional[ByteStore] = None
"""The lower-level backing storage layer for the parent documents"""
docstore: BaseStore[str, Document]
"""The storage interface for the parent documents"""
id_key: str = "doc_id"
search_kwargs: dict = Field(default_factory=dict)
"""Keyword arguments to pass to the search function."""
search_type: SearchType = SearchType.similarity
"""Type of search to perform (similarity / mmr)"""
@root_validator(pre=True)
def shim_docstore(cls, values: Dict) -> Dict:
byte_store = values.get("byte_store")
docstore = values.get("docstore")
if byte_store is not None:
docstore = create_kv_docstore(byte_store)
elif docstore is None:
raise Exception("You must pass a `byte_store` parameter.")
values["docstore"] = docstore
return values
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
"""Get documents relevant to a query.
Args:
query: String to find relevant documents for
run_manager: The callbacks handler to use
Returns:
List of relevant documents
"""
if self.search_type == SearchType.mmr:
sub_docs = self.vectorstore.max_marginal_relevance_search(
query, **self.search_kwargs
)
elif self.search_type == SearchType.similarity_score_threshold:
sub_docs_and_similarities = (
self.vectorstore.similarity_search_with_relevance_scores(
query, **self.search_kwargs
)
)
sub_docs = [sub_doc for sub_doc, _ in sub_docs_and_similarities]
else:
sub_docs = self.vectorstore.similarity_search(query, **self.search_kwargs)
# We do this to maintain the order of the ids that are returned
ids = []
for d in sub_docs:
if self.id_key in d.metadata and d.metadata[self.id_key] not in ids:
ids.append(d.metadata[self.id_key])
docs = self.docstore.mget(ids)
return [d for d in docs if d is not None]
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
"""Asynchronously get documents relevant to a query.
Args:
query: String to find relevant documents for
run_manager: The callbacks handler to use
Returns:
List of relevant documents
"""
if self.search_type == SearchType.mmr:
sub_docs = await self.vectorstore.amax_marginal_relevance_search(
query, **self.search_kwargs
)
elif self.search_type == SearchType.similarity_score_threshold:
sub_docs_and_similarities = (
await self.vectorstore.asimilarity_search_with_relevance_scores(
query, **self.search_kwargs
)
)
sub_docs = [sub_doc for sub_doc, _ in sub_docs_and_similarities]
else:
sub_docs = await self.vectorstore.asimilarity_search(
query, **self.search_kwargs
)
# We do this to maintain the order of the ids that are returned
ids = []
for d in sub_docs:
if self.id_key in d.metadata and d.metadata[self.id_key] not in ids:
ids.append(d.metadata[self.id_key])
docs = await self.docstore.amget(ids)
return [d for d in docs if d is not None]