from __future__ import annotations
import io
import json
from typing import (
TYPE_CHECKING,
Any,
Dict,
Iterator,
List,
Optional,
Sequence,
Tuple,
Union,
)
from google.api_core.exceptions import NotFound
from google.cloud import storage # type: ignore[attr-defined, unused-ignore]
from google.cloud.storage import ( # type: ignore[attr-defined, unused-ignore, import-untyped]
Blob,
transfer_manager,
)
from langchain_core.documents import Document
from langchain_core.stores import BaseStore
if TYPE_CHECKING:
from google.cloud import datastore # type: ignore[attr-defined, unused-ignore]
GCS_MAX_BATCH_SIZE = 100
[docs]
class DocumentStorage(BaseStore[str, Document]):
"""Abstract interface of a key, text storage for retrieving documents."""
[docs]
class GCSDocumentStorage(DocumentStorage):
"""Stores documents in Google Cloud Storage.
For each pair id, document_text the name of the blob will be {prefix}/{id} stored
in plain text format.
"""
[docs]
def __init__(
self,
bucket: storage.Bucket,
prefix: Optional[str] = "documents",
threaded=True,
n_threads=8,
) -> None:
"""Constructor.
Args:
bucket: Bucket where the documents will be stored.
prefix: Prefix that is prepended to all document names.
"""
super().__init__()
self._bucket = bucket
self._prefix = prefix
self._threaded = threaded
self._n_threads = n_threads
if threaded:
if not (int(n_threads) > 0 and int(n_threads) <= 50):
raise ValueError(
"n_threads must be a valid integer,"
" greater than 0 and lower than or equal to 50"
)
def _prepare_doc_for_bulk_upload(
self, key: str, value: Document
) -> Tuple[io.IOBase, Blob]:
document_json = value.dict()
document_text = json.dumps(document_json).encode("utf-8")
doc_contents = io.BytesIO(document_text)
blob_name = self._get_blob_name(key)
blob = self._bucket.blob(blob_name)
return doc_contents, blob
[docs]
def mset(self, key_value_pairs: Sequence[Tuple[str, Document]]) -> None:
"""Stores a series of documents using each keys
Args:
key_value_pairs (Sequence[Tuple[K, V]]): A sequence of key-value pairs.
"""
if self._threaded:
results = transfer_manager.upload_many(
[
self._prepare_doc_for_bulk_upload(key, value)
for key, value in key_value_pairs
],
skip_if_exists=False,
upload_kwargs=None,
deadline=None,
raise_exception=False,
worker_type="thread",
max_workers=self._n_threads,
)
for result in results:
# The results list is either `None` or an exception for each filename in
# the input list, in order.
if isinstance(result, Exception):
raise result
else:
for key, value in key_value_pairs:
self._set_one(key, value)
def _convert_bytes_to_doc(
self, doc: io.BytesIO, result: Any
) -> Union[Document, None]:
if isinstance(result, NotFound):
return None
elif result is None:
doc.seek(0)
raw_doc = doc.read()
data = raw_doc.decode("utf-8")
data_json = json.loads(data)
return Document(**data_json)
else:
raise Exception(
"Unexpected result type when batch getting multiple files from GCS"
)
[docs]
def mget(self, keys: Sequence[str]) -> List[Optional[Document]]:
"""Gets a batch of documents by id.
The default implementation only loops `get_by_id`.
Subclasses that have faster ways to retrieve data by batch should implement
this method.
Args:
ids: List of ids for the text.
Returns:
List of documents. If the key id is not found for any id record returns a
None instead.
"""
if self._threaded:
download_docs = [
(self._bucket.blob(self._get_blob_name(key)), io.BytesIO())
for key in keys
]
download_results = transfer_manager.download_many(
download_docs,
skip_if_exists=False,
download_kwargs=None,
deadline=None,
raise_exception=False,
worker_type="thread",
max_workers=self._n_threads,
)
for i, result in enumerate(download_results):
if isinstance(result, Exception) and not isinstance(result, NotFound):
raise result
return [
self._convert_bytes_to_doc(doc[1], result)
for doc, result in zip(download_docs, download_results)
]
else:
return [self._get_one(key) for key in keys]
[docs]
def mdelete(self, keys: Sequence[str]) -> None:
"""Deletes a batch of documents by id.
Args:
keys: List of ids for the text.
"""
for i in range(0, len(keys), GCS_MAX_BATCH_SIZE):
batch = keys[i : i + GCS_MAX_BATCH_SIZE]
with self._bucket.client.batch():
for key in batch:
self._delete_one(key)
[docs]
def yield_keys(self, *, prefix: str | None = None) -> Iterator[str]:
"""Yields the keys present in the storage.
Args:
prefix: Ignored. Uses the prefix provided in the constructor.
"""
for blob in self._bucket.list_blobs(prefix=self._prefix):
yield blob.name.split("/")[-1]
def _get_one(self, key: str) -> Document | None:
"""Gets the text of a document by its id. If not found, returns None.
Args:
key: Id of the document to get from the storage.
Returns:
Document if found, otherwise None.
"""
blob_name = self._get_blob_name(key)
existing_blob = self._bucket.get_blob(blob_name)
if existing_blob is None:
return None
document_str = existing_blob.download_as_text()
document_json: Dict[str, Any] = json.loads(document_str)
return Document(**document_json)
def _set_one(self, key: str, value: Document) -> None:
"""Stores a document text associated to a document_id.
Args:
key: Id of the document to be stored.
document: Document to be stored.
"""
blob_name = self._get_blob_name(key)
new_blow = self._bucket.blob(blob_name)
document_json = value.dict()
document_text = json.dumps(document_json)
new_blow.upload_from_string(document_text)
def _delete_one(self, key: str) -> None:
"""Deletes one document by its key.
Args:
key (str): Id of the document to delete.
"""
blob_name = self._get_blob_name(key)
blob = self._bucket.blob(blob_name)
blob.delete()
def _get_blob_name(self, document_id: str) -> str:
"""Builds a blob name using the prefix and the document_id.
Args:
document_id: Id of the document.
Returns:
Name of the blob that the document will be/is stored in
"""
return f"{self._prefix}/{document_id}"
[docs]
class DataStoreDocumentStorage(DocumentStorage):
"""Stores documents in Google Cloud DataStore."""
[docs]
def __init__(
self,
datastore_client: datastore.Client,
kind: str = "document_id",
text_property_name: str = "text",
metadata_property_name: str = "metadata",
) -> None:
"""Constructor.
Args:
bucket: Bucket where the documents will be stored.
prefix: Prefix that is prepended to all document names.
"""
super().__init__()
self._client = datastore_client
self._text_property_name = text_property_name
self._metadata_property_name = metadata_property_name
self._kind = kind
[docs]
def mget(self, keys: Sequence[str]) -> List[Optional[Document]]:
"""Gets a batch of documents by id.
Args:
ids: List of ids for the text.
Returns:
List of texts. If the key id is not found for any id record returns a None
instead.
"""
ds_keys = [self._client.key(self._kind, id_) for id_ in keys]
entities = self._client.get_multi(ds_keys)
# Entities are not sorted by key by default, the order is unclear. This orders
# the list by the id retrieved.
entity_id_lookup = {entity.key.id_or_name: entity for entity in entities}
entities = [entity_id_lookup.get(id_) for id_ in keys]
return [
Document(
page_content=entity[self._text_property_name],
metadata=self._convert_entity_to_dict(
entity[self._metadata_property_name]
),
)
if entity is not None
else None
for entity in entities
]
[docs]
def mset(self, key_value_pairs: Sequence[Tuple[str, Document]]) -> None:
"""Stores a series of documents using each keys
Args:
key_value_pairs (Sequence[Tuple[K, V]]): A sequence of key-value pairs.
"""
ids = [key for key, _ in key_value_pairs]
documents = [document for _, document in key_value_pairs]
with self._client.transaction():
keys = [self._client.key(self._kind, id_) for id_ in ids]
entities = []
for key, document in zip(keys, documents):
entity = self._client.entity(key=key)
entity[self._text_property_name] = document.page_content
entity[self._metadata_property_name] = document.metadata
entities.append(entity)
self._client.put_multi(entities)
[docs]
def mdelete(self, keys: Sequence[str]) -> None:
"""Deletes a sequence of documents by key.
Args:
keys (Sequence[str]): A sequence of keys to delete.
"""
with self._client.transaction():
keys = [self._client.key(self._kind, id_) for id_ in keys]
self._client.delete_multi(keys)
[docs]
def yield_keys(self, *, prefix: str | None = None) -> Iterator[str]:
"""Yields the keys of all documents in the storage.
Args:
prefix: Ignored
"""
query = self._client.query(kind=self._kind)
query.keys_only()
for entity in query.fetch():
yield str(entity.key.id_or_name)
def _convert_entity_to_dict(self, entity: datastore.Entity) -> Dict[str, Any]:
"""Recursively transform an entity into a plain dictionary."""
from google.cloud import datastore # type: ignore[attr-defined, unused-ignore]
dict_entity = dict(entity)
for key in dict_entity:
value = dict_entity[key]
if isinstance(value, datastore.Entity):
dict_entity[key] = self._convert_entity_to_dict(value)
return dict_entity