[Mlir-commits] [mlir] [mlir][amdgpu] Define an amdgpu.scaling_mfma wrapper (PR #137498)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri May 2 09:55:06 PDT 2025


https://github.com/Muzammiluddin-Syed-ECE updated https://github.com/llvm/llvm-project/pull/137498

>From c43bc26494473a002c74dfa00088cc9b58a8dd8a Mon Sep 17 00:00:00 2001
From: Muzammiluddin Syed <muzasyed at amd.com>
Date: Fri, 25 Apr 2025 14:02:10 -0500
Subject: [PATCH 1/6] Defining amdgpu.scaled_mfma op

Signed-off-by: Muzammiluddin Syed <muzasyed at amd.com>
---
 mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 48 ++++++++++++++
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           | 63 ++++++++++++++++++-
 mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp  | 43 +++++++++++++
 .../Conversion/AMDGPUToROCDL/mfma-gfx950.mlir | 47 ++++++++++++++
 4 files changed, 198 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index f14aa5a2e1564..d1c601882fc93 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -830,4 +830,52 @@ def AMDGPU_GatherToLDSOp :
   let hasVerifier = 1;
 }
 
+def AMDGPU_ScaledMFMAOp :
+    AMDGPU_Op<"scaled_mfma", [AllTypesMatch<["destC", "destD"]>,
+                        Pure]>,
+    Arguments<(ins
+                   I32Attr:$m,
+                   I32Attr:$n,
+                   I32Attr:$k,
+                   MFMAInTypes:$sourceA,
+                   MFMAInTypes:$sourceB,
+                   MFMAOutTypes:$destC,
+                   I32Attr:$scaleA,
+                   I32Attr:$scaleB,
+                   I32Attr:$opselA,
+                   I32Attr:$opselB)>,
+    Results<(outs MFMAOutTypes: $destD)> {
+  let summary = "MLIR wrapper for CDNA mfma instructions";
+  let description = [{
+    The `amdgpu.scaled_mfma` op is an MLIR wrapper around intrinsics
+    for various scaled versions of `mfma` instructions in the CDNA architecture, which perform
+    multiple outer products in order to allow fast matrix multiplication.
+
+    The wrapper will select an appropriate `mfma` instruction, if one is available,
+    based on the provided `m`, `k`, `n`, and `nBlks` attributes, along with the
+    types of the source and destination arguments.
+
+    Note, this wrapper allows specifying `vector<4Kxi8>` arguments to MFMA
+    intrinsics that take an integer type of width `4K`. For example,
+    one can provide a vector<4xi8> as an argument to an MFMA instruction that
+    logically takes 4 i8s but whose intrinsics are specified to take an i32.
+    In these cases, the bytes in the vector will be concatenated in little-endian
+    order (that is, v[0] will go to arg[7:0], v[1] to arg[15:8] and so on).
+
+    This wrapper takes inspiration from `amdgpu.mfma`, but has some key differences:
+    - `amdgpu.scaled_mfma` operates on fp4 (f4E2M1FN), fp6 (f6E2M3FN and f6E3M2FN) and 
+    fp8 (f8E4M3FN and f8E5M2) types using either M=N=16, K=128 or M=N=32, K=64 as their tile 
+    size. 
+    - `amdgpu.scaled_mfma` does not support broadcasting. So, `cbsz`, `abid`, and `blgp` 
+    are omitted from this wrapper.
+    - The negateA, negateB, and negateC flags in `amdgpu.mfma` are only supported for 
+    double-precision operations on gfx94x and so are not included here. 
+  }];
+  let assemblyFormat = [{
+    $sourceA `*` $sourceB `+` $destC
+    attr-dict
+    `:` type($sourceA) `,` type($sourceB) `,` type($destC)
+  }];
+  let hasVerifier = 1;
+}
 #endif // AMDGPU
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 91dbc2de65c4e..6ace61ae653de 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -833,6 +833,14 @@ mfmaOpToScaledIntrinsic(MFMAOp mfma, Chipset chipset) {
       mfma.getBlocks(), chipset);
 }
 
+static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
+mfmaOpToScaledIntrinsic(ScaledMFMAOp smfma, Chipset chipset) {
+  return mfmaOpToScaledIntrinsic(smfma.getSourceA().getType(),
+                                 smfma.getSourceB().getType(),
+                                 smfma.getDestC().getType(), smfma.getM(),
+                                 smfma.getN(), smfma.getK(), 1u, chipset);
+}
+
 /// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
 /// if one exists. This includes checking to ensure the intrinsic is supported
 /// on the architecture you are compiling for.
@@ -954,6 +962,54 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
   }
 };
 
+struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> {
+  ScaledMFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
+      : ConvertOpToLLVMPattern<ScaledMFMAOp>(converter), chipset(chipset) {}
+
+  Chipset chipset;
+
+  LogicalResult
+  matchAndRewrite(ScaledMFMAOp op, ScaledMFMAOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    Type outType = typeConverter->convertType(op.getDestD().getType());
+    Type intrinsicOutType = outType;
+    if (auto outVecType = dyn_cast<VectorType>(outType))
+      if (outVecType.getElementType().isBF16())
+        intrinsicOutType = outVecType.clone(rewriter.getI16Type());
+
+    if (chipset.majorVersion != 9 || chipset < kGfx908)
+      return op->emitOpError("Scaled MFMA only supported on gfx908+");
+    std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
+        maybeScaledIntrinsic = mfmaOpToScaledIntrinsic(op, chipset);
+    if (!maybeScaledIntrinsic.has_value())
+      return op.emitOpError(
+          "no intrinsic matching Scaled MFMA size on given chipset");
+
+    StringRef intrinsicName = std::get<0>(*maybeScaledIntrinsic);
+    OperationState loweredOp(loc, intrinsicName);
+    loweredOp.addTypes(intrinsicOutType);
+    loweredOp.addOperands(
+        {convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceA()),
+         convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceB()),
+         adaptor.getDestC()});
+    Value scaleA = createI32Constant(rewriter, loc, adaptor.getScaleA());
+    Value scaleB = createI32Constant(rewriter, loc, adaptor.getScaleB());
+    Value opselA = createI32Constant(rewriter, loc, adaptor.getOpselA());
+    Value opselB = createI32Constant(rewriter, loc, adaptor.getOpselB());
+    auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
+    loweredOp.addOperands({createI32Constant(rewriter, loc, aTypeCode),
+                           createI32Constant(rewriter, loc, bTypeCode),
+                           /*scale A byte=*/opselA, /*scale A=*/scaleA,
+                           /*scale B byte=*/opselB, /*scale B=*/scaleB});
+    Value lowered = rewriter.create(loweredOp)->getResult(0);
+    if (outType != intrinsicOutType)
+      lowered = rewriter.create<LLVM::BitcastOp>(loc, outType, lowered);
+    rewriter.replaceOp(op, lowered);
+    return success();
+  }
+};
+
 struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
   WMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
       : ConvertOpToLLVMPattern<WMMAOp>(converter), chipset(chipset) {}
@@ -1474,8 +1530,9 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
            RawBufferOpLowering<RawBufferAtomicCmpswapOp,
                                ROCDL::RawPtrBufferAtomicCmpSwap>,
            AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering,
-           MFMAOpLowering, WMMAOpLowering, ExtPackedFp8OpLowering,
-           PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
-           GatherToLDSOpLowering>(converter, chipset);
+           MFMAOpLowering, ScaledMFMAOpLowering, WMMAOpLowering,
+           ExtPackedFp8OpLowering, PackedTrunc2xFp8OpLowering,
+           PackedStochRoundFp8OpLowering, GatherToLDSOpLowering>(converter,
+                                                                 chipset);
   patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
 }
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 549a4376a4a04..5da40a06dc2f0 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -506,6 +506,49 @@ LogicalResult GatherToLDSOp::verify() {
   return success();
 }
 
+LogicalResult ScaledMFMAOp::verify() {
+  unsigned opselA = getOpselA();
+  unsigned opselB = getOpselB();
+
+  opselA >>= 8;
+  opselB >>= 8;
+
+  if (opselA != 0)
+    return emitOpError("Opsel A must be a zero extended 8 bit value.");
+
+  if (opselB != 0)
+    return emitOpError("Opsel B must be a zero extended 8 bit value.");
+
+  auto validType = [&](Type mlirElemType) {
+    return llvm::TypeSwitch<Type, bool>(mlirElemType)
+        .Case([](Float8E4M3FNType) { return true; })
+        .Case([](Float8E5M2Type) { return true; })
+        .Case([](Float6E2M3FNType) { return true; })
+        .Case([](Float6E3M2FNType) { return true; })
+        .Case([](Float4E2M1FNType) { return true; })
+        .Default([](Type) { return false; });
+  };
+
+  Type aType = getSourceA().getType();
+  Type bType = getSourceB().getType();
+  aType = getElementTypeOrSelf(aType);
+  bType = getElementTypeOrSelf(bType);
+  if (!validType(aType))
+    return emitOpError("Source A must be of element type fp4, fp6 or fp8.");
+  if (!validType(bType))
+    return emitOpError("Source B must be of element type fp4, fp6 or fp8.");
+
+  unsigned m = getM();
+  unsigned n = getN();
+  unsigned k = getK();
+  bool tileConfig1 = (m == n && n == 32 && k == 64);
+  bool tileConfig2 = (m == n && n == 16 && k == 128);
+  if (!tileConfig1 && !tileConfig2)
+    return emitOpError("Invalid tile size for scaled mfma.");
+
+  return success();
+}
+
 #include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
 
 #define GET_ATTRDEF_CLASSES
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir b/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir
index de63f249bb530..f525a37e5ec80 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir
@@ -51,3 +51,50 @@ func.func @mfma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<16xf32>,
 
   func.return
 }
+
+func.func @scaled_mfma_to_rocdl(%arg0 : vector<16xf32>,
+                    %arg1 : vector<4xf32>, %arg2 : vector<32xf8E4M3FN>,
+                    %arg3 : vector<32xf8E5M2>, %arg4 : vector<32xf6E2M3FN>,
+                    %arg5 : vector<32xf6E3M2FN>, %arg6 : vector<32xf4E2M1FN>) {
+  
+  // CHECK: %[[c1:.+]] = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: %[[c2:.+]] = llvm.mlir.constant(2 : i32) : i32
+  // CHECK: %[[c0:.+]] = llvm.mlir.constant(0 : i32) : i32
+
+  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[c0]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+  amdgpu.scaled_mfma %arg2 * %arg2 + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32,  scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<16xf32>
+  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[c0]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+  amdgpu.scaled_mfma %arg2 * %arg2 + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32,  scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<4xf32>
+  
+  // CHECK: llvm.bitcast
+  
+  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c1]], %[[c1]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+  amdgpu.scaled_mfma %arg3 * %arg3 + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32,  scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<16xf32>
+  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c1]], %[[c1]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+  amdgpu.scaled_mfma %arg3 * %arg3 + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32,  scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<4xf32>
+  
+  // CHECK: llvm.bitcast
+  
+  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c2]], %[[c2]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+  amdgpu.scaled_mfma %arg4 * %arg4 + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32,  scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf6E2M3FN>, vector<32xf6E2M3FN>, vector<16xf32>
+  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c2]], %[[c2]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+  amdgpu.scaled_mfma %arg4 * %arg4 + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32,  scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf6E2M3FN>, vector<32xf6E2M3FN>, vector<4xf32>
+  
+  // CHECK: llvm.bitcast
+  // CHECK: %[[c3:.+]] = llvm.mlir.constant(3 : i32) : i32
+
+  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c3]], %[[c3]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+  amdgpu.scaled_mfma %arg5 * %arg5 + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32,  scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf6E3M2FN>, vector<32xf6E3M2FN>, vector<16xf32>
+  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c3]], %[[c3]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+  amdgpu.scaled_mfma %arg5 * %arg5 + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32,  scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf6E3M2FN>, vector<32xf6E3M2FN>, vector<4xf32>
+  
+  // CHECK: llvm.bitcast
+  // CHECK: %[[c4:.+]] = llvm.mlir.constant(4 : i32) : i32
+  
+  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c4]], %[[c4]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<4xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+  amdgpu.scaled_mfma %arg6 * %arg6 + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32,  scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf4E2M1FN>, vector<32xf4E2M1FN>, vector<16xf32>
+  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c4]], %[[c4]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<4xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+  amdgpu.scaled_mfma %arg6 * %arg6 + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32,  scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf4E2M1FN>, vector<32xf4E2M1FN>, vector<4xf32>
+
+  func.return
+}

>From 3ba7ea8cc97b59cafeba682dadedb00a982075d4 Mon Sep 17 00:00:00 2001
From: Muzammiluddin Syed <muzasyed at amd.com>
Date: Mon, 28 Apr 2025 00:37:09 -0500
Subject: [PATCH 2/6] PR Review Round 1

Signed-off-by: Muzammiluddin Syed <muzasyed at amd.com>
---
 mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td |  8 ++---
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           | 33 ++++++++---------
 mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp  | 35 +++++++------------
 3 files changed, 34 insertions(+), 42 deletions(-)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index d1c601882fc93..d8d59c2f27924 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -804,7 +804,7 @@ def AMDGPU_GatherToLDSOp :
                    TypeAttr:$transferType
                    )>,
     Results<(outs)> {
-  let summary = "MLIR wrapper for CDNA mfma instructions";
+  let summary = "MLIR wrapper for CDNA Gather to LDS instructions";
   let description = [{
     The `amdgpu.global_load` op is a wrapper around the `global_load_lds` instructions.
 
@@ -845,7 +845,7 @@ def AMDGPU_ScaledMFMAOp :
                    I32Attr:$opselA,
                    I32Attr:$opselB)>,
     Results<(outs MFMAOutTypes: $destD)> {
-  let summary = "MLIR wrapper for CDNA mfma instructions";
+  let summary = "MLIR wrapper for CDNA scaled mfma instructions";
   let description = [{
     The `amdgpu.scaled_mfma` op is an MLIR wrapper around intrinsics
     for various scaled versions of `mfma` instructions in the CDNA architecture, which perform
@@ -857,7 +857,7 @@ def AMDGPU_ScaledMFMAOp :
 
     Note, this wrapper allows specifying `vector<4Kxi8>` arguments to MFMA
     intrinsics that take an integer type of width `4K`. For example,
-    one can provide a vector<4xi8> as an argument to an MFMA instruction that
+    one can provide a `vector<4xi8>` as an argument to an MFMA instruction that
     logically takes 4 i8s but whose intrinsics are specified to take an i32.
     In these cases, the bytes in the vector will be concatenated in little-endian
     order (that is, v[0] will go to arg[7:0], v[1] to arg[15:8] and so on).
@@ -868,7 +868,7 @@ def AMDGPU_ScaledMFMAOp :
     size. 
     - `amdgpu.scaled_mfma` does not support broadcasting. So, `cbsz`, `abid`, and `blgp` 
     are omitted from this wrapper.
-    - The negateA, negateB, and negateC flags in `amdgpu.mfma` are only supported for 
+    - The `negateA`, `negateB`, and `negateC` flags in `amdgpu.mfma` are only supported for 
     double-precision operations on gfx94x and so are not included here. 
   }];
   let assemblyFormat = [{
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 6ace61ae653de..af71ba7de5c6b 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -23,6 +23,7 @@
 
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Casting.h"
 #include <optional>
 
 namespace mlir {
@@ -826,19 +827,20 @@ mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m,
 }
 
 static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
-mfmaOpToScaledIntrinsic(MFMAOp mfma, Chipset chipset) {
-  return mfmaOpToScaledIntrinsic(
-      mfma.getSourceA().getType(), mfma.getSourceB().getType(),
-      mfma.getDestC().getType(), mfma.getM(), mfma.getN(), mfma.getK(),
-      mfma.getBlocks(), chipset);
-}
-
-static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
-mfmaOpToScaledIntrinsic(ScaledMFMAOp smfma, Chipset chipset) {
-  return mfmaOpToScaledIntrinsic(smfma.getSourceA().getType(),
-                                 smfma.getSourceB().getType(),
-                                 smfma.getDestC().getType(), smfma.getM(),
-                                 smfma.getN(), smfma.getK(), 1u, chipset);
+mfmaOpToScaledIntrinsic(Operation *op, Chipset chipset) {
+  if (auto mfma = llvm::dyn_cast_or_null<MFMAOp>(op)) {
+    return mfmaOpToScaledIntrinsic(
+        mfma.getSourceA().getType(), mfma.getSourceB().getType(),
+        mfma.getDestC().getType(), mfma.getM(), mfma.getN(), mfma.getK(),
+        mfma.getBlocks(), chipset);
+  }
+  if (auto smfma = llvm::dyn_cast_or_null<ScaledMFMAOp>(op)) {
+    return mfmaOpToScaledIntrinsic(smfma.getSourceA().getType(),
+                                   smfma.getSourceB().getType(),
+                                   smfma.getDestC().getType(), smfma.getM(),
+                                   smfma.getN(), smfma.getK(), 1u, chipset);
+  }
+  return std::nullopt;
 }
 
 /// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
@@ -964,7 +966,7 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
 
 struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> {
   ScaledMFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
-      : ConvertOpToLLVMPattern<ScaledMFMAOp>(converter), chipset(chipset) {}
+      : ConvertOpToLLVMPattern(converter), chipset(chipset) {}
 
   Chipset chipset;
 
@@ -986,7 +988,7 @@ struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> {
       return op.emitOpError(
           "no intrinsic matching Scaled MFMA size on given chipset");
 
-    StringRef intrinsicName = std::get<0>(*maybeScaledIntrinsic);
+    auto [intrinsicName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
     OperationState loweredOp(loc, intrinsicName);
     loweredOp.addTypes(intrinsicOutType);
     loweredOp.addOperands(
@@ -997,7 +999,6 @@ struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> {
     Value scaleB = createI32Constant(rewriter, loc, adaptor.getScaleB());
     Value opselA = createI32Constant(rewriter, loc, adaptor.getOpselA());
     Value opselB = createI32Constant(rewriter, loc, adaptor.getOpselB());
-    auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
     loweredOp.addOperands({createI32Constant(rewriter, loc, aTypeCode),
                            createI32Constant(rewriter, loc, bTypeCode),
                            /*scale A byte=*/opselA, /*scale A=*/scaleA,
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 5da40a06dc2f0..1ffba5c542b1b 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -507,36 +507,27 @@ LogicalResult GatherToLDSOp::verify() {
 }
 
 LogicalResult ScaledMFMAOp::verify() {
-  unsigned opselA = getOpselA();
-  unsigned opselB = getOpselB();
-
-  opselA >>= 8;
-  opselB >>= 8;
+  unsigned opselA = getOpselA() >> 8;
+  unsigned opselB = getOpselB() >> 8;
 
   if (opselA != 0)
-    return emitOpError("Opsel A must be a zero extended 8 bit value.");
+    return emitOpError("Opsel A must be a zero extended 8 bit value");
 
   if (opselB != 0)
-    return emitOpError("Opsel B must be a zero extended 8 bit value.");
-
-  auto validType = [&](Type mlirElemType) {
-    return llvm::TypeSwitch<Type, bool>(mlirElemType)
-        .Case([](Float8E4M3FNType) { return true; })
-        .Case([](Float8E5M2Type) { return true; })
-        .Case([](Float6E2M3FNType) { return true; })
-        .Case([](Float6E3M2FNType) { return true; })
-        .Case([](Float4E2M1FNType) { return true; })
-        .Default([](Type) { return false; });
-  };
+    return emitOpError("Opsel B must be a zero extended 8 bit value");
+
+  auto isValidType =
+      llvm::IsaPred<Float8E4M3FNType, Float8E5M2Type, Float6E2M3FNType,
+                    Float6E3M2FNType, Float4E2M1FNType>;
 
   Type aType = getSourceA().getType();
   Type bType = getSourceB().getType();
   aType = getElementTypeOrSelf(aType);
   bType = getElementTypeOrSelf(bType);
-  if (!validType(aType))
-    return emitOpError("Source A must be of element type fp4, fp6 or fp8.");
-  if (!validType(bType))
-    return emitOpError("Source B must be of element type fp4, fp6 or fp8.");
+  if (!isValidType(aType))
+    return emitOpError("Source A must be of element type fp4, fp6 or fp8");
+  if (!isValidType(bType))
+    return emitOpError("Source B must be of element type fp4, fp6 or fp8");
 
   unsigned m = getM();
   unsigned n = getN();
@@ -544,7 +535,7 @@ LogicalResult ScaledMFMAOp::verify() {
   bool tileConfig1 = (m == n && n == 32 && k == 64);
   bool tileConfig2 = (m == n && n == 16 && k == 128);
   if (!tileConfig1 && !tileConfig2)
-    return emitOpError("Invalid tile size for scaled mfma.");
+    return emitOpError("Invalid tile size for scaled mfma");
 
   return success();
 }

>From 846c389810800b490a7dd0fcdb75230e93945e75 Mon Sep 17 00:00:00 2001
From: Muzammiluddin Syed <muzasyed at amd.com>
Date: Tue, 29 Apr 2025 21:03:34 -0500
Subject: [PATCH 3/6] PR review round 2

Signed-off-by: Muzammiluddin Syed <muzasyed at amd.com>
---
 mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 25 ++++++----
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           | 33 ++++++-------
 mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp  | 37 ++++-----------
 .../Conversion/AMDGPUToROCDL/mfma-gfx950.mlir | 47 ++++++++++---------
 4 files changed, 62 insertions(+), 80 deletions(-)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index d8d59c2f27924..f1527a5f5cca3 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -687,6 +687,11 @@ def MFMAOutTypes : AnyTypeOf<[F64,
                               VectorOfLengthAndType<[4, 16, 32], [F32]>,
                               VectorOfLengthAndType<[4, 16, 32], [I32]>,
                               VectorOfLengthAndType<[4], [F64]>]>;
+// scaled_mfma
+def ScaledMFMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[8], [F8E5M2FNUZ, F8E4M3FNUZ]>,
+                             VectorOfLengthAndType<[8, 32], [F8E5M2, F8E4M3FN]>,
+                             VectorOfLengthAndType<[32], [F6E2M3FN, F6E3M2FN, F4E2M1FN]>]>;
+def ScaledMFMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 16, 32], [F32]>]>;
 // wmma
 def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<
                              [4, 8, 16],
@@ -837,14 +842,14 @@ def AMDGPU_ScaledMFMAOp :
                    I32Attr:$m,
                    I32Attr:$n,
                    I32Attr:$k,
-                   MFMAInTypes:$sourceA,
-                   MFMAInTypes:$sourceB,
-                   MFMAOutTypes:$destC,
-                   I32Attr:$scaleA,
-                   I32Attr:$scaleB,
-                   I32Attr:$opselA,
-                   I32Attr:$opselB)>,
-    Results<(outs MFMAOutTypes: $destD)> {
+                   ScaledMFMAInTypes:$sourceA,
+                   ScaledMFMAInTypes:$sourceB,
+                   ScaledMFMAOutTypes:$destC,
+                   AnyTypeOf<[I8, FixedVectorOfLengthAndType<[4], [I8]>]>:$scalesA,
+                   AnyTypeOf<[I8, FixedVectorOfLengthAndType<[4], [I8]>]>:$scalesB,
+                   I32Attr:$scalesIdxA,
+                   I32Attr:$scalesIdxB)>,
+    Results<(outs ScaledMFMAOutTypes: $destD)> {
   let summary = "MLIR wrapper for CDNA scaled mfma instructions";
   let description = [{
     The `amdgpu.scaled_mfma` op is an MLIR wrapper around intrinsics
@@ -872,9 +877,9 @@ def AMDGPU_ScaledMFMAOp :
     double-precision operations on gfx94x and so are not included here. 
   }];
   let assemblyFormat = [{
-    $sourceA `*` $sourceB `+` $destC
+    `(` $scalesA `[` $scalesIdxA `]` `*` $sourceA `)` `*` `(` $scalesB `[` $scalesIdxB `]` `*` $sourceB `)` `+` $destC
     attr-dict
-    `:` type($sourceA) `,` type($sourceB) `,` type($destC)
+    `:` type($sourceA) `,` type($scalesA) `,` type($sourceB) `,` type($scalesB) `,` type($destC)
   }];
   let hasVerifier = 1;
 }
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index af71ba7de5c6b..62d05867681ff 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -974,19 +974,15 @@ struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> {
   matchAndRewrite(ScaledMFMAOp op, ScaledMFMAOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
-    Type outType = typeConverter->convertType(op.getDestD().getType());
-    Type intrinsicOutType = outType;
-    if (auto outVecType = dyn_cast<VectorType>(outType))
-      if (outVecType.getElementType().isBF16())
-        intrinsicOutType = outVecType.clone(rewriter.getI16Type());
+    Type intrinsicOutType = typeConverter->convertType(op.getDestD().getType());
 
-    if (chipset.majorVersion != 9 || chipset < kGfx908)
-      return op->emitOpError("Scaled MFMA only supported on gfx908+");
+    if (chipset.majorVersion != 9 || chipset < kGfx950)
+      return op->emitOpError("scaled MFMA only supported on gfx908+");
     std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
         maybeScaledIntrinsic = mfmaOpToScaledIntrinsic(op, chipset);
     if (!maybeScaledIntrinsic.has_value())
       return op.emitOpError(
-          "no intrinsic matching Scaled MFMA size on given chipset");
+          "no intrinsic matching scaled MFMA size on given chipset");
 
     auto [intrinsicName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
     OperationState loweredOp(loc, intrinsicName);
@@ -995,17 +991,18 @@ struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> {
         {convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceA()),
          convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceB()),
          adaptor.getDestC()});
-    Value scaleA = createI32Constant(rewriter, loc, adaptor.getScaleA());
-    Value scaleB = createI32Constant(rewriter, loc, adaptor.getScaleB());
-    Value opselA = createI32Constant(rewriter, loc, adaptor.getOpselA());
-    Value opselB = createI32Constant(rewriter, loc, adaptor.getOpselB());
-    loweredOp.addOperands({createI32Constant(rewriter, loc, aTypeCode),
-                           createI32Constant(rewriter, loc, bTypeCode),
-                           /*scale A byte=*/opselA, /*scale A=*/scaleA,
-                           /*scale B byte=*/opselB, /*scale B=*/scaleB});
+    Value scalesIdxA = createI32Constant(rewriter, loc, adaptor.getScalesIdxA());
+    Value scalesIdxB = createI32Constant(rewriter, loc, adaptor.getScalesIdxB());
+    loweredOp.addOperands(
+        {createI32Constant(rewriter, loc, aTypeCode),
+         createI32Constant(rewriter, loc, bTypeCode),
+         /*scales A*/
+         convertMFMAVectorOperand(rewriter, loc, adaptor.getScalesA()),
+         /*scales B*/
+         convertMFMAVectorOperand(rewriter, loc, adaptor.getScalesB()),
+         /*scales idx A=*/scalesIdxA,
+         /*scales idx B=*/scalesIdxB});
     Value lowered = rewriter.create(loweredOp)->getResult(0);
-    if (outType != intrinsicOutType)
-      lowered = rewriter.create<LLVM::BitcastOp>(loc, outType, lowered);
     rewriter.replaceOp(op, lowered);
     return success();
   }
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 1ffba5c542b1b..1e1cdd680033e 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -507,35 +507,14 @@ LogicalResult GatherToLDSOp::verify() {
 }
 
 LogicalResult ScaledMFMAOp::verify() {
-  unsigned opselA = getOpselA() >> 8;
-  unsigned opselB = getOpselB() >> 8;
-
-  if (opselA != 0)
-    return emitOpError("Opsel A must be a zero extended 8 bit value");
-
-  if (opselB != 0)
-    return emitOpError("Opsel B must be a zero extended 8 bit value");
-
-  auto isValidType =
-      llvm::IsaPred<Float8E4M3FNType, Float8E5M2Type, Float6E2M3FNType,
-                    Float6E3M2FNType, Float4E2M1FNType>;
-
-  Type aType = getSourceA().getType();
-  Type bType = getSourceB().getType();
-  aType = getElementTypeOrSelf(aType);
-  bType = getElementTypeOrSelf(bType);
-  if (!isValidType(aType))
-    return emitOpError("Source A must be of element type fp4, fp6 or fp8");
-  if (!isValidType(bType))
-    return emitOpError("Source B must be of element type fp4, fp6 or fp8");
-
-  unsigned m = getM();
-  unsigned n = getN();
-  unsigned k = getK();
-  bool tileConfig1 = (m == n && n == 32 && k == 64);
-  bool tileConfig2 = (m == n && n == 16 && k == 128);
-  if (!tileConfig1 && !tileConfig2)
-    return emitOpError("Invalid tile size for scaled mfma");
+  unsigned scalesIdxA = getScalesIdxA();
+  unsigned scalesIdxB = getScalesIdxB();
+
+  if (scalesIdxA > 3)
+    return emitOpError("scales idx A must be a value from 0 to 3 inclusive");
+
+  if (scalesIdxB > 3)
+    return emitOpError("scales idx B must be a value from 0 to 3 inclusive");
 
   return success();
 }
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir b/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir
index f525a37e5ec80..76113076845d6 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir
@@ -55,46 +55,47 @@ func.func @mfma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<16xf32>,
 func.func @scaled_mfma_to_rocdl(%arg0 : vector<16xf32>,
                     %arg1 : vector<4xf32>, %arg2 : vector<32xf8E4M3FN>,
                     %arg3 : vector<32xf8E5M2>, %arg4 : vector<32xf6E2M3FN>,
-                    %arg5 : vector<32xf6E3M2FN>, %arg6 : vector<32xf4E2M1FN>) {
+                    %arg5 : vector<32xf6E3M2FN>, %arg6 : vector<32xf4E2M1FN>, 
+                    %arg7 : vector<4xi8>, %arg8 : i8) {
   
-  // CHECK: %[[c1:.+]] = llvm.mlir.constant(1 : i32) : i32
-  // CHECK: %[[c2:.+]] = llvm.mlir.constant(2 : i32) : i32
   // CHECK: %[[c0:.+]] = llvm.mlir.constant(0 : i32) : i32
+  // CHECK: %[[c1:.+]] = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: llvm.bitcast
 
-  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[c0]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
-  amdgpu.scaled_mfma %arg2 * %arg2 + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32,  scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<16xf32>
-  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[c0]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
-  amdgpu.scaled_mfma %arg2 * %arg2 + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32,  scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<4xf32>
+  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[c1]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i8, i32, i32) -> vector<16xf32>
+  amdgpu.scaled_mfma ( %arg7 [ 0 ] * %arg2 ) * ( %arg8 [ 1 ] * %arg2 ) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<32xf8E4M3FN>, vector<4xi8>, vector<32xf8E4M3FN>, i8, vector<16xf32>
+  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[c1]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i8, i32, i32) -> vector<4xf32>
+  amdgpu.scaled_mfma ( %arg7 [ 0 ] * %arg2 ) * ( %arg8 [ 1 ] * %arg2 ) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<32xf8E4M3FN>, vector<4xi8>, vector<32xf8E4M3FN>, i8, vector<4xf32>
   
   // CHECK: llvm.bitcast
   
-  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c1]], %[[c1]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
-  amdgpu.scaled_mfma %arg3 * %arg3 + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32,  scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<16xf32>
-  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c1]], %[[c1]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
-  amdgpu.scaled_mfma %arg3 * %arg3 + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32,  scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<4xf32>
+  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[c1]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i8, i32, i32) -> vector<16xf32>
+  amdgpu.scaled_mfma ( %arg7 [ 0 ] * %arg3 ) * ( %arg8 [ 1 ] * %arg3 ) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<32xf8E5M2>, vector<4xi8>, vector<32xf8E5M2>, i8, vector<16xf32>
+  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[c1]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i8, i32, i32) -> vector<4xf32>
+  amdgpu.scaled_mfma ( %arg7 [ 0 ] * %arg3 ) * ( %arg8 [ 1 ] * %arg3 ) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<32xf8E5M2>, vector<4xi8>, vector<32xf8E5M2>, i8, vector<4xf32>
   
   // CHECK: llvm.bitcast
   
-  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c2]], %[[c2]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
-  amdgpu.scaled_mfma %arg4 * %arg4 + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32,  scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf6E2M3FN>, vector<32xf6E2M3FN>, vector<16xf32>
-  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c2]], %[[c2]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
-  amdgpu.scaled_mfma %arg4 * %arg4 + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32,  scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf6E2M3FN>, vector<32xf6E2M3FN>, vector<4xf32>
+  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[c1]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i8, i32, i32) -> vector<16xf32>
+  amdgpu.scaled_mfma ( %arg7 [ 0 ] * %arg4 ) * ( %arg8 [ 1 ] * %arg4 ) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<32xf6E2M3FN>, vector<4xi8>, vector<32xf6E2M3FN>, i8, vector<16xf32>
+  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[c1]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i8, i32, i32) -> vector<4xf32>
+  amdgpu.scaled_mfma ( %arg7 [ 0 ] * %arg4 ) * ( %arg8 [ 1 ] * %arg4 ) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<32xf6E2M3FN>, vector<4xi8>, vector<32xf6E2M3FN>, i8, vector<4xf32>
   
   // CHECK: llvm.bitcast
   // CHECK: %[[c3:.+]] = llvm.mlir.constant(3 : i32) : i32
 
-  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c3]], %[[c3]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
-  amdgpu.scaled_mfma %arg5 * %arg5 + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32,  scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf6E3M2FN>, vector<32xf6E3M2FN>, vector<16xf32>
-  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c3]], %[[c3]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
-  amdgpu.scaled_mfma %arg5 * %arg5 + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32,  scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf6E3M2FN>, vector<32xf6E3M2FN>, vector<4xf32>
+  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[c1]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i8, i32, i32) -> vector<16xf32>
+  amdgpu.scaled_mfma ( %arg7 [ 0 ] * %arg5 ) * ( %arg8 [ 1 ] * %arg5 ) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<32xf6E3M2FN>, vector<4xi8>, vector<32xf6E3M2FN>, i8, vector<16xf32>
+  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[c1]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i8, i32, i32) -> vector<4xf32>
+  amdgpu.scaled_mfma ( %arg7 [ 0 ] * %arg5 ) * ( %arg8 [ 1 ] * %arg5 ) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<32xf6E3M2FN>, vector<4xi8>, vector<32xf6E3M2FN>, i8, vector<4xf32>
   
   // CHECK: llvm.bitcast
   // CHECK: %[[c4:.+]] = llvm.mlir.constant(4 : i32) : i32
   
-  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c4]], %[[c4]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<4xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
-  amdgpu.scaled_mfma %arg6 * %arg6 + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32,  scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf4E2M1FN>, vector<32xf4E2M1FN>, vector<16xf32>
-  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c4]], %[[c4]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<4xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
-  amdgpu.scaled_mfma %arg6 * %arg6 + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32,  scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf4E2M1FN>, vector<32xf4E2M1FN>, vector<4xf32>
+  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[c1]] : (vector<4xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i8, i32, i32) -> vector<16xf32>
+  amdgpu.scaled_mfma ( %arg7 [ 0 ] * %arg6 ) * ( %arg8 [ 1 ] * %arg6 ) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<32xf4E2M1FN>, vector<4xi8>, vector<32xf4E2M1FN>, i8, vector<16xf32>
+  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[c1]] : (vector<4xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i8, i32, i32) -> vector<4xf32>
+  amdgpu.scaled_mfma ( %arg7 [ 0 ] * %arg6 ) * ( %arg8 [ 1 ] * %arg6 ) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<32xf4E2M1FN>, vector<4xi8>, vector<32xf4E2M1FN>, i8, vector<4xf32>
 
   func.return
 }

>From 02f5d98008983f7e30da3fc711caac0bfd489f84 Mon Sep 17 00:00:00 2001
From: Muzammiluddin Syed <muzasyed at amd.com>
Date: Wed, 30 Apr 2025 11:37:00 -0500
Subject: [PATCH 4/6] PR review round 3

Signed-off-by: Muzammiluddin Syed <muzasyed at amd.com>
---
 mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 15 +++--
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           | 60 ++++++++++++-------
 mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp  | 13 ----
 .../Conversion/AMDGPUToROCDL/mfma-gfx950.mlir | 47 ++++++++-------
 mlir/test/Dialect/AMDGPU/ops.mlir             |  7 +++
 5 files changed, 78 insertions(+), 64 deletions(-)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index f1527a5f5cca3..744aa7a28a4b9 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -688,10 +688,9 @@ def MFMAOutTypes : AnyTypeOf<[F64,
                               VectorOfLengthAndType<[4, 16, 32], [I32]>,
                               VectorOfLengthAndType<[4], [F64]>]>;
 // scaled_mfma
-def ScaledMFMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[8], [F8E5M2FNUZ, F8E4M3FNUZ]>,
-                             VectorOfLengthAndType<[8, 32], [F8E5M2, F8E4M3FN]>,
-                             VectorOfLengthAndType<[32], [F6E2M3FN, F6E3M2FN, F4E2M1FN]>]>;
-def ScaledMFMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 16, 32], [F32]>]>;
+def ScaledMFMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[32], [F8E5M2, F8E4M3FN]>,
+                                   VectorOfLengthAndType<[32], [F6E2M3FN, F6E3M2FN, F4E2M1FN]>]>;
+def ScaledMFMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 16], [F32]>]>;
 // wmma
 def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<
                              [4, 8, 16],
@@ -847,8 +846,9 @@ def AMDGPU_ScaledMFMAOp :
                    ScaledMFMAOutTypes:$destC,
                    AnyTypeOf<[I8, FixedVectorOfLengthAndType<[4], [I8]>]>:$scalesA,
                    AnyTypeOf<[I8, FixedVectorOfLengthAndType<[4], [I8]>]>:$scalesB,
-                   I32Attr:$scalesIdxA,
-                   I32Attr:$scalesIdxB)>,
+                   ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<3>]>:$scalesIdxA,
+                   ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<3>]>:$scalesIdxB
+                   )>,
     Results<(outs ScaledMFMAOutTypes: $destD)> {
   let summary = "MLIR wrapper for CDNA scaled mfma instructions";
   let description = [{
@@ -879,8 +879,7 @@ def AMDGPU_ScaledMFMAOp :
   let assemblyFormat = [{
     `(` $scalesA `[` $scalesIdxA `]` `*` $sourceA `)` `*` `(` $scalesB `[` $scalesIdxB `]` `*` $sourceB `)` `+` $destC
     attr-dict
-    `:` type($sourceA) `,` type($scalesA) `,` type($sourceB) `,` type($scalesB) `,` type($destC)
+    `:` type($scalesA) `,` type($sourceA) `,` type($scalesB) `,` type($sourceB) `,` type($destC)
   }];
-  let hasVerifier = 1;
 }
 #endif // AMDGPU
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 62d05867681ff..01525154cc8c5 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -529,6 +529,25 @@ static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter,
   return input;
 }
 
+/// Converts the scaled MFMA operands, `scalesA` and `scalesB`, from MLIR AMDGPU
+/// dialect convention to ROCDL and LLVM AMDGPU intrinsics convention.
+///
+/// Specifically:
+/// 1. If `input` is a i8 value, zero extend it to i32
+/// 2. If `input` is a vector of length 4 and type i8, cast it to i32
+///
+/// Note that the type of `input` has already been LLVM type converted:
+/// therefore 8-bit and smaller floats are represented as their corresponding
+/// `iN` integers.
+static Value castScaledMFMAVectorOperand(ConversionPatternRewriter &rewriter,
+                                         Location loc, Value input) {
+  Type inputType = input.getType();
+  Type outputType = rewriter.getI32Type();
+  if (auto intType = dyn_cast<IntegerType>(inputType))
+    return rewriter.create<LLVM::ZExtOp>(loc, outputType, input);
+  return rewriter.create<LLVM::BitcastOp>(loc, outputType, input);
+}
+
 /// Push an input operand. If it is a float type, nothing to do. If it is
 /// an integer type, then we need to also push its signdness (1 for signed, 0
 /// for unsigned) and we need to pack the input 16xi8 vector into a 4xi32
@@ -827,20 +846,19 @@ mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m,
 }
 
 static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
-mfmaOpToScaledIntrinsic(Operation *op, Chipset chipset) {
-  if (auto mfma = llvm::dyn_cast_or_null<MFMAOp>(op)) {
-    return mfmaOpToScaledIntrinsic(
-        mfma.getSourceA().getType(), mfma.getSourceB().getType(),
-        mfma.getDestC().getType(), mfma.getM(), mfma.getN(), mfma.getK(),
-        mfma.getBlocks(), chipset);
-  }
-  if (auto smfma = llvm::dyn_cast_or_null<ScaledMFMAOp>(op)) {
-    return mfmaOpToScaledIntrinsic(smfma.getSourceA().getType(),
-                                   smfma.getSourceB().getType(),
-                                   smfma.getDestC().getType(), smfma.getM(),
-                                   smfma.getN(), smfma.getK(), 1u, chipset);
-  }
-  return std::nullopt;
+mfmaOpToScaledIntrinsic(MFMAOp mfma, Chipset chipset) {
+  return mfmaOpToScaledIntrinsic(
+      mfma.getSourceA().getType(), mfma.getSourceB().getType(),
+      mfma.getDestC().getType(), mfma.getM(), mfma.getN(), mfma.getK(),
+      mfma.getBlocks(), chipset);
+}
+
+static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
+mfmaOpToScaledIntrinsic(ScaledMFMAOp smfma, Chipset chipset) {
+  return mfmaOpToScaledIntrinsic(smfma.getSourceA().getType(),
+                                 smfma.getSourceB().getType(),
+                                 smfma.getDestC().getType(), smfma.getM(),
+                                 smfma.getN(), smfma.getK(), 1u, chipset);
 }
 
 /// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
@@ -991,17 +1009,19 @@ struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> {
         {convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceA()),
          convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceB()),
          adaptor.getDestC()});
-    Value scalesIdxA = createI32Constant(rewriter, loc, adaptor.getScalesIdxA());
-    Value scalesIdxB = createI32Constant(rewriter, loc, adaptor.getScalesIdxB());
+    Value scalesIdxA =
+        createI32Constant(rewriter, loc, adaptor.getScalesIdxA());
+    Value scalesIdxB =
+        createI32Constant(rewriter, loc, adaptor.getScalesIdxB());
     loweredOp.addOperands(
         {createI32Constant(rewriter, loc, aTypeCode),
          createI32Constant(rewriter, loc, bTypeCode),
+         /*scales idx A=*/scalesIdxA,
          /*scales A*/
-         convertMFMAVectorOperand(rewriter, loc, adaptor.getScalesA()),
+         castScaledMFMAVectorOperand(rewriter, loc, adaptor.getScalesA()),
+         /*scales idx B=*/scalesIdxB,
          /*scales B*/
-         convertMFMAVectorOperand(rewriter, loc, adaptor.getScalesB()),
-         /*scales idx A=*/scalesIdxA,
-         /*scales idx B=*/scalesIdxB});
+         castScaledMFMAVectorOperand(rewriter, loc, adaptor.getScalesB())});
     Value lowered = rewriter.create(loweredOp)->getResult(0);
     rewriter.replaceOp(op, lowered);
     return success();
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 1e1cdd680033e..549a4376a4a04 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -506,19 +506,6 @@ LogicalResult GatherToLDSOp::verify() {
   return success();
 }
 
-LogicalResult ScaledMFMAOp::verify() {
-  unsigned scalesIdxA = getScalesIdxA();
-  unsigned scalesIdxB = getScalesIdxB();
-
-  if (scalesIdxA > 3)
-    return emitOpError("scales idx A must be a value from 0 to 3 inclusive");
-
-  if (scalesIdxB > 3)
-    return emitOpError("scales idx B must be a value from 0 to 3 inclusive");
-
-  return success();
-}
-
 #include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
 
 #define GET_ATTRDEF_CLASSES
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir b/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir
index 76113076845d6..f9fdd0ef4cbfc 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir
@@ -60,42 +60,43 @@ func.func @scaled_mfma_to_rocdl(%arg0 : vector<16xf32>,
   
   // CHECK: %[[c0:.+]] = llvm.mlir.constant(0 : i32) : i32
   // CHECK: %[[c1:.+]] = llvm.mlir.constant(1 : i32) : i32
-  // CHECK: llvm.bitcast
+  // CHECK: %[[c2:.+]] = llvm.bitcast{{.*}} : vector<4xi8> to i32
+  // CHECK: %[[c3:.+]] = llvm.zext{{.*}} : i8 to i32
 
-  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[c1]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i8, i32, i32) -> vector<16xf32>
-  amdgpu.scaled_mfma ( %arg7 [ 0 ] * %arg2 ) * ( %arg8 [ 1 ] * %arg2 ) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<32xf8E4M3FN>, vector<4xi8>, vector<32xf8E4M3FN>, i8, vector<16xf32>
-  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[c1]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i8, i32, i32) -> vector<4xf32>
-  amdgpu.scaled_mfma ( %arg7 [ 0 ] * %arg2 ) * ( %arg8 [ 1 ] * %arg2 ) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<32xf8E4M3FN>, vector<4xi8>, vector<32xf8E4M3FN>, i8, vector<4xf32>
+  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[c2]], %[[c1]], %[[c3]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+  amdgpu.scaled_mfma(%arg7[0] * %arg2) * (%arg8[1] * %arg2) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<4xi8>, vector<32xf8E4M3FN>, i8, vector<32xf8E4M3FN>, vector<16xf32>
+  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[c2]], %[[c1]], %[[c3]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+  amdgpu.scaled_mfma(%arg7[0] * %arg2) * (%arg8[1] * %arg2) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<4xi8>, vector<32xf8E4M3FN>, i8, vector<32xf8E4M3FN>, vector<4xf32>
   
   // CHECK: llvm.bitcast
   
-  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[c1]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i8, i32, i32) -> vector<16xf32>
-  amdgpu.scaled_mfma ( %arg7 [ 0 ] * %arg3 ) * ( %arg8 [ 1 ] * %arg3 ) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<32xf8E5M2>, vector<4xi8>, vector<32xf8E5M2>, i8, vector<16xf32>
-  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[c1]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i8, i32, i32) -> vector<4xf32>
-  amdgpu.scaled_mfma ( %arg7 [ 0 ] * %arg3 ) * ( %arg8 [ 1 ] * %arg3 ) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<32xf8E5M2>, vector<4xi8>, vector<32xf8E5M2>, i8, vector<4xf32>
+  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[c2]], %[[c1]], %[[c3]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+  amdgpu.scaled_mfma(%arg7[0] * %arg3) * (%arg8[1] * %arg3) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<4xi8>, vector<32xf8E5M2>, i8, vector<32xf8E5M2>, vector<16xf32>
+  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[c2]], %[[c1]], %[[c3]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+  amdgpu.scaled_mfma(%arg7[0] * %arg3) * (%arg8[1] * %arg3) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<4xi8>, vector<32xf8E5M2>, i8, vector<32xf8E5M2>, vector<4xf32>
   
   // CHECK: llvm.bitcast
   
-  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[c1]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i8, i32, i32) -> vector<16xf32>
-  amdgpu.scaled_mfma ( %arg7 [ 0 ] * %arg4 ) * ( %arg8 [ 1 ] * %arg4 ) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<32xf6E2M3FN>, vector<4xi8>, vector<32xf6E2M3FN>, i8, vector<16xf32>
-  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[c1]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i8, i32, i32) -> vector<4xf32>
-  amdgpu.scaled_mfma ( %arg7 [ 0 ] * %arg4 ) * ( %arg8 [ 1 ] * %arg4 ) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<32xf6E2M3FN>, vector<4xi8>, vector<32xf6E2M3FN>, i8, vector<4xf32>
+  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[c2]], %[[c1]], %[[c3]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+  amdgpu.scaled_mfma(%arg7[0] * %arg4) * (%arg8[1] * %arg4) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<4xi8>, vector<32xf6E2M3FN>, i8, vector<32xf6E2M3FN>, vector<16xf32>
+  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[c2]], %[[c1]], %[[c3]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+  amdgpu.scaled_mfma(%arg7[0] * %arg4) * (%arg8[1] * %arg4) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<4xi8>, vector<32xf6E2M3FN>, i8, vector<32xf6E2M3FN>, vector<4xf32>
   
   // CHECK: llvm.bitcast
-  // CHECK: %[[c3:.+]] = llvm.mlir.constant(3 : i32) : i32
+  // CHECK: llvm.mlir.constant(3 : i32) : i32
 
-  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[c1]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i8, i32, i32) -> vector<16xf32>
-  amdgpu.scaled_mfma ( %arg7 [ 0 ] * %arg5 ) * ( %arg8 [ 1 ] * %arg5 ) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<32xf6E3M2FN>, vector<4xi8>, vector<32xf6E3M2FN>, i8, vector<16xf32>
-  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[c1]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i8, i32, i32) -> vector<4xf32>
-  amdgpu.scaled_mfma ( %arg7 [ 0 ] * %arg5 ) * ( %arg8 [ 1 ] * %arg5 ) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<32xf6E3M2FN>, vector<4xi8>, vector<32xf6E3M2FN>, i8, vector<4xf32>
+  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[c2]], %[[c1]], %[[c3]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+  amdgpu.scaled_mfma(%arg7[0] * %arg5) * (%arg8[1] * %arg5) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<4xi8>, vector<32xf6E3M2FN>, i8, vector<32xf6E3M2FN>, vector<16xf32>
+  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[c2]], %[[c1]], %[[c3]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+  amdgpu.scaled_mfma(%arg7[0] * %arg5) * (%arg8[1] * %arg5) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<4xi8>, vector<32xf6E3M2FN>, i8, vector<32xf6E3M2FN>, vector<4xf32>
   
   // CHECK: llvm.bitcast
-  // CHECK: %[[c4:.+]] = llvm.mlir.constant(4 : i32) : i32
+  // CHECK: llvm.mlir.constant(4 : i32) : i32
   
-  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[c1]] : (vector<4xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i8, i32, i32) -> vector<16xf32>
-  amdgpu.scaled_mfma ( %arg7 [ 0 ] * %arg6 ) * ( %arg8 [ 1 ] * %arg6 ) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<32xf4E2M1FN>, vector<4xi8>, vector<32xf4E2M1FN>, i8, vector<16xf32>
-  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[c1]] : (vector<4xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i8, i32, i32) -> vector<4xf32>
-  amdgpu.scaled_mfma ( %arg7 [ 0 ] * %arg6 ) * ( %arg8 [ 1 ] * %arg6 ) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<32xf4E2M1FN>, vector<4xi8>, vector<32xf4E2M1FN>, i8, vector<4xf32>
+  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[c2]], %[[c1]], %[[c3]] : (vector<4xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+  amdgpu.scaled_mfma(%arg7[0] * %arg6) * (%arg8[1] * %arg6) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<4xi8>, vector<32xf4E2M1FN>, i8, vector<32xf4E2M1FN>, vector<16xf32>
+  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[c2]], %[[c1]], %[[c3]] : (vector<4xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+  amdgpu.scaled_mfma(%arg7[0] * %arg6) * (%arg8[1] * %arg6) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<4xi8>, vector<32xf4E2M1FN>, i8, vector<32xf4E2M1FN>, vector<4xf32>
 
   func.return
 }
diff --git a/mlir/test/Dialect/AMDGPU/ops.mlir b/mlir/test/Dialect/AMDGPU/ops.mlir
index 16b3193d270cb..8dbee80eff54f 100644
--- a/mlir/test/Dialect/AMDGPU/ops.mlir
+++ b/mlir/test/Dialect/AMDGPU/ops.mlir
@@ -164,3 +164,10 @@ func.func @swizzle_bitmode(%arg0 : f32) -> f32 {
   %0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : f32
   func.return %0 : f32
 }
+
+// CHECK-LABEL: func @scaled_mfma
+func.func @scaled_mfma(%arg0 : i8, %arg1 : vector<32xf6E2M3FN>, %arg2 : vector<16xf32>) -> vector<16xf32> {
+  // CHECK: amdgpu.scaled_mfma
+  %0 = amdgpu.scaled_mfma(%arg0[0] * %arg1) * (%arg0[1] * %arg1) + %arg2 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : i8, vector<32xf6E2M3FN>, i8, vector<32xf6E2M3FN>, vector<16xf32>
+  func.return %0 : vector<16xf32>
+}

>From 8a9face20dd8d279cbfa22b07e1b46f2159feb5d Mon Sep 17 00:00:00 2001
From: Muzammiluddin Syed <muzasyed at amd.com>
Date: Thu, 1 May 2025 11:03:07 -0500
Subject: [PATCH 5/6] PR review round 4

Signed-off-by: Muzammiluddin Syed <muzasyed at amd.com>
---
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           |  8 +++---
 .../Conversion/AMDGPUToROCDL/mfma-gfx950.mlir | 26 ++++++++++---------
 2 files changed, 18 insertions(+), 16 deletions(-)

diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 01525154cc8c5..6e596485cbb58 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -539,8 +539,8 @@ static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter,
 /// Note that the type of `input` has already been LLVM type converted:
 /// therefore 8-bit and smaller floats are represented as their corresponding
 /// `iN` integers.
-static Value castScaledMFMAVectorOperand(ConversionPatternRewriter &rewriter,
-                                         Location loc, Value input) {
+static Value castMFMAScaleOperand(ConversionPatternRewriter &rewriter,
+                                  Location loc, Value input) {
   Type inputType = input.getType();
   Type outputType = rewriter.getI32Type();
   if (auto intType = dyn_cast<IntegerType>(inputType))
@@ -1018,10 +1018,10 @@ struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> {
          createI32Constant(rewriter, loc, bTypeCode),
          /*scales idx A=*/scalesIdxA,
          /*scales A*/
-         castScaledMFMAVectorOperand(rewriter, loc, adaptor.getScalesA()),
+         castMFMAScaleOperand(rewriter, loc, adaptor.getScalesA()),
          /*scales idx B=*/scalesIdxB,
          /*scales B*/
-         castScaledMFMAVectorOperand(rewriter, loc, adaptor.getScalesB())});
+         castMFMAScaleOperand(rewriter, loc, adaptor.getScalesB())});
     Value lowered = rewriter.create(loweredOp)->getResult(0);
     rewriter.replaceOp(op, lowered);
     return success();
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir b/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir
index f9fdd0ef4cbfc..b23aca76919b8 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir
@@ -52,6 +52,8 @@ func.func @mfma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<16xf32>,
   func.return
 }
 
+// CHECK-LABEL: func @scaled_mfma_to_rocdl(
+// CHECK-SAME: %[[ARG0:.*]]: vector<16xf32>, %[[ARG1:.*]]: vector<4xf32>, %[[ARG2:.*]]: vector<32xf8E4M3FN>, %[[ARG3:.*]]: vector<32xf8E5M2>, %[[ARG4:.*]]: vector<32xf6E2M3FN>, %[[ARG5:.*]]: vector<32xf6E3M2FN>, %[[ARG6:.*]]: vector<32xf4E2M1FN>, %[[ARG7:.*]]: vector<4xi8>, %[[ARG8:.*]]: i8
 func.func @scaled_mfma_to_rocdl(%arg0 : vector<16xf32>,
                     %arg1 : vector<4xf32>, %arg2 : vector<32xf8E4M3FN>,
                     %arg3 : vector<32xf8E5M2>, %arg4 : vector<32xf6E2M3FN>,
@@ -60,42 +62,42 @@ func.func @scaled_mfma_to_rocdl(%arg0 : vector<16xf32>,
   
   // CHECK: %[[c0:.+]] = llvm.mlir.constant(0 : i32) : i32
   // CHECK: %[[c1:.+]] = llvm.mlir.constant(1 : i32) : i32
-  // CHECK: %[[c2:.+]] = llvm.bitcast{{.*}} : vector<4xi8> to i32
-  // CHECK: %[[c3:.+]] = llvm.zext{{.*}} : i8 to i32
+  // CHECK: %[[b0:.+]] = llvm.bitcast %[[ARG7]] : vector<4xi8> to i32
+  // CHECK: %[[z0:.+]] = llvm.zext %[[ARG8]] : i8 to i32
 
-  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[c2]], %[[c1]], %[[c3]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
   amdgpu.scaled_mfma(%arg7[0] * %arg2) * (%arg8[1] * %arg2) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<4xi8>, vector<32xf8E4M3FN>, i8, vector<32xf8E4M3FN>, vector<16xf32>
-  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[c2]], %[[c1]], %[[c3]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
   amdgpu.scaled_mfma(%arg7[0] * %arg2) * (%arg8[1] * %arg2) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<4xi8>, vector<32xf8E4M3FN>, i8, vector<32xf8E4M3FN>, vector<4xf32>
   
   // CHECK: llvm.bitcast
   
-  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[c2]], %[[c1]], %[[c3]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
   amdgpu.scaled_mfma(%arg7[0] * %arg3) * (%arg8[1] * %arg3) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<4xi8>, vector<32xf8E5M2>, i8, vector<32xf8E5M2>, vector<16xf32>
-  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[c2]], %[[c1]], %[[c3]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
   amdgpu.scaled_mfma(%arg7[0] * %arg3) * (%arg8[1] * %arg3) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<4xi8>, vector<32xf8E5M2>, i8, vector<32xf8E5M2>, vector<4xf32>
   
   // CHECK: llvm.bitcast
   
-  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[c2]], %[[c1]], %[[c3]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
   amdgpu.scaled_mfma(%arg7[0] * %arg4) * (%arg8[1] * %arg4) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<4xi8>, vector<32xf6E2M3FN>, i8, vector<32xf6E2M3FN>, vector<16xf32>
-  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[c2]], %[[c1]], %[[c3]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
   amdgpu.scaled_mfma(%arg7[0] * %arg4) * (%arg8[1] * %arg4) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<4xi8>, vector<32xf6E2M3FN>, i8, vector<32xf6E2M3FN>, vector<4xf32>
   
   // CHECK: llvm.bitcast
   // CHECK: llvm.mlir.constant(3 : i32) : i32
 
-  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[c2]], %[[c1]], %[[c3]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
   amdgpu.scaled_mfma(%arg7[0] * %arg5) * (%arg8[1] * %arg5) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<4xi8>, vector<32xf6E3M2FN>, i8, vector<32xf6E3M2FN>, vector<16xf32>
-  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[c2]], %[[c1]], %[[c3]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
   amdgpu.scaled_mfma(%arg7[0] * %arg5) * (%arg8[1] * %arg5) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<4xi8>, vector<32xf6E3M2FN>, i8, vector<32xf6E3M2FN>, vector<4xf32>
   
   // CHECK: llvm.bitcast
   // CHECK: llvm.mlir.constant(4 : i32) : i32
   
-  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[c2]], %[[c1]], %[[c3]] : (vector<4xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<4xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
   amdgpu.scaled_mfma(%arg7[0] * %arg6) * (%arg8[1] * %arg6) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<4xi8>, vector<32xf4E2M1FN>, i8, vector<32xf4E2M1FN>, vector<16xf32>
-  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[c2]], %[[c1]], %[[c3]] : (vector<4xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<4xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
   amdgpu.scaled_mfma(%arg7[0] * %arg6) * (%arg8[1] * %arg6) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<4xi8>, vector<32xf4E2M1FN>, i8, vector<32xf4E2M1FN>, vector<4xf32>
 
   func.return

>From b38e93faf2cf2a327fd384018b7b4f1fb21e61c1 Mon Sep 17 00:00:00 2001
From: Muzammiluddin Syed <muzasyed at amd.com>
Date: Fri, 2 May 2025 11:54:51 -0500
Subject: [PATCH 6/6] Changing scales type to float

Signed-off-by: Muzammiluddin Syed <muzasyed at amd.com>
---
 mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td |  4 +--
 .../Conversion/AMDGPUToROCDL/mfma-gfx950.mlir | 28 +++++++++----------
 mlir/test/Dialect/AMDGPU/ops.mlir             |  4 +--
 3 files changed, 18 insertions(+), 18 deletions(-)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 744aa7a28a4b9..0cebecee0390f 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -844,8 +844,8 @@ def AMDGPU_ScaledMFMAOp :
                    ScaledMFMAInTypes:$sourceA,
                    ScaledMFMAInTypes:$sourceB,
                    ScaledMFMAOutTypes:$destC,
-                   AnyTypeOf<[I8, FixedVectorOfLengthAndType<[4], [I8]>]>:$scalesA,
-                   AnyTypeOf<[I8, FixedVectorOfLengthAndType<[4], [I8]>]>:$scalesB,
+                   AnyTypeOf<[F8E8M0FNU, FixedVectorOfLengthAndType<[4], [F8E8M0FNU]>]>:$scalesA,
+                   AnyTypeOf<[F8E8M0FNU, FixedVectorOfLengthAndType<[4], [F8E8M0FNU]>]>:$scalesB,
                    ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<3>]>:$scalesIdxA,
                    ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<3>]>:$scalesIdxB
                    )>,
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir b/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir
index b23aca76919b8..52a5d39f668c6 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir
@@ -53,52 +53,52 @@ func.func @mfma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<16xf32>,
 }
 
 // CHECK-LABEL: func @scaled_mfma_to_rocdl(
-// CHECK-SAME: %[[ARG0:.*]]: vector<16xf32>, %[[ARG1:.*]]: vector<4xf32>, %[[ARG2:.*]]: vector<32xf8E4M3FN>, %[[ARG3:.*]]: vector<32xf8E5M2>, %[[ARG4:.*]]: vector<32xf6E2M3FN>, %[[ARG5:.*]]: vector<32xf6E3M2FN>, %[[ARG6:.*]]: vector<32xf4E2M1FN>, %[[ARG7:.*]]: vector<4xi8>, %[[ARG8:.*]]: i8
+// CHECK-SAME: %[[ARG0:.*]]: vector<16xf32>, %[[ARG1:.*]]: vector<4xf32>, %[[ARG2:.*]]: vector<32xf8E4M3FN>, %[[ARG3:.*]]: vector<32xf8E5M2>, %[[ARG4:.*]]: vector<32xf6E2M3FN>, %[[ARG5:.*]]: vector<32xf6E3M2FN>, %[[ARG6:.*]]: vector<32xf4E2M1FN>, %[[ARG7:.*]]: vector<4xf8E8M0FNU>, %[[ARG8:.*]]: f8E8M0FNU
 func.func @scaled_mfma_to_rocdl(%arg0 : vector<16xf32>,
                     %arg1 : vector<4xf32>, %arg2 : vector<32xf8E4M3FN>,
                     %arg3 : vector<32xf8E5M2>, %arg4 : vector<32xf6E2M3FN>,
                     %arg5 : vector<32xf6E3M2FN>, %arg6 : vector<32xf4E2M1FN>, 
-                    %arg7 : vector<4xi8>, %arg8 : i8) {
+                    %arg7 : vector<4xf8E8M0FNU>, %arg8 : f8E8M0FNU) {
   
   // CHECK: %[[c0:.+]] = llvm.mlir.constant(0 : i32) : i32
   // CHECK: %[[c1:.+]] = llvm.mlir.constant(1 : i32) : i32
-  // CHECK: %[[b0:.+]] = llvm.bitcast %[[ARG7]] : vector<4xi8> to i32
-  // CHECK: %[[z0:.+]] = llvm.zext %[[ARG8]] : i8 to i32
+  // CHECK: %[[b0:.+]] = llvm.bitcast {{.*}} : vector<4xi8> to i32
+  // CHECK: %[[z0:.+]] = llvm.zext {{.*}} : i8 to i32
 
   // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
-  amdgpu.scaled_mfma(%arg7[0] * %arg2) * (%arg8[1] * %arg2) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<4xi8>, vector<32xf8E4M3FN>, i8, vector<32xf8E4M3FN>, vector<16xf32>
+  amdgpu.scaled_mfma(%arg7[0] * %arg2) * (%arg8[1] * %arg2) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<4xf8E8M0FNU>, vector<32xf8E4M3FN>, f8E8M0FNU, vector<32xf8E4M3FN>, vector<16xf32>
   // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
-  amdgpu.scaled_mfma(%arg7[0] * %arg2) * (%arg8[1] * %arg2) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<4xi8>, vector<32xf8E4M3FN>, i8, vector<32xf8E4M3FN>, vector<4xf32>
+  amdgpu.scaled_mfma(%arg7[0] * %arg2) * (%arg8[1] * %arg2) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<4xf8E8M0FNU>, vector<32xf8E4M3FN>, f8E8M0FNU, vector<32xf8E4M3FN>, vector<4xf32>
   
   // CHECK: llvm.bitcast
   
   // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
-  amdgpu.scaled_mfma(%arg7[0] * %arg3) * (%arg8[1] * %arg3) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<4xi8>, vector<32xf8E5M2>, i8, vector<32xf8E5M2>, vector<16xf32>
+  amdgpu.scaled_mfma(%arg7[0] * %arg3) * (%arg8[1] * %arg3) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<4xf8E8M0FNU>, vector<32xf8E5M2>, f8E8M0FNU, vector<32xf8E5M2>, vector<16xf32>
   // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
-  amdgpu.scaled_mfma(%arg7[0] * %arg3) * (%arg8[1] * %arg3) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<4xi8>, vector<32xf8E5M2>, i8, vector<32xf8E5M2>, vector<4xf32>
+  amdgpu.scaled_mfma(%arg7[0] * %arg3) * (%arg8[1] * %arg3) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<4xf8E8M0FNU>, vector<32xf8E5M2>, f8E8M0FNU, vector<32xf8E5M2>, vector<4xf32>
   
   // CHECK: llvm.bitcast
   
   // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
-  amdgpu.scaled_mfma(%arg7[0] * %arg4) * (%arg8[1] * %arg4) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<4xi8>, vector<32xf6E2M3FN>, i8, vector<32xf6E2M3FN>, vector<16xf32>
+  amdgpu.scaled_mfma(%arg7[0] * %arg4) * (%arg8[1] * %arg4) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<4xf8E8M0FNU>, vector<32xf6E2M3FN>, f8E8M0FNU, vector<32xf6E2M3FN>, vector<16xf32>
   // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
-  amdgpu.scaled_mfma(%arg7[0] * %arg4) * (%arg8[1] * %arg4) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<4xi8>, vector<32xf6E2M3FN>, i8, vector<32xf6E2M3FN>, vector<4xf32>
+  amdgpu.scaled_mfma(%arg7[0] * %arg4) * (%arg8[1] * %arg4) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<4xf8E8M0FNU>, vector<32xf6E2M3FN>, f8E8M0FNU, vector<32xf6E2M3FN>, vector<4xf32>
   
   // CHECK: llvm.bitcast
   // CHECK: llvm.mlir.constant(3 : i32) : i32
 
   // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
-  amdgpu.scaled_mfma(%arg7[0] * %arg5) * (%arg8[1] * %arg5) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<4xi8>, vector<32xf6E3M2FN>, i8, vector<32xf6E3M2FN>, vector<16xf32>
+  amdgpu.scaled_mfma(%arg7[0] * %arg5) * (%arg8[1] * %arg5) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<4xf8E8M0FNU>, vector<32xf6E3M2FN>, f8E8M0FNU, vector<32xf6E3M2FN>, vector<16xf32>
   // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
-  amdgpu.scaled_mfma(%arg7[0] * %arg5) * (%arg8[1] * %arg5) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<4xi8>, vector<32xf6E3M2FN>, i8, vector<32xf6E3M2FN>, vector<4xf32>
+  amdgpu.scaled_mfma(%arg7[0] * %arg5) * (%arg8[1] * %arg5) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<4xf8E8M0FNU>, vector<32xf6E3M2FN>, f8E8M0FNU, vector<32xf6E3M2FN>, vector<4xf32>
   
   // CHECK: llvm.bitcast
   // CHECK: llvm.mlir.constant(4 : i32) : i32
   
   // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<4xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
-  amdgpu.scaled_mfma(%arg7[0] * %arg6) * (%arg8[1] * %arg6) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<4xi8>, vector<32xf4E2M1FN>, i8, vector<32xf4E2M1FN>, vector<16xf32>
+  amdgpu.scaled_mfma(%arg7[0] * %arg6) * (%arg8[1] * %arg6) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<16xf32>
   // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<4xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
-  amdgpu.scaled_mfma(%arg7[0] * %arg6) * (%arg8[1] * %arg6) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<4xi8>, vector<32xf4E2M1FN>, i8, vector<32xf4E2M1FN>, vector<4xf32>
+  amdgpu.scaled_mfma(%arg7[0] * %arg6) * (%arg8[1] * %arg6) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32>
 
   func.return
 }
diff --git a/mlir/test/Dialect/AMDGPU/ops.mlir b/mlir/test/Dialect/AMDGPU/ops.mlir
index 8dbee80eff54f..188cfcc4eb38b 100644
--- a/mlir/test/Dialect/AMDGPU/ops.mlir
+++ b/mlir/test/Dialect/AMDGPU/ops.mlir
@@ -166,8 +166,8 @@ func.func @swizzle_bitmode(%arg0 : f32) -> f32 {
 }
 
 // CHECK-LABEL: func @scaled_mfma
-func.func @scaled_mfma(%arg0 : i8, %arg1 : vector<32xf6E2M3FN>, %arg2 : vector<16xf32>) -> vector<16xf32> {
+func.func @scaled_mfma(%arg0 : f8E8M0FNU, %arg1 : vector<32xf6E2M3FN>, %arg2 : vector<16xf32>) -> vector<16xf32> {
   // CHECK: amdgpu.scaled_mfma
-  %0 = amdgpu.scaled_mfma(%arg0[0] * %arg1) * (%arg0[1] * %arg1) + %arg2 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : i8, vector<32xf6E2M3FN>, i8, vector<32xf6E2M3FN>, vector<16xf32>
+  %0 = amdgpu.scaled_mfma(%arg0[0] * %arg1) * (%arg0[1] * %arg1) + %arg2 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : f8E8M0FNU, vector<32xf6E2M3FN>, f8E8M0FNU, vector<32xf6E2M3FN>, vector<16xf32>
   func.return %0 : vector<16xf32>
 }



More information about the Mlir-commits mailing list