Source code for jsonargparse._common

import dataclasses
import inspect
import sys
from contextlib import contextmanager
from contextvars import ContextVar
from typing import (  # type: ignore[attr-defined]
    Dict,
    Generic,
    List,
    Optional,
    Tuple,
    Type,
    TypeVar,
    Union,
    _GenericAlias,
)

from ._namespace import Namespace
from ._type_checking import ArgumentParser

ClassType = TypeVar("ClassType")

if sys.version_info < (3, 8):
    from typing import Callable

    InstantiatorCallable = Callable[..., ClassType]
else:
    from typing import Protocol

    class InstantiatorCallable(Protocol):
        def __call__(self, class_type: Type[ClassType], *args, **kwargs) -> ClassType:
            pass  # pragma: no cover


InstantiatorsDictType = Dict[Tuple[type, bool], InstantiatorCallable]


parent_parser: ContextVar["ArgumentParser"] = ContextVar("parent_parser")
parser_capture: ContextVar[bool] = ContextVar("parser_capture", default=False)
defaults_cache: ContextVar[Optional[Namespace]] = ContextVar("defaults_cache", default=None)
lenient_check: ContextVar[Union[bool, str]] = ContextVar("lenient_check", default=False)
load_value_mode: ContextVar[Optional[str]] = ContextVar("load_value_mode", default=None)
class_instantiators: ContextVar[Optional[InstantiatorsDictType]] = ContextVar("class_instantiators")
nested_links: ContextVar[List[dict]] = ContextVar("nested_links", default=[])


parser_context_vars = dict(
    parent_parser=parent_parser,
    parser_capture=parser_capture,
    defaults_cache=defaults_cache,
    lenient_check=lenient_check,
    load_value_mode=load_value_mode,
    class_instantiators=class_instantiators,
    nested_links=nested_links,
)


@contextmanager
def parser_context(**kwargs):
    context_var_tokens = []
    for name, value in kwargs.items():
        context_var = parser_context_vars[name]
        token = context_var.set(value)
        context_var_tokens.append((context_var, token))
    try:
        yield
    finally:
        for context_var, token in context_var_tokens:
            context_var.reset(token)


def is_subclass(cls, class_or_tuple) -> bool:
    """Extension of issubclass that supports non-class arguments."""
    try:
        return inspect.isclass(cls) and issubclass(cls, class_or_tuple)
    except TypeError:
        return False


[docs] def is_final_class(cls) -> bool: """Checks whether a class is final, i.e. decorated with ``typing.final``.""" return getattr(cls, "__final__", False)
def is_generic_class(cls) -> bool: return isinstance(cls, _GenericAlias) and getattr(cls, "__module__", "") != "typing" def get_generic_origin(cls): return cls.__origin__ if is_generic_class(cls) else cls def is_dataclass_like(cls) -> bool: if is_generic_class(cls): return is_dataclass_like(cls.__origin__) if not inspect.isclass(cls): return False if is_final_class(cls): return True classes = [c for c in inspect.getmro(cls) if c not in {object, Generic}] all_dataclasses = all(dataclasses.is_dataclass(c) for c in classes) if not all_dataclasses: from ._optionals import attrs_support, is_pydantic_model if is_pydantic_model(cls): return True if attrs_support: import attrs if attrs.has(cls): return True return all_dataclasses def default_class_instantiator(class_type: Type[ClassType], *args, **kwargs) -> ClassType: return class_type(*args, **kwargs) class ClassInstantiator: def __init__(self, instantiators: InstantiatorsDictType) -> None: self.instantiators = instantiators def __call__(self, class_type: Type[ClassType], *args, **kwargs) -> ClassType: for (cls, subclasses), instantiator in self.instantiators.items(): if class_type is cls or (subclasses and is_subclass(class_type, cls)): return instantiator(class_type, *args, **kwargs) return default_class_instantiator(class_type, *args, **kwargs) def get_class_instantiator() -> InstantiatorCallable: instantiators = class_instantiators.get() if not instantiators: return default_class_instantiator return ClassInstantiator(instantiators)