Source code for langchain.output_parsers.pandas_dataframe

import re
from typing import Any, Dict, List, Tuple, Union

from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers.base import BaseOutputParser
from langchain_core.pydantic_v1 import validator

from langchain.output_parsers.format_instructions import (
    PANDAS_DATAFRAME_FORMAT_INSTRUCTIONS,
)


[docs] class PandasDataFrameOutputParser(BaseOutputParser[Dict[str, Any]]): """Parse an output using Pandas DataFrame format.""" """The Pandas DataFrame to parse.""" dataframe: Any @validator("dataframe") def validate_dataframe(cls, val: Any) -> Any: import pandas as pd if issubclass(type(val), pd.DataFrame): return val if pd.DataFrame(val).empty: raise ValueError("DataFrame cannot be empty.") raise TypeError( "Wrong type for 'dataframe', must be a subclass \ of Pandas DataFrame (pd.DataFrame)" )
[docs] def parse_array( self, array: str, original_request_params: str ) -> Tuple[List[Union[int, str]], str]: parsed_array: List[Union[int, str]] = [] # Check if the format is [1,3,5] if re.match(r"\[\d+(,\s*\d+)*\]", array): parsed_array = [int(i) for i in re.findall(r"\d+", array)] # Check if the format is [1..5] elif re.match(r"\[(\d+)\.\.(\d+)\]", array): match = re.match(r"\[(\d+)\.\.(\d+)\]", array) if match: start, end = map(int, match.groups()) parsed_array = list(range(start, end + 1)) else: raise OutputParserException( f"Unable to parse the array provided in {array}. \ Please check the format instructions." ) # Check if the format is ["column_name"] elif re.match(r"\[[a-zA-Z0-9_]+(?:,[a-zA-Z0-9_]+)*\]", array): match = re.match(r"\[[a-zA-Z0-9_]+(?:,[a-zA-Z0-9_]+)*\]", array) if match: parsed_array = list(map(str, match.group().strip("[]").split(","))) else: raise OutputParserException( f"Unable to parse the array provided in {array}. \ Please check the format instructions." ) # Validate the array if not parsed_array: raise OutputParserException( f"Invalid array format in '{original_request_params}'. \ Please check the format instructions." ) elif ( isinstance(parsed_array[0], int) and parsed_array[-1] > self.dataframe.index.max() ): raise OutputParserException( f"The maximum index {parsed_array[-1]} exceeds the maximum index of \ the Pandas DataFrame {self.dataframe.index.max()}." ) return parsed_array, original_request_params.split("[")[0]
[docs] def parse(self, request: str) -> Dict[str, Any]: stripped_request_params = None splitted_request = request.strip().split(":") if len(splitted_request) != 2: raise OutputParserException( f"Request '{request}' is not correctly formatted. \ Please refer to the format instructions." ) result = {} try: request_type, request_params = splitted_request if request_type in {"Invalid column", "Invalid operation"}: raise OutputParserException( f"{request}. Please check the format instructions." ) array_exists = re.search(r"(\[.*?\])", request_params) if array_exists: parsed_array, stripped_request_params = self.parse_array( array_exists.group(1), request_params ) if request_type == "column": filtered_df = self.dataframe[ self.dataframe.index.isin(parsed_array) ] if len(parsed_array) == 1: result[stripped_request_params] = filtered_df[ stripped_request_params ].iloc[parsed_array[0]] else: result[stripped_request_params] = filtered_df[ stripped_request_params ] elif request_type == "row": filtered_df = self.dataframe[ self.dataframe.columns.intersection(parsed_array) ] if len(parsed_array) == 1: result[stripped_request_params] = filtered_df.iloc[ int(stripped_request_params) ][parsed_array[0]] else: result[stripped_request_params] = filtered_df.iloc[ int(stripped_request_params) ] else: filtered_df = self.dataframe[ self.dataframe.index.isin(parsed_array) ] result[request_type] = getattr( filtered_df[stripped_request_params], request_type )() else: if request_type == "column": result[request_params] = self.dataframe[request_params] elif request_type == "row": result[request_params] = self.dataframe.iloc[int(request_params)] else: result[request_type] = getattr( self.dataframe[request_params], request_type )() except (AttributeError, IndexError, KeyError): if request_type not in {"column", "row"}: raise OutputParserException( f"Unsupported request type '{request_type}'. \ Please check the format instructions." ) raise OutputParserException( f"""Requested index { request_params if stripped_request_params is None else stripped_request_params } is out of bounds.""" ) return result
[docs] def get_format_instructions(self) -> str: return PANDAS_DATAFRAME_FORMAT_INSTRUCTIONS.format( columns=", ".join(self.dataframe.columns) )