Source code for langchain_mistralai.embeddings

import asyncio
import logging
import warnings
from typing import Dict, Iterable, List, Optional

import httpx
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import (
    BaseModel,
    Extra,
    Field,
    SecretStr,
    root_validator,
)
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from tokenizers import Tokenizer  # type: ignore

logger = logging.getLogger(__name__)

MAX_TOKENS = 16_000
"""A batching parameter for the Mistral API. This is NOT the maximum number of tokens
accepted by the embedding model for each document/chunk, but rather the maximum number 
of tokens that can be sent in a single request to the Mistral API (across multiple
documents/chunks)"""


[docs] class DummyTokenizer: """Dummy tokenizer for when tokenizer cannot be accessed (e.g., via Huggingface)"""
[docs] def encode_batch(self, texts: List[str]) -> List[List[str]]: return [list(text) for text in texts]
[docs] class MistralAIEmbeddings(BaseModel, Embeddings): """MistralAI embedding models. To use, set the environment variable `MISTRAL_API_KEY` is set with your API key or pass it as a named parameter to the constructor. Example: .. code-block:: python from langchain_mistralai import MistralAIEmbeddings mistral = MistralAIEmbeddings( model="mistral-embed", api_key="my-api-key" ) """ client: httpx.Client = Field(default=None) #: :meta private: async_client: httpx.AsyncClient = Field(default=None) #: :meta private: mistral_api_key: Optional[SecretStr] = Field(default=None, alias="api_key") endpoint: str = "https://api.mistral.ai/v1/" max_retries: int = 5 timeout: int = 120 max_concurrent_requests: int = 64 tokenizer: Tokenizer = Field(default=None) model: str = "mistral-embed" class Config: extra = Extra.forbid arbitrary_types_allowed = True allow_population_by_field_name = True @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate configuration.""" values["mistral_api_key"] = convert_to_secret_str( get_from_dict_or_env( values, "mistral_api_key", "MISTRAL_API_KEY", default="" ) ) api_key_str = values["mistral_api_key"].get_secret_value() # todo: handle retries if not values.get("client"): values["client"] = httpx.Client( base_url=values["endpoint"], headers={ "Content-Type": "application/json", "Accept": "application/json", "Authorization": f"Bearer {api_key_str}", }, timeout=values["timeout"], ) # todo: handle retries and max_concurrency if not values.get("async_client"): values["async_client"] = httpx.AsyncClient( base_url=values["endpoint"], headers={ "Content-Type": "application/json", "Accept": "application/json", "Authorization": f"Bearer {api_key_str}", }, timeout=values["timeout"], ) if values["tokenizer"] is None: try: values["tokenizer"] = Tokenizer.from_pretrained( "mistralai/Mixtral-8x7B-v0.1" ) except IOError: # huggingface_hub GatedRepoError warnings.warn( "Could not download mistral tokenizer from Huggingface for " "calculating batch sizes. Set a Huggingface token via the " "HF_TOKEN environment variable to download the real tokenizer. " "Falling back to a dummy tokenizer that uses `len()`." ) values["tokenizer"] = DummyTokenizer() return values def _get_batches(self, texts: List[str]) -> Iterable[List[str]]: """Split a list of texts into batches of less than 16k tokens for Mistral API.""" batch: List[str] = [] batch_tokens = 0 text_token_lengths = [ len(encoded) for encoded in self.tokenizer.encode_batch(texts) ] for text, text_tokens in zip(texts, text_token_lengths): if batch_tokens + text_tokens > MAX_TOKENS: if len(batch) > 0: # edge case where first batch exceeds max tokens # should not yield an empty batch. yield batch batch = [text] batch_tokens = text_tokens else: batch.append(text) batch_tokens += text_tokens if batch: yield batch
[docs] def embed_documents(self, texts: List[str]) -> List[List[float]]: """Embed a list of document texts. Args: texts: The list of texts to embed. Returns: List of embeddings, one for each text. """ try: batch_responses = ( self.client.post( url="/embeddings", json=dict( model=self.model, input=batch, ), ) for batch in self._get_batches(texts) ) return [ list(map(float, embedding_obj["embedding"])) for response in batch_responses for embedding_obj in response.json()["data"] ] except Exception as e: logger.error(f"An error occurred with MistralAI: {e}") raise
[docs] async def aembed_documents(self, texts: List[str]) -> List[List[float]]: """Embed a list of document texts. Args: texts: The list of texts to embed. Returns: List of embeddings, one for each text. """ try: batch_responses = await asyncio.gather( *[ self.async_client.post( url="/embeddings", json=dict( model=self.model, input=batch, ), ) for batch in self._get_batches(texts) ] ) return [ list(map(float, embedding_obj["embedding"])) for response in batch_responses for embedding_obj in response.json()["data"] ] except Exception as e: logger.error(f"An error occurred with MistralAI: {e}") raise
[docs] def embed_query(self, text: str) -> List[float]: """Embed a single query text. Args: text: The text to embed. Returns: Embedding for the text. """ return self.embed_documents([text])[0]
[docs] async def aembed_query(self, text: str) -> List[float]: """Embed a single query text. Args: text: The text to embed. Returns: Embedding for the text. """ return (await self.aembed_documents([text]))[0]