[Mlir-commits] [mlir] [mlir][amdgpu] lowerings for ScaledExtPacked816 (PR #168123)

Jakub Kuderski llvmlistbot at llvm.org
Mon Nov 17 08:10:56 PST 2025


================
@@ -1613,6 +1613,154 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
   return success();
 }
 
+int32_t getScaleSel(int32_t blockSize, unsigned bitWidth,
+                    int32_t firstScaleLane, int32_t firstScaleByte) {
+  // When lowering amdgpu.scaled_ext_packed816 to rocdl.cvt.scale.pk*.f*.f*
+  // operations, the attributes blockSize, sourceType, firstScaleLane and
+  // firstScaleByte are merged into a single attribute scaleSel. This is how
+  // those values are merged together.
+  assert(llvm::is_contained({16, 32}, blockSize));
+  assert(llvm::is_contained(::llvm::ArrayRef<unsigned>{4, 6, 8}, bitWidth));
+
+  const bool is_fp8 = bitWidth == 8;
+  const bool is_block_16 = blockSize == 16;
+
+  if (!is_fp8) {
+    int bit_0 = is_block_16;
+    assert(llvm::is_contained({0, 2}, firstScaleByte));
+    int bit_1 = (firstScaleByte == 2) << 1;
+    assert(llvm::is_contained({0, 1}, firstScaleLane));
+    int bit_2 = firstScaleLane << 2;
+    return bit_2 | bit_1 | bit_0;
+  }
+
+  int bit_0 = is_block_16;
+  // firstScaleByte is guaranteed to be defined by two bits.
+  assert(llvm::is_contained({0, 1, 2, 3}, firstScaleByte));
+  int bit_2_and_1 = firstScaleByte << 1;
+  assert(llvm::is_contained({0, 1}, firstScaleLane));
+  int bit_3 = firstScaleLane << 3;
+  int bits = bit_3 | bit_2_and_1 | bit_0;
+  // These are invalid cases.
+  assert(!llvm::is_contained(
+      {0b0011, 0b0101, 0b0111, 0b1000, 0b1001, 0b1011, 0b1111}, bits));
+  return bits;
+}
+
+static std::optional<StringRef>
+scaledExtPacked816ToIntrinsic(Type srcElemType, Type destElemType) {
+  using fp4 = Float4E2M1FNType;
+  using fp8 = Float8E4M3FNType;
+  using bf8 = Float8E5M2Type;
+  using fp6 = Float6E2M3FNType;
+  using bf6 = Float6E3M2FNType;
+  if (isa<fp4>(srcElemType) && destElemType.isF16())
+    return ROCDL::CvtPkScalePk8F16Fp4Op::getOperationName();
+  if (isa<fp8>(srcElemType) && destElemType.isF16())
+    return ROCDL::CvtPkScalePk8F16Fp8Op::getOperationName();
+  if (isa<bf8>(srcElemType) && destElemType.isF16())
+    return ROCDL::CvtPkScalePk8F16Bf8Op::getOperationName();
+  if (isa<fp4>(srcElemType) && destElemType.isBF16())
+    return ROCDL::CvtPkScalePk8Bf16Fp4Op::getOperationName();
+  if (isa<fp8>(srcElemType) && destElemType.isBF16())
+    return ROCDL::CvtPkScalePk8Bf16Fp8Op::getOperationName();
+  if (isa<bf8>(srcElemType) && destElemType.isBF16())
+    return ROCDL::CvtPkScalePk8Bf16Bf8Op::getOperationName();
+  if (isa<fp4>(srcElemType) && destElemType.isF32())
+    return ROCDL::CvtPkScalePk8F32Fp4Op::getOperationName();
+  if (isa<fp8>(srcElemType) && destElemType.isF32())
+    return ROCDL::CvtPkScalePk8F32Fp8Op::getOperationName();
+  if (isa<bf8>(srcElemType) && destElemType.isF32())
+    return ROCDL::CvtPkScalePk8F32Bf8Op::getOperationName();
+  if (isa<fp6>(srcElemType) && destElemType.isF16())
+    return ROCDL::CvtPkScalePk16F16Fp6Op::getOperationName();
+  if (isa<bf6>(srcElemType) && destElemType.isF16())
+    return ROCDL::CvtPkScalePk16F16Bf6Op::getOperationName();
+  if (isa<fp6>(srcElemType) && destElemType.isBF16())
+    return ROCDL::CvtPkScalePk16Bf16Fp6Op::getOperationName();
+  if (isa<bf6>(srcElemType) && destElemType.isBF16())
+    return ROCDL::CvtPkScalePk16Bf16Bf6Op::getOperationName();
+  if (isa<fp6>(srcElemType) && destElemType.isF32())
+    return ROCDL::CvtPkScalePk16F32Fp6Op::getOperationName();
+  if (isa<bf6>(srcElemType) && destElemType.isF32())
+    return ROCDL::CvtPkScalePk16F32Bf6Op::getOperationName();
----------------
kuhar wrote:

I think these would be easier to follow if we introduced one more level of grouping, e.g., put a top-level `if` condition for each src type.

https://github.com/llvm/llvm-project/pull/168123


More information about the Mlir-commits mailing list