unitxt.text2sql_utils module

class unitxt.text2sql_utils.Cache[source]

Bases: object

A class that provides disk-based caching functionality for a given function.

async async_get_or_set(key, compute_fn, no_cache=False, refresh=False)[source]
async_memoize(key_func=<function generate_cache_key>, no_cache=False, refresh=False)[source]
get_or_set(key, compute_fn, no_cache=False, refresh=False)[source]
memoize(key_func=<function generate_cache_key>, no_cache=False, refresh=False)[source]
class unitxt.text2sql_utils.DatabaseConnector(db_config: SQLDatabase)[source]

Bases: ABC

Abstract base class for database connectors.

abstract execute_query(query: str) Any[source]

Abstract method to execute a query against the database.

abstract get_table_schema() str[source]

Abstract method to get database schema.

class unitxt.text2sql_utils.InMemoryDatabaseConnector(db_config: SQLDatabase)[source]

Bases: DatabaseConnector

Database connector for mocking databases with in-memory data structures.

execute_query(query: str) Any[source]

Simulates executing a query against the mock database.

get_table_schema(select_tables: List[str] | None = None) str[source]

Generates a mock schema from the tables structure.

class unitxt.text2sql_utils.LocalSQLiteConnector(db_config: SQLDatabase)[source]

Bases: DatabaseConnector

Database connector for SQLite databases.

download_database(db_id)[source]

Downloads the database from huggingface if needed.

execute_query(query: str) Any[source]

Executes a query against the SQLite database.

get_db_file_path(db_id)[source]

Gets the local path of a downloaded database file.

get_table_schema() str[source]

Extracts schema from an SQLite database.

class unitxt.text2sql_utils.RemoteDatabaseConnector(db_config: SQLDatabase)[source]

Bases: DatabaseConnector

Database connector for remote databases accessed via HTTP.

execute_query(query: str) Any[source]

Executes a query against the remote database, with retries for certain exceptions.

get_table_schema() str[source]

Retrieves the schema of a database.

class unitxt.text2sql_utils.SQLExecutionResult(execution_accuracy: int, non_empty_execution_accuracy: int, subset_non_empty_execution_accuracy: int, execution_accuracy_bird: int, non_empty_gold_df: int, gold_sql_runtime: float, predicted_sql_runtime: float, pred_to_gold_runtime_ratio: float, gold_error: int, predicted_error: int, gold_df_json: str, predicted_df_json: str, error_message: str)[source]

Bases: object

class unitxt.text2sql_utils.SQLNonExecutionMetricResult(sqlglot_validity: int, sqlparse_validity: int, sqlglot_equivalence: int, sqlglot_optimized_equivalence: int, sqlparse_equivalence: int, sql_exact_match: int, sql_syntactic_equivalence: int)[source]

Bases: object

unitxt.text2sql_utils.collect_clause(statement, clause_keyword)[source]

Parse SQL statement and collect clauses.

unitxt.text2sql_utils.compare_dfs_bird_eval_logic(df1: DataFrame, df2: DataFrame)[source]

Check if two SQL query result sets are exactly equal, as in BIRD evaluation.

This function checks if the set of rows returned by the predicted SQL query (predicted_res) is exactly equal to the set of rows returned by the ground truth SQL query (ground_truth_res). This is the logic used in the original BIRD evaluation code: https://github.com/AlibabaResearch/DAMO-ConvAI/blob/main/bird/llm/src/evaluation.py.

unitxt.text2sql_utils.compare_dfs_ignore_colnames_ordered_rows(df1: DataFrame, df2: DataFrame) bool[source]
unitxt.text2sql_utils.compare_dfs_ignore_colnames_subset(df1: DataFrame, df2: DataFrame, ignore_row_order: bool = True) bool[source]

Checks if the smaller of the two DataFrames is likely a subset of the other.

Subset comparison is column-based, to support Text2SQL evaluation for when the predicted SQL dataframe has missing or additional columns. Each row is treated as a multiset of (stringified) values, and the function checks if every row in the smaller DataFrame (by column count) is a multiset subset of the corresponding row in the larger DataFrame. When ground truth SQL does not have ORDER BY, ignore_row_order can be set to True to ignore the order of rows. In this case, column values are sorted before comparison. This means that there could be cases where the dataframes have the exact same number of rows and column values after sort are the same, but the dataframes are not actually a subset of each other. This is unlikely to happen in practice, but the score is not guaranteed to be 100% accurate and may overestimate the accuracy.

Parameters:
  • df1 (pd.DataFrame) – The first DataFrame to compare.

  • df2 (pd.DataFrame) – The second DataFrame to compare.

  • ignore_row_order (bool, optional) – If True, ignores the order of rows by sorting them before comparison. Defaults to True.

Returns:

True if the smaller DataFrame (column-wise) is likely a subset of the

larger one, False otherwise.

Return type:

bool

unitxt.text2sql_utils.compare_dfs_ignore_colnames_unordered_rows(df1: DataFrame, df2: DataFrame) bool[source]
unitxt.text2sql_utils.compare_result_dfs(gold_df: DataFrame, pred_df: DataFrame, gold_sql: str) Tuple[int, int, int][source]

Compares two DataFrames representing SQL query results.

Parameters:
  • gold_df (pd.DataFrame) – The ground truth DataFrame.

  • pred_df (pd.DataFrame) – The predicted DataFrame.

  • gold_sql (str) – The ground truth SQL query string.

Returns:

A tuple containing:
  • match (int): 1 if the predicted DataFrame matches the gold DataFrame

  • non_empty_match (int): 1 if both DataFrames are non-empty and match, 0 otherwise.

  • subset_match (int): 1 if the predicted DataFrame is a subset or superset of the gold DataFrame.

Return type:

Tuple[int, int, int]

Notes

  • The comparison ignores column names.

  • Row order is considered only if ‘ORDER BY’ is present in the SQL query.

unitxt.text2sql_utils.execute_query_local(db_path: str, query: str) Any

Executes a query against the SQLite database.

unitxt.text2sql_utils.execute_query_remote(api_url: str, database_id: str, api_key: str, query: str, retryable_exceptions: tuple = (<class 'requests.exceptions.ConnectionError'>, <class 'requests.exceptions.ReadTimeout'>), max_retries: int = 3, retry_delay: int = 5, timeout: int = 30) -> (typing.Union[dict, NoneType], <class 'str'>)

Executes a query against the remote database, with retries for certain exceptions.

unitxt.text2sql_utils.extract_select_columns(statement)[source]

Parse SQL using sqlparse and extract columns.

unitxt.text2sql_utils.extract_select_info(sql: str)[source]

Parse SQL using sqlparse and return a dict of extracted columns and clauses.

unitxt.text2sql_utils.extract_sql_from_text(text: str) str[source]

Extracts the first SQL query from the given text.

Priority: 1. SQL inside fenced blocks like `sql ... ` 2. SQL starting on a new line or after a colon/label 3. SQL without semicolon

Returns:

The SQL query string, or an empty string if not found.

unitxt.text2sql_utils.generate_cache_key(*args, **kwargs)[source]

Generate a stable hashable cache key for various input types.

Parameters:
  • args – Positional arguments of the function.

  • kwargs – Keyword arguments of the function.

Returns:

A hashed key as a string.

unitxt.text2sql_utils.get_cache()[source]

Returns a singleton cache instance, initializing it if necessary.

unitxt.text2sql_utils.get_db_connector(db_type: str)[source]

Creates and returns the appropriate DatabaseConnector instance based on db_type.

unitxt.text2sql_utils.get_sql_execution_results(predicted_sql: str, gold_sql: str, connector, sql_timeout: float) SQLExecutionResult[source]

Execute and compare predicted and gold SQL queries, returning execution metrics.

Parameters:
  • predicted_sql (str) – The SQL query predicted by the model.

  • gold_sql (str) – The reference (gold) SQL query.

  • connector – Database connector object used to execute the queries.

  • sql_timeout (float) – Maximum time (in seconds) allowed for query execution.

Returns:

An object containing various execution metrics, including:
  • execution_accuracy (int): 1 if predicted and gold queries produce equivalent results, else 0.

  • non_empty_execution_accuracy (int): 1 if both queries produce non-empty and equivalent results, else 0.

  • subset_non_empty_execution_accuracy (int): 1 if predicted results are a subset or superset of gold results and non-empty, else 0. Subset comparison is column-based. This means that the predicted SQL dataframe can have missing or additional columns compared to the gold SQL dataframe.

  • execution_accuracy_bird (int): 1 if results match according to BIRD evaluation logic, else 0.

  • non_empty_gold_df (int): 1 if the gold query result is non-empty, else 0.

  • gold_sql_runtime (float): Execution time for the gold SQL query.

  • predicted_sql_runtime (float): Execution time for the predicted SQL query.

  • pred_to_gold_runtime_ratio (float): Ratio of predicted to gold query runtimes.

  • gold_error (int): 1 if the gold query failed, else 0.

  • predicted_error (int): 1 if the predicted query failed, else 0.

  • gold_df_json (str): JSON representation of the gold query result DataFrame.

  • predicted_df_json (str): JSON representation of the predicted query result DataFrame.

  • error_message (str): Error message if any query failed, else empty string.

Return type:

SQLExecutionResult

Notes

  • If the gold query fails, the function returns early with error details.

  • If the predicted query is identical or SQL-equivalent to the gold query, results are considered correct without execution.

  • Otherwise, both queries are executed and their results compared using multiple metrics.

unitxt.text2sql_utils.is_sqlglot_parsable(sql: str, db_type='sqlite') bool[source]

Returns True if sqlglot does not encounter any error, False otherwise.

unitxt.text2sql_utils.is_sqlparse_parsable(sql: str) bool[source]

Returns True if sqlparse does not encounter any error, False otherwise.

unitxt.text2sql_utils.replace_select_clause(source_query: str, target_query: str, dialect: str = 'postgres') str[source]

Replaces the SELECT clause of the target SQL query with the SELECT clause from the source SQL query.

Parameters:
  • source_query (str) – SQL query whose SELECT clause will be used.

  • target_query (str) – SQL query whose SELECT clause will be replaced.

  • dialect (str) – SQL dialect for parsing and rendering (default: “postgres”).

Returns:

A new SQL query with the SELECT clause of target_query replaced by that of source_query.

Return type:

str

Raises:

ValueError – If either query is not a valid SELECT statement.

Example

>>> replace_select_clause(
...     "SELECT id FROM employees",
...     "SELECT name FROM employees WHERE age > 30"
... )
"SELECT id FROM employees WHERE age > 30"
unitxt.text2sql_utils.run_query(sql: str, connector, sql_timeout: float) Tuple[DataFrame | None, float, str][source]

Executes a SQL query using the provided connector with a timeout.

Parameters:
  • sql (str) – The SQL query string to execute.

  • connector – An object with an execute_query method that executes the SQL query.

  • sql_timeout (float) – The maximum time in seconds to allow for query execution.

Returns:

  • A pandas DataFrame containing the query results, or None if an error occurred.

  • The duration in seconds taken to execute the query. 0.0 if an error.

  • An error message string if an error occurred, otherwise an empty string.

Return type:

Tuple[Optional[pd.DataFrame], float, str]

Notes

  • If the SQL string is empty or only whitespace, returns immediately with a message.

  • If the query execution exceeds the timeout, returns a timeout error message.

  • Any other exceptions are caught and returned as error messages.

unitxt.text2sql_utils.sql_exact_match(sql1: str, sql2: str) bool[source]

Return True if two SQL strings match after very basic normalization.

unitxt.text2sql_utils.sqlglot_optimized_equivalence(expected: str, generated: str) int[source]

Checks if SQL queries are equivalent using SQLGlot parsing, so we don’t run them.

unitxt.text2sql_utils.sqlglot_parsed_queries_equivalent(sql1: str, sql2: str, dialect: str = '') bool[source]

Return True if two SQL queries match after parsing with SQLGlot.

unitxt.text2sql_utils.sqlparse_queries_equivalent(sql1: str, sql2: str) bool[source]

Returns True if both SQL queries are naively considered equivalent.

unitxt.text2sql_utils.strip_alias(col: str) str[source]

Remove any AS alias from a column.