[Mlir-commits] [mlir] [mlir][amdgpu] lowerings for ScaledExtPacked816 (PR #168123)

Erick Ochoa Lopez llvmlistbot at llvm.org
Mon Nov 17 11:45:43 PST 2025


https://github.com/amd-eochoalo updated https://github.com/llvm/llvm-project/pull/168123

>From 7b64631ffa0f4e2880d0a839443e450120da7a68 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Mon, 10 Nov 2025 13:23:20 -0500
Subject: [PATCH 01/43] Update documentation

---
 mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 35 +++++++++++++------
 1 file changed, 25 insertions(+), 10 deletions(-)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 45cb67f0eee4a..4820b7a747ac2 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -127,7 +127,7 @@ def AMDGPU_ScaledExtPacked816Op
           FixedVectorOfShapeAndType<[4], F8E8M0FNU>:$scale,
           ConfinedAttr<I32Attr, [IsValidBlockSize]>:$blockSize,
           ConfinedAttr<I32Attr, [IntMinValue<0>, IntMaxValue<1>]>:$firstScaleLane,
-          ConfinedAttr<I32Attr, [IntMinValue<0>, IntMaxValue<2>]>:$firstScaleByte)>,
+          ConfinedAttr<I32Attr, [IntMinValue<0>, IntMaxValue<3>]>:$firstScaleByte)>,
       Results<(
           outs AnyTypeOf<[FixedVectorOfShapeAndType<[8], F32>,
                           FixedVectorOfShapeAndType<[8], F16>,
@@ -139,17 +139,21 @@ def AMDGPU_ScaledExtPacked816Op
   let summary = "Extend a vector of packed floating point values";
 
   let description = [{
-    The scales applied to the input microfloats are stored in two bytes which
+    The scales applied to the input microfloats are stored in bytes which
     come from the `scales` input provided in a *half* of the wave identified
-    by `firstScaleLane`. The pair of bytes used is selected by
-    `firstScaleByte`. The 16 vectors in consecutive lanes starting from
+    by `firstScaleLane`. The bytes used is selected by `firstScaleByte` and depends
+    on the type of `source`. The 16 vectors in consecutive lanes starting from
     `firstScaleLane` (which we'll call the scale vectors) will be used by both
-    halves of the wave (with lane L reading from L % 16'th scale vector), but
-    each half will use a different byte.
+    halves of the wave (with lane L reading from L % 16'th scale vector).
+
+    When `source` is either F4E2M1FN, F6E2M3FN, or F6E3M2FN each half of the
+    wave will use a different byte. The first one being `firstScaleByte` and
+    the second one being `firstScaleByte` + 1. When the block size is 32,
+    `firstScaleByte` can be either 0 or 2, selecting halves of the scale vectors.
+    Lanes 0-15 will read from `firstScaleByte` and lanes 16-31 will read
+    from `firstScaleByte` + 1.
+
 
-    When the block size is 32, `firstScaleByte` can be either 0 or 2,
-    selecting halves of the scale vectors. Lanes 0-15 will read from
-    `firstScaleByte` and lanes 16-31 will read from `firstScaleByte` + 1.
     For example:
     ```mlir
     // Input: 8-element vector of F8E4M3FN, converting to F32
@@ -165,7 +169,8 @@ def AMDGPU_ScaledExtPacked816Op
       : vector<16xf6E2M3FN>, vector<4xf8E8M0FNU> -> vector<16xf16>
     ```
 
-    However, when the block size is 16, `firstScaleByte` can be 0 or 1.
+    When `source` is either F4E2M1FN, F6E2M3FN, or F6E3M2FN and
+    the block size is 16, `firstScaleByte` can be 0 or 1.
     Lanes 0-15 read from the `firstScaleByte`th element of the scale vectors,
     while lanes 16-31 read from `firstScaleByte` + 2.
     For example:
@@ -187,6 +192,16 @@ def AMDGPU_ScaledExtPacked816Op
     instructions use for matix scales. These selection operands allows
     one to choose portions of the matrix to convert.
 
+    When `source` is either F8E4M3FN or F8E5M2 and `blockSize` is 32,
+    then the same byte will be used by both halves of the wave.
+    In this case, `firstScaleByte` can be any value from 0 to 3.
+
+    When `source` is either F8E4M3FN or F8E5M2 and `blockSize` is 16,
+    following combinations are allowed:
+    * `firstScaleLane(0), firstScaleByte(0)`
+    * `firstScaleLane(1), firstScaleByte(2)`
+    all other combinations are reserved.
+
     Available on gfx1250+.
   }];
 

>From 08e96b19369451dd5ec4e72ed2905bd0b2e0cf71 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Mon, 10 Nov 2025 13:44:08 -0500
Subject: [PATCH 02/43] Fix verifiers

---
 mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 23 +++++++++++++++-----
 mlir/test/Dialect/AMDGPU/invalid.mlir        | 20 ++++++++++++-----
 2 files changed, 32 insertions(+), 11 deletions(-)

diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index df955fc90b45f..5c35823678576 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -344,14 +344,27 @@ void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns(
 LogicalResult ScaledExtPacked816Op::verify() {
   int blockSize = getBlockSize();
   assert((blockSize == 16 || blockSize == 32) && "invalid block size");
+
   int firstScaleByte = getFirstScaleByte();
-  if (blockSize == 16 && !llvm::is_contained({0, 1}, firstScaleByte)) {
-    return emitOpError(
-        "blockSize of 16 can only have firstScaleByte be 0 or 1.");
+  auto sourceType = cast<VectorType>(getSource().getType());
+  Type elementType = sourceType.getElementType();
+  auto floatType = cast<FloatType>(elementType);
+  int bitWidth = floatType.getWidth();
+
+  if (llvm::is_contained({4, 6}, bitWidth) && blockSize == 16 &&
+      !llvm::is_contained({0, 1}, firstScaleByte)) {
+    return emitOpError("blockSize of 16 can only have firstScaleByte be 0 or 1 "
+                       "for f4 and f6.");
+  }
+  if (llvm::is_contained({4, 6}, bitWidth) && blockSize == 32 &&
+      !llvm::is_contained({0, 2}, firstScaleByte)) {
+    return emitOpError("blockSize of 32 can only have firstScaleByte be 0 or 2 "
+                       "for f4 and f6.");
   }
-  if (blockSize == 32 && !llvm::is_contained({0, 2}, firstScaleByte)) {
+  if (bitWidth == 8 && blockSize == 16 &&
+      !llvm::is_contained({0, 2}, firstScaleByte)) {
     return emitOpError(
-        "blockSize of 32 can only have firstScaleByte be 0 or 2.");
+        "blockSize of 16 can only have firstScaleByte be 0 or 2 for f8.");
   }
 
   return success();
diff --git a/mlir/test/Dialect/AMDGPU/invalid.mlir b/mlir/test/Dialect/AMDGPU/invalid.mlir
index 4c6f62a045405..5c8cc8b67c4b3 100644
--- a/mlir/test/Dialect/AMDGPU/invalid.mlir
+++ b/mlir/test/Dialect/AMDGPU/invalid.mlir
@@ -333,17 +333,25 @@ func.func @gather_to_lds_non_lds(%idx1 : index, %mem1 : memref<32xf16>, %mem2 :
 
 // -----
 
-func.func @amdgpu.scaled_ext_packed816_invalid_block_size_and_first_scale_byte_16(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) {
-  // expected-error at +1 {{'amdgpu.scaled_ext_packed816' op blockSize of 16 can only have firstScaleByte be 0 or 1.}}
-  %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(16) firstScaleLane(0) firstScaleByte(2) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xf16>
+func.func @amdgpu.scaled_ext_packed816_invalid_block_size_and_first_scale_byte_16(%v: vector<8xf4E2M1FN>, %scale: vector<4xf8E8M0FNU>) {
+  // expected-error at +1 {{'amdgpu.scaled_ext_packed816' op blockSize of 16 can only have firstScaleByte be 0 or 1 for f4 and f6}}
+  %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(16) firstScaleLane(0) firstScaleByte(2) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xf16>
   func.return
 }
 
 // -----
 
-func.func @amdgpu.scaled_ext_packed816_invalid_block_size_and_first_scale_byte_32(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) {
-  // expected-error at +1 {{'amdgpu.scaled_ext_packed816' op blockSize of 32 can only have firstScaleByte be 0 or 2.}}
-  %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(1) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xf16>
+func.func @amdgpu.scaled_ext_packed816_invalid_block_size_and_first_scale_byte_32(%v: vector<8xf4E2M1FN>, %scale: vector<4xf8E8M0FNU>) {
+  // expected-error at +1 {{'amdgpu.scaled_ext_packed816' op blockSize of 32 can only have firstScaleByte be 0 or 2 for f4 and f6.}}
+  %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(1) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xf16>
+  func.return
+}
+
+// -----
+
+func.func @amdgpu.scaled_ext_packed816_invalid_attributes_for_f8(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) {
+  // expected-error at +1 {{'amdgpu.scaled_ext_packed816' op blockSize of 16 can only have firstScaleByte be 0 or 2 for f8.}}
+  %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(16) firstScaleLane(0) firstScaleByte(1) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xf16>
   func.return
 }
 

>From d0932cc2935840b7e86dc700f0ec056ccabebdea Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Mon, 10 Nov 2025 11:13:57 -0500
Subject: [PATCH 03/43] [mlir][amdgpu] Convert scaled_ext_packed816 to rocdl

---
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           | 28 ++++++++++++++++---
 1 file changed, 24 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 3a307a0756d93..48c1b17a2203a 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1492,6 +1492,19 @@ struct ExtPackedFp8OpLowering final
                   ConversionPatternRewriter &rewriter) const override;
 };
 
+struct ScaledExtPacked816OpLowering final
+    : public ConvertOpToLLVMPattern<ScaledExtPacked816Op> {
+  ScaledExtPacked816OpLowering(const LLVMTypeConverter &converter,
+                               Chipset chipset)
+      : ConvertOpToLLVMPattern<amdgpu::ScaledExtPacked816Op>(converter),
+        chipset(chipset) {}
+  Chipset chipset;
+
+  LogicalResult
+  matchAndRewrite(ScaledExtPacked816Op op, ScaledExtPacked816OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+
 struct PackedTrunc2xFp8OpLowering final
     : public ConvertOpToLLVMPattern<PackedTrunc2xFp8Op> {
   PackedTrunc2xFp8OpLowering(const LLVMTypeConverter &converter,
@@ -1600,6 +1613,12 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
   return success();
 }
 
+LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
+    ScaledExtPacked816Op op, ScaledExtPacked816OpAdaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
+  return failure();
+}
+
 LogicalResult ScaledExtPackedOpLowering::matchAndRewrite(
     ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
     ConversionPatternRewriter &rewriter) const {
@@ -2138,9 +2157,10 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
                                ROCDL::RawPtrBufferAtomicCmpSwap>,
            AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering,
            SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
-           WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPackedOpLowering,
-           PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
-           PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
-           TransposeLoadOpLowering, AMDGPUPermlaneLowering>(converter, chipset);
+           WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPacked816OpLowering,
+           ScaledExtPackedOpLowering, PackedScaledTruncOpLowering,
+           PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
+           GatherToLDSOpLowering, TransposeLoadOpLowering,
+           AMDGPUPermlaneLowering>(converter, chipset);
   patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
 }

>From 0d1d668762b534c90f2f838788b495c09b872d81 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Mon, 10 Nov 2025 15:43:41 -0500
Subject: [PATCH 04/43] Create skeleton for pattern

---
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           | 120 ++++++++++++++++++
 1 file changed, 120 insertions(+)

diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 48c1b17a2203a..568013bee5ec8 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1613,9 +1613,129 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
   return success();
 }
 
+int getScaleSel(int blockSize, int bitWidth, int firstScaleLane,
+                int firstScaleByte) {
+  // When lowering amdgpu.scaled_ext_packed816 to
+  // rocdl.cvt.scale.pk*.f*.f* operations, the
+  // attributes blockSize, sourceType, firstScaleLane and firstScaleByte
+  // are merged into a single attribute scaleSel.
+  //
+  // This is how those values are merged together.
+  assert(llvm::is_contained({16, 32}, blockSize));
+  assert(llvm::is_contained({4, 6, 8}, bitWidth));
+
+  const bool is_fp8 = bitWidth == 8;
+  const bool is_block_16 = blockSize == 16;
+
+  if (!is_fp8 && firstScaleLane == 0 && firstScaleByte == 0 && !is_block_16) {
+    return 0b000;
+  }
+  if (!is_fp8 && firstScaleLane == 0 && firstScaleByte == 0 && is_block_16) {
+    return 0b001;
+  }
+  if (!is_fp8 && firstScaleLane == 0 && firstScaleByte == 2 && !is_block_16) {
+    return 0b010;
+  }
+  if (!is_fp8 && firstScaleLane == 0 && firstScaleByte == 2 && is_block_16) {
+    return 0b011;
+  }
+  if (!is_fp8 && firstScaleLane == 1 && firstScaleByte == 0 && !is_block_16) {
+    return 0b100;
+  }
+  if (!is_fp8 && firstScaleLane == 1 && firstScaleByte == 0 && is_block_16) {
+    return 0b101;
+  }
+  if (!is_fp8 && firstScaleLane == 1 && firstScaleByte == 2 && !is_block_16) {
+    return 0b110;
+  }
+  if (!is_fp8 && firstScaleLane == 1 && firstScaleByte == 2 && is_block_16) {
+    return 0b111;
+  }
+
+  if (is_fp8 && firstScaleLane == 0 && firstScaleByte == 0 && !is_block_16) {
+    return 0b0000;
+  }
+  if (is_fp8 && firstScaleLane == 0 && firstScaleByte == 0 && is_block_16) {
+    return 0b0001;
+  }
+  if (is_fp8 && firstScaleLane == 0 && firstScaleByte == 1 && !is_block_16) {
+    return 0b0010;
+  }
+  if (is_fp8 && firstScaleLane == 0 && firstScaleByte == 2 && !is_block_16) {
+    return 0b0100;
+  }
+  if (is_fp8 && firstScaleLane == 0 && firstScaleByte == 3 && !is_block_16) {
+    return 0b0110;
+  }
+  if (is_fp8 && firstScaleLane == 1 && firstScaleByte == 1 && !is_block_16) {
+    return 0b1010;
+  }
+  if (is_fp8 && firstScaleLane == 1 && firstScaleByte == 2 && !is_block_16) {
+    return 0b1100;
+  }
+  if (is_fp8 && firstScaleLane == 1 && firstScaleByte == 2 && is_block_16) {
+    return 0b1101;
+  }
+  if (is_fp8 && firstScaleLane == 1 && firstScaleByte == 3 && !is_block_16) {
+    return 0b1110;
+  }
+
+  llvm_unreachable("invalid combination of firstScaleLane, firstScaleByte, "
+                   "blockSize and type.");
+  return 0;
+}
+
 LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
     ScaledExtPacked816Op op, ScaledExtPacked816OpAdaptor adaptor,
     ConversionPatternRewriter &rewriter) const {
+
+  int firstScaleLane = op.getFirstScaleLane();
+  int firstScaleByte = op.getFirstScaleByte();
+  int blockSize = op.getBlockSize();
+  auto sourceType = cast<VectorType>(op.getSource().getType());
+  auto srcElemType = cast<FloatType>(sourceType.getElementType());
+  int bitWidth = srcElemType.getWidth();
+  int scaleSel =
+      getScaleSel(blockSize, bitWidth, firstScaleLane, firstScaleByte);
+
+  auto targetType = cast<VectorType>(op.getResult().getType());
+  auto tgtElemType = cast<FloatType>(targetType.getElementType());
+  Location loc = op.getLoc();
+  // Ok, so we need to construct ops depending on the sourceType and targetType.
+  // smallT = [Fp4, Fp8, Bf8]
+  // largeT = [F16, Bf16, F32]
+  // CvtPkScalePk{8}${largeT}${smallT}
+  if (isa<Float4E2M1FNType>(srcElemType) and isa<Float16Type>(tgtElemType)) {
+    ROCDL::CvtPkScalePk8F16Fp4Op::create(
+        rewriter, loc, op.getResult().getType(), adaptor.getSource(),
+        adaptor.getScale(), scaleSel);
+    return success();
+  }
+  /*
+  CvtPkScalePk8F16Fp8Op
+  CvtPkScalePk8F16Bf8Op
+
+  CvtPkScalePk8Bf16Fp4Op
+  CvtPkScalePk8Bf16Fp8Op
+  CvtPkScalePk8Bf16Bf8Op
+
+  CvtPkScalePk8F32Fp4Op
+  CvtPkScalePk8F32Fp8Op
+  CvtPkScalePk8F32Bf8Op
+
+  // smallT = [Fp6, Bf6]
+  // largeT = [F16, Bf16, F32]
+  // CvtPkScalePk{16}${largeT}${smallT}
+  CvtPkScale16F16Fp6Op
+  CvtPkScale16F16Bf6Op
+
+  CvtPkScale16Bf16Fp6Op
+  CvtPkScale16Bf16Bf6Op
+
+  CvtPkScale16F32Fp6Op
+  CvtPkScale16F32Bf6Op
+  */
+
   return failure();
 }
 

>From 163b15aef1795ed5d30e627a7f781406b00a37d2 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Mon, 10 Nov 2025 16:15:34 -0500
Subject: [PATCH 05/43] Initial conversion

---
 .../Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp   | 16 +++++++++++++---
 1 file changed, 13 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 568013bee5ec8..d93332a3f3c40 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1701,14 +1701,24 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
   auto targetType = cast<VectorType>(op.getResult().getType());
   auto tgtElemType = cast<FloatType>(targetType.getElementType());
   Location loc = op.getLoc();
+  // %scale: vector<4xf8E8M0FNU>
+  // ===========================
+  // %scale: i32
+  IntegerType i32 = rewriter.getI32Type();
+  Value castedScale =
+      LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getScale());
+
   // Ok, so we need to construct ops depending on the sourceType and targetType.
   // smallT = [Fp4, Fp8, Bf8]
   // largeT = [F16, Bf16, F32]
   // CvtPkScalePk{8}${largeT}${smallT}
   if (isa<Float4E2M1FNType>(srcElemType) and isa<Float16Type>(tgtElemType)) {
-    ROCDL::CvtPkScalePk8F16Fp4Op::create(
-        rewriter, loc, op.getResult().getType(), adaptor.getSource(),
-        adaptor.getScale(), scaleSel);
+    Value castedSource =
+        LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getSource());
+    auto newOp = ROCDL::CvtPkScalePk8F16Fp4Op::create(
+        rewriter, loc, op.getResult().getType(), castedSource, castedScale,
+        scaleSel);
+    rewriter.replaceOp(op, newOp);
     return success();
   }
   /*

>From 6d7e2a65171ae253430067e3db23b94ca6a29c27 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Mon, 10 Nov 2025 16:26:39 -0500
Subject: [PATCH 06/43] Add first test

---
 .../Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir   | 13 +++++++++++++
 1 file changed, 13 insertions(+)

diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
index 2fd3df6dcfa71..840187d1f36d7 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
@@ -456,3 +456,16 @@ func.func @sched_barrier() {
   amdgpu.sched_barrier allow = <valu|all_vmem>
   func.return
 }
+
+// CHECK-LABEL: @scaled_ext_packed816_fp4
+// CHECK: (%[[SOURCE:.+]]: vector<8xf4E2M1FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>)
+func.func @scaled_ext_packed816_fp4(%v: vector<8xf4E2M1FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>) {
+  // CHECK: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8>
+  // CHECK: %[[SOURCE_8xi4:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<8xf4E2M1FN> to vector<8xi4>
+  // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+  // CHECK: %[[SOURCE_i32:.+]] = llvm.bitcast %[[SOURCE_8xi4]] : vector<8xi4> to i32
+  // CHECK: %[[RES:.+]] = rocdl.cvt.scale.pk8.f16.fp4 %[[SOURCE_i32]], %[[SCALE_i32]][0] : vector<8xf16>
+  %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xf16>
+  func.return %ret0: vector<8xf16>
+}
+

>From fc5d8587ac8f3ea9142d3810d1dfea7ae42b4ff1 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Mon, 10 Nov 2025 17:18:58 -0500
Subject: [PATCH 07/43] Adds two more cases

---
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           | 28 +++++++++++++++++++
 .../AMDGPUToROCDL/amdgpu-to-rocdl.mlir        | 25 ++++++++++++++++-
 2 files changed, 52 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index d93332a3f3c40..605bcf38204ba 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1721,6 +1721,34 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
     rewriter.replaceOp(op, newOp);
     return success();
   }
+  if (isa<Float8E4M3FNType>(srcElemType) and isa<Float16Type>(tgtElemType)) {
+    // vector<8xf8E4M3FN>
+    Value source = adaptor.getSource();
+
+    // vector<2xi32>
+    VectorType v2xi32 = VectorType::get(2, i32);
+    Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v2xi32, source);
+
+    auto newOp = ROCDL::CvtPkScalePk8F16Fp8Op::create(
+        rewriter, loc, op.getResult().getType(), castedSource, castedScale,
+        scaleSel);
+    rewriter.replaceOp(op, newOp);
+    return success();
+  }
+  if (isa<Float8E5M2Type>(srcElemType) and isa<Float16Type>(tgtElemType)) {
+    // vector<8xf8E5M2>
+    Value source = adaptor.getSource();
+
+    // vector<2xi32>
+    VectorType v2xi32 = VectorType::get(2, i32);
+    Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v2xi32, source);
+
+    auto newOp = ROCDL::CvtPkScalePk8F16Bf8Op::create(
+        rewriter, loc, op.getResult().getType(), castedSource, castedScale,
+        scaleSel);
+    rewriter.replaceOp(op, newOp);
+    return success();
+  }
   /*
   CvtPkScalePk8F16Fp8Op
   CvtPkScalePk8F16Bf8Op
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
index 840187d1f36d7..8edd6c038af1e 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
@@ -458,7 +458,7 @@ func.func @sched_barrier() {
 }
 
 // CHECK-LABEL: @scaled_ext_packed816_fp4
-// CHECK: (%[[SOURCE:.+]]: vector<8xf4E2M1FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>)
+// CHECK-SAME: (%[[SOURCE:.+]]: vector<8xf4E2M1FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>)
 func.func @scaled_ext_packed816_fp4(%v: vector<8xf4E2M1FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>) {
   // CHECK: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8>
   // CHECK: %[[SOURCE_8xi4:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<8xf4E2M1FN> to vector<8xi4>
@@ -469,3 +469,26 @@ func.func @scaled_ext_packed816_fp4(%v: vector<8xf4E2M1FN>, %scale: vector<4xf8E
   func.return %ret0: vector<8xf16>
 }
 
+// CHECK-LABEL: @scaled_ext_packed816_fp8
+// CHECK-SAME: (%[[SOURCE:.+]]: vector<8xf8E4M3FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>)
+func.func @scaled_ext_packed816_fp8(%v: vector<8xf8E4M3FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>) {
+  // CHECK: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8>
+  // CHECK: %[[SOURCE_8xi8:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<8xf8E4M3FN> to vector<8xi8>
+  // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+  // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32>
+  // CHECK: %[[RES:.+]] = rocdl.cvt.scale.pk8.f16.fp8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xf16>
+  %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xf16>
+  func.return %ret0 : vector<8xf16>
+}
+
+// CHECK-LABEL: @scaled_ext_packed816_bf8
+// CHECK-SAME: (%[[SOURCE:.+]]: vector<8xf8E5M2>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>)
+func.func @scaled_ext_packed816_bf8(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>) {
+  // CHECK: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8>
+  // CHECK: %[[SOURCE_8xi8:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<8xf8E5M2> to vector<8xi8>
+  // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+  // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32>
+  // CHECK: %[[RES:.+]] = rocdl.cvt.scale.pk8.f16.bf8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xf16>
+  %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xf16>
+  func.return %ret0 : vector<8xf16>
+}

>From d9a254f629fe430c961ef407e2d529a1a10f691a Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 11 Nov 2025 09:14:28 -0500
Subject: [PATCH 08/43] Add case for pk8.bf16.fp4

---
 .../Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 18 +++++++++++++++---
 .../AMDGPUToROCDL/amdgpu-to-rocdl.mlir         | 11 ++++++++---
 2 files changed, 23 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 605bcf38204ba..5fd58c2b2906b 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1713,6 +1713,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
   // largeT = [F16, Bf16, F32]
   // CvtPkScalePk{8}${largeT}${smallT}
   if (isa<Float4E2M1FNType>(srcElemType) and isa<Float16Type>(tgtElemType)) {
+    // CvtPkScalePk8F16Fp4Op
+    // i32
     Value castedSource =
         LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getSource());
     auto newOp = ROCDL::CvtPkScalePk8F16Fp4Op::create(
@@ -1722,6 +1724,7 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
     return success();
   }
   if (isa<Float8E4M3FNType>(srcElemType) and isa<Float16Type>(tgtElemType)) {
+    // CvtPkScalePk8F16Fp8Op
     // vector<8xf8E4M3FN>
     Value source = adaptor.getSource();
 
@@ -1736,6 +1739,7 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
     return success();
   }
   if (isa<Float8E5M2Type>(srcElemType) and isa<Float16Type>(tgtElemType)) {
+    // CvtPkScalePk8F16Bf8Op
     // vector<8xf8E5M2>
     Value source = adaptor.getSource();
 
@@ -1749,11 +1753,19 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
     rewriter.replaceOp(op, newOp);
     return success();
   }
+  if (isa<Float4E2M1FNType>(srcElemType) and isa<BFloat16Type>(tgtElemType)) {
+    // CvtPkScalePk8Bf16Fp4Op
+    // i32
+    Value castedSource =
+        LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getSource());
+    auto newOp = ROCDL::CvtPkScalePk8Bf16Fp4Op::create(
+        rewriter, loc, op.getResult().getType(), castedSource, castedScale,
+        scaleSel);
+    rewriter.replaceOp(op, newOp);
+    return success();
+  }
   /*
-  CvtPkScalePk8F16Fp8Op
-  CvtPkScalePk8F16Bf8Op
 
-  CvtPkScalePk8Bf16Fp4Op
   CvtPkScalePk8Bf16Fp8Op
   CvtPkScalePk8Bf16Bf8Op
 
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
index 8edd6c038af1e..d2099d2f60eff 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
@@ -459,14 +459,19 @@ func.func @sched_barrier() {
 
 // CHECK-LABEL: @scaled_ext_packed816_fp4
 // CHECK-SAME: (%[[SOURCE:.+]]: vector<8xf4E2M1FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>)
-func.func @scaled_ext_packed816_fp4(%v: vector<8xf4E2M1FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>) {
+func.func @scaled_ext_packed816_fp4(%v: vector<8xf4E2M1FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>) {
   // CHECK: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8>
   // CHECK: %[[SOURCE_8xi4:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<8xf4E2M1FN> to vector<8xi4>
   // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
   // CHECK: %[[SOURCE_i32:.+]] = llvm.bitcast %[[SOURCE_8xi4]] : vector<8xi4> to i32
-  // CHECK: %[[RES:.+]] = rocdl.cvt.scale.pk8.f16.fp4 %[[SOURCE_i32]], %[[SCALE_i32]][0] : vector<8xf16>
+  // CHECK: rocdl.cvt.scale.pk8.f16.fp4 %[[SOURCE_i32]], %[[SCALE_i32]][0] : vector<8xf16>
   %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xf16>
-  func.return %ret0: vector<8xf16>
+
+  // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+  // CHECK: %[[SOURCE_i32:.+]] = llvm.bitcast %[[SOURCE_8xi4]] : vector<8xi4> to i32
+  // CHECK: rocdl.cvt.scale.pk8.bf16.fp4 %[[SOURCE_i32]], %[[SCALE_i32]][0] : vector<8xbf16>
+  %ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xbf16>
+  func.return %ret0, %ret1: vector<8xf16>, vector<8xbf16>
 }
 
 // CHECK-LABEL: @scaled_ext_packed816_fp8

>From c5eb6989affbd32c2100909943e790af851a5a43 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 11 Nov 2025 09:28:50 -0500
Subject: [PATCH 09/43] Add conversion for pk8.bf16.bf8

---
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           | 16 +++++++++-
 .../AMDGPUToROCDL/amdgpu-to-rocdl.mlir        | 31 ++++++++++++-------
 2 files changed, 34 insertions(+), 13 deletions(-)

diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 5fd58c2b2906b..ede420606b7b4 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1764,9 +1764,23 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
     rewriter.replaceOp(op, newOp);
     return success();
   }
+  if (isa<Float8E5M2Type>(srcElemType) and isa<BFloat16Type>(tgtElemType)) {
+    // CvtPkScalePk8Bf16Fp8Op
+    // vector<8xf8E5M2>
+    Value source = adaptor.getSource();
+
+    // vector<2xi32>
+    VectorType v2xi32 = VectorType::get(2, i32);
+    Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v2xi32, source);
+
+    auto newOp = ROCDL::CvtPkScalePk8Bf16Fp8Op::create(
+        rewriter, loc, op.getResult().getType(), castedSource, castedScale,
+        scaleSel);
+    rewriter.replaceOp(op, newOp);
+    return success();
+  }
   /*
 
-  CvtPkScalePk8Bf16Fp8Op
   CvtPkScalePk8Bf16Bf8Op
 
   CvtPkScalePk8F32Fp4Op
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
index d2099d2f60eff..e6b2ef7bc3f79 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
@@ -474,6 +474,23 @@ func.func @scaled_ext_packed816_fp4(%v: vector<8xf4E2M1FN>, %scale: vector<4xf8E
   func.return %ret0, %ret1: vector<8xf16>, vector<8xbf16>
 }
 
+// CHECK-LABEL: @scaled_ext_packed816_bf8
+// CHECK-SAME: (%[[SOURCE:.+]]: vector<8xf8E5M2>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>)
+func.func @scaled_ext_packed816_bf8(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>) {
+  // CHECK: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8>
+  // CHECK: %[[SOURCE_8xi8:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<8xf8E5M2> to vector<8xi8>
+  // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+  // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32>
+  // CHECK: %[[RES:.+]] = rocdl.cvt.scale.pk8.f16.bf8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xf16>
+  %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xf16>
+
+  // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+  // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32>
+  // CHECK: rocdl.cvt.scale.pk8.bf16.fp8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xbf16>
+  %ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xbf16>
+  func.return %ret0, %ret1 : vector<8xf16>, vector<8xbf16>
+}
+
 // CHECK-LABEL: @scaled_ext_packed816_fp8
 // CHECK-SAME: (%[[SOURCE:.+]]: vector<8xf8E4M3FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>)
 func.func @scaled_ext_packed816_fp8(%v: vector<8xf8E4M3FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>) {
@@ -481,19 +498,9 @@ func.func @scaled_ext_packed816_fp8(%v: vector<8xf8E4M3FN>, %scale: vector<4xf8E
   // CHECK: %[[SOURCE_8xi8:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<8xf8E4M3FN> to vector<8xi8>
   // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
   // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32>
-  // CHECK: %[[RES:.+]] = rocdl.cvt.scale.pk8.f16.fp8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xf16>
+  // CHECK: rocdl.cvt.scale.pk8.f16.fp8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xf16>
   %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xf16>
-  func.return %ret0 : vector<8xf16>
-}
 
-// CHECK-LABEL: @scaled_ext_packed816_bf8
-// CHECK-SAME: (%[[SOURCE:.+]]: vector<8xf8E5M2>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>)
-func.func @scaled_ext_packed816_bf8(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>) {
-  // CHECK: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8>
-  // CHECK: %[[SOURCE_8xi8:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<8xf8E5M2> to vector<8xi8>
-  // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
-  // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32>
-  // CHECK: %[[RES:.+]] = rocdl.cvt.scale.pk8.f16.bf8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xf16>
-  %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xf16>
   func.return %ret0 : vector<8xf16>
 }
+

>From cec5f045c639adeb2fda944751768b69fd9a5b9d Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 11 Nov 2025 09:41:25 -0500
Subject: [PATCH 10/43] Fix and add new case

---
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           | 24 ++++++++++++--
 .../AMDGPUToROCDL/amdgpu-to-rocdl.mlir        | 32 +++++++++++--------
 2 files changed, 40 insertions(+), 16 deletions(-)

diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index ede420606b7b4..8a55f41e8495c 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1710,8 +1710,12 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
 
   // Ok, so we need to construct ops depending on the sourceType and targetType.
   // smallT = [Fp4, Fp8, Bf8]
+  //           Bf8 = E5M2
+  //           Fp8 = E4M3
+  //
   // largeT = [F16, Bf16, F32]
   // CvtPkScalePk{8}${largeT}${smallT}
+
   if (isa<Float4E2M1FNType>(srcElemType) and isa<Float16Type>(tgtElemType)) {
     // CvtPkScalePk8F16Fp4Op
     // i32
@@ -1764,9 +1768,9 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
     rewriter.replaceOp(op, newOp);
     return success();
   }
-  if (isa<Float8E5M2Type>(srcElemType) and isa<BFloat16Type>(tgtElemType)) {
+  if (isa<Float8E4M3FNType>(srcElemType) and isa<BFloat16Type>(tgtElemType)) {
     // CvtPkScalePk8Bf16Fp8Op
-    // vector<8xf8E5M2>
+    // vector<8xf8E4M3FN>
     Value source = adaptor.getSource();
 
     // vector<2xi32>
@@ -1779,9 +1783,23 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
     rewriter.replaceOp(op, newOp);
     return success();
   }
+  if (isa<Float8E5M2Type>(srcElemType) and isa<BFloat16Type>(tgtElemType)) {
+    // CvtPkScalePk8Bf16Bf8Op
+    // vector<8xf8E5M2>
+    Value source = adaptor.getSource();
+
+    // vector<2xi32>
+    VectorType v2xi32 = VectorType::get(2, i32);
+    Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v2xi32, source);
+
+    auto newOp = ROCDL::CvtPkScalePk8Bf16Bf8Op::create(
+        rewriter, loc, op.getResult().getType(), castedSource, castedScale,
+        scaleSel);
+    rewriter.replaceOp(op, newOp);
+    return success();
+  }
   /*
 
-  CvtPkScalePk8Bf16Bf8Op
 
   CvtPkScalePk8F32Fp4Op
   CvtPkScalePk8F32Fp8Op
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
index e6b2ef7bc3f79..e248856ed472e 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
@@ -474,6 +474,24 @@ func.func @scaled_ext_packed816_fp4(%v: vector<8xf4E2M1FN>, %scale: vector<4xf8E
   func.return %ret0, %ret1: vector<8xf16>, vector<8xbf16>
 }
 
+// CHECK-LABEL: @scaled_ext_packed816_fp8
+// CHECK-SAME: (%[[SOURCE:.+]]: vector<8xf8E4M3FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>)
+func.func @scaled_ext_packed816_fp8(%v: vector<8xf8E4M3FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>) {
+  // CHECK: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8>
+  // CHECK: %[[SOURCE_8xi8:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<8xf8E4M3FN> to vector<8xi8>
+  // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+  // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32>
+  // CHECK: rocdl.cvt.scale.pk8.f16.fp8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xf16>
+  %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xf16>
+
+  // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+  // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32>
+  // CHECK: rocdl.cvt.scale.pk8.bf16.fp8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xbf16>
+  %ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xbf16>
+
+  func.return %ret0, %ret1 : vector<8xf16>, vector<8xbf16>
+}
+
 // CHECK-LABEL: @scaled_ext_packed816_bf8
 // CHECK-SAME: (%[[SOURCE:.+]]: vector<8xf8E5M2>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>)
 func.func @scaled_ext_packed816_bf8(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>) {
@@ -486,21 +504,9 @@ func.func @scaled_ext_packed816_bf8(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M
 
   // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
   // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32>
-  // CHECK: rocdl.cvt.scale.pk8.bf16.fp8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xbf16>
+  // CHECK: rocdl.cvt.scale.pk8.bf16.bf8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xbf16>
   %ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xbf16>
   func.return %ret0, %ret1 : vector<8xf16>, vector<8xbf16>
 }
 
-// CHECK-LABEL: @scaled_ext_packed816_fp8
-// CHECK-SAME: (%[[SOURCE:.+]]: vector<8xf8E4M3FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>)
-func.func @scaled_ext_packed816_fp8(%v: vector<8xf8E4M3FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>) {
-  // CHECK: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8>
-  // CHECK: %[[SOURCE_8xi8:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<8xf8E4M3FN> to vector<8xi8>
-  // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
-  // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32>
-  // CHECK: rocdl.cvt.scale.pk8.f16.fp8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xf16>
-  %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xf16>
-
-  func.return %ret0 : vector<8xf16>
-}
 

>From 7dc34425abb1ff3271b039cbd95d008531c6a940 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 11 Nov 2025 09:45:32 -0500
Subject: [PATCH 11/43] Add case for pk8.f32.fp4

---
 mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 13 +++++++++++--
 .../Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir   |  9 +++++++--
 2 files changed, 18 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 8a55f41e8495c..5c5aad91cd8e0 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1798,9 +1798,18 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
     rewriter.replaceOp(op, newOp);
     return success();
   }
+  if (isa<Float4E2M1FNType>(srcElemType) and isa<Float32Type>(tgtElemType)) {
+    // CvtPkScalePk8F32Fp4Op
+    // i32
+    Value castedSource =
+        LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getSource());
+    auto newOp = ROCDL::CvtPkScalePk8F32Fp4Op::create(
+        rewriter, loc, op.getResult().getType(), castedSource, castedScale,
+        scaleSel);
+    rewriter.replaceOp(op, newOp);
+    return success();
+  }
   /*
-
-
   CvtPkScalePk8F32Fp4Op
   CvtPkScalePk8F32Fp8Op
   CvtPkScalePk8F32Bf8Op
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
index e248856ed472e..13f109b787d4d 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
@@ -459,7 +459,7 @@ func.func @sched_barrier() {
 
 // CHECK-LABEL: @scaled_ext_packed816_fp4
 // CHECK-SAME: (%[[SOURCE:.+]]: vector<8xf4E2M1FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>)
-func.func @scaled_ext_packed816_fp4(%v: vector<8xf4E2M1FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>) {
+func.func @scaled_ext_packed816_fp4(%v: vector<8xf4E2M1FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>, vector<8xf32>) {
   // CHECK: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8>
   // CHECK: %[[SOURCE_8xi4:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<8xf4E2M1FN> to vector<8xi4>
   // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
@@ -471,7 +471,12 @@ func.func @scaled_ext_packed816_fp4(%v: vector<8xf4E2M1FN>, %scale: vector<4xf8E
   // CHECK: %[[SOURCE_i32:.+]] = llvm.bitcast %[[SOURCE_8xi4]] : vector<8xi4> to i32
   // CHECK: rocdl.cvt.scale.pk8.bf16.fp4 %[[SOURCE_i32]], %[[SCALE_i32]][0] : vector<8xbf16>
   %ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xbf16>
-  func.return %ret0, %ret1: vector<8xf16>, vector<8xbf16>
+
+  // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+  // CHECK: %[[SOURCE_i32:.+]] = llvm.bitcast %[[SOURCE_8xi4]] : vector<8xi4> to i32
+  // CHECK: rocdl.cvt.scale.pk8.f32.fp4 %[[SOURCE_i32]], %[[SCALE_i32]][0] : vector<8xf32>
+  %ret2 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xf32>
+  func.return %ret0, %ret1, %ret2: vector<8xf16>, vector<8xbf16>, vector<8xf32>
 }
 
 // CHECK-LABEL: @scaled_ext_packed816_fp8

>From 0ba6b949e626bbafdb5fdc5b94cea342b4e7aab5 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 11 Nov 2025 09:50:01 -0500
Subject: [PATCH 12/43] Add case for pk8.f32.fp8

---
 .../Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp  | 17 +++++++++++++++--
 .../AMDGPUToROCDL/amdgpu-to-rocdl.mlir          |  9 +++++++--
 2 files changed, 22 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 5c5aad91cd8e0..a347a822aba66 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1809,9 +1809,22 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
     rewriter.replaceOp(op, newOp);
     return success();
   }
+  if (isa<Float8E4M3FNType>(srcElemType) and isa<Float32Type>(tgtElemType)) {
+    // CvtPkScalePk8F32Fp8Op
+    // vector<8xf8E4M3FN>
+    Value source = adaptor.getSource();
+
+    // vector<2xi32>
+    VectorType v2xi32 = VectorType::get(2, i32);
+    Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v2xi32, source);
+
+    auto newOp = ROCDL::CvtPkScalePk8F32Fp8Op::create(
+        rewriter, loc, op.getResult().getType(), castedSource, castedScale,
+        scaleSel);
+    rewriter.replaceOp(op, newOp);
+    return success();
+  }
   /*
-  CvtPkScalePk8F32Fp4Op
-  CvtPkScalePk8F32Fp8Op
   CvtPkScalePk8F32Bf8Op
 
   // smallT = [Fp6, Bf6]
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
index 13f109b787d4d..917d25b9bc8f2 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
@@ -481,7 +481,7 @@ func.func @scaled_ext_packed816_fp4(%v: vector<8xf4E2M1FN>, %scale: vector<4xf8E
 
 // CHECK-LABEL: @scaled_ext_packed816_fp8
 // CHECK-SAME: (%[[SOURCE:.+]]: vector<8xf8E4M3FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>)
-func.func @scaled_ext_packed816_fp8(%v: vector<8xf8E4M3FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>) {
+func.func @scaled_ext_packed816_fp8(%v: vector<8xf8E4M3FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>, vector<8xf32>) {
   // CHECK: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8>
   // CHECK: %[[SOURCE_8xi8:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<8xf8E4M3FN> to vector<8xi8>
   // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
@@ -494,7 +494,12 @@ func.func @scaled_ext_packed816_fp8(%v: vector<8xf8E4M3FN>, %scale: vector<4xf8E
   // CHECK: rocdl.cvt.scale.pk8.bf16.fp8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xbf16>
   %ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xbf16>
 
-  func.return %ret0, %ret1 : vector<8xf16>, vector<8xbf16>
+  // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+  // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32>
+  // CHECK: rocdl.cvt.scale.pk8.f32.fp8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xf32>
+  %ret2 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xf32>
+
+  func.return %ret0, %ret1, %ret2 : vector<8xf16>, vector<8xbf16>, vector<8xf32>
 }
 
 // CHECK-LABEL: @scaled_ext_packed816_bf8

>From 551849e8496343a293005c32e9f3c6cc6429eba6 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 11 Nov 2025 09:54:18 -0500
Subject: [PATCH 13/43] Add case for pk8.f32.bf8

---
 .../Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp   | 16 +++++++++++++++-
 .../AMDGPUToROCDL/amdgpu-to-rocdl.mlir           |  9 +++++++--
 2 files changed, 22 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index a347a822aba66..72617f76e6c91 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1824,8 +1824,22 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
     rewriter.replaceOp(op, newOp);
     return success();
   }
+  if (isa<Float8E5M2Type>(srcElemType) and isa<Float32Type>(tgtElemType)) {
+    // CvtPkScalePk8F32Bf8Op
+    // vector<8xf8E5M2>
+    Value source = adaptor.getSource();
+
+    // vector<2xi32>
+    VectorType v2xi32 = VectorType::get(2, i32);
+    Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v2xi32, source);
+
+    auto newOp = ROCDL::CvtPkScalePk8F32Bf8Op::create(
+        rewriter, loc, op.getResult().getType(), castedSource, castedScale,
+        scaleSel);
+    rewriter.replaceOp(op, newOp);
+    return success();
+  }
   /*
-  CvtPkScalePk8F32Bf8Op
 
   // smallT = [Fp6, Bf6]
   // largeT = [F16, Bf16, F32]
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
index 917d25b9bc8f2..220759b6fab39 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
@@ -504,7 +504,7 @@ func.func @scaled_ext_packed816_fp8(%v: vector<8xf8E4M3FN>, %scale: vector<4xf8E
 
 // CHECK-LABEL: @scaled_ext_packed816_bf8
 // CHECK-SAME: (%[[SOURCE:.+]]: vector<8xf8E5M2>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>)
-func.func @scaled_ext_packed816_bf8(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>) {
+func.func @scaled_ext_packed816_bf8(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>, vector<8xf32>) {
   // CHECK: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8>
   // CHECK: %[[SOURCE_8xi8:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<8xf8E5M2> to vector<8xi8>
   // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
@@ -516,7 +516,12 @@ func.func @scaled_ext_packed816_bf8(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M
   // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32>
   // CHECK: rocdl.cvt.scale.pk8.bf16.bf8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xbf16>
   %ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xbf16>
-  func.return %ret0, %ret1 : vector<8xf16>, vector<8xbf16>
+
+  // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+  // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32>
+  // CHECK: rocdl.cvt.scale.pk8.f32.bf8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xf32>
+  %ret2 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xf32>
+  func.return %ret0, %ret1, %ret2 : vector<8xf16>, vector<8xbf16>, vector<8xf32>
 }
 
 

>From c1e10c8e4b9b6a6ba9ce210aabc7bc97a5608cab Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 11 Nov 2025 10:26:48 -0500
Subject: [PATCH 14/43] Add case for pk16.f16.bf6

---
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           | 38 ++++++++++++++++---
 .../AMDGPUToROCDL/amdgpu-to-rocdl.mlir        | 26 +++++++++++++
 2 files changed, 58 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 72617f76e6c91..af068fbfd957f 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1714,7 +1714,7 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
   //           Fp8 = E4M3
   //
   // largeT = [F16, Bf16, F32]
-  // CvtPkScalePk{8}${largeT}${smallT}
+  // CvtPkScalePk8${largeT}${smallT}
 
   if (isa<Float4E2M1FNType>(srcElemType) and isa<Float16Type>(tgtElemType)) {
     // CvtPkScalePk8F16Fp4Op
@@ -1839,13 +1839,39 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
     rewriter.replaceOp(op, newOp);
     return success();
   }
-  /*
-
   // smallT = [Fp6, Bf6]
   // largeT = [F16, Bf16, F32]
-  // CvtPkScalePk{16}${largeT}${smallT}
-  CvtPkScale16F16Fp6Op
-  CvtPkScale16F16Bf6Op
+  //
+  // Fp6 = Float6E2M3FN
+  // Bf6 = Float6E3M2FN
+
+  // CvtPkScalePk16${largeT}${smallT}
+  if (isa<Float6E2M3FNType>(srcElemType) and isa<Float16Type>(tgtElemType)) {
+    // CvtPkScale16F16Fp6Op
+    Value source = adaptor.getSource();
+    VectorType v3xi32 = VectorType::get(3, i32);
+    Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source);
+
+    auto newOp = ROCDL::CvtPkScalePk16F16Fp6Op::create(
+        rewriter, loc, op.getResult().getType(), castedSource, castedScale,
+        scaleSel);
+    rewriter.replaceOp(op, newOp);
+    return success();
+  }
+  if (isa<Float6E3M2FNType>(srcElemType) and isa<Float16Type>(tgtElemType)) {
+    // CvtPkScale16F16Bf6Op
+    Value source = adaptor.getSource();
+    VectorType v3xi32 = VectorType::get(3, i32);
+    Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source);
+
+    auto newOp = ROCDL::CvtPkScalePk16F16Bf6Op::create(
+        rewriter, loc, op.getResult().getType(), castedSource, castedScale,
+        scaleSel);
+    rewriter.replaceOp(op, newOp);
+    return success();
+  }
+
+  /*
 
   CvtPkScale16Bf16Fp6Op
   CvtPkScale16Bf16Bf6Op
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
index 220759b6fab39..349440053a646 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
@@ -525,3 +525,29 @@ func.func @scaled_ext_packed816_bf8(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M
 }
 
 
+// CHECK-LABEL: @scaled_ext_packed816_fp6
+// CHECK-SAME: (%[[SOURCE:.+]]: vector<16xf6E2M3FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>)
+func.func @scaled_ext_packed816_fp6(%v: vector<16xf6E2M3FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf16>) {
+  // CHECK-DAG: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8>
+  // CHECK-DAG: %[[SOURCE_16xi6:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<16xf6E2M3FN> to vector<16xi6>
+
+  // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+  // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32>
+  // CHECK: rocdl.cvt.scale.pk16.f16.fp6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xf16>
+  %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E2M3FN>, vector<4xf8E8M0FNU> -> vector<16xf16>
+  return %ret0: vector<16xf16>
+}
+
+// CHECK-LABEL: @scaled_ext_packed816_bf6
+// CHECK-SAME: (%[[SOURCE:.+]]: vector<16xf6E3M2FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>)
+func.func @scaled_ext_packed816_bf6(%v: vector<16xf6E3M2FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf16>) {
+  // CHECK-DAG: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8>
+  // CHECK-DAG: %[[SOURCE_16xi6:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<16xf6E3M2FN> to vector<16xi6>
+
+  // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+  // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32>
+  // CHECK: rocdl.cvt.scale.pk16.f16.bf6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xf16>
+  %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xf16>
+  return %ret0: vector<16xf16>
+}
+

>From e958d56ca7e5fa32767a10186fe17905adf11ecf Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 11 Nov 2025 10:30:17 -0500
Subject: [PATCH 15/43] Add case for pk16.bf16.fp6

---
 mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 13 ++++++++++++-
 .../Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir   |  9 +++++++--
 2 files changed, 19 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index af068fbfd957f..cbe3251268f78 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1870,10 +1870,21 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
     rewriter.replaceOp(op, newOp);
     return success();
   }
+  if (isa<Float6E2M3FNType>(srcElemType) and isa<BFloat16Type>(tgtElemType)) {
+    // CvtPkScale16Bf16Fp6Op
+    Value source = adaptor.getSource();
+    VectorType v3xi32 = VectorType::get(3, i32);
+    Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source);
+
+    auto newOp = ROCDL::CvtPkScalePk16Bf16Fp6Op::create(
+        rewriter, loc, op.getResult().getType(), castedSource, castedScale,
+        scaleSel);
+    rewriter.replaceOp(op, newOp);
+    return success();
+  }
 
   /*
 
-  CvtPkScale16Bf16Fp6Op
   CvtPkScale16Bf16Bf6Op
 
   CvtPkScale16F32Fp6Op
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
index 349440053a646..76e26e0cab40a 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
@@ -527,7 +527,7 @@ func.func @scaled_ext_packed816_bf8(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M
 
 // CHECK-LABEL: @scaled_ext_packed816_fp6
 // CHECK-SAME: (%[[SOURCE:.+]]: vector<16xf6E2M3FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>)
-func.func @scaled_ext_packed816_fp6(%v: vector<16xf6E2M3FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf16>) {
+func.func @scaled_ext_packed816_fp6(%v: vector<16xf6E2M3FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf16>, vector<16xbf16>) {
   // CHECK-DAG: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8>
   // CHECK-DAG: %[[SOURCE_16xi6:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<16xf6E2M3FN> to vector<16xi6>
 
@@ -535,7 +535,12 @@ func.func @scaled_ext_packed816_fp6(%v: vector<16xf6E2M3FN>, %scale: vector<4xf8
   // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32>
   // CHECK: rocdl.cvt.scale.pk16.f16.fp6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xf16>
   %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E2M3FN>, vector<4xf8E8M0FNU> -> vector<16xf16>
-  return %ret0: vector<16xf16>
+
+  // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+  // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32>
+  // CHECK: rocdl.cvt.scale.pk16.bf16.fp6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xbf16>
+  %ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E2M3FN>, vector<4xf8E8M0FNU> -> vector<16xbf16>
+  return %ret0, %ret1: vector<16xf16>, vector<16xbf16>
 }
 
 // CHECK-LABEL: @scaled_ext_packed816_bf6

>From 1f79bdd99637a7be7765df808c3e3a63199ca524 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 11 Nov 2025 10:38:08 -0500
Subject: [PATCH 16/43] Add case for pk16.bf16.bf6

---
 mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 13 ++++++++++++-
 .../Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir   |  9 +++++++--
 2 files changed, 19 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index cbe3251268f78..322c8efcf3aea 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1882,10 +1882,21 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
     rewriter.replaceOp(op, newOp);
     return success();
   }
+  if (isa<Float6E3M2FNType>(srcElemType) and isa<BFloat16Type>(tgtElemType)) {
+    // CvtPkScale16Bf16Bf6Op
+    Value source = adaptor.getSource();
+    VectorType v3xi32 = VectorType::get(3, i32);
+    Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source);
+
+    auto newOp = ROCDL::CvtPkScalePk16Bf16Bf6Op::create(
+        rewriter, loc, op.getResult().getType(), castedSource, castedScale,
+        scaleSel);
+    rewriter.replaceOp(op, newOp);
+    return success();
+  }
 
   /*
 
-  CvtPkScale16Bf16Bf6Op
 
   CvtPkScale16F32Fp6Op
   CvtPkScale16F32Bf6Op
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
index 76e26e0cab40a..b31116538228e 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
@@ -545,7 +545,7 @@ func.func @scaled_ext_packed816_fp6(%v: vector<16xf6E2M3FN>, %scale: vector<4xf8
 
 // CHECK-LABEL: @scaled_ext_packed816_bf6
 // CHECK-SAME: (%[[SOURCE:.+]]: vector<16xf6E3M2FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>)
-func.func @scaled_ext_packed816_bf6(%v: vector<16xf6E3M2FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf16>) {
+func.func @scaled_ext_packed816_bf6(%v: vector<16xf6E3M2FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf16>, vector<16xbf16>) {
   // CHECK-DAG: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8>
   // CHECK-DAG: %[[SOURCE_16xi6:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<16xf6E3M2FN> to vector<16xi6>
 
@@ -553,6 +553,11 @@ func.func @scaled_ext_packed816_bf6(%v: vector<16xf6E3M2FN>, %scale: vector<4xf8
   // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32>
   // CHECK: rocdl.cvt.scale.pk16.f16.bf6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xf16>
   %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xf16>
-  return %ret0: vector<16xf16>
+
+  // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+  // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32>
+  // CHECK: rocdl.cvt.scale.pk16.bf16.bf6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xbf16>
+  %ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xbf16>
+  return %ret0, %ret1: vector<16xf16>, vector<16xbf16>
 }
 

>From c5628e6ecf7a101a247da33fb55a497cfe33e98c Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 11 Nov 2025 10:44:19 -0500
Subject: [PATCH 17/43] Add case for pk16.f32.fp6

---
 mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 13 ++++++++++++-
 .../Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir   |  9 +++++++--
 2 files changed, 19 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 322c8efcf3aea..fb189b24b29e7 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1894,11 +1894,22 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
     rewriter.replaceOp(op, newOp);
     return success();
   }
+  if (isa<Float6E2M3FNType>(srcElemType) and isa<Float32Type>(tgtElemType)) {
+    // CvtPkScale16F32Fp6Op
+    Value source = adaptor.getSource();
+    VectorType v3xi32 = VectorType::get(3, i32);
+    Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source);
+
+    auto newOp = ROCDL::CvtPkScalePk16F32Fp6Op::create(
+        rewriter, loc, op.getResult().getType(), castedSource, castedScale,
+        scaleSel);
+    rewriter.replaceOp(op, newOp);
+    return success();
+  }
 
   /*
 
 
-  CvtPkScale16F32Fp6Op
   CvtPkScale16F32Bf6Op
   */
 
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
index b31116538228e..dccdb81033738 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
@@ -527,7 +527,7 @@ func.func @scaled_ext_packed816_bf8(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M
 
 // CHECK-LABEL: @scaled_ext_packed816_fp6
 // CHECK-SAME: (%[[SOURCE:.+]]: vector<16xf6E2M3FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>)
-func.func @scaled_ext_packed816_fp6(%v: vector<16xf6E2M3FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf16>, vector<16xbf16>) {
+func.func @scaled_ext_packed816_fp6(%v: vector<16xf6E2M3FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf16>, vector<16xbf16>, vector<16xf32>) {
   // CHECK-DAG: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8>
   // CHECK-DAG: %[[SOURCE_16xi6:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<16xf6E2M3FN> to vector<16xi6>
 
@@ -540,7 +540,12 @@ func.func @scaled_ext_packed816_fp6(%v: vector<16xf6E2M3FN>, %scale: vector<4xf8
   // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32>
   // CHECK: rocdl.cvt.scale.pk16.bf16.fp6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xbf16>
   %ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E2M3FN>, vector<4xf8E8M0FNU> -> vector<16xbf16>
-  return %ret0, %ret1: vector<16xf16>, vector<16xbf16>
+
+  // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+  // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32>
+  // CHECK: rocdl.cvt.scale.pk16.f32.fp6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xf32>
+  %ret2 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E2M3FN>, vector<4xf8E8M0FNU> -> vector<16xf32>
+  return %ret0, %ret1, %ret2: vector<16xf16>, vector<16xbf16>, vector<16xf32>
 }
 
 // CHECK-LABEL: @scaled_ext_packed816_bf6

>From db56c98bfa77ab923d4cfed8ddbb4ed15863b690 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 11 Nov 2025 10:48:29 -0500
Subject: [PATCH 18/43] Add case for pk16.f32.bf6

---
 .../Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp   | 16 +++++++++++-----
 .../AMDGPUToROCDL/amdgpu-to-rocdl.mlir           |  9 +++++++--
 2 files changed, 18 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index fb189b24b29e7..26282ec4c2279 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1906,12 +1906,18 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
     rewriter.replaceOp(op, newOp);
     return success();
   }
+  if (isa<Float6E3M2FNType>(srcElemType) and isa<Float32Type>(tgtElemType)) {
+    // CvtPkScale16F32Bf6Op
+    Value source = adaptor.getSource();
+    VectorType v3xi32 = VectorType::get(3, i32);
+    Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source);
 
-  /*
-
-
-  CvtPkScale16F32Bf6Op
-  */
+    auto newOp = ROCDL::CvtPkScalePk16F32Bf6Op::create(
+        rewriter, loc, op.getResult().getType(), castedSource, castedScale,
+        scaleSel);
+    rewriter.replaceOp(op, newOp);
+    return success();
+  }
 
   return failure();
 }
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
index dccdb81033738..94a04d98004c7 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
@@ -550,7 +550,7 @@ func.func @scaled_ext_packed816_fp6(%v: vector<16xf6E2M3FN>, %scale: vector<4xf8
 
 // CHECK-LABEL: @scaled_ext_packed816_bf6
 // CHECK-SAME: (%[[SOURCE:.+]]: vector<16xf6E3M2FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>)
-func.func @scaled_ext_packed816_bf6(%v: vector<16xf6E3M2FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf16>, vector<16xbf16>) {
+func.func @scaled_ext_packed816_bf6(%v: vector<16xf6E3M2FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf16>, vector<16xbf16>, vector<16xf32>) {
   // CHECK-DAG: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8>
   // CHECK-DAG: %[[SOURCE_16xi6:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<16xf6E3M2FN> to vector<16xi6>
 
@@ -563,6 +563,11 @@ func.func @scaled_ext_packed816_bf6(%v: vector<16xf6E3M2FN>, %scale: vector<4xf8
   // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32>
   // CHECK: rocdl.cvt.scale.pk16.bf16.bf6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xbf16>
   %ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xbf16>
-  return %ret0, %ret1: vector<16xf16>, vector<16xbf16>
+
+  // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+  // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32>
+  // CHECK: rocdl.cvt.scale.pk16.f32.bf6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xf32>
+  %ret2 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xf32>
+  return %ret0, %ret1, %ret2: vector<16xf16>, vector<16xbf16>, vector<16xf32>
 }
 

>From 73ed4b7b7d1562a0aa50fa7caec5c15426c57c27 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 11 Nov 2025 10:52:19 -0500
Subject: [PATCH 19/43] Refactor NFC

---
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           | 74 ++++++++-----------
 1 file changed, 31 insertions(+), 43 deletions(-)

diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 26282ec4c2279..44393033ef442 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1725,9 +1725,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
         rewriter, loc, op.getResult().getType(), castedSource, castedScale,
         scaleSel);
     rewriter.replaceOp(op, newOp);
-    return success();
-  }
-  if (isa<Float8E4M3FNType>(srcElemType) and isa<Float16Type>(tgtElemType)) {
+  } else if (isa<Float8E4M3FNType>(srcElemType) and
+             isa<Float16Type>(tgtElemType)) {
     // CvtPkScalePk8F16Fp8Op
     // vector<8xf8E4M3FN>
     Value source = adaptor.getSource();
@@ -1740,9 +1739,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
         rewriter, loc, op.getResult().getType(), castedSource, castedScale,
         scaleSel);
     rewriter.replaceOp(op, newOp);
-    return success();
-  }
-  if (isa<Float8E5M2Type>(srcElemType) and isa<Float16Type>(tgtElemType)) {
+  } else if (isa<Float8E5M2Type>(srcElemType) and
+             isa<Float16Type>(tgtElemType)) {
     // CvtPkScalePk8F16Bf8Op
     // vector<8xf8E5M2>
     Value source = adaptor.getSource();
@@ -1755,9 +1753,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
         rewriter, loc, op.getResult().getType(), castedSource, castedScale,
         scaleSel);
     rewriter.replaceOp(op, newOp);
-    return success();
-  }
-  if (isa<Float4E2M1FNType>(srcElemType) and isa<BFloat16Type>(tgtElemType)) {
+  } else if (isa<Float4E2M1FNType>(srcElemType) and
+             isa<BFloat16Type>(tgtElemType)) {
     // CvtPkScalePk8Bf16Fp4Op
     // i32
     Value castedSource =
@@ -1766,9 +1763,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
         rewriter, loc, op.getResult().getType(), castedSource, castedScale,
         scaleSel);
     rewriter.replaceOp(op, newOp);
-    return success();
-  }
-  if (isa<Float8E4M3FNType>(srcElemType) and isa<BFloat16Type>(tgtElemType)) {
+  } else if (isa<Float8E4M3FNType>(srcElemType) and
+             isa<BFloat16Type>(tgtElemType)) {
     // CvtPkScalePk8Bf16Fp8Op
     // vector<8xf8E4M3FN>
     Value source = adaptor.getSource();
@@ -1781,9 +1777,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
         rewriter, loc, op.getResult().getType(), castedSource, castedScale,
         scaleSel);
     rewriter.replaceOp(op, newOp);
-    return success();
-  }
-  if (isa<Float8E5M2Type>(srcElemType) and isa<BFloat16Type>(tgtElemType)) {
+  } else if (isa<Float8E5M2Type>(srcElemType) and
+             isa<BFloat16Type>(tgtElemType)) {
     // CvtPkScalePk8Bf16Bf8Op
     // vector<8xf8E5M2>
     Value source = adaptor.getSource();
@@ -1796,9 +1791,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
         rewriter, loc, op.getResult().getType(), castedSource, castedScale,
         scaleSel);
     rewriter.replaceOp(op, newOp);
-    return success();
-  }
-  if (isa<Float4E2M1FNType>(srcElemType) and isa<Float32Type>(tgtElemType)) {
+  } else if (isa<Float4E2M1FNType>(srcElemType) and
+             isa<Float32Type>(tgtElemType)) {
     // CvtPkScalePk8F32Fp4Op
     // i32
     Value castedSource =
@@ -1807,9 +1801,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
         rewriter, loc, op.getResult().getType(), castedSource, castedScale,
         scaleSel);
     rewriter.replaceOp(op, newOp);
-    return success();
-  }
-  if (isa<Float8E4M3FNType>(srcElemType) and isa<Float32Type>(tgtElemType)) {
+  } else if (isa<Float8E4M3FNType>(srcElemType) and
+             isa<Float32Type>(tgtElemType)) {
     // CvtPkScalePk8F32Fp8Op
     // vector<8xf8E4M3FN>
     Value source = adaptor.getSource();
@@ -1822,9 +1815,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
         rewriter, loc, op.getResult().getType(), castedSource, castedScale,
         scaleSel);
     rewriter.replaceOp(op, newOp);
-    return success();
-  }
-  if (isa<Float8E5M2Type>(srcElemType) and isa<Float32Type>(tgtElemType)) {
+  } else if (isa<Float8E5M2Type>(srcElemType) and
+             isa<Float32Type>(tgtElemType)) {
     // CvtPkScalePk8F32Bf8Op
     // vector<8xf8E5M2>
     Value source = adaptor.getSource();
@@ -1837,7 +1829,6 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
         rewriter, loc, op.getResult().getType(), castedSource, castedScale,
         scaleSel);
     rewriter.replaceOp(op, newOp);
-    return success();
   }
   // smallT = [Fp6, Bf6]
   // largeT = [F16, Bf16, F32]
@@ -1846,7 +1837,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
   // Bf6 = Float6E3M2FN
 
   // CvtPkScalePk16${largeT}${smallT}
-  if (isa<Float6E2M3FNType>(srcElemType) and isa<Float16Type>(tgtElemType)) {
+  else if (isa<Float6E2M3FNType>(srcElemType) and
+           isa<Float16Type>(tgtElemType)) {
     // CvtPkScale16F16Fp6Op
     Value source = adaptor.getSource();
     VectorType v3xi32 = VectorType::get(3, i32);
@@ -1856,9 +1848,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
         rewriter, loc, op.getResult().getType(), castedSource, castedScale,
         scaleSel);
     rewriter.replaceOp(op, newOp);
-    return success();
-  }
-  if (isa<Float6E3M2FNType>(srcElemType) and isa<Float16Type>(tgtElemType)) {
+  } else if (isa<Float6E3M2FNType>(srcElemType) and
+             isa<Float16Type>(tgtElemType)) {
     // CvtPkScale16F16Bf6Op
     Value source = adaptor.getSource();
     VectorType v3xi32 = VectorType::get(3, i32);
@@ -1868,9 +1859,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
         rewriter, loc, op.getResult().getType(), castedSource, castedScale,
         scaleSel);
     rewriter.replaceOp(op, newOp);
-    return success();
-  }
-  if (isa<Float6E2M3FNType>(srcElemType) and isa<BFloat16Type>(tgtElemType)) {
+  } else if (isa<Float6E2M3FNType>(srcElemType) and
+             isa<BFloat16Type>(tgtElemType)) {
     // CvtPkScale16Bf16Fp6Op
     Value source = adaptor.getSource();
     VectorType v3xi32 = VectorType::get(3, i32);
@@ -1880,9 +1870,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
         rewriter, loc, op.getResult().getType(), castedSource, castedScale,
         scaleSel);
     rewriter.replaceOp(op, newOp);
-    return success();
-  }
-  if (isa<Float6E3M2FNType>(srcElemType) and isa<BFloat16Type>(tgtElemType)) {
+  } else if (isa<Float6E3M2FNType>(srcElemType) and
+             isa<BFloat16Type>(tgtElemType)) {
     // CvtPkScale16Bf16Bf6Op
     Value source = adaptor.getSource();
     VectorType v3xi32 = VectorType::get(3, i32);
@@ -1892,9 +1881,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
         rewriter, loc, op.getResult().getType(), castedSource, castedScale,
         scaleSel);
     rewriter.replaceOp(op, newOp);
-    return success();
-  }
-  if (isa<Float6E2M3FNType>(srcElemType) and isa<Float32Type>(tgtElemType)) {
+  } else if (isa<Float6E2M3FNType>(srcElemType) and
+             isa<Float32Type>(tgtElemType)) {
     // CvtPkScale16F32Fp6Op
     Value source = adaptor.getSource();
     VectorType v3xi32 = VectorType::get(3, i32);
@@ -1904,9 +1892,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
         rewriter, loc, op.getResult().getType(), castedSource, castedScale,
         scaleSel);
     rewriter.replaceOp(op, newOp);
-    return success();
-  }
-  if (isa<Float6E3M2FNType>(srcElemType) and isa<Float32Type>(tgtElemType)) {
+  } else if (isa<Float6E3M2FNType>(srcElemType) and
+             isa<Float32Type>(tgtElemType)) {
     // CvtPkScale16F32Bf6Op
     Value source = adaptor.getSource();
     VectorType v3xi32 = VectorType::get(3, i32);
@@ -1916,10 +1903,11 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
         rewriter, loc, op.getResult().getType(), castedSource, castedScale,
         scaleSel);
     rewriter.replaceOp(op, newOp);
-    return success();
+  } else {
+    return failure();
   }
 
-  return failure();
+  return success();
 }
 
 LogicalResult ScaledExtPackedOpLowering::matchAndRewrite(

>From 94cc74070c0d91382c3444d96e837eb17f6afc52 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 11 Nov 2025 11:01:43 -0500
Subject: [PATCH 20/43] Refactor NFC

---
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           | 90 +++++++------------
 1 file changed, 30 insertions(+), 60 deletions(-)

diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 44393033ef442..e28b53d8ed1a5 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1721,10 +1721,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
     // i32
     Value castedSource =
         LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getSource());
-    auto newOp = ROCDL::CvtPkScalePk8F16Fp4Op::create(
-        rewriter, loc, op.getResult().getType(), castedSource, castedScale,
-        scaleSel);
-    rewriter.replaceOp(op, newOp);
+    rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F16Fp4Op>(
+        op, op.getResult().getType(), castedSource, castedScale, scaleSel);
   } else if (isa<Float8E4M3FNType>(srcElemType) and
              isa<Float16Type>(tgtElemType)) {
     // CvtPkScalePk8F16Fp8Op
@@ -1735,10 +1733,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
     VectorType v2xi32 = VectorType::get(2, i32);
     Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v2xi32, source);
 
-    auto newOp = ROCDL::CvtPkScalePk8F16Fp8Op::create(
-        rewriter, loc, op.getResult().getType(), castedSource, castedScale,
-        scaleSel);
-    rewriter.replaceOp(op, newOp);
+    rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F16Fp8Op>(
+        op, op.getResult().getType(), castedSource, castedScale, scaleSel);
   } else if (isa<Float8E5M2Type>(srcElemType) and
              isa<Float16Type>(tgtElemType)) {
     // CvtPkScalePk8F16Bf8Op
@@ -1749,20 +1745,16 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
     VectorType v2xi32 = VectorType::get(2, i32);
     Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v2xi32, source);
 
-    auto newOp = ROCDL::CvtPkScalePk8F16Bf8Op::create(
-        rewriter, loc, op.getResult().getType(), castedSource, castedScale,
-        scaleSel);
-    rewriter.replaceOp(op, newOp);
+    rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F16Bf8Op>(
+        op, op.getResult().getType(), castedSource, castedScale, scaleSel);
   } else if (isa<Float4E2M1FNType>(srcElemType) and
              isa<BFloat16Type>(tgtElemType)) {
     // CvtPkScalePk8Bf16Fp4Op
     // i32
     Value castedSource =
         LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getSource());
-    auto newOp = ROCDL::CvtPkScalePk8Bf16Fp4Op::create(
-        rewriter, loc, op.getResult().getType(), castedSource, castedScale,
-        scaleSel);
-    rewriter.replaceOp(op, newOp);
+    rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8Bf16Fp4Op>(
+        op, op.getResult().getType(), castedSource, castedScale, scaleSel);
   } else if (isa<Float8E4M3FNType>(srcElemType) and
              isa<BFloat16Type>(tgtElemType)) {
     // CvtPkScalePk8Bf16Fp8Op
@@ -1773,10 +1765,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
     VectorType v2xi32 = VectorType::get(2, i32);
     Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v2xi32, source);
 
-    auto newOp = ROCDL::CvtPkScalePk8Bf16Fp8Op::create(
-        rewriter, loc, op.getResult().getType(), castedSource, castedScale,
-        scaleSel);
-    rewriter.replaceOp(op, newOp);
+    rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8Bf16Fp8Op>(
+        op, op.getResult().getType(), castedSource, castedScale, scaleSel);
   } else if (isa<Float8E5M2Type>(srcElemType) and
              isa<BFloat16Type>(tgtElemType)) {
     // CvtPkScalePk8Bf16Bf8Op
@@ -1787,20 +1777,16 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
     VectorType v2xi32 = VectorType::get(2, i32);
     Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v2xi32, source);
 
-    auto newOp = ROCDL::CvtPkScalePk8Bf16Bf8Op::create(
-        rewriter, loc, op.getResult().getType(), castedSource, castedScale,
-        scaleSel);
-    rewriter.replaceOp(op, newOp);
+    rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8Bf16Bf8Op>(
+        op, op.getResult().getType(), castedSource, castedScale, scaleSel);
   } else if (isa<Float4E2M1FNType>(srcElemType) and
              isa<Float32Type>(tgtElemType)) {
     // CvtPkScalePk8F32Fp4Op
     // i32
     Value castedSource =
         LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getSource());
-    auto newOp = ROCDL::CvtPkScalePk8F32Fp4Op::create(
-        rewriter, loc, op.getResult().getType(), castedSource, castedScale,
-        scaleSel);
-    rewriter.replaceOp(op, newOp);
+    rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F32Fp4Op>(
+        op, op.getResult().getType(), castedSource, castedScale, scaleSel);
   } else if (isa<Float8E4M3FNType>(srcElemType) and
              isa<Float32Type>(tgtElemType)) {
     // CvtPkScalePk8F32Fp8Op
@@ -1811,10 +1797,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
     VectorType v2xi32 = VectorType::get(2, i32);
     Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v2xi32, source);
 
-    auto newOp = ROCDL::CvtPkScalePk8F32Fp8Op::create(
-        rewriter, loc, op.getResult().getType(), castedSource, castedScale,
-        scaleSel);
-    rewriter.replaceOp(op, newOp);
+    rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F32Fp8Op>(
+        op, op.getResult().getType(), castedSource, castedScale, scaleSel);
   } else if (isa<Float8E5M2Type>(srcElemType) and
              isa<Float32Type>(tgtElemType)) {
     // CvtPkScalePk8F32Bf8Op
@@ -1825,10 +1809,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
     VectorType v2xi32 = VectorType::get(2, i32);
     Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v2xi32, source);
 
-    auto newOp = ROCDL::CvtPkScalePk8F32Bf8Op::create(
-        rewriter, loc, op.getResult().getType(), castedSource, castedScale,
-        scaleSel);
-    rewriter.replaceOp(op, newOp);
+    rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F32Bf8Op>(
+        op, op.getResult().getType(), castedSource, castedScale, scaleSel);
   }
   // smallT = [Fp6, Bf6]
   // largeT = [F16, Bf16, F32]
@@ -1844,10 +1826,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
     VectorType v3xi32 = VectorType::get(3, i32);
     Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source);
 
-    auto newOp = ROCDL::CvtPkScalePk16F16Fp6Op::create(
-        rewriter, loc, op.getResult().getType(), castedSource, castedScale,
-        scaleSel);
-    rewriter.replaceOp(op, newOp);
+    rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16F16Fp6Op>(
+        op, op.getResult().getType(), castedSource, castedScale, scaleSel);
   } else if (isa<Float6E3M2FNType>(srcElemType) and
              isa<Float16Type>(tgtElemType)) {
     // CvtPkScale16F16Bf6Op
@@ -1855,10 +1835,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
     VectorType v3xi32 = VectorType::get(3, i32);
     Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source);
 
-    auto newOp = ROCDL::CvtPkScalePk16F16Bf6Op::create(
-        rewriter, loc, op.getResult().getType(), castedSource, castedScale,
-        scaleSel);
-    rewriter.replaceOp(op, newOp);
+    rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16F16Bf6Op>(
+        op, op.getResult().getType(), castedSource, castedScale, scaleSel);
   } else if (isa<Float6E2M3FNType>(srcElemType) and
              isa<BFloat16Type>(tgtElemType)) {
     // CvtPkScale16Bf16Fp6Op
@@ -1866,10 +1844,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
     VectorType v3xi32 = VectorType::get(3, i32);
     Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source);
 
-    auto newOp = ROCDL::CvtPkScalePk16Bf16Fp6Op::create(
-        rewriter, loc, op.getResult().getType(), castedSource, castedScale,
-        scaleSel);
-    rewriter.replaceOp(op, newOp);
+    rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16Bf16Fp6Op>(
+        op, op.getResult().getType(), castedSource, castedScale, scaleSel);
   } else if (isa<Float6E3M2FNType>(srcElemType) and
              isa<BFloat16Type>(tgtElemType)) {
     // CvtPkScale16Bf16Bf6Op
@@ -1877,10 +1853,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
     VectorType v3xi32 = VectorType::get(3, i32);
     Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source);
 
-    auto newOp = ROCDL::CvtPkScalePk16Bf16Bf6Op::create(
-        rewriter, loc, op.getResult().getType(), castedSource, castedScale,
-        scaleSel);
-    rewriter.replaceOp(op, newOp);
+    rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16Bf16Bf6Op>(
+        op, op.getResult().getType(), castedSource, castedScale, scaleSel);
   } else if (isa<Float6E2M3FNType>(srcElemType) and
              isa<Float32Type>(tgtElemType)) {
     // CvtPkScale16F32Fp6Op
@@ -1888,10 +1862,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
     VectorType v3xi32 = VectorType::get(3, i32);
     Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source);
 
-    auto newOp = ROCDL::CvtPkScalePk16F32Fp6Op::create(
-        rewriter, loc, op.getResult().getType(), castedSource, castedScale,
-        scaleSel);
-    rewriter.replaceOp(op, newOp);
+    rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16F32Fp6Op>(
+        op, op.getResult().getType(), castedSource, castedScale, scaleSel);
   } else if (isa<Float6E3M2FNType>(srcElemType) and
              isa<Float32Type>(tgtElemType)) {
     // CvtPkScale16F32Bf6Op
@@ -1899,10 +1871,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
     VectorType v3xi32 = VectorType::get(3, i32);
     Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source);
 
-    auto newOp = ROCDL::CvtPkScalePk16F32Bf6Op::create(
-        rewriter, loc, op.getResult().getType(), castedSource, castedScale,
-        scaleSel);
-    rewriter.replaceOp(op, newOp);
+    rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16F32Bf6Op>(
+        op, op.getResult().getType(), castedSource, castedScale, scaleSel);
   } else {
     return failure();
   }

>From 4f27e043e39776b908e0ee4e079fd16fc7ef2f3a Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 11 Nov 2025 13:45:30 -0500
Subject: [PATCH 21/43] Use method instead of isa

---
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           | 46 +++++++------------
 1 file changed, 16 insertions(+), 30 deletions(-)

diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index e28b53d8ed1a5..53f7accdb5a54 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1699,7 +1699,7 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
       getScaleSel(blockSize, bitWidth, firstScaleLane, firstScaleByte);
 
   auto targetType = cast<VectorType>(op.getResult().getType());
-  auto tgtElemType = cast<FloatType>(targetType.getElementType());
+  auto destElemType = cast<FloatType>(targetType.getElementType());
   Location loc = op.getLoc();
   // %scale: vector<4xf8E8M0FNU>
   // ===========================
@@ -1716,15 +1716,14 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
   // largeT = [F16, Bf16, F32]
   // CvtPkScalePk8${largeT}${smallT}
 
-  if (isa<Float4E2M1FNType>(srcElemType) and isa<Float16Type>(tgtElemType)) {
+  if (isa<Float4E2M1FNType>(srcElemType) and destElemType.isF16()) {
     // CvtPkScalePk8F16Fp4Op
     // i32
     Value castedSource =
         LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getSource());
     rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F16Fp4Op>(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
-  } else if (isa<Float8E4M3FNType>(srcElemType) and
-             isa<Float16Type>(tgtElemType)) {
+  } else if (isa<Float8E4M3FNType>(srcElemType) and destElemType.isF16()) {
     // CvtPkScalePk8F16Fp8Op
     // vector<8xf8E4M3FN>
     Value source = adaptor.getSource();
@@ -1735,8 +1734,7 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
 
     rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F16Fp8Op>(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
-  } else if (isa<Float8E5M2Type>(srcElemType) and
-             isa<Float16Type>(tgtElemType)) {
+  } else if (isa<Float8E5M2Type>(srcElemType) and destElemType.isF16()) {
     // CvtPkScalePk8F16Bf8Op
     // vector<8xf8E5M2>
     Value source = adaptor.getSource();
@@ -1747,16 +1745,14 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
 
     rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F16Bf8Op>(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
-  } else if (isa<Float4E2M1FNType>(srcElemType) and
-             isa<BFloat16Type>(tgtElemType)) {
+  } else if (isa<Float4E2M1FNType>(srcElemType) and destElemType.isBF16()) {
     // CvtPkScalePk8Bf16Fp4Op
     // i32
     Value castedSource =
         LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getSource());
     rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8Bf16Fp4Op>(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
-  } else if (isa<Float8E4M3FNType>(srcElemType) and
-             isa<BFloat16Type>(tgtElemType)) {
+  } else if (isa<Float8E4M3FNType>(srcElemType) and destElemType.isBF16()) {
     // CvtPkScalePk8Bf16Fp8Op
     // vector<8xf8E4M3FN>
     Value source = adaptor.getSource();
@@ -1767,8 +1763,7 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
 
     rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8Bf16Fp8Op>(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
-  } else if (isa<Float8E5M2Type>(srcElemType) and
-             isa<BFloat16Type>(tgtElemType)) {
+  } else if (isa<Float8E5M2Type>(srcElemType) and destElemType.isBF16()) {
     // CvtPkScalePk8Bf16Bf8Op
     // vector<8xf8E5M2>
     Value source = adaptor.getSource();
@@ -1779,16 +1774,14 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
 
     rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8Bf16Bf8Op>(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
-  } else if (isa<Float4E2M1FNType>(srcElemType) and
-             isa<Float32Type>(tgtElemType)) {
+  } else if (isa<Float4E2M1FNType>(srcElemType) and destElemType.isF32()) {
     // CvtPkScalePk8F32Fp4Op
     // i32
     Value castedSource =
         LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getSource());
     rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F32Fp4Op>(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
-  } else if (isa<Float8E4M3FNType>(srcElemType) and
-             isa<Float32Type>(tgtElemType)) {
+  } else if (isa<Float8E4M3FNType>(srcElemType) and destElemType.isF32()) {
     // CvtPkScalePk8F32Fp8Op
     // vector<8xf8E4M3FN>
     Value source = adaptor.getSource();
@@ -1799,8 +1792,7 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
 
     rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F32Fp8Op>(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
-  } else if (isa<Float8E5M2Type>(srcElemType) and
-             isa<Float32Type>(tgtElemType)) {
+  } else if (isa<Float8E5M2Type>(srcElemType) and destElemType.isF32()) {
     // CvtPkScalePk8F32Bf8Op
     // vector<8xf8E5M2>
     Value source = adaptor.getSource();
@@ -1819,8 +1811,7 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
   // Bf6 = Float6E3M2FN
 
   // CvtPkScalePk16${largeT}${smallT}
-  else if (isa<Float6E2M3FNType>(srcElemType) and
-           isa<Float16Type>(tgtElemType)) {
+  else if (isa<Float6E2M3FNType>(srcElemType) and destElemType.isF16()) {
     // CvtPkScale16F16Fp6Op
     Value source = adaptor.getSource();
     VectorType v3xi32 = VectorType::get(3, i32);
@@ -1828,8 +1819,7 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
 
     rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16F16Fp6Op>(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
-  } else if (isa<Float6E3M2FNType>(srcElemType) and
-             isa<Float16Type>(tgtElemType)) {
+  } else if (isa<Float6E3M2FNType>(srcElemType) and destElemType.isF16()) {
     // CvtPkScale16F16Bf6Op
     Value source = adaptor.getSource();
     VectorType v3xi32 = VectorType::get(3, i32);
@@ -1837,8 +1827,7 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
 
     rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16F16Bf6Op>(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
-  } else if (isa<Float6E2M3FNType>(srcElemType) and
-             isa<BFloat16Type>(tgtElemType)) {
+  } else if (isa<Float6E2M3FNType>(srcElemType) and destElemType.isBF16()) {
     // CvtPkScale16Bf16Fp6Op
     Value source = adaptor.getSource();
     VectorType v3xi32 = VectorType::get(3, i32);
@@ -1846,8 +1835,7 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
 
     rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16Bf16Fp6Op>(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
-  } else if (isa<Float6E3M2FNType>(srcElemType) and
-             isa<BFloat16Type>(tgtElemType)) {
+  } else if (isa<Float6E3M2FNType>(srcElemType) and destElemType.isBF16()) {
     // CvtPkScale16Bf16Bf6Op
     Value source = adaptor.getSource();
     VectorType v3xi32 = VectorType::get(3, i32);
@@ -1855,8 +1843,7 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
 
     rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16Bf16Bf6Op>(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
-  } else if (isa<Float6E2M3FNType>(srcElemType) and
-             isa<Float32Type>(tgtElemType)) {
+  } else if (isa<Float6E2M3FNType>(srcElemType) and destElemType.isF32()) {
     // CvtPkScale16F32Fp6Op
     Value source = adaptor.getSource();
     VectorType v3xi32 = VectorType::get(3, i32);
@@ -1864,8 +1851,7 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
 
     rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16F32Fp6Op>(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
-  } else if (isa<Float6E3M2FNType>(srcElemType) and
-             isa<Float32Type>(tgtElemType)) {
+  } else if (isa<Float6E3M2FNType>(srcElemType) and destElemType.isF32()) {
     // CvtPkScale16F32Bf6Op
     Value source = adaptor.getSource();
     VectorType v3xi32 = VectorType::get(3, i32);

>From 0f8f3c493817873ca8047b7f2b788d1845716f09 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 11 Nov 2025 13:47:59 -0500
Subject: [PATCH 22/43] Hoist variable

---
 mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 13 +------------
 1 file changed, 1 insertion(+), 12 deletions(-)

diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 53f7accdb5a54..646d27a164830 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1715,6 +1715,7 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
   //
   // largeT = [F16, Bf16, F32]
   // CvtPkScalePk8${largeT}${smallT}
+  Value source = adaptor.getSource();
 
   if (isa<Float4E2M1FNType>(srcElemType) and destElemType.isF16()) {
     // CvtPkScalePk8F16Fp4Op
@@ -1726,7 +1727,6 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
   } else if (isa<Float8E4M3FNType>(srcElemType) and destElemType.isF16()) {
     // CvtPkScalePk8F16Fp8Op
     // vector<8xf8E4M3FN>
-    Value source = adaptor.getSource();
 
     // vector<2xi32>
     VectorType v2xi32 = VectorType::get(2, i32);
@@ -1737,7 +1737,6 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
   } else if (isa<Float8E5M2Type>(srcElemType) and destElemType.isF16()) {
     // CvtPkScalePk8F16Bf8Op
     // vector<8xf8E5M2>
-    Value source = adaptor.getSource();
 
     // vector<2xi32>
     VectorType v2xi32 = VectorType::get(2, i32);
@@ -1755,7 +1754,6 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
   } else if (isa<Float8E4M3FNType>(srcElemType) and destElemType.isBF16()) {
     // CvtPkScalePk8Bf16Fp8Op
     // vector<8xf8E4M3FN>
-    Value source = adaptor.getSource();
 
     // vector<2xi32>
     VectorType v2xi32 = VectorType::get(2, i32);
@@ -1766,7 +1764,6 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
   } else if (isa<Float8E5M2Type>(srcElemType) and destElemType.isBF16()) {
     // CvtPkScalePk8Bf16Bf8Op
     // vector<8xf8E5M2>
-    Value source = adaptor.getSource();
 
     // vector<2xi32>
     VectorType v2xi32 = VectorType::get(2, i32);
@@ -1784,7 +1781,6 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
   } else if (isa<Float8E4M3FNType>(srcElemType) and destElemType.isF32()) {
     // CvtPkScalePk8F32Fp8Op
     // vector<8xf8E4M3FN>
-    Value source = adaptor.getSource();
 
     // vector<2xi32>
     VectorType v2xi32 = VectorType::get(2, i32);
@@ -1795,7 +1791,6 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
   } else if (isa<Float8E5M2Type>(srcElemType) and destElemType.isF32()) {
     // CvtPkScalePk8F32Bf8Op
     // vector<8xf8E5M2>
-    Value source = adaptor.getSource();
 
     // vector<2xi32>
     VectorType v2xi32 = VectorType::get(2, i32);
@@ -1813,7 +1808,6 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
   // CvtPkScalePk16${largeT}${smallT}
   else if (isa<Float6E2M3FNType>(srcElemType) and destElemType.isF16()) {
     // CvtPkScale16F16Fp6Op
-    Value source = adaptor.getSource();
     VectorType v3xi32 = VectorType::get(3, i32);
     Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source);
 
@@ -1821,7 +1815,6 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
   } else if (isa<Float6E3M2FNType>(srcElemType) and destElemType.isF16()) {
     // CvtPkScale16F16Bf6Op
-    Value source = adaptor.getSource();
     VectorType v3xi32 = VectorType::get(3, i32);
     Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source);
 
@@ -1829,7 +1822,6 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
   } else if (isa<Float6E2M3FNType>(srcElemType) and destElemType.isBF16()) {
     // CvtPkScale16Bf16Fp6Op
-    Value source = adaptor.getSource();
     VectorType v3xi32 = VectorType::get(3, i32);
     Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source);
 
@@ -1837,7 +1829,6 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
   } else if (isa<Float6E3M2FNType>(srcElemType) and destElemType.isBF16()) {
     // CvtPkScale16Bf16Bf6Op
-    Value source = adaptor.getSource();
     VectorType v3xi32 = VectorType::get(3, i32);
     Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source);
 
@@ -1845,7 +1836,6 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
   } else if (isa<Float6E2M3FNType>(srcElemType) and destElemType.isF32()) {
     // CvtPkScale16F32Fp6Op
-    Value source = adaptor.getSource();
     VectorType v3xi32 = VectorType::get(3, i32);
     Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source);
 
@@ -1853,7 +1843,6 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
   } else if (isa<Float6E3M2FNType>(srcElemType) and destElemType.isF32()) {
     // CvtPkScale16F32Bf6Op
-    Value source = adaptor.getSource();
     VectorType v3xi32 = VectorType::get(3, i32);
     Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source);
 

>From a7a853ea525814ef1f7707cdb99a6327fcbbbfa1 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 11 Nov 2025 13:59:31 -0500
Subject: [PATCH 23/43] Refactor NFC

---
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           | 92 ++++++++-----------
 1 file changed, 40 insertions(+), 52 deletions(-)

diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 646d27a164830..824e4249088ae 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1708,6 +1708,19 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
   Value castedScale =
       LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getScale());
 
+  Value source = adaptor.getSource();
+  Type packedType;
+  if (isa<Float4E2M1FNType>(srcElemType)) {
+    packedType = i32;
+  } else if (isa<Float8E4M3FNType>(srcElemType) ||
+             isa<Float8E5M2Type>(srcElemType)) {
+    packedType = VectorType::get(2, i32);
+  } else if (isa<Float6E2M3FNType>(srcElemType) ||
+             isa<Float6E3M2FNType>(srcElemType)) {
+    packedType = VectorType::get(3, i32);
+  } else {
+    llvm_unreachable("invalid element type for scaled ext");
+  }
   // Ok, so we need to construct ops depending on the sourceType and targetType.
   // smallT = [Fp4, Fp8, Bf8]
   //           Bf8 = E5M2
@@ -1715,87 +1728,68 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
   //
   // largeT = [F16, Bf16, F32]
   // CvtPkScalePk8${largeT}${smallT}
-  Value source = adaptor.getSource();
 
   if (isa<Float4E2M1FNType>(srcElemType) and destElemType.isF16()) {
     // CvtPkScalePk8F16Fp4Op
     // i32
     Value castedSource =
-        LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getSource());
+        LLVM::BitcastOp::create(rewriter, loc, packedType, source);
     rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F16Fp4Op>(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
   } else if (isa<Float8E4M3FNType>(srcElemType) and destElemType.isF16()) {
     // CvtPkScalePk8F16Fp8Op
     // vector<8xf8E4M3FN>
-
-    // vector<2xi32>
-    VectorType v2xi32 = VectorType::get(2, i32);
-    Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v2xi32, source);
-
+    Value castedSource =
+        LLVM::BitcastOp::create(rewriter, loc, packedType, source);
     rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F16Fp8Op>(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
   } else if (isa<Float8E5M2Type>(srcElemType) and destElemType.isF16()) {
     // CvtPkScalePk8F16Bf8Op
     // vector<8xf8E5M2>
-
-    // vector<2xi32>
-    VectorType v2xi32 = VectorType::get(2, i32);
-    Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v2xi32, source);
-
+    Value castedSource =
+        LLVM::BitcastOp::create(rewriter, loc, packedType, source);
     rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F16Bf8Op>(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
   } else if (isa<Float4E2M1FNType>(srcElemType) and destElemType.isBF16()) {
     // CvtPkScalePk8Bf16Fp4Op
     // i32
     Value castedSource =
-        LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getSource());
+        LLVM::BitcastOp::create(rewriter, loc, packedType, source);
     rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8Bf16Fp4Op>(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
   } else if (isa<Float8E4M3FNType>(srcElemType) and destElemType.isBF16()) {
     // CvtPkScalePk8Bf16Fp8Op
     // vector<8xf8E4M3FN>
-
-    // vector<2xi32>
-    VectorType v2xi32 = VectorType::get(2, i32);
-    Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v2xi32, source);
-
+    Value castedSource =
+        LLVM::BitcastOp::create(rewriter, loc, packedType, source);
     rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8Bf16Fp8Op>(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
   } else if (isa<Float8E5M2Type>(srcElemType) and destElemType.isBF16()) {
     // CvtPkScalePk8Bf16Bf8Op
     // vector<8xf8E5M2>
-
-    // vector<2xi32>
-    VectorType v2xi32 = VectorType::get(2, i32);
-    Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v2xi32, source);
-
+    Value castedSource =
+        LLVM::BitcastOp::create(rewriter, loc, packedType, source);
     rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8Bf16Bf8Op>(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
   } else if (isa<Float4E2M1FNType>(srcElemType) and destElemType.isF32()) {
     // CvtPkScalePk8F32Fp4Op
     // i32
     Value castedSource =
-        LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getSource());
+        LLVM::BitcastOp::create(rewriter, loc, packedType, source);
     rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F32Fp4Op>(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
   } else if (isa<Float8E4M3FNType>(srcElemType) and destElemType.isF32()) {
     // CvtPkScalePk8F32Fp8Op
     // vector<8xf8E4M3FN>
-
-    // vector<2xi32>
-    VectorType v2xi32 = VectorType::get(2, i32);
-    Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v2xi32, source);
-
+    Value castedSource =
+        LLVM::BitcastOp::create(rewriter, loc, packedType, source);
     rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F32Fp8Op>(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
   } else if (isa<Float8E5M2Type>(srcElemType) and destElemType.isF32()) {
     // CvtPkScalePk8F32Bf8Op
     // vector<8xf8E5M2>
-
-    // vector<2xi32>
-    VectorType v2xi32 = VectorType::get(2, i32);
-    Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v2xi32, source);
-
+    Value castedSource =
+        LLVM::BitcastOp::create(rewriter, loc, packedType, source);
     rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F32Bf8Op>(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
   }
@@ -1808,44 +1802,38 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
   // CvtPkScalePk16${largeT}${smallT}
   else if (isa<Float6E2M3FNType>(srcElemType) and destElemType.isF16()) {
     // CvtPkScale16F16Fp6Op
-    VectorType v3xi32 = VectorType::get(3, i32);
-    Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source);
-
+    Value castedSource =
+        LLVM::BitcastOp::create(rewriter, loc, packedType, source);
     rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16F16Fp6Op>(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
   } else if (isa<Float6E3M2FNType>(srcElemType) and destElemType.isF16()) {
     // CvtPkScale16F16Bf6Op
-    VectorType v3xi32 = VectorType::get(3, i32);
-    Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source);
-
+    Value castedSource =
+        LLVM::BitcastOp::create(rewriter, loc, packedType, source);
     rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16F16Bf6Op>(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
   } else if (isa<Float6E2M3FNType>(srcElemType) and destElemType.isBF16()) {
     // CvtPkScale16Bf16Fp6Op
-    VectorType v3xi32 = VectorType::get(3, i32);
-    Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source);
-
+    Value castedSource =
+        LLVM::BitcastOp::create(rewriter, loc, packedType, source);
     rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16Bf16Fp6Op>(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
   } else if (isa<Float6E3M2FNType>(srcElemType) and destElemType.isBF16()) {
     // CvtPkScale16Bf16Bf6Op
-    VectorType v3xi32 = VectorType::get(3, i32);
-    Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source);
-
+    Value castedSource =
+        LLVM::BitcastOp::create(rewriter, loc, packedType, source);
     rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16Bf16Bf6Op>(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
   } else if (isa<Float6E2M3FNType>(srcElemType) and destElemType.isF32()) {
     // CvtPkScale16F32Fp6Op
-    VectorType v3xi32 = VectorType::get(3, i32);
-    Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source);
-
+    Value castedSource =
+        LLVM::BitcastOp::create(rewriter, loc, packedType, source);
     rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16F32Fp6Op>(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
   } else if (isa<Float6E3M2FNType>(srcElemType) and destElemType.isF32()) {
     // CvtPkScale16F32Bf6Op
-    VectorType v3xi32 = VectorType::get(3, i32);
-    Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source);
-
+    Value castedSource =
+        LLVM::BitcastOp::create(rewriter, loc, packedType, source);
     rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16F32Bf6Op>(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
   } else {

>From 9e4ab0e7b03009fe50cddaf8f218e93fe0bc82f1 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 11 Nov 2025 14:02:38 -0500
Subject: [PATCH 24/43] Hoist variable. NFC

---
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           | 56 +------------------
 1 file changed, 2 insertions(+), 54 deletions(-)

diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 824e4249088ae..3d41d47da6e00 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1728,68 +1728,34 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
   //
   // largeT = [F16, Bf16, F32]
   // CvtPkScalePk8${largeT}${smallT}
+  Value castedSource =
+      LLVM::BitcastOp::create(rewriter, loc, packedType, source);
 
   if (isa<Float4E2M1FNType>(srcElemType) and destElemType.isF16()) {
-    // CvtPkScalePk8F16Fp4Op
-    // i32
-    Value castedSource =
-        LLVM::BitcastOp::create(rewriter, loc, packedType, source);
     rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F16Fp4Op>(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
   } else if (isa<Float8E4M3FNType>(srcElemType) and destElemType.isF16()) {
-    // CvtPkScalePk8F16Fp8Op
-    // vector<8xf8E4M3FN>
-    Value castedSource =
-        LLVM::BitcastOp::create(rewriter, loc, packedType, source);
     rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F16Fp8Op>(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
   } else if (isa<Float8E5M2Type>(srcElemType) and destElemType.isF16()) {
-    // CvtPkScalePk8F16Bf8Op
-    // vector<8xf8E5M2>
-    Value castedSource =
-        LLVM::BitcastOp::create(rewriter, loc, packedType, source);
     rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F16Bf8Op>(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
   } else if (isa<Float4E2M1FNType>(srcElemType) and destElemType.isBF16()) {
-    // CvtPkScalePk8Bf16Fp4Op
-    // i32
-    Value castedSource =
-        LLVM::BitcastOp::create(rewriter, loc, packedType, source);
     rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8Bf16Fp4Op>(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
   } else if (isa<Float8E4M3FNType>(srcElemType) and destElemType.isBF16()) {
-    // CvtPkScalePk8Bf16Fp8Op
-    // vector<8xf8E4M3FN>
-    Value castedSource =
-        LLVM::BitcastOp::create(rewriter, loc, packedType, source);
     rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8Bf16Fp8Op>(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
   } else if (isa<Float8E5M2Type>(srcElemType) and destElemType.isBF16()) {
-    // CvtPkScalePk8Bf16Bf8Op
-    // vector<8xf8E5M2>
-    Value castedSource =
-        LLVM::BitcastOp::create(rewriter, loc, packedType, source);
     rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8Bf16Bf8Op>(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
   } else if (isa<Float4E2M1FNType>(srcElemType) and destElemType.isF32()) {
-    // CvtPkScalePk8F32Fp4Op
-    // i32
-    Value castedSource =
-        LLVM::BitcastOp::create(rewriter, loc, packedType, source);
     rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F32Fp4Op>(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
   } else if (isa<Float8E4M3FNType>(srcElemType) and destElemType.isF32()) {
-    // CvtPkScalePk8F32Fp8Op
-    // vector<8xf8E4M3FN>
-    Value castedSource =
-        LLVM::BitcastOp::create(rewriter, loc, packedType, source);
     rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F32Fp8Op>(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
   } else if (isa<Float8E5M2Type>(srcElemType) and destElemType.isF32()) {
-    // CvtPkScalePk8F32Bf8Op
-    // vector<8xf8E5M2>
-    Value castedSource =
-        LLVM::BitcastOp::create(rewriter, loc, packedType, source);
     rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F32Bf8Op>(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
   }
@@ -1801,39 +1767,21 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
 
   // CvtPkScalePk16${largeT}${smallT}
   else if (isa<Float6E2M3FNType>(srcElemType) and destElemType.isF16()) {
-    // CvtPkScale16F16Fp6Op
-    Value castedSource =
-        LLVM::BitcastOp::create(rewriter, loc, packedType, source);
     rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16F16Fp6Op>(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
   } else if (isa<Float6E3M2FNType>(srcElemType) and destElemType.isF16()) {
-    // CvtPkScale16F16Bf6Op
-    Value castedSource =
-        LLVM::BitcastOp::create(rewriter, loc, packedType, source);
     rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16F16Bf6Op>(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
   } else if (isa<Float6E2M3FNType>(srcElemType) and destElemType.isBF16()) {
-    // CvtPkScale16Bf16Fp6Op
-    Value castedSource =
-        LLVM::BitcastOp::create(rewriter, loc, packedType, source);
     rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16Bf16Fp6Op>(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
   } else if (isa<Float6E3M2FNType>(srcElemType) and destElemType.isBF16()) {
-    // CvtPkScale16Bf16Bf6Op
-    Value castedSource =
-        LLVM::BitcastOp::create(rewriter, loc, packedType, source);
     rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16Bf16Bf6Op>(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
   } else if (isa<Float6E2M3FNType>(srcElemType) and destElemType.isF32()) {
-    // CvtPkScale16F32Fp6Op
-    Value castedSource =
-        LLVM::BitcastOp::create(rewriter, loc, packedType, source);
     rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16F32Fp6Op>(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
   } else if (isa<Float6E3M2FNType>(srcElemType) and destElemType.isF32()) {
-    // CvtPkScale16F32Bf6Op
-    Value castedSource =
-        LLVM::BitcastOp::create(rewriter, loc, packedType, source);
     rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16F32Bf6Op>(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
   } else {

>From 728686f4ab79a5f64c389c83fc47447f0b69b663 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 11 Nov 2025 14:03:55 -0500
Subject: [PATCH 25/43] Comments. NFC

---
 mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 6 ++----
 1 file changed, 2 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 3d41d47da6e00..0affa2ead9f78 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1721,7 +1721,6 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
   } else {
     llvm_unreachable("invalid element type for scaled ext");
   }
-  // Ok, so we need to construct ops depending on the sourceType and targetType.
   // smallT = [Fp4, Fp8, Bf8]
   //           Bf8 = E5M2
   //           Fp8 = E4M3
@@ -1760,11 +1759,10 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
   }
   // smallT = [Fp6, Bf6]
+  //           Fp6 = Float6E2M3FN
+  //           Bf6 = Float6E3M2FN
   // largeT = [F16, Bf16, F32]
   //
-  // Fp6 = Float6E2M3FN
-  // Bf6 = Float6E3M2FN
-
   // CvtPkScalePk16${largeT}${smallT}
   else if (isa<Float6E2M3FNType>(srcElemType) and destElemType.isF16()) {
     rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16F16Fp6Op>(

>From 47dc32e6698b7ea8f3926bc39d97273b32a012df Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 11 Nov 2025 14:09:01 -0500
Subject: [PATCH 26/43] refactor. nfc

---
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           | 33 +++++++++----------
 1 file changed, 15 insertions(+), 18 deletions(-)

diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 0affa2ead9f78..ec889879fd8d4 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1701,9 +1701,6 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
   auto targetType = cast<VectorType>(op.getResult().getType());
   auto destElemType = cast<FloatType>(targetType.getElementType());
   Location loc = op.getLoc();
-  // %scale: vector<4xf8E8M0FNU>
-  // ===========================
-  // %scale: i32
   IntegerType i32 = rewriter.getI32Type();
   Value castedScale =
       LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getScale());
@@ -1730,31 +1727,31 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
   Value castedSource =
       LLVM::BitcastOp::create(rewriter, loc, packedType, source);
 
-  if (isa<Float4E2M1FNType>(srcElemType) and destElemType.isF16()) {
+  if (isa<Float4E2M1FNType>(srcElemType) && destElemType.isF16()) {
     rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F16Fp4Op>(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
-  } else if (isa<Float8E4M3FNType>(srcElemType) and destElemType.isF16()) {
+  } else if (isa<Float8E4M3FNType>(srcElemType) && destElemType.isF16()) {
     rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F16Fp8Op>(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
-  } else if (isa<Float8E5M2Type>(srcElemType) and destElemType.isF16()) {
+  } else if (isa<Float8E5M2Type>(srcElemType) && destElemType.isF16()) {
     rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F16Bf8Op>(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
-  } else if (isa<Float4E2M1FNType>(srcElemType) and destElemType.isBF16()) {
+  } else if (isa<Float4E2M1FNType>(srcElemType) && destElemType.isBF16()) {
     rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8Bf16Fp4Op>(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
-  } else if (isa<Float8E4M3FNType>(srcElemType) and destElemType.isBF16()) {
+  } else if (isa<Float8E4M3FNType>(srcElemType) && destElemType.isBF16()) {
     rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8Bf16Fp8Op>(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
-  } else if (isa<Float8E5M2Type>(srcElemType) and destElemType.isBF16()) {
+  } else if (isa<Float8E5M2Type>(srcElemType) && destElemType.isBF16()) {
     rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8Bf16Bf8Op>(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
-  } else if (isa<Float4E2M1FNType>(srcElemType) and destElemType.isF32()) {
+  } else if (isa<Float4E2M1FNType>(srcElemType) && destElemType.isF32()) {
     rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F32Fp4Op>(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
-  } else if (isa<Float8E4M3FNType>(srcElemType) and destElemType.isF32()) {
+  } else if (isa<Float8E4M3FNType>(srcElemType) && destElemType.isF32()) {
     rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F32Fp8Op>(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
-  } else if (isa<Float8E5M2Type>(srcElemType) and destElemType.isF32()) {
+  } else if (isa<Float8E5M2Type>(srcElemType) && destElemType.isF32()) {
     rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F32Bf8Op>(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
   }
@@ -1764,22 +1761,22 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
   // largeT = [F16, Bf16, F32]
   //
   // CvtPkScalePk16${largeT}${smallT}
-  else if (isa<Float6E2M3FNType>(srcElemType) and destElemType.isF16()) {
+  else if (isa<Float6E2M3FNType>(srcElemType) && destElemType.isF16()) {
     rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16F16Fp6Op>(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
-  } else if (isa<Float6E3M2FNType>(srcElemType) and destElemType.isF16()) {
+  } else if (isa<Float6E3M2FNType>(srcElemType) && destElemType.isF16()) {
     rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16F16Bf6Op>(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
-  } else if (isa<Float6E2M3FNType>(srcElemType) and destElemType.isBF16()) {
+  } else if (isa<Float6E2M3FNType>(srcElemType) && destElemType.isBF16()) {
     rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16Bf16Fp6Op>(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
-  } else if (isa<Float6E3M2FNType>(srcElemType) and destElemType.isBF16()) {
+  } else if (isa<Float6E3M2FNType>(srcElemType) && destElemType.isBF16()) {
     rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16Bf16Bf6Op>(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
-  } else if (isa<Float6E2M3FNType>(srcElemType) and destElemType.isF32()) {
+  } else if (isa<Float6E2M3FNType>(srcElemType) && destElemType.isF32()) {
     rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16F32Fp6Op>(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
-  } else if (isa<Float6E3M2FNType>(srcElemType) and destElemType.isF32()) {
+  } else if (isa<Float6E3M2FNType>(srcElemType) && destElemType.isF32()) {
     rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16F32Bf6Op>(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
   } else {

>From 3cfea7e2abf631970b8a115c3ed1f7d56267ad38 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 11 Nov 2025 14:14:44 -0500
Subject: [PATCH 27/43] Keep conventions

---
 mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 3 +++
 1 file changed, 3 insertions(+)

diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index ec889879fd8d4..6dbf57342cb3f 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1709,12 +1709,15 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
   Type packedType;
   if (isa<Float4E2M1FNType>(srcElemType)) {
     packedType = i32;
+    packedType = getTypeConverter()->convertType(packedType);
   } else if (isa<Float8E4M3FNType>(srcElemType) ||
              isa<Float8E5M2Type>(srcElemType)) {
     packedType = VectorType::get(2, i32);
+    packedType = getTypeConverter()->convertType(packedType);
   } else if (isa<Float6E2M3FNType>(srcElemType) ||
              isa<Float6E3M2FNType>(srcElemType)) {
     packedType = VectorType::get(3, i32);
+    packedType = getTypeConverter()->convertType(packedType);
   } else {
     llvm_unreachable("invalid element type for scaled ext");
   }

>From 2b010cd04d4644b06a864777de0c63edae4be0c6 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Mon, 17 Nov 2025 09:35:11 -0500
Subject: [PATCH 28/43] Less of exhaustive enumeration

---
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           | 71 +++++--------------
 1 file changed, 19 insertions(+), 52 deletions(-)

diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 6dbf57342cb3f..a9f58063ac32b 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1627,62 +1627,29 @@ int getScaleSel(int blockSize, int bitWidth, int firstScaleLane,
   const bool is_fp8 = bitWidth == 8;
   const bool is_block_16 = blockSize == 16;
 
-  if (!is_fp8 && firstScaleLane == 0 && firstScaleByte == 0 && !is_block_16) {
-    return 0b000;
-  }
-  if (!is_fp8 && firstScaleLane == 0 && firstScaleByte == 0 && is_block_16) {
-    return 0b001;
-  }
-  if (!is_fp8 && firstScaleLane == 0 && firstScaleByte == 2 && !is_block_16) {
-    return 0b010;
-  }
-  if (!is_fp8 && firstScaleLane == 0 && firstScaleByte == 2 && is_block_16) {
-    return 0b011;
-  }
-  if (!is_fp8 && firstScaleLane == 1 && firstScaleByte == 0 && !is_block_16) {
-    return 0b100;
-  }
-  if (!is_fp8 && firstScaleLane == 1 && firstScaleByte == 0 && is_block_16) {
-    return 0b101;
-  }
-  if (!is_fp8 && firstScaleLane == 1 && firstScaleByte == 2 && !is_block_16) {
-    return 0b110;
-  }
-  if (!is_fp8 && firstScaleLane == 1 && firstScaleByte == 2 && is_block_16) {
-    return 0b111;
-  }
-
-  if (is_fp8 && firstScaleLane == 0 && firstScaleByte == 0 && !is_block_16) {
-    return 0b0000;
-  }
-  if (is_fp8 && firstScaleLane == 0 && firstScaleByte == 0 && is_block_16) {
-    return 0b0001;
-  }
-  if (is_fp8 && firstScaleLane == 0 && firstScaleByte == 1 && !is_block_16) {
-    return 0b0010;
-  }
-  if (is_fp8 && firstScaleLane == 0 && firstScaleByte == 2 && !is_block_16) {
-    return 0b0100;
-  }
-  if (is_fp8 && firstScaleLane == 0 && firstScaleByte == 3 && !is_block_16) {
-    return 0b0110;
-  }
-  if (is_fp8 && firstScaleLane == 1 && firstScaleByte == 1 && !is_block_16) {
-    return 0b1010;
-  }
-  if (is_fp8 && firstScaleLane == 1 && firstScaleByte == 2 && !is_block_16) {
-    return 0b1100;
-  }
-  if (is_fp8 && firstScaleLane == 1 && firstScaleByte == 2 && is_block_16) {
-    return 0b1101;
-  }
-  if (is_fp8 && firstScaleLane == 1 && firstScaleByte == 3 && !is_block_16) {
-    return 0b1110;
+  if (!is_fp8) {
+    int bit_0 = is_block_16;
+    assert(llvm::is_contained({0, 2}, firstScaleByte));
+    int bit_1 = (firstScaleByte == 2) << 1;
+    assert(llvm::is_contained({0, 1}, firstScaleLane));
+    int bit_2 = firstScaleLane << 2;
+    return bit_2 | bit_1 | bit_0;
+  } else {
+    int bit_0 = is_block_16;
+    // firstScaleByte is guaranteed to be defined by two bits.
+    assert(llvm::is_contained({0, 1, 2, 3}, firstScaleByte));
+    int bit_2_and_1 = firstScaleByte << 1;
+    assert(llvm::is_contained({0, 1}, firstScaleLane));
+    int bit_3 = firstScaleLane << 3;
+    int bits = bit_3 | bit_2_and_1 | bit_0;
+    // These are invalid cases.
+    assert(!llvm::is_contained(
+        {0b0011, 0b0101, 0b0111, 0b1000, 0b1001, 0b1011, 0b1111}, bits));
+    return bits;
   }
 
   llvm_unreachable("invalid combination of firstScaleLane, firstScaleByte, "
                    "blockSize and type.");
-  return 0;
 }
 
 LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(

>From 6f07ef03b6f778fec82abc85914c6b47095d2312 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Mon, 17 Nov 2025 09:39:10 -0500
Subject: [PATCH 29/43] Correct types

---
 .../lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 14 +++++++-------
 1 file changed, 7 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index a9f58063ac32b..596cac1b76469 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1613,8 +1613,8 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
   return success();
 }
 
-int getScaleSel(int blockSize, int bitWidth, int firstScaleLane,
-                int firstScaleByte) {
+int32_t getScaleSel(int32_t blockSize, unsigned bitWidth,
+                    int32_t firstScaleLane, int32_t firstScaleByte) {
   // When lowering amdgpu.scaled_ext_packed816 to
   // rocdl.cvt.scale.pk*.f*.f* operations, the
   // attributes blockSize, sourceType, firstScaleLane and firstScaleByte
@@ -1656,13 +1656,13 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
     ScaledExtPacked816Op op, ScaledExtPacked816OpAdaptor adaptor,
     ConversionPatternRewriter &rewriter) const {
 
-  int firstScaleLane = op.getFirstScaleLane();
-  int firstScaleByte = op.getFirstScaleByte();
-  int blockSize = op.getBlockSize();
+  int32_t firstScaleLane = op.getFirstScaleLane();
+  int32_t firstScaleByte = op.getFirstScaleByte();
+  int32_t blockSize = op.getBlockSize();
   auto sourceType = cast<VectorType>(op.getSource().getType());
   auto srcElemType = cast<FloatType>(sourceType.getElementType());
-  int bitWidth = srcElemType.getWidth();
-  int scaleSel =
+  unsigned bitWidth = srcElemType.getWidth();
+  int32_t scaleSel =
       getScaleSel(blockSize, bitWidth, firstScaleLane, firstScaleByte);
 
   auto targetType = cast<VectorType>(op.getResult().getType());

>From 69787933cc44ed01faff7abf38078c0f61917bbd Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Mon, 17 Nov 2025 09:39:48 -0500
Subject: [PATCH 30/43] Reflow comment

---
 mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 10 ++++------
 1 file changed, 4 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 596cac1b76469..a7c9a68dd7731 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1615,12 +1615,10 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
 
 int32_t getScaleSel(int32_t blockSize, unsigned bitWidth,
                     int32_t firstScaleLane, int32_t firstScaleByte) {
-  // When lowering amdgpu.scaled_ext_packed816 to
-  // rocdl.cvt.scale.pk*.f*.f* operations, the
-  // attributes blockSize, sourceType, firstScaleLane and firstScaleByte
-  // are merged into a single attribute scaleSel.
-  //
-  // This is how those values are merged together.
+  // When lowering amdgpu.scaled_ext_packed816 to rocdl.cvt.scale.pk*.f*.f*
+  // operations, the attributes blockSize, sourceType, firstScaleLane and
+  // firstScaleByte are merged into a single attribute scaleSel. This is how
+  // those values are merged together.
   assert(llvm::is_contained({16, 32}, blockSize));
   assert(llvm::is_contained({4, 6, 8}, bitWidth));
 

>From ed66571310d9fd687e1b549beb5a79704c629cb4 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Mon, 17 Nov 2025 09:40:29 -0500
Subject: [PATCH 31/43] superfluous empty line

---
 mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 1 -
 1 file changed, 1 deletion(-)

diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index a7c9a68dd7731..15c511460a552 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1653,7 +1653,6 @@ int32_t getScaleSel(int32_t blockSize, unsigned bitWidth,
 LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
     ScaledExtPacked816Op op, ScaledExtPacked816OpAdaptor adaptor,
     ConversionPatternRewriter &rewriter) const {
-
   int32_t firstScaleLane = op.getFirstScaleLane();
   int32_t firstScaleByte = op.getFirstScaleByte();
   int32_t blockSize = op.getBlockSize();

>From 33ef57e0dce2640dc8c3cf3c5623ffc71eb42d18 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Mon, 17 Nov 2025 09:55:53 -0500
Subject: [PATCH 32/43] Using

---
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           | 58 ++++++++-----------
 1 file changed, 24 insertions(+), 34 deletions(-)

diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 15c511460a552..b0b0d4a9fe604 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1620,7 +1620,7 @@ int32_t getScaleSel(int32_t blockSize, unsigned bitWidth,
   // firstScaleByte are merged into a single attribute scaleSel. This is how
   // those values are merged together.
   assert(llvm::is_contained({16, 32}, blockSize));
-  assert(llvm::is_contained({4, 6, 8}, bitWidth));
+  assert(llvm::is_contained(::llvm::ArrayRef<unsigned>{4, 6, 8}, bitWidth));
 
   const bool is_fp8 = bitWidth == 8;
   const bool is_block_16 = blockSize == 16;
@@ -1653,6 +1653,11 @@ int32_t getScaleSel(int32_t blockSize, unsigned bitWidth,
 LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
     ScaledExtPacked816Op op, ScaledExtPacked816OpAdaptor adaptor,
     ConversionPatternRewriter &rewriter) const {
+  using fp4 = Float4E2M1FNType;
+  using fp8 = Float8E4M3FNType;
+  using bf8 = Float8E5M2Type;
+  using fp6 = Float6E2M3FNType;
+  using bf6 = Float6E3M2FNType;
   int32_t firstScaleLane = op.getFirstScaleLane();
   int32_t firstScaleByte = op.getFirstScaleByte();
   int32_t blockSize = op.getBlockSize();
@@ -1671,79 +1676,64 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
 
   Value source = adaptor.getSource();
   Type packedType;
-  if (isa<Float4E2M1FNType>(srcElemType)) {
+  if (isa<fp4>(srcElemType)) {
     packedType = i32;
     packedType = getTypeConverter()->convertType(packedType);
-  } else if (isa<Float8E4M3FNType>(srcElemType) ||
-             isa<Float8E5M2Type>(srcElemType)) {
+  } else if (isa<fp8, bf8>(srcElemType)) {
     packedType = VectorType::get(2, i32);
     packedType = getTypeConverter()->convertType(packedType);
-  } else if (isa<Float6E2M3FNType>(srcElemType) ||
-             isa<Float6E3M2FNType>(srcElemType)) {
+  } else if (isa<fp6, bf6>(srcElemType)) {
     packedType = VectorType::get(3, i32);
     packedType = getTypeConverter()->convertType(packedType);
   } else {
     llvm_unreachable("invalid element type for scaled ext");
   }
-  // smallT = [Fp4, Fp8, Bf8]
-  //           Bf8 = E5M2
-  //           Fp8 = E4M3
-  //
-  // largeT = [F16, Bf16, F32]
-  // CvtPkScalePk8${largeT}${smallT}
   Value castedSource =
       LLVM::BitcastOp::create(rewriter, loc, packedType, source);
 
-  if (isa<Float4E2M1FNType>(srcElemType) && destElemType.isF16()) {
+  if (isa<fp4>(srcElemType) && destElemType.isF16()) {
     rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F16Fp4Op>(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
-  } else if (isa<Float8E4M3FNType>(srcElemType) && destElemType.isF16()) {
+  } else if (isa<fp8>(srcElemType) && destElemType.isF16()) {
     rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F16Fp8Op>(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
-  } else if (isa<Float8E5M2Type>(srcElemType) && destElemType.isF16()) {
+  } else if (isa<bf8>(srcElemType) && destElemType.isF16()) {
     rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F16Bf8Op>(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
-  } else if (isa<Float4E2M1FNType>(srcElemType) && destElemType.isBF16()) {
+  } else if (isa<fp4>(srcElemType) && destElemType.isBF16()) {
     rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8Bf16Fp4Op>(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
-  } else if (isa<Float8E4M3FNType>(srcElemType) && destElemType.isBF16()) {
+  } else if (isa<fp8>(srcElemType) && destElemType.isBF16()) {
     rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8Bf16Fp8Op>(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
-  } else if (isa<Float8E5M2Type>(srcElemType) && destElemType.isBF16()) {
+  } else if (isa<bf8>(srcElemType) && destElemType.isBF16()) {
     rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8Bf16Bf8Op>(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
-  } else if (isa<Float4E2M1FNType>(srcElemType) && destElemType.isF32()) {
+  } else if (isa<fp4>(srcElemType) && destElemType.isF32()) {
     rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F32Fp4Op>(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
-  } else if (isa<Float8E4M3FNType>(srcElemType) && destElemType.isF32()) {
+  } else if (isa<fp8>(srcElemType) && destElemType.isF32()) {
     rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F32Fp8Op>(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
-  } else if (isa<Float8E5M2Type>(srcElemType) && destElemType.isF32()) {
+  } else if (isa<bf8>(srcElemType) && destElemType.isF32()) {
     rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F32Bf8Op>(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
-  }
-  // smallT = [Fp6, Bf6]
-  //           Fp6 = Float6E2M3FN
-  //           Bf6 = Float6E3M2FN
-  // largeT = [F16, Bf16, F32]
-  //
-  // CvtPkScalePk16${largeT}${smallT}
-  else if (isa<Float6E2M3FNType>(srcElemType) && destElemType.isF16()) {
+  } else if (isa<fp6>(srcElemType) && destElemType.isF16()) {
     rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16F16Fp6Op>(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
-  } else if (isa<Float6E3M2FNType>(srcElemType) && destElemType.isF16()) {
+  } else if (isa<bf6>(srcElemType) && destElemType.isF16()) {
     rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16F16Bf6Op>(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
-  } else if (isa<Float6E2M3FNType>(srcElemType) && destElemType.isBF16()) {
+  } else if (isa<fp6>(srcElemType) && destElemType.isBF16()) {
     rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16Bf16Fp6Op>(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
-  } else if (isa<Float6E3M2FNType>(srcElemType) && destElemType.isBF16()) {
+  } else if (isa<bf6>(srcElemType) && destElemType.isBF16()) {
     rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16Bf16Bf6Op>(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
-  } else if (isa<Float6E2M3FNType>(srcElemType) && destElemType.isF32()) {
+  } else if (isa<fp6>(srcElemType) && destElemType.isF32()) {
     rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16F32Fp6Op>(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
-  } else if (isa<Float6E3M2FNType>(srcElemType) && destElemType.isF32()) {
+  } else if (isa<bf6>(srcElemType) && destElemType.isF32()) {
     rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16F32Bf6Op>(
         op, op.getResult().getType(), castedSource, castedScale, scaleSel);
   } else {

>From a83cec94451eae1889c92e333dd6fa7b47904bad Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Mon, 17 Nov 2025 10:18:27 -0500
Subject: [PATCH 33/43] Add chipset check and moved tests

---
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           |   8 +-
 .../AMDGPUToROCDL/amdgpu-to-rocdl.mlir        | 114 -----------------
 .../AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir   | 116 ++++++++++++++++++
 3 files changed, 123 insertions(+), 115 deletions(-)
 create mode 100644 mlir/test/Conversion/AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir

diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index b0b0d4a9fe604..7e30feb520a72 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1658,6 +1658,13 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
   using bf8 = Float8E5M2Type;
   using fp6 = Float6E2M3FNType;
   using bf6 = Float6E3M2FNType;
+  Location loc = op.getLoc();
+  if (chipset != Chipset{12, 5, 0}) {
+    return rewriter.notifyMatchFailure(
+        loc,
+        "Scaled fp packed conversion instructions are not available on target "
+        "architecture and their emulation is not implemented");
+  }
   int32_t firstScaleLane = op.getFirstScaleLane();
   int32_t firstScaleByte = op.getFirstScaleByte();
   int32_t blockSize = op.getBlockSize();
@@ -1669,7 +1676,6 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
 
   auto targetType = cast<VectorType>(op.getResult().getType());
   auto destElemType = cast<FloatType>(targetType.getElementType());
-  Location loc = op.getLoc();
   IntegerType i32 = rewriter.getI32Type();
   Value castedScale =
       LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getScale());
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
index 94a04d98004c7..432b8876696a9 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
@@ -457,117 +457,3 @@ func.func @sched_barrier() {
   func.return
 }
 
-// CHECK-LABEL: @scaled_ext_packed816_fp4
-// CHECK-SAME: (%[[SOURCE:.+]]: vector<8xf4E2M1FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>)
-func.func @scaled_ext_packed816_fp4(%v: vector<8xf4E2M1FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>, vector<8xf32>) {
-  // CHECK: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8>
-  // CHECK: %[[SOURCE_8xi4:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<8xf4E2M1FN> to vector<8xi4>
-  // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
-  // CHECK: %[[SOURCE_i32:.+]] = llvm.bitcast %[[SOURCE_8xi4]] : vector<8xi4> to i32
-  // CHECK: rocdl.cvt.scale.pk8.f16.fp4 %[[SOURCE_i32]], %[[SCALE_i32]][0] : vector<8xf16>
-  %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xf16>
-
-  // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
-  // CHECK: %[[SOURCE_i32:.+]] = llvm.bitcast %[[SOURCE_8xi4]] : vector<8xi4> to i32
-  // CHECK: rocdl.cvt.scale.pk8.bf16.fp4 %[[SOURCE_i32]], %[[SCALE_i32]][0] : vector<8xbf16>
-  %ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xbf16>
-
-  // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
-  // CHECK: %[[SOURCE_i32:.+]] = llvm.bitcast %[[SOURCE_8xi4]] : vector<8xi4> to i32
-  // CHECK: rocdl.cvt.scale.pk8.f32.fp4 %[[SOURCE_i32]], %[[SCALE_i32]][0] : vector<8xf32>
-  %ret2 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xf32>
-  func.return %ret0, %ret1, %ret2: vector<8xf16>, vector<8xbf16>, vector<8xf32>
-}
-
-// CHECK-LABEL: @scaled_ext_packed816_fp8
-// CHECK-SAME: (%[[SOURCE:.+]]: vector<8xf8E4M3FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>)
-func.func @scaled_ext_packed816_fp8(%v: vector<8xf8E4M3FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>, vector<8xf32>) {
-  // CHECK: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8>
-  // CHECK: %[[SOURCE_8xi8:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<8xf8E4M3FN> to vector<8xi8>
-  // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
-  // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32>
-  // CHECK: rocdl.cvt.scale.pk8.f16.fp8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xf16>
-  %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xf16>
-
-  // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
-  // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32>
-  // CHECK: rocdl.cvt.scale.pk8.bf16.fp8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xbf16>
-  %ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xbf16>
-
-  // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
-  // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32>
-  // CHECK: rocdl.cvt.scale.pk8.f32.fp8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xf32>
-  %ret2 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xf32>
-
-  func.return %ret0, %ret1, %ret2 : vector<8xf16>, vector<8xbf16>, vector<8xf32>
-}
-
-// CHECK-LABEL: @scaled_ext_packed816_bf8
-// CHECK-SAME: (%[[SOURCE:.+]]: vector<8xf8E5M2>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>)
-func.func @scaled_ext_packed816_bf8(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>, vector<8xf32>) {
-  // CHECK: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8>
-  // CHECK: %[[SOURCE_8xi8:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<8xf8E5M2> to vector<8xi8>
-  // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
-  // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32>
-  // CHECK: %[[RES:.+]] = rocdl.cvt.scale.pk8.f16.bf8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xf16>
-  %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xf16>
-
-  // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
-  // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32>
-  // CHECK: rocdl.cvt.scale.pk8.bf16.bf8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xbf16>
-  %ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xbf16>
-
-  // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
-  // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32>
-  // CHECK: rocdl.cvt.scale.pk8.f32.bf8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xf32>
-  %ret2 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xf32>
-  func.return %ret0, %ret1, %ret2 : vector<8xf16>, vector<8xbf16>, vector<8xf32>
-}
-
-
-// CHECK-LABEL: @scaled_ext_packed816_fp6
-// CHECK-SAME: (%[[SOURCE:.+]]: vector<16xf6E2M3FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>)
-func.func @scaled_ext_packed816_fp6(%v: vector<16xf6E2M3FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf16>, vector<16xbf16>, vector<16xf32>) {
-  // CHECK-DAG: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8>
-  // CHECK-DAG: %[[SOURCE_16xi6:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<16xf6E2M3FN> to vector<16xi6>
-
-  // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
-  // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32>
-  // CHECK: rocdl.cvt.scale.pk16.f16.fp6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xf16>
-  %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E2M3FN>, vector<4xf8E8M0FNU> -> vector<16xf16>
-
-  // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
-  // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32>
-  // CHECK: rocdl.cvt.scale.pk16.bf16.fp6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xbf16>
-  %ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E2M3FN>, vector<4xf8E8M0FNU> -> vector<16xbf16>
-
-  // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
-  // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32>
-  // CHECK: rocdl.cvt.scale.pk16.f32.fp6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xf32>
-  %ret2 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E2M3FN>, vector<4xf8E8M0FNU> -> vector<16xf32>
-  return %ret0, %ret1, %ret2: vector<16xf16>, vector<16xbf16>, vector<16xf32>
-}
-
-// CHECK-LABEL: @scaled_ext_packed816_bf6
-// CHECK-SAME: (%[[SOURCE:.+]]: vector<16xf6E3M2FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>)
-func.func @scaled_ext_packed816_bf6(%v: vector<16xf6E3M2FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf16>, vector<16xbf16>, vector<16xf32>) {
-  // CHECK-DAG: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8>
-  // CHECK-DAG: %[[SOURCE_16xi6:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<16xf6E3M2FN> to vector<16xi6>
-
-  // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
-  // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32>
-  // CHECK: rocdl.cvt.scale.pk16.f16.bf6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xf16>
-  %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xf16>
-
-  // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
-  // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32>
-  // CHECK: rocdl.cvt.scale.pk16.bf16.bf6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xbf16>
-  %ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xbf16>
-
-  // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
-  // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32>
-  // CHECK: rocdl.cvt.scale.pk16.f32.bf6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xf32>
-  %ret2 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xf32>
-  return %ret0, %ret1, %ret2: vector<16xf16>, vector<16xbf16>, vector<16xf32>
-}
-
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir b/mlir/test/Conversion/AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir
new file mode 100644
index 0000000000000..811a8e49dc5c6
--- /dev/null
+++ b/mlir/test/Conversion/AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir
@@ -0,0 +1,116 @@
+// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1250 | FileCheck %s
+
+// CHECK-LABEL: @scaled_ext_packed816_fp4
+// CHECK-SAME: (%[[SOURCE:.+]]: vector<8xf4E2M1FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>)
+func.func @scaled_ext_packed816_fp4(%v: vector<8xf4E2M1FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>, vector<8xf32>) {
+  // CHECK: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8>
+  // CHECK: %[[SOURCE_8xi4:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<8xf4E2M1FN> to vector<8xi4>
+  // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+  // CHECK: %[[SOURCE_i32:.+]] = llvm.bitcast %[[SOURCE_8xi4]] : vector<8xi4> to i32
+  // CHECK: rocdl.cvt.scale.pk8.f16.fp4 %[[SOURCE_i32]], %[[SCALE_i32]][0] : vector<8xf16>
+  %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xf16>
+
+  // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+  // CHECK: %[[SOURCE_i32:.+]] = llvm.bitcast %[[SOURCE_8xi4]] : vector<8xi4> to i32
+  // CHECK: rocdl.cvt.scale.pk8.bf16.fp4 %[[SOURCE_i32]], %[[SCALE_i32]][0] : vector<8xbf16>
+  %ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xbf16>
+
+  // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+  // CHECK: %[[SOURCE_i32:.+]] = llvm.bitcast %[[SOURCE_8xi4]] : vector<8xi4> to i32
+  // CHECK: rocdl.cvt.scale.pk8.f32.fp4 %[[SOURCE_i32]], %[[SCALE_i32]][0] : vector<8xf32>
+  %ret2 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xf32>
+  func.return %ret0, %ret1, %ret2: vector<8xf16>, vector<8xbf16>, vector<8xf32>
+}
+
+// CHECK-LABEL: @scaled_ext_packed816_fp8
+// CHECK-SAME: (%[[SOURCE:.+]]: vector<8xf8E4M3FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>)
+func.func @scaled_ext_packed816_fp8(%v: vector<8xf8E4M3FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>, vector<8xf32>) {
+  // CHECK: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8>
+  // CHECK: %[[SOURCE_8xi8:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<8xf8E4M3FN> to vector<8xi8>
+  // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+  // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32>
+  // CHECK: rocdl.cvt.scale.pk8.f16.fp8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xf16>
+  %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xf16>
+
+  // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+  // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32>
+  // CHECK: rocdl.cvt.scale.pk8.bf16.fp8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xbf16>
+  %ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xbf16>
+
+  // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+  // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32>
+  // CHECK: rocdl.cvt.scale.pk8.f32.fp8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xf32>
+  %ret2 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xf32>
+
+  func.return %ret0, %ret1, %ret2 : vector<8xf16>, vector<8xbf16>, vector<8xf32>
+}
+
+// CHECK-LABEL: @scaled_ext_packed816_bf8
+// CHECK-SAME: (%[[SOURCE:.+]]: vector<8xf8E5M2>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>)
+func.func @scaled_ext_packed816_bf8(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>, vector<8xf32>) {
+  // CHECK: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8>
+  // CHECK: %[[SOURCE_8xi8:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<8xf8E5M2> to vector<8xi8>
+  // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+  // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32>
+  // CHECK: %[[RES:.+]] = rocdl.cvt.scale.pk8.f16.bf8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xf16>
+  %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xf16>
+
+  // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+  // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32>
+  // CHECK: rocdl.cvt.scale.pk8.bf16.bf8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xbf16>
+  %ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xbf16>
+
+  // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+  // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32>
+  // CHECK: rocdl.cvt.scale.pk8.f32.bf8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xf32>
+  %ret2 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xf32>
+  func.return %ret0, %ret1, %ret2 : vector<8xf16>, vector<8xbf16>, vector<8xf32>
+}
+
+
+// CHECK-LABEL: @scaled_ext_packed816_fp6
+// CHECK-SAME: (%[[SOURCE:.+]]: vector<16xf6E2M3FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>)
+func.func @scaled_ext_packed816_fp6(%v: vector<16xf6E2M3FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf16>, vector<16xbf16>, vector<16xf32>) {
+  // CHECK-DAG: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8>
+  // CHECK-DAG: %[[SOURCE_16xi6:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<16xf6E2M3FN> to vector<16xi6>
+
+  // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+  // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32>
+  // CHECK: rocdl.cvt.scale.pk16.f16.fp6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xf16>
+  %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E2M3FN>, vector<4xf8E8M0FNU> -> vector<16xf16>
+
+  // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+  // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32>
+  // CHECK: rocdl.cvt.scale.pk16.bf16.fp6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xbf16>
+  %ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E2M3FN>, vector<4xf8E8M0FNU> -> vector<16xbf16>
+
+  // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+  // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32>
+  // CHECK: rocdl.cvt.scale.pk16.f32.fp6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xf32>
+  %ret2 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E2M3FN>, vector<4xf8E8M0FNU> -> vector<16xf32>
+  return %ret0, %ret1, %ret2: vector<16xf16>, vector<16xbf16>, vector<16xf32>
+}
+
+// CHECK-LABEL: @scaled_ext_packed816_bf6
+// CHECK-SAME: (%[[SOURCE:.+]]: vector<16xf6E3M2FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>)
+func.func @scaled_ext_packed816_bf6(%v: vector<16xf6E3M2FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf16>, vector<16xbf16>, vector<16xf32>) {
+  // CHECK-DAG: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8>
+  // CHECK-DAG: %[[SOURCE_16xi6:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<16xf6E3M2FN> to vector<16xi6>
+
+  // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+  // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32>
+  // CHECK: rocdl.cvt.scale.pk16.f16.bf6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xf16>
+  %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xf16>
+
+  // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+  // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32>
+  // CHECK: rocdl.cvt.scale.pk16.bf16.bf6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xbf16>
+  %ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xbf16>
+
+  // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32
+  // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32>
+  // CHECK: rocdl.cvt.scale.pk16.f32.bf6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xf32>
+  %ret2 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xf32>
+  return %ret0, %ret1, %ret2: vector<16xf16>, vector<16xbf16>, vector<16xf32>
+}
+

>From 34ed3e9384a683e44b967baff53fb952b3320e90 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Mon, 17 Nov 2025 10:21:24 -0500
Subject: [PATCH 34/43] Refactor NFC

---
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           | 25 ++++++++-----------
 1 file changed, 11 insertions(+), 14 deletions(-)

diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 7e30feb520a72..0c5f4ff7f8227 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1632,22 +1632,19 @@ int32_t getScaleSel(int32_t blockSize, unsigned bitWidth,
     assert(llvm::is_contained({0, 1}, firstScaleLane));
     int bit_2 = firstScaleLane << 2;
     return bit_2 | bit_1 | bit_0;
-  } else {
-    int bit_0 = is_block_16;
-    // firstScaleByte is guaranteed to be defined by two bits.
-    assert(llvm::is_contained({0, 1, 2, 3}, firstScaleByte));
-    int bit_2_and_1 = firstScaleByte << 1;
-    assert(llvm::is_contained({0, 1}, firstScaleLane));
-    int bit_3 = firstScaleLane << 3;
-    int bits = bit_3 | bit_2_and_1 | bit_0;
-    // These are invalid cases.
-    assert(!llvm::is_contained(
-        {0b0011, 0b0101, 0b0111, 0b1000, 0b1001, 0b1011, 0b1111}, bits));
-    return bits;
   }
 
-  llvm_unreachable("invalid combination of firstScaleLane, firstScaleByte, "
-                   "blockSize and type.");
+  int bit_0 = is_block_16;
+  // firstScaleByte is guaranteed to be defined by two bits.
+  assert(llvm::is_contained({0, 1, 2, 3}, firstScaleByte));
+  int bit_2_and_1 = firstScaleByte << 1;
+  assert(llvm::is_contained({0, 1}, firstScaleLane));
+  int bit_3 = firstScaleLane << 3;
+  int bits = bit_3 | bit_2_and_1 | bit_0;
+  // These are invalid cases.
+  assert(!llvm::is_contained(
+      {0b0011, 0b0101, 0b0111, 0b1000, 0b1001, 0b1011, 0b1111}, bits));
+  return bits;
 }
 
 LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(

>From b88f7f6e74c0038873dabec3d31b2bba12b8e6ab Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Mon, 17 Nov 2025 10:44:41 -0500
Subject: [PATCH 35/43] Use operation name

---
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           | 105 ++++++++++--------
 1 file changed, 57 insertions(+), 48 deletions(-)

diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 0c5f4ff7f8227..23660361094c3 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1647,6 +1647,46 @@ int32_t getScaleSel(int32_t blockSize, unsigned bitWidth,
   return bits;
 }
 
+static std::optional<StringRef>
+scaledExtPacked816ToIntrinsic(Type srcElemType, Type destElemType) {
+  using fp4 = Float4E2M1FNType;
+  using fp8 = Float8E4M3FNType;
+  using bf8 = Float8E5M2Type;
+  using fp6 = Float6E2M3FNType;
+  using bf6 = Float6E3M2FNType;
+  if (isa<fp4>(srcElemType) && destElemType.isF16())
+    return ROCDL::CvtPkScalePk8F16Fp4Op::getOperationName();
+  if (isa<fp8>(srcElemType) && destElemType.isF16())
+    return ROCDL::CvtPkScalePk8F16Fp8Op::getOperationName();
+  if (isa<bf8>(srcElemType) && destElemType.isF16())
+    return ROCDL::CvtPkScalePk8F16Bf8Op::getOperationName();
+  if (isa<fp4>(srcElemType) && destElemType.isBF16())
+    return ROCDL::CvtPkScalePk8Bf16Fp4Op::getOperationName();
+  if (isa<fp8>(srcElemType) && destElemType.isBF16())
+    return ROCDL::CvtPkScalePk8Bf16Fp8Op::getOperationName();
+  if (isa<bf8>(srcElemType) && destElemType.isBF16())
+    return ROCDL::CvtPkScalePk8Bf16Bf8Op::getOperationName();
+  if (isa<fp4>(srcElemType) && destElemType.isF32())
+    return ROCDL::CvtPkScalePk8F32Fp4Op::getOperationName();
+  if (isa<fp8>(srcElemType) && destElemType.isF32())
+    return ROCDL::CvtPkScalePk8F32Fp8Op::getOperationName();
+  if (isa<bf8>(srcElemType) && destElemType.isF32())
+    return ROCDL::CvtPkScalePk8F32Bf8Op::getOperationName();
+  if (isa<fp6>(srcElemType) && destElemType.isF16())
+    return ROCDL::CvtPkScalePk16F16Fp6Op::getOperationName();
+  if (isa<bf6>(srcElemType) && destElemType.isF16())
+    return ROCDL::CvtPkScalePk16F16Bf6Op::getOperationName();
+  if (isa<fp6>(srcElemType) && destElemType.isBF16())
+    return ROCDL::CvtPkScalePk16Bf16Fp6Op::getOperationName();
+  if (isa<bf6>(srcElemType) && destElemType.isBF16())
+    return ROCDL::CvtPkScalePk16Bf16Bf6Op::getOperationName();
+  if (isa<fp6>(srcElemType) && destElemType.isF32())
+    return ROCDL::CvtPkScalePk16F32Fp6Op::getOperationName();
+  if (isa<bf6>(srcElemType) && destElemType.isF32())
+    return ROCDL::CvtPkScalePk16F32Bf6Op::getOperationName();
+  return std::nullopt;
+}
+
 LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
     ScaledExtPacked816Op op, ScaledExtPacked816OpAdaptor adaptor,
     ConversionPatternRewriter &rewriter) const {
@@ -1694,54 +1734,23 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
   Value castedSource =
       LLVM::BitcastOp::create(rewriter, loc, packedType, source);
 
-  if (isa<fp4>(srcElemType) && destElemType.isF16()) {
-    rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F16Fp4Op>(
-        op, op.getResult().getType(), castedSource, castedScale, scaleSel);
-  } else if (isa<fp8>(srcElemType) && destElemType.isF16()) {
-    rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F16Fp8Op>(
-        op, op.getResult().getType(), castedSource, castedScale, scaleSel);
-  } else if (isa<bf8>(srcElemType) && destElemType.isF16()) {
-    rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F16Bf8Op>(
-        op, op.getResult().getType(), castedSource, castedScale, scaleSel);
-  } else if (isa<fp4>(srcElemType) && destElemType.isBF16()) {
-    rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8Bf16Fp4Op>(
-        op, op.getResult().getType(), castedSource, castedScale, scaleSel);
-  } else if (isa<fp8>(srcElemType) && destElemType.isBF16()) {
-    rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8Bf16Fp8Op>(
-        op, op.getResult().getType(), castedSource, castedScale, scaleSel);
-  } else if (isa<bf8>(srcElemType) && destElemType.isBF16()) {
-    rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8Bf16Bf8Op>(
-        op, op.getResult().getType(), castedSource, castedScale, scaleSel);
-  } else if (isa<fp4>(srcElemType) && destElemType.isF32()) {
-    rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F32Fp4Op>(
-        op, op.getResult().getType(), castedSource, castedScale, scaleSel);
-  } else if (isa<fp8>(srcElemType) && destElemType.isF32()) {
-    rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F32Fp8Op>(
-        op, op.getResult().getType(), castedSource, castedScale, scaleSel);
-  } else if (isa<bf8>(srcElemType) && destElemType.isF32()) {
-    rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F32Bf8Op>(
-        op, op.getResult().getType(), castedSource, castedScale, scaleSel);
-  } else if (isa<fp6>(srcElemType) && destElemType.isF16()) {
-    rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16F16Fp6Op>(
-        op, op.getResult().getType(), castedSource, castedScale, scaleSel);
-  } else if (isa<bf6>(srcElemType) && destElemType.isF16()) {
-    rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16F16Bf6Op>(
-        op, op.getResult().getType(), castedSource, castedScale, scaleSel);
-  } else if (isa<fp6>(srcElemType) && destElemType.isBF16()) {
-    rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16Bf16Fp6Op>(
-        op, op.getResult().getType(), castedSource, castedScale, scaleSel);
-  } else if (isa<bf6>(srcElemType) && destElemType.isBF16()) {
-    rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16Bf16Bf6Op>(
-        op, op.getResult().getType(), castedSource, castedScale, scaleSel);
-  } else if (isa<fp6>(srcElemType) && destElemType.isF32()) {
-    rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16F32Fp6Op>(
-        op, op.getResult().getType(), castedSource, castedScale, scaleSel);
-  } else if (isa<bf6>(srcElemType) && destElemType.isF32()) {
-    rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16F32Bf6Op>(
-        op, op.getResult().getType(), castedSource, castedScale, scaleSel);
-  } else {
-    return failure();
-  }
+  std::optional<StringRef> maybeIntrinsic =
+      scaledExtPacked816ToIntrinsic(srcElemType, destElemType);
+  if (!maybeIntrinsic.has_value())
+    return op.emitOpError(
+        "no intrinsic matching packed scaled conversion on the given chipset");
+
+  OperationState loweredOp(loc, *maybeIntrinsic);
+  loweredOp.addTypes({op.getResult().getType()});
+  loweredOp.addOperands({castedSource, castedScale});
+
+  SmallVector<NamedAttribute, 1> attrs;
+  attrs.push_back(
+      NamedAttribute("scaleSel", rewriter.getI32IntegerAttr(scaleSel)));
+
+  loweredOp.addAttributes(attrs);
+  Operation *lowered = rewriter.create(loweredOp);
+  rewriter.replaceOp(op, lowered);
 
   return success();
 }

>From 7a7ecaf31b5ad267343b871d5f2f44b8eff875b3 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Mon, 17 Nov 2025 10:48:38 -0500
Subject: [PATCH 36/43] Convert result type

---
 mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 23660361094c3..2e73f7b0d4266 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1741,7 +1741,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
         "no intrinsic matching packed scaled conversion on the given chipset");
 
   OperationState loweredOp(loc, *maybeIntrinsic);
-  loweredOp.addTypes({op.getResult().getType()});
+  Type llvmResultType = typeConverter->convertType(op.getResult().getType());
+  loweredOp.addTypes({llvmResultType});
   loweredOp.addOperands({castedSource, castedScale});
 
   SmallVector<NamedAttribute, 1> attrs;

>From 1025e2b999b7f42947842cfb3185b921a0ea534d Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Mon, 17 Nov 2025 10:58:20 -0500
Subject: [PATCH 37/43] Check for type conversion failures

---
 mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 11 ++++++++---
 1 file changed, 8 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 2e73f7b0d4266..ae907b8ffefc3 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1718,7 +1718,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
       LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getScale());
 
   Value source = adaptor.getSource();
-  Type packedType;
+  Type llvmResultType = typeConverter->convertType(op.getResult().getType());
+  Type packedType = nullptr;
   if (isa<fp4>(srcElemType)) {
     packedType = i32;
     packedType = getTypeConverter()->convertType(packedType);
@@ -1729,8 +1730,13 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
     packedType = VectorType::get(3, i32);
     packedType = getTypeConverter()->convertType(packedType);
   } else {
-    llvm_unreachable("invalid element type for scaled ext");
+    llvm_unreachable("invalid element type for packed scaled ext");
+  }
+
+  if (!packedType || !llvmResultType) {
+    return rewriter.notifyMatchFailure(op, "type conversion failed");
   }
+
   Value castedSource =
       LLVM::BitcastOp::create(rewriter, loc, packedType, source);
 
@@ -1741,7 +1747,6 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
         "no intrinsic matching packed scaled conversion on the given chipset");
 
   OperationState loweredOp(loc, *maybeIntrinsic);
-  Type llvmResultType = typeConverter->convertType(op.getResult().getType());
   loweredOp.addTypes({llvmResultType});
   loweredOp.addOperands({castedSource, castedScale});
 

>From 7c44f0959a85b0a801a687e781605df81e3b1c6d Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Mon, 17 Nov 2025 11:21:25 -0500
Subject: [PATCH 38/43] Add top-level if condition for each src type

---
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           | 78 +++++++++++--------
 1 file changed, 47 insertions(+), 31 deletions(-)

diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index f9bbc6535f0e1..d23a89a199131 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1667,37 +1667,53 @@ scaledExtPacked816ToIntrinsic(Type srcElemType, Type destElemType) {
   using bf8 = Float8E5M2Type;
   using fp6 = Float6E2M3FNType;
   using bf6 = Float6E3M2FNType;
-  if (isa<fp4>(srcElemType) && destElemType.isF16())
-    return ROCDL::CvtPkScalePk8F16Fp4Op::getOperationName();
-  if (isa<fp8>(srcElemType) && destElemType.isF16())
-    return ROCDL::CvtPkScalePk8F16Fp8Op::getOperationName();
-  if (isa<bf8>(srcElemType) && destElemType.isF16())
-    return ROCDL::CvtPkScalePk8F16Bf8Op::getOperationName();
-  if (isa<fp4>(srcElemType) && destElemType.isBF16())
-    return ROCDL::CvtPkScalePk8Bf16Fp4Op::getOperationName();
-  if (isa<fp8>(srcElemType) && destElemType.isBF16())
-    return ROCDL::CvtPkScalePk8Bf16Fp8Op::getOperationName();
-  if (isa<bf8>(srcElemType) && destElemType.isBF16())
-    return ROCDL::CvtPkScalePk8Bf16Bf8Op::getOperationName();
-  if (isa<fp4>(srcElemType) && destElemType.isF32())
-    return ROCDL::CvtPkScalePk8F32Fp4Op::getOperationName();
-  if (isa<fp8>(srcElemType) && destElemType.isF32())
-    return ROCDL::CvtPkScalePk8F32Fp8Op::getOperationName();
-  if (isa<bf8>(srcElemType) && destElemType.isF32())
-    return ROCDL::CvtPkScalePk8F32Bf8Op::getOperationName();
-  if (isa<fp6>(srcElemType) && destElemType.isF16())
-    return ROCDL::CvtPkScalePk16F16Fp6Op::getOperationName();
-  if (isa<bf6>(srcElemType) && destElemType.isF16())
-    return ROCDL::CvtPkScalePk16F16Bf6Op::getOperationName();
-  if (isa<fp6>(srcElemType) && destElemType.isBF16())
-    return ROCDL::CvtPkScalePk16Bf16Fp6Op::getOperationName();
-  if (isa<bf6>(srcElemType) && destElemType.isBF16())
-    return ROCDL::CvtPkScalePk16Bf16Bf6Op::getOperationName();
-  if (isa<fp6>(srcElemType) && destElemType.isF32())
-    return ROCDL::CvtPkScalePk16F32Fp6Op::getOperationName();
-  if (isa<bf6>(srcElemType) && destElemType.isF32())
-    return ROCDL::CvtPkScalePk16F32Bf6Op::getOperationName();
-  return std::nullopt;
+  if (isa<fp4>(srcElemType)) {
+    if (destElemType.isF16())
+      return ROCDL::CvtPkScalePk8F16Fp4Op::getOperationName();
+    if (destElemType.isBF16())
+      return ROCDL::CvtPkScalePk8Bf16Fp4Op::getOperationName();
+    if (destElemType.isF32())
+      return ROCDL::CvtPkScalePk8F32Fp4Op::getOperationName();
+    return std::nullopt;
+  }
+  if (isa<fp8>(srcElemType)) {
+    if (destElemType.isF16())
+      return ROCDL::CvtPkScalePk8F16Fp8Op::getOperationName();
+    if (destElemType.isBF16())
+      return ROCDL::CvtPkScalePk8Bf16Fp8Op::getOperationName();
+    if (destElemType.isF32())
+      return ROCDL::CvtPkScalePk8F32Fp8Op::getOperationName();
+    return std::nullopt;
+  }
+  if (isa<bf8>(srcElemType)) {
+    if (destElemType.isF16())
+      return ROCDL::CvtPkScalePk8F16Bf8Op::getOperationName();
+    if (destElemType.isBF16())
+      return ROCDL::CvtPkScalePk8Bf16Bf8Op::getOperationName();
+    if (destElemType.isF32())
+      return ROCDL::CvtPkScalePk8F32Bf8Op::getOperationName();
+    return std::nullopt;
+  }
+  if (isa<fp6>(srcElemType)) {
+    if (destElemType.isF16())
+      return ROCDL::CvtPkScalePk16F16Fp6Op::getOperationName();
+    if (destElemType.isBF16())
+      return ROCDL::CvtPkScalePk16Bf16Fp6Op::getOperationName();
+    if (destElemType.isF32())
+      return ROCDL::CvtPkScalePk16F32Fp6Op::getOperationName();
+    return std::nullopt;
+  }
+  if (isa<bf6>(srcElemType)) {
+    if (destElemType.isF16())
+      return ROCDL::CvtPkScalePk16F16Bf6Op::getOperationName();
+    if (destElemType.isBF16())
+      return ROCDL::CvtPkScalePk16Bf16Bf6Op::getOperationName();
+    if (destElemType.isF32())
+      return ROCDL::CvtPkScalePk16F32Bf6Op::getOperationName();
+    return std::nullopt;
+  }
+  llvm_unreachable("invalid combination of element types for packed conversion "
+                   "instructions");
 }
 
 LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(

>From f06a67e94e08dd242c6a89dd06023b6ce5d95ad0 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Mon, 17 Nov 2025 11:37:30 -0500
Subject: [PATCH 39/43] Add chipset constant at beginning of file

---
 mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 7 ++++---
 1 file changed, 4 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index d23a89a199131..4d1734fe710c0 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -43,6 +43,7 @@ constexpr Chipset kGfx908 = Chipset(9, 0, 8);
 constexpr Chipset kGfx90a = Chipset(9, 0, 0xa);
 constexpr Chipset kGfx942 = Chipset(9, 4, 2);
 constexpr Chipset kGfx950 = Chipset(9, 5, 0);
+constexpr Chipset kGfx1250 = Chipset(12, 5, 0);
 
 /// Convert an unsigned number `val` to i32.
 static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter,
@@ -1149,7 +1150,7 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
                                  k, isRDNA3);
 
   // Handle gfx1250.
-  if (chipset == Chipset{12, 5, 0})
+  if (chipset == kGfx1250)
     return wmmaOpToIntrinsicGfx1250(elemSourceType, elemBSourceType,
                                     elemDestType, k);
 
@@ -1300,7 +1301,7 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
     if (chipset.majorVersion != 11 && chipset.majorVersion != 12)
       return op->emitOpError("WMMA only supported on gfx11 and gfx12");
 
-    bool isGFX1250 = chipset >= Chipset(12, 5, 0);
+    bool isGFX1250 = chipset >= kGfx1250;
 
     // The WMMA operations represent vectors of bf16s as vectors of i16s
     // (except on gfx1250), so we need to bitcast bfloats to i16 and then
@@ -1725,7 +1726,7 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
   using fp6 = Float6E2M3FNType;
   using bf6 = Float6E3M2FNType;
   Location loc = op.getLoc();
-  if (chipset != Chipset{12, 5, 0}) {
+  if (chipset != kGfx1250) {
     return rewriter.notifyMatchFailure(
         loc,
         "Scaled fp packed conversion instructions are not available on target "

>From 1dbcb95412e274d598c337b700b2524130fd3d25 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Mon, 17 Nov 2025 11:42:14 -0500
Subject: [PATCH 40/43] wip

---
 mlir/test/Conversion/AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/mlir/test/Conversion/AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir b/mlir/test/Conversion/AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir
index 811a8e49dc5c6..87e4eb363d343 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir
@@ -1,4 +1,5 @@
-// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1250 | FileCheck %s
+// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1250 --split-input-file --verify-diagnostics \
+// RUN: | FileCheck %s
 
 // CHECK-LABEL: @scaled_ext_packed816_fp4
 // CHECK-SAME: (%[[SOURCE:.+]]: vector<8xf4E2M1FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>)

>From eee0ce97137e993bed9450d281536e74d21d8108 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Mon, 17 Nov 2025 11:52:55 -0500
Subject: [PATCH 41/43] Add invalid srcElemType case

---
 .../AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir   | 40 +++++++++++++++++++
 mlir/test/Dialect/AMDGPU/invalid.mlir         | 32 ---------------
 2 files changed, 40 insertions(+), 32 deletions(-)

diff --git a/mlir/test/Conversion/AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir b/mlir/test/Conversion/AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir
index 87e4eb363d343..73711a2b98ac9 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir
@@ -115,3 +115,43 @@ func.func @scaled_ext_packed816_bf6(%v: vector<16xf6E3M2FN>, %scale: vector<4xf8
   return %ret0, %ret1, %ret2: vector<16xf16>, vector<16xbf16>, vector<16xf32>
 }
 
+// -----
+
+func.func @amdgpu.scaled_ext_packed816_invalid_block_size_and_first_scale_byte_16(%v: vector<8xf4E2M1FN>, %scale: vector<4xf8E8M0FNU>) {
+  // expected-error at +1 {{'amdgpu.scaled_ext_packed816' op blockSize of 16 can only have firstScaleByte be 0 or 1 for f4 and f6}}
+  %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(16) firstScaleLane(0) firstScaleByte(2) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xf16>
+  func.return
+}
+
+// -----
+
+func.func @amdgpu.scaled_ext_packed816_invalid_block_size_and_first_scale_byte_32(%v: vector<8xf4E2M1FN>, %scale: vector<4xf8E8M0FNU>) {
+  // expected-error at +1 {{'amdgpu.scaled_ext_packed816' op blockSize of 32 can only have firstScaleByte be 0 or 2 for f4 and f6.}}
+  %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(1) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xf16>
+  func.return
+}
+
+// -----
+
+func.func @amdgpu.scaled_ext_packed816_invalid_attributes_for_f8(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) {
+  // expected-error at +1 {{'amdgpu.scaled_ext_packed816' op blockSize of 16 can only have firstScaleByte be 0 or 2 for f8.}}
+  %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(16) firstScaleLane(0) firstScaleByte(1) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xf16>
+  func.return
+}
+
+// -----
+
+func.func @amdgpu.scaled_ext_packed816_invalid_input_output_sizes(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) {
+  // expected-error at +1 {{'amdgpu.scaled_ext_packed816' op failed to verify that all of {source, res} have same shape}}
+  %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(16) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<16xf16>
+  func.return
+}
+
+// -----
+
+func.func @amdgpu.scaled_ext_packed816_invalid_src_elem_type(%v: vector<16xf16>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf16>) {
+  // expected-error at +1 {{'amdgpu.scaled_ext_packed816' op operand #0 must be}}
+  %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf16>, vector<4xf8E8M0FNU> -> vector<16xf16>
+  return %ret0: vector<16xf16>
+}
+
diff --git a/mlir/test/Dialect/AMDGPU/invalid.mlir b/mlir/test/Dialect/AMDGPU/invalid.mlir
index 5c8cc8b67c4b3..61fdf29a78cbd 100644
--- a/mlir/test/Dialect/AMDGPU/invalid.mlir
+++ b/mlir/test/Dialect/AMDGPU/invalid.mlir
@@ -333,38 +333,6 @@ func.func @gather_to_lds_non_lds(%idx1 : index, %mem1 : memref<32xf16>, %mem2 :
 
 // -----
 
-func.func @amdgpu.scaled_ext_packed816_invalid_block_size_and_first_scale_byte_16(%v: vector<8xf4E2M1FN>, %scale: vector<4xf8E8M0FNU>) {
-  // expected-error at +1 {{'amdgpu.scaled_ext_packed816' op blockSize of 16 can only have firstScaleByte be 0 or 1 for f4 and f6}}
-  %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(16) firstScaleLane(0) firstScaleByte(2) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xf16>
-  func.return
-}
-
-// -----
-
-func.func @amdgpu.scaled_ext_packed816_invalid_block_size_and_first_scale_byte_32(%v: vector<8xf4E2M1FN>, %scale: vector<4xf8E8M0FNU>) {
-  // expected-error at +1 {{'amdgpu.scaled_ext_packed816' op blockSize of 32 can only have firstScaleByte be 0 or 2 for f4 and f6.}}
-  %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(1) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xf16>
-  func.return
-}
-
-// -----
-
-func.func @amdgpu.scaled_ext_packed816_invalid_attributes_for_f8(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) {
-  // expected-error at +1 {{'amdgpu.scaled_ext_packed816' op blockSize of 16 can only have firstScaleByte be 0 or 2 for f8.}}
-  %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(16) firstScaleLane(0) firstScaleByte(1) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xf16>
-  func.return
-}
-
-// -----
-
-func.func @amdgpu.scaled_ext_packed816_invalid_input_output_sizes(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) {
-  // expected-error at +1 {{'amdgpu.scaled_ext_packed816' op failed to verify that all of {source, res} have same shape}}
-  %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(16) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<16xf16>
-  func.return
-}
-
-// -----
-
 func.func @scaled_mfma_invalid_m(%arg0 : vector<4xf8E8M0FNU>, %arg1 : vector<32xf4E2M1FN>, %arg2 : vector<16xf32>) -> vector<16xf32> {
   // expected-error at +1 {{'amdgpu.scaled_mfma' op attribute 'm' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {16, 32}}}
   %0 = amdgpu.scaled_mfma 8x32x64 (%arg0[0] * %arg1) * (%arg0[1] * %arg1) + %arg2 : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<16xf32>

>From 9860cdd9ffb108f472a18dff751a3401e13f695e Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Mon, 17 Nov 2025 14:28:08 -0500
Subject: [PATCH 42/43] Update verifiers

---
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           |  2 +-
 mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp  | 45 ++++++++++++-------
 .../AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir   |  7 +++
 3 files changed, 37 insertions(+), 17 deletions(-)

diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 4d1734fe710c0..bf83591bf6047 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1634,7 +1634,7 @@ int32_t getScaleSel(int32_t blockSize, unsigned bitWidth,
   // firstScaleByte are merged into a single attribute scaleSel. This is how
   // those values are merged together.
   assert(llvm::is_contained({16, 32}, blockSize));
-  assert(llvm::is_contained(::llvm::ArrayRef<unsigned>{4, 6, 8}, bitWidth));
+  assert(llvm::is_contained(llvm::ArrayRef<unsigned>{4, 6, 8}, bitWidth));
 
   const bool is_fp8 = bitWidth == 8;
   const bool is_block_16 = blockSize == 16;
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 5c35823678576..955de3bb861ba 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -343,28 +343,41 @@ void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns(
 //===----------------------------------------------------------------------===//
 LogicalResult ScaledExtPacked816Op::verify() {
   int blockSize = getBlockSize();
-  assert((blockSize == 16 || blockSize == 32) && "invalid block size");
+  assert(llvm::is_contained({16, 32}, blockSize) && "invalid block size");
 
   int firstScaleByte = getFirstScaleByte();
+  int firstScaleLane = getFirstScaleLane();
   auto sourceType = cast<VectorType>(getSource().getType());
   Type elementType = sourceType.getElementType();
   auto floatType = cast<FloatType>(elementType);
-  int bitWidth = floatType.getWidth();
+  unsigned bitWidth = floatType.getWidth();
 
-  if (llvm::is_contained({4, 6}, bitWidth) && blockSize == 16 &&
-      !llvm::is_contained({0, 1}, firstScaleByte)) {
-    return emitOpError("blockSize of 16 can only have firstScaleByte be 0 or 1 "
-                       "for f4 and f6.");
-  }
-  if (llvm::is_contained({4, 6}, bitWidth) && blockSize == 32 &&
-      !llvm::is_contained({0, 2}, firstScaleByte)) {
-    return emitOpError("blockSize of 32 can only have firstScaleByte be 0 or 2 "
-                       "for f4 and f6.");
-  }
-  if (bitWidth == 8 && blockSize == 16 &&
-      !llvm::is_contained({0, 2}, firstScaleByte)) {
-    return emitOpError(
-        "blockSize of 16 can only have firstScaleByte be 0 or 2 for f8.");
+  assert(llvm::is_contained(llvm::ArrayRef<unsigned>{4, 6, 8}, bitWidth));
+
+  const bool is_fp8 = bitWidth == 8;
+  const bool is_block_16 = blockSize == 16;
+
+  if (!is_fp8) {
+    if (is_block_16) {
+      if (!llvm::is_contained({0, 1}, firstScaleByte)) {
+        return emitOpError("blockSize of 16 can only have firstScaleByte be 0 "
+                           "or 1 for f4 and f6.");
+      }
+    } else {
+      if (!llvm::is_contained({0, 2}, firstScaleByte)) {
+        return emitOpError("blockSize of 32 can only have firstScaleByte be 0 "
+                           "or 2 for f4 and f6.");
+      }
+    }
+  } else {
+    if (is_block_16) {
+      bool is_valid = ((firstScaleLane == 0) && (firstScaleByte == 0)) ||
+                      ((firstScaleLane == 1) && (firstScaleByte == 2));
+      if (!is_valid) {
+        return emitOpError(
+            "blockSize of 16 can only have firstScaleByte be 0 or 2 for f8.");
+      }
+    }
   }
 
   return success();
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir b/mlir/test/Conversion/AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir
index 73711a2b98ac9..fbe13a29c53ab 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir
@@ -155,3 +155,10 @@ func.func @amdgpu.scaled_ext_packed816_invalid_src_elem_type(%v: vector<16xf16>,
   return %ret0: vector<16xf16>
 }
 
+// -----
+
+func.func @amdgpu.scaled_ext_packed816_invalid_dst_elem_type(%v: vector<16xf6E3M2FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf64>) {
+  // expected-error at +1 {{'amdgpu.scaled_ext_packed816' op result #0 must be vector}}
+  %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xf64>
+  return %ret0: vector<16xf64>
+}

>From 0b8f561e00d2cea1a6e60e53e13720a617611330 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Mon, 17 Nov 2025 14:35:10 -0500
Subject: [PATCH 43/43] Update verifier message

---
 mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp                 | 4 ++--
 mlir/test/Conversion/AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir | 2 +-
 2 files changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 955de3bb861ba..d55f3cec47c1f 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -374,8 +374,8 @@ LogicalResult ScaledExtPacked816Op::verify() {
       bool is_valid = ((firstScaleLane == 0) && (firstScaleByte == 0)) ||
                       ((firstScaleLane == 1) && (firstScaleByte == 2));
       if (!is_valid) {
-        return emitOpError(
-            "blockSize of 16 can only have firstScaleByte be 0 or 2 for f8.");
+        return emitOpError("blockSize of 16 can only have (firstScaleLane, "
+                           "firstScaleByte) be (0, 0) or (1, 2) for f8.");
       }
     }
   }
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir b/mlir/test/Conversion/AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir
index fbe13a29c53ab..d2391140ce056 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir
@@ -134,7 +134,7 @@ func.func @amdgpu.scaled_ext_packed816_invalid_block_size_and_first_scale_byte_3
 // -----
 
 func.func @amdgpu.scaled_ext_packed816_invalid_attributes_for_f8(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) {
-  // expected-error at +1 {{'amdgpu.scaled_ext_packed816' op blockSize of 16 can only have firstScaleByte be 0 or 2 for f8.}}
+  // expected-error at +1 {{'amdgpu.scaled_ext_packed816' op blockSize of 16 can only have (firstScaleLane, firstScaleByte) be (0, 0) or (1, 2) for f8.}}
   %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(16) firstScaleLane(0) firstScaleByte(1) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xf16>
   func.return
 }



More information about the Mlir-commits mailing list