Source code for schema.config

"""
This module is used to define the config file schema
"""
from typing import Any, Dict
import yaml
from pydantic import BaseModel, Field, model_validator, field_validator
from models.meta_info import ModelMetaInfo
from schema.group_agent import GroupAgent
from schema.shared_state import get_shared_state
from utils.util import clear_default_values, get_field_name


[docs] class ConfigFile(BaseModel): """ This class is used to store the information for the config file """ models:Dict[str, ModelMetaInfo] = Field(validation_alias="Models", description="The models") keys:set[str] = Field(validation_alias="Keys", description="The model keys", default_factory=set) base_dir:str = Field(validation_alias="BaseDir", description="The base directory", default=None) group_chat_agents:Dict[str, GroupAgent] = Field(validation_alias="GroupChatAgents", description="The group chat agents") shared_state:Dict[str, Any] = Field(validation_alias="SharedState", description="The shared state", default_factory=dict)
[docs] @field_validator("shared_state") @classmethod def validate_shared_state(cls, data: Dict[str, Any]) -> Dict[str, Any]: """ This method is used to validate the shared state Args: data: The data Returns: The data """ SHARED_CONFIG = get_shared_state() SHARED_CONFIG.update(**data) return data
[docs] @model_validator(mode='before') @classmethod def validate_models_and_add_base_dir(cls, data: Any) -> Any: """ This method is used to validate the models Args: data: The data Returns: The data """ if isinstance(data, dict): model_key = get_field_name(data, ["models", "Models"]) if model_key is None: raise ValueError("Models key not found in config file") models = data[model_key] for model_key, model in models.items(): if "base_dir" not in model: model["base_dir"] = data["base_dir"] group_chat_agents_key = get_field_name(data, ["group_chat_agents", "GroupChatAgents"]) if group_chat_agents_key is None: raise ValueError("GroupChatAgents key not found in config file") group_chat_agents = data[group_chat_agents_key] for group_chat_agent in group_chat_agents.values(): group_chat_agent["base_dir"] = data["base_dir"] return data
[docs] @model_validator(mode="after") def after_models_validate(self)-> 'ConfigFile': """ This method is used to validate the models Returns: The config file object Raises: InvalidConfigError: If the models are invalid """ models_already_validated = {} for model_key, model in self.models.items(): if model_key in self.keys: raise ValueError(f"Model key {model_key} already exists") self.keys.add(model_key) model.key = model_key if model.inherits_from is None: models_already_validated[model_key] = model continue if model.inherits_from not in models_already_validated: raise ValueError((f"Model {model.inherits_from} not found" " in models or the config is not in " "proper order")) # Let's apply inheritance base_model = models_already_validated[model.inherits_from].model_dump() base_types = [int, float, str, bool] dict_model = model.model_dump() dict_model = clear_default_values(dict_model) base_model = clear_default_values(base_model) for key, value in base_model.items(): if key not in dict_model: dict_model[key] = value # pylint: disable=unidiomatic-typecheck elif type(value) in base_types and type(value) == type(dict_model[key]): pass elif isinstance(value, list) and isinstance(dict_model[key],list): dict_model[key] = list(set(dict_model[key] + value)) elif isinstance(value, dict) and isinstance(dict_model[key], dict): dict_model[key] = {**value, **dict_model[key]} else: raise ValueError((f"Model {model.name} has key {key} with value " f"{dict_model[key]} of type {type(dict_model[key])}" f" but base model {model.inherits_from} has" f" value {value} of type {type(value)}")) del dict_model["inherits_from"] model = ModelMetaInfo(**dict_model) models_already_validated[model_key] = model self.models = models_already_validated return self
[docs] def pydantic_validate_config(config_file:str, base_path:str)->ConfigFile: """ Validates the config file using pydantic Args: config_file: The path to the config file base_path: The base path of the app Returns: The validated config file """ with open(config_file, encoding="utf-8") as file_handler: try: config = yaml.safe_load(file_handler) except yaml.YAMLError as exc: raise exc return ConfigFile(base_dir=base_path, **config)