"""
Mixin for handling configs stored in yaml
Should be split off into another package :)
"""
import re
import shutil
from importlib.metadata import version
from itertools import chain
from pathlib import Path
from typing import Any, ClassVar, Literal, Self, Union, overload
import yaml
from pydantic import (
BaseModel,
ConfigDict,
Field,
GetCoreSchemaHandler,
ValidationError,
field_validator,
)
from pydantic_core import core_schema
from noob.types import AbsoluteIdentifier, ConfigID, ConfigSource, valid_config_id
[docs]
class YamlDumper(yaml.SafeDumper):
"""Dumper that can represent extra types like Paths"""
[docs]
def represent_path(self, data: Path) -> yaml.ScalarNode:
"""Represent a path as a string"""
return self.represent_scalar("tag:yaml.org,2002:str", str(data))
YamlDumper.add_representer(type(Path()), YamlDumper.represent_path)
[docs]
class YAMLMixin:
"""
Mixin class that provides :meth:`.from_yaml` and :meth:`.to_yaml`
classmethods
"""
[docs]
@classmethod
def from_yaml(cls: type[Self], file_path: str | Path) -> Self:
"""Instantiate this class by passing the contents of a yaml file as kwargs"""
with open(file_path) as file:
config_data = yaml.safe_load(file)
return cls(**config_data)
[docs]
def to_yaml(self, path: Path | None = None, **kwargs: Any) -> str:
"""
Dump the contents of this class to a yaml file, returning the
contents of the dumped string
"""
data_str = self.to_yamls(**kwargs)
if path:
with open(path, "w") as file:
file.write(data_str)
return data_str
[docs]
def to_yamls(self, **kwargs: Any) -> str:
"""
Dump the contents of this class to a yaml string
Args:
**kwargs: passed to :meth:`.BaseModel.model_dump`
"""
data = self._dump_data(**kwargs)
return yaml.dump(data, Dumper=YamlDumper, sort_keys=False)
def _dump_data(self, **kwargs: Any) -> dict:
data = self.model_dump(**kwargs) if isinstance(self, BaseModel) else self.__dict__
return data
[docs]
class ConfigYAMLMixin(BaseModel, YAMLMixin):
"""
Yaml Mixin class that always puts a header consisting of
* `id` - unique identifier for this config
* `noob_model` - fully-qualified module path to model class
* `noob_version` - version of noob when this model was created
at the top of the file.
"""
model_config = ConfigDict(validate_default=True)
noob_id: ConfigID | None = None
noob_model: AbsoluteIdentifier = Field(None, validate_default=True)
noob_version: str = version("noob")
HEADER_FIELDS: ClassVar[tuple[str, ...]] = ("noob_id", "noob_model", "noob_version")
[docs]
@classmethod
def from_yaml(cls: type[Self], file_path: str | Path) -> Self:
"""Instantiate this class by passing the contents of a yaml file as kwargs"""
file_path = Path(file_path)
with open(file_path) as file:
config_data = yaml.safe_load(file)
# fill in any missing fields in the source file needed for a header
config_data = cls._complete_header(config_data, file_path)
try:
instance = cls(**config_data)
except ValidationError:
if (backup_path := file_path.with_suffix(".yaml.bak")).exists():
from noob.logging import init_logger
init_logger("config").debug(
f"Model instantiation failed, restoring modified backup from {backup_path}..."
)
shutil.copy(backup_path, file_path)
raise
return instance
[docs]
@classmethod
def from_id(cls: type[Self], id: ConfigID) -> Self:
"""
Instantiate a model from a config `id` specified in one of the .yaml configs in
either the user :attr:`.Config.config_dir` or the packaged ``config`` dir.
.. note::
this method does not yet validate that the config matches the model loading it
"""
globs = [src.rglob("*.y*ml") for src in cls.config_sources()]
for config_file in chain(*globs):
try:
file_id = yaml_peek("noob_id", config_file)
except KeyError:
continue
if file_id == id:
from noob.logging import init_logger
init_logger("config").debug(
"Model for %s found at %s", cls._model_name(), config_file
)
return cls.from_yaml(config_file)
raise KeyError(f"No config with id {id} found in {cls.config_sources()}")
[docs]
@classmethod
def from_any(cls: type[Self], source: ConfigSource | Self) -> Self:
"""
Try and instantiate a config model from any supported constructor.
Args:
source (:class:`.ConfigID`, :class:`.Path`, :class:`.PathLike[str]`):
Either
* the ``id`` of a config file in the user configs directory or builtin
* a relative ``Path`` to a config file, relative to the current working directory
* a relative ``Path`` to a config file, relative to the user config directory
* an absolute ``Path`` to a config file
* an instance of the class to be constructed (returned unchanged)
"""
if isinstance(source, cls):
return source
elif isinstance(source, str) and valid_config_id(source):
return cls.from_id(source)
elif isinstance(source, Path | str):
from noob.config import config
source = Path(source)
if source.suffix in (".yaml", ".yml"):
if source.exists():
# either relative to cwd or absolute
return cls.from_yaml(source)
elif (
not source.is_absolute()
and (user_source := config.config_dir / source).exists()
):
return cls.from_yaml(user_source)
raise ValueError(
f"Instance of config model {cls.__name__} could not be instantiated from "
f"{source} - id or file not found, or type not supported"
)
[docs]
@field_validator("noob_model", mode="before")
@classmethod
def fill_noob_model(cls, v: str | None) -> AbsoluteIdentifier:
"""Get name of instantiating model, if not provided"""
if v is None:
v = cls._model_name()
return v
[docs]
@classmethod
def config_sources(cls: type[Self]) -> list[Path]:
"""
Directories to search for config files, in order of priority
such that earlier sources are preferred over later sources.
"""
from noob.config import Config, get_entrypoint_sources, get_extra_sources
return [Config().config_dir, *get_extra_sources(), *get_entrypoint_sources()]
def _dump_data(self, **kwargs: Any) -> dict:
"""Ensure that header is prepended to model data"""
return {**self._yaml_header(self), **super()._dump_data(**kwargs)}
@classmethod
def _model_name(cls) -> AbsoluteIdentifier:
return f"{cls.__module__}.{cls.__name__}"
@classmethod
def _yaml_header(cls, instance: Self | dict) -> dict:
if isinstance(instance, dict):
model_id = instance.get("noob_id", None)
noob_model = instance.get("noob_model", cls._model_name())
noob_version = instance.get("noob_version", version("noob"))
else:
model_id = getattr(instance, "noob_id", None)
noob_model = getattr(instance, "noob_model", cls._model_name())
noob_version = getattr(instance, "noob_version", version("noob"))
if model_id is None:
# if missing an id, try and recover with model default cautiously
# so we throw the exception during validation and not here, for clarity.
model_id = getattr(cls.model_fields.get("noob_id", None), "default", None)
if type(model_id).__name__ == "PydanticUndefinedType":
model_id = None
return {
"noob_id": model_id,
"noob_model": noob_model,
"noob_version": noob_version,
}
@classmethod
def _complete_header(cls: type[Self], data: dict, file_path: str | Path) -> dict:
"""fill in any missing fields in the source file needed for a header"""
file_path = Path(file_path)
missing_fields = set(cls.HEADER_FIELDS) - set(data.keys())
keys = tuple(data.keys())
out_of_order = len(keys) >= 3 and keys[0:3] != cls.HEADER_FIELDS
if missing_fields or out_of_order:
if missing_fields:
msg = f"Missing required header fields {missing_fields} in config model "
f"{str(file_path)}. Updating file (preserving backup)..."
else:
msg = f"Header keys were present, but either not at the start of {str(file_path)} "
"or in out of order. Updating file (preserving backup)..."
from noob.logging import init_logger
logger = init_logger(cls.__name__)
logger.warning(msg)
logger.debug(data)
header = cls._yaml_header(data)
data = {**header, **data}
shutil.copy(file_path, file_path.with_suffix(".yaml.bak"))
with open(file_path, "w") as yfile:
yaml.dump(data, yfile, Dumper=YamlDumper, sort_keys=False)
return data
@classmethod
def __get_pydantic_core_schema__(
cls, source_type: Any, handler: GetCoreSchemaHandler
) -> core_schema.CoreSchema:
"""
Add before_validator to allow instantiation from id
"""
def _from_id(value: Union[str, "ConfigYAMLMixin"]) -> "ConfigYAMLMixin":
if isinstance(value, str):
return cls.from_id(value)
else:
return value
return core_schema.no_info_before_validator_function(
_from_id,
handler(source_type),
# TODO: add this when updating pydantic floor to 2.10
# json_schema_input_schema=core_schema.union_schema(
# [handler(source_type), handler(ConfigID)]
# ),
)
@overload
def yaml_peek(
key: str, path: str | Path, root: bool = True, first: Literal[True] = True
) -> str: ...
@overload
def yaml_peek(
key: str, path: str | Path, root: bool = True, first: Literal[False] = False
) -> list[str]: ...
@overload
def yaml_peek(
key: str, path: str | Path, root: bool = True, first: bool = True
) -> str | list[str]: ...
[docs]
def yaml_peek(key: str, path: str | Path, root: bool = True, first: bool = True) -> str | list[str]:
"""
Peek into a yaml file without parsing the whole file to retrieve the value of a single key.
This function is _not_ designed for robustness to the yaml spec, it is for simple key: value
pairs, not fancy shit like multiline strings, tagged values, etc. If you want it to be,
then i'm afraid you'll have to make a PR about it.
Returns a string no matter what the yaml type is so ya have to do your own casting if you want
Args:
key (str): The key to peek for
path (:class:`pathlib.Path` , str): The yaml file to peek into
root (bool): Only find keys at the root of the document (default ``True`` ), otherwise
find keys at any level of nesting.
first (bool): Only return the first appearance of the key (default). Otherwise return a
list of values (not implemented lol)
Returns:
str
"""
if root:
pattern = re.compile(
rf"^(?P<key>{key}):\s*\"*\'*(?P<value>\S.*?)\"*\'*$", flags=re.MULTILINE
)
else:
pattern = re.compile(
rf"^\s*(?P<key>{key}):\s*\"*\'*(?P<value>\S.*?)\"*\'*$", flags=re.MULTILINE
)
res: re.Match[str] | None = None
if first:
with open(path) as yfile:
for line in yfile:
res = pattern.match(line)
if res:
break
if res is not None:
return res.groupdict()["value"]
else:
with open(path) as yfile:
text = yfile.read()
matches = [match.groupdict()["value"] for match in pattern.finditer(text)]
if matches:
return matches
raise KeyError(f"Key {key} not found in {path}")