unitxt.inference module¶
- class unitxt.inference.AzureOpenAIInferenceEngine(data_classification_policy: List[str] = ['public'], _requirements_list: List[str] | Dict[str, str] = {'openai': "Install openai package using 'pip install --upgrade openai"}, requirements: List[str] | Dict[str, str] = [], frequency_penalty: float | NoneType = None, presence_penalty: float | NoneType = None, max_tokens: int | NoneType = None, seed: int | NoneType = None, stop: str | NoneType | List[str] = None, temperature: float | NoneType = None, top_p: float | NoneType = None, top_logprobs: int | NoneType = 20, logit_bias: Dict[str, int] | NoneType = None, logprobs: bool | NoneType = True, n: int | NoneType = None, parallel_tool_calls: bool | NoneType = None, service_tier: Literal['auto', 'default'] | NoneType = None, cache_batch_size: int = 100, use_cache: bool = True, label: str = 'azure_openai', model_name: str = __required__, parameters: OpenAiInferenceEngineParams | NoneType = None, base_url: str | NoneType = None, default_headers: Dict[str, str] = {}, credentials: unitxt.inference.CredentialsOpenAi = {}, num_parallel_requests: int = 20)[source]¶
Bases:
OpenAiInferenceEngine
- class unitxt.inference.CrossProviderInferenceEngine(data_classification_policy: List[str] = None, model: str = __required__, frequency_penalty: float | NoneType = None, presence_penalty: float | NoneType = None, max_tokens: int | NoneType = None, seed: int | NoneType = None, stop: str | NoneType | List[str] = None, temperature: float | NoneType = None, top_p: float | NoneType = None, top_logprobs: int | NoneType = None, logit_bias: Dict[str, int] | NoneType = None, logprobs: bool | NoneType = None, n: int | NoneType = None, parallel_tool_calls: bool | NoneType = None, service_tier: Literal['auto', 'default'] | NoneType = None, credentials: Dict[str, str] | NoneType = {}, extra_headers: Dict[str, str] | NoneType = None, cache_batch_size: int = 100, use_cache: bool = True, label: str = 'cross_provider', provider: Literal['watsonx', 'together-ai', 'open-ai', 'aws', 'ollama', 'bam', 'watsonx-sdk', 'rits', 'azure', 'vertex-ai', 'replicate'] | NoneType = None, provider_specific_args: Dict[str, Dict[str, str]] | NoneType = None, provider_model_map: Dict[Literal['watsonx', 'together-ai', 'open-ai', 'aws', 'ollama', 'bam', 'watsonx-sdk', 'rits', 'azure', 'vertex-ai', 'replicate'], Dict[str, str]] = {'watsonx-sdk': {'granite-20b-code-instruct': 'ibm/granite-20b-code-instruct', 'granite-3-2-8b-instruct': 'ibm/granite-3-2-8b-instruct', 'granite-3-2b-instruct': 'ibm/granite-3-2b-instruct', 'granite-3-8b-instruct': 'ibm/granite-3-8b-instruct', 'granite-34b-code-instruct': 'ibm/granite-34b-code-instruct', 'granite-guardian-3-8b': 'ibm/granite-guardian-3-8b', 'granite-vision-3-2-2b': 'ibm/granite-vision-3-2-2b', 'llama-3-1-8b-instruct': 'meta-llama/llama-3-1-8b-instruct', 'llama-3-1-70b-instruct': 'meta-llama/llama-3-1-70b-instruct', 'llama-3-1-405b-instruct': 'meta-llama/llama-3-405b-instruct', 'llama-3-2-11b-vision-instruct': 'meta-llama/llama-3-2-11b-vision-instruct', 'llama-3-2-1b-instruct': 'meta-llama/llama-3-2-1b-instruct', 'llama-3-2-3b-instruct': 'meta-llama/llama-3-2-3b-instruct', 'llama-3-2-90b-vision-instruct': 'meta-llama/llama-3-2-90b-vision-instruct', 'llama-3-3-70b-instruct': 'meta-llama/llama-3-3-70b-instruct', 'llama-guard-3-11b-vision': 'meta-llama/llama-guard-3-11b-vision', 'mistral-large-instruct': 'mistralai/mistral-large', 'mixtral-8x7b-instruct-v01': 'mistralai/mixtral-8x7b-instruct-v01'}, 'together-ai': {'llama-3-8b-instruct': 'together_ai/meta-llama/Llama-3-8b-chat-hf', 'llama-3-70b-instruct': 'together_ai/meta-llama/Llama-3-70b-chat-hf', 'llama-3-1-8b-instruct': 'together_ai/meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo', 'llama-3-1-70b-instruct': 'together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo', 'llama-3-1-405b-instruct': 'together_ai/meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo', 'llama-3-2-1b-instruct': 'together_ai/togethercomputer/llama-3-2-1b-instruct', 'llama-3-3-70b-instruct': 'together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo'}, 'aws': {'llama-3-8b-instruct': 'bedrock/meta.llama3-8b-instruct-v1:0', 'llama-3-70b-instruct': 'bedrock/meta.llama3-70b-instruct-v1:0'}, 'ollama': {'llama-3-8b-instruct': 'llama3:8b', 'llama-3-70b-instruct': 'llama3:70b', 'llama-3-1-8b-instruct': 'llama3.1:8b', 'llama-3-1-70b-instruct': 'llama3.1:70b', 'llama-3-1-405b-instruct': 'llama3.1:405b', 'llama-3-2-1b-instruct': 'llama3.2:1b', 'llama-3-2-3b-instruct': 'llama3.2:3b', 'llama-3-3-70b-instruct': 'llama3.3'}, 'bam': {'granite-3-8b-instruct': 'ibm/granite-8b-instruct-preview-4k', 'llama-3-8b-instruct': 'meta-llama/llama-3-8b-instruct', 'llama-3-2-1b-instruct': 'meta-llama/llama-3-2-1b-instruct', 'flan-t5-xxl': 'google/flan-t5-xxl'}, 'rits': {'granite-3-8b-instruct': 'ibm-granite/granite-3.0-8b-instruct', 'granite-3-2-8b-instruct': 'ibm-granite/granite-3.2-8b-instruct', 'granite-3-3-8b-instruct': 'ibm-granite/granite-3.3-8b-instruct', 'llama-3-1-8b-instruct': 'meta-llama/llama-3-1-8b-instruct', 'llama-3-1-70b-instruct': 'meta-llama/llama-3-1-70b-instruct', 'llama-3-1-405b-instruct': 'meta-llama/llama-3-1-405b-instruct-fp8', 'llama-3-1-405b-instruct-fp8': 'meta-llama/llama-3-1-405b-instruct-fp8', 'llama-3-2-11b-vision-instruct': 'meta-llama/Llama-3.2-11B-Vision-Instruct', 'llama-3-2-90b-vision-instruct': 'meta-llama/Llama-3.2-90B-Vision-Instruct', 'llama-3-3-70b-instruct': 'meta-llama/llama-3-3-70b-instruct', 'mistral-large-instruct': 'mistralai/mistral-large-instruct-2407', 'mixtral-8x7b-instruct': 'mistralai/mixtral-8x7B-instruct-v0.1', 'deepseek-v3': 'deepseek-ai/DeepSeek-V3', 'granite-guardian-3-2-3b-a800m': 'ibm-granite/granite-guardian-3.2-3b-a800m', 'granite-guardian-3-2-5b': 'ibm-granite/granite-guardian-3.2-5b'}, 'open-ai': {'o1-mini': 'o1-mini', 'o1-preview': 'o1-preview', 'gpt-4o-mini': 'gpt-4o-mini', 'gpt-4o-mini-2024-07-18': 'gpt-4o-mini-2024-07-18', 'gpt-4o': 'gpt-4o', 'gpt-4o-2024-08-06': 'gpt-4o-2024-08-06', 'gpt-4o-2024-05-13': 'gpt-4o-2024-05-13', 'gpt-4-turbo-preview': 'gpt-4-0125-preview', 'gpt-4-turbo': 'gpt-4-turbo', 'gpt-4-0125-preview': 'gpt-4-0125-preview', 'gpt-4-1106-preview': 'gpt-4-1106-preview', 'gpt-3.5-turbo-1106': 'gpt-3.5-turbo-1106', 'gpt-3.5-turbo': 'gpt-3.5-turbo', 'gpt-3.5-turbo-0301': 'gpt-3.5-turbo-0301', 'gpt-3.5-turbo-0613': 'gpt-3.5-turbo-0613', 'gpt-3.5-turbo-16k': 'gpt-3.5-turbo-16k', 'gpt-3.5-turbo-16k-0613': 'gpt-3.5-turbo-16k-0613', 'gpt-4': 'gpt-4', 'gpt-4-0314': 'gpt-4-0314', 'gpt-4-0613': 'gpt-4-0613', 'gpt-4-32k': 'gpt-4-32k', 'gpt-4-32k-0314': 'gpt-4-32k-0314', 'gpt-4-32k-0613': 'gpt-4-32k-0613', 'gpt-4-vision-preview': 'gpt-4-vision-preview'}, 'azure': {'o1-mini': 'azure/o1-mini', 'o1-preview': 'azure/o1-preview', 'gpt-4o-mini': 'azure/gpt-4o-mini', 'gpt-4o': 'azure/gpt-4o', 'gpt-4o-2024-08-06': 'azure/gpt-4o-2024-08-06', 'gpt-4': 'azure/gpt-4', 'gpt-4-0314': 'azure/gpt-4-0314', 'gpt-4-0613': 'azure/gpt-4-0613', 'gpt-4-32k': 'azure/gpt-4-32k', 'gpt-4-32k-0314': 'azure/gpt-4-32k-0314', 'gpt-4-32k-0613': 'azure/gpt-4-32k-0613', 'gpt-4-1106-preview': 'azure/gpt-4-1106-preview', 'gpt-4-0125-preview': 'azure/gpt-4-0125-preview', 'gpt-4-turbo': 'azure/gpt-4-turbo-2024-04-09', 'gpt-3.5-turbo': 'azure/gpt-3.5-turbo', 'gpt-3.5-turbo-0301': 'azure/gpt-3.5-turbo-0301', 'gpt-3.5-turbo-0613': 'azure/gpt-3.5-turbo-0613', 'gpt-3.5-turbo-16k': 'azure/gpt-3.5-turbo-16k', 'gpt-3.5-turbo-16k-0613': 'azure/gpt-3.5-turbo-16k-0613', 'gpt-4-vision': 'azure/gpt-4-vision'}, 'vertex-ai': {'llama-3-1-8b-instruct': 'vertex_ai/meta/llama-3.1-8b-instruct-maas', 'llama-3-1-70b-instruct': 'vertex_ai/meta/llama-3.1-70b-instruct-maas', 'llama-3-1-405b-instruct': 'vertex_ai/meta/llama-3.1-405b-instruct-maas'}, 'replicate': {'granite-3-2-8b-instruct': 'replicate/ibm-granite/granite-3.2-8b-instruct', 'granite-vision-3-2-2b': 'replicate/ibm-granite/granite-vision-3.2-2b', 'granite-3-1-8b-instruct': 'replicate/ibm-granite/granite-3.1-8b-instruct', 'granite-3-1-2b-instruct': 'replicate/ibm-granite/granite-3.1-2b-instruct', 'granite-3-8b-instruct': 'replicate/ibm-granite/granite-3.0-8b-instruct', 'granite-3-2b-instruct': 'replicate/ibm-granite/granite-3.0-2b-instruct', 'granite-8b-code-instruct-128k': 'replicate/ibm-granite/granite-8b-code-instruct-128k', 'granite-20b-code-instruct-8k': 'replicate/ibm-granite/granite-20b-code-instruct-8k', 'llama-2-13b': 'replicate/meta/llama-2-13b', 'llama-2-13b-chat': 'replicate/meta/llama-2-13b-chat', 'llama-2-70b': 'replicate/meta/llama-2-70b', 'llama-2-70b-chat': 'replicate/meta/llama-2-70b-chat', 'llama-2-7b': 'replicate/meta/llama-2-7b', 'llama-2-7b-chat': 'replicate/meta/llama-2-7b-chat', 'llama-3-1-405b-instruct': 'replicate/meta/meta-llama-3.1-405b-instruct', 'llama-3-70b': 'replicate/meta/meta-llama-3-70b', 'llama-3-70b-instruct': 'replicate/meta/meta-llama-3-70b-instruct', 'llama-3-8b': 'replicate/meta/meta-llama-3-8b', 'llama-3-8b-instruct': 'replicate/meta/meta-llama-3-8b-instruct', 'mistral-7b-instruct-v0.2': 'replicate/mistralai/mistral-7b-instruct-v0.2', 'mistral-7b-v0.1': 'replicate/mistralai/mistral-7b-v0.1', 'mixtral-8x7b-instruct-v0.1': 'replicate/mistralai/mixtral-8x7b-instruct-v0.1'}, 'watsonx': {'granite-20b-code-instruct': 'watsonx/ibm/granite-20b-code-instruct', 'granite-3-2-8b-instruct': 'watsonx/ibm/granite-3-2-8b-instruct', 'granite-3-2b-instruct': 'watsonx/ibm/granite-3-2b-instruct', 'granite-3-8b-instruct': 'watsonx/ibm/granite-3-8b-instruct', 'granite-34b-code-instruct': 'watsonx/ibm/granite-34b-code-instruct', 'granite-guardian-3-8b': 'watsonx/ibm/granite-guardian-3-8b', 'granite-vision-3-2-2b': 'watsonx/ibm/granite-vision-3-2-2b', 'llama-3-1-8b-instruct': 'watsonx/meta-llama/llama-3-1-8b-instruct', 'llama-3-1-70b-instruct': 'watsonx/meta-llama/llama-3-1-70b-instruct', 'llama-3-1-405b-instruct': 'watsonx/meta-llama/llama-3-405b-instruct', 'llama-3-2-11b-vision-instruct': 'watsonx/meta-llama/llama-3-2-11b-vision-instruct', 'llama-3-2-1b-instruct': 'watsonx/meta-llama/llama-3-2-1b-instruct', 'llama-3-2-3b-instruct': 'watsonx/meta-llama/llama-3-2-3b-instruct', 'llama-3-2-90b-vision-instruct': 'watsonx/meta-llama/llama-3-2-90b-vision-instruct', 'llama-3-3-70b-instruct': 'watsonx/meta-llama/llama-3-3-70b-instruct', 'llama-guard-3-11b-vision': 'watsonx/meta-llama/llama-guard-3-11b-vision', 'mistral-large-instruct': 'watsonx/mistralai/mistral-large', 'mixtral-8x7b-instruct-v01': 'watsonx/mistralai/mixtral-8x7b-instruct-v01'}})[source]¶
Bases:
InferenceEngine
,StandardAPIParamsMixin
Inference engine capable of dynamically switching between multiple providers APIs.
This class extends the InferenceEngine and OpenAiInferenceEngineParamsMixin to enable seamless integration with various API providers. The supported APIs are specified in
_supported_apis
, allowing users to interact with multiple models from different sources. Theprovider_model_map
dictionary maps each API to specific model identifiers, enabling automatic configuration based on user requests.Current _supported_apis = [“watsonx”, “together-ai”, “open-ai”, “aws”, “ollama”, “bam”, “watsonx-sdk”, “rits”, “vertex-ai”]
- Parameters:
provider (Optional) – Specifies the current API in use. Must be one of the literals in _supported_apis.
provider_model_map (Dict[_supported_apis, Dict[str, str]]) – mapping each supported API to a corresponding model identifier string. This mapping allows consistent access to models across different API backends.
provider_specific_args – (Optional[Dict[str, Dict[str,str]]]) Args specific to a provider for example provider_specific_args={“watsonx”: {“max_requests_per_second”: 4}}
- provider_model_map: Dict[Literal['watsonx', 'together-ai', 'open-ai', 'aws', 'ollama', 'bam', 'watsonx-sdk', 'rits', 'azure', 'vertex-ai', 'replicate'], Dict[str, str]] = {'aws': {'llama-3-70b-instruct': 'bedrock/meta.llama3-70b-instruct-v1:0', 'llama-3-8b-instruct': 'bedrock/meta.llama3-8b-instruct-v1:0'}, 'azure': {'gpt-3.5-turbo': 'azure/gpt-3.5-turbo', 'gpt-3.5-turbo-0301': 'azure/gpt-3.5-turbo-0301', 'gpt-3.5-turbo-0613': 'azure/gpt-3.5-turbo-0613', 'gpt-3.5-turbo-16k': 'azure/gpt-3.5-turbo-16k', 'gpt-3.5-turbo-16k-0613': 'azure/gpt-3.5-turbo-16k-0613', 'gpt-4': 'azure/gpt-4', 'gpt-4-0125-preview': 'azure/gpt-4-0125-preview', 'gpt-4-0314': 'azure/gpt-4-0314', 'gpt-4-0613': 'azure/gpt-4-0613', 'gpt-4-1106-preview': 'azure/gpt-4-1106-preview', 'gpt-4-32k': 'azure/gpt-4-32k', 'gpt-4-32k-0314': 'azure/gpt-4-32k-0314', 'gpt-4-32k-0613': 'azure/gpt-4-32k-0613', 'gpt-4-turbo': 'azure/gpt-4-turbo-2024-04-09', 'gpt-4-vision': 'azure/gpt-4-vision', 'gpt-4o': 'azure/gpt-4o', 'gpt-4o-2024-08-06': 'azure/gpt-4o-2024-08-06', 'gpt-4o-mini': 'azure/gpt-4o-mini', 'o1-mini': 'azure/o1-mini', 'o1-preview': 'azure/o1-preview'}, 'bam': {'flan-t5-xxl': 'google/flan-t5-xxl', 'granite-3-8b-instruct': 'ibm/granite-8b-instruct-preview-4k', 'llama-3-2-1b-instruct': 'meta-llama/llama-3-2-1b-instruct', 'llama-3-8b-instruct': 'meta-llama/llama-3-8b-instruct'}, 'ollama': {'llama-3-1-405b-instruct': 'llama3.1:405b', 'llama-3-1-70b-instruct': 'llama3.1:70b', 'llama-3-1-8b-instruct': 'llama3.1:8b', 'llama-3-2-1b-instruct': 'llama3.2:1b', 'llama-3-2-3b-instruct': 'llama3.2:3b', 'llama-3-3-70b-instruct': 'llama3.3', 'llama-3-70b-instruct': 'llama3:70b', 'llama-3-8b-instruct': 'llama3:8b'}, 'open-ai': {'gpt-3.5-turbo': 'gpt-3.5-turbo', 'gpt-3.5-turbo-0301': 'gpt-3.5-turbo-0301', 'gpt-3.5-turbo-0613': 'gpt-3.5-turbo-0613', 'gpt-3.5-turbo-1106': 'gpt-3.5-turbo-1106', 'gpt-3.5-turbo-16k': 'gpt-3.5-turbo-16k', 'gpt-3.5-turbo-16k-0613': 'gpt-3.5-turbo-16k-0613', 'gpt-4': 'gpt-4', 'gpt-4-0125-preview': 'gpt-4-0125-preview', 'gpt-4-0314': 'gpt-4-0314', 'gpt-4-0613': 'gpt-4-0613', 'gpt-4-1106-preview': 'gpt-4-1106-preview', 'gpt-4-32k': 'gpt-4-32k', 'gpt-4-32k-0314': 'gpt-4-32k-0314', 'gpt-4-32k-0613': 'gpt-4-32k-0613', 'gpt-4-turbo': 'gpt-4-turbo', 'gpt-4-turbo-preview': 'gpt-4-0125-preview', 'gpt-4-vision-preview': 'gpt-4-vision-preview', 'gpt-4o': 'gpt-4o', 'gpt-4o-2024-05-13': 'gpt-4o-2024-05-13', 'gpt-4o-2024-08-06': 'gpt-4o-2024-08-06', 'gpt-4o-mini': 'gpt-4o-mini', 'gpt-4o-mini-2024-07-18': 'gpt-4o-mini-2024-07-18', 'o1-mini': 'o1-mini', 'o1-preview': 'o1-preview'}, 'replicate': {'granite-20b-code-instruct-8k': 'replicate/ibm-granite/granite-20b-code-instruct-8k', 'granite-3-1-2b-instruct': 'replicate/ibm-granite/granite-3.1-2b-instruct', 'granite-3-1-8b-instruct': 'replicate/ibm-granite/granite-3.1-8b-instruct', 'granite-3-2-8b-instruct': 'replicate/ibm-granite/granite-3.2-8b-instruct', 'granite-3-2b-instruct': 'replicate/ibm-granite/granite-3.0-2b-instruct', 'granite-3-8b-instruct': 'replicate/ibm-granite/granite-3.0-8b-instruct', 'granite-8b-code-instruct-128k': 'replicate/ibm-granite/granite-8b-code-instruct-128k', 'granite-vision-3-2-2b': 'replicate/ibm-granite/granite-vision-3.2-2b', 'llama-2-13b': 'replicate/meta/llama-2-13b', 'llama-2-13b-chat': 'replicate/meta/llama-2-13b-chat', 'llama-2-70b': 'replicate/meta/llama-2-70b', 'llama-2-70b-chat': 'replicate/meta/llama-2-70b-chat', 'llama-2-7b': 'replicate/meta/llama-2-7b', 'llama-2-7b-chat': 'replicate/meta/llama-2-7b-chat', 'llama-3-1-405b-instruct': 'replicate/meta/meta-llama-3.1-405b-instruct', 'llama-3-70b': 'replicate/meta/meta-llama-3-70b', 'llama-3-70b-instruct': 'replicate/meta/meta-llama-3-70b-instruct', 'llama-3-8b': 'replicate/meta/meta-llama-3-8b', 'llama-3-8b-instruct': 'replicate/meta/meta-llama-3-8b-instruct', 'mistral-7b-instruct-v0.2': 'replicate/mistralai/mistral-7b-instruct-v0.2', 'mistral-7b-v0.1': 'replicate/mistralai/mistral-7b-v0.1', 'mixtral-8x7b-instruct-v0.1': 'replicate/mistralai/mixtral-8x7b-instruct-v0.1'}, 'rits': {'deepseek-v3': 'deepseek-ai/DeepSeek-V3', 'granite-3-2-8b-instruct': 'ibm-granite/granite-3.2-8b-instruct', 'granite-3-3-8b-instruct': 'ibm-granite/granite-3.3-8b-instruct', 'granite-3-8b-instruct': 'ibm-granite/granite-3.0-8b-instruct', 'granite-guardian-3-2-3b-a800m': 'ibm-granite/granite-guardian-3.2-3b-a800m', 'granite-guardian-3-2-5b': 'ibm-granite/granite-guardian-3.2-5b', 'llama-3-1-405b-instruct': 'meta-llama/llama-3-1-405b-instruct-fp8', 'llama-3-1-405b-instruct-fp8': 'meta-llama/llama-3-1-405b-instruct-fp8', 'llama-3-1-70b-instruct': 'meta-llama/llama-3-1-70b-instruct', 'llama-3-1-8b-instruct': 'meta-llama/llama-3-1-8b-instruct', 'llama-3-2-11b-vision-instruct': 'meta-llama/Llama-3.2-11B-Vision-Instruct', 'llama-3-2-90b-vision-instruct': 'meta-llama/Llama-3.2-90B-Vision-Instruct', 'llama-3-3-70b-instruct': 'meta-llama/llama-3-3-70b-instruct', 'mistral-large-instruct': 'mistralai/mistral-large-instruct-2407', 'mixtral-8x7b-instruct': 'mistralai/mixtral-8x7B-instruct-v0.1'}, 'together-ai': {'llama-3-1-405b-instruct': 'together_ai/meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo', 'llama-3-1-70b-instruct': 'together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo', 'llama-3-1-8b-instruct': 'together_ai/meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo', 'llama-3-2-1b-instruct': 'together_ai/togethercomputer/llama-3-2-1b-instruct', 'llama-3-3-70b-instruct': 'together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo', 'llama-3-70b-instruct': 'together_ai/meta-llama/Llama-3-70b-chat-hf', 'llama-3-8b-instruct': 'together_ai/meta-llama/Llama-3-8b-chat-hf'}, 'vertex-ai': {'llama-3-1-405b-instruct': 'vertex_ai/meta/llama-3.1-405b-instruct-maas', 'llama-3-1-70b-instruct': 'vertex_ai/meta/llama-3.1-70b-instruct-maas', 'llama-3-1-8b-instruct': 'vertex_ai/meta/llama-3.1-8b-instruct-maas'}, 'watsonx': {'granite-20b-code-instruct': 'watsonx/ibm/granite-20b-code-instruct', 'granite-3-2-8b-instruct': 'watsonx/ibm/granite-3-2-8b-instruct', 'granite-3-2b-instruct': 'watsonx/ibm/granite-3-2b-instruct', 'granite-3-8b-instruct': 'watsonx/ibm/granite-3-8b-instruct', 'granite-34b-code-instruct': 'watsonx/ibm/granite-34b-code-instruct', 'granite-guardian-3-8b': 'watsonx/ibm/granite-guardian-3-8b', 'granite-vision-3-2-2b': 'watsonx/ibm/granite-vision-3-2-2b', 'llama-3-1-405b-instruct': 'watsonx/meta-llama/llama-3-405b-instruct', 'llama-3-1-70b-instruct': 'watsonx/meta-llama/llama-3-1-70b-instruct', 'llama-3-1-8b-instruct': 'watsonx/meta-llama/llama-3-1-8b-instruct', 'llama-3-2-11b-vision-instruct': 'watsonx/meta-llama/llama-3-2-11b-vision-instruct', 'llama-3-2-1b-instruct': 'watsonx/meta-llama/llama-3-2-1b-instruct', 'llama-3-2-3b-instruct': 'watsonx/meta-llama/llama-3-2-3b-instruct', 'llama-3-2-90b-vision-instruct': 'watsonx/meta-llama/llama-3-2-90b-vision-instruct', 'llama-3-3-70b-instruct': 'watsonx/meta-llama/llama-3-3-70b-instruct', 'llama-guard-3-11b-vision': 'watsonx/meta-llama/llama-guard-3-11b-vision', 'mistral-large-instruct': 'watsonx/mistralai/mistral-large', 'mixtral-8x7b-instruct-v01': 'watsonx/mistralai/mixtral-8x7b-instruct-v01'}, 'watsonx-sdk': {'granite-20b-code-instruct': 'ibm/granite-20b-code-instruct', 'granite-3-2-8b-instruct': 'ibm/granite-3-2-8b-instruct', 'granite-3-2b-instruct': 'ibm/granite-3-2b-instruct', 'granite-3-8b-instruct': 'ibm/granite-3-8b-instruct', 'granite-34b-code-instruct': 'ibm/granite-34b-code-instruct', 'granite-guardian-3-8b': 'ibm/granite-guardian-3-8b', 'granite-vision-3-2-2b': 'ibm/granite-vision-3-2-2b', 'llama-3-1-405b-instruct': 'meta-llama/llama-3-405b-instruct', 'llama-3-1-70b-instruct': 'meta-llama/llama-3-1-70b-instruct', 'llama-3-1-8b-instruct': 'meta-llama/llama-3-1-8b-instruct', 'llama-3-2-11b-vision-instruct': 'meta-llama/llama-3-2-11b-vision-instruct', 'llama-3-2-1b-instruct': 'meta-llama/llama-3-2-1b-instruct', 'llama-3-2-3b-instruct': 'meta-llama/llama-3-2-3b-instruct', 'llama-3-2-90b-vision-instruct': 'meta-llama/llama-3-2-90b-vision-instruct', 'llama-3-3-70b-instruct': 'meta-llama/llama-3-3-70b-instruct', 'llama-guard-3-11b-vision': 'meta-llama/llama-guard-3-11b-vision', 'mistral-large-instruct': 'mistralai/mistral-large', 'mixtral-8x7b-instruct-v01': 'mistralai/mixtral-8x7b-instruct-v01'}}¶
- class unitxt.inference.GenericInferenceEngine(data_classification_policy: List[str] = None, cache_batch_size: int = 100, use_cache: bool = True, default: str | NoneType = None)[source]¶
Bases:
InferenceEngine
,ArtifactFetcherMixin
,LogProbInferenceEngine
- class unitxt.inference.HFAutoModelInferenceEngine(data_classification_policy: List[str] = None, device: str | NoneType = None, max_new_tokens: int = __required__, do_sample: bool = False, temperature: float | NoneType = None, top_p: float | NoneType = None, top_k: int | NoneType = None, num_beams: int | NoneType = None, repetition_penalty: float | NoneType = None, pad_token_id: int | NoneType = None, eos_token_id: int | NoneType = None, lazy_load: bool = False, _requirements_list: List[str] | Dict[str, str] = {'transformers': "Install huggingface package using 'pip install --upgrade transformers", 'torch': 'Install torch, go on PyTorch website for mode details.', 'accelerate': 'pip install accelerate'}, requirements: List[str] | Dict[str, str] = [], cache_batch_size: int = 100, use_cache: bool = True, model_name: str = __required__, label: str = 'hf_auto_model', n_top_tokens: int = 5, device_map: Any = None, use_fast_tokenizer: bool = True, low_cpu_mem_usage: bool = True, torch_dtype: str = 'torch.float16', batch_size: int = 1, model: Any = None, processor: Any = None, use_fp16: bool = True, load_in_8bit: bool = False, padding: bool = True, truncation: bool = True, padding_side: str = 'left', chat_kwargs_dict: dict = {})[source]¶
Bases:
HFInferenceEngineBase
- chat_kwargs_dict: dict = {}¶
- class unitxt.inference.HFGenerationParamsMixin(data_classification_policy: List[str] = None, max_new_tokens: int = __required__, do_sample: bool = False, temperature: float | NoneType = None, top_p: float | NoneType = None, top_k: int | NoneType = None, num_beams: int | NoneType = None, repetition_penalty: float | NoneType = None, pad_token_id: int | NoneType = None, eos_token_id: int | NoneType = None)[source]¶
Bases:
Artifact
- class unitxt.inference.HFInferenceEngineBase(data_classification_policy: List[str] = None, device: str | NoneType = None, max_new_tokens: int = __required__, do_sample: bool = False, temperature: float | NoneType = None, top_p: float | NoneType = None, top_k: int | NoneType = None, num_beams: int | NoneType = None, repetition_penalty: float | NoneType = None, pad_token_id: int | NoneType = None, eos_token_id: int | NoneType = None, lazy_load: bool = False, _requirements_list: List[str] | Dict[str, str] = {'transformers': "Install huggingface package using 'pip install --upgrade transformers", 'torch': 'Install torch, go on PyTorch website for mode details.', 'accelerate': 'pip install accelerate'}, requirements: List[str] | Dict[str, str] = [], cache_batch_size: int = 100, use_cache: bool = True, model_name: str = __required__, label: str = __required__, n_top_tokens: int = 5, device_map: Any = None, use_fast_tokenizer: bool = True, low_cpu_mem_usage: bool = True, torch_dtype: str = 'torch.float16', batch_size: int = 1, model: Any = None, processor: Any = None)[source]¶
Bases:
InferenceEngine
,LogProbInferenceEngine
,PackageRequirementsMixin
,LazyLoadMixin
,HFGenerationParamsMixin
,TorchDeviceMixin
- class unitxt.inference.HFLlavaInferenceEngine(data_classification_policy: List[str] = None, device: str | NoneType = None, max_new_tokens: int = __required__, do_sample: bool = False, temperature: float | NoneType = None, top_p: float | NoneType = None, top_k: int | NoneType = None, num_beams: int | NoneType = None, repetition_penalty: float | NoneType = None, pad_token_id: int | NoneType = None, eos_token_id: int | NoneType = None, lazy_load: bool = True, _requirements_list: List[str] | Dict[str, str] = {'transformers': "Install huggingface package using 'pip install --upgrade transformers", 'torch': 'Install torch, go on PyTorch website for mode details.', 'accelerate': 'pip install accelerate'}, requirements: List[str] | Dict[str, str] = [], cache_batch_size: int = 100, use_cache: bool = True, model_name: str = __required__, label: str = 'hf_lava', n_top_tokens: int = 5, device_map: Any = None, use_fast_tokenizer: bool = True, low_cpu_mem_usage: bool = True, torch_dtype: str = 'torch.float16', batch_size: int = 1, model: Any = None, processor: Any = None, image_token: str = '<image>')[source]¶
Bases:
HFInferenceEngineBase
- class unitxt.inference.HFOptionSelectingInferenceEngine(data_classification_policy: List[str] = None, device: str | NoneType = None, cache_batch_size: int = 100, use_cache: bool = True, model_name: str = __required__, batch_size: int = __required__)[source]¶
Bases:
InferenceEngine
,TorchDeviceMixin
HuggingFace based class for inference engines that calculate log probabilities.
This class uses models from the HuggingFace Transformers library to calculate log probabilities for text inputs.
- class unitxt.inference.HFPeftInferenceEngine(data_classification_policy: List[str] = None, device: str | NoneType = None, max_new_tokens: int = __required__, do_sample: bool = False, temperature: float | NoneType = None, top_p: float | NoneType = None, top_k: int | NoneType = None, num_beams: int | NoneType = None, repetition_penalty: float | NoneType = None, pad_token_id: int | NoneType = None, eos_token_id: int | NoneType = None, lazy_load: bool = False, _requirements_list: List[str] | Dict[str, str] = {'transformers': "Install huggingface package using 'pip install --upgrade transformers", 'torch': 'Install torch, go on PyTorch website for mode details.', 'accelerate': 'pip install accelerate', 'peft': "Install 'peft' package using: 'pip install peft'."}, requirements: List[str] | Dict[str, str] = [], cache_batch_size: int = 100, use_cache: bool = True, model_name: str = __required__, label: str = 'hf_peft_auto_model', n_top_tokens: int = 5, device_map: Any = None, use_fast_tokenizer: bool = True, low_cpu_mem_usage: bool = True, torch_dtype: str = 'torch.float16', batch_size: int = 1, model: Any = None, processor: Any = None, use_fp16: bool = True, load_in_8bit: bool = False, padding: bool = True, truncation: bool = True, padding_side: str = 'left', chat_kwargs_dict: dict = {}, peft_config: Any = None)[source]¶
Bases:
HFAutoModelInferenceEngine
- class unitxt.inference.HFPipelineBasedInferenceEngine(data_classification_policy: List[str] = None, device: str | NoneType = None, max_new_tokens: int = __required__, do_sample: bool = False, temperature: float | NoneType = None, top_p: float | NoneType = None, top_k: int | NoneType = None, num_beams: int | NoneType = None, repetition_penalty: float | NoneType = None, pad_token_id: int | NoneType = None, eos_token_id: int | NoneType = None, lazy_load: bool = False, _requirements_list: List[str] | Dict[str, str] = {'transformers': "Install huggingface package using 'pip install --upgrade transformers", 'torch': 'Install torch, go on PyTorch website for mode details.', 'accelerate': 'pip install accelerate'}, requirements: List[str] | Dict[str, str] = [], cache_batch_size: int = 100, use_cache: bool = True, model_name: str = __required__, label: str = 'hf_pipeline_inference_engine', use_fast_tokenizer: bool = True, use_fp16: bool = True, load_in_8bit: bool = False, task: str | NoneType = None, device_map: Any = None, pipe: Any = None)[source]¶
Bases:
InferenceEngine
,PackageRequirementsMixin
,LazyLoadMixin
,HFGenerationParamsMixin
,TorchDeviceMixin
- class unitxt.inference.IbmGenAiInferenceEngine(data_classification_policy: List[str] = ['public', 'proprietary'], _requirements_list: List[str] | Dict[str, str] = {'ibm-generative-ai': "Install ibm-genai package using 'pip install --upgrade ibm-generative-ai"}, requirements: List[str] | Dict[str, str] = [], beam_width: int | NoneType = None, decoding_method: Literal['greedy', 'sample'] | NoneType = None, include_stop_sequence: bool | NoneType = None, length_penalty: Any = None, max_new_tokens: int | NoneType = None, min_new_tokens: int | NoneType = None, random_seed: int | NoneType = None, repetition_penalty: float | NoneType = None, return_options: Any = None, stop_sequences: List[str] | NoneType = None, temperature: float | NoneType = None, time_limit: int | NoneType = None, top_k: int | NoneType = None, top_p: float | NoneType = None, truncate_input_tokens: int | NoneType = None, typical_p: float | NoneType = None, cache_batch_size: int = 100, use_cache: bool = True, label: str = 'ibm_genai', model_name: str = __required__, parameters: IbmGenAiInferenceEngineParams | NoneType = None, rate_limit: int = 10)[source]¶
Bases:
InferenceEngine
,IbmGenAiInferenceEngineParamsMixin
,PackageRequirementsMixin
,LogProbInferenceEngine
,OptionSelectingByLogProbsInferenceEngine
- data_classification_policy: List[str] = ['public', 'proprietary']¶
- class unitxt.inference.IbmGenAiInferenceEngineParamsMixin(data_classification_policy: List[str] = None, beam_width: int | NoneType = None, decoding_method: Literal['greedy', 'sample'] | NoneType = None, include_stop_sequence: bool | NoneType = None, length_penalty: Any = None, max_new_tokens: int | NoneType = None, min_new_tokens: int | NoneType = None, random_seed: int | NoneType = None, repetition_penalty: float | NoneType = None, return_options: Any = None, stop_sequences: List[str] | NoneType = None, temperature: float | NoneType = None, time_limit: int | NoneType = None, top_k: int | NoneType = None, top_p: float | NoneType = None, truncate_input_tokens: int | NoneType = None, typical_p: float | NoneType = None)[source]¶
Bases:
Artifact
- class unitxt.inference.InferenceEngine(data_classification_policy: List[str] = None, cache_batch_size: int = 100, use_cache: bool = True)[source]¶
Bases:
Artifact
Abstract base class for inference.
- get_model_details() Dict [source]¶
Might not be possible to implement for all inference engines. Returns an empty dict by default.
- infer(dataset: List[Dict[str, Any]] | Dataset, return_meta_data: bool = False) ListWithMetadata[str] | ListWithMetadata[TextGenerationInferenceOutput] [source]¶
Verifies instances of a dataset and perform inference on the input dataset.
If return_meta_data - returns a list of TextGenerationInferenceOutput, else returns a list of the string predictions.
- class unitxt.inference.LMMSEvalBaseInferenceEngine(data_classification_policy: List[str] = None, device: str | NoneType = None, lazy_load: bool = False, _requirements_list: List[str] | Dict[str, str] = {'lmms_eval': "Install llms-eval package using 'pip install lmms-eval==0.2.4'"}, requirements: List[str] | Dict[str, str] = [], cache_batch_size: int = 100, use_cache: bool = True, model_type: str = __required__, model_args: Dict[str, str] = __required__, batch_size: int = 1, image_token: str = '<image>')[source]¶
Bases:
InferenceEngine
,PackageRequirementsMixin
,LazyLoadMixin
,TorchDeviceMixin
- class unitxt.inference.LMMSEvalInferenceEngine(data_classification_policy: List[str] = None, device: str | NoneType = None, lazy_load: bool = False, _requirements_list: List[str] | Dict[str, str] = {'lmms_eval': "Install llms-eval package using 'pip install lmms-eval==0.2.4'"}, requirements: List[str] | Dict[str, str] = [], cache_batch_size: int = 100, use_cache: bool = True, model_type: str = __required__, model_args: Dict[str, str] = __required__, batch_size: int = 1, image_token: str = '<image>', max_new_tokens: int = 32, temperature: float = 0.0, do_sample: bool = False, generate_until: List[str] = ['\n\n'])[source]¶
Bases:
LMMSEvalBaseInferenceEngine
- generate_until: List[str] = ['\n\n']¶
- class unitxt.inference.LMMSEvalLoglikelihoodInferenceEngine(data_classification_policy: List[str] = None, device: str | NoneType = None, lazy_load: bool = False, _requirements_list: List[str] | Dict[str, str] = {'lmms_eval': "Install llms-eval package using 'pip install lmms-eval==0.2.4'"}, requirements: List[str] | Dict[str, str] = [], cache_batch_size: int = 100, use_cache: bool = True, model_type: str = __required__, model_args: Dict[str, str] = __required__, batch_size: int = 1, image_token: str = '<image>', request_type: Literal['loglikelihood'] = 'loglikelihood')[source]¶
Bases:
LMMSEvalBaseInferenceEngine
- class unitxt.inference.LazyLoadMixin(data_classification_policy: List[str] = None, lazy_load: bool = False)[source]¶
Bases:
Artifact
- class unitxt.inference.ListWithMetadata(*args, metadata: dict | None = None, **kwargs)[source]¶
Bases:
List
[T
]
- class unitxt.inference.LiteLLMInferenceEngine(data_classification_policy: List[str] = None, _requirements_list: list = ['litellm', 'tenacity', 'tqdm', 'diskcache'], requirements: List[str] | Dict[str, str] = [], model: str = __required__, frequency_penalty: float | NoneType = None, presence_penalty: float | NoneType = None, max_tokens: int | NoneType = None, seed: int | NoneType = None, stop: str | NoneType | List[str] = None, temperature: float | NoneType = None, top_p: float | NoneType = None, top_logprobs: int | NoneType = None, logit_bias: Dict[str, int] | NoneType = None, logprobs: bool | NoneType = None, n: int | NoneType = None, parallel_tool_calls: bool | NoneType = None, service_tier: Literal['auto', 'default'] | NoneType = None, credentials: Dict[str, str] | NoneType = {}, extra_headers: Dict[str, str] | NoneType = None, cache_batch_size: int = 100, use_cache: bool = True, label: str = 'litellm', max_requests_per_second: float = 6, max_retries: int = 5)[source]¶
Bases:
InferenceEngine
,StandardAPIParamsMixin
,PackageRequirementsMixin
- class unitxt.inference.LogProbInferenceEngine(data_classification_policy: List[str] = None)[source]¶
Bases:
ABC
,Artifact
Abstract base class for inference with log probs.
- infer_log_probs(dataset: List[Dict[str, Any]] | Dataset, return_meta_data: bool = False) List[Dict] | List[TextGenerationInferenceOutput] [source]¶
Verifies instances of a dataset and performs inference that returns log probabilities of top tokens.
For each instance , generates a list of top tokens per position. [ “top_tokens”: [ { “text”: …, “logprob”: …} , … ] If return_meta_data - returns a list of TextGenerationInferenceOutput, else returns the list of the logprob dicts. return_meta_data is only supported for some InferenceEngines.
- class unitxt.inference.MockInferenceEngine(data_classification_policy: List[str] = None, cache_batch_size: int = 100, use_cache: bool = True, model_name: str = __required__, default_inference_value: str = '[[10]]', default_inference_value_logprob: List[Dict[str, Any]] = [{'logprob': -1, 'text': '[[10]]', 'top_tokens': [{'logprob': -1, 'text': '[[10]]'}]}], label: str = 'mock_inference_engine')[source]¶
- class unitxt.inference.MockModeMixin(data_classification_policy: List[str] = None, mock_mode: bool = False)[source]¶
Bases:
Artifact
- class unitxt.inference.OllamaInferenceEngine(data_classification_policy: List[str] = ['public', 'proprietary'], _requirements_list: List[str] | Dict[str, str] = {'ollama': "Install ollama package using 'pip install --upgrade ollama"}, requirements: List[str] | Dict[str, str] = [], model: str = __required__, frequency_penalty: float | NoneType = None, presence_penalty: float | NoneType = None, max_tokens: int | NoneType = None, seed: int | NoneType = None, stop: str | NoneType | List[str] = None, temperature: float | NoneType = None, top_p: float | NoneType = None, top_logprobs: int | NoneType = None, logit_bias: Dict[str, int] | NoneType = None, logprobs: bool | NoneType = None, n: int | NoneType = None, parallel_tool_calls: bool | NoneType = None, service_tier: Literal['auto', 'default'] | NoneType = None, credentials: Dict[str, str] | NoneType = {}, extra_headers: Dict[str, str] | NoneType = None, cache_batch_size: int = 100, use_cache: bool = True, label: str = 'ollama')[source]¶
Bases:
InferenceEngine
,StandardAPIParamsMixin
,PackageRequirementsMixin
- data_classification_policy: List[str] = ['public', 'proprietary']¶
- class unitxt.inference.OpenAiInferenceEngine(data_classification_policy: List[str] = ['public'], _requirements_list: List[str] | Dict[str, str] = {'openai': "Install openai package using 'pip install --upgrade openai"}, requirements: List[str] | Dict[str, str] = [], frequency_penalty: float | NoneType = None, presence_penalty: float | NoneType = None, max_tokens: int | NoneType = None, seed: int | NoneType = None, stop: str | NoneType | List[str] = None, temperature: float | NoneType = None, top_p: float | NoneType = None, top_logprobs: int | NoneType = 20, logit_bias: Dict[str, int] | NoneType = None, logprobs: bool | NoneType = True, n: int | NoneType = None, parallel_tool_calls: bool | NoneType = None, service_tier: Literal['auto', 'default'] | NoneType = None, cache_batch_size: int = 100, use_cache: bool = True, label: str = 'openai', model_name: str = __required__, parameters: OpenAiInferenceEngineParams | NoneType = None, base_url: str | NoneType = None, default_headers: Dict[str, str] = {}, credentials: unitxt.inference.CredentialsOpenAi = {}, num_parallel_requests: int = 20)[source]¶
Bases:
InferenceEngine
,LogProbInferenceEngine
,OpenAiInferenceEngineParamsMixin
,PackageRequirementsMixin
- credentials: CredentialsOpenAi = {}¶
- data_classification_policy: List[str] = ['public']¶
- default_headers: Dict[str, str] = {}¶
- class unitxt.inference.OpenAiInferenceEngineParamsMixin(data_classification_policy: List[str] = None, frequency_penalty: float | NoneType = None, presence_penalty: float | NoneType = None, max_tokens: int | NoneType = None, seed: int | NoneType = None, stop: str | NoneType | List[str] = None, temperature: float | NoneType = None, top_p: float | NoneType = None, top_logprobs: int | NoneType = 20, logit_bias: Dict[str, int] | NoneType = None, logprobs: bool | NoneType = True, n: int | NoneType = None, parallel_tool_calls: bool | NoneType = None, service_tier: Literal['auto', 'default'] | NoneType = None)[source]¶
Bases:
Artifact
- class unitxt.inference.OptionSelectingByLogProbsInferenceEngine[source]¶
Bases:
object
OptionSelectingByLogProbsInferenceEngine inference engine is used to select an option based on the logprobs of an options list conditioned by a prompt.
The inference engines that inherit from this class must implement get_token_count and get_options_log_probs.
- abstract get_options_log_probs(dataset)[source]¶
Get the token logprobs of the options of the key task_data.options of each dict of the dataset.
Add to each instance in the data a “options_log_prob” field, which is a dict with str as key and a list of {text: str, logprob:float}.
- Parameters:
dataset (List[Dict[str, Any]]) – A list of dictionaries, each representing a data instance.
- Returns:
The token count of the texts
- Return type:
List[int]
- abstract get_token_count(dataset)[source]¶
Get the token count of the source key of each dict of the dataset. Add to each instance in the data a “token_count” field.
- Parameters:
dataset (List[Dict[str, Any]]) – A list of dictionaries, each representing a data instance.
- Returns:
The token count of the texts
- Return type:
List[int]
- class unitxt.inference.RITSInferenceEngine(data_classification_policy: List[str] = ['public', 'proprietary'], _requirements_list: List[str] | Dict[str, str] = {'openai': "Install openai package using 'pip install --upgrade openai"}, requirements: List[str] | Dict[str, str] = [], frequency_penalty: float | NoneType = None, presence_penalty: float | NoneType = None, max_tokens: int | NoneType = None, seed: int | NoneType = None, stop: str | NoneType | List[str] = None, temperature: float | NoneType = None, top_p: float | NoneType = None, top_logprobs: int | NoneType = 20, logit_bias: Dict[str, int] | NoneType = None, logprobs: bool | NoneType = True, n: int | NoneType = None, parallel_tool_calls: bool | NoneType = None, service_tier: Literal['auto', 'default'] | NoneType = None, cache_batch_size: int = 100, use_cache: bool = True, label: str = 'rits', model_name: str = __required__, parameters: OpenAiInferenceEngineParams | NoneType = None, base_url: str | NoneType = None, default_headers: Dict[str, str] = {}, credentials: unitxt.inference.CredentialsOpenAi = {}, num_parallel_requests: int = 20)[source]¶
Bases:
OpenAiInferenceEngine
- data_classification_policy: List[str] = ['public', 'proprietary']¶
- model_names_dict = {'microsoft/phi-4': 'microsoft-phi-4'}¶
- class unitxt.inference.StandardAPIParamsMixin(data_classification_policy: List[str] = None, model: str = __required__, frequency_penalty: float | NoneType = None, presence_penalty: float | NoneType = None, max_tokens: int | NoneType = None, seed: int | NoneType = None, stop: str | NoneType | List[str] = None, temperature: float | NoneType = None, top_p: float | NoneType = None, top_logprobs: int | NoneType = None, logit_bias: Dict[str, int] | NoneType = None, logprobs: bool | NoneType = None, n: int | NoneType = None, parallel_tool_calls: bool | NoneType = None, service_tier: Literal['auto', 'default'] | NoneType = None, credentials: Dict[str, str] | NoneType = {}, extra_headers: Dict[str, str] | NoneType = None)[source]¶
Bases:
Artifact
- credentials: Dict[str, str] | None = {}¶
- class unitxt.inference.TextGenerationInferenceOutput(prediction: str | List[Dict[str, Any]], input_tokens: int | None = None, output_tokens: int | None = None, stop_reason: str | None = None, seed: int | None = None, input_text: str | None = None, model_name: str | None = None, inference_type: str | None = None)[source]¶
Bases:
object
Contains the prediction results and metadata for the inference.
- Parameters:
prediction (Union[str, List[Dict[str, Any]]]) – If this is the result of an _infer call, the string predicted by the model.
call (| If this is the results of an _infer_log_probs) – the i’th token in the response. The entry “top_tokens” in the dictionary holds a sorted list of the top tokens for this position and their probabilities.
represents (a list of dictionaries. The i'th dictionary) – the i’th token in the response. The entry “top_tokens” in the dictionary holds a sorted list of the top tokens for this position and their probabilities.
example (| For) –
[ {.. "top_tokens": [ {"text": "a", 'logprob': }, {"text": "b", 'logprob': } ....]}, {.. "top_tokens": [ {"text": "c", 'logprob': }, {"text": "d", 'logprob': } ....]} ]
input_tokens (int) – number of input tokens to the model.
output_tokens (int) – number of output tokens to the model.
stop_reason (str) – stop reason for text generation, for example “eos” (end of string).
seed (int) – seed used by the model during generation.
input_text (str) – input to the model.
model_name (str) – the model_name as kept in the InferenceEngine.
inference_type (str) – The label stating the type of the InferenceEngine.
- class unitxt.inference.TogetherAiInferenceEngine(data_classification_policy: List[str] = ['public'], _requirements_list: List[str] | Dict[str, str] = {'together': "Install together package using 'pip install --upgrade together"}, requirements: List[str] | Dict[str, str] = [], max_tokens: int | NoneType = None, stop: List[str] | NoneType = None, temperature: float | NoneType = None, top_p: float | NoneType = None, top_k: int | NoneType = None, repetition_penalty: float | NoneType = None, logprobs: int | NoneType = None, echo: bool | NoneType = None, n: int | NoneType = None, min_p: float | NoneType = None, presence_penalty: float | NoneType = None, frequency_penalty: float | NoneType = None, cache_batch_size: int = 100, use_cache: bool = True, label: str = 'together', model_name: str = __required__, parameters: unitxt.inference.TogetherAiInferenceEngineParamsMixin | NoneType = None)[source]¶
Bases:
InferenceEngine
,TogetherAiInferenceEngineParamsMixin
,PackageRequirementsMixin
- data_classification_policy: List[str] = ['public']¶
- class unitxt.inference.TogetherAiInferenceEngineParamsMixin(data_classification_policy: List[str] = None, max_tokens: int | NoneType = None, stop: List[str] | NoneType = None, temperature: float | NoneType = None, top_p: float | NoneType = None, top_k: int | NoneType = None, repetition_penalty: float | NoneType = None, logprobs: int | NoneType = None, echo: bool | NoneType = None, n: int | NoneType = None, min_p: float | NoneType = None, presence_penalty: float | NoneType = None, frequency_penalty: float | NoneType = None)[source]¶
Bases:
Artifact
- class unitxt.inference.TorchDeviceMixin(data_classification_policy: List[str] = None, device: str | NoneType = None)[source]¶
Bases:
Artifact
- class unitxt.inference.VLLMInferenceEngine(data_classification_policy: List[str] = None, model: str = __required__, n: int = 1, best_of: int | NoneType = None, _real_n: int | NoneType = None, presence_penalty: float = 0.0, frequency_penalty: float = 0.0, repetition_penalty: float = 1.0, temperature: float = 0.0, top_p: float = 1.0, top_k: int = -1, min_p: float = 0.0, seed: int | NoneType = None, stop: str | List[str] | NoneType = None, stop_token_ids: List[int] | NoneType = None, bad_words: List[str] | NoneType = None, ignore_eos: bool = False, max_tokens: int | NoneType = 16, min_tokens: int = 0, logprobs: int | NoneType = None, prompt_logprobs: int | NoneType = None, _requirements_list: List[str] | Dict[str, str] = [], requirements: List[str] | Dict[str, str] = [], cache_batch_size: int = 100, use_cache: bool = True)[source]¶
Bases:
InferenceEngine
,PackageRequirementsMixin
,VLLMParamsMixin
- class unitxt.inference.VLLMParamsMixin(data_classification_policy: List[str] = None, model: str = __required__, n: int = 1, best_of: int | NoneType = None, _real_n: int | NoneType = None, presence_penalty: float = 0.0, frequency_penalty: float = 0.0, repetition_penalty: float = 1.0, temperature: float = 0.0, top_p: float = 1.0, top_k: int = -1, min_p: float = 0.0, seed: int | NoneType = None, stop: str | List[str] | NoneType = None, stop_token_ids: List[int] | NoneType = None, bad_words: List[str] | NoneType = None, ignore_eos: bool = False, max_tokens: int | NoneType = 16, min_tokens: int = 0, logprobs: int | NoneType = None, prompt_logprobs: int | NoneType = None)[source]¶
Bases:
Artifact
- class unitxt.inference.VLLMRemoteInferenceEngine(data_classification_policy: List[str] = ['public'], _requirements_list: List[str] | Dict[str, str] = {'openai': "Install openai package using 'pip install --upgrade openai"}, requirements: List[str] | Dict[str, str] = [], frequency_penalty: float | NoneType = None, presence_penalty: float | NoneType = None, max_tokens: int | NoneType = None, seed: int | NoneType = None, stop: str | NoneType | List[str] = None, temperature: float | NoneType = None, top_p: float | NoneType = None, top_logprobs: int | NoneType = 20, logit_bias: Dict[str, int] | NoneType = None, logprobs: bool | NoneType = True, n: int | NoneType = None, parallel_tool_calls: bool | NoneType = None, service_tier: Literal['auto', 'default'] | NoneType = None, cache_batch_size: int = 100, use_cache: bool = True, label: str = 'vllm-remote', model_name: str = __required__, parameters: OpenAiInferenceEngineParams | NoneType = None, base_url: str | NoneType = None, default_headers: Dict[str, str] = {}, credentials: unitxt.inference.CredentialsOpenAi = {}, num_parallel_requests: int = 20)[source]¶
Bases:
OpenAiInferenceEngine
- class unitxt.inference.WMLChatParamsMixin(data_classification_policy: List[str] = None, frequency_penalty: float | NoneType = None, top_logprobs: int | NoneType = None, presence_penalty: float | NoneType = None, response_format: Dict[str, Any] | NoneType = None, temperature: float | NoneType = None, max_tokens: int | NoneType = None, time_limit: int | NoneType = None, top_p: float | NoneType = None, n: int | NoneType = None, seed: int | NoneType = None, logit_bias: Dict[str, Any] | NoneType = None, stop: List[str] | NoneType = None)[source]¶
Bases:
Artifact
- class unitxt.inference.WMLGenerationParamsMixin(data_classification_policy: List[str] = None, decoding_method: Literal['greedy', 'sample'] | NoneType = None, length_penalty: Dict[str, float | int] | NoneType = None, temperature: float | NoneType = None, top_p: float | NoneType = None, top_k: int | NoneType = None, random_seed: int | NoneType = None, repetition_penalty: float | NoneType = None, min_new_tokens: int | NoneType = None, max_new_tokens: int | NoneType = None, stop_sequences: List[str] | NoneType = None, time_limit: int | NoneType = None, truncate_input_tokens: int | NoneType = None, prompt_variables: Dict[str, Any] | NoneType = None, return_options: Dict[str, bool] | NoneType = None)[source]¶
Bases:
Artifact
- class unitxt.inference.WMLInferenceEngineBase(data_classification_policy: List[str] = ['public', 'proprietary'], _requirements_list: List[str] | Dict[str, str] = {'ibm_watsonx_ai': "Install ibm-watsonx-ai package using 'pip install --upgrade ibm-watsonx-ai'. It is advised to have Python version >=3.10 installed, as at lower version this package may cause conflicts with other installed packages."}, requirements: List[str] | Dict[str, str] = [], cache_batch_size: int = 100, use_cache: bool = True, credentials: Dict[Literal['url', 'username', 'password', 'api_key', 'project_id', 'space_id', 'instance_id'], str] | NoneType = None, model_name: str | NoneType = None, deployment_id: str | NoneType = None, concurrency_limit: int = 10, label: str = 'wml', parameters: WMLInferenceEngineParams | unitxt.inference.WMLGenerationParamsMixin | unitxt.inference.WMLChatParamsMixin | NoneType = None, _client: Any = None, _model: Any = None)[source]¶
Bases:
InferenceEngine
,PackageRequirementsMixin
,LogProbInferenceEngine
,OptionSelectingByLogProbsInferenceEngine
Base for classes running inference using ibm-watsonx-ai.
- Parameters:
credentials (Dict[str, str], optional) – By default, it is created by a class instance which tries to retrieve proper environment variables (“WML_URL”, “WML_PROJECT_ID”, “WML_SPACE_ID”, “WML_APIKEY”, “WML_USERNAME”, “WML_PASSWORD”, “WML_INSTANCE_ID”). However, a dictionary with the following keys: “url”, “apikey”, “project_id”, “space_id”, “username”, “password”, “instance_id” can be directly provided instead.
model_name (str, optional) – ID of a model to be used for inference. Mutually exclusive with ‘deployment_id’.
deployment_id (str, optional) – Deployment ID of a tuned model to be used for inference. Mutually exclusive with ‘model_name’.
concurrency_limit (int) – Number of concurrent requests sent to a model. Default is 10, which is also the maximum value for the generation.
parameters (Union[WMLInferenceEngineParams, WMLGenerationParamsMixin, WMLChatParamsMixin], optional) – Defines inference parameters and their values. Deprecated attribute, please pass respective parameters directly to the respective class instead.
- data_classification_policy: List[str] = ['public', 'proprietary']¶
- class unitxt.inference.WMLInferenceEngineChat(data_classification_policy: List[str] = ['public', 'proprietary'], frequency_penalty: float | NoneType = None, top_logprobs: int | NoneType = None, presence_penalty: float | NoneType = None, response_format: Dict[str, Any] | NoneType = None, temperature: float | NoneType = None, max_tokens: int | NoneType = None, time_limit: int | NoneType = None, top_p: float | NoneType = None, n: int | NoneType = None, seed: int | NoneType = None, logit_bias: Dict[str, Any] | NoneType = None, stop: List[str] | NoneType = None, _requirements_list: List[str] | Dict[str, str] = {'ibm_watsonx_ai': "Install ibm-watsonx-ai package using 'pip install --upgrade ibm-watsonx-ai'. It is advised to have Python version >=3.10 installed, as at lower version this package may cause conflicts with other installed packages."}, requirements: List[str] | Dict[str, str] = [], cache_batch_size: int = 100, use_cache: bool = True, credentials: Dict[Literal['url', 'username', 'password', 'api_key', 'project_id', 'space_id', 'instance_id'], str] | NoneType = None, model_name: str | NoneType = None, deployment_id: str | NoneType = None, concurrency_limit: int = 10, label: str = 'wml', parameters: WMLInferenceEngineParams | unitxt.inference.WMLGenerationParamsMixin | unitxt.inference.WMLChatParamsMixin | NoneType = None, _client: Any = None, _model: Any = None, image_encoder: unitxt.image_operators.EncodeImageToString | NoneType = None)[source]¶
Bases:
WMLInferenceEngineBase
,WMLChatParamsMixin
Creates chat session and returns a model’s response.
You can also include images in your inputs. If you use only textual input, it is recommended to use ‘WMLInferenceEngineGeneration’ instead as it is faster, and allows more parameters for text generation.
You can provide either already formatted messages, or a raw dataset as an input. In case of the former, all passed images should be base64-encoded strings given as an ‘image_url’ within a message. Moreover, only one image per a list of messages may be sent. As for the latter, if there are multiple images per one instance, they will be sent separately with the same query. If that could possibly affect expected responses, concatenate images within an instance into a single image and adjust your query accordingly (if necessary).
- Parameters:
image_encoder (EncodeImageToString, optional) – operator which encodes images in given format to base64 strings required by service. You should specify it when you are using images in your inputs.
Example
from .api import load_dataset from .image_operators image_encoder = EncodeImageToString(image_format="JPEG") wml_credentials = { "url": "some_url", "project_id": "some_id", "api_key": "some_key" } model_name = "meta-llama/llama-3-2-11b-vision-instruct" wml_inference = WMLInferenceEngineChat( credentials=wml_credentials, model_name=model_name, image_encoder=image_encoder, data_classification_policy=["public"], max_tokens=1024, ) dataset = load_dataset( dataset_query="card=cards.doc_vqa.en,template=templates.qa.with_context.with_type,loader_limit=30" ) results = wml_inference.infer(dataset["test"])
- class unitxt.inference.WMLInferenceEngineGeneration(data_classification_policy: List[str] = ['public', 'proprietary'], decoding_method: Literal['greedy', 'sample'] | NoneType = None, length_penalty: Dict[str, float | int] | NoneType = None, temperature: float | NoneType = None, top_p: float | NoneType = None, top_k: int | NoneType = None, random_seed: int | NoneType = None, repetition_penalty: float | NoneType = None, min_new_tokens: int | NoneType = None, max_new_tokens: int | NoneType = None, stop_sequences: List[str] | NoneType = None, time_limit: int | NoneType = None, truncate_input_tokens: int | NoneType = None, prompt_variables: Dict[str, Any] | NoneType = None, return_options: Dict[str, bool] | NoneType = None, _requirements_list: List[str] | Dict[str, str] = {'ibm_watsonx_ai': "Install ibm-watsonx-ai package using 'pip install --upgrade ibm-watsonx-ai'. It is advised to have Python version >=3.10 installed, as at lower version this package may cause conflicts with other installed packages."}, requirements: List[str] | Dict[str, str] = [], cache_batch_size: int = 100, use_cache: bool = True, credentials: Dict[Literal['url', 'username', 'password', 'api_key', 'project_id', 'space_id', 'instance_id'], str] | NoneType = None, model_name: str | NoneType = None, deployment_id: str | NoneType = None, concurrency_limit: int = 10, label: str = 'wml', parameters: WMLInferenceEngineParams | unitxt.inference.WMLGenerationParamsMixin | unitxt.inference.WMLChatParamsMixin | NoneType = None, _client: Any = None, _model: Any = None)[source]¶
Bases:
WMLInferenceEngineBase
,WMLGenerationParamsMixin
Generates text for textual inputs.
If you want to include images in your input, please use ‘WMLInferenceEngineChat’ instead.
Examples
from .api import load_dataset wml_credentials = { "url": "some_url", "project_id": "some_id", "api_key": "some_key" } model_name = "google/flan-t5-xxl" wml_inference = WMLInferenceEngineGeneration( credentials=wml_credentials, model_name=model_name, data_classification_policy=["public"], top_p=0.5, random_seed=123, ) dataset = load_dataset( dataset_query="card=cards.argument_topic,template_card_index=0,loader_limit=5" ) results = wml_inference.infer(dataset["test"])