[Mlir-commits] [mlir] [mlir][amdgpu] Add lowerings for ScaledExtPacked816 (PR #168123)

Erick Ochoa Lopez llvmlistbot at llvm.org
Mon Nov 17 11:48:28 PST 2025


================
@@ -1613,6 +1613,154 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
   return success();
 }
 
+int32_t getScaleSel(int32_t blockSize, unsigned bitWidth,
+                    int32_t firstScaleLane, int32_t firstScaleByte) {
+  // When lowering amdgpu.scaled_ext_packed816 to rocdl.cvt.scale.pk*.f*.f*
+  // operations, the attributes blockSize, sourceType, firstScaleLane and
+  // firstScaleByte are merged into a single attribute scaleSel. This is how
+  // those values are merged together.
+  assert(llvm::is_contained({16, 32}, blockSize));
+  assert(llvm::is_contained(::llvm::ArrayRef<unsigned>{4, 6, 8}, bitWidth));
+
+  const bool is_fp8 = bitWidth == 8;
+  const bool is_block_16 = blockSize == 16;
+
+  if (!is_fp8) {
+    int bit_0 = is_block_16;
+    assert(llvm::is_contained({0, 2}, firstScaleByte));
+    int bit_1 = (firstScaleByte == 2) << 1;
+    assert(llvm::is_contained({0, 1}, firstScaleLane));
+    int bit_2 = firstScaleLane << 2;
+    return bit_2 | bit_1 | bit_0;
+  }
+
+  int bit_0 = is_block_16;
+  // firstScaleByte is guaranteed to be defined by two bits.
+  assert(llvm::is_contained({0, 1, 2, 3}, firstScaleByte));
+  int bit_2_and_1 = firstScaleByte << 1;
+  assert(llvm::is_contained({0, 1}, firstScaleLane));
+  int bit_3 = firstScaleLane << 3;
+  int bits = bit_3 | bit_2_and_1 | bit_0;
+  // These are invalid cases.
+  assert(!llvm::is_contained(
+      {0b0011, 0b0101, 0b0111, 0b1000, 0b1001, 0b1011, 0b1111}, bits));
+  return bits;
+}
+
+static std::optional<StringRef>
+scaledExtPacked816ToIntrinsic(Type srcElemType, Type destElemType) {
+  using fp4 = Float4E2M1FNType;
+  using fp8 = Float8E4M3FNType;
+  using bf8 = Float8E5M2Type;
+  using fp6 = Float6E2M3FNType;
+  using bf6 = Float6E3M2FNType;
+  if (isa<fp4>(srcElemType) && destElemType.isF16())
+    return ROCDL::CvtPkScalePk8F16Fp4Op::getOperationName();
+  if (isa<fp8>(srcElemType) && destElemType.isF16())
+    return ROCDL::CvtPkScalePk8F16Fp8Op::getOperationName();
+  if (isa<bf8>(srcElemType) && destElemType.isF16())
+    return ROCDL::CvtPkScalePk8F16Bf8Op::getOperationName();
+  if (isa<fp4>(srcElemType) && destElemType.isBF16())
+    return ROCDL::CvtPkScalePk8Bf16Fp4Op::getOperationName();
+  if (isa<fp8>(srcElemType) && destElemType.isBF16())
+    return ROCDL::CvtPkScalePk8Bf16Fp8Op::getOperationName();
+  if (isa<bf8>(srcElemType) && destElemType.isBF16())
+    return ROCDL::CvtPkScalePk8Bf16Bf8Op::getOperationName();
+  if (isa<fp4>(srcElemType) && destElemType.isF32())
+    return ROCDL::CvtPkScalePk8F32Fp4Op::getOperationName();
+  if (isa<fp8>(srcElemType) && destElemType.isF32())
+    return ROCDL::CvtPkScalePk8F32Fp8Op::getOperationName();
+  if (isa<bf8>(srcElemType) && destElemType.isF32())
+    return ROCDL::CvtPkScalePk8F32Bf8Op::getOperationName();
+  if (isa<fp6>(srcElemType) && destElemType.isF16())
+    return ROCDL::CvtPkScalePk16F16Fp6Op::getOperationName();
+  if (isa<bf6>(srcElemType) && destElemType.isF16())
+    return ROCDL::CvtPkScalePk16F16Bf6Op::getOperationName();
+  if (isa<fp6>(srcElemType) && destElemType.isBF16())
+    return ROCDL::CvtPkScalePk16Bf16Fp6Op::getOperationName();
+  if (isa<bf6>(srcElemType) && destElemType.isBF16())
+    return ROCDL::CvtPkScalePk16Bf16Bf6Op::getOperationName();
+  if (isa<fp6>(srcElemType) && destElemType.isF32())
+    return ROCDL::CvtPkScalePk16F32Fp6Op::getOperationName();
+  if (isa<bf6>(srcElemType) && destElemType.isF32())
+    return ROCDL::CvtPkScalePk16F32Bf6Op::getOperationName();
+  return std::nullopt;
+}
+
+LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
+    ScaledExtPacked816Op op, ScaledExtPacked816OpAdaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
+  using fp4 = Float4E2M1FNType;
+  using fp8 = Float8E4M3FNType;
+  using bf8 = Float8E5M2Type;
+  using fp6 = Float6E2M3FNType;
+  using bf6 = Float6E3M2FNType;
+  Location loc = op.getLoc();
+  if (chipset != Chipset{12, 5, 0}) {
+    return rewriter.notifyMatchFailure(
+        loc,
+        "Scaled fp packed conversion instructions are not available on target "
+        "architecture and their emulation is not implemented");
+  }
+  int32_t firstScaleLane = op.getFirstScaleLane();
+  int32_t firstScaleByte = op.getFirstScaleByte();
+  int32_t blockSize = op.getBlockSize();
+  auto sourceType = cast<VectorType>(op.getSource().getType());
+  auto srcElemType = cast<FloatType>(sourceType.getElementType());
+  unsigned bitWidth = srcElemType.getWidth();
+  int32_t scaleSel =
+      getScaleSel(blockSize, bitWidth, firstScaleLane, firstScaleByte);
+
+  auto targetType = cast<VectorType>(op.getResult().getType());
+  auto destElemType = cast<FloatType>(targetType.getElementType());
+  IntegerType i32 = rewriter.getI32Type();
+  Value castedScale =
+      LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getScale());
+
+  Value source = adaptor.getSource();
+  Type llvmResultType = typeConverter->convertType(op.getResult().getType());
+  Type packedType = nullptr;
+  if (isa<fp4>(srcElemType)) {
+    packedType = i32;
+    packedType = getTypeConverter()->convertType(packedType);
+  } else if (isa<fp8, bf8>(srcElemType)) {
+    packedType = VectorType::get(2, i32);
+    packedType = getTypeConverter()->convertType(packedType);
+  } else if (isa<fp6, bf6>(srcElemType)) {
+    packedType = VectorType::get(3, i32);
+    packedType = getTypeConverter()->convertType(packedType);
+  } else {
+    llvm_unreachable("invalid element type for packed scaled ext");
----------------
amd-eochoalo wrote:

https://github.com/llvm/llvm-project/pull/168123/commits/9860cdd9ffb108f472a18dff751a3401e13f695e

https://github.com/llvm/llvm-project/pull/168123


More information about the Mlir-commits mailing list