1. python decorator、 @torch.jit.trace 以及@torch.jhit.script

装饰器是一个 Python 语法糖:接收一个函数或类作为输入,返回一个新函数

@decorator 等价于:

func = decorator(func)

举例:

import time

def timer(func):
    def wrapper(*args, **kwargs):
        start = time.time()
        out = func(*args, **kwargs)
        print(f"{func.__name__} cost {time.time() - start:.3f}s")
        return out
    return wrapper

@timer
def add(a, b):
    return a + b

add(1, 2)

带参数的装饰器:

def repeat(n):
    def decorator(func):
        def wrapper(*args, **kwargs):
            for _ in range(n):
                out = func(*args, **kwargs)
            return out
        return wrapper
    return decorator

@repeat(3)
def hello():
    print("hi")

hello()
  1. @torch.jit.trace:基于运行轨迹的图捕获

用一次(或多次)真实输入运行模型,记录张量算子调用轨迹,生成 TorchScript 静态计算图

import torch
import torch.nn as nn

class MyModel(nn.Module):
    def forward(self, x):
        return x * 2 + 1

model = MyModel()
example = torch.randn(4, 8)

traced = torch.jit.trace(model, example)
print(traced.graph)

trace 的本质:执行一次 forward,记录发生了哪些 Tensor ops,得到的是一个 op-level DAG,不保留 Python 控制流;

⚠️ trace 只会记录示例输入走过的那条路径,对于其他 branch 不记录,所以 @torch.jit.trace 只适合 无 if / for 控制流依赖数据。

  1. @torch. jit.script

torch.jit.trace:把一次“实际运行”录成静态计算图
torch.jit.script:把模型“逻辑本身”编译成静态计算图

tracescript 得到的对象都是 torch.jit.ScriptModule(或 ScriptFunction),它是一个“可调用的、无 Python 依赖的、静态计算图对象”,不再执行 python forward,而是指向 pytorchscript IR,区别是script保持python控制流trace不保持,原理前者采样后者编译。

  1. 如果我需要输出特征图,或者获取一个model中命名规律的几个层e.g. l_1 l_2的输入输出,应该怎么做?

本质上是两个操作: