[Mlir-commits] [mlir] [mlir][amdgpu] lowerings for ScaledExtPacked816 (PR #168123)

Erick Ochoa Lopez llvmlistbot at llvm.org
Mon Nov 17 05:36:56 PST 2025


================
@@ -1613,6 +1613,182 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
   return success();
 }
 
+int getScaleSel(int blockSize, int bitWidth, int firstScaleLane,
+                int firstScaleByte) {
+  // When lowering amdgpu.scaled_ext_packed816 to
+  // rocdl.cvt.scale.pk*.f*.f* operations, the
+  // attributes blockSize, sourceType, firstScaleLane and firstScaleByte
+  // are merged into a single attribute scaleSel.
+  //
+  // This is how those values are merged together.
+  assert(llvm::is_contained({16, 32}, blockSize));
+  assert(llvm::is_contained({4, 6, 8}, bitWidth));
+
+  const bool is_fp8 = bitWidth == 8;
+  const bool is_block_16 = blockSize == 16;
+
+  if (!is_fp8 && firstScaleLane == 0 && firstScaleByte == 0 && !is_block_16) {
+    return 0b000;
+  }
+  if (!is_fp8 && firstScaleLane == 0 && firstScaleByte == 0 && is_block_16) {
+    return 0b001;
+  }
+  if (!is_fp8 && firstScaleLane == 0 && firstScaleByte == 2 && !is_block_16) {
+    return 0b010;
+  }
+  if (!is_fp8 && firstScaleLane == 0 && firstScaleByte == 2 && is_block_16) {
+    return 0b011;
+  }
+  if (!is_fp8 && firstScaleLane == 1 && firstScaleByte == 0 && !is_block_16) {
+    return 0b100;
+  }
+  if (!is_fp8 && firstScaleLane == 1 && firstScaleByte == 0 && is_block_16) {
+    return 0b101;
+  }
+  if (!is_fp8 && firstScaleLane == 1 && firstScaleByte == 2 && !is_block_16) {
+    return 0b110;
+  }
+  if (!is_fp8 && firstScaleLane == 1 && firstScaleByte == 2 && is_block_16) {
+    return 0b111;
+  }
+
+  if (is_fp8 && firstScaleLane == 0 && firstScaleByte == 0 && !is_block_16) {
+    return 0b0000;
+  }
+  if (is_fp8 && firstScaleLane == 0 && firstScaleByte == 0 && is_block_16) {
+    return 0b0001;
+  }
+  if (is_fp8 && firstScaleLane == 0 && firstScaleByte == 1 && !is_block_16) {
+    return 0b0010;
+  }
+  if (is_fp8 && firstScaleLane == 0 && firstScaleByte == 2 && !is_block_16) {
+    return 0b0100;
+  }
+  if (is_fp8 && firstScaleLane == 0 && firstScaleByte == 3 && !is_block_16) {
+    return 0b0110;
+  }
+  if (is_fp8 && firstScaleLane == 1 && firstScaleByte == 1 && !is_block_16) {
+    return 0b1010;
+  }
+  if (is_fp8 && firstScaleLane == 1 && firstScaleByte == 2 && !is_block_16) {
+    return 0b1100;
+  }
+  if (is_fp8 && firstScaleLane == 1 && firstScaleByte == 2 && is_block_16) {
+    return 0b1101;
+  }
+  if (is_fp8 && firstScaleLane == 1 && firstScaleByte == 3 && !is_block_16) {
+    return 0b1110;
+  }
+
+  llvm_unreachable("invalid combination of firstScaleLane, firstScaleByte, "
+                   "blockSize and type.");
+  return 0;
+}
+
+LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
+    ScaledExtPacked816Op op, ScaledExtPacked816OpAdaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
+
+  int firstScaleLane = op.getFirstScaleLane();
+  int firstScaleByte = op.getFirstScaleByte();
+  int blockSize = op.getBlockSize();
+  auto sourceType = cast<VectorType>(op.getSource().getType());
+  auto srcElemType = cast<FloatType>(sourceType.getElementType());
+  int bitWidth = srcElemType.getWidth();
+  int scaleSel =
+      getScaleSel(blockSize, bitWidth, firstScaleLane, firstScaleByte);
+
+  auto targetType = cast<VectorType>(op.getResult().getType());
+  auto destElemType = cast<FloatType>(targetType.getElementType());
+  Location loc = op.getLoc();
+  IntegerType i32 = rewriter.getI32Type();
+  Value castedScale =
+      LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getScale());
+
+  Value source = adaptor.getSource();
+  Type packedType;
+  if (isa<Float4E2M1FNType>(srcElemType)) {
+    packedType = i32;
+    packedType = getTypeConverter()->convertType(packedType);
+  } else if (isa<Float8E4M3FNType>(srcElemType) ||
+             isa<Float8E5M2Type>(srcElemType)) {
+    packedType = VectorType::get(2, i32);
+    packedType = getTypeConverter()->convertType(packedType);
+  } else if (isa<Float6E2M3FNType>(srcElemType) ||
+             isa<Float6E3M2FNType>(srcElemType)) {
+    packedType = VectorType::get(3, i32);
+    packedType = getTypeConverter()->convertType(packedType);
+  } else {
+    llvm_unreachable("invalid element type for scaled ext");
+  }
+  // smallT = [Fp4, Fp8, Bf8]
+  //           Bf8 = E5M2
+  //           Fp8 = E4M3
+  //
+  // largeT = [F16, Bf16, F32]
+  // CvtPkScalePk8${largeT}${smallT}
+  Value castedSource =
+      LLVM::BitcastOp::create(rewriter, loc, packedType, source);
+
+  if (isa<Float4E2M1FNType>(srcElemType) && destElemType.isF16()) {
+    rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F16Fp4Op>(
+        op, op.getResult().getType(), castedSource, castedScale, scaleSel);
+  } else if (isa<Float8E4M3FNType>(srcElemType) && destElemType.isF16()) {
+    rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F16Fp8Op>(
+        op, op.getResult().getType(), castedSource, castedScale, scaleSel);
+  } else if (isa<Float8E5M2Type>(srcElemType) && destElemType.isF16()) {
+    rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F16Bf8Op>(
+        op, op.getResult().getType(), castedSource, castedScale, scaleSel);
+  } else if (isa<Float4E2M1FNType>(srcElemType) && destElemType.isBF16()) {
+    rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8Bf16Fp4Op>(
+        op, op.getResult().getType(), castedSource, castedScale, scaleSel);
+  } else if (isa<Float8E4M3FNType>(srcElemType) && destElemType.isBF16()) {
+    rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8Bf16Fp8Op>(
+        op, op.getResult().getType(), castedSource, castedScale, scaleSel);
+  } else if (isa<Float8E5M2Type>(srcElemType) && destElemType.isBF16()) {
+    rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8Bf16Bf8Op>(
+        op, op.getResult().getType(), castedSource, castedScale, scaleSel);
+  } else if (isa<Float4E2M1FNType>(srcElemType) && destElemType.isF32()) {
+    rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F32Fp4Op>(
+        op, op.getResult().getType(), castedSource, castedScale, scaleSel);
+  } else if (isa<Float8E4M3FNType>(srcElemType) && destElemType.isF32()) {
+    rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F32Fp8Op>(
+        op, op.getResult().getType(), castedSource, castedScale, scaleSel);
+  } else if (isa<Float8E5M2Type>(srcElemType) && destElemType.isF32()) {
+    rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F32Bf8Op>(
+        op, op.getResult().getType(), castedSource, castedScale, scaleSel);
+  }
+  // smallT = [Fp6, Bf6]
+  //           Fp6 = Float6E2M3FN
+  //           Bf6 = Float6E3M2FN
+  // largeT = [F16, Bf16, F32]
+  //
+  // CvtPkScalePk16${largeT}${smallT}
+  else if (isa<Float6E2M3FNType>(srcElemType) && destElemType.isF16()) {
+    rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16F16Fp6Op>(
+        op, op.getResult().getType(), castedSource, castedScale, scaleSel);
+  } else if (isa<Float6E3M2FNType>(srcElemType) && destElemType.isF16()) {
+    rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16F16Bf6Op>(
+        op, op.getResult().getType(), castedSource, castedScale, scaleSel);
+  } else if (isa<Float6E2M3FNType>(srcElemType) && destElemType.isBF16()) {
+    rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16Bf16Fp6Op>(
+        op, op.getResult().getType(), castedSource, castedScale, scaleSel);
+  } else if (isa<Float6E3M2FNType>(srcElemType) && destElemType.isBF16()) {
+    rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16Bf16Bf6Op>(
+        op, op.getResult().getType(), castedSource, castedScale, scaleSel);
+  } else if (isa<Float6E2M3FNType>(srcElemType) && destElemType.isF32()) {
+    rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16F32Fp6Op>(
+        op, op.getResult().getType(), castedSource, castedScale, scaleSel);
+  } else if (isa<Float6E3M2FNType>(srcElemType) && destElemType.isF32()) {
+    rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16F32Bf6Op>(
+        op, op.getResult().getType(), castedSource, castedScale, scaleSel);
+  } else {
+    return failure();
+  }
+
+  return success();
----------------
amd-eochoalo wrote:

Yes. We can chat on Slack.

https://github.com/llvm/llvm-project/pull/168123


More information about the Mlir-commits mailing list