"""Code related to optional dependencies."""

import inspect
import locale
import os
from contextlib import contextmanager, suppress
from importlib.util import find_spec
from subprocess import PIPE, Popen
from typing import Optional

__all__ = [

typing_extensions_support = find_spec("typing_extensions") is not None
typeshed_client_support = find_spec("typeshed_client") is not None
jsonschema_support = find_spec("jsonschema") is not None
jsonnet_support = find_spec("_jsonnet") is not None
url_support = find_spec("requests") is not None
docstring_parser_support = find_spec("docstring_parser") is not None
argcomplete_support = find_spec("argcomplete") is not None
fsspec_support = find_spec("fsspec") is not None
ruyaml_support = find_spec("ruyaml") is not None
omegaconf_support = find_spec("omegaconf") is not None
reconplogger_support = find_spec("reconplogger") is not None
attrs_support = find_spec("attrs") is not None

_config_read_mode = "fr"
_docstring_parse_options = {
    "style": None,
    "attribute_docstrings": False,

def typing_extensions_import(name):
    if typing_extensions_support:
        return getattr(__import__("typing_extensions"), name)
        return getattr(__import__("typing"), name, False)

[docs] def final(cls): """Decorator to make a class ``final``, i.e., it shouldn't be subclassed. It is the same as ```` or an equivalent implementation depending on the python version and whether typing-extensions is installed. """ setattr(cls, "__final__", True) return cls
def is_compatible_final(final) -> bool: @final class FinalClass: pass return getattr(FinalClass, "__final__", False) stdlib_final = typing_extensions_import("final") if stdlib_final and is_compatible_final(stdlib_final) and "SPHINX_BUILD" not in os.environ: final = stdlib_final # noqa: F811 def import_typeshed_client(): if typeshed_client_support: import typeshed_client return typeshed_client else: return __import__("argparse").Namespace(ImportedInfo=object, ModulePath=object, Resolver=object) class UndefinedException(Exception): pass def get_jsonschema_exceptions(): from jsonschema.exceptions import ValidationError return (ValidationError,) @contextmanager def missing_package_raise(package, importer): try: yield None except ImportError as ex: raise ImportError(f"{package} package is required by {importer} :: {ex}") from ex def import_jsonschema(importer): with missing_package_raise("jsonschema", importer): import jsonschema return jsonschema, jsonschema.Draft7Validator def import_jsonnet(importer): with missing_package_raise("jsonnet", importer): import _jsonnet return _jsonnet def import_requests(importer): with missing_package_raise("requests", importer): import requests return requests def import_docstring_parser(importer): with missing_package_raise("docstring-parser", importer): import docstring_parser return docstring_parser def import_argcomplete(importer): with missing_package_raise("argcomplete", importer): import argcomplete return argcomplete def import_fsspec(importer): with missing_package_raise("fsspec", importer): import fsspec return fsspec def import_ruyaml(importer): with missing_package_raise("ruyaml", importer): import ruyaml return ruyaml def import_reconplogger(importer): with missing_package_raise("reconplogger", importer): import reconplogger return reconplogger
[docs] def set_config_read_mode( urls_enabled: bool = False, fsspec_enabled: bool = False, ): """Enables/disables optional config read modes. Args: urls_enabled: Whether to read config files from URLs using requests package. fsspec_enabled: Whether to read config files from fsspec supported file systems. """ imports = { "u": import_requests, "s": import_fsspec, } def update_mode(flag, enabled): global _config_read_mode if enabled: imports[flag]("set_config_read_mode") if flag not in _config_read_mode: _config_read_mode = _config_read_mode.replace("f", "f" + flag) else: _config_read_mode = _config_read_mode.replace(flag, "") update_mode("u", urls_enabled) update_mode("s", fsspec_enabled)
[docs] def get_config_read_mode() -> str: """Returns the current config reading mode.""" return _config_read_mode
[docs] def set_docstring_parse_options(style=None, attribute_docstrings: Optional[bool] = None): """Sets options for docstring parsing. Args: style (docstring_parser.DocstringStyle): The docstring style to expect. attribute_docstrings: Whether to parse attribute docstrings (slower). """ global _docstring_parse_options dp = import_docstring_parser("set_docstring_parse_options") if style is not None: if not isinstance(style, dp.DocstringStyle): raise ValueError(f"Expected style to be of type {dp.DocstringStyle}.") _docstring_parse_options["style"] = style if attribute_docstrings is not None: if not isinstance(attribute_docstrings, bool): raise ValueError("Expected attribute_docstrings to be boolean.") _docstring_parse_options["attribute_docstrings"] = attribute_docstrings
def get_docstring_parse_options(): if _docstring_parse_options["style"] is None: dp = import_docstring_parser("get_docstring_parse_options") _docstring_parse_options["style"] = dp.DocstringStyle.AUTO return _docstring_parse_options def parse_docstring(component, params=False, logger=None): dp = import_docstring_parser("parse_docstring") options = get_docstring_parse_options() try: if params and options["attribute_docstrings"]: return dp.parse_from_object(component, style=options["style"]) else: return dp.parse(component.__doc__, style=options["style"]) except (ValueError, dp.ParseError) as ex: if logger: logger.debug(f"Failed parsing docstring for {component}: {ex}") return None def parse_docs(component, parent, logger): docs = {} if docstring_parser_support: doc_sources = [component] if inspect.isclass(parent) and component.__name__ == "__init__": doc_sources += [parent] for src in doc_sources: doc = parse_docstring(src, params=True, logger=logger) if doc: for param in doc.params: docs[param.arg_name] = param.description return docs def get_doc_short_description(function_or_class, method_name=None, logger=None): if docstring_parser_support: component = function_or_class if inspect.isclass(function_or_class): if not method_name: docstring = parse_docstring(function_or_class, params=False, logger=logger) if docstring and docstring.short_description: return docstring.short_description component = getattr(function_or_class, method_name or "__init__") docstring = parse_docstring(component, params=False, logger=logger) if docstring: return docstring.short_description return None def get_files_completer(): from argcomplete.completers import FilesCompleter return FilesCompleter() class FilesCompleterMethod: """Completer method for Action classes that should complete files.""" def completer(self, prefix, **kwargs): files_completer = get_files_completer() return sorted(files_completer(prefix, **kwargs)) def argcomplete_autocomplete(parser): if argcomplete_support: argcomplete = import_argcomplete("argcomplete_autocomplete") from ._common import parser_context with parser_context(load_value_mode=parser.parser_mode): argcomplete.autocomplete(parser) def argcomplete_namespace(caller, parser, namespace): if caller == "argcomplete": namespace.__class__ = __import__("jsonargparse").Namespace namespace = parser.merge_config(parser.get_defaults(skip_check=True), namespace).as_flat() return namespace def argcomplete_warn_redraw_prompt(prefix, message): argcomplete = import_argcomplete("argcomplete_warn_redraw_prompt") if prefix != "": argcomplete.warn(message) with suppress(Exception): proc = Popen(f"ps -p {os.getppid()} -oppid=".split(), stdout=PIPE, stderr=PIPE) stdout, _ = proc.communicate() shell_pid = int(stdout.decode().strip()) os.kill(shell_pid, 28) _ = "_" if locale.getlocale()[1] != "UTF-8" else "\xa0" return [_ + message.replace(" ", _), ""] def get_omegaconf_loader(): """Returns a yaml loader function based on OmegaConf which supports variable interpolation.""" import io from ._loaders_dumpers import yaml_load with missing_package_raise("omegaconf", "get_omegaconf_loader"): from omegaconf import OmegaConf def omegaconf_load(value): value_pyyaml = yaml_load(value) if isinstance(value_pyyaml, (str, int, float, bool)) or value_pyyaml is None: return value_pyyaml value_omegaconf = OmegaConf.to_object(OmegaConf.load(io.StringIO(value))) str_ref = {k: None for k in [value]} return value_pyyaml if value_omegaconf == str_ref else value_omegaconf return omegaconf_load annotated_alias = typing_extensions_import("_AnnotatedAlias") def is_annotated(typehint: type) -> bool: return annotated_alias and isinstance(typehint, annotated_alias) def get_pydantic_support() -> int: support = "0" if find_spec("pydantic"): try: from importlib.metadata import version support = version("pydantic") except ImportError: import pydantic support = pydantic.version.VERSION return int(support.split(".", 1)[0]) pydantic_support = get_pydantic_support() def is_pydantic_model(class_type) -> int: classes = inspect.getmro(class_type) if pydantic_support and inspect.isclass(class_type) else [] for cls in classes: if getattr(cls, "__module__", "").startswith("pydantic") and getattr(cls, "__name__", "") == "BaseModel": import pydantic if issubclass(cls, pydantic.BaseModel): return pydantic_support elif pydantic_support > 1 and issubclass(cls, pydantic.v1.BaseModel): return 1 return 0 def get_module(value): return getattr(type(value), "__module__", "").split(".", 1)[0] def is_annotated_validator(typehint: type) -> bool: return ( pydantic_support > 1 and is_annotated(typehint) and any(get_module(m) in {"pydantic", "annotated_types"} for m in typehint.__metadata__) # type: ignore ) def validate_annotated(value, typehint: type): from pydantic import TypeAdapter return TypeAdapter(typehint).validate_python(value)