unitxt.inference module

class unitxt.inference.AsyncTokenBucket(rate, capacity)[source]

Bases: object

async acquire(tokens=1)[source]
class unitxt.inference.CrossProviderInferenceEngine(data_classification_policy: List[str] = None, model: str = __required__, frequency_penalty: float | NoneType = None, presence_penalty: float | NoneType = None, max_tokens: int | NoneType = None, seed: int | NoneType = None, stop: str | NoneType | List[str] = None, temperature: float | NoneType = None, top_p: float | NoneType = None, top_logprobs: int | NoneType = None, logit_bias: Dict[str, int] | NoneType = None, logprobs: bool | NoneType = None, n: int | NoneType = None, parallel_tool_calls: bool | NoneType = None, service_tier: Literal['auto', 'default'] | NoneType = None, provider: Literal['watsonx', 'together-ai', 'open-ai', 'aws', 'ollama', 'bam', 'watsonx-sdk'] | NoneType = None, provider_model_map: Dict[Literal['watsonx', 'together-ai', 'open-ai', 'aws', 'ollama', 'bam', 'watsonx-sdk'], Dict[str, str]] = {'watsonx': {'llama-3-8b-instruct': 'watsonx/meta-llama/llama-3-8b-instruct', 'llama-3-70b-instruct': 'watsonx/meta-llama/llama-3-70b-instruct', 'granite-3-8b-instruct': 'watsonx/ibm/granite-3-8b-instruct', 'flan-t5-xxl': 'watsonx/google/flan-t5-xxl', 'llama-3-2-1b-instruct': 'watsonx/meta-llama/llama-3-2-1b-instruct'}, 'watsonx-sdk': {'llama-3-8b-instruct': 'meta-llama/llama-3-8b-instruct', 'llama-3-70b-instruct': 'meta-llama/llama-3-70b-instruct', 'granite-3-8b-instruct': 'ibm/granite-3-8b-instruct'}, 'together-ai': {'llama-3-8b-instruct': 'together_ai/togethercomputer/llama-3-8b-instruct', 'llama-3-70b-instruct': 'together_ai/togethercomputer/llama-3-70b-instruct', 'llama-3-2-1b-instruct': 'together_ai/togethercomputer/llama-3-2-1b-instruct'}, 'aws': {'llama-3-8b-instruct': 'bedrock/meta.llama3-8b-instruct-v1:0', 'llama-3-70b-instruct': 'bedrock/meta.llama3-70b-instruct-v1:0'}, 'ollama': {'llama-3-8b-instruct': 'llama3:8b', 'llama-3-70b-instruct': 'llama3:70b'}, 'bam': {'granite-3-8b-instruct': 'ibm/granite-8b-instruct-preview-4k', 'llama-3-8b-instruct': 'meta-llama/llama-3-8b-instruct', 'llama-3-2-1b-instruct': 'meta-llama/llama-3-2-1b-instruct', 'flan-t5-xxl': 'google/flan-t5-xxl'}})[source]

Bases: InferenceEngine, StandardAPIParamsMixin

Inference engine capable of dynamically switching between multiple providers APIs.

This class extends the InferenceEngine and OpenAiInferenceEngineParamsMixin to enable seamless integration with various API providers. The supported APIs are specified in _supported_apis, allowing users to interact with multiple models from different sources. The api_model_map dictionary maps each API to specific model identifiers, enabling automatic configuration based on user requests.

provider

Optional; Specifies the current API in use. Must be one of the literals in _supported_apis.

Type:

Literal[‘watsonx’, ‘together-ai’, ‘open-ai’, ‘aws’, ‘ollama’, ‘bam’, ‘watsonx-sdk’] | None

provider_model_map

Dictionary mapping each supported API to a corresponding model identifier string. This mapping allows consistent access to models across different API backends.

Type:

Dict[Literal[‘watsonx’, ‘together-ai’, ‘open-ai’, ‘aws’, ‘ollama’, ‘bam’, ‘watsonx-sdk’], Dict[str, str]]

provider_model_map: Dict[Literal['watsonx', 'together-ai', 'open-ai', 'aws', 'ollama', 'bam', 'watsonx-sdk'], Dict[str, str]] = {'aws': {'llama-3-70b-instruct': 'bedrock/meta.llama3-70b-instruct-v1:0', 'llama-3-8b-instruct': 'bedrock/meta.llama3-8b-instruct-v1:0'}, 'bam': {'flan-t5-xxl': 'google/flan-t5-xxl', 'granite-3-8b-instruct': 'ibm/granite-8b-instruct-preview-4k', 'llama-3-2-1b-instruct': 'meta-llama/llama-3-2-1b-instruct', 'llama-3-8b-instruct': 'meta-llama/llama-3-8b-instruct'}, 'ollama': {'llama-3-70b-instruct': 'llama3:70b', 'llama-3-8b-instruct': 'llama3:8b'}, 'together-ai': {'llama-3-2-1b-instruct': 'together_ai/togethercomputer/llama-3-2-1b-instruct', 'llama-3-70b-instruct': 'together_ai/togethercomputer/llama-3-70b-instruct', 'llama-3-8b-instruct': 'together_ai/togethercomputer/llama-3-8b-instruct'}, 'watsonx': {'flan-t5-xxl': 'watsonx/google/flan-t5-xxl', 'granite-3-8b-instruct': 'watsonx/ibm/granite-3-8b-instruct', 'llama-3-2-1b-instruct': 'watsonx/meta-llama/llama-3-2-1b-instruct', 'llama-3-70b-instruct': 'watsonx/meta-llama/llama-3-70b-instruct', 'llama-3-8b-instruct': 'watsonx/meta-llama/llama-3-8b-instruct'}, 'watsonx-sdk': {'granite-3-8b-instruct': 'ibm/granite-3-8b-instruct', 'llama-3-70b-instruct': 'meta-llama/llama-3-70b-instruct', 'llama-3-8b-instruct': 'meta-llama/llama-3-8b-instruct'}}
class unitxt.inference.GenericInferenceEngine(data_classification_policy: List[str] = None, default: str | NoneType = None)[source]

Bases: InferenceEngine, ArtifactFetcherMixin

class unitxt.inference.HFLlavaInferenceEngine(data_classification_policy: List[str] = None, lazy_load: bool = True, model_name: str = __required__, max_new_tokens: int = __required__)[source]

Bases: InferenceEngine, LazyLoadMixin

class unitxt.inference.HFOptionSelectingInferenceEngine(data_classification_policy: List[str] = None, model_name: str = __required__, batch_size: int = __required__)[source]

Bases: InferenceEngine

HuggingFace based class for inference engines that calculate log probabilities.

This class uses models from the HuggingFace Transformers library to calculate log probabilities for text inputs.

class unitxt.inference.HFPipelineBasedInferenceEngine(data_classification_policy: List[str] = None, lazy_load: bool = False, _requirements_list: List[str] | Dict[str, str] = {'transformers': "Install huggingface package using 'pip install --upgrade transformers"}, model_name: str = __required__, max_new_tokens: int = __required__, use_fp16: bool = True, batch_size: int = 1, top_k: int | NoneType = None)[source]

Bases: InferenceEngine, PackageRequirementsMixin, LazyLoadMixin

class unitxt.inference.IbmGenAiInferenceEngine(data_classification_policy: List[str] = ['public', 'proprietary'], _requirements_list: List[str] | Dict[str, str] = {'ibm-generative-ai': "Install ibm-genai package using 'pip install --upgrade ibm-generative-ai"}, beam_width: int | NoneType = None, decoding_method: Literal['greedy', 'sample'] | NoneType = None, include_stop_sequence: bool | NoneType = None, length_penalty: Any = None, max_new_tokens: int | NoneType = None, min_new_tokens: int | NoneType = None, random_seed: int | NoneType = None, repetition_penalty: float | NoneType = None, return_options: Any = None, stop_sequences: List[str] | NoneType = None, temperature: float | NoneType = None, time_limit: int | NoneType = None, top_k: int | NoneType = None, top_p: float | NoneType = None, truncate_input_tokens: int | NoneType = None, typical_p: float | NoneType = None, label: str = 'ibm_genai', model_name: str = __required__, parameters: IbmGenAiInferenceEngineParams | NoneType = None)[source]

Bases: InferenceEngine, IbmGenAiInferenceEngineParamsMixin, PackageRequirementsMixin, LogProbInferenceEngine, OptionSelectingByLogProbsInferenceEngine

data_classification_policy: List[str] = ['public', 'proprietary']
class unitxt.inference.IbmGenAiInferenceEngineParamsMixin(data_classification_policy: List[str] = None, beam_width: int | NoneType = None, decoding_method: Literal['greedy', 'sample'] | NoneType = None, include_stop_sequence: bool | NoneType = None, length_penalty: Any = None, max_new_tokens: int | NoneType = None, min_new_tokens: int | NoneType = None, random_seed: int | NoneType = None, repetition_penalty: float | NoneType = None, return_options: Any = None, stop_sequences: List[str] | NoneType = None, temperature: float | NoneType = None, time_limit: int | NoneType = None, top_k: int | NoneType = None, top_p: float | NoneType = None, truncate_input_tokens: int | NoneType = None, typical_p: float | NoneType = None)[source]

Bases: Artifact

class unitxt.inference.InferenceEngine(data_classification_policy: List[str] = None)[source]

Bases: Artifact

Abstract base class for inference.

class unitxt.inference.LMMSEvalBaseInferenceEngine(data_classification_policy: List[str] = None, lazy_load: bool = False, _requirements_list: List[str] | Dict[str, str] = ['lmms-eval==0.2.4'], model_type: str = __required__, model_args: Dict[str, str] = __required__, batch_size: int = 1)[source]

Bases: InferenceEngine, PackageRequirementsMixin, LazyLoadMixin

class unitxt.inference.LMMSEvalInferenceEngine(data_classification_policy: List[str] = None, lazy_load: bool = False, _requirements_list: List[str] | Dict[str, str] = ['lmms-eval==0.2.4'], model_type: str = __required__, model_args: Dict[str, str] = __required__, batch_size: int = 1, max_new_tokens: int = 32, temperature: float = 0.0, do_sample: bool = False, generate_until: List[str] = ['\n\n'])[source]

Bases: LMMSEvalBaseInferenceEngine

generate_until: List[str] = ['\n\n']
class unitxt.inference.LMMSEvalLoglikelihoodInferenceEngine(data_classification_policy: List[str] = None, lazy_load: bool = False, _requirements_list: List[str] | Dict[str, str] = ['lmms-eval==0.2.4'], model_type: str = __required__, model_args: Dict[str, str] = __required__, batch_size: int = 1, request_type: Literal['loglikelihood'] = 'loglikelihood')[source]

Bases: LMMSEvalBaseInferenceEngine

class unitxt.inference.LazyLoadMixin(data_classification_policy: List[str] = None, lazy_load: bool = False)[source]

Bases: Artifact

class unitxt.inference.LiteLLMInferenceEngine(data_classification_policy: List[str] = None, _requirements_list: list = ['litellm', 'tenacity', 'tqdm', 'diskcache'], model: str = __required__, frequency_penalty: float | NoneType = None, presence_penalty: float | NoneType = None, max_tokens: int | NoneType = None, seed: int | NoneType = None, stop: str | NoneType | List[str] = None, temperature: float | NoneType = None, top_p: float | NoneType = None, top_logprobs: int | NoneType = None, logit_bias: Dict[str, int] | NoneType = None, logprobs: bool | NoneType = None, n: int | NoneType = None, parallel_tool_calls: bool | NoneType = None, service_tier: Literal['auto', 'default'] | NoneType = None, max_requests_per_second: float = 6, max_retries: int = 5)[source]

Bases: InferenceEngine, StandardAPIParamsMixin, PackageRequirementsMixin

class unitxt.inference.LogProbInferenceEngine(data_classification_policy: List[str] = None)[source]

Bases: ABC, Artifact

Abstract base class for inference with log probs.

class unitxt.inference.MockInferenceEngine(data_classification_policy: List[str] = None, model_name: str = __required__, default_inference_value: str = '[[10]]')[source]

Bases: InferenceEngine

class unitxt.inference.MockModeMixin(data_classification_policy: List[str] = None, mock_mode: bool = False)[source]

Bases: Artifact

class unitxt.inference.OllamaInferenceEngine(data_classification_policy: List[str] = ['public', 'proprietary'], _requirements_list: List[str] | Dict[str, str] = {'ollama': "Install ollama package using 'pip install --upgrade ollama"}, model: str = __required__, frequency_penalty: float | NoneType = None, presence_penalty: float | NoneType = None, max_tokens: int | NoneType = None, seed: int | NoneType = None, stop: str | NoneType | List[str] = None, temperature: float | NoneType = None, top_p: float | NoneType = None, top_logprobs: int | NoneType = None, logit_bias: Dict[str, int] | NoneType = None, logprobs: bool | NoneType = None, n: int | NoneType = None, parallel_tool_calls: bool | NoneType = None, service_tier: Literal['auto', 'default'] | NoneType = None, label: str = 'ollama')[source]

Bases: InferenceEngine, StandardAPIParamsMixin, PackageRequirementsMixin

data_classification_policy: List[str] = ['public', 'proprietary']
class unitxt.inference.OpenAiInferenceEngine(data_classification_policy: List[str] = ['public'], _requirements_list: List[str] | Dict[str, str] = {'openai': "Install openai package using 'pip install --upgrade openai"}, frequency_penalty: float | NoneType = None, presence_penalty: float | NoneType = None, max_tokens: int | NoneType = None, seed: int | NoneType = None, stop: str | NoneType | List[str] = None, temperature: float | NoneType = None, top_p: float | NoneType = None, top_logprobs: int | NoneType = 20, logit_bias: Dict[str, int] | NoneType = None, logprobs: bool | NoneType = True, n: int | NoneType = None, parallel_tool_calls: bool | NoneType = None, service_tier: Literal['auto', 'default'] | NoneType = None, label: str = 'openai', model_name: str = __required__, parameters: OpenAiInferenceEngineParams | NoneType = None)[source]

Bases: InferenceEngine, LogProbInferenceEngine, OpenAiInferenceEngineParamsMixin, PackageRequirementsMixin

data_classification_policy: List[str] = ['public']
class unitxt.inference.OpenAiInferenceEngineParamsMixin(data_classification_policy: List[str] = None, frequency_penalty: float | NoneType = None, presence_penalty: float | NoneType = None, max_tokens: int | NoneType = None, seed: int | NoneType = None, stop: str | NoneType | List[str] = None, temperature: float | NoneType = None, top_p: float | NoneType = None, top_logprobs: int | NoneType = 20, logit_bias: Dict[str, int] | NoneType = None, logprobs: bool | NoneType = True, n: int | NoneType = None, parallel_tool_calls: bool | NoneType = None, service_tier: Literal['auto', 'default'] | NoneType = None)[source]

Bases: Artifact

class unitxt.inference.OptionSelectingByLogProbsInferenceEngine[source]

Bases: object

OptionSelectingByLogProbsInferenceEngine inference engine is used to select an option based on the logprobs of an options list conditioned by a prompt.

The inference engines that inherit from this class must implement get_token_count and get_options_log_probs.

abstract get_options_log_probs(dataset)[source]

Get the token logprobs of the options of the key task_data.options of each dict of the dataset.

Add to each instance in the data a “options_log_prob” field, which is a dict with str as key and a list of {text: str, logprob:float}.

Parameters:

dataset (List[Dict[str, Any]]) – A list of dictionaries, each representing a data instance.

Returns:

The token count of the texts

Return type:

List[int]

abstract get_token_count(dataset)[source]

Get the token count of the source key of each dict of the dataset. Add to each instance in the data a “token_count” field.

Parameters:

dataset (List[Dict[str, Any]]) – A list of dictionaries, each representing a data instance.

Returns:

The token count of the texts

Return type:

List[int]

select(dataset: List[Dict[str, Any]]) List[Dict[str, Any]][source]

Calculate most likely labels based on log probabilities for a set of fixed completions.

class unitxt.inference.StandardAPIParamsMixin(data_classification_policy: List[str] = None, model: str = __required__, frequency_penalty: float | NoneType = None, presence_penalty: float | NoneType = None, max_tokens: int | NoneType = None, seed: int | NoneType = None, stop: str | NoneType | List[str] = None, temperature: float | NoneType = None, top_p: float | NoneType = None, top_logprobs: int | NoneType = None, logit_bias: Dict[str, int] | NoneType = None, logprobs: bool | NoneType = None, n: int | NoneType = None, parallel_tool_calls: bool | NoneType = None, service_tier: Literal['auto', 'default'] | NoneType = None)[source]

Bases: Artifact

class unitxt.inference.TextGenerationInferenceOutput(prediction: str | List[Dict[str, Any]], input_tokens: int | None = None, output_tokens: int | None = None, model_name: str | None = None, inference_type: str | None = None)[source]

Bases: object

Contains the prediction results and metadata for the inference.

Args: prediction (Union[str, List[Dict[str, Any]]]): If this is the result of an _infer call, the string predicted by the model. If this is the results of an _infer_log_probs call, a list of dictionaries. The i’th dictionary represents the i’th token in the response. The entry “top_tokens” in the dictionary holds a sorted list of the top tokens for this position and their probabilities. For example: [ {.. “top_tokens”: [ {“text”: “a”, ‘logprob’: }, {“text”: “b”, ‘logprob’: } ….]},

{.. “top_tokens”: [ {“text”: “c”, ‘logprob’: }, {“text”: “d”, ‘logprob’: } ….]}

]

input_tokens (int) : number of input tokens to the model. output_tokens (int) : number of output tokens to the model. model_name (str): the model_name as kept in the InferenceEngine. inference_type (str): The label stating the type of the InferenceEngine.

class unitxt.inference.TogetherAiInferenceEngine(data_classification_policy: List[str] = ['public'], _requirements_list: List[str] | Dict[str, str] = {'together': "Install together package using 'pip install --upgrade together"}, max_tokens: int | NoneType = None, stop: List[str] | NoneType = None, temperature: float | NoneType = None, top_p: float | NoneType = None, top_k: int | NoneType = None, repetition_penalty: float | NoneType = None, logprobs: int | NoneType = None, echo: bool | NoneType = None, n: int | NoneType = None, min_p: float | NoneType = None, presence_penalty: float | NoneType = None, frequency_penalty: float | NoneType = None, label: str = 'together', model_name: str = __required__, parameters: unitxt.inference.TogetherAiInferenceEngineParamsMixin | NoneType = None)[source]

Bases: InferenceEngine, TogetherAiInferenceEngineParamsMixin, PackageRequirementsMixin

data_classification_policy: List[str] = ['public']
class unitxt.inference.TogetherAiInferenceEngineParamsMixin(data_classification_policy: List[str] = None, max_tokens: int | NoneType = None, stop: List[str] | NoneType = None, temperature: float | NoneType = None, top_p: float | NoneType = None, top_k: int | NoneType = None, repetition_penalty: float | NoneType = None, logprobs: int | NoneType = None, echo: bool | NoneType = None, n: int | NoneType = None, min_p: float | NoneType = None, presence_penalty: float | NoneType = None, frequency_penalty: float | NoneType = None)[source]

Bases: Artifact

class unitxt.inference.VLLMInferenceEngine(data_classification_policy: List[str] = None, model: str = __required__, frequency_penalty: float | NoneType = None, presence_penalty: float | NoneType = None, max_tokens: int | NoneType = None, seed: int | NoneType = None, stop: str | NoneType | List[str] = None, temperature: float | NoneType = None, top_p: float | NoneType = None, top_logprobs: int | NoneType = None, logit_bias: Dict[str, int] | NoneType = None, logprobs: bool | NoneType = None, n: int | NoneType = None, parallel_tool_calls: bool | NoneType = None, service_tier: Literal['auto', 'default'] | NoneType = None, _requirements_list: List[str] | Dict[str, str] = [])[source]

Bases: InferenceEngine, PackageRequirementsMixin, StandardAPIParamsMixin

class unitxt.inference.VLLMRemoteInferenceEngine(data_classification_policy: List[str] = ['public'], _requirements_list: List[str] | Dict[str, str] = {'openai': "Install openai package using 'pip install --upgrade openai"}, frequency_penalty: float | NoneType = None, presence_penalty: float | NoneType = None, max_tokens: int | NoneType = None, seed: int | NoneType = None, stop: str | NoneType | List[str] = None, temperature: float | NoneType = None, top_p: float | NoneType = None, top_logprobs: int | NoneType = 20, logit_bias: Dict[str, int] | NoneType = None, logprobs: bool | NoneType = True, n: int | NoneType = None, parallel_tool_calls: bool | NoneType = None, service_tier: Literal['auto', 'default'] | NoneType = None, label: str = 'vllm', model_name: str = __required__, parameters: OpenAiInferenceEngineParams | NoneType = None)[source]

Bases: OpenAiInferenceEngine

class unitxt.inference.WMLInferenceEngine(data_classification_policy: List[str] = ['public', 'proprietary'], _requirements_list: List[str] | Dict[str, str] = {'ibm-watsonx-ai==1.1.14': "Install ibm-watsonx-ai package using 'pip install --upgrade ibm-watsonx-ai'. It is advised to have Python version >=3.10 installed, as at lower version this package may cause conflicts with other installed packages."}, decoding_method: Literal['greedy', 'sample'] | NoneType = None, length_penalty: Dict[str, float | int] | NoneType = None, temperature: float | NoneType = None, top_p: float | NoneType = None, top_k: int | NoneType = None, random_seed: int | NoneType = None, repetition_penalty: float | NoneType = None, min_new_tokens: int | NoneType = None, max_new_tokens: int | NoneType = None, stop_sequences: List[str] | NoneType = None, time_limit: int | NoneType = None, truncate_input_tokens: int | NoneType = None, prompt_variables: Dict[str, Any] | NoneType = None, return_options: Dict[str, bool] | NoneType = None, credentials: Dict[Literal['url', 'apikey', 'project_id'], str] | NoneType = None, model_name: str | NoneType = None, deployment_id: str | NoneType = None, label: str = 'wml', parameters: WMLInferenceEngineParams | NoneType = None, concurrency_limit: int = 10, _client: Any = None)[source]

Bases: InferenceEngine, WMLInferenceEngineParamsMixin, PackageRequirementsMixin, LogProbInferenceEngine, OptionSelectingByLogProbsInferenceEngine

Runs inference using ibm-watsonx-ai.

credentials

By default, it is created by a class instance which tries to retrieve proper environment variables (“WML_URL”, “WML_PROJECT_ID”, “WML_APIKEY”). However, a dictionary with the following keys: “url”, “apikey”, “project_id” can be directly provided instead.

Type:

Dict[str, str], optional

model_name

ID of a model to be used for inference. Mutually exclusive with ‘deployment_id’.

Type:

str, optional

deployment_id

Deployment ID of a tuned model to be used for inference. Mutually exclusive with ‘model_name’.

Type:

str, optional

parameters

Instance of WMLInferenceEngineParams which defines inference parameters and their values. Deprecated attribute, please pass respective parameters directly to the WMLInferenceEngine class instead.

Type:

WMLInferenceEngineParams, optional

concurrency_limit

number of requests that will be sent in parallel, max is 10.

Type:

int

Examples

from .api import load_dataset

wml_credentials = {

“url”: “some_url”, “project_id”: “some_id”, “api_key”: “some_key”

} model_name = “google/flan-t5-xxl” wml_inference = WMLInferenceEngine(

credentials=wml_credentials, model_name=model_name, data_classification_policy=[“public”], top_p=0.5, random_seed=123,

)

dataset = load_dataset(

dataset_query=”card=cards.argument_topic,template_card_index=0,loader_limit=5”

) results = wml_inference.infer(dataset[“test”])

data_classification_policy: List[str] = ['public', 'proprietary']
class unitxt.inference.WMLInferenceEngineParamsMixin(data_classification_policy: List[str] = None, decoding_method: Literal['greedy', 'sample'] | NoneType = None, length_penalty: Dict[str, float | int] | NoneType = None, temperature: float | NoneType = None, top_p: float | NoneType = None, top_k: int | NoneType = None, random_seed: int | NoneType = None, repetition_penalty: float | NoneType = None, min_new_tokens: int | NoneType = None, max_new_tokens: int | NoneType = None, stop_sequences: List[str] | NoneType = None, time_limit: int | NoneType = None, truncate_input_tokens: int | NoneType = None, prompt_variables: Dict[str, Any] | NoneType = None, return_options: Dict[str, bool] | NoneType = None)[source]

Bases: Artifact

unitxt.inference.get_images_without_text(instance)[source]
unitxt.inference.get_model_and_label_id(model_name, label)[source]
unitxt.inference.get_text_without_images(instance, image_token='<image>')[source]