Source code for jsonargparse._loaders_dumpers

"""Code related to loading and dumping."""

import inspect
import re
from typing import Any, Callable, Dict, Tuple, Type

import yaml

from ._common import load_value_mode, parent_parser
from ._optionals import import_jsonnet, omegaconf_support
from ._type_checking import ArgumentParser

__all__ = [
    "set_loader",
    "set_dumper",
]


class DefaultLoader(getattr(yaml, "CSafeLoader", yaml.SafeLoader)):  # type: ignore[misc]
    pass


# https://stackoverflow.com/a/37958106/2732151
def remove_implicit_resolver(cls, tag_to_remove):
    if "yaml_implicit_resolvers" not in cls.__dict__:
        cls.yaml_implicit_resolvers = cls.yaml_implicit_resolvers.copy()

    for first_letter, mappings in cls.yaml_implicit_resolvers.items():
        cls.yaml_implicit_resolvers[first_letter] = [(tag, regexp) for tag, regexp in mappings if tag != tag_to_remove]


remove_implicit_resolver(DefaultLoader, "tag:yaml.org,2002:timestamp")
remove_implicit_resolver(DefaultLoader, "tag:yaml.org,2002:float")


DefaultLoader.add_implicit_resolver(
    "tag:yaml.org,2002:float",
    re.compile(
        """^(?:
     [-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
    |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
    |\\.[0-9_]+(?:[eE][-+][0-9]+)?
    |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
    |[-+]?\\.(?:inf|Inf|INF)
    |\\.(?:nan|NaN|NAN))$""",
        re.X,
    ),
    list("-+0123456789."),
)


def yaml_load(stream):
    if stream.strip() == "-":
        value = stream
    else:
        value = yaml.load(stream, Loader=DefaultLoader)
    if isinstance(value, dict) and value and all(v is None for v in value.values()):
        if len(value) == 1 and stream.strip() == next(iter(value.keys())) + ":":
            value = stream
        else:
            keys = set(stream.strip(" {}").replace(" ", "").split(","))
            if len(keys) > 0 and keys == set(value.keys()):
                value = stream
    return value


def jsonnet_load(stream, path="", ext_vars=None):
    from ._jsonnet import ActionJsonnet

    ext_vars, ext_codes = ActionJsonnet.split_ext_vars(ext_vars)
    _jsonnet = import_jsonnet("jsonnet_load")
    try:
        val = _jsonnet.evaluate_snippet(path, stream, ext_vars=ext_vars, ext_codes=ext_codes)
    except RuntimeError as ex:
        try:
            return yaml_load(stream)
        except pyyaml_exceptions:
            raise ValueError(str(ex)) from ex
    return yaml_load(val)


loaders: Dict[str, Callable] = {
    "yaml": yaml_load,
    "jsonnet": jsonnet_load,
}

pyyaml_exceptions = (yaml.YAMLError,)
jsonnet_exceptions = pyyaml_exceptions + (ValueError,)

loader_exceptions: Dict[str, Tuple[Type[Exception], ...]] = {
    "yaml": pyyaml_exceptions,
    "jsonnet": jsonnet_exceptions,
}


def get_load_value_mode() -> str:
    mode = load_value_mode.get()
    if mode is None:
        mode = parent_parser.get().parser_mode
    return mode


def get_loader_exceptions():
    return loader_exceptions[get_load_value_mode()]


def load_value(value: str, simple_types: bool = False, **kwargs):
    loader = loaders[get_load_value_mode()]
    if kwargs:
        params = set(list(inspect.signature(loader).parameters)[1:])
        kwargs = {k: v for k, v in kwargs.items() if k in params}
    loaded_value = loader(value, **kwargs)
    if not simple_types and isinstance(loaded_value, (int, float, bool, str)):
        loaded_value = value
    return loaded_value


dump_yaml_kwargs = {
    "default_flow_style": False,
    "allow_unicode": True,
    "sort_keys": False,
}

dump_json_kwargs = {
    "ensure_ascii": False,
    "sort_keys": False,
}


def yaml_dump(data):
    return yaml.safe_dump(data, **dump_yaml_kwargs)


def yaml_comments_dump(data, parser):
    dump = dumpers["yaml"](data)
    formatter = parser.formatter_class(parser.prog)
    return formatter.add_yaml_comments(dump)


def json_dump(data):
    import json

    return json.dumps(data, separators=(",", ":"), **dump_json_kwargs)


def json_indented_dump(data):
    import json

    return json.dumps(data, indent=2, **dump_json_kwargs) + "\n"


dumpers: Dict[str, Callable] = {
    "yaml": yaml_dump,
    "yaml_comments": yaml_comments_dump,
    "json": json_dump,
    "json_indented": json_indented_dump,
    "jsonnet": json_indented_dump,
}

comment_prefix: Dict[str, str] = {
    "yaml": "# ",
    "yaml_comments": "# ",
    "jsonnet": "// ",
}


def check_valid_dump_format(dump_format: str):
    if dump_format not in {"parser_mode"}.union(set(dumpers)):
        raise ValueError(f'Unknown output format "{dump_format}".')


def dump_using_format(parser: "ArgumentParser", data: dict, dump_format: str) -> str:
    if dump_format == "parser_mode":
        dump_format = parser.parser_mode if parser.parser_mode in dumpers else "yaml"
    args = (data, parser) if dump_format == "yaml_comments" else (data,)
    dump = dumpers[dump_format](*args)
    if parser.dump_header and comment_prefix.get(dump_format):
        prefix = comment_prefix[dump_format]
        header = "\n".join(prefix + line for line in parser.dump_header)
        dump = f"{header}\n{dump}"
    return dump


[docs] def set_loader(mode: str, loader_fn: Callable[[str], Any], exceptions: Tuple[Type[Exception], ...] = pyyaml_exceptions): """Sets the value loader function to be used when parsing with a certain mode. The ``loader_fn`` function must accept as input a single str type parameter and return any of the basic types {str, bool, int, float, list, dict, None}. If this function is not based on PyYAML for things to work correctly the exceptions types that can be raised when parsing a value fails should be provided. Args: mode: The parser mode for which to set its loader function. Example: "yaml". loader_fn: The loader function to set. Example: ``yaml.safe_load``. exceptions: Exceptions that the loader can raise when load fails. Example: (yaml.parser.ParserError, yaml.scanner.ScannerError). """ loaders[mode] = loader_fn loader_exceptions[mode] = exceptions
[docs] def set_dumper(format_name: str, dumper_fn: Callable[[Any], str]): """Sets the dumping function for a given format name. Args: format_name: Name to use for dumping with this function. Example: "yaml_custom". dumper_fn: The dumper function to set. Example: ``yaml.safe_dump``. """ dumpers[format_name] = dumper_fn
def set_omegaconf_loader(): if omegaconf_support and "omegaconf" not in loaders: from ._optionals import get_omegaconf_loader set_loader("omegaconf", get_omegaconf_loader())