Skip to content

Base node

Node

Base Class Node to keep track of Computation Graph of torch models

Source code in torchview/computation_node/base_node.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
class Node:
    '''Base Class Node to keep track of Computation Graph of torch models'''
    def __init__(
        self,
        depth: int,
        parents: NodeContainer[Node] | Node | None = None,
        children: NodeContainer[Node] | Node | None = None,
        name: str = 'node',
    ) -> None:
        if children is None:
            children = NodeContainer()
        if parents is None:
            parents = NodeContainer()

        self.children = (
            NodeContainer([children]) if isinstance(children, Node)
            else children
        )

        self.parents = (
            NodeContainer([parents]) if isinstance(parents, Node)
            else parents
        )

        self.name = name
        self.depth = depth
        self.node_id = 'null'

    def __str__(self) -> str:
        return self.name

    def __repr__(self) -> str:
        return f"{self.name} at {hex(id(self))}"

    def add_child(self, node: Node) -> None:
        self.children.add(node)

    def add_parent(self, node: Node) -> None:
        self.parents.add(node)

    def remove_child(self, node: Node) -> None:
        self.children.remove(node)

    def remove_parent(self, node: Node) -> None:
        self.parents.remove(node)

    def set_children(self, node_arr: NodeContainer[Node]) -> None:
        self.children = node_arr

    def set_parents(self, node_arr: NodeContainer[Node]) -> None:
        self.parents = node_arr

    def is_root(self) -> bool:
        return not self.parents

    def is_leaf(self) -> bool:
        return not self.children

    def set_node_id(self) -> None:
        raise NotImplementedError(
            'To be implemented by subclasses of Node Class !!!'
        )