Source code for jsonargparse.loaders_dumpers

"""Code related to loading and dumping."""

import inspect
import json
import re
import yaml
from contextlib import contextmanager
from contextvars import ContextVar
from typing import Any, Callable, Dict, Tuple, Type

from .optionals import dump_preserve_order_support, import_jsonnet
from .type_checking import ArgumentParser


__all__ = [
    'set_loader',
    'set_dumper',
]


load_value_mode: ContextVar = ContextVar('load_value_mode')
regex_curly_comma = re.compile(' *[{},] *')


@contextmanager
def load_value_context(mode):
    t = load_value_mode.set(mode)
    try:
        yield
    finally:
        load_value_mode.reset(t)


class DefaultLoader(getattr(yaml, 'CSafeLoader', yaml.SafeLoader)):  # type: ignore
    pass


# https://stackoverflow.com/a/37958106/2732151
def remove_implicit_resolver(cls, tag_to_remove):
    if 'yaml_implicit_resolvers' not in cls.__dict__:
        cls.yaml_implicit_resolvers = cls.yaml_implicit_resolvers.copy()

    for first_letter, mappings in cls.yaml_implicit_resolvers.items():
        cls.yaml_implicit_resolvers[first_letter] = [
            (tag, regexp) for tag, regexp in mappings if tag != tag_to_remove
        ]


remove_implicit_resolver(DefaultLoader, 'tag:yaml.org,2002:timestamp')
remove_implicit_resolver(DefaultLoader, 'tag:yaml.org,2002:float')


DefaultLoader.add_implicit_resolver(
    'tag:yaml.org,2002:float',
    re.compile('''^(?:
     [-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
    |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
    |\\.[0-9_]+(?:[eE][-+][0-9]+)?
    |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
    |[-+]?\\.(?:inf|Inf|INF)
    |\\.(?:nan|NaN|NAN))$''', re.X),
    list('-+0123456789.'),
)


def yaml_load(stream):
    if stream.strip() == '-':
        value = stream
    else:
        value = yaml.load(stream, Loader=DefaultLoader)
    if isinstance(value, dict) and all(v is None for v in value.values()):
        keys = {k for k in regex_curly_comma.split(stream) if k}
        if len(keys) > 0 and keys == set(value.keys()):
            value = stream
    return value


def jsonnet_load(stream, path='', ext_vars=None):
    from .jsonnet import ActionJsonnet
    ext_vars, ext_codes = ActionJsonnet.split_ext_vars(ext_vars)
    _jsonnet = import_jsonnet('jsonnet_load')
    try:
        val = _jsonnet.evaluate_snippet(path, stream, ext_vars=ext_vars, ext_codes=ext_codes)
    except RuntimeError as ex:
        try:
            return yaml_load(stream)
        except pyyaml_exceptions:
            raise ValueError(str(ex)) from ex
    return yaml_load(val)


loaders: Dict[str, Callable] = {
    'yaml': yaml_load,
    'jsonnet': jsonnet_load,
}

pyyaml_exceptions = (yaml.parser.ParserError, yaml.scanner.ScannerError)
jsonnet_exceptions = pyyaml_exceptions + (ValueError,)

loader_exceptions: Dict[str, Tuple[Type[Exception], ...]] = {
    'yaml': pyyaml_exceptions,
    'jsonnet': jsonnet_exceptions,
}


def get_loader_exceptions():
    return loader_exceptions[load_value_mode.get()]


def load_value(value: str, **kwargs):
    loader = loaders[load_value_mode.get()]
    if kwargs:
        params = set(list(inspect.signature(loader).parameters)[1:])
        kwargs = {k: v for k, v in kwargs.items() if k in params}
    return loader(value, **kwargs)


dump_yaml_kwargs = {
    'default_flow_style': False,
    'allow_unicode': True,
    'sort_keys': False if dump_preserve_order_support else True,
}

dump_json_kwargs = {
    'ensure_ascii': False,
    'sort_keys': False if dump_preserve_order_support else True,
}


def yaml_dump(data):
    return yaml.safe_dump(data, **dump_yaml_kwargs)


def yaml_comments_dump(data, parser):
    dump = dumpers['yaml'](data)
    formatter = parser.formatter_class(parser.prog)
    return formatter.add_yaml_comments(dump)


def json_dump(data):
    return json.dumps(data, separators=(',', ':'), **dump_json_kwargs)


def json_indented_dump(data):
    return json.dumps(data, indent=2, **dump_json_kwargs)+'\n'


dumpers: Dict[str, Callable] = {
    'yaml': yaml_dump,
    'yaml_comments': yaml_comments_dump,
    'json': json_dump,
    'json_indented': json_indented_dump,
    'jsonnet': json_indented_dump,
}

comment_prefix: Dict[str,str] = {
    'yaml': '# ',
    'yaml_comments': '# ',
    'jsonnet': '// ',
}


def check_valid_dump_format(dump_format: str):
    if dump_format not in {'parser_mode'}.union(set(dumpers.keys())):
        raise ValueError(f'Unknown output format "{dump_format}".')


def dump_using_format(parser: 'ArgumentParser', data: dict, dump_format: str) -> str:
    if dump_format == 'parser_mode':
        dump_format = parser.parser_mode if parser.parser_mode in dumpers else 'yaml'
    args = (data, parser) if dump_format == 'yaml_comments' else (data,)
    dump = dumpers[dump_format](*args)
    if parser.dump_header and comment_prefix.get(dump_format):
        prefix = comment_prefix[dump_format]
        header = '\n'.join(prefix + line for line in parser.dump_header)
        dump = f'{header}\n{dump}'
    return dump


[docs]def set_loader(mode: str, loader_fn: Callable[[str], Any], exceptions: Tuple[Type[Exception], ...] = pyyaml_exceptions): """Sets the value loader function to be used when parsing with a certain mode. The ``loader_fn`` function must accept as input a single str type parameter and return any of the basic types {str, bool, int, float, list, dict, None}. If this function is not based on PyYAML for things to work correctly the exceptions types that can be raised when parsing a value fails should be provided. Args: mode: The parser mode for which to set its loader function. Example: "yaml". loader_fn: The loader function to set. Example: ``yaml.safe_load``. exceptions: Exceptions that the loader can raise when load fails. Example: (yaml.parser.ParserError, yaml.scanner.ScannerError). """ loaders[mode] = loader_fn loader_exceptions[mode] = exceptions
[docs]def set_dumper(format_name: str, dumper_fn: Callable[[Any], str]): """Sets the dumping function for a given format name. Args: format_name: Name to use for dumping with this function. Example: "yaml_custom". dumper_fn: The dumper function to set. Example: ``yaml.safe_dump``. """ dumpers[format_name] = dumper_fn