Source code for models.meta_info

"""
This module contains the meta information for a model
"""
from typing import Optional, Dict, Any, Union,Literal
import os
import sys
from importlib.util import spec_from_file_location, module_from_spec
from pydantic import Field, \
                        field_validator, \
                        model_validator, \
                        FieldValidationInfo, \
                        AliasChoices
from ui_elements.format_option import FormatOption
from ui_elements.base_element import StreamLitPydanticModel
from conversations.conversation import Conversation
from models.base_model import BaseLLMModel
from schema.shared_state import get_shared_state


[docs] class ModelMetaInfo(StreamLitPydanticModel): """ This class is used to store the meta information for a model """ base_dir:str = Field(validation_alias=AliasChoices('base_dir', "BaseDir"), description="The base dir of the app") key: Optional[str] = Field(description="The key for the model", default=None) name:str = Field(validation_alias=AliasChoices('name', 'Name'), description="The name of the model", min_length=3) description:str = Field(validation_alias=AliasChoices('description', "Description"), description="The description of the model", min_length=20) llm_model_file: Optional[str] = Field(validation_alias= AliasChoices('llm_model_file', "LLMModelFile"), description="The model file", default=None) llm_model_class: Optional[str] = Field(validation_alias= AliasChoices('llm_model_class', "LLMModelClass"), description="The model class", default=None) supports_stream: bool = Field(validation_alias= AliasChoices('supports_stream', "SupportsStream"), description="Whether the model supports streaming", default=None) icon:str = Field(validation_alias=AliasChoices('icon', "Icon"), description="The icon for the model", default="llm_model.png") system_message: Optional[str] = Field(validation_alias= AliasChoices('system_message', "SystemMessage"), description="The system message for the model", default=("You are a helpful assistant that" " responds in Markdown format.")) llm_arguments: Dict[str, Any] = Field(validation_alias= AliasChoices('llm_arguments', "LLMArguments"), description="The arguments for the LLM Model", default={}) memory_arguments: Dict[str, Any] = Field(validation_alias= AliasChoices('memory_arguments', "MemoryArguments"), description=("The arguments for the" " chat memory module"), default={}) required_llm_arguments:Dict[str, Literal["STRING", "INT", "FLOAT", "BOOL", "SECRET_STRING", "LONG_STRING", ] ] = Field(validation_alias= AliasChoices('required_llm_arguments', "RequiredLLMArguments"), description=("The required arguments" " for the model"), default={}) inherits_from: Optional[str] = Field(validation_alias=AliasChoices('inherits_from', "InheritsFrom"), description="The model to inherit from", default=None) is_persistent: bool = Field(validation_alias=AliasChoices('is_persistent', "IsPersistent"), description="Whether the model is persistent", default=True) _conversations:list[Conversation] = []
[docs] def add_conversation(self, conversation:Conversation)->None: """ This method is used to add a conversation Args: conversation: The conversation """ self._conversations.append(conversation)
[docs] def get_conversations(self)->list[Conversation]: """ This method is used to get the conversations Returns: The conversations """ return self._conversations
[docs] def get_additional_custom_field_value(self, field_name:str)->Any: """ This method is used to get the value for custom fields Args: field_name: The field name Returns: The field value """ SHARED_CONFIG = get_shared_state() if field_name in SHARED_CONFIG: return SHARED_CONFIG[field_name] value = self.llm_arguments[field_name] if field_name in self.llm_arguments else None if value is not None: SHARED_CONFIG[field_name] = value return value
[docs] def set_additional_custom_field_value(self, field_name:str, field_value:Any): """ This method is used to set the value for custom fields Args: field_name: The field name field_value: The field value """ self.llm_arguments[field_name] = field_value
[docs] def additional_custom_fields_to_show(self)->list[Union[str, FormatOption]]: """ This method is used to get the additional fields to show Returns: The additional fields to show """ fields = [] for argument, argument_type in self.required_llm_arguments.items(): field = FormatOption(format_type=argument_type, title=argument, field_name=argument) fields.append(field) return fields
[docs] def set_value_for_custom_fields(self, field_name:str, field_value:Any): """ This method is used to set the value for custom fields Args: field_name: The field name field_value: The field value """ SHARED_CONFIG = get_shared_state() SHARED_CONFIG[field_name] = field_value self.llm_arguments[field_name] = field_value
[docs] @classmethod def fields_to_show(cls)->list[Union[str, FormatOption]]: """ This method is used to get the fields to show Returns: The fields to show """ schema:Dict[str, Any] = cls.model_json_schema() field_info:Dict[str, Any] = schema["properties"] fields = [ FormatOption(format_type="IMAGE_PATH", title=field_info["icon"]["description"], field_name="icon"), FormatOption(format_type="STRING", title=field_info["name"]["description"], field_name="name"), FormatOption(format_type="LONG_STRING", title=field_info["description"]["description"], field_name="description"), FormatOption(format_type="LONG_STRING", title=field_info["system_message"]["description"], field_name="system_message"), FormatOption(format_type="BOOL", title=field_info["supports_stream"]["description"], field_name="supports_stream"), FormatOption(format_type="STRING", title=field_info["llm_model_file"]["description"], field_name="llm_model_file"), FormatOption(format_type="STRING", title=field_info["llm_model_class"]["description"], field_name="llm_model_class"), FormatOption(format_type="DICT", title=field_info["llm_arguments"]["description"], field_name="llm_arguments"), FormatOption(format_type="DICT", title=field_info["memory_arguments"]["description"], field_name="memory_arguments"), FormatOption(format_type="BOOL", title=field_info["is_persistent"]["description"], field_name="is_persistent") ] return fields
[docs] @classmethod def fields_to_edit(cls)->list[Union[str, Dict[str, Dict[str, FormatOption]]]]: """ This method is used to get the fields to edit Returns: The fields to edit """ schema:Dict[str, Any] = cls.model_json_schema() field_info:Dict[str, Any] = schema["properties"] fields = [ FormatOption(format_type="IMAGE_PATH", title=field_info["icon"]["description"], field_name="icon"), FormatOption(format_type="STRING", title=field_info["name"]["description"], field_name="name", create_value=""), FormatOption(format_type="LONG_STRING", title=field_info["description"]["description"], field_name="description", create_value=""), FormatOption(format_type="LONG_STRING", title=field_info["system_message"]["description"], field_name="system_message"), FormatOption(format_type="BOOL", title=field_info["supports_stream"]["description"], field_name="supports_stream"), FormatOption(format_type="DICT", title=field_info["llm_arguments"]["description"], field_name="llm_arguments"), FormatOption(format_type="DICT", title=field_info["memory_arguments"]["description"], field_name="memory_arguments") ] return fields
[docs] def additional_custom_fields_to_edit(self)->list[Union[str, FormatOption]]: """ This method is used to get the additional fields to show Returns: The additional fields to show """ fields = [] for argument, argument_type in self.required_llm_arguments.items(): field = FormatOption(format_type=argument_type, title=argument, field_name=argument) fields.append(field) return fields
[docs] @field_validator("base_dir") @classmethod def validate_base_dir(cls, value:str)->str: """ Validates the base dir Args: value: The value of the base dir Returns: The value of the base dir Raises: ValueError: If the base dir is not found """ if not os.path.isdir(value): raise ValueError(f"Base dir {value} not found") return value
[docs] @field_validator("icon") @classmethod def validate_icon(cls, value:str, field_info:FieldValidationInfo)->str: """ Validates the icon Args: value: The value of the icon field_info: The field info object from pydantic. Returns: str: The value of the icon Raises: ValueError: If the icon is not found """ value = os.path.join(field_info.data["base_dir"], "assets", value) if not os.path.isfile(value): raise ValueError(f"Icon file {value} not found") return value
# pylint: disable=unused-argument
[docs] @model_validator(mode="before") @classmethod def validate_model(cls, data:Dict, field_info:FieldValidationInfo)->Dict: """ Validates the model Args: data: The data of the model field_info: The field info object from pydantic. Returns: The processed and validated data for the model Raises: ValueError: If the model is invalid """ model_file_keys = ["llm_model_file", "LLMModelFile"] model_class_keys = ["llm_model_class", "LLMModelClass"] inherits_from_keys = ["inherits_from", "InheritsFrom"] model_file_key = None model_class_key = None inherits_from_key = None for key in model_file_keys: if key in data: model_file_key = key break for key in model_class_keys: if key in data: model_class_key = key break for key in inherits_from_keys: if key in data: inherits_from_key = key break if model_file_key is None or model_class_key is None: if inherits_from_key is None: raise ValueError(("Expected llm_model_file and llm_model_class" " or inherits_from to be present")) return data llm_model_file = os.path.join(data["base_dir"], "models", data[model_file_key]) if os.path.isfile(llm_model_file) \ and llm_model_file.endswith(".py") and model_class_key in data: class_name = data[model_class_key] module_name = f"models.{data[model_file_key].split('.', 1)[0]}" spec = spec_from_file_location(module_name, llm_model_file) module = module_from_spec(spec) sys.modules[module_name] = module spec.loader.exec_module(module) if hasattr(module, class_name) and \ issubclass(getattr(module, class_name), BaseLLMModel): data[model_file_key] = llm_model_file return data raise ValueError((f"Class {data[model_class_key]} not found in model file" " {llm_model_file} or is not derived from BaseLLMModel")) raise ValueError((f"Model file {llm_model_file} not found or is not a python" " file or doesn't have a corresponding model class"))