Source code for jsonargparse._link_arguments

"""Code related to argument linking."""

import re
from argparse import SUPPRESS
from argparse import Action as ArgparseAction
from collections import defaultdict
from contextlib import contextmanager
from contextvars import ContextVar
from importlib import import_module
from typing import Any, Callable, List, Optional, Tuple, Type, Union

from ._actions import (
    Action,
    ActionConfigFile,
    _ActionConfigLoad,
    _ActionPrintConfig,
    _ActionSubCommands,
    _find_parent_action,
    filter_default_actions,
)
from ._namespace import Namespace, split_key_leaf
from ._parameter_resolvers import get_signature_parameters
from ._type_checking import ArgumentParser, _ArgumentGroup

__all__ = ["ArgumentLinking"]


def find_parent_or_child_actions(
    parser: "ArgumentParser",
    key: str,
    exclude: Optional[Union[Type[ArgparseAction], Tuple[Type[ArgparseAction], ...]]] = None,
) -> Optional[List[ArgparseAction]]:
    found: List[ArgparseAction] = []
    action = _find_parent_action(parser, key, exclude=exclude)
    if action is not None:
        found = [action]
    else:
        actions = filter_default_actions(parser._actions)
        if exclude is not None:
            actions = [a for a in actions if not isinstance(a, exclude)]
        prefix = key + "."
        found = [a for a in actions if a.dest.startswith(prefix)]
    return None if found == [] else found


def find_subclass_action_or_class_group(
    parser: "ArgumentParser",
    key: str,
    exclude: Optional[Union[Type[ArgparseAction], Tuple[Type[ArgparseAction], ...]]] = None,
) -> Optional[Union[ArgparseAction, "_ArgumentGroup"]]:
    from ._typehints import ActionTypeHint

    action = _find_parent_action(parser, key, exclude=exclude)
    if ActionTypeHint.is_subclass_typehint(action):
        return action
    key_set = {key, split_key_leaf(key)[0]}
    for group in parser._action_groups:
        if getattr(group, "dest", None) in key_set and hasattr(group, "instantiate_class"):
            return group
    return None


apply_config_skip: ContextVar = ContextVar("apply_config_skip", default=False)


@contextmanager
def skip_apply_links():
    t = apply_config_skip.set(True)
    try:
        yield
    finally:
        apply_config_skip.reset(t)


class DirectedGraph:
    def __init__(self):
        self.nodes = []
        self.edges_dict = defaultdict(list)

    def add_edge(self, source, target):
        for node in [source, target]:
            if node not in self.nodes:
                self.nodes.append(node)
        self.edges_dict[self.nodes.index(source)].append(self.nodes.index(target))

    def get_topological_order(self):
        exploring = [False] * len(self.nodes)
        visited = [False] * len(self.nodes)
        order = []
        for source in range(len(self.nodes)):
            if not visited[source]:
                self.topological_sort(source, exploring, visited, order)
        return [self.nodes[n] for n in order]

    def topological_sort(self, source, exploring, visited, order):
        exploring[source] = True
        for target in self.edges_dict[source]:
            if exploring[target]:
                raise ValueError(
                    f"Graph has cycles, found while checking {self.nodes[source]} --> " + self.nodes[target]
                )
            elif not visited[target]:
                self.topological_sort(target, exploring, visited, order)
        visited[source] = True
        exploring[source] = False
        order.insert(0, source)


class ActionLink(Action):
    def __init__(
        self,
        parser,
        source: Union[str, Tuple[str, ...]],
        target: str,
        compute_fn: Optional[Callable] = None,
        apply_on: str = "parse",
    ):
        if not hasattr(parser, "_links_group"):
            parser._links_group = parser.add_argument_group("Linked arguments")
        self.parser = parser
        self._target = target
        self._source = source = (source,) if isinstance(source, str) else source
        self.apply_on = apply_on
        self.compute_fn = compute_fn
        self._initial_input_checks(source, target)

        # Set and check source actions or group
        exclude = (ActionLink, _ActionConfigLoad, _ActionSubCommands, ActionConfigFile)
        if apply_on == "instantiate":
            self.source = [(s, find_subclass_action_or_class_group(parser, s, exclude=exclude)) for s in source]
            for key, action in self.source:
                if action is None:
                    raise ValueError(
                        f"Links applied on instantiation require source to be a subclass action or a class group: {key}"
                    )
        else:
            self.source = [
                (s, find_parent_or_child_actions(parser, s, exclude=exclude)) for s in source  # type: ignore[misc]
            ]

        # Set and check target action
        self.target = (target, _find_parent_action(parser, target, exclude=exclude))
        for key, action in self.source + [self.target]:
            if action is None:
                raise ValueError(f'No action for key "{key}".')
        assert self.target[1] is not None

        from ._typehints import ActionTypeHint

        is_target_subclass = ActionTypeHint.is_subclass_typehint(self.target[1], all_subtypes=False, also_lists=True)
        valid_target_init_arg = is_target_subclass and target.startswith(f"{self.target[1].dest}.init_args.")
        valid_target_leaf = self.target[1].dest == target
        if not valid_target_leaf and is_target_subclass and not valid_target_init_arg:
            prefix = f"{self.target[1].dest}.init_args."
            raise ValueError(f'Target key expected to start with "{prefix}", got "{target}".')

        # Replace target action with link action
        if not is_target_subclass or valid_target_leaf:
            for key in self.target[1].option_strings:
                parser._option_string_actions[key] = self
            parser._actions[parser._actions.index(self.target[1])] = self
            for group in parser._action_groups:
                if self.target[1] in group._group_actions:
                    group._group_actions.remove(self.target[1])
                    if is_target_subclass:
                        help_dest = f"{self.target[1].dest}.help"
                        for action in group._group_actions:
                            if action.dest == help_dest:  # type: ignore[union-attr]
                                group._group_actions.remove(action)
                                break
                    if group._group_actions and all(isinstance(a, _ActionConfigLoad) for a in group._group_actions):
                        group.description = (
                            f"Group '{group._group_actions[0].dest}': All arguments are derived from links."
                        )
                        group._group_actions.clear()

        # Remove target from required
        if target in parser.required_args:
            parser.required_args.remove(target)
        if is_target_subclass and not valid_target_leaf:
            sub_add_kwargs = self.target[1].sub_add_kwargs  # type: ignore[attr-defined]
            if "linked_targets" not in sub_add_kwargs:
                sub_add_kwargs["linked_targets"] = set()
            subtarget = target.split(".init_args.", 1)[1]
            sub_add_kwargs["linked_targets"].add(subtarget)

        # Add link action to group to show in help
        parser._links_group._group_actions.append(self)

        # Check instantiation link does not create cycle
        if apply_on == "instantiate":
            try:
                self.instantiation_order(parser)
            except ValueError as ex:
                raise ValueError(f"Invalid link {source[0]} --> {target}: {ex}") from ex

        # Initialize link action
        if compute_fn is None:
            link_str = source[0]
        else:
            link_str = getattr(compute_fn, "__name__", str(compute_fn)) + "(" + ", ".join(source) + ")"
        link_str += " --> " + target

        help_str: Optional[str]
        if is_target_subclass and not valid_target_leaf:
            type_attr = None
            help_str = f"Use --{self.target[1].dest}.help CLASS_PATH for details."
        else:
            type_attr = getattr(self.target[1], "_typehint", self.target[1].type)
            help_str = self.target[1].help
            if help_str == import_module("jsonargparse._formatters").empty_help:
                help_str = f"Target argument '{self.target[1].dest}' lacks type and help"

        super().__init__(
            [link_str],
            dest=target,
            default=SUPPRESS,
            metavar=f"[applied on {self.apply_on}]",
            type=type_attr,
            help=help_str,
        )

    def get_kwargs(self) -> dict:
        return {
            "source": self._source,
            "target": self._target,
            "apply_on": self.apply_on,
            "compute_fn": self.compute_fn,
        }

    def _initial_input_checks(self, source, target):
        # Check apply_on
        if self.apply_on not in {"parse", "instantiate"}:
            raise ValueError("apply_on must be 'parse' or 'instantiate'.")

        # Check compute function
        if self.compute_fn is None and not (isinstance(source, str) or len(source) == 1):
            raise ValueError("Multiple source keys requires a compute function.")

        if self.apply_on == "parse":
            # Check source
            link_actions = self.parser._links_group._group_actions
            existing_targets = {a.target[0] for a in link_actions}
            if target in existing_targets:
                raise ValueError(f'Target "{target}" is already a target of another link.')
            for src in [source] if isinstance(source, str) else source:
                if src in existing_targets:
                    raise ValueError(f'Source "{src}" not allowed since it is the target of another link.')
            # Check target
            existing_sources = {s[0] for a in link_actions for s in a.source if a.apply_on == "parse"}
            if target in existing_sources:
                raise ValueError(f'Target "{target}" not allowed since it is the source of another link.')

    def __call__(self, *args, **kwargs):
        source = ", ".join(s[0] for s in self.source)
        raise TypeError(f'Linked "{self.target[0]}" must be given via "{source}".')

    def _check_type(self, value, cfg=None):
        return self.parser._check_value_key(self.target[1], value, self.target[0], cfg)

    def call_compute_fn(self, args):
        try:
            assert callable(self.compute_fn)
            return self.compute_fn(*args)
        except Exception as ex:
            link = self.option_strings[0]
            args = ", ".join(str(a) for a in args)
            raise ValueError(f"Call to compute_fn of link '{link}' with args ({args}) failed: {ex}") from ex

    @staticmethod
    def apply_parsing_links(parser: "ArgumentParser", cfg: Namespace) -> None:
        if apply_config_skip.get() or _ActionPrintConfig.is_print_config_requested(parser):
            return

        subcommand, subparser = _ActionSubCommands.get_subcommand(parser, cfg, fail_no_subcommand=False)
        if subcommand and subcommand in cfg:
            ActionLink.apply_parsing_links(subparser, cfg[subcommand])  # type: ignore[arg-type]
        if not hasattr(parser, "_links_group"):
            return
        for action in get_link_actions(parser, "parse"):
            from ._typehints import ActionTypeHint

            args = []
            skip_link = False
            for source_key, source_action in action.source:
                if ActionTypeHint.is_subclass_typehint(source_action[0]) and source_key not in cfg:  # type: ignore[index]
                    parser.logger.debug(
                        f"Link '{action.option_strings[0]}' ignored since source '{source_key}' not found in namespace."
                    )
                    skip_link = True
                    break
                for source_action_n in [a for a in source_action if a.dest in cfg]:  # type: ignore[union-attr]
                    parser._check_value_key(source_action_n, cfg[source_action_n.dest], source_action_n.dest, None)
                args.append(cfg[source_key])
            if skip_link:
                continue

            if action.compute_fn is None:
                value = args[0]
                # Automatic namespace to dict based on link target type hint
                target_key, target_action = action.target
                if isinstance(value, Namespace) and isinstance(target_action, ActionTypeHint):
                    same_key = target_key == target_action.dest
                    if (
                        same_key and target_action.is_mapping_typehint(target_action._typehint)
                    ) or target_action.is_init_arg_mapping_typehint(target_key, cfg):
                        value = value.as_dict()
            else:
                # Automatic namespace to dict based on compute_fn param type hint
                params = get_signature_parameters(action.compute_fn)
                for n, param in enumerate(params):
                    if (
                        n < len(args)
                        and isinstance(args[n], Namespace)
                        and ActionTypeHint.is_mapping_typehint(param.annotation)
                    ):
                        args[n] = args[n].as_dict()
                # Compute value
                value = action.call_compute_fn(args)
            ActionLink.set_target_value(action, value, cfg, parser.logger)
            parser.logger.debug(f"Applied link '{action.option_strings[0]}'.")

    @staticmethod
    def apply_instantiation_links(parser, cfg, target=None, order=None):
        if not hasattr(parser, "_links_group"):
            return

        applied_key = "__applied_instantiation_links__"
        applied_links = cfg.pop(applied_key) if applied_key in cfg else set()
        link_actions = get_link_actions(parser, "instantiate", skip=applied_links)
        if order and link_actions:
            link_actions = ActionLink.reorder(order, link_actions)

        for action in link_actions:
            target_key = action.target[0]
            if not (
                order or target_key == target or target_key.startswith(f"{target}.")
            ) or is_nested_instantiation_link(action):
                continue
            source_objects = []
            for source_key, source_action in action.source:
                source_object = cfg[source_action.dest]
                if source_key == source_action.dest:
                    source_objects.append(source_object)
                else:
                    attr = split_key_leaf(source_key)[1]
                    from ._typehints import ActionTypeHint

                    if ActionTypeHint.is_subclass_typehint(source_action) and not hasattr(source_object, attr):
                        parser.logger.debug(
                            f"Link '{action.option_strings[0]}' ignored since attribute '{attr}' not found "
                            f"in source {source_object}."
                        )
                        continue
                    source_objects.append(getattr(source_object, attr))
            if not source_objects:
                continue
            elif action.compute_fn is None:
                value = source_objects[0]
            else:
                value = action.call_compute_fn(source_objects)
            ActionLink.set_target_value(action, value, cfg, parser.logger)
            applied_links.add(action)
            parser.logger.debug(f"Applied link '{action.option_strings[0]}'.")

        if target:
            cfg[applied_key] = applied_links

    @staticmethod
    def get_nested_links(parser, action):
        def trim_param_keys(params: dict):
            params = params.copy()
            params["source"] = tuple(k[len(f"{action.dest}.") :] for k in params["source"])
            params["target"] = params["target"][len(f"{action.dest}.init_args.") :]
            return params

        links = []
        for link in get_link_actions(parser, "instantiate"):
            if link.target[1] is action and is_nested_instantiation_link(link):
                links.append(trim_param_keys(link.get_kwargs()))
        return links

    @staticmethod
    def set_target_value(action: "ActionLink", value: Any, cfg: Namespace, logger) -> None:
        target_key, target_action = action.target
        assert target_action
        from ._typehints import ActionTypeHint

        if ActionTypeHint.is_subclass_typehint(target_action, all_subtypes=False, also_lists=True):
            if target_key == target_action.dest:
                target_action._check_type(value)  # type: ignore[attr-defined]
            else:
                parent = cfg.get(target_action.dest)
                child_key = target_key[len(target_action.dest) + 1 :]
                if isinstance(parent, list) and any(isinstance(i, Namespace) and child_key in i for i in parent):
                    for item in parent:
                        if child_key in item:
                            item[child_key] = value
                    return
                if target_key not in cfg:
                    logger.debug(f"Link '{action.option_strings[0]}' ignored since target not found.")
                    return
        cfg[target_key] = value

    @staticmethod
    def instantiation_order(parser):
        actions = get_link_actions(parser, "instantiate")
        if actions:
            graph = DirectedGraph()
            for action in actions:
                target = re.sub(r"\.init_args$", "", split_key_leaf(action.target[0])[0])
                for _, source_action in action.source:
                    graph.add_edge(source_action.dest, target)
            return graph.get_topological_order()
        return []

    @staticmethod
    def reorder(order, components):
        ordered = []
        for key in order:
            after = []
            for component in components:
                if key == component.dest or component.dest.startswith(key + "."):
                    ordered.append(component)
                else:
                    after.append(component)
            components = after
        return ordered + components

    @staticmethod
    def strip_link_target_keys(parser, cfg):
        def del_target_key(target_key):
            cfg.pop(target_key, None)
            if "." not in target_key:
                return
            parent_key, _ = split_key_leaf(target_key)
            if "." in target_key and parent_key in cfg and not cfg[parent_key]:
                del cfg[parent_key]

        for action in [a for a in parser._actions if isinstance(a, ActionLink)]:
            del_target_key(action.target[0])
        from ._typehints import ActionTypeHint

        for action in [a for a in parser._actions if isinstance(a, ActionTypeHint) and hasattr(a, "sub_add_kwargs")]:
            for key in action.sub_add_kwargs.get("linked_targets", []):
                del_target_key(f"{action.dest}.init_args.{key}")

        with _ActionSubCommands.not_single_subcommand():
            subcommands, subparsers = _ActionSubCommands.get_subcommands(parser, cfg)
        if subcommands is not None:
            for num, subcommand in enumerate(subcommands):
                if subcommand in cfg:
                    ActionLink.strip_link_target_keys(subparsers[num], cfg[subcommand])


def get_link_actions(parser: "ArgumentParser", apply_on: str, skip=set()) -> List[ActionLink]:
    if not hasattr(parser, "_links_group"):
        return []
    return [a for a in parser._links_group._group_actions if a.apply_on == apply_on and a not in skip]


def is_nested_instantiation_link(action: ActionLink) -> bool:
    from ._typehints import ActionTypeHint

    target_key, target_action = action.target
    assert target_action
    return (
        target_key.startswith(f"{target_action.dest}.init_args.")
        and ActionTypeHint.is_subclass_typehint(target_action)
        and all(a is target_action for _, a in action.source)
        and all(k.startswith(f"{target_action.dest}.") for k, _ in action.source)
    )


[docs] class ArgumentLinking: """Method for linking arguments."""