Base Class Node to keep track of Computation Graph of torch models
Source code in torchview/computation_node/base_node.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85 | class Node:
'''Base Class Node to keep track of Computation Graph of torch models'''
def __init__(
self,
depth: int,
parents: NodeContainer[Node] | Node | None = None,
children: NodeContainer[Node] | Node | None = None,
name: str = 'node',
) -> None:
if children is None:
children = NodeContainer()
if parents is None:
parents = NodeContainer()
self.children = (
NodeContainer([children]) if isinstance(children, Node)
else children
)
self.parents = (
NodeContainer([parents]) if isinstance(parents, Node)
else parents
)
self.name = name
self.depth = depth
self.node_id = 'null'
def __str__(self) -> str:
return self.name
def __repr__(self) -> str:
return f"{self.name} at {hex(id(self))}"
def add_child(self, node: Node) -> None:
self.children.add(node)
def add_parent(self, node: Node) -> None:
self.parents.add(node)
def remove_child(self, node: Node) -> None:
self.children.remove(node)
def remove_parent(self, node: Node) -> None:
self.parents.remove(node)
def set_children(self, node_arr: NodeContainer[Node]) -> None:
self.children = node_arr
def set_parents(self, node_arr: NodeContainer[Node]) -> None:
self.parents = node_arr
def is_root(self) -> bool:
return not self.parents
def is_leaf(self) -> bool:
return not self.children
def set_node_id(self) -> None:
raise NotImplementedError(
'To be implemented by subclasses of Node Class !!!'
)
|