参考:https://zhuanlan.zhihu.com/p/1911790481416909410

“LOP3” 是 NVIDIA CUDA / PTX(Parallel Thread Execution)中一个比较高级的位/布尔操作指令。它可以对 三个输入 做任意一种布尔逻辑组合(即三元逻辑操作),由一个 8-bit 的 LUT(查找表)立即数(immLut)来决定具体的逻辑函数。下面是详细解释和用法。


LOP3 的基本概念

具体地,immLut 的 8 位(bit0 到 bit7)对应于输入组合 (a_bit, b_bit, c_bit) 的 8 种可能(000,001,010,011,100,101,110,111)。如果对于某一种组合输出要为 1,那对应的 immLut 那一位就设为 1;否则设为 0。

StackOverflow 有比较清楚的解释:

“The lop3.b32 PTX instruction can perform a more-or-less arbitrary boolean (logical) operation on 3 variables A, B, and C. … we must provide a “lookup-table” immediate argument (immLut — an 8-bit quantity).” ([Stack Overflow]1)
也就是说,immLut 就是用来指定三输入逻辑函数的查找表。([Stack Overflow]1)


为什么有 LOP3:优点和用途

LOP3 在 PTX/硬件层面是个强大的工具,主要有以下几个优点:

  1. 减少指令数 / 合并逻辑
    如果你用传统的 AND / OR / XOR / NOT 等组合逻辑去做一个复杂的三输入布尔函数,可能要几条指令(比如先 NAND、再 XOR、再 AND、再 NOT 等)。用 LOP3,只需要一条指令 + 一个立即数,就能实现这个组合逻辑。这样可以节省指令数、寄存器临时值以及执行瓶颈。

  2. 按位并行处理
    LOP3 是对每一 bit 并行做逻辑操作(即对 32 位的每一 bit 分别看作 a_i, b_i, c_i,给出输出 bit_i)。所以对于向量 / 字位操作特别高效。

  3. 灵活性高
    只要你能把逻辑表达式写成三输入布尔函数,就可以通过一个 immLut 来表达 —— 相对于只能做固定几种逻辑(AND、OR、XOR)更通用。

  4. 硬件支持
    在较新的 NVIDIA 架构(从 Maxwell / Compute Capability ≥ 5.0 起)就支持 LOP3。 ([Stack Overflow]1)
    在很多高级加速库 / 优化里,用 LOP3 来替代多个布尔组合是常见做法。([知乎专栏]2)


一个例子说明

假设你要实现逻辑函数:

$$ F = (A \lor B) \land (\lnot C) $$

你希望用 lop3.b32 来完成这个操作。做法是:

  1. 你列出这个函数对于三输入的真值表(8 种组合)
  2. 然后把这些输出写成一个 8 位二进制数 immLut
  3. 在 PTX 中写 lop3.b32 d, a, b, c, immLut

StackOverflow 上有示例,假设这个函数对应的 immLut = 0x54(这只是举例),就写成:

lop3.b32 %d, %a, %b, %c, 0x54;

这个指令就对每个位执行相应的布尔逻辑。([Stack Overflow]1)


LOP3 vs PRMT 的区别 / 场景对比

__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
  assert(false);
#else
  uint4 result;
  // 现在拿到的 source 是 int4 类型,要转成 fp16 类型
  uint32_t* h = reinterpret_cast<uint32_t*>(&result);
  uint32_t const i4s = reinterpret_cast<uint32_t const&>(source);

  // First, we extract the i4s and construct an intermediate fp16 number.
  // 0xcc = 0b1100 1100, 0xaa = 0b1010 1010
  // 这里 immuLut 是 lop3 的操作符
  static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
  // 0~7个int4
  static constexpr uint32_t BOTTOM_MASK = 0x000f000f;
  static constexpr uint32_t TOP_MASK = 0x00f000f0;
  // 0x6400 == 0b0110 0100 0000 0000 ### 1 5=25 8
  // 这里 6400 的作用就是把 int4或者说int8转为 fp16,exp=25(-15)
  // 最后还要减一个 6480,也就是减一个(1024+128)=1152,其中1024是科学计数法的1.0中的1,128是int转uint时多加的。
  static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;

  // Note that the entire sequence only requires 1 shift instruction. This is
  // thanks to the register packing format and the fact that we force our
  // integers to be unsigned, and account for this in the fp16 subtractions. In
  // addition, I exploit the fact that sub and fma have the same throughput in
  // order to convert elt_23 and elt_67 to fp16 without having to shift them to
  // the bottom bits before hand.

  // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW
  // dependency if we issue immediately before required.
  const uint32_t top_i4s = i4s >> 8;
  // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
  asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
               : "=r"(h[0])
               : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM),
                 "n"(immLut));
  // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
  asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
               : "=r"(h[1])
               : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM),
                 "n"(immLut));
  // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
  asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
               : "=r"(h[2])
               : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM),
                 "n"(immLut));
  // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
  asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
               : "=r"(h[3])
               : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM),
                 "n"(immLut));

  // I use inline PTX below because I am not sure if the compiler will emit
  // float2half instructions if I use the half2 ctor. In this case, I chose
  // performance reliability over code readability.

  // This is the half2 {1032, 1032} represented as an integer.
  // static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408;
  // Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7]
  // 定义用于最终转换的魔数常量
  // 表示fp16格式的{1024, 1024}
  static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400;
  // This is the half2 {1 / 16, 1 / 16} represented as an integer.
  // 表示fp16格式的{1 / 16, 1 / 16},用于缩放高4位的值
  static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00;
  // This is the half2 {-72, -72} represented as an integer.
  // static constexpr uint32_t NEG_72 = 0xd480d480;
  // Haotian: Let's use {-64, -64}.
  // 表示fp16格式的{-64, -64},用于偏移校正
  static constexpr uint32_t NEG_64 = 0xd400d400;

  // 交织?
  // e7 e3 e5 e1 e6 e2 e4 e0

  // Finally, we construct the output numbers.
  // Convert elt_01
  asm volatile("sub.f16x2 %0, %1, %2;\n"
               : "=r"(h[0])
               : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
  // Convert elt_23
  asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
               : "=r"(h[1])
               : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
  // Convert elt_45
  asm volatile("sub.f16x2 %0, %1, %2;\n"
               : "=r"(h[2])
               : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
  // Convert elt_67
  asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
               : "=r"(h[3])
               : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));

  return result;
#endif
  __builtin_unreachable();  // Suppress missing return statement warning
}