import argparse
import dataclasses
import inspect
import logging
import os
import sys
from contextlib import contextmanager
from contextvars import ContextVar
from typing import ( # type: ignore[attr-defined]
Dict,
Generic,
List,
Optional,
Tuple,
Type,
TypeVar,
Union,
_GenericAlias,
)
from ._namespace import Namespace
from ._optionals import import_reconplogger, reconplogger_support
from ._type_checking import ArgumentParser
__all__ = [
"LoggerProperty",
"null_logger",
]
ClassType = TypeVar("ClassType")
if sys.version_info < (3, 8):
from typing import Callable
InstantiatorCallable = Callable[..., ClassType]
else:
from typing import Protocol
class InstantiatorCallable(Protocol):
def __call__(self, class_type: Type[ClassType], *args, **kwargs) -> ClassType:
pass # pragma: no cover
InstantiatorsDictType = Dict[Tuple[type, bool], InstantiatorCallable]
parent_parser: ContextVar["ArgumentParser"] = ContextVar("parent_parser")
parser_capture: ContextVar[bool] = ContextVar("parser_capture", default=False)
defaults_cache: ContextVar[Optional[Namespace]] = ContextVar("defaults_cache", default=None)
lenient_check: ContextVar[Union[bool, str]] = ContextVar("lenient_check", default=False)
load_value_mode: ContextVar[Optional[str]] = ContextVar("load_value_mode", default=None)
class_instantiators: ContextVar[Optional[InstantiatorsDictType]] = ContextVar("class_instantiators")
nested_links: ContextVar[List[dict]] = ContextVar("nested_links", default=[])
parser_context_vars = dict(
parent_parser=parent_parser,
parser_capture=parser_capture,
defaults_cache=defaults_cache,
lenient_check=lenient_check,
load_value_mode=load_value_mode,
class_instantiators=class_instantiators,
nested_links=nested_links,
)
@contextmanager
def parser_context(**kwargs):
context_var_tokens = []
for name, value in kwargs.items():
context_var = parser_context_vars[name]
token = context_var.set(value)
context_var_tokens.append((context_var, token))
try:
yield
finally:
for context_var, token in context_var_tokens:
context_var.reset(token)
def is_subclass(cls, class_or_tuple) -> bool:
"""Extension of issubclass that supports non-class arguments."""
try:
return inspect.isclass(cls) and issubclass(cls, class_or_tuple)
except TypeError:
return False
[docs]
def is_final_class(cls) -> bool:
"""Checks whether a class is final, i.e. decorated with ``typing.final``."""
return getattr(cls, "__final__", False)
def is_generic_class(cls) -> bool:
return isinstance(cls, _GenericAlias) and getattr(cls, "__module__", "") != "typing"
def get_generic_origin(cls):
return cls.__origin__ if is_generic_class(cls) else cls
def is_dataclass_like(cls) -> bool:
if is_generic_class(cls):
return is_dataclass_like(cls.__origin__)
if not inspect.isclass(cls):
return False
if is_final_class(cls):
return True
classes = [c for c in inspect.getmro(cls) if c not in {object, Generic}]
all_dataclasses = all(dataclasses.is_dataclass(c) for c in classes)
if not all_dataclasses:
from ._optionals import attrs_support, is_pydantic_model
if is_pydantic_model(cls):
return True
if attrs_support:
import attrs
if attrs.has(cls):
return True
return all_dataclasses
def default_class_instantiator(class_type: Type[ClassType], *args, **kwargs) -> ClassType:
return class_type(*args, **kwargs)
class ClassInstantiator:
def __init__(self, instantiators: InstantiatorsDictType) -> None:
self.instantiators = instantiators
def __call__(self, class_type: Type[ClassType], *args, **kwargs) -> ClassType:
for (cls, subclasses), instantiator in self.instantiators.items():
if class_type is cls or (subclasses and is_subclass(class_type, cls)):
return instantiator(class_type, *args, **kwargs)
return default_class_instantiator(class_type, *args, **kwargs)
def get_class_instantiator() -> InstantiatorCallable:
instantiators = class_instantiators.get()
if not instantiators:
return default_class_instantiator
return ClassInstantiator(instantiators)
# logging
logging_levels = {"CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"}
null_logger = logging.getLogger("jsonargparse_null_logger")
null_logger.addHandler(logging.NullHandler())
null_logger.parent = None
def setup_default_logger(data, level, caller):
name = caller
if isinstance(data, str):
name = data
elif isinstance(data, dict) and "name" in data:
name = data["name"]
logger = logging.getLogger(name)
logger.parent = None
if len(logger.handlers) == 0:
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s"))
logger.addHandler(handler)
level = getattr(logging, level)
for handler in logger.handlers:
handler.setLevel(level)
return logger
def parse_logger(logger: Union[bool, str, dict, logging.Logger], caller):
if not isinstance(logger, (bool, str, dict, logging.Logger)):
raise ValueError(f"Expected logger to be an instance of (bool, str, dict, logging.Logger), but got {logger}.")
if isinstance(logger, dict) and len(set(logger.keys()) - {"name", "level"}) > 0:
value = {k: v for k, v in logger.items() if k not in {"name", "level"}}
raise ValueError(f"Unexpected data to configure logger: {value}.")
if logger is False:
return null_logger
level = "WARNING"
if isinstance(logger, dict) and "level" in logger:
level = logger["level"]
if level not in logging_levels:
raise ValueError(f"Got logger level {level!r} but must be one of {logging_levels}.")
if (logger is True or (isinstance(logger, dict) and "name" not in logger)) and reconplogger_support:
kwargs = {"level": "DEBUG", "reload": True} if debug_mode_active() else {}
logger = import_reconplogger("parse_logger").logger_setup(**kwargs)
if not isinstance(logger, logging.Logger):
logger = setup_default_logger(logger, level, caller)
return logger
[docs]
class LoggerProperty:
"""Class designed to be inherited by other classes to add a logger property."""
[docs]
def __init__(self, *args, logger: Union[bool, str, dict, logging.Logger] = False, **kwargs):
"""Initializer for LoggerProperty class."""
self.logger = logger # type: ignore[assignment]
super().__init__(*args, **kwargs)
@property
def logger(self) -> logging.Logger:
"""The logger property for the class.
:getter: Returns the current logger.
:setter: Sets the given logging.Logger as logger or sets the default logger
if given True/str(logger name)/dict(name, level), or disables logging
if given False.
Raises:
ValueError: If an invalid logger value is given.
"""
return self._logger
@logger.setter
def logger(self, logger: Union[bool, str, dict, logging.Logger]):
if logger is None:
from ._deprecated import deprecation_warning, logger_property_none_message
deprecation_warning((LoggerProperty.logger, None), logger_property_none_message, stacklevel=2)
logger = False
if not logger and debug_mode_active():
logger = {"level": "DEBUG"}
self._logger = parse_logger(logger, type(self).__name__)
def debug_mode_active() -> bool:
return os.getenv("JSONARGPARSE_DEBUG", "").lower() not in {"", "false", "no", "0"}
if debug_mode_active():
os.environ["LOGGER_LEVEL"] = "DEBUG" # pragma: no cover
# base classes
class Action(LoggerProperty, argparse.Action):
"""Base for jsonargparse Action classes."""
def _check_type_(self, value, **kwargs):
if not hasattr(self, "_check_type_kwargs"):
self._check_type_kwargs = set(inspect.signature(self._check_type).parameters.keys())
kwargs = {k: v for k, v in kwargs.items() if k in self._check_type_kwargs}
return self._check_type(value, **kwargs)