"""Collection of types and type generators."""
import inspect
import operator
import os
import pathlib
import re
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Pattern, Tuple, Type, Union
from ._common import is_final_class, is_subclass
from ._optionals import final, pydantic_support
from ._util import Path, get_import_path, get_private_kwargs, import_object
__all__ = [
"final",
"is_final_class",
"register_type",
"restricted_number_type",
"restricted_string_type",
"path_type",
"PositiveInt",
"NonNegativeInt",
"PositiveFloat",
"NonNegativeFloat",
"ClosedUnitInterval",
"OpenUnitInterval",
"NotEmptyStr",
"Email",
"Path_fr",
"Path_fc",
"Path_dw",
"Path_dc",
"Path_drw",
]
_operators1 = {
operator.gt: ">",
operator.ge: ">=",
operator.lt: "<",
operator.le: "<=",
operator.eq: "==",
operator.ne: "!=",
}
_operators2 = {v: k for k, v in _operators1.items()}
registered_types: Dict[tuple, type] = {}
registered_type_handlers: Dict[type, "RegisteredType"] = {}
registration_pending: Dict[str, Callable] = {}
def create_type(
name: str,
base_type: type,
check_value: Callable,
register_key: Optional[Tuple] = None,
docstring: Optional[str] = None,
extra_attrs: Optional[dict] = None,
) -> type:
if register_key in registered_types:
registered_type = registered_types[register_key]
if registered_type.__name__ != name:
raise ValueError(f"Same type already registered with a different name: {registered_type.__name__}.")
return registered_type
class TypeCore:
_check_value = check_value
def __new__(cls, v):
cls._check_value(cls, v)
return super().__new__(cls, v)
if extra_attrs is not None:
for key, value in extra_attrs.items():
setattr(TypeCore, key, value)
created_type = type(name, (TypeCore, base_type), {"__doc__": docstring})
add_type(created_type, register_key)
return created_type
[docs]def restricted_number_type(
name: Optional[str],
base_type: type,
restrictions: Union[Tuple, List[Tuple]],
join: str = "and",
docstring: Optional[str] = None,
) -> type:
"""Creates or returns an already registered restricted number type class.
Args:
name: Name for the type or None for an automatic name.
base_type: One of {int, float}.
restrictions: Tuples of pairs (comparison, reference), e.g. ('>', 0).
join: How to combine multiple comparisons, one of {'or', 'and'}.
docstring: Docstring for the type class.
Returns:
The created or retrieved type class.
"""
if base_type not in {int, float}:
raise ValueError("Expected base_type to be one of {int, float}.")
if join not in {"or", "and"}:
raise ValueError("Expected join to be one of {'or', 'and'}.")
restrictions = [restrictions] if isinstance(restrictions, tuple) else restrictions
if (
not isinstance(restrictions, list)
or not all(isinstance(x, tuple) and len(x) == 2 for x in restrictions)
or not all(x[0] in _operators2 and x[1] == base_type(x[1]) for x in restrictions)
):
raise ValueError(
"Expected restrictions to be a list of tuples each with a comparison operator "
f"(> >= < <= == !=) and a reference value of type {base_type.__name__}."
)
register_key = (tuple(sorted(restrictions)), base_type, join)
restrictions = [(_operators2[x[0]], x[1]) for x in restrictions]
expression = (" " + join + " ").join(["v" + _operators1[op] + str(ref) for op, ref in restrictions])
if name is None:
name = base_type.__name__
for num, (comparison, ref) in enumerate(restrictions):
name += "_" + join + "_" if num > 0 else "_"
name += comparison.__name__ + str(ref).replace(".", "")
extra_attrs = {
"_restrictions": restrictions,
"_expression": expression,
"_join": join,
"_type": base_type,
}
def check_value(cls, v):
if isinstance(v, bool):
raise ValueError(f"{v} not a number")
if cls._type == int and isinstance(v, float) and not float.is_integer(v):
raise ValueError(f"{v} not an integer")
vv = cls._type(v)
check = [comparison(vv, ref) for comparison, ref in cls._restrictions]
if (cls._join == "and" and not all(check)) or (cls._join == "or" and not any(check)):
raise ValueError(f"{v} does not conform to restriction {cls._expression}")
return create_type(
name=name,
base_type=base_type,
check_value=check_value,
register_key=register_key,
docstring=docstring,
extra_attrs=extra_attrs,
)
[docs]def restricted_string_type(
name: str,
regex: Union[str, Pattern],
docstring: Optional[str] = None,
) -> type:
"""Creates or returns an already registered restricted string type class.
Args:
name: Name for the type or None for an automatic name.
regex: Regular expression that the string must match.
docstring: Docstring for the type class.
Returns:
The created or retrieved type class.
"""
if isinstance(regex, str):
regex = re.compile(regex)
expression = "matching " + regex.pattern
extra_attrs = {
"_regex": regex,
"_expression": expression,
"_type": str,
}
def check_value(cls, v):
if not cls._regex.match(v):
raise ValueError(f"{v} does not match regular expression {cls._regex.pattern}")
return create_type(
name=name,
base_type=str,
check_value=check_value,
register_key=(expression, str),
docstring=docstring,
extra_attrs=extra_attrs,
)
def _is_path_type(value, type_class):
return isinstance(value, Path)
[docs]def path_type(mode: str, docstring: Optional[str] = None, **kwargs) -> type:
"""Creates or returns an already registered path type class.
Args:
mode: The required type and access permissions among [fdrwxcuFDRWX].
docstring: Docstring for the type class.
Returns:
The created or retrieved type class.
"""
Path._check_mode(mode)
name = "Path_" + mode
key_name = "path " + "".join(sorted(mode))
skip_check = get_private_kwargs(kwargs, skip_check=False)
if skip_check:
from ._deprecated import path_skip_check_deprecation
path_skip_check_deprecation()
name += "_skip_check"
key_name += " skip_check"
register_key = (key_name, str)
if register_key in registered_types:
return registered_types[register_key]
class PathType(Path):
_expression = name
_mode = mode
_skip_check = skip_check
_type = str
def __init__(self, v, **k):
super().__init__(v, mode=self._mode, skip_check=self._skip_check, **k)
restricted_type = type(name, (PathType,), {"__doc__": docstring})
add_type(restricted_type, register_key, type_check=_is_path_type)
return restricted_type
class RegisteredType:
def __init__(
self,
type_class: Any,
serializer: Callable,
deserializer: Optional[Callable],
deserializer_exceptions: Union[Type[Exception], Tuple[Type[Exception], ...]],
type_check: Callable,
):
self.type_class = type_class
self.serializer = serializer
self.base_deserializer = type_class if deserializer is None else deserializer
self.deserializer_exceptions = deserializer_exceptions
self.type_check = type_check
def __eq__(self, other):
return all(getattr(self, k) == getattr(other, k) for k in ["type_class", "serializer", "base_deserializer"])
def is_value_of_type(self, value):
return self.type_check(value, self.type_class)
def deserializer(self, value):
try:
return self.base_deserializer(value)
except self.deserializer_exceptions as ex:
type_class_name = getattr(self.type_class, "__name__", str(self.type_class))
raise ValueError(f"Not of type {type_class_name}: {ex}") from ex
[docs]def register_type(
type_class: Any,
serializer: Callable = str,
deserializer: Optional[Callable] = None,
deserializer_exceptions: Union[Type[Exception], Tuple[Type[Exception], ...]] = (
ValueError,
TypeError,
AttributeError,
),
type_check: Callable = lambda v, t: v.__class__ == t,
fail_already_registered: bool = True,
uniqueness_key: Optional[Tuple] = None,
) -> None:
"""Registers a new type for use in jsonargparse parsers.
Args:
type_class: The type object to be registered.
serializer: Function that converts an instance of the class to a basic type.
deserializer: Function that converts a basic type to an instance of the class. Default instantiates type_class.
deserializer_exceptions: Exceptions that deserializer raises when it fails.
type_check: Function to check if a value is of type_class. Gets as arguments the value and type_class.
fail_already_registered: Whether to fail if type has already been registered.
uniqueness_key: Key to determine uniqueness of type.
"""
type_handler = RegisteredType(type_class, serializer, deserializer, deserializer_exceptions, type_check)
fail_already_registered = globals().get("_fail_already_registered", fail_already_registered)
if not uniqueness_key and fail_already_registered and get_registered_type(type_class):
if type_handler == registered_type_handlers[type_class]:
return
raise ValueError(f'Type "{type_class}" already registered with different serializer and/or deserializer.')
registered_type_handlers[type_class] = type_handler
if uniqueness_key is not None:
registered_types[uniqueness_key] = type_class
def register_type_on_first_use(import_path: str, *args, **kwargs):
registration_pending[import_path] = lambda: register_type(
import_object(import_path),
*args,
**kwargs,
)
def get_registered_type(type_class) -> Optional[RegisteredType]:
if type_class not in registered_type_handlers:
from contextlib import suppress
with suppress(AttributeError, ValueError):
import_path = get_import_path(type_class)
if import_path in registration_pending:
registration_pending.pop(import_path)()
return registered_type_handlers.get(type_class)
def add_type(type_class: Type, uniqueness_key: Optional[Tuple], type_check: Optional[Callable] = None):
assert uniqueness_key not in registered_types
if type_class.__name__ in globals():
raise ValueError(f'Type name "{type_class.__name__}" clashes with name already defined in jsonargparse.typing.')
globals()[type_class.__name__] = type_class
kwargs = {"uniqueness_key": uniqueness_key}
if type_check is not None:
kwargs["type_check"] = type_check # type: ignore
register_type(type_class, type_class._type, **kwargs) # type: ignore
_fail_already_registered = False
PositiveInt = restricted_number_type("PositiveInt", int, (">", 0), docstring="int restricted to be >0")
NonNegativeInt = restricted_number_type("NonNegativeInt", int, (">=", 0), docstring="int restricted to be ≥0")
PositiveFloat = restricted_number_type("PositiveFloat", float, (">", 0), docstring="float restricted to be >0")
NonNegativeFloat = restricted_number_type("NonNegativeFloat", float, (">=", 0), docstring="float restricted to be ≥0")
ClosedUnitInterval = restricted_number_type(
"ClosedUnitInterval", float, [(">=", 0), ("<=", 1)], docstring="float restricted to be ≥0 and ≤1"
)
OpenUnitInterval = restricted_number_type(
"OpenUnitInterval", float, [(">", 0), ("<", 1)], docstring="float restricted to be >0 and <1"
)
NotEmptyStr = restricted_string_type(
"NotEmptyStr", r"^.*[^ ].*$", docstring=r"str restricted to not-empty pattern ^.*[^ ].*$"
)
Email = restricted_string_type(
"Email", r"^[^@ ]+@[^@ ]+\.[^@ ]+$", docstring=r"str restricted to the email pattern ^[^@ ]+@[^@ ]+\.[^@ ]+$"
)
Path_fr = path_type("fr", docstring="path to a file that exists and is readable")
Path_fc = path_type("fc", docstring="path to a file that can be created if it does not exist")
Path_dw = path_type("dw", docstring="path to a directory that exists and is writeable")
Path_dc = path_type("dc", docstring="path to a directory that can be created if it does not exist")
Path_drw = path_type("drw", docstring="path to a directory that exists and is readable and writeable")
register_type(os.PathLike, str, str)
register_type(complex)
register_type_on_first_use("uuid.UUID")
for _path in [pathlib.Path, pathlib.PosixPath, pathlib.WindowsPath]:
register_type(_path, str, _path, type_check=isinstance)
def timedelta_deserializer(value):
def raise_error():
raise ValueError(f'Expected a string with form "h:m:s" or "d days, h:m:s" but got "{value}"')
if not isinstance(value, str):
raise_error()
pattern = r"(?P<hours>\d+):(?P<minutes>\d+):(?P<seconds>\d[\.\d+]*)"
if "day" in value:
pattern = r"(?P<days>[-\d]+) day[s]*, " + pattern
match = re.match(pattern, value)
if not match:
raise_error()
kwargs = {key: float(val) for key, val in match.groupdict().items()}
from datetime import timedelta
return timedelta(**kwargs)
register_type_on_first_use("datetime.timedelta", deserializer=timedelta_deserializer)
def bytes_serializer(value: Union[bytes, bytearray]) -> str:
from base64 import b64encode
return b64encode(value).decode()
def bytes_deserializer(value: str) -> bytes:
from base64 import b64decode
return b64decode(value)
def bytearray_deserializer(value: str) -> bytearray:
from base64 import b64decode
return bytearray(b64decode(value))
register_type_on_first_use("builtins.bytes", serializer=bytes_serializer, deserializer=bytes_deserializer)
register_type_on_first_use("builtins.bytearray", serializer=bytes_serializer, deserializer=bytearray_deserializer)
def range_serializer(value):
if value.step == 1:
if value.start == 0:
return f"range({value.stop})"
return f"range({value.start}, {value.stop})"
return f"range({value.start}, {value.stop}, {value.step})"
re_range_stop = re.compile(r"^(-?\d+)$")
re_range_start_stop = re.compile(r"^(-?\d+),(-?\d+)$")
re_range_start_stop_step = re.compile(r"^(-?\d+),(-?\d+),(-?\d+)$")
def range_deserializer(value):
value = value.strip()
if value.startswith("range(") and value.endswith(")"):
value = value[6:-1].replace(" ", "")
match = re_range_stop.match(value)
if match:
return range(int(match[1]))
match = re_range_start_stop.match(value)
if match:
return range(int(match[1]), int(match[2]))
match = re_range_start_stop_step.match(value)
if match:
return range(int(match[1]), int(match[2]), int(match[3]))
raise ValueError("Expected 'range(<stop>)' or 'range(<start>, <stop>)' or 'range(<start>, <stop>, <step>)'")
register_type(range, serializer=range_serializer, deserializer=range_deserializer)
pydantic_types: Tuple[type, ...] = tuple()
if pydantic_support:
import pydantic
for module in [pydantic.types, pydantic.networks]:
pydantic_types += tuple(
v for k, v in vars(module).items() if inspect.isclass(v) and k in module.__all__ and not issubclass(v, Enum)
)
def pydantic_deserializer(type_class):
from pydantic import create_model # pylint: disable=no-name-in-module
pydantic_model = create_model("pydantic_model", pydantic_field=(type_class, ...))
def deserialize(value):
return pydantic_model(pydantic_field=value).pydantic_field
return deserialize
def pydantic_serializer(type_class):
serializer = str
for base in [int, float, bool, list, dict, (set, list)]:
if not isinstance(base, tuple):
base = (base, base)
if issubclass(type_class, base[0]):
serializer = base[1]
break
return serializer
def is_pydantic_type(type_class):
return pydantic_support and is_subclass(type_class, pydantic_types)
def register_pydantic_type(type_class):
if not is_pydantic_type(type_class):
return
if not get_registered_type(type_class):
from pydantic import ValidationError
register_type(
type_class=type_class,
serializer=pydantic_serializer(type_class),
deserializer=pydantic_deserializer(type_class),
deserializer_exceptions=ValidationError,
)
del _fail_already_registered