Skip to content

Recorder tensor

Recorder

Context Manager that sets modules forward and torch creation ops to record them in computation graph

Source code in torchview/recorder_tensor.py
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
class Recorder:
    '''Context Manager that sets modules forward and torch creation ops
    to record them in computation graph'''
    def __init__(
        self, orig_mod_forward: Callable[..., Any], new_mod_forward: Callable[..., Any],
        model_graph: ComputationGraph
    ) -> None:
        self.orig_module_forward = orig_mod_forward
        self.new_module_forward = new_mod_forward
        self.model_graph = model_graph

    def __enter__(self) -> None:
        setattr(
            torch.nn.Module, "__call__", self.new_module_forward
        )

        for name, op in zip(orig_name_list, _orig_op_list):
            setattr(
                torch, name, creation_ops_wrapper(op, self.model_graph)
            )

    def __exit__(self, exc_type: Any, exc_value: Any, exc_traceback: Any) -> None:
        # reset module __call__ back to original method and torch creation ops
        setattr(
            torch.nn.Module, "__call__", self.orig_module_forward
        )

        for name, op in zip(orig_name_list, _orig_op_list):
            setattr(
                torch, name, op
            )

RecorderTensor

Bases: Tensor

Subclass of torch.Tensor used for constructing visual computation graph.

This class stores list of TensorNode objects to keep record of Nodes during forward propagation. The torch_function is also overriden to record needed nodes for visual computation graph.

Attributes:

Name Type Description
tensor_nodes

list[TensorNode] List of TensorNode objects to store relevant TensorNodes

Source code in torchview/recorder_tensor.py
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
class RecorderTensor(torch.Tensor):
    '''Subclass of torch.Tensor used for constructing visual computation graph.

    This class stores list of TensorNode objects to keep record of Nodes during forward
    propagation. The torch_function is also overriden to record needed nodes for visual
    computation graph.

    Attributes:
        All the inherited attributes from torch.Tensor
        tensor_nodes: list[TensorNode]
            List of TensorNode objects to store relevant TensorNodes'''
    @staticmethod
    def __new__(
        cls: Any,
        x: Any,
        tensor_nodes: Any,
        *args: Any,
        **kwargs: Any
    ) -> Any:
        # pylint: disable=unused-argument
        return super().__new__(cls, x, *args, **kwargs)  # type: ignore[call-arg]

    def __init__(
        self, x: Any, tensor_node: TensorNode | list[TensorNode]
    ):
        # pylint: disable=unused-argument
        # super().__init__() # optional

        if isinstance(tensor_node, TensorNode):
            self.tensor_nodes = [tensor_node]
        else:
            self.tensor_nodes = tensor_node

    @classmethod
    def __torch_function__(
        cls: Any, func: Callable[..., Any] | ScriptMethod,
        types: Any,
        args: Any = (),
        kwargs: Any = None,
    ) -> Any:
        '''Calls torch functions for RecorderTensor subclass of torch.Tensor
        Forward prop => Construct Function Node => Construct Output TensorNode
        Args:
            The same arguments as that of  original __torch_function__
            except that the tensor that originated from input (through forward prop)
            are RecorderTensors
        '''
        if kwargs is None:
            kwargs = {}

        args_nodes: NodeContainer[TensorNode] = (
            reduce_data_info([args, kwargs], collect_tensor_node, NodeContainer())
        )

        # This is necessary for torch version < 1.10
        if func in [F.linear, F.embedding]:
            out = nn.parameter.Parameter.__torch_function__(
                func, types, args, kwargs).as_subclass(RecorderTensor)
        else:
            # use original torch_function; otherwise,
            # it leads to infinite recursive call of torch_function
            out = super().__torch_function__(func, types, args, kwargs)

        # if no RecorderTensor is found in input or output
        # dont create any node, give the result only
        if not args_nodes:
            return out
        if not reduce_data_info(out, collect_tensor, OrderedSet()):
            return out

        # Create function_node and connect to its parents tensor node
        cur_depth = next(iter(args_nodes)).depth
        input_context = next(iter(args_nodes)).context
        func_name = (
            func.name if isinstance(func, ScriptMethod) else func.__name__
        )
        cur_node = FunctionNode(
            func, cur_depth, args_nodes, name=func_name  # type: ignore[arg-type]
        )

        for i in args_nodes:
            i.add_child(cur_node)

        input_context.append(cur_node)
        attach_kwargs = {
            'parents': cur_node, 'depth': cur_depth, "context": input_context,
            'is_aux': False, 'parent_hierarchy': {cur_depth: cur_node},
            'name': 'output-tensor' if cur_depth == 0 else 'hidden-tensor'
        }
        traverse_data_inplace(out, attach_node(attach_kwargs))

        # note that when processing inplace operation, input shape is calculated
        # correctly only if inplace operation preserves the input shape
        # which it does for all torch-builtin inplace operations
        # you cant use this before output computation since shape calls
        # to another torch_function (infinite recursion)
        cur_node.set_input_shape(
            reduce_data_info([args, kwargs], collect_shape, [])
        )
        cur_node.set_output_shape(reduce_data_info(out, collect_shape, []))

        return out

__torch_function__(func, types, args=(), kwargs=None) classmethod

Calls torch functions for RecorderTensor subclass of torch.Tensor Forward prop => Construct Function Node => Construct Output TensorNode Args: The same arguments as that of original torch_function except that the tensor that originated from input (through forward prop) are RecorderTensors

Source code in torchview/recorder_tensor.py
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
@classmethod
def __torch_function__(
    cls: Any, func: Callable[..., Any] | ScriptMethod,
    types: Any,
    args: Any = (),
    kwargs: Any = None,
) -> Any:
    '''Calls torch functions for RecorderTensor subclass of torch.Tensor
    Forward prop => Construct Function Node => Construct Output TensorNode
    Args:
        The same arguments as that of  original __torch_function__
        except that the tensor that originated from input (through forward prop)
        are RecorderTensors
    '''
    if kwargs is None:
        kwargs = {}

    args_nodes: NodeContainer[TensorNode] = (
        reduce_data_info([args, kwargs], collect_tensor_node, NodeContainer())
    )

    # This is necessary for torch version < 1.10
    if func in [F.linear, F.embedding]:
        out = nn.parameter.Parameter.__torch_function__(
            func, types, args, kwargs).as_subclass(RecorderTensor)
    else:
        # use original torch_function; otherwise,
        # it leads to infinite recursive call of torch_function
        out = super().__torch_function__(func, types, args, kwargs)

    # if no RecorderTensor is found in input or output
    # dont create any node, give the result only
    if not args_nodes:
        return out
    if not reduce_data_info(out, collect_tensor, OrderedSet()):
        return out

    # Create function_node and connect to its parents tensor node
    cur_depth = next(iter(args_nodes)).depth
    input_context = next(iter(args_nodes)).context
    func_name = (
        func.name if isinstance(func, ScriptMethod) else func.__name__
    )
    cur_node = FunctionNode(
        func, cur_depth, args_nodes, name=func_name  # type: ignore[arg-type]
    )

    for i in args_nodes:
        i.add_child(cur_node)

    input_context.append(cur_node)
    attach_kwargs = {
        'parents': cur_node, 'depth': cur_depth, "context": input_context,
        'is_aux': False, 'parent_hierarchy': {cur_depth: cur_node},
        'name': 'output-tensor' if cur_depth == 0 else 'hidden-tensor'
    }
    traverse_data_inplace(out, attach_node(attach_kwargs))

    # note that when processing inplace operation, input shape is calculated
    # correctly only if inplace operation preserves the input shape
    # which it does for all torch-builtin inplace operations
    # you cant use this before output computation since shape calls
    # to another torch_function (infinite recursion)
    cur_node.set_input_shape(
        reduce_data_info([args, kwargs], collect_shape, [])
    )
    cur_node.set_output_shape(reduce_data_info(out, collect_shape, []))

    return out

attach_node(kwargs, tensor_to_node=None)

Creates the function to attach TensorNodes, needed for nested calls

Source code in torchview/recorder_tensor.py
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
def attach_node(
    kwargs: dict[str, Any],
    tensor_to_node: dict[RecorderTensor, TensorNode] | None = None
) -> Callable[..., Any]:
    '''Creates the function to attach TensorNodes, needed for nested calls'''
    def _func(recorded_tensor: RecorderTensor) -> None:
        '''Attaches TensorNode to ModuleNode or FunctionNode
        '''

        if kwargs['is_aux'] and tensor_to_node:
            kwargs['main_node'] = tensor_to_node[recorded_tensor]
        new_kwargs = {
            key_word: value
            for key_word, value in kwargs.items() if key_word != 'tensor_to_node'
        }
        tensor_node = TensorNode(
            tensor=recorded_tensor,
            **new_kwargs
        )
        if isinstance(kwargs["parents"], ModuleNode):
            assert getattr(recorded_tensor, 'tensor_nodes', None) is not None, (
                f'RecorderTensor to be attached to the Node'
                f'{kwargs["parents"]} must have tensor node'
            )
        assert isinstance(kwargs["parents"], (FunctionNode, ModuleNode)), (
            f'Node {kwargs["parents"]} to which to attach must be either'
            f'FunctionNode or ModuleNode'
        )

        if getattr(recorded_tensor, 'tensor_nodes', None) is None:
            recorded_tensor.tensor_nodes = [tensor_node]
        else:
            # ModuleNode: Attaches auxiliary node to tensors
            # Auxiliary nodes should be appended to keep track the node
            # history of tensor
            # FunctionNode: These should overwrite the last tensor node
            # There are 2 different cases:
            # Non-inplace ops -> New tensor and this node is the first
            # Inplace ops -> Result tensor overwrites the intput tensor
            # in memory, so it should overwrite in node history as well
            # for both cases, overwritting the last tensor node is correct
            if isinstance(kwargs["parents"], ModuleNode):
                recorded_tensor.tensor_nodes.append(tensor_node)
            elif isinstance(kwargs["parents"], FunctionNode):
                recorded_tensor.tensor_nodes[-1] = tensor_node
        kwargs["parents"].add_child(tensor_node)
        kwargs['context'].append(tensor_node)
    return _func

insert_empty_pass_node(recorded_tensor, out_node)

First, inserts empty-pass node as a child of tensor nodes. Then, inserts TensorNode as a child of this empty-pass node

Source code in torchview/recorder_tensor.py
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
def insert_empty_pass_node(
    recorded_tensor: RecorderTensor, out_node: TensorNode
) -> None:
    '''First, inserts empty-pass node as a child of tensor nodes. Then, inserts
    TensorNode as a child of this empty-pass node'''
    out_pass = FunctionNode(
        lambda x: x, out_node.depth, out_node,
        name='empty-pass'
    )
    out_node.add_child(out_pass)
    out_node.context.append(out_pass)

    passed_out_node = TensorNode(
        recorded_tensor, out_node.depth, out_pass,
        context=out_node.context, is_aux=False,
        parent_hierarchy={
            recorded_tensor.tensor_nodes[-1].depth: out_pass
        }
    )

    out_node.context.append(passed_out_node)
    out_pass.add_child(passed_out_node)

    # Update the current node of RecorderTensor
    # Here append instead of overwrite the last node because
    # this is a dummy FunctionNode that has no actual place in
    # computation graph
    recorded_tensor.tensor_nodes.append(passed_out_node)

module_forward_wrapper(model_graph)

Wrapper for forward functions of modules

Source code in torchview/recorder_tensor.py
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
def module_forward_wrapper(model_graph: ComputationGraph) -> Callable[..., Any]:
    '''Wrapper for forward functions of modules'''
    def _module_forward_wrapper(mod: nn.Module, *args: Any, **kwargs: Any) -> Any:
        '''Forward prop of module for RecorderTensor subclass
        Construct Module Node => forward-prop => process output nodes to retain
        module hierarchy correctly
        '''
        # Create module node and connect to its parents tensor node
        input_nodes: NodeContainer[TensorNode] = (
            reduce_data_info([args, kwargs], collect_tensor_node, NodeContainer())
        )
        # get unique input tensors, prevent duplications for auxiliary nodes
        input_recorder: OrderedSet[RecorderTensor] = (
            reduce_data_info([args, kwargs], collect_tensor, OrderedSet())
        )
        # if none of args originated from input
        # hence only torch.Tensor
        if not input_nodes:
            return _orig_module_forward(mod, *args, **kwargs)

        # Create module_node and connect to its parents tensor node
        cur_depth = next(iter(input_nodes)).depth
        input_context = next(iter(input_nodes)).context
        cur_node = ModuleNode(
            mod, cur_depth, input_nodes,  # type: ignore[arg-type]
            name=type(mod).__name__
        )
        cur_node.set_input_shape(
            reduce_data_info([args, kwargs], collect_shape, [])
        )

        # update context with current modules's context
        input_context.append({cur_node: []})
        for node in input_nodes:
            node.add_child(cur_node)

        tensor_to_node: dict[RecorderTensor, TensorNode] = (
            reduce_data_info([args, kwargs], collect_tensor_node_id_dict, {})
        )
        attach_kwargs = {
            'parents': cur_node, 'depth': cur_depth+1,
            'context': input_context[-1][cur_node], 'is_aux': True,
            'name': 'auxiliary-tensor'
        }

        traverse_data_inplace(
            input_recorder, attach_node(attach_kwargs, tensor_to_node)
        )

        model_graph.context_tracker['current_depth'] = cur_depth+1
        model_graph.context_tracker['current_context'] = input_context[-1][cur_node]

        # TODO: check if output contains RecorderTensor
        # this seems not to be necessary so far
        out = _orig_module_forward(mod, *args, **kwargs)

        model_graph.context_tracker['current_depth'] = cur_depth
        model_graph.context_tracker['current_context'] = input_context

        # pop appropriate nodes, see implementation below
        output_recorder: OrderedSet[RecorderTensor] = (
            reduce_data_info(out, collect_tensor, OrderedSet())
        )

        traverse_data_inplace(
            output_recorder,
            process_output_node(cur_node)
        )

        traverse_data_inplace(
            input_recorder, pop_after_forward, recorded_output=output_recorder,
        )

        # remove auxiliary tensor nodes from recorder_tensor
        output_nodes: NodeContainer[TensorNode] = (
            reduce_data_info(out, collect_tensor_node, NodeContainer())
        )

        for output_node in output_nodes:
            cur_node.add_output_nodes(output_node)
            output_node.context = input_context

        cur_node.set_output_shape(reduce_data_info(out, collect_shape, []))
        return out

    return _module_forward_wrapper

pop_after_forward(r_in, recorded_output)

Removes/pops nodes from RecorderTensors to maintain correct nodes Two types of process exist for types of modules: Non-inplace ops => pop auxiliary nodes In-place ops => pop input nodes since inplace ops overwrites input in memory.

Source code in torchview/recorder_tensor.py
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
def pop_after_forward(
    r_in: RecorderTensor,
    recorded_output: OrderedSet[RecorderTensor],
) -> None:
    '''Removes/pops nodes from RecorderTensors to maintain correct nodes
    Two types of process exist for types of modules:
    Non-inplace ops => pop auxiliary nodes
    In-place ops => pop input nodes since inplace ops overwrites input in memory.
    '''

    in_place_func_message = (
        'Tensor before and after inplace operation must have the same memory address'
    )
    output_id: OrderedSet[int] = OrderedSet(id(x) for x in recorded_output)

    if id(r_in) not in output_id:
        _ = reduce_data_info(
            r_in, collect_tensor_node, NodeContainer(), is_pop=True
        )

    # input of inplace operation
    else:
        assert id(r_in) == r_in.tensor_nodes[-1].tensor_id, (
            in_place_func_message
        )
        assert id(r_in) == r_in.tensor_nodes[-2].tensor_id, (
            in_place_func_message
        )
        # pop tensor node before inplace operation
        r_in.tensor_nodes.pop(-2)

process_output_node(cur_node)

Returns function to update output node after forward pass of nn.Modules

Source code in torchview/recorder_tensor.py
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
def process_output_node(
    cur_node: ModuleNode
) -> Callable[..., Any]:
    '''Returns function to update output node after forward
    pass of nn.Modules'''
    def _func(recorded_data: RecorderTensor) -> None:
        output_node = recorded_data.tensor_nodes[-1]
        cur_depth = cur_node.depth
        # if output node is reused inside module or empty module is used
        # introduce node for empty pass function
        if not output_node.is_leaf() or output_node.is_aux:
            insert_empty_pass_node(recorded_data, output_node)

        recorded_data.tensor_nodes[-1].depth = cur_depth
        name = 'output-tensor' if cur_depth == 0 else 'hidden-tensor'
        recorded_data.tensor_nodes[-1].name = name
        recorded_data.tensor_nodes[-1].parent_hierarchy[cur_depth] = cur_node
    return _func

reduce_data_info(recorded_data, action_fn, collected, **kwargs)

Apply action_fn to RecorderTensor inside recorded_data to collect info of input data into collected (Iterable) e.g. shape of RecorderTensor

Source code in torchview/recorder_tensor.py
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
def reduce_data_info(
    recorded_data: Any, action_fn: Callable[..., Any], collected: L, **kwargs: Any
) -> L:
    '''Apply action_fn to RecorderTensor inside recorded_data to collect info of
    input data into collected (Iterable) e.g. shape of RecorderTensor'''
    if isinstance(recorded_data, RecorderTensor):
        action_fn(recorded_data, collected, **kwargs)
    elif isinstance(recorded_data, Mapping):
        for r_d in recorded_data.values():
            reduce_data_info(r_d, action_fn, collected, **kwargs)
    elif (
        isinstance(recorded_data, Iterable) and
        not isinstance(recorded_data, (str, torch.Tensor))
    ):
        for r_d in recorded_data:
            reduce_data_info(r_d, action_fn, collected, **kwargs)
    return collected

traverse_data_inplace(recorded_data, action_fn, **kwargs)

Apply action_fn RecorderTensor objects inside recorded_data to change data Usuall action_fn is a function that transforms RecorderTensor in memory

Source code in torchview/recorder_tensor.py
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
def traverse_data_inplace(
    recorded_data: Any, action_fn: Callable[..., Any], **kwargs: Any
) -> None:
    '''Apply action_fn RecorderTensor objects inside recorded_data to change data
    Usuall action_fn is a function that transforms RecorderTensor in memory'''
    if isinstance(recorded_data, RecorderTensor):
        action_fn(recorded_data, **kwargs)
    elif isinstance(recorded_data, Mapping):
        for r_d in recorded_data.values():
            traverse_data_inplace(r_d, action_fn, **kwargs)
    elif (
        isinstance(recorded_data, Iterable) and
        not isinstance(recorded_data, (str, torch.Tensor))
    ):
        for r_d in recorded_data:
            traverse_data_inplace(r_d, action_fn, **kwargs)