参考:https://zhuanlan.zhihu.com/p/1911790481416909410
“LOP3” 是 NVIDIA CUDA / PTX(Parallel Thread Execution)中一个比较高级的位/布尔操作指令。它可以对 三个输入 做任意一种布尔逻辑组合(即三元逻辑操作),由一个 8-bit 的 LUT(查找表)立即数(immLut)来决定具体的逻辑函数。下面是详细解释和用法。
LOP3 的基本概念
- 指令形式(PTX)是
lop3.b32 d, a, b, c, immLut;(也有.u32、.b32等变体) - 它的作用是:把
a, b, c这三个 32-bit 值作为布尔变量,按位进行逻辑操作。immLut是一个 8-bit 的常数,表示一个从 3 输入到 1 输出的真值表(truth table)。 - 换句话说,LOP3 可以实现任何三输入布尔函数 —— 例如
F = (a AND b) OR (NOT c),或者F = (a XOR b) AND c,等等。只要你用合适的immLut表格,就能让lop3做这个函数。
具体地,immLut 的 8 位(bit0 到 bit7)对应于输入组合 (a_bit, b_bit, c_bit) 的 8 种可能(000,001,010,011,100,101,110,111)。如果对于某一种组合输出要为 1,那对应的 immLut 那一位就设为 1;否则设为 0。
StackOverflow 有比较清楚的解释:
“The lop3.b32 PTX instruction can perform a more-or-less arbitrary boolean (logical) operation on 3 variables A, B, and C. … we must provide a “lookup-table” immediate argument (immLut — an 8-bit quantity).” ([Stack Overflow]1)
也就是说,immLut就是用来指定三输入逻辑函数的查找表。([Stack Overflow]1)
为什么有 LOP3:优点和用途
LOP3 在 PTX/硬件层面是个强大的工具,主要有以下几个优点:
-
减少指令数 / 合并逻辑
如果你用传统的 AND / OR / XOR / NOT 等组合逻辑去做一个复杂的三输入布尔函数,可能要几条指令(比如先 NAND、再 XOR、再 AND、再 NOT 等)。用 LOP3,只需要一条指令 + 一个立即数,就能实现这个组合逻辑。这样可以节省指令数、寄存器临时值以及执行瓶颈。 -
按位并行处理
LOP3 是对每一 bit 并行做逻辑操作(即对 32 位的每一 bit 分别看作 a_i, b_i, c_i,给出输出 bit_i)。所以对于向量 / 字位操作特别高效。 -
灵活性高
只要你能把逻辑表达式写成三输入布尔函数,就可以通过一个immLut来表达 —— 相对于只能做固定几种逻辑(AND、OR、XOR)更通用。 -
硬件支持
在较新的 NVIDIA 架构(从 Maxwell / Compute Capability ≥ 5.0 起)就支持 LOP3。 ([Stack Overflow]1)
在很多高级加速库 / 优化里,用 LOP3 来替代多个布尔组合是常见做法。([知乎专栏]2)
一个例子说明
假设你要实现逻辑函数:
$$ F = (A \lor B) \land (\lnot C) $$
你希望用 lop3.b32 来完成这个操作。做法是:
- 你列出这个函数对于三输入的真值表(8 种组合)
- 然后把这些输出写成一个 8 位二进制数
immLut - 在 PTX 中写
lop3.b32 d, a, b, c, immLut
StackOverflow 上有示例,假设这个函数对应的 immLut = 0x54(这只是举例),就写成:
lop3.b32 %d, %a, %b, %c, 0x54;这个指令就对每个位执行相应的布尔逻辑。([Stack Overflow]1)
LOP3 vs PRMT 的区别 / 场景对比
-
PRMT(或prmt.b32)是字节级(byte-level)的重排 / permute 指令,用来做字节重组 / 重排 / 插入 / 混合常量等操作。它的粒度是每个字节的选择。 -
LOP3的粒度是位(bit-level),做的是逻辑操作。用它可以非常灵活地对每一位做布尔组合。比如在 INT4 / INT8 等量化技巧里,对单个位做条件掩码 / 重组 / bit 插入等操作,LOP3 很常见。([知乎专栏]2) -
在某些更高级的量化 / 解码 / 反量化方案里(尤其是 INT4、半比特 / nibble 级别的重组),开发者可能会用 LOP3 替代 PRMT 以得到更精细的控制。正如某篇文章中说的:
在 NV FasterTransformer 的实现中,使用了另一个指令 LOP3 来替代 PRMT,从而完成 INT4 快速反量化到 FP16/BF16 的核心逻辑。([知乎专栏]2)
__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
assert(false);
#else
uint4 result;
// 现在拿到的 source 是 int4 类型,要转成 fp16 类型
uint32_t* h = reinterpret_cast<uint32_t*>(&result);
uint32_t const i4s = reinterpret_cast<uint32_t const&>(source);
// First, we extract the i4s and construct an intermediate fp16 number.
// 0xcc = 0b1100 1100, 0xaa = 0b1010 1010
// 这里 immuLut 是 lop3 的操作符
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
// 0~7个int4
static constexpr uint32_t BOTTOM_MASK = 0x000f000f;
static constexpr uint32_t TOP_MASK = 0x00f000f0;
// 0x6400 == 0b0110 0100 0000 0000 ### 1 5=25 8
// 这里 6400 的作用就是把 int4或者说int8转为 fp16,exp=25(-15)
// 最后还要减一个 6480,也就是减一个(1024+128)=1152,其中1024是科学计数法的1.0中的1,128是int转uint时多加的。
static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;
// Note that the entire sequence only requires 1 shift instruction. This is
// thanks to the register packing format and the fact that we force our
// integers to be unsigned, and account for this in the fp16 subtractions. In
// addition, I exploit the fact that sub and fma have the same throughput in
// order to convert elt_23 and elt_67 to fp16 without having to shift them to
// the bottom bits before hand.
// Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW
// dependency if we issue immediately before required.
const uint32_t top_i4s = i4s >> 8;
// Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[0])
: "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM),
"n"(immLut));
// Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[1])
: "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM),
"n"(immLut));
// Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[2])
: "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM),
"n"(immLut));
// Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[3])
: "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM),
"n"(immLut));
// I use inline PTX below because I am not sure if the compiler will emit
// float2half instructions if I use the half2 ctor. In this case, I chose
// performance reliability over code readability.
// This is the half2 {1032, 1032} represented as an integer.
// static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408;
// Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7]
// 定义用于最终转换的魔数常量
// 表示fp16格式的{1024, 1024}
static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400;
// This is the half2 {1 / 16, 1 / 16} represented as an integer.
// 表示fp16格式的{1 / 16, 1 / 16},用于缩放高4位的值
static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00;
// This is the half2 {-72, -72} represented as an integer.
// static constexpr uint32_t NEG_72 = 0xd480d480;
// Haotian: Let's use {-64, -64}.
// 表示fp16格式的{-64, -64},用于偏移校正
static constexpr uint32_t NEG_64 = 0xd400d400;
// 交织?
// e7 e3 e5 e1 e6 e2 e4 e0
// Finally, we construct the output numbers.
// Convert elt_01
asm volatile("sub.f16x2 %0, %1, %2;\n"
: "=r"(h[0])
: "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
// Convert elt_23
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
: "=r"(h[1])
: "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
// Convert elt_45
asm volatile("sub.f16x2 %0, %1, %2;\n"
: "=r"(h[2])
: "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
// Convert elt_67
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
: "=r"(h[3])
: "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
return result;
#endif
__builtin_unreachable(); // Suppress missing return statement warning
}