"""Collection of types and type generators."""
import re
import operator
from enum import Enum
from typing import Dict, List, Tuple, Any, Union, Optional, Type, Pattern
from .util import Path, _issubclass
__all__ = [
'registered_types',
'restricted_number_type',
'restricted_string_type',
'path_type',
'PositiveInt',
'NonNegativeInt',
'PositiveFloat',
'NonNegativeFloat',
'ClosedUnitInterval',
'OpenUnitInterval',
'NotEmptyStr',
'Email',
'Path_fr',
'Path_fc',
'Path_dw',
]
_operators1 = {
operator.gt: '>',
operator.ge: '>=',
operator.lt: '<',
operator.le: '<=',
operator.eq: '==',
operator.ne: '!=',
}
_operators2 = {v: k for k, v in _operators1.items()}
_schema_operator_map = {
operator.gt: 'exclusiveMinimum',
operator.ge: 'minimum',
operator.lt: 'exclusiveMaximum',
operator.le: 'maximum',
}
registered_types = {} # type: Dict[Tuple, Any]
[docs]def restricted_number_type(
name: Optional[str],
base_type: Type,
restrictions: Union[Tuple, List[Tuple]],
join: str = 'and',
docstring: Optional[str] = None,
) -> Type:
"""Creates or returns an already registered restricted number type class.
Args:
name: Name for the type or None for an automatic name.
base_type: One of {int, float}.
restrictions: Tuples of pairs (comparison, reference), e.g. ('>', 0).
join: How to combine multiple comparisons, one of {'or', 'and'}.
docstring: Docstring for the type class.
Returns:
The created or retrieved type class.
"""
if base_type not in {int, float}:
raise ValueError('Expected base_type to be one of {int, float}.')
if join not in {'or', 'and'}:
raise ValueError("Expected join to be one of {'or', 'and'}.")
restrictions = [restrictions] if isinstance(restrictions, tuple) else restrictions
if not isinstance(restrictions, list) or \
not all(isinstance(x, tuple) and len(x) == 2 for x in restrictions) or \
not all(x[0] in _operators2 and x[1] == base_type(x[1]) for x in restrictions):
raise ValueError('Expected restrictions to be a list of tuples each with a comparison operator '
'(> >= < <= == !=) and a reference value of type '+base_type.__name__+'.')
register_key = (tuple(sorted(restrictions)), base_type, join)
if register_key in registered_types:
registered_type = registered_types[register_key]
if name is not None and registered_type.__name__ != name:
raise ValueError('Same type already registered with a different name: '+registered_type.__name__+'.')
return registered_type
restrictions = [(_operators2[x[0]], x[1]) for x in restrictions]
expression = (' '+join+' ').join(['v'+_operators1[op]+str(ref) for op, ref in restrictions])
class RestrictedNumber:
_restrictions = restrictions
_expression = expression
_type = base_type
_join = join
def __new__(cls, v):
def within_restriction(cls, v):
check = [comparison(v, ref) for comparison, ref in cls._restrictions]
if (cls._join == 'and' and not all(check)) or \
(cls._join == 'or' and not any(check)):
return False
return True
v = cls._type(v)
if not within_restriction(cls, v):
raise ValueError('invalid value, '+str(v)+' does not conform to restriction '+cls._expression)
return super().__new__(cls, v)
if name is None:
name = base_type.__name__
for num, (comparison, ref) in enumerate(restrictions):
name += '_'+join+'_' if num > 0 else '_'
name += comparison.__name__ + str(ref).replace('.', '')
restricted_type = type(name, (RestrictedNumber, base_type), {})
if docstring is not None:
restricted_type.__doc__ = docstring
register_type(register_key, restricted_type)
return restricted_type
[docs]def restricted_string_type(
name: str,
regex: Union[str, Pattern],
docstring: Optional[str] = None,
) -> Type:
"""Creates or returns an already registered restricted string type class.
Args:
name: Name for the type or None for an automatic name.
regex: Regular expression that the string must match.
docstring: Docstring for the type class.
Returns:
The created or retrieved type class.
"""
if isinstance(regex, str):
regex = re.compile(regex)
expression = 'matching '+regex.pattern
register_key = (expression, str)
if register_key in registered_types:
registered_type = registered_types[register_key]
if registered_type.__name__ != name:
raise ValueError('Same type already registered with a different name: '+registered_type.__name__+'.')
return registered_type
class RestrictedString:
_regex = regex
_expression = expression
_type = str
def __new__(cls, v):
v = str(v)
if not cls._regex.match(v):
raise ValueError('invalid value, "'+v+'" does not match regular expression '+cls._expression)
return super().__new__(cls, v)
restricted_type = type(name, (RestrictedString, str), {})
if docstring is not None:
restricted_type.__doc__ = docstring
register_type(register_key, restricted_type)
return restricted_type
[docs]def path_type(
mode: str,
docstring: Optional[str] = None,
) -> Type:
"""Creates or returns an already registered path type class.
Args:
mode: The required type and access permissions among [fdrwxcuFDRWX].
docstring: Docstring for the type class.
Returns:
The created or retrieved type class.
"""
name = 'Path_'+mode
register_key = ('path '+''.join(sorted(mode)), str)
if register_key in registered_types:
return registered_types[register_key]
class PathType(Path):
_expression = name
_mode = mode
_type = str
def __init__(self, v):
super().__init__(v, mode=self._mode)
def __repr__(self):
return self.rel_path
restricted_type = type(name, (PathType, str), {})
if docstring is not None:
restricted_type.__doc__ = docstring
register_type(register_key, restricted_type)
return restricted_type
def register_type(register_key, new_type):
assert register_key not in registered_types
if new_type.__name__ in globals():
raise ValueError('Type name "'+new_type.__name__+'" clashes with name already defined in jsonargparse.typing.')
globals()[new_type.__name__] = new_type
registered_types[register_key] = new_type
PositiveInt = restricted_number_type('PositiveInt', int, ('>', 0),
docstring='int restricted to be >0')
NonNegativeInt = restricted_number_type('NonNegativeInt', int, ('>=', 0),
docstring='int restricted to be ≥0')
PositiveFloat = restricted_number_type('PositiveFloat', float, ('>', 0),
docstring='float restricted to be >0')
NonNegativeFloat = restricted_number_type('NonNegativeFloat', float, ('>=', 0),
docstring='float restricted to be ≥0')
ClosedUnitInterval = restricted_number_type('ClosedUnitInterval', float, [('>=', 0), ('<=', 1)],
docstring='float restricted to be ≥0 and ≤1')
OpenUnitInterval = restricted_number_type('OpenUnitInterval', float, [('>', 0), ('<', 1)],
docstring='float restricted to be >0 and <1')
NotEmptyStr = restricted_string_type('NotEmptyStr', r'^.*[^ ].*$',
docstring=r'str restricted to not-empty pattern ^.*[^ ].*$')
Email = restricted_string_type('Email', r'^[^@ ]+@[^@ ]+\.[^@ ]+$',
docstring=r'str restricted to the email pattern ^[^@ ]+@[^@ ]+\.[^@ ]+$')
Path_fr = path_type('fr',
docstring='str pointing to a file that exists and is readable')
Path_fc = path_type('fc',
docstring='str pointing to a file that can be created if it does not exist')
Path_dw = path_type('dw',
docstring='str pointing to a directory that exists and is writeable')
def is_optional(annotation, ref_type):
"""Checks whether a type annotation is an optional for one type class."""
return hasattr(annotation, '__origin__') and \
annotation.__origin__ == Union and \
len(annotation.__args__) == 2 and \
any(type(None) == a for a in annotation.__args__) and \
any(_issubclass(a, ref_type) for a in annotation.__args__)
def annotation_to_schema(annotation) -> Optional[Dict[str, str]]:
"""Generates a json schema from a type annotation if possible.
Args:
annotation: The type annotation to process.
Returns:
The json schema or None if an unsupported type.
"""
schema = None
if issubclass(annotation, (int, float)) and \
hasattr(annotation, '_join') and annotation._join == 'and' and \
hasattr(annotation, '_restrictions') and \
all(x[0] in _schema_operator_map for x in annotation._restrictions):
schema = {'type': 'integer' if issubclass(annotation, int) else 'number'}
for comparison, ref in annotation._restrictions:
schema[_schema_operator_map[comparison]] = ref
elif issubclass(annotation, str) and hasattr(annotation, '_regex'):
schema = {'type': 'string', 'pattern': annotation._regex.pattern}
return schema
def type_in(obj, types_set):
return obj in types_set or (hasattr(obj, '__origin__') and obj.__origin__ in types_set)
def type_to_str(obj):
if _issubclass(obj, (bool, int, float, str, Enum)):
if hasattr(obj, '_expression'):
return obj._type.__name__ + ' ' + obj._expression
else:
return obj.__name__
elif obj is not None:
return re.sub(r'[a-z_.]+\.', '', str(obj)).replace('NoneType', 'null')