"""Code related to loading and dumping."""
import inspect
import re
from typing import Any, Callable, Dict, Tuple, Type
import yaml
from ._common import load_value_mode, parent_parser
from .optionals import import_jsonnet, omegaconf_support
from .type_checking import ArgumentParser
__all__ = [
'set_loader',
'set_dumper',
]
regex_curly_comma = re.compile(' *[{},] *')
class DefaultLoader(getattr(yaml, 'CSafeLoader', yaml.SafeLoader)): # type: ignore
pass
# https://stackoverflow.com/a/37958106/2732151
def remove_implicit_resolver(cls, tag_to_remove):
if 'yaml_implicit_resolvers' not in cls.__dict__:
cls.yaml_implicit_resolvers = cls.yaml_implicit_resolvers.copy()
for first_letter, mappings in cls.yaml_implicit_resolvers.items():
cls.yaml_implicit_resolvers[first_letter] = [
(tag, regexp) for tag, regexp in mappings if tag != tag_to_remove
]
remove_implicit_resolver(DefaultLoader, 'tag:yaml.org,2002:timestamp')
remove_implicit_resolver(DefaultLoader, 'tag:yaml.org,2002:float')
DefaultLoader.add_implicit_resolver(
'tag:yaml.org,2002:float',
re.compile('''^(?:
[-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
|[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
|\\.[0-9_]+(?:[eE][-+][0-9]+)?
|[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
|[-+]?\\.(?:inf|Inf|INF)
|\\.(?:nan|NaN|NAN))$''', re.X),
list('-+0123456789.'),
)
def yaml_load(stream):
if stream.strip() == '-':
value = stream
else:
value = yaml.load(stream, Loader=DefaultLoader)
if isinstance(value, dict) and value and all(v is None for v in value.values()):
if len(value) == 1 and stream.strip() == next(iter(value.keys())) + ':':
value = stream
else:
keys = {k for k in regex_curly_comma.split(stream) if k}
if len(keys) > 0 and keys == set(value.keys()):
value = stream
return value
def jsonnet_load(stream, path='', ext_vars=None):
from .jsonnet import ActionJsonnet
ext_vars, ext_codes = ActionJsonnet.split_ext_vars(ext_vars)
_jsonnet = import_jsonnet('jsonnet_load')
try:
val = _jsonnet.evaluate_snippet(path, stream, ext_vars=ext_vars, ext_codes=ext_codes)
except RuntimeError as ex:
try:
return yaml_load(stream)
except pyyaml_exceptions:
raise ValueError(str(ex)) from ex
return yaml_load(val)
loaders: Dict[str, Callable] = {
'yaml': yaml_load,
'jsonnet': jsonnet_load,
}
pyyaml_exceptions = (yaml.YAMLError,)
jsonnet_exceptions = pyyaml_exceptions + (ValueError,)
loader_exceptions: Dict[str, Tuple[Type[Exception], ...]] = {
'yaml': pyyaml_exceptions,
'jsonnet': jsonnet_exceptions,
}
def get_load_value_mode() -> str:
mode = load_value_mode.get()
if mode is None:
mode = parent_parser.get().parser_mode
return mode
def get_loader_exceptions():
return loader_exceptions[get_load_value_mode()]
def load_value(value: str, simple_types: bool = False, **kwargs):
loader = loaders[get_load_value_mode()]
if kwargs:
params = set(list(inspect.signature(loader).parameters)[1:])
kwargs = {k: v for k, v in kwargs.items() if k in params}
loaded_value = loader(value, **kwargs)
if not simple_types and isinstance(loaded_value, (int, float, bool, str)):
loaded_value = value
return loaded_value
dump_yaml_kwargs = {
'default_flow_style': False,
'allow_unicode': True,
'sort_keys': False,
}
dump_json_kwargs = {
'ensure_ascii': False,
'sort_keys': False,
}
def yaml_dump(data):
return yaml.safe_dump(data, **dump_yaml_kwargs)
def yaml_comments_dump(data, parser):
dump = dumpers['yaml'](data)
formatter = parser.formatter_class(parser.prog)
return formatter.add_yaml_comments(dump)
def json_dump(data):
import json
return json.dumps(data, separators=(',', ':'), **dump_json_kwargs)
def json_indented_dump(data):
import json
return json.dumps(data, indent=2, **dump_json_kwargs)+'\n'
dumpers: Dict[str, Callable] = {
'yaml': yaml_dump,
'yaml_comments': yaml_comments_dump,
'json': json_dump,
'json_indented': json_indented_dump,
'jsonnet': json_indented_dump,
}
comment_prefix: Dict[str,str] = {
'yaml': '# ',
'yaml_comments': '# ',
'jsonnet': '// ',
}
def check_valid_dump_format(dump_format: str):
if dump_format not in {'parser_mode'}.union(set(dumpers)):
raise ValueError(f'Unknown output format "{dump_format}".')
def dump_using_format(parser: 'ArgumentParser', data: dict, dump_format: str) -> str:
if dump_format == 'parser_mode':
dump_format = parser.parser_mode if parser.parser_mode in dumpers else 'yaml'
args = (data, parser) if dump_format == 'yaml_comments' else (data,)
dump = dumpers[dump_format](*args)
if parser.dump_header and comment_prefix.get(dump_format):
prefix = comment_prefix[dump_format]
header = '\n'.join(prefix + line for line in parser.dump_header)
dump = f'{header}\n{dump}'
return dump
[docs]def set_loader(mode: str, loader_fn: Callable[[str], Any], exceptions: Tuple[Type[Exception], ...] = pyyaml_exceptions):
"""Sets the value loader function to be used when parsing with a certain mode.
The ``loader_fn`` function must accept as input a single str type parameter
and return any of the basic types {str, bool, int, float, list, dict, None}.
If this function is not based on PyYAML for things to work correctly the
exceptions types that can be raised when parsing a value fails should be
provided.
Args:
mode: The parser mode for which to set its loader function. Example: "yaml".
loader_fn: The loader function to set. Example: ``yaml.safe_load``.
exceptions: Exceptions that the loader can raise when load fails. Example: (yaml.parser.ParserError, yaml.scanner.ScannerError).
"""
loaders[mode] = loader_fn
loader_exceptions[mode] = exceptions
[docs]def set_dumper(format_name: str, dumper_fn: Callable[[Any], str]):
"""Sets the dumping function for a given format name.
Args:
format_name: Name to use for dumping with this function. Example: "yaml_custom".
dumper_fn: The dumper function to set. Example: ``yaml.safe_dump``.
"""
dumpers[format_name] = dumper_fn
def set_omegaconf_loader():
if omegaconf_support and 'omegaconf' not in loaders:
from .optionals import get_omegaconf_loader
set_loader('omegaconf', get_omegaconf_loader())