Skip to content

Compute node

FunctionNode

Bases: Node

Subclass of node specialized for nodes that does computation (e.g. torch.functions)

Source code in torchview/computation_node/compute_node.py
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
class FunctionNode(Node):
    '''Subclass of node specialized for nodes
    that does computation (e.g. torch.functions)
    '''
    def __init__(
        self,
        function_unit: Callable[..., Any],
        depth: int,
        parents: NodeContainer[Node] | Node | None = None,
        children: NodeContainer[Node] | Node | None = None,
        name: str = 'function-node',
    ) -> None:
        super(FunctionNode, self).__init__(
            depth, parents, children, name
        )
        self.compute_unit_id = id(function_unit)
        self.is_container = True
        self.input_shape: list[Tuple[int, ...]] = []
        self.output_shape: list[Tuple[int, ...]] = []
        self.set_node_id()
        self.output_nodes = self.children

    def set_input_shape(self, input_shape: list[Tuple[int, ...]]) -> None:
        self.input_shape = input_shape

    def set_output_shape(self, output_shape: list[Tuple[int, ...]]) -> None:
        self.output_shape = output_shape

    def add_output_nodes(self, output_node: Node) -> None:
        self.output_nodes.add(output_node)

    def set_node_id(self, output_id: int | str | None = None) -> None:
        '''Sets the id of FunctionNode.
        If no output is given, it sets to value unique to node.
        If output id is given, id is determined by only id of nn.module object
        This is crucial when rolling recursive modules by identifying them with this id
        mechanism'''
        if output_id is None:
            self.node_id = f'{id(self)}'
        else:
            self.node_id = f'{self.compute_unit_id}-{output_id}'

set_node_id(output_id=None)

Sets the id of FunctionNode. If no output is given, it sets to value unique to node. If output id is given, id is determined by only id of nn.module object This is crucial when rolling recursive modules by identifying them with this id mechanism

Source code in torchview/computation_node/compute_node.py
133
134
135
136
137
138
139
140
141
142
def set_node_id(self, output_id: int | str | None = None) -> None:
    '''Sets the id of FunctionNode.
    If no output is given, it sets to value unique to node.
    If output id is given, id is determined by only id of nn.module object
    This is crucial when rolling recursive modules by identifying them with this id
    mechanism'''
    if output_id is None:
        self.node_id = f'{id(self)}'
    else:
        self.node_id = f'{self.compute_unit_id}-{output_id}'

ModuleNode

Bases: Node

Subclass of node specialzed for storing torch Module info

Source code in torchview/computation_node/compute_node.py
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
class ModuleNode(Node):
    '''Subclass of node specialzed for storing torch Module info
    '''
    def __init__(
        self,
        module_unit: nn.Module,
        depth: int,
        parents: NodeContainer[Node] | Node | None = None,
        children: NodeContainer[Node] | Node | None = None,
        name: str = 'module-node',
        output_nodes: NodeContainer[Node] | None = None,
    ) -> None:
        super(ModuleNode, self).__init__(
            depth, parents, children, name
        )
        self.compute_unit_id = id(module_unit)
        self.is_activation = is_generator_empty(module_unit.parameters())
        self.is_container = not any(module_unit.children())
        self.input_shape: list[Tuple[int, ...]] = []
        self.output_shape: list[Tuple[int, ...]] = []
        self.output_nodes = NodeContainer() if output_nodes is None else output_nodes
        self.set_node_id()

    def set_input_shape(self, input_shape: list[Tuple[int, ...]]) -> None:
        self.input_shape = input_shape

    def set_output_shape(self, output_shape: list[Tuple[int, ...]]) -> None:
        self.output_shape = output_shape

    def add_output_nodes(self, output_node: Node) -> None:
        self.output_nodes.add(output_node)

    def set_node_id(self, output_id: int | str | None = None) -> None:
        '''Sets the id of ModuleNode.
        If no output is given, it sets to value unique to node.
        If output id is given, there are 2 cases:
            1. Parameterless module: id is determined by output_id and id of nn.Module
            2. Module with parameter: id is determined by only id of nn.module object
        This is crucial when rolling recursive modules by identifying them with this id
        mechanism'''
        if output_id is None:
            self.node_id = f'{id(self)}'
        else:
            if self.is_activation:
                # zero-parameter module -> module for activation function, e.g. ReLU
                self.node_id = f'{self.compute_unit_id}-{output_id}'
            else:
                self.node_id = f'{self.compute_unit_id}-'

set_node_id(output_id=None)

Sets the id of ModuleNode. If no output is given, it sets to value unique to node. If output id is given, there are 2 cases: 1. Parameterless module: id is determined by output_id and id of nn.Module 2. Module with parameter: id is determined by only id of nn.module object This is crucial when rolling recursive modules by identifying them with this id mechanism

Source code in torchview/computation_node/compute_node.py
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
def set_node_id(self, output_id: int | str | None = None) -> None:
    '''Sets the id of ModuleNode.
    If no output is given, it sets to value unique to node.
    If output id is given, there are 2 cases:
        1. Parameterless module: id is determined by output_id and id of nn.Module
        2. Module with parameter: id is determined by only id of nn.module object
    This is crucial when rolling recursive modules by identifying them with this id
    mechanism'''
    if output_id is None:
        self.node_id = f'{id(self)}'
    else:
        if self.is_activation:
            # zero-parameter module -> module for activation function, e.g. ReLU
            self.node_id = f'{self.compute_unit_id}-{output_id}'
        else:
            self.node_id = f'{self.compute_unit_id}-'

TensorNode

Bases: Node

Subclass of node specialzed for nodes that stores tensor (subclass of torch.Tensor called RecorderTensor)

Source code in torchview/computation_node/compute_node.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
class TensorNode(Node):
    '''Subclass of node specialzed for nodes that
    stores tensor (subclass of torch.Tensor called RecorderTensor)
    '''
    def __init__(
        self,
        tensor: torch.Tensor,
        depth: int,
        parents: NodeContainer[Node] | Node | None = None,
        children: NodeContainer[Node] | Node | None = None,
        name: str = 'tensor',
        context: Any | None = None,
        is_aux: bool = False,
        main_node: TensorNode | None = None,
        parent_hierarchy: dict[int, ModuleNode | FunctionNode] | None = None,
    ):

        super(TensorNode, self).__init__(
            depth, parents, children, name,
        )
        self.tensor_id = id(tensor)
        self.tensor_shape = tuple(tensor.shape)
        self.name = name
        self.is_aux = is_aux
        self.main_node = self if main_node is None else main_node
        self.context = [] if context is None else context
        self.parent_hierarchy = {} if parent_hierarchy is None else parent_hierarchy
        self.set_node_id()

    def set_node_id(self, children_id: int | str | None = None) -> None:
        if children_id is None:
            self.node_id = (
                f'{id(self.main_node)}' if self.is_aux and self.main_node
                else f'{id(self)}'
            )
        else:
            self.node_id = f'{id(self)}-{children_id}'