- python decorator、 @torch.jit.trace 以及@torch.jhit.script
装饰器是一个 Python 语法糖:接收一个函数或类作为输入,返回一个新函数
@decorator 等价于:
func = decorator(func)举例:
import time
def timer(func):
def wrapper(*args, **kwargs):
start = time.time()
out = func(*args, **kwargs)
print(f"{func.__name__} cost {time.time() - start:.3f}s")
return out
return wrapper
@timer
def add(a, b):
return a + b
add(1, 2)带参数的装饰器:
def repeat(n):
def decorator(func):
def wrapper(*args, **kwargs):
for _ in range(n):
out = func(*args, **kwargs)
return out
return wrapper
return decorator
@repeat(3)
def hello():
print("hi")
hello()@torch.jit.trace:基于运行轨迹的图捕获
用一次(或多次)真实输入运行模型,记录张量算子调用轨迹,生成 TorchScript 静态计算图
import torch
import torch.nn as nn
class MyModel(nn.Module):
def forward(self, x):
return x * 2 + 1
model = MyModel()
example = torch.randn(4, 8)
traced = torch.jit.trace(model, example)
print(traced.graph)trace 的本质:执行一次 forward,记录发生了哪些 Tensor ops,得到的是一个 op-level DAG,不保留 Python 控制流;
@torch. jit.script
- 将 Python 子集直接编译为 TorchScript,保留控制流、类型和语义
torch.jit.trace:把一次“实际运行”录成静态计算图
torch.jit.script:把模型“逻辑本身”编译成静态计算图
trace 和 script 得到的对象都是 torch.jit.ScriptModule(或 ScriptFunction),它是一个“可调用的、无 Python 依赖的、静态计算图对象”,不再执行 python forward,而是指向 pytorchscript IR,区别是script保持python控制流trace不保持,原理前者采样后者编译。
- 如果我需要输出特征图,或者获取一个model中命名规律的几个层e.g. l_1 l_2的输入输出,应该怎么做?
本质上是两个操作: