unitxt.inference module¶
- class unitxt.inference.CrossProviderInferenceEngine(data_classification_policy: List[str] = None, model: str = __required__, frequency_penalty: float | NoneType = None, presence_penalty: float | NoneType = None, max_tokens: int | NoneType = None, seed: int | NoneType = None, stop: str | NoneType | List[str] = None, temperature: float | NoneType = None, top_p: float | NoneType = None, top_logprobs: int | NoneType = None, logit_bias: Dict[str, int] | NoneType = None, logprobs: bool | NoneType = None, n: int | NoneType = None, parallel_tool_calls: bool | NoneType = None, service_tier: Literal['auto', 'default'] | NoneType = None, provider: Literal['watsonx', 'together-ai', 'open-ai', 'aws', 'ollama', 'bam', 'watsonx-sdk'] | NoneType = None, provider_model_map: Dict[Literal['watsonx', 'together-ai', 'open-ai', 'aws', 'ollama', 'bam', 'watsonx-sdk'], Dict[str, str]] = {'watsonx': {'llama-3-8b-instruct': 'watsonx/meta-llama/llama-3-8b-instruct', 'llama-3-70b-instruct': 'watsonx/meta-llama/llama-3-70b-instruct', 'granite-3-8b-instruct': 'watsonx/ibm/granite-3-8b-instruct', 'flan-t5-xxl': 'watsonx/google/flan-t5-xxl', 'llama-3-2-1b-instruct': 'watsonx/meta-llama/llama-3-2-1b-instruct'}, 'watsonx-sdk': {'llama-3-8b-instruct': 'meta-llama/llama-3-8b-instruct', 'llama-3-70b-instruct': 'meta-llama/llama-3-70b-instruct', 'granite-3-8b-instruct': 'ibm/granite-3-8b-instruct'}, 'together-ai': {'llama-3-8b-instruct': 'together_ai/togethercomputer/llama-3-8b-instruct', 'llama-3-70b-instruct': 'together_ai/togethercomputer/llama-3-70b-instruct', 'llama-3-2-1b-instruct': 'together_ai/togethercomputer/llama-3-2-1b-instruct'}, 'aws': {'llama-3-8b-instruct': 'bedrock/meta.llama3-8b-instruct-v1:0', 'llama-3-70b-instruct': 'bedrock/meta.llama3-70b-instruct-v1:0'}, 'ollama': {'llama-3-8b-instruct': 'llama3:8b', 'llama-3-70b-instruct': 'llama3:70b'}, 'bam': {'granite-3-8b-instruct': 'ibm/granite-8b-instruct-preview-4k', 'llama-3-8b-instruct': 'meta-llama/llama-3-8b-instruct', 'llama-3-2-1b-instruct': 'meta-llama/llama-3-2-1b-instruct', 'flan-t5-xxl': 'google/flan-t5-xxl'}})[source]¶
Bases:
InferenceEngine
,StandardAPIParamsMixin
Inference engine capable of dynamically switching between multiple providers APIs.
This class extends the InferenceEngine and OpenAiInferenceEngineParamsMixin to enable seamless integration with various API providers. The supported APIs are specified in _supported_apis, allowing users to interact with multiple models from different sources. The api_model_map dictionary maps each API to specific model identifiers, enabling automatic configuration based on user requests.
- provider¶
Optional; Specifies the current API in use. Must be one of the literals in _supported_apis.
- Type:
Literal[‘watsonx’, ‘together-ai’, ‘open-ai’, ‘aws’, ‘ollama’, ‘bam’, ‘watsonx-sdk’] | None
- provider_model_map¶
Dictionary mapping each supported API to a corresponding model identifier string. This mapping allows consistent access to models across different API backends.
- Type:
Dict[Literal[‘watsonx’, ‘together-ai’, ‘open-ai’, ‘aws’, ‘ollama’, ‘bam’, ‘watsonx-sdk’], Dict[str, str]]
- provider_model_map: Dict[Literal['watsonx', 'together-ai', 'open-ai', 'aws', 'ollama', 'bam', 'watsonx-sdk'], Dict[str, str]] = {'aws': {'llama-3-70b-instruct': 'bedrock/meta.llama3-70b-instruct-v1:0', 'llama-3-8b-instruct': 'bedrock/meta.llama3-8b-instruct-v1:0'}, 'bam': {'flan-t5-xxl': 'google/flan-t5-xxl', 'granite-3-8b-instruct': 'ibm/granite-8b-instruct-preview-4k', 'llama-3-2-1b-instruct': 'meta-llama/llama-3-2-1b-instruct', 'llama-3-8b-instruct': 'meta-llama/llama-3-8b-instruct'}, 'ollama': {'llama-3-70b-instruct': 'llama3:70b', 'llama-3-8b-instruct': 'llama3:8b'}, 'together-ai': {'llama-3-2-1b-instruct': 'together_ai/togethercomputer/llama-3-2-1b-instruct', 'llama-3-70b-instruct': 'together_ai/togethercomputer/llama-3-70b-instruct', 'llama-3-8b-instruct': 'together_ai/togethercomputer/llama-3-8b-instruct'}, 'watsonx': {'flan-t5-xxl': 'watsonx/google/flan-t5-xxl', 'granite-3-8b-instruct': 'watsonx/ibm/granite-3-8b-instruct', 'llama-3-2-1b-instruct': 'watsonx/meta-llama/llama-3-2-1b-instruct', 'llama-3-70b-instruct': 'watsonx/meta-llama/llama-3-70b-instruct', 'llama-3-8b-instruct': 'watsonx/meta-llama/llama-3-8b-instruct'}, 'watsonx-sdk': {'granite-3-8b-instruct': 'ibm/granite-3-8b-instruct', 'llama-3-70b-instruct': 'meta-llama/llama-3-70b-instruct', 'llama-3-8b-instruct': 'meta-llama/llama-3-8b-instruct'}}¶
- class unitxt.inference.GenericInferenceEngine(data_classification_policy: List[str] = None, default: str | NoneType = None)[source]¶
Bases:
InferenceEngine
,ArtifactFetcherMixin
- class unitxt.inference.HFLlavaInferenceEngine(data_classification_policy: List[str] = None, lazy_load: bool = True, model_name: str = __required__, max_new_tokens: int = __required__)[source]¶
Bases:
InferenceEngine
,LazyLoadMixin
- class unitxt.inference.HFOptionSelectingInferenceEngine(data_classification_policy: List[str] = None, model_name: str = __required__, batch_size: int = __required__)[source]¶
Bases:
InferenceEngine
HuggingFace based class for inference engines that calculate log probabilities.
This class uses models from the HuggingFace Transformers library to calculate log probabilities for text inputs.
- class unitxt.inference.HFPipelineBasedInferenceEngine(data_classification_policy: List[str] = None, lazy_load: bool = False, _requirements_list: List[str] | Dict[str, str] = {'transformers': "Install huggingface package using 'pip install --upgrade transformers"}, model_name: str = __required__, max_new_tokens: int = __required__, use_fp16: bool = True, batch_size: int = 1, top_k: int | NoneType = None)[source]¶
Bases:
InferenceEngine
,PackageRequirementsMixin
,LazyLoadMixin
- class unitxt.inference.IbmGenAiInferenceEngine(data_classification_policy: List[str] = ['public', 'proprietary'], _requirements_list: List[str] | Dict[str, str] = {'ibm-generative-ai': "Install ibm-genai package using 'pip install --upgrade ibm-generative-ai"}, beam_width: int | NoneType = None, decoding_method: Literal['greedy', 'sample'] | NoneType = None, include_stop_sequence: bool | NoneType = None, length_penalty: Any = None, max_new_tokens: int | NoneType = None, min_new_tokens: int | NoneType = None, random_seed: int | NoneType = None, repetition_penalty: float | NoneType = None, return_options: Any = None, stop_sequences: List[str] | NoneType = None, temperature: float | NoneType = None, time_limit: int | NoneType = None, top_k: int | NoneType = None, top_p: float | NoneType = None, truncate_input_tokens: int | NoneType = None, typical_p: float | NoneType = None, label: str = 'ibm_genai', model_name: str = __required__, parameters: IbmGenAiInferenceEngineParams | NoneType = None)[source]¶
Bases:
InferenceEngine
,IbmGenAiInferenceEngineParamsMixin
,PackageRequirementsMixin
,LogProbInferenceEngine
,OptionSelectingByLogProbsInferenceEngine
- data_classification_policy: List[str] = ['public', 'proprietary']¶
- class unitxt.inference.IbmGenAiInferenceEngineParamsMixin(data_classification_policy: List[str] = None, beam_width: int | NoneType = None, decoding_method: Literal['greedy', 'sample'] | NoneType = None, include_stop_sequence: bool | NoneType = None, length_penalty: Any = None, max_new_tokens: int | NoneType = None, min_new_tokens: int | NoneType = None, random_seed: int | NoneType = None, repetition_penalty: float | NoneType = None, return_options: Any = None, stop_sequences: List[str] | NoneType = None, temperature: float | NoneType = None, time_limit: int | NoneType = None, top_k: int | NoneType = None, top_p: float | NoneType = None, truncate_input_tokens: int | NoneType = None, typical_p: float | NoneType = None)[source]¶
Bases:
Artifact
- class unitxt.inference.InferenceEngine(data_classification_policy: List[str] = None)[source]¶
Bases:
Artifact
Abstract base class for inference.
- class unitxt.inference.LMMSEvalBaseInferenceEngine(data_classification_policy: List[str] = None, lazy_load: bool = False, _requirements_list: List[str] | Dict[str, str] = ['lmms-eval==0.2.4'], model_type: str = __required__, model_args: Dict[str, str] = __required__, batch_size: int = 1)[source]¶
Bases:
InferenceEngine
,PackageRequirementsMixin
,LazyLoadMixin
- class unitxt.inference.LMMSEvalInferenceEngine(data_classification_policy: List[str] = None, lazy_load: bool = False, _requirements_list: List[str] | Dict[str, str] = ['lmms-eval==0.2.4'], model_type: str = __required__, model_args: Dict[str, str] = __required__, batch_size: int = 1, max_new_tokens: int = 32, temperature: float = 0.0, do_sample: bool = False, generate_until: List[str] = ['\n\n'])[source]¶
Bases:
LMMSEvalBaseInferenceEngine
- generate_until: List[str] = ['\n\n']¶
- class unitxt.inference.LMMSEvalLoglikelihoodInferenceEngine(data_classification_policy: List[str] = None, lazy_load: bool = False, _requirements_list: List[str] | Dict[str, str] = ['lmms-eval==0.2.4'], model_type: str = __required__, model_args: Dict[str, str] = __required__, batch_size: int = 1, request_type: Literal['loglikelihood'] = 'loglikelihood')[source]¶
Bases:
LMMSEvalBaseInferenceEngine
- class unitxt.inference.LazyLoadMixin(data_classification_policy: List[str] = None, lazy_load: bool = False)[source]¶
Bases:
Artifact
- class unitxt.inference.LiteLLMInferenceEngine(data_classification_policy: List[str] = None, _requirements_list: list = ['litellm', 'tenacity', 'tqdm', 'diskcache'], model: str = __required__, frequency_penalty: float | NoneType = None, presence_penalty: float | NoneType = None, max_tokens: int | NoneType = None, seed: int | NoneType = None, stop: str | NoneType | List[str] = None, temperature: float | NoneType = None, top_p: float | NoneType = None, top_logprobs: int | NoneType = None, logit_bias: Dict[str, int] | NoneType = None, logprobs: bool | NoneType = None, n: int | NoneType = None, parallel_tool_calls: bool | NoneType = None, service_tier: Literal['auto', 'default'] | NoneType = None, max_requests_per_second: float = 6, max_retries: int = 5)[source]¶
Bases:
InferenceEngine
,StandardAPIParamsMixin
,PackageRequirementsMixin
- class unitxt.inference.LogProbInferenceEngine(data_classification_policy: List[str] = None)[source]¶
Bases:
ABC
,Artifact
Abstract base class for inference with log probs.
- class unitxt.inference.MockInferenceEngine(data_classification_policy: List[str] = None, model_name: str = __required__, default_inference_value: str = '[[10]]')[source]¶
Bases:
InferenceEngine
- class unitxt.inference.MockModeMixin(data_classification_policy: List[str] = None, mock_mode: bool = False)[source]¶
Bases:
Artifact
- class unitxt.inference.OllamaInferenceEngine(data_classification_policy: List[str] = ['public', 'proprietary'], _requirements_list: List[str] | Dict[str, str] = {'ollama': "Install ollama package using 'pip install --upgrade ollama"}, model: str = __required__, frequency_penalty: float | NoneType = None, presence_penalty: float | NoneType = None, max_tokens: int | NoneType = None, seed: int | NoneType = None, stop: str | NoneType | List[str] = None, temperature: float | NoneType = None, top_p: float | NoneType = None, top_logprobs: int | NoneType = None, logit_bias: Dict[str, int] | NoneType = None, logprobs: bool | NoneType = None, n: int | NoneType = None, parallel_tool_calls: bool | NoneType = None, service_tier: Literal['auto', 'default'] | NoneType = None, label: str = 'ollama')[source]¶
Bases:
InferenceEngine
,StandardAPIParamsMixin
,PackageRequirementsMixin
- data_classification_policy: List[str] = ['public', 'proprietary']¶
- class unitxt.inference.OpenAiInferenceEngine(data_classification_policy: List[str] = ['public'], _requirements_list: List[str] | Dict[str, str] = {'openai': "Install openai package using 'pip install --upgrade openai"}, frequency_penalty: float | NoneType = None, presence_penalty: float | NoneType = None, max_tokens: int | NoneType = None, seed: int | NoneType = None, stop: str | NoneType | List[str] = None, temperature: float | NoneType = None, top_p: float | NoneType = None, top_logprobs: int | NoneType = 20, logit_bias: Dict[str, int] | NoneType = None, logprobs: bool | NoneType = True, n: int | NoneType = None, parallel_tool_calls: bool | NoneType = None, service_tier: Literal['auto', 'default'] | NoneType = None, label: str = 'openai', model_name: str = __required__, parameters: OpenAiInferenceEngineParams | NoneType = None)[source]¶
Bases:
InferenceEngine
,LogProbInferenceEngine
,OpenAiInferenceEngineParamsMixin
,PackageRequirementsMixin
- data_classification_policy: List[str] = ['public']¶
- class unitxt.inference.OpenAiInferenceEngineParamsMixin(data_classification_policy: List[str] = None, frequency_penalty: float | NoneType = None, presence_penalty: float | NoneType = None, max_tokens: int | NoneType = None, seed: int | NoneType = None, stop: str | NoneType | List[str] = None, temperature: float | NoneType = None, top_p: float | NoneType = None, top_logprobs: int | NoneType = 20, logit_bias: Dict[str, int] | NoneType = None, logprobs: bool | NoneType = True, n: int | NoneType = None, parallel_tool_calls: bool | NoneType = None, service_tier: Literal['auto', 'default'] | NoneType = None)[source]¶
Bases:
Artifact
- class unitxt.inference.OptionSelectingByLogProbsInferenceEngine[source]¶
Bases:
object
OptionSelectingByLogProbsInferenceEngine inference engine is used to select an option based on the logprobs of an options list conditioned by a prompt.
The inference engines that inherit from this class must implement get_token_count and get_options_log_probs.
- abstract get_options_log_probs(dataset)[source]¶
Get the token logprobs of the options of the key task_data.options of each dict of the dataset.
Add to each instance in the data a “options_log_prob” field, which is a dict with str as key and a list of {text: str, logprob:float}.
- Parameters:
dataset (List[Dict[str, Any]]) – A list of dictionaries, each representing a data instance.
- Returns:
The token count of the texts
- Return type:
List[int]
- abstract get_token_count(dataset)[source]¶
Get the token count of the source key of each dict of the dataset. Add to each instance in the data a “token_count” field.
- Parameters:
dataset (List[Dict[str, Any]]) – A list of dictionaries, each representing a data instance.
- Returns:
The token count of the texts
- Return type:
List[int]
- class unitxt.inference.StandardAPIParamsMixin(data_classification_policy: List[str] = None, model: str = __required__, frequency_penalty: float | NoneType = None, presence_penalty: float | NoneType = None, max_tokens: int | NoneType = None, seed: int | NoneType = None, stop: str | NoneType | List[str] = None, temperature: float | NoneType = None, top_p: float | NoneType = None, top_logprobs: int | NoneType = None, logit_bias: Dict[str, int] | NoneType = None, logprobs: bool | NoneType = None, n: int | NoneType = None, parallel_tool_calls: bool | NoneType = None, service_tier: Literal['auto', 'default'] | NoneType = None)[source]¶
Bases:
Artifact
- class unitxt.inference.TextGenerationInferenceOutput(prediction: str | List[Dict[str, Any]], input_tokens: int | None = None, output_tokens: int | None = None, model_name: str | None = None, inference_type: str | None = None)[source]¶
Bases:
object
Contains the prediction results and metadata for the inference.
Args: prediction (Union[str, List[Dict[str, Any]]]): If this is the result of an _infer call, the string predicted by the model. If this is the results of an _infer_log_probs call, a list of dictionaries. The i’th dictionary represents the i’th token in the response. The entry “top_tokens” in the dictionary holds a sorted list of the top tokens for this position and their probabilities. For example: [ {.. “top_tokens”: [ {“text”: “a”, ‘logprob’: }, {“text”: “b”, ‘logprob’: } ….]},
{.. “top_tokens”: [ {“text”: “c”, ‘logprob’: }, {“text”: “d”, ‘logprob’: } ….]}
]
input_tokens (int) : number of input tokens to the model. output_tokens (int) : number of output tokens to the model. model_name (str): the model_name as kept in the InferenceEngine. inference_type (str): The label stating the type of the InferenceEngine.
- class unitxt.inference.TogetherAiInferenceEngine(data_classification_policy: List[str] = ['public'], _requirements_list: List[str] | Dict[str, str] = {'together': "Install together package using 'pip install --upgrade together"}, max_tokens: int | NoneType = None, stop: List[str] | NoneType = None, temperature: float | NoneType = None, top_p: float | NoneType = None, top_k: int | NoneType = None, repetition_penalty: float | NoneType = None, logprobs: int | NoneType = None, echo: bool | NoneType = None, n: int | NoneType = None, min_p: float | NoneType = None, presence_penalty: float | NoneType = None, frequency_penalty: float | NoneType = None, label: str = 'together', model_name: str = __required__, parameters: unitxt.inference.TogetherAiInferenceEngineParamsMixin | NoneType = None)[source]¶
Bases:
InferenceEngine
,TogetherAiInferenceEngineParamsMixin
,PackageRequirementsMixin
- data_classification_policy: List[str] = ['public']¶
- class unitxt.inference.TogetherAiInferenceEngineParamsMixin(data_classification_policy: List[str] = None, max_tokens: int | NoneType = None, stop: List[str] | NoneType = None, temperature: float | NoneType = None, top_p: float | NoneType = None, top_k: int | NoneType = None, repetition_penalty: float | NoneType = None, logprobs: int | NoneType = None, echo: bool | NoneType = None, n: int | NoneType = None, min_p: float | NoneType = None, presence_penalty: float | NoneType = None, frequency_penalty: float | NoneType = None)[source]¶
Bases:
Artifact
- class unitxt.inference.VLLMInferenceEngine(data_classification_policy: List[str] = None, model: str = __required__, frequency_penalty: float | NoneType = None, presence_penalty: float | NoneType = None, max_tokens: int | NoneType = None, seed: int | NoneType = None, stop: str | NoneType | List[str] = None, temperature: float | NoneType = None, top_p: float | NoneType = None, top_logprobs: int | NoneType = None, logit_bias: Dict[str, int] | NoneType = None, logprobs: bool | NoneType = None, n: int | NoneType = None, parallel_tool_calls: bool | NoneType = None, service_tier: Literal['auto', 'default'] | NoneType = None, _requirements_list: List[str] | Dict[str, str] = [])[source]¶
Bases:
InferenceEngine
,PackageRequirementsMixin
,StandardAPIParamsMixin
- class unitxt.inference.VLLMRemoteInferenceEngine(data_classification_policy: List[str] = ['public'], _requirements_list: List[str] | Dict[str, str] = {'openai': "Install openai package using 'pip install --upgrade openai"}, frequency_penalty: float | NoneType = None, presence_penalty: float | NoneType = None, max_tokens: int | NoneType = None, seed: int | NoneType = None, stop: str | NoneType | List[str] = None, temperature: float | NoneType = None, top_p: float | NoneType = None, top_logprobs: int | NoneType = 20, logit_bias: Dict[str, int] | NoneType = None, logprobs: bool | NoneType = True, n: int | NoneType = None, parallel_tool_calls: bool | NoneType = None, service_tier: Literal['auto', 'default'] | NoneType = None, label: str = 'vllm', model_name: str = __required__, parameters: OpenAiInferenceEngineParams | NoneType = None)[source]¶
Bases:
OpenAiInferenceEngine
- class unitxt.inference.WMLInferenceEngine(data_classification_policy: List[str] = ['public', 'proprietary'], _requirements_list: List[str] | Dict[str, str] = {'ibm-watsonx-ai==1.1.14': "Install ibm-watsonx-ai package using 'pip install --upgrade ibm-watsonx-ai'. It is advised to have Python version >=3.10 installed, as at lower version this package may cause conflicts with other installed packages."}, decoding_method: Literal['greedy', 'sample'] | NoneType = None, length_penalty: Dict[str, float | int] | NoneType = None, temperature: float | NoneType = None, top_p: float | NoneType = None, top_k: int | NoneType = None, random_seed: int | NoneType = None, repetition_penalty: float | NoneType = None, min_new_tokens: int | NoneType = None, max_new_tokens: int | NoneType = None, stop_sequences: List[str] | NoneType = None, time_limit: int | NoneType = None, truncate_input_tokens: int | NoneType = None, prompt_variables: Dict[str, Any] | NoneType = None, return_options: Dict[str, bool] | NoneType = None, credentials: Dict[Literal['url', 'apikey', 'project_id'], str] | NoneType = None, model_name: str | NoneType = None, deployment_id: str | NoneType = None, label: str = 'wml', parameters: WMLInferenceEngineParams | NoneType = None, concurrency_limit: int = 10, _client: Any = None)[source]¶
Bases:
InferenceEngine
,WMLInferenceEngineParamsMixin
,PackageRequirementsMixin
,LogProbInferenceEngine
,OptionSelectingByLogProbsInferenceEngine
Runs inference using ibm-watsonx-ai.
- credentials¶
By default, it is created by a class instance which tries to retrieve proper environment variables (“WML_URL”, “WML_PROJECT_ID”, “WML_APIKEY”). However, a dictionary with the following keys: “url”, “apikey”, “project_id” can be directly provided instead.
- Type:
Dict[str, str], optional
- model_name¶
ID of a model to be used for inference. Mutually exclusive with ‘deployment_id’.
- Type:
str, optional
- deployment_id¶
Deployment ID of a tuned model to be used for inference. Mutually exclusive with ‘model_name’.
- Type:
str, optional
- parameters¶
Instance of WMLInferenceEngineParams which defines inference parameters and their values. Deprecated attribute, please pass respective parameters directly to the WMLInferenceEngine class instead.
- Type:
WMLInferenceEngineParams, optional
- concurrency_limit¶
number of requests that will be sent in parallel, max is 10.
- Type:
int
Examples
from .api import load_dataset
- wml_credentials = {
“url”: “some_url”, “project_id”: “some_id”, “api_key”: “some_key”
} model_name = “google/flan-t5-xxl” wml_inference = WMLInferenceEngine(
credentials=wml_credentials, model_name=model_name, data_classification_policy=[“public”], top_p=0.5, random_seed=123,
)
- dataset = load_dataset(
dataset_query=”card=cards.argument_topic,template_card_index=0,loader_limit=5”
) results = wml_inference.infer(dataset[“test”])
- data_classification_policy: List[str] = ['public', 'proprietary']¶
- class unitxt.inference.WMLInferenceEngineParamsMixin(data_classification_policy: List[str] = None, decoding_method: Literal['greedy', 'sample'] | NoneType = None, length_penalty: Dict[str, float | int] | NoneType = None, temperature: float | NoneType = None, top_p: float | NoneType = None, top_k: int | NoneType = None, random_seed: int | NoneType = None, repetition_penalty: float | NoneType = None, min_new_tokens: int | NoneType = None, max_new_tokens: int | NoneType = None, stop_sequences: List[str] | NoneType = None, time_limit: int | NoneType = None, truncate_input_tokens: int | NoneType = None, prompt_variables: Dict[str, Any] | NoneType = None, return_options: Dict[str, bool] | NoneType = None)[source]¶
Bases:
Artifact