"""Classes and functions related to namespace objects."""
import argparse
from contextlib import contextmanager
from typing import (
Any,
Callable,
Dict,
Iterator,
List,
Optional,
Set,
Tuple,
Union,
overload,
)
__all__ = [
"Namespace",
"namespace_to_dict",
"dict_to_namespace",
"strip_meta",
]
meta_keys = {"__default_config__", "__path__", "__orig__"}
def split_key(key: str) -> List[str]:
return key.split(".")
def split_key_root(key: str) -> List[str]:
return key.split(".", 1)
def split_key_leaf(key: str) -> List[str]:
return key.rsplit(".", 1)
def is_meta_key(key: str) -> bool:
leaf_key = split_key_leaf(key)[-1]
return leaf_key in meta_keys
@overload
def strip_meta(cfg: "Namespace") -> "Namespace":
... # pragma: no cover
@overload
def strip_meta(cfg: Dict[str, Any]) -> Dict[str, Any]:
... # pragma: no cover
def recreate_branches(data, skip_keys=None):
new_data = data
if isinstance(data, (Namespace, dict)):
new_data = type(data)()
for key, val in getattr(data, "__dict__", data).items():
if skip_keys is None or key not in skip_keys:
new_data[key] = recreate_branches(val, skip_keys)
elif isinstance(data, list):
new_data = [recreate_branches(v, skip_keys) for v in data]
return new_data
@contextmanager
def patch_namespace():
namespace_class = argparse.Namespace
argparse.Namespace = Namespace
try:
yield
finally:
argparse.Namespace = namespace_class
[docs]def namespace_to_dict(namespace: "Namespace") -> Dict[str, Any]:
"""Returns a copy of a nested namespace converted into a nested dictionary."""
return namespace.clone().as_dict()
[docs]def dict_to_namespace(cfg_dict: Union[Dict[str, Any], "Namespace"]) -> "Namespace":
"""Converts a nested dictionary into a nested namespace."""
cfg_dict = recreate_branches(cfg_dict)
def expand_dict(cfg):
for k, v in cfg.items():
if isinstance(v, dict) and all(isinstance(k, str) for k in v.keys()):
cfg[k] = expand_dict(v)
elif isinstance(v, list):
for nn, vv in enumerate(v):
if isinstance(vv, dict) and all(isinstance(k, str) for k in vv.keys()):
cfg[k][nn] = expand_dict(vv)
return Namespace(**cfg)
return expand_dict(cfg_dict)
[docs]class Namespace(argparse.Namespace):
"""Extension of argparse's Namespace to support nesting and subscript access."""
[docs] def __init__(self, *args, **kwargs):
"""Initializer for Namespace objects.
Instantiating a Namespace with initial values most commonly is done by
providing keyword arguments, e.g. ``Namespace(name1=value1,
name2=value2)``. Alternatively a single positional ``Namespace`` or
``dict`` object can be given.
"""
if len(args) == 0:
super().__init__(**kwargs)
else:
if len(kwargs) != 0 or len(args) != 1 or not isinstance(args[0], (argparse.Namespace, dict)):
raise ValueError("Expected a single positional parameter of type Namespace or dict.")
for key, val in args[0].items() if type(args[0]) is dict else vars(args[0]).items():
self[key] = val
def _parse_key(self, key: str) -> Tuple[str, Optional["Namespace"], str]:
"""Parses a key for the nested namespace.
Args:
key: The key that is being parsed.
Returns:
Tuple with three elements corresponding to:
- The leaf key.
- The parent namespace object.
- The parent namespace key.
Raises:
KeyError: When given invalid key.
"""
if " " in key:
raise KeyError(f'Spaces not allowed in keys: "{key}".')
key_split = split_key(key)
if any(k == "" for k in key_split):
raise KeyError(f'Empty nested key: "{key}".')
key_split = [add_clash_mark(k) for k in key_split]
leaf_key = key_split[-1]
parent_ns: Namespace = self
parent_key = ""
if len(key_split) > 1:
parent_key = ".".join(key_split[:-1])
for subkey in key_split[:-1]:
if hasattr(parent_ns, subkey) or (isinstance(parent_ns, dict) and subkey in parent_ns):
parent_ns = parent_ns[subkey]
if parent_ns is not None and not isinstance(parent_ns, (Namespace, dict)):
return leaf_key, None, parent_key
else:
return leaf_key, None, parent_key
return leaf_key, parent_ns, parent_key
def _parse_required_key(self, key: str) -> Tuple[str, "Namespace", str]:
"""Same as _parse_key but raises KeyError if key not found."""
leaf_key, parent_ns, parent_key = self._parse_key(key)
if parent_ns is None or not hasattr(parent_ns, leaf_key):
raise KeyError(f'Key "{key}" not found in namespace.')
return leaf_key, parent_ns, parent_key
def _create_nested_namespace(self, key: str) -> "Namespace":
"""Creates a nested namespace object.
Args:
key: The key where the nested namespace is created.
Returns:
The created nested namespace.
"""
parent_ns = self
for key in split_key(key):
if not isinstance(getattr(parent_ns, key, None), Namespace):
setattr(parent_ns, key, Namespace())
parent_ns = getattr(parent_ns, key)
return parent_ns
def __setattr__(self, name: str, value: Any) -> None:
"""Sets an attribute to a possibly nested namespace."""
if "." in name:
self.__setitem__(name, value)
else:
super().__setattr__(add_clash_mark(name), value)
def __setitem__(self, key: str, item: Any) -> None:
"""Sets an item to a possibly nested namespace."""
leaf_key, parent_ns, parent_key = self._parse_key(key)
if parent_ns is None:
parent_ns = self._create_nested_namespace(parent_key)
if isinstance(parent_ns, dict):
parent_ns[leaf_key] = item
else:
setattr(parent_ns, leaf_key, item)
def __getitem__(self, key: str) -> Any:
"""Gets an item from a possibly nested namespace."""
leaf_key, parent_ns, _ = self._parse_required_key(key)
return getattr(parent_ns, leaf_key)
def __delitem__(self, key: str) -> None:
"""Deletes an item from a possibly nested namespace."""
leaf_key, parent_ns, _ = self._parse_key(key)
del parent_ns.__dict__[leaf_key]
def __contains__(self, key: str) -> bool:
"""Checks if an item is set possibly in a nested namespace."""
if not isinstance(key, str):
return False
try:
leaf_key, parent_ns, _ = self._parse_required_key(key)
except KeyError:
return False
return leaf_key in parent_ns.__dict__
def __bool__(self) -> bool:
"""Returns False if namespace is empty, otherwise True."""
return bool(self.__dict__)
[docs] def as_dict(self) -> Dict[str, Any]:
"""Converts the nested namespaces into nested dictionaries."""
dic = {}
for key, val in vars(self).items():
if isinstance(val, Namespace):
val = val.as_dict()
elif isinstance(val, dict) and val != {} and all(isinstance(v, Namespace) for v in val.values()):
val = {k: v.as_dict() for k, v in val.items()}
elif isinstance(val, list) and val != [] and all(isinstance(v, Namespace) for v in val):
val = [v.as_dict() for v in val]
dic[del_clash_mark(key)] = val
return dic
[docs] def as_flat(self) -> argparse.Namespace:
"""Converts the nested namespaces into a single argparse flat namespace."""
flat = argparse.Namespace()
for key, val in self.items():
setattr(flat, key, val)
return flat
[docs] def items(self, branches: bool = False) -> Iterator[Tuple[str, Any]]:
"""Returns a generator of all leaf (key, value) items, optionally including branches."""
for key, val in vars(self).items():
key = del_clash_mark(key)
if isinstance(val, Namespace):
if branches:
yield key, val
for subkey, subval in val.items(branches):
yield key + "." + del_clash_mark(subkey), subval
else:
yield key, val
[docs] def keys(self, branches: bool = False) -> Iterator[str]:
"""Returns a generator of all leaf keys, optionally including branches."""
for key, _ in self.items(branches):
yield key
[docs] def values(self, branches: bool = False) -> Iterator[Any]:
"""Returns a generator of all leaf values, optionally including branches."""
for _, val in self.items(branches):
yield val
[docs] def get_sorted_keys(self, branches: bool = True, key_filter: Callable = is_meta_key) -> List[str]:
"""Returns a list of keys sorted by descending depth.
Args:
branches: Whether to include branch keys instead of only leaves.
key_filter: Function that selects keys to exclude.
"""
keys = [k for k in self.keys() if not key_filter(k)]
if branches:
for key in [k for k in keys if "." in k]:
key_split = split_key(key)
for num in range(len(key_split) - 1):
parent_key = ".".join(key_split[: num + 1])
if parent_key not in keys:
keys.append(parent_key)
keys.sort(key=lambda x: -len(split_key(x)))
return keys
[docs] def clone(self) -> "Namespace":
"""Creates an new identical nested namespace."""
return recreate_branches(self)
[docs] def update(
self, value: Union["Namespace", Any], key: Optional[str] = None, only_unset: bool = False
) -> "Namespace":
"""Sets or replaces all items from the given nested namespace.
Args:
value: A namespace to update multiple values or other type to set in a single key.
key: Branch key where to set the value. Required if value is not namespace.
only_unset: Whether to only set the value if not set in namespace.
"""
if not isinstance(value, Namespace):
if not key:
raise KeyError("Key is required if value not a Namespace.")
if not only_unset or key not in self:
self[key] = value
else:
prefix = key + "." if key else ""
for key, val in value.items():
if not only_unset or prefix + key not in self:
self[prefix + key] = val
return self
def get(self, key: str, default: Any = None) -> Any:
try:
return self[key]
except (KeyError, TypeError):
return default
def get_value_and_parent(self, key: str) -> Tuple[Any, "Namespace", str]:
leaf_key, parent_ns, _ = self._parse_required_key(key)
return parent_ns[leaf_key], parent_ns, leaf_key
def pop(self, key: str, default: Any = None) -> Any:
leaf_key, parent_ns, _ = self._parse_key(key)
if not parent_ns:
return default
return parent_ns.__dict__.pop(leaf_key, default)
clash_names: Set[str] = set(dir(Namespace))
clash_mark = "\u200B"
def add_clash_mark(key: str) -> str:
if key in clash_names:
key = clash_mark + key
return key
def del_clash_mark(key: str) -> str:
if key[0] == clash_mark:
key = key[1:]
return key
# Temporal to provide backward compatibility in pytorch-lightning
import yaml # noqa: E402
yaml.SafeDumper.add_representer(Namespace, lambda d, x: d.represent_mapping("tag:yaml.org,2002:map", x.as_dict()))