Module src.utils

Expand source code
import networkx as nwx

from src.CONSTANTS import (
    ACTION_NODE,
    ALTERNATIVE_NODE,
    CONTEXT_NODE,
    DATA_ITEM_ATTR,
    DECISION_NODE,
    ID_RO,
    IS_ALTERNATIVE,
    IS_IN_PARALLEL,
    IS_ORIGINAL_ATTR,
    PARALLEL_NODE,
    PARALLEL_START_ATTR,
    PARLLEL_END_ATTR,
    RANGE_ATTR,
    TRIGGER,
    TYPE_ATTR,
)


def get_type_nodes(graph, node_type):
    """
    Retrieves list of nodes of the given type

    Args:
        graph (networkx graph): The graph.
        node_type (str): Type of the nodes to retrieve.

    Returns:
        list: List of node of the given type.

    """
    return [node for node, attr in graph.nodes.items() if attr[TYPE_ATTR] == node_type]


def find_goal_node(graph, start_node):
    """
    Finds a goal node recursively from a given start_node.

    Goal node MUST NOT have any out edges.

    Args:
        graph (networkx graph): The graph.
        start_node (str): Node to start the search.

    Returns:
        str: Name of the goal node.

    """
    out_edges = graph.out_edges(start_node)
    if len(out_edges):
        return find_goal_node(graph, list(out_edges)[0][1])
    return start_node


def find_init_node(graph, start_node):
    """
    Finds a initial node recursively from a given start_node.

    Initial node MUST NOT have any in edges.

    Args:
        graph (networkx graph): The graph.
        start_node (str): Node to start the search.

    Returns:
        str: Name of the goal node.

    """
    in_edges = graph.in_edges(start_node)
    if len(in_edges):
        return find_init_node(graph, list(in_edges)[0][0])
    return start_node


def get_metric_name(metric):
    """
    Extract the name of the metric.

    Metrics are node attributs that contains the word 'Cost'.

    Args:
        metric (str): The metric.

    Returns:
        str: Name of the metric.

    """
    metric_name = metric if metric == "cost" else metric.replace("Cost", "")
    return metric_name


def get_all_parallel_nodes(graph):
    """
    Finds all the nodes involved in parallel paths.

    Args:
        graph (networkx graph): The graph.


    Returns:
        list: List of parallel nodes.

    """
    parallel_nodes = []
    for node, attributes in graph.nodes.items():

        if attributes.get(IS_IN_PARALLEL) == True:
            parallel_nodes.append(f"{node}")

    return parallel_nodes


def get_number_parallel_paths(graph):
    """
    Finds the number of parallel paths.

    Args:
        graph (networkx graph): The graph.

    Returns:
        int: Number of parallel paths.
    """  # find_init_node
    p_start = ""
    p_end = ""
    n_path_found = {}
    for node, attributes in graph.nodes.items():
        if find_init_node(graph, node) not in n_path_found:
            n_path_found[find_init_node(graph, node)] = 0
        if attributes.get(PARALLEL_START_ATTR) == True:
            p_start = node
        if attributes.get(PARLLEL_END_ATTR) == True:
            p_end = node
        if p_start != "" and p_end != "":
            parallel_sequence = list(
                nwx.all_simple_paths(graph, source=p_start, target=p_end)
            )
            n_paths = len(parallel_sequence)
            context = find_init_node(graph, p_start)
            n_path_found[context] = n_paths + n_path_found.get(context, 0)
            p_start = ""
    return n_path_found


def find_parallel_path(graph, p_nodes_found):
    """
    Finds all parallel paths from a list of parallel nodes.

    Args:
        graph (networkx graph): The graph.
        p_nodes_found (list): List of parallel start and end nodes.


    Returns:
        str: PDDL representation of the parallel path.

    """
    parallelNode = ""
    # TODO: numParallelPaths for each diseases
    end_nodes = []

    for start_node in p_nodes_found:

        # TODO: Check whether this will make it more robust with a bigger graph
        # Check if the current start node is an en d node
        if start_node not in end_nodes:
            for end_node in p_nodes_found:

                parallel_sequence = list(
                    nwx.all_simple_paths(graph, source=start_node, target=end_node)
                )
                if not parallel_sequence:
                    continue
                elif len(parallel_sequence) == 1:
                    continue
                else:
                    parallelTypeNode = ""
                    untraversedParallelNode = ""

                    parallelNode += "(parallelStartNode {})\n\t".format(start_node)
                    graph.nodes[start_node][PARALLEL_START_ATTR] = True
                    if end_node not in end_nodes:
                        end_nodes.append(end_node)
                        parallelNode += "(parallelEndNode {})\n\t".format(end_node)
                        graph.nodes[end_node][PARLLEL_END_ATTR] = True

                    # for path in parallel_sequence:
                    (
                        parallelTypeNode,
                        untraversedParallelNode,
                    ) = update_between_parallel_nodes(
                        graph,
                        start_node,
                        end_node,
                        parallelTypeNode,
                        untraversedParallelNode,
                    )

                    parallelNode += parallelTypeNode
                    parallelNode += untraversedParallelNode

    return parallelNode


def update_between_parallel_nodes(
    graph,
    start_node,
    end_node,
    parallelTypeNode,
    untraversedParallelNode,
    numParallelPaths=0,
):
    """
    Updates the PDDL representation of the parallel nodes between a parallel start node and a parallel end node.



    Args:
        graph (networkx graph): The graph.
        start_node (str): Start node of the parallel path, Parallel Start Node.
        end_node (str): End node of the parallel path, Parallel End Node.
        parallelTypeNode (str): PDDL representation of the parallel node.
        untraversedParallelNode (str): PDDL representation of the untraversed parallel node.
        numParallelPaths (int): Number of parallel paths.

    Returns:
        str: PDDL representation of the parallel nodes.
    """
    if start_node == end_node:
        return parallelTypeNode, untraversedParallelNode

    if type(start_node) == str:
        first_path, *path_list = graph.out_edges(start_node)
    else:
        first_path, *path_list = start_node

    if len(path_list) == 1:
        _, node = path_list.pop()
        _, nodefp = first_path
        parallelTypeNode, untraversedParallelNode = update_between_parallel_nodes(
            graph,
            nodefp,
            end_node,
            parallelTypeNode,
            untraversedParallelNode,
            numParallelPaths + 1,
        )
        return update_between_parallel_nodes(
            graph,
            node,
            end_node,
            parallelTypeNode,
            untraversedParallelNode,
            numParallelPaths + 1,
        )

    elif len(path_list) > 1:
        _, nodefp = first_path
        parallelTypeNode, untraversedParallelNode = update_between_parallel_nodes(
            graph,
            nodefp,
            end_node,
            parallelTypeNode,
            untraversedParallelNode,
            numParallelPaths + 1,
        )

        return update_between_parallel_nodes(
            graph,
            path_list,
            end_node,
            parallelTypeNode,
            untraversedParallelNode,
            numParallelPaths + 1,
        )

    if graph.nodes[start_node][TYPE_ATTR] != PARALLEL_NODE:
        graph.nodes[start_node][IS_IN_PARALLEL] = True

        parallelTypeNode += "(parallel{}Node {})\n\t".format(
            graph.nodes[start_node][TYPE_ATTR].capitalize(), start_node
        )
        untraversedParallelNode += "(untraversedParallelNode {})\n\t".format(start_node)

    _, node = first_path
    return update_between_parallel_nodes(
        graph,
        node,
        end_node,
        parallelTypeNode,
        untraversedParallelNode,
        numParallelPaths + 1,
    )


def get_all_metrics(graph):
    """
    Finds all the metrics.

    Metrics are node attributs that contains the word 'Cost'.

    Args:
        graph (networkx graph): The graph.


    Returns:
        list: List of metrics.

    """
    action_nodes = get_type_nodes(graph, ACTION_NODE)
    metrics = []
    for node in action_nodes:
        node_metrics = [
            attr for attr in graph.nodes[node] if attr.lower().find("cost") != -1
        ]
        metrics.extend(node_metrics)
    return list(set(metrics))


def get_all_revIds(graph):
    """
    Finds all revision IDs.

    Args:
        graph (networkx graph): The graph.

    Returns:
        list: List of revision IDs.
    """
    revIds = []
    for _, attr in graph.nodes.items():
        idRO = attr.get(ID_RO, False)
        if idRO and idRO not in revIds:
            revIds.append(idRO)
    return revIds


def find_revId_involved_nodes(graph, revId):
    """
    Finds all the nodes involved in a given revision ID. This includes the list of triggering nodes and the inserted nodes

    Args:
        graph (networkx graph): The graph.
        revId (str): The revision ID.

    Returns:
        list: List of node's name.
    """
    nodes = []
    for node, attr in graph.nodes.items():
        node_revId = attr.get(ID_RO, False)
        if node_revId and node_revId == revId:
            nodes.extend(attr.get(TRIGGER))
            parent_nodes = list(graph.predecessors(node))
            # Loop over all the parents of the existing node
            for parent_node in parent_nodes:
                children = list(graph.successors(parent_node))
                while len(children) > 0:
                    child = children.pop()
                    child_attr = graph.nodes[child]
                    child_id_ro = child_attr.get(ID_RO, None)
                    # need to check the revision flags when the added nodes are not just action nodes
                    if child_id_ro and child_id_ro != revId:
                        print(revId, child)
                        nodes.append(child)
                        children.extend(list(graph.successors(child)))
    return list(set(nodes))


def match_nodes_to_disease(graph):
    """
    For each disease, find the revision operators that involve the disease.

    Format of the return object:

        {
            "disease1": ['ro1', 'ro2'],
            ...
        }


    Args:
        graph (networkx graph): The graph.

    Returns:
        object: Object where the keys are the diseases and the values are a list of revision operations IDs.
    """
    revIds = get_all_revIds(graph)
    diseases = get_type_nodes(graph, CONTEXT_NODE)

    ro_disease = {}
    for disease in diseases:
        ro_disease[disease] = set()

    for revId in revIds:
        nodes = find_revId_involved_nodes(graph, revId)
        diseases_involved = []
        for node in nodes:
            diseases_involved.append(find_init_node(graph, node))
        for disease in diseases_involved:
            ro_disease[disease].add(revId)
    return ro_disease


def handle_alternative_nodes(graph):
    """
    Modifies the alternative nodes to decision nodes where all successors have the same edges value.
    That way, the planner will look into all successors for an optimize solution.


    The patient values provided should include a value called "default" with the value of 0 or 1 that will be used for the alternative nodes.

    Args:
        graph (networkx graph): The graph.
    """
    for node, attr in graph.nodes.items():
        if attr.get(TYPE_ATTR) == ALTERNATIVE_NODE:
            graph.nodes[node][TYPE_ATTR] = DECISION_NODE
            graph.nodes[node][DATA_ITEM_ATTR] = "default_value"
            graph.nodes[node][IS_ALTERNATIVE] = True
            for succ in graph.successors(node):
                graph.edges[node, succ, 0][RANGE_ATTR] = "0..1"

Functions

def find_goal_node(graph, start_node)

Finds a goal node recursively from a given start_node.

Goal node MUST NOT have any out edges.

Args

graph : networkx graph
The graph.
start_node : str
Node to start the search.

Returns

str
Name of the goal node.
Expand source code
def find_goal_node(graph, start_node):
    """
    Finds a goal node recursively from a given start_node.

    Goal node MUST NOT have any out edges.

    Args:
        graph (networkx graph): The graph.
        start_node (str): Node to start the search.

    Returns:
        str: Name of the goal node.

    """
    out_edges = graph.out_edges(start_node)
    if len(out_edges):
        return find_goal_node(graph, list(out_edges)[0][1])
    return start_node
def find_init_node(graph, start_node)

Finds a initial node recursively from a given start_node.

Initial node MUST NOT have any in edges.

Args

graph : networkx graph
The graph.
start_node : str
Node to start the search.

Returns

str
Name of the goal node.
Expand source code
def find_init_node(graph, start_node):
    """
    Finds a initial node recursively from a given start_node.

    Initial node MUST NOT have any in edges.

    Args:
        graph (networkx graph): The graph.
        start_node (str): Node to start the search.

    Returns:
        str: Name of the goal node.

    """
    in_edges = graph.in_edges(start_node)
    if len(in_edges):
        return find_init_node(graph, list(in_edges)[0][0])
    return start_node
def find_parallel_path(graph, p_nodes_found)

Finds all parallel paths from a list of parallel nodes.

Args

graph : networkx graph
The graph.
p_nodes_found : list
List of parallel start and end nodes.

Returns

str
PDDL representation of the parallel path.
Expand source code
def find_parallel_path(graph, p_nodes_found):
    """
    Finds all parallel paths from a list of parallel nodes.

    Args:
        graph (networkx graph): The graph.
        p_nodes_found (list): List of parallel start and end nodes.


    Returns:
        str: PDDL representation of the parallel path.

    """
    parallelNode = ""
    # TODO: numParallelPaths for each diseases
    end_nodes = []

    for start_node in p_nodes_found:

        # TODO: Check whether this will make it more robust with a bigger graph
        # Check if the current start node is an en d node
        if start_node not in end_nodes:
            for end_node in p_nodes_found:

                parallel_sequence = list(
                    nwx.all_simple_paths(graph, source=start_node, target=end_node)
                )
                if not parallel_sequence:
                    continue
                elif len(parallel_sequence) == 1:
                    continue
                else:
                    parallelTypeNode = ""
                    untraversedParallelNode = ""

                    parallelNode += "(parallelStartNode {})\n\t".format(start_node)
                    graph.nodes[start_node][PARALLEL_START_ATTR] = True
                    if end_node not in end_nodes:
                        end_nodes.append(end_node)
                        parallelNode += "(parallelEndNode {})\n\t".format(end_node)
                        graph.nodes[end_node][PARLLEL_END_ATTR] = True

                    # for path in parallel_sequence:
                    (
                        parallelTypeNode,
                        untraversedParallelNode,
                    ) = update_between_parallel_nodes(
                        graph,
                        start_node,
                        end_node,
                        parallelTypeNode,
                        untraversedParallelNode,
                    )

                    parallelNode += parallelTypeNode
                    parallelNode += untraversedParallelNode

    return parallelNode
def find_revId_involved_nodes(graph, revId)

Finds all the nodes involved in a given revision ID. This includes the list of triggering nodes and the inserted nodes

Args

graph : networkx graph
The graph.
revId : str
The revision ID.

Returns

list
List of node's name.
Expand source code
def find_revId_involved_nodes(graph, revId):
    """
    Finds all the nodes involved in a given revision ID. This includes the list of triggering nodes and the inserted nodes

    Args:
        graph (networkx graph): The graph.
        revId (str): The revision ID.

    Returns:
        list: List of node's name.
    """
    nodes = []
    for node, attr in graph.nodes.items():
        node_revId = attr.get(ID_RO, False)
        if node_revId and node_revId == revId:
            nodes.extend(attr.get(TRIGGER))
            parent_nodes = list(graph.predecessors(node))
            # Loop over all the parents of the existing node
            for parent_node in parent_nodes:
                children = list(graph.successors(parent_node))
                while len(children) > 0:
                    child = children.pop()
                    child_attr = graph.nodes[child]
                    child_id_ro = child_attr.get(ID_RO, None)
                    # need to check the revision flags when the added nodes are not just action nodes
                    if child_id_ro and child_id_ro != revId:
                        print(revId, child)
                        nodes.append(child)
                        children.extend(list(graph.successors(child)))
    return list(set(nodes))
def get_all_metrics(graph)

Finds all the metrics.

Metrics are node attributs that contains the word 'Cost'.

Args

graph : networkx graph
The graph.

Returns

list
List of metrics.
Expand source code
def get_all_metrics(graph):
    """
    Finds all the metrics.

    Metrics are node attributs that contains the word 'Cost'.

    Args:
        graph (networkx graph): The graph.


    Returns:
        list: List of metrics.

    """
    action_nodes = get_type_nodes(graph, ACTION_NODE)
    metrics = []
    for node in action_nodes:
        node_metrics = [
            attr for attr in graph.nodes[node] if attr.lower().find("cost") != -1
        ]
        metrics.extend(node_metrics)
    return list(set(metrics))
def get_all_parallel_nodes(graph)

Finds all the nodes involved in parallel paths.

Args

graph : networkx graph
The graph.

Returns

list
List of parallel nodes.
Expand source code
def get_all_parallel_nodes(graph):
    """
    Finds all the nodes involved in parallel paths.

    Args:
        graph (networkx graph): The graph.


    Returns:
        list: List of parallel nodes.

    """
    parallel_nodes = []
    for node, attributes in graph.nodes.items():

        if attributes.get(IS_IN_PARALLEL) == True:
            parallel_nodes.append(f"{node}")

    return parallel_nodes
def get_all_revIds(graph)

Finds all revision IDs.

Args

graph : networkx graph
The graph.

Returns

list
List of revision IDs.
Expand source code
def get_all_revIds(graph):
    """
    Finds all revision IDs.

    Args:
        graph (networkx graph): The graph.

    Returns:
        list: List of revision IDs.
    """
    revIds = []
    for _, attr in graph.nodes.items():
        idRO = attr.get(ID_RO, False)
        if idRO and idRO not in revIds:
            revIds.append(idRO)
    return revIds
def get_metric_name(metric)

Extract the name of the metric.

Metrics are node attributs that contains the word 'Cost'.

Args

metric : str
The metric.

Returns

str
Name of the metric.
Expand source code
def get_metric_name(metric):
    """
    Extract the name of the metric.

    Metrics are node attributs that contains the word 'Cost'.

    Args:
        metric (str): The metric.

    Returns:
        str: Name of the metric.

    """
    metric_name = metric if metric == "cost" else metric.replace("Cost", "")
    return metric_name
def get_number_parallel_paths(graph)

Finds the number of parallel paths.

Args

graph : networkx graph
The graph.

Returns

int
Number of parallel paths.
Expand source code
def get_number_parallel_paths(graph):
    """
    Finds the number of parallel paths.

    Args:
        graph (networkx graph): The graph.

    Returns:
        int: Number of parallel paths.
    """  # find_init_node
    p_start = ""
    p_end = ""
    n_path_found = {}
    for node, attributes in graph.nodes.items():
        if find_init_node(graph, node) not in n_path_found:
            n_path_found[find_init_node(graph, node)] = 0
        if attributes.get(PARALLEL_START_ATTR) == True:
            p_start = node
        if attributes.get(PARLLEL_END_ATTR) == True:
            p_end = node
        if p_start != "" and p_end != "":
            parallel_sequence = list(
                nwx.all_simple_paths(graph, source=p_start, target=p_end)
            )
            n_paths = len(parallel_sequence)
            context = find_init_node(graph, p_start)
            n_path_found[context] = n_paths + n_path_found.get(context, 0)
            p_start = ""
    return n_path_found
def get_type_nodes(graph, node_type)

Retrieves list of nodes of the given type

Args

graph : networkx graph
The graph.
node_type : str
Type of the nodes to retrieve.

Returns

list
List of node of the given type.
Expand source code
def get_type_nodes(graph, node_type):
    """
    Retrieves list of nodes of the given type

    Args:
        graph (networkx graph): The graph.
        node_type (str): Type of the nodes to retrieve.

    Returns:
        list: List of node of the given type.

    """
    return [node for node, attr in graph.nodes.items() if attr[TYPE_ATTR] == node_type]
def handle_alternative_nodes(graph)

Modifies the alternative nodes to decision nodes where all successors have the same edges value. That way, the planner will look into all successors for an optimize solution.

The patient values provided should include a value called "default" with the value of 0 or 1 that will be used for the alternative nodes.

Args

graph : networkx graph
The graph.
Expand source code
def handle_alternative_nodes(graph):
    """
    Modifies the alternative nodes to decision nodes where all successors have the same edges value.
    That way, the planner will look into all successors for an optimize solution.


    The patient values provided should include a value called "default" with the value of 0 or 1 that will be used for the alternative nodes.

    Args:
        graph (networkx graph): The graph.
    """
    for node, attr in graph.nodes.items():
        if attr.get(TYPE_ATTR) == ALTERNATIVE_NODE:
            graph.nodes[node][TYPE_ATTR] = DECISION_NODE
            graph.nodes[node][DATA_ITEM_ATTR] = "default_value"
            graph.nodes[node][IS_ALTERNATIVE] = True
            for succ in graph.successors(node):
                graph.edges[node, succ, 0][RANGE_ATTR] = "0..1"
def match_nodes_to_disease(graph)

For each disease, find the revision operators that involve the disease.

Format of the return object:

{
    "disease1": ['ro1', 'ro2'],
    ...
}

Args

graph : networkx graph
The graph.

Returns

object
Object where the keys are the diseases and the values are a list of revision operations IDs.
Expand source code
def match_nodes_to_disease(graph):
    """
    For each disease, find the revision operators that involve the disease.

    Format of the return object:

        {
            "disease1": ['ro1', 'ro2'],
            ...
        }


    Args:
        graph (networkx graph): The graph.

    Returns:
        object: Object where the keys are the diseases and the values are a list of revision operations IDs.
    """
    revIds = get_all_revIds(graph)
    diseases = get_type_nodes(graph, CONTEXT_NODE)

    ro_disease = {}
    for disease in diseases:
        ro_disease[disease] = set()

    for revId in revIds:
        nodes = find_revId_involved_nodes(graph, revId)
        diseases_involved = []
        for node in nodes:
            diseases_involved.append(find_init_node(graph, node))
        for disease in diseases_involved:
            ro_disease[disease].add(revId)
    return ro_disease
def update_between_parallel_nodes(graph, start_node, end_node, parallelTypeNode, untraversedParallelNode, numParallelPaths=0)

Updates the PDDL representation of the parallel nodes between a parallel start node and a parallel end node.

Args

graph : networkx graph
The graph.
start_node : str
Start node of the parallel path, Parallel Start Node.
end_node : str
End node of the parallel path, Parallel End Node.
parallelTypeNode : str
PDDL representation of the parallel node.
untraversedParallelNode : str
PDDL representation of the untraversed parallel node.
numParallelPaths : int
Number of parallel paths.

Returns

str
PDDL representation of the parallel nodes.
Expand source code
def update_between_parallel_nodes(
    graph,
    start_node,
    end_node,
    parallelTypeNode,
    untraversedParallelNode,
    numParallelPaths=0,
):
    """
    Updates the PDDL representation of the parallel nodes between a parallel start node and a parallel end node.



    Args:
        graph (networkx graph): The graph.
        start_node (str): Start node of the parallel path, Parallel Start Node.
        end_node (str): End node of the parallel path, Parallel End Node.
        parallelTypeNode (str): PDDL representation of the parallel node.
        untraversedParallelNode (str): PDDL representation of the untraversed parallel node.
        numParallelPaths (int): Number of parallel paths.

    Returns:
        str: PDDL representation of the parallel nodes.
    """
    if start_node == end_node:
        return parallelTypeNode, untraversedParallelNode

    if type(start_node) == str:
        first_path, *path_list = graph.out_edges(start_node)
    else:
        first_path, *path_list = start_node

    if len(path_list) == 1:
        _, node = path_list.pop()
        _, nodefp = first_path
        parallelTypeNode, untraversedParallelNode = update_between_parallel_nodes(
            graph,
            nodefp,
            end_node,
            parallelTypeNode,
            untraversedParallelNode,
            numParallelPaths + 1,
        )
        return update_between_parallel_nodes(
            graph,
            node,
            end_node,
            parallelTypeNode,
            untraversedParallelNode,
            numParallelPaths + 1,
        )

    elif len(path_list) > 1:
        _, nodefp = first_path
        parallelTypeNode, untraversedParallelNode = update_between_parallel_nodes(
            graph,
            nodefp,
            end_node,
            parallelTypeNode,
            untraversedParallelNode,
            numParallelPaths + 1,
        )

        return update_between_parallel_nodes(
            graph,
            path_list,
            end_node,
            parallelTypeNode,
            untraversedParallelNode,
            numParallelPaths + 1,
        )

    if graph.nodes[start_node][TYPE_ATTR] != PARALLEL_NODE:
        graph.nodes[start_node][IS_IN_PARALLEL] = True

        parallelTypeNode += "(parallel{}Node {})\n\t".format(
            graph.nodes[start_node][TYPE_ATTR].capitalize(), start_node
        )
        untraversedParallelNode += "(untraversedParallelNode {})\n\t".format(start_node)

    _, node = first_path
    return update_between_parallel_nodes(
        graph,
        node,
        end_node,
        parallelTypeNode,
        untraversedParallelNode,
        numParallelPaths + 1,
    )