"""Module containing configuration."""
from __future__ import annotations
import importlib
import json
import sys
from pathlib import Path
from typing import TYPE_CHECKING
from typing import Any
from packaging.version import Version
from pydantic import BaseModel
from pydantic import ConfigDict
from pydantic import PrivateAttr
from pydantic import PydanticUserError
from pydantic import __version__ as pydantic_version
from pydantic import create_model
from pydantic.fields import FieldInfo
from ruamel.yaml import YAML
from pyglotaran_extras.config.plot_config import PlotConfig
from pyglotaran_extras.config.plot_config import PlotLabelOverrideMap
from pyglotaran_extras.config.plot_config import __PlotFunctionRegistry
from pyglotaran_extras.config.utils import add_yaml_repr
from pyglotaran_extras.io.setup_case_study import get_script_dir
if TYPE_CHECKING:
from collections.abc import Generator
from collections.abc import Iterable
# Only imported for builtin schema generation
from collections.abc import Sequence # noqa: F401
from typing import Literal # noqa: F401
CONFIG_FILE_STEM = "pygta_config"
EXPORT_TEMPLATE = """\
# yaml-language-server: $schema={schema_path}
{config_yaml}\
"""
[docs]
class UsePlotConfigError(Exception):
"""Error thrown when ``use_plot_config`` has none json serializable kwargs."""
def __init__(self, func_name: str, error: PydanticUserError) -> None: # noqa: DOC
"""Use ``func_name`` and original ``error`` to create error message."""
msg = (
f"The function ``{func_name}`` decorated with ``use_plot_config`` has an keyword "
"argument with a type annotation can not be represents in the config.\n"
"Please use the name of this keyword argument in the ``exclude_from_config`` "
"keyword argument to ``use_plot_config``.\n"
f"Original error:\n{error}"
)
super().__init__(msg)
[docs]
@add_yaml_repr
class Config(BaseModel):
"""Main configuration class."""
model_config = ConfigDict(extra="forbid")
plotting: PlotConfig = PlotConfig()
_source_files: list[Path] = PrivateAttr(default_factory=list)
_source_hash: int = PrivateAttr(default=hash(()))
[docs]
def merge(self, other: Config) -> Config:
"""Merge two ``Config``'s where ``other`` overrides values and return a new instance.
Parameters
----------
other : Config
Other ``Config`` to merge in.
Returns
-------
Config
"""
merged = self.model_copy(deep=True)
merged.plotting = merged.plotting.merge(other.plotting)
for source_file in other._source_files:
if source_file in merged._source_files:
merged._source_files.remove(source_file)
merged._source_files.append(source_file)
merged._source_hash = merged._calculate_source_hash()
return merged
def _reset(self, other: Config | None = None) -> Config:
"""Reset self to ``other`` config or default initialization.
Parameters
----------
other : Config | None
Other ``Config`` to to reset to.
Returns
-------
Config
"""
if other is None:
other = Config()
else:
self._source_files = other._source_files
self.plotting = other.plotting
return self
def _calculate_source_hash(self) -> int: # noqa: DOC
"""Calculate hash of source files based on their modification time."""
return hash(tuple(source_file.stat().st_mtime for source_file in self._source_files))
[docs]
def reload(self) -> Config:
"""Reset and reload config from files.
Returns
-------
Config
"""
if self._source_hash == self._calculate_source_hash():
return self
context_config = getattr(self.plotting, "__context_config", None)
merged = self._reset()
for config in load_config_files(self._source_files):
merged = merged.merge(config)
self.plotting = merged.plotting
if context_config is not None:
setattr(self.plotting, "__context_config", context_config)
self._source_hash = merged._source_hash
return self
[docs]
def load(self, config_file_path: Path | str) -> Config:
"""Disregard current config and config file paths, and reload from ``config_file_path``.
Parameters
----------
config_file_path : Path | str
Path to the config file to load.
Returns
-------
Config
"""
self._source_files = [Path(config_file_path)]
return self.reload()
[docs]
def export(self, export_folder: Path | str | None = None, *, update: bool = True) -> Path:
"""Export current config and schema to ``export_folder``.
Parameters
----------
export_folder : Path | str | None
Folder to export config and scheme to. Defaults to None, which means that the script
folder is used
update : bool
Whether to update or overwrite and existing config file. Defaults to True
Returns
-------
Path
Path to exported config file.
"""
if export_folder is None:
from pyglotaran_extras import SCRIPT_DIR
export_folder = SCRIPT_DIR
else:
export_folder = Path(export_folder)
export_folder.mkdir(parents=True, exist_ok=True)
schema_path = create_config_schema(export_folder)
export_path = export_folder / f"{CONFIG_FILE_STEM}.yml"
if export_path.is_file() is True and update is True:
merged = Config().load(export_path).merge(self)
config = merged
else:
config = self
export_path.write_text(
EXPORT_TEMPLATE.format(schema_path=schema_path.name, config_yaml=config),
encoding="utf8",
)
return export_path
[docs]
def rediscover(self, *, include_home_dir: bool = True, lookup_depth: int = 2) -> list[Path]:
"""Rediscover config paths based on the ``SCRIPT_DIR`` discovered on import.
Parameters
----------
include_home_dir : bool
Where or not to include the users home folder in the config lookup. Defaults to True
lookup_depth : int
Depth at which to look for configs in parent folders of ``script_dir``.
If set to ``1`` only ``script_dir`` will be considered as config dir.
Defaults to ``2``.
Returns
-------
list[Path]
Paths of the discovered config files.
"""
from pyglotaran_extras import SCRIPT_DIR
self._source_files = list(
discover_config_files(
SCRIPT_DIR, include_home_dir=include_home_dir, lookup_depth=lookup_depth
)
)
return self._source_files
[docs]
def init_project(self) -> Config:
"""Initialize configuration for the current project.
This will use the configs discovered and resolved config during import to create a new
config and schema for your current project inside of your working directory (script dir),
if it didn't exist before.
Returns
-------
Config
"""
from pyglotaran_extras import SCRIPT_DIR
if any(find_config_in_dir(SCRIPT_DIR)) is False:
self.export()
self.rediscover()
self.reload()
return self
[docs]
def find_config_in_dir(dir_path: Path) -> Generator[Path, None, None]:
"""Find the config file inside of dir ``dir_path``.
Parameters
----------
dir_path : Path
Directory path to look for a config file.
Yields
------
Path
"""
for extension in (".yaml", ".yml"):
config_file = (dir_path / CONFIG_FILE_STEM).with_suffix(extension)
if config_file.is_file():
yield config_file
[docs]
def discover_config_files(
script_dir: Path, *, include_home_dir: bool = True, lookup_depth: int = 2
) -> Generator[Path, None, None]:
"""Find config files in the users home folder and the current working dir and parents.
Parameters
----------
script_dir : Path
Path to the current scripts/notebooks parent folder.
include_home_dir : bool
Where or not to include the users home folder in the config lookup. Defaults to True
lookup_depth : int
Depth at which to look for configs in parent folders of ``script_dir``.
If set to ``1`` only ``script_dir`` will be considered as config dir.
Defaults to ``2``.
Yields
------
Path
"""
if include_home_dir is True:
yield from find_config_in_dir(Path.home())
parent_dirs = tuple(reversed((script_dir / "dummy").parents))
if lookup_depth > 0 and lookup_depth <= len(parent_dirs):
parent_dirs = parent_dirs[-lookup_depth:]
for parent in parent_dirs:
yield from find_config_in_dir(parent)
[docs]
def load_config_files(config_paths: Iterable[Path]) -> Generator[Config, None, None]:
"""Load config files into new config instances.
Parameters
----------
config_paths : Iterable[Path]
Path to the config file.
Yields
------
Config
"""
yaml = YAML()
for config_path in config_paths:
try:
config_dict = yaml.load(config_path)
config = Config.model_validate(config_dict) if config_dict is not None else Config()
config._source_files.append(config_path)
yield config
# We use a very broad range of exception to ensure the config loading at import never
# breaks importing
except Exception as error: # noqa: BLE001
print( # noqa: T201
"Error loading the config:\n",
f"Source path: {config_path.as_posix()}\n",
f"Error: {error}",
file=sys.stderr,
sep="",
)
[docs]
def merge_configs(configs: Iterable[Config]) -> Config:
"""Merge ``Config``'s from left to right, where the right ``Config`` overrides the left.
Parameters
----------
configs : Iterable[Config]
Config instances to merge together.
Returns
-------
Config
"""
full_config = Config()
for config in configs:
full_config = full_config.merge(config)
return full_config
[docs]
def load_config(
script_dir: Path, *, include_home_dir: bool = True, lookup_depth: int = 2
) -> Config:
"""Discover and load config files.
Parameters
----------
script_dir : Path
Path to the current scripts/notebooks parent folder.
include_home_dir : bool
Where or not to include the users home folder in the config lookup. Defaults to True
lookup_depth : int
Depth at which to look for configs in parent folders of ``script_dir``.
If set to ``1`` only ``script_dir`` will be considered as config dir.
Defaults to ``2``.
Returns
-------
Config
See Also
--------
discover_config_files
"""
config_paths = discover_config_files(
script_dir, include_home_dir=include_home_dir, lookup_depth=lookup_depth
)
configs = load_config_files(config_paths)
return merge_configs(configs)
def _find_script_dir_at_import(package_root_file: str) -> Path:
"""Find the script dir when importing ``pyglotaran_extras``.
The assumption is that the first file not inside of ``pyglotaran_extras`` or importlib
is the script in question.
The max ``nesting_offset`` of 20 was chosen semi arbitrarily (typically ``nesting + offset``
is around 9-13 depending on the import) to ensure that there won't be an infinite loop.
Parameters
----------
package_root_file : str
The dunder file attribute (``__file__``) in the package root file.
Returns
-------
Path
"""
nesting_offset = 0
importlib_path = Path(importlib.__file__).parent
package_root = Path(package_root_file).parent
script_dir = get_script_dir(nesting=2)
while (
importlib_path in (script_dir / "dummy").parents
or package_root in (script_dir / "dummy").parents
) and nesting_offset < 20:
nesting_offset += 1
script_dir = get_script_dir(nesting=2 + nesting_offset)
return script_dir
[docs]
def create_config_schema(
output_folder: Path | str | None = None,
file_name: Path | str = f"{CONFIG_FILE_STEM}.schema.json",
) -> Path:
"""Create json schema file to be used for autocompletion and linting of the config.
Parameters
----------
output_folder : Path | str | None
Folder to write schema file to. Defaults to None, which means that the script
folder is used
file_name : Path | str
Name of the scheme file. Defaults to "pygta_config.schema.json"
Returns
-------
Path
Path to the file the schema got saved to.
Raises
------
UsePlotConfigError
If any function decorated with ``use_plot_config`` has a keyword argument with a default
value and a type annotation that can not be serialized into a json schema.
"""
json_schema = Config.model_json_schema()
general_kwargs: dict[str, Any] = {}
for function_name, default_kwargs in __PlotFunctionRegistry.items():
try:
name_prefix = "".join([parts.capitalize() for parts in function_name.split("_")])
fields: Any = {
kwarg_name: (
kwarg_value["annotation"],
FieldInfo(
default=kwarg_value["default"], description=kwarg_value["docstring"]
),
)
for kwarg_name, kwarg_value in default_kwargs.items()
}
kwargs_model_name = f"{name_prefix}Kwargs"
func_kwargs = create_model(
kwargs_model_name,
__config__=ConfigDict(extra="forbid"),
__doc__=(
f"Default arguments to use for ``{function_name}``, "
"if not specified in function call."
),
**fields,
)
config_model_name = f"{name_prefix}Config"
func_config = create_model(
config_model_name,
__config__=ConfigDict(extra="forbid"),
__doc__=(
f"Plot function configuration specific to ``{function_name}`` "
"(overrides values in general)."
),
default_args_override=(func_kwargs, {}),
axis_label_override=(PlotLabelOverrideMap, PlotLabelOverrideMap()),
)
func_json_schema = func_config.model_json_schema()
general_kwargs |= func_json_schema["$defs"][kwargs_model_name]["properties"]
json_schema["$defs"] |= func_json_schema.pop("$defs")
json_schema["$defs"][config_model_name] = func_json_schema
json_schema["$defs"]["PlotConfig"]["properties"][function_name] = (
{"$ref": f"#/$defs/{config_model_name}"}
if Version(pydantic_version) >= Version("2.9")
else {"allOf": [{"$ref": f"#/$defs/{config_model_name}"}]} # type:ignore[dict-item]
)
except PydanticUserError as error:
raise UsePlotConfigError(function_name, error) # noqa: B904
json_schema["$defs"]["PerFunctionPlotConfig"]["properties"]["default_args_override"][
"properties"
] = general_kwargs
json_schema["$defs"]["PerFunctionPlotConfig"]["properties"]["default_args_override"][
"additionalProperties"
] = False
if output_folder is None:
from pyglotaran_extras import SCRIPT_DIR
output_folder = SCRIPT_DIR
else:
output_folder = Path(output_folder)
output_folder.mkdir(parents=True, exist_ok=True)
output_file = output_folder / file_name
output_file.write_text(json.dumps(json_schema, ensure_ascii=False), encoding="utf8")
return output_file