[Mlir-commits] [mlir] [mlir][AMDGPU] Add scaled wmma ops for gfx1250 (PR #169854)

Justin Rosner llvmlistbot at llvm.org
Fri Nov 28 09:20:57 PST 2025


https://github.com/justinrosner updated https://github.com/llvm/llvm-project/pull/169854

>From 32b0c2c160607874949a87ecf38fa68bd118bd86 Mon Sep 17 00:00:00 2001
From: Justin Rosner <justin.rosner at amd.com>
Date: Thu, 27 Nov 2025 19:25:56 +0000
Subject: [PATCH 1/4] Add scaled WMMA to AMDGPU

---
 mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td |  60 ++++++++
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           | 128 ++++++++++++++++--
 mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp  |  71 ++++++++++
 3 files changed, 244 insertions(+), 15 deletions(-)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index e07c72b839e7c..a2201d3127370 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -951,6 +951,13 @@ def MFMAOutTypes : AnyTypeOf<[F64,
 def ScaledMFMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[32], [F8E5M2, F8E4M3FN]>,
                                    VectorOfLengthAndType<[32], [F6E2M3FN, F6E3M2FN, F4E2M1FN]>]>;
 def ScaledMFMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 16], [F32]>]>;
+
+// scaled_wmma
+def ScaledWMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[64], [F8E5M2, F8E4M3FN]>,
+                                   VectorOfLengthAndType<[64], [F6E2M3FN, F6E3M2FN]>,
+                                   VectorOfLengthAndType<[64, 128], [F4E2M1FN]>]>;
+def ScaledWMMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8], [F32]>]>;
+
 // wmma
 def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[2], [F32]>,
                              VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>,
@@ -1218,6 +1225,59 @@ def AMDGPU_ScaledMFMAOp :
   let hasCanonicalizer = 1;
 }
 
+def AMDGPU_ScaledWMMAOp :
+    AMDGPU_Op<"scaled_wmma", [AllTypesMatch<["destC", "destD"]>,
+                              Pure]>,
+    Arguments<(ins
+                   ConfinedAttr<I32Attr, [IntIsOneOf<[16, 32]>]>:$m,
+                   ConfinedAttr<I32Attr, [IntIsOneOf<[16]>]>:$n,
+                   ConfinedAttr<I32Attr, [IntIsOneOf<[128]>]>:$k,
+                   ScaledWMMAInTypes:$sourceA,
+                   ScaledWMMAInTypes:$sourceB,
+                   ScaledWMMAOutTypes:$destC,
+                   AnyTypeOf<[I32, I64]>:$scaleA,
+                   AnyTypeOf<[I32, I64]>:$scaleB,
+                   DefaultValuedAttr<I32Attr, "0">:$scaleAType,
+                   DefaultValuedAttr<I32Attr, "0">:$fmtScaleA,
+                   DefaultValuedAttr<I32Attr, "0">:$scaleBType,
+                   DefaultValuedAttr<I32Attr, "0">:$fmtScaleB
+                   )>,
+    Results<(outs ScaledWMMAOutTypes: $destD)> {
+  let summary = "MLIR wrapper for RDNA scaled wmma instructions";
+  let description = [{
+    The `amdgpu.scaled_wmma` op is an MLIR wrapper around intrinsics for scaled
+    `wmma` instructions in the RDNA architecture. These instructions perform
+    matrix multiplication with per-block scaling of inputs, supporting fp4, fp6,
+    and fp8 data formats.
+
+    The scale instructions support two tile sizes:
+    - 16x16x128 with mixed f8/f6/f4 formats (output: vector<4xf32>)
+    - 32x16x128 with f4 format only (output: vector<8xf32>)
+
+    The `scaleA` and `scaleB` parameters are scale exponents that can be either
+    i32 (for wmma.scale) or i64 (for wmma.scale16) to support per-block scaling.
+
+    Optional modifiers:
+    - `scaleAType`, `scaleBType`: Type of scale parameter
+    - `fmtScaleA`, `fmtScaleB`: Format of scale parameter
+
+    Example:
+    ```mlir
+      %0 = amdgpu.scaled_wmma (%sa * %matA) * (%sb * %matB) + %matC
+        { m = 16, n = 16, k = 128 } : i32, vector<64xf8E4M3FN>, i32, vector<64xf8E4M3FN>, vector<4xf32>
+
+      %1 = amdgpu.scaled_wmma (%sc * %matD) * (%sd * %matE) + %matF
+        { m = 32, n = 16, k = 128 } : i32, vector<128xf4E2M1FN>, i32, vector<64xf4E2M1FN>, vector<8xf32>
+    ```
+  }];
+  let assemblyFormat = [{
+    `(` $scaleA `*` $sourceA `)` `*` `(` $scaleB `*` $sourceB `)` `+` $destC
+    attr-dict
+    `:` type($scaleA) `,` type($sourceA) `,` type($scaleB) `,` type($sourceB) `,` type($destC)
+  }];
+  let hasVerifier = 1;
+}
+
 def AMDGPU_MakeDmaBaseOp :
     AMDGPU_Op<"make_dma_base", [Pure, AttrSizedOperandSegments]>,
     Arguments<(ins
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index b9a5e7d7f6eac..02f0e14791bf2 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -612,8 +612,8 @@ struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
 
 } // namespace
 
-/// Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL
-/// and LLVM AMDGPU intrinsics convention.
+/// Pack small float vector operands (fp4/fp6/fp8/bf16) into the format
+/// expected by scaled matrix multiply intrinsics (MFMA/WMMA).
 ///
 /// Specifically:
 /// 1. If the element type is bfloat16, bitcast it to i16 unless rocdl intrinsic
@@ -627,9 +627,9 @@ struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
 /// Note that the type of `input` has already been LLVM type converted:
 /// therefore 8-bit and smaller floats are represented as their corresponding
 /// `iN` integers.
-static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter,
-                                      Location loc, Value input,
-                                      bool allowBf16 = true) {
+static Value packSmallFloatVectorOperand(ConversionPatternRewriter &rewriter,
+                                         Location loc, Value input,
+                                         bool allowBf16 = true) {
   Type inputType = input.getType();
   if (auto vectorType = dyn_cast<VectorType>(inputType)) {
     if (vectorType.getElementType().isBF16() && !allowBf16)
@@ -918,7 +918,7 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
   return std::nullopt;
 }
 
-static std::optional<uint32_t> mfmaTypeSelectCode(Type mlirElemType) {
+static std::optional<uint32_t> smallFloatTypeToFormatCode(Type mlirElemType) {
   return llvm::TypeSwitch<Type, std::optional<uint32_t>>(mlirElemType)
       .Case([](Float8E4M3FNType) { return 0u; })
       .Case([](Float8E5M2Type) { return 1u; })
@@ -947,8 +947,8 @@ mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m,
   if (!isa<Float32Type>(destType))
     return std::nullopt;
 
-  std::optional<uint32_t> aTypeCode = mfmaTypeSelectCode(aType);
-  std::optional<uint32_t> bTypeCode = mfmaTypeSelectCode(bType);
+  std::optional<uint32_t> aTypeCode = smallFloatTypeToFormatCode(aType);
+  std::optional<uint32_t> bTypeCode = smallFloatTypeToFormatCode(bType);
   if (!aTypeCode || !bTypeCode)
     return std::nullopt;
 
@@ -1212,11 +1212,12 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
     }();
     OperationState loweredOp(loc, intrinsicName);
     loweredOp.addTypes(intrinsicOutType);
-    loweredOp.addOperands({convertMFMAVectorOperand(
-                               rewriter, loc, adaptor.getSourceA(), allowBf16),
-                           convertMFMAVectorOperand(
-                               rewriter, loc, adaptor.getSourceB(), allowBf16),
-                           adaptor.getDestC()});
+    loweredOp.addOperands(
+        {packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceA(),
+                                      allowBf16),
+         packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceB(),
+                                      allowBf16),
+         adaptor.getDestC()});
     if (isScaled) {
       Value zero = createI32Constant(rewriter, loc, 0);
       auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
@@ -1261,8 +1262,8 @@ struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> {
     OperationState loweredOp(loc, intrinsicName);
     loweredOp.addTypes(intrinsicOutType);
     loweredOp.addOperands(
-        {convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceA()),
-         convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceB()),
+        {packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceA()),
+         packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceB()),
          adaptor.getDestC()});
     Value scalesIdxA =
         createI32Constant(rewriter, loc, adaptor.getScalesIdxA());
@@ -1363,6 +1364,103 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
   }
 };
 
+struct ScaledWMMAOpLowering : public ConvertOpToLLVMPattern<ScaledWMMAOp> {
+  ScaledWMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
+      : ConvertOpToLLVMPattern<ScaledWMMAOp>(converter), chipset(chipset) {}
+
+  Chipset chipset;
+
+  LogicalResult
+  matchAndRewrite(ScaledWMMAOp op, ScaledWMMAOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    auto outType =
+        typeConverter->convertType<VectorType>(op.getDestD().getType());
+    if (!outType)
+      return rewriter.notifyMatchFailure(op, "type conversion failed");
+
+    if (chipset < Chipset(12, 5, 0))
+      return op->emitOpError("WMMA scale only supported on gfx1250+");
+
+    int64_t m = op.getM();
+    int64_t n = op.getN();
+    int64_t k = op.getK();
+
+    Type aElemType = getElementTypeOrSelf(op.getSourceA().getType());
+    Type bElemType = getElementTypeOrSelf(op.getSourceB().getType());
+
+    std::optional<uint32_t> aFmtCode = smallFloatTypeToFormatCode(aElemType);
+    std::optional<uint32_t> bFmtCode = smallFloatTypeToFormatCode(bElemType);
+
+    if (!aFmtCode || !bFmtCode)
+      return op.emitOpError("unsupported element types for scaled_wmma");
+
+    // Determine which intrinsic to use based on dimensions and scale type
+    StringRef intrinsicName;
+    bool isScale16 = adaptor.getScaleA().getType().isInteger(64);
+    bool is32x16 = (m == 32 && n == 16 && k == 128);
+
+    if (m == 16 && n == 16 && k == 128) {
+      intrinsicName = isScale16
+                ? ROCDL::wmma_scale16_f32_16x16x128_f8f6f4::getOperationName()
+                : ROCDL::wmma_scale_f32_16x16x128_f8f6f4::getOperationName();
+    } else if (is32x16) {
+      intrinsicName = isScale16
+                ? ROCDL::wmma_scale16_f32_32x16x128_f4::getOperationName()
+                : ROCDL::wmma_scale_f32_32x16x128_f4::getOperationName();
+    } else {
+      return op.emitOpError("unsupported scaled_wmma dimensions: ")
+             << m << "x" << n << "x" << k;
+    }
+
+    SmallVector<NamedAttribute, 8> attrs;
+
+    // The f4 variant does not have fmtA and fmtB attributes
+    if (!is32x16) {
+      attrs.push_back(rewriter.getNamedAttr("fmtA",
+                              rewriter.getI32IntegerAttr(*aFmtCode)));
+      attrs.push_back(rewriter.getNamedAttr("fmtB",
+                              rewriter.getI32IntegerAttr(*bFmtCode)));
+    }
+
+    // Add modifier attributes - modC and reuse flags default to 0/false
+    attrs.push_back(rewriter.getNamedAttr("reuseA",
+                              rewriter.getBoolAttr(false)));
+    attrs.push_back(rewriter.getNamedAttr("reuseB",
+                              rewriter.getBoolAttr(false)));
+    attrs.push_back(rewriter.getNamedAttr("modC",
+                              rewriter.getI16IntegerAttr(0)));
+
+    // Scale type/format parameters from the operation
+    attrs.push_back(rewriter.getNamedAttr("scaleAType",
+                              rewriter.getI32IntegerAttr(op.getScaleAType())));
+    attrs.push_back(rewriter.getNamedAttr("fmtScaleA",
+                              rewriter.getI32IntegerAttr(op.getFmtScaleA())));
+    attrs.push_back(rewriter.getNamedAttr("scaleBType",
+                              rewriter.getI32IntegerAttr(op.getScaleBType())));
+    attrs.push_back(rewriter.getNamedAttr("fmtScaleB",
+                              rewriter.getI32IntegerAttr(op.getFmtScaleB())));
+
+    // Convert typed float vectors to packed i32 format if needed
+    Value sourceA =
+        packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceA());
+    Value sourceB =
+        packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceB());
+
+    // Create the intrinsic call
+    OperationState loweredOp(loc, intrinsicName);
+    loweredOp.addTypes(outType);
+    loweredOp.addOperands({sourceA, sourceB, adaptor.getDestC(),
+                           adaptor.getScaleA(), adaptor.getScaleB()});
+    loweredOp.addAttributes(attrs);
+
+    Operation *lowered = rewriter.create(loweredOp);
+    rewriter.replaceOp(op, lowered->getResults());
+
+    return success();
+  }
+};
+
 struct TransposeLoadOpLowering
     : public ConvertOpToLLVMPattern<TransposeLoadOp> {
   TransposeLoadOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index cdc10c60a42ae..87bd1903290ae 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -442,6 +442,77 @@ LogicalResult WMMAOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// ScaledWMMAOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult ScaledWMMAOp::verify() {
+  auto sourceAType = cast<VectorType>(getSourceA().getType());
+  auto sourceBType = cast<VectorType>(getSourceB().getType());
+  auto destType = cast<VectorType>(getDestC().getType());
+  
+  // Validate output type is F32
+  if (!destType.getElementType().isF32())
+    return emitOpError("destination must have f32 element type");
+
+  // Validate source element types are small floats (fp4/fp6/fp8)
+  Type aElemType = sourceAType.getElementType();
+  Type bElemType = sourceBType.getElementType();
+
+  bool aIsSmallFloat = aElemType.isFloat(4) || aElemType.isFloat(6) ||
+                       aElemType.isFloat(8);
+  bool bIsSmallFloat = bElemType.isFloat(4) || bElemType.isFloat(6) ||
+                       bElemType.isFloat(8);
+
+  if (!aIsSmallFloat || !bIsSmallFloat)
+    return emitOpError("source operands must have small float element types "
+                       "(fp4/fp6/fp8)");
+
+  // Validate scale types match (both i32 or both i64)
+  Type scaleAType = getScaleA().getType();
+  Type scaleBType = getScaleB().getType();
+  if (scaleAType != scaleBType)
+    return emitOpError("scaleA and scaleB must have the same type");
+
+  // Validate vector lengths based on dimensions
+  int64_t m = getM();
+  int64_t aLen = sourceAType.getNumElements();
+  int64_t bLen = sourceBType.getNumElements();
+  int64_t expectedOutLen = (m == 16) ? 4 : 8;
+  
+  if (destType.getNumElements() != expectedOutLen)
+    return emitOpError("expected output vector of length " +
+                       Twine(expectedOutLen) + " but got " +
+                       Twine(destType.getNumElements()));
+
+  if (m == 16) {
+    // For 16×16×128: both A and B must be 64 elements
+    if (aLen != 64)
+      return emitOpError(
+          "for 16x16x128, sourceA must have 64 elements but got " +
+          Twine(aLen));
+    if (bLen != 64)
+      return emitOpError(
+          "for 16x16x128, sourceB must have 64 elements but got " +
+          Twine(bLen));
+  } else { // m == 32
+    // For 32×16×128: only fp4 is supported, A is 128, B is 64
+    if (!aElemType.isFloat(4))
+      return emitOpError("32x16x128 only supports fp4 element types");
+
+    if (aLen != 128)
+      return emitOpError(
+          "for 32x16x128, sourceA must have 128 elements but got " +
+          Twine(aLen));
+    if (bLen != 64)
+      return emitOpError(
+          "for 32x16x128, sourceB must have 64 elements but got " +
+          Twine(bLen));
+  }
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // MFMAOp
 //===----------------------------------------------------------------------===//

>From 8750acce924f204b8a3381166ec2fe2e92dc0d76 Mon Sep 17 00:00:00 2001
From: Justin Rosner <justin.rosner at amd.com>
Date: Thu, 27 Nov 2025 19:36:08 +0000
Subject: [PATCH 2/4] Add LIT tests

---
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           |  3 +-
 .../AMDGPUToROCDL/wmma-gfx1250.mlir           | 73 +++++++++++++++++++
 mlir/test/Dialect/AMDGPU/ops.mlir             | 26 +++++++
 3 files changed, 101 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 02f0e14791bf2..f4034f44d06b8 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -2427,7 +2427,8 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
                                ROCDL::RawPtrBufferAtomicCmpSwap>,
            AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering,
            SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
-           WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPacked816OpLowering,
+           WMMAOpLowering, ScaledWMMAOpLowering, ExtPackedFp8OpLowering,
+           ScaledExtPacked816OpLowering,
            ScaledExtPackedOpLowering, PackedScaledTruncOpLowering,
            PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
            GatherToLDSOpLowering, TransposeLoadOpLowering,
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir
index 37259f6ed06eb..d187e62484059 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir
@@ -89,6 +89,79 @@ func.func @wmma_k128(%arg0 : vector<64xf8E4M3FN>, %arg1 : vector<64xf8E5M2>,
   return
 }
 
+// CHECK-LABEL: @wmma_scale_16x16x128_fp8
+func.func @wmma_scale_16x16x128_fp8(%arg0 : vector<64xf8E4M3FN>, %arg1 : vector<64xf8E5M2>,
+                                    %arg2 : vector<4xf32>, %arg3 : i32, %arg4 : i32) {
+  // CHECK: rocdl.wmma.scale.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg2, %arg3, %arg4 : (vector<16xi32>, vector<16xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+  %0 = amdgpu.scaled_wmma (%arg3 * %arg0) * (%arg4 * %arg0) + %arg2
+    { m = 16 : i32, n = 16 : i32, k = 128 : i32 } : i32, vector<64xf8E4M3FN>, i32, vector<64xf8E4M3FN>, vector<4xf32>
+
+  // CHECK: rocdl.wmma.scale.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg2, %arg3, %arg4 {fmtA = 1 : i32, fmtB = 1 : i32} : (vector<16xi32>, vector<16xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+  %1 = amdgpu.scaled_wmma (%arg3 * %arg1) * (%arg4 * %arg1) + %arg2
+    { m = 16 : i32, n = 16 : i32, k = 128 : i32 } : i32, vector<64xf8E5M2>, i32, vector<64xf8E5M2>, vector<4xf32>
+
+  func.return
+}
+
+// CHECK-LABEL: @wmma_scale_16x16x128_fp6
+func.func @wmma_scale_16x16x128_fp6(%arg0 : vector<64xf6E2M3FN>, %arg1 : vector<64xf6E3M2FN>,
+                                    %arg2 : vector<4xf32>, %arg3 : i32, %arg4 : i32) {
+  // CHECK: rocdl.wmma.scale.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg2, %arg3, %arg4 {fmtA = 2 : i32, fmtB = 2 : i32} : (vector<12xi32>, vector<12xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+  %0 = amdgpu.scaled_wmma (%arg3 * %arg0) * (%arg4 * %arg0) + %arg2
+    { m = 16 : i32, n = 16 : i32, k = 128 : i32 } : i32, vector<64xf6E2M3FN>, i32, vector<64xf6E2M3FN>, vector<4xf32>
+
+  // CHECK: rocdl.wmma.scale.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg2, %arg3, %arg4 {fmtA = 3 : i32, fmtB = 3 : i32} : (vector<12xi32>, vector<12xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+  %1 = amdgpu.scaled_wmma (%arg3 * %arg1) * (%arg4 * %arg1) + %arg2
+    { m = 16 : i32, n = 16 : i32, k = 128 : i32 } : i32, vector<64xf6E3M2FN>, i32, vector<64xf6E3M2FN>, vector<4xf32>
+
+  func.return
+}
+
+// CHECK-LABEL: @wmma_scale_16x16x128_mixed
+func.func @wmma_scale_16x16x128_mixed(%arg0 : vector<64xf8E4M3FN>, %arg1 : vector<64xf6E2M3FN>,
+                                      %arg2 : vector<64xf4E2M1FN>, %arg3 : vector<4xf32>,
+                                      %arg4 : i32, %arg5 : i32) {
+  // CHECK: rocdl.wmma.scale.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg3, %arg4, %arg5 {fmtB = 2 : i32} : (vector<16xi32>, vector<12xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+  %0 = amdgpu.scaled_wmma (%arg4 * %arg0) * (%arg5 * %arg1) + %arg3
+    { m = 16 : i32, n = 16 : i32, k = 128 : i32 } : i32, vector<64xf8E4M3FN>, i32, vector<64xf6E2M3FN>, vector<4xf32>
+
+  // CHECK: rocdl.wmma.scale.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg3, %arg4, %arg5 {fmtA = 2 : i32, fmtB = 4 : i32} : (vector<12xi32>, vector<8xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+  %1 = amdgpu.scaled_wmma (%arg4 * %arg1) * (%arg5 * %arg2) + %arg3
+    { m = 16 : i32, n = 16 : i32, k = 128 : i32 } : i32, vector<64xf6E2M3FN>, i32, vector<64xf4E2M1FN>, vector<4xf32>
+
+  func.return
+}
+
+// CHECK-LABEL: @wmma_scale16_16x16x128_fp8
+func.func @wmma_scale16_16x16x128_fp8(%arg0 : vector<64xf8E4M3FN>, %arg1 : vector<64xf8E5M2>,
+                                      %arg2 : vector<4xf32>, %arg3 : i64, %arg4 : i64) {
+  // CHECK: rocdl.wmma.scale16.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg2, %arg3, %arg4 : (vector<16xi32>, vector<16xi32>, vector<4xf32>, i64, i64) -> vector<4xf32>
+  %0 = amdgpu.scaled_wmma (%arg3 * %arg0) * (%arg4 * %arg0) + %arg2
+    { m = 16 : i32, n = 16 : i32, k = 128 : i32 } : i64, vector<64xf8E4M3FN>, i64, vector<64xf8E4M3FN>, vector<4xf32>
+
+  func.return
+}
+
+// CHECK-LABEL: @wmma_scale_32x16x128_fp4
+func.func @wmma_scale_32x16x128_fp4(%arg0 : vector<128xf4E2M1FN>, %arg1 : vector<64xf4E2M1FN>,
+                                    %arg2 : vector<8xf32>, %arg3 : i32, %arg4 : i32) {
+  // CHECK: rocdl.wmma.scale.f32.32x16x128.f4 {{.*}}, {{.*}}, %arg2, %arg3, %arg4 : (vector<16xi32>, vector<8xi32>, vector<8xf32>, i32, i32) -> vector<8xf32>
+  %0 = amdgpu.scaled_wmma (%arg3 * %arg0) * (%arg4 * %arg1) + %arg2
+    { m = 32 : i32, n = 16 : i32, k = 128 : i32 } : i32, vector<128xf4E2M1FN>, i32, vector<64xf4E2M1FN>, vector<8xf32>
+
+  func.return
+}
+
+// CHECK-LABEL: @wmma_scale16_32x16x128_fp4
+func.func @wmma_scale16_32x16x128_fp4(%arg0 : vector<128xf4E2M1FN>, %arg1 : vector<64xf4E2M1FN>,
+                                      %arg2 : vector<8xf32>, %arg3 : i64, %arg4 : i64) {
+  // CHECK: rocdl.wmma.scale16.f32.32x16x128.f4 {{.*}}, {{.*}}, %arg2, %arg3, %arg4 : (vector<16xi32>, vector<8xi32>, vector<8xf32>, i64, i64) -> vector<8xf32>
+  %0 = amdgpu.scaled_wmma (%arg3 * %arg0) * (%arg4 * %arg1) + %arg2
+    { m = 32 : i32, n = 16 : i32, k = 128 : i32 } : i64, vector<128xf4E2M1FN>, i64, vector<64xf4E2M1FN>, vector<8xf32>
+
+  func.return
+}
+
 // -----
 
 func.func @wmma_unsupported_k(%arg0 : vector<8xf16>, %arg1 : vector<8xf32>) {
diff --git a/mlir/test/Dialect/AMDGPU/ops.mlir b/mlir/test/Dialect/AMDGPU/ops.mlir
index 653f9f64d24f4..f1492b2d4d14e 100644
--- a/mlir/test/Dialect/AMDGPU/ops.mlir
+++ b/mlir/test/Dialect/AMDGPU/ops.mlir
@@ -697,3 +697,29 @@ func.func @make_dma_base(%idx: index, %mem: memref<8xi32>, %smem: memref<8xi32,
   func.return
 }
 
+// CHECK-LABEL: func @wmma_scale
+func.func @wmma_scale(%fp8_src: vector<64xf8E4M3FN>, %bf8_src: vector<64xf8E5M2>,
+                      %fp6_src: vector<64xf6E2M3FN>, %fp4_src_a: vector<128xf4E2M1FN>,
+                      %fp4_src_b: vector<64xf4E2M1FN>,
+                      %dst0: vector<4xf32>, %dst1: vector<8xf32>,
+                      %scale32: i32, %scale64: i64) {
+  // CHECK: amdgpu.scaled_wmma({{.*}} * {{.*}}) * ({{.*}} * {{.*}}) + {{.*}} {k = 128 : i32, m = 16 : i32, n = 16 : i32} : i32, vector<64xf8E4M3FN>, i32, vector<64xf8E4M3FN>, vector<4xf32>
+  %0 = amdgpu.scaled_wmma (%scale32 * %fp8_src) * (%scale32 * %fp8_src) + %dst0
+    { m = 16 : i32, n = 16 : i32, k = 128 : i32 } : i32, vector<64xf8E4M3FN>, i32, vector<64xf8E4M3FN>, vector<4xf32>
+  // CHECK: amdgpu.scaled_wmma({{.*}} * {{.*}}) * ({{.*}} * {{.*}}) + {{.*}} {k = 128 : i32, m = 16 : i32, n = 16 : i32} : i32, vector<64xf8E5M2>, i32, vector<64xf8E5M2>, vector<4xf32>
+  %1 = amdgpu.scaled_wmma (%scale32 * %bf8_src) * (%scale32 * %bf8_src) + %dst0
+    { m = 16 : i32, n = 16 : i32, k = 128 : i32 } : i32, vector<64xf8E5M2>, i32, vector<64xf8E5M2>, vector<4xf32>
+  // CHECK: amdgpu.scaled_wmma({{.*}} * {{.*}}) * ({{.*}} * {{.*}}) + {{.*}} {k = 128 : i32, m = 16 : i32, n = 16 : i32} : i32, vector<64xf6E2M3FN>, i32, vector<64xf6E2M3FN>, vector<4xf32>
+  %2 = amdgpu.scaled_wmma (%scale32 * %fp6_src) * (%scale32 * %fp6_src) + %dst0
+    { m = 16 : i32, n = 16 : i32, k = 128 : i32 } : i32, vector<64xf6E2M3FN>, i32, vector<64xf6E2M3FN>, vector<4xf32>
+  // CHECK: amdgpu.scaled_wmma({{.*}} * {{.*}}) * ({{.*}} * {{.*}}) + {{.*}} {k = 128 : i32, m = 16 : i32, n = 16 : i32} : i32, vector<64xf4E2M1FN>, i32, vector<64xf6E2M3FN>, vector<4xf32>
+  %3 = amdgpu.scaled_wmma (%scale32 * %fp4_src_b) * (%scale32 * %fp6_src) + %dst0
+    { m = 16 : i32, n = 16 : i32, k = 128 : i32 } : i32, vector<64xf4E2M1FN>, i32, vector<64xf6E2M3FN>, vector<4xf32>
+  // CHECK: amdgpu.scaled_wmma({{.*}} * {{.*}}) * ({{.*}} * {{.*}}) + {{.*}} {k = 128 : i32, m = 16 : i32, n = 16 : i32} : i64, vector<64xf8E4M3FN>, i64, vector<64xf8E4M3FN>, vector<4xf32>
+  %4 = amdgpu.scaled_wmma (%scale64 * %fp8_src) * (%scale64 * %fp8_src) + %dst0
+    { m = 16 : i32, n = 16 : i32, k = 128 : i32 } : i64, vector<64xf8E4M3FN>, i64, vector<64xf8E4M3FN>, vector<4xf32>
+  // CHECK: amdgpu.scaled_wmma({{.*}} * {{.*}}) * ({{.*}} * {{.*}}) + {{.*}} {k = 128 : i32, m = 32 : i32, n = 16 : i32} : i32, vector<128xf4E2M1FN>, i32, vector<64xf4E2M1FN>, vector<8xf32>
+  %5 = amdgpu.scaled_wmma (%scale32 * %fp4_src_a) * (%scale32 * %fp4_src_b) + %dst1
+    { m = 32 : i32, n = 16 : i32, k = 128 : i32 } : i32, vector<128xf4E2M1FN>, i32, vector<64xf4E2M1FN>, vector<8xf32>
+  func.return
+}

>From e8e4aa1c7f1dc9d0421a4a907e8dc9b156387758 Mon Sep 17 00:00:00 2001
From: Justin Rosner <justin.rosner at amd.com>
Date: Thu, 27 Nov 2025 19:44:22 +0000
Subject: [PATCH 3/4] Clang-format

---
 mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 38 +++++-----
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           | 69 +++++++++----------
 mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp  | 12 ++--
 3 files changed, 57 insertions(+), 62 deletions(-)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index a2201d3127370..9d65130154010 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -953,9 +953,10 @@ def ScaledMFMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[32], [F8E5M2, F8E4M3FN
 def ScaledMFMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 16], [F32]>]>;
 
 // scaled_wmma
-def ScaledWMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[64], [F8E5M2, F8E4M3FN]>,
-                                   VectorOfLengthAndType<[64], [F6E2M3FN, F6E3M2FN]>,
-                                   VectorOfLengthAndType<[64, 128], [F4E2M1FN]>]>;
+def ScaledWMMAInTypes
+    : AnyTypeOf<[VectorOfLengthAndType<[64], [F8E5M2, F8E4M3FN]>,
+                 VectorOfLengthAndType<[64], [F6E2M3FN, F6E3M2FN]>,
+                 VectorOfLengthAndType<[64, 128], [F4E2M1FN]>]>;
 def ScaledWMMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8], [F32]>]>;
 
 // wmma
@@ -1225,24 +1226,19 @@ def AMDGPU_ScaledMFMAOp :
   let hasCanonicalizer = 1;
 }
 
-def AMDGPU_ScaledWMMAOp :
-    AMDGPU_Op<"scaled_wmma", [AllTypesMatch<["destC", "destD"]>,
-                              Pure]>,
-    Arguments<(ins
-                   ConfinedAttr<I32Attr, [IntIsOneOf<[16, 32]>]>:$m,
-                   ConfinedAttr<I32Attr, [IntIsOneOf<[16]>]>:$n,
-                   ConfinedAttr<I32Attr, [IntIsOneOf<[128]>]>:$k,
-                   ScaledWMMAInTypes:$sourceA,
-                   ScaledWMMAInTypes:$sourceB,
-                   ScaledWMMAOutTypes:$destC,
-                   AnyTypeOf<[I32, I64]>:$scaleA,
-                   AnyTypeOf<[I32, I64]>:$scaleB,
-                   DefaultValuedAttr<I32Attr, "0">:$scaleAType,
-                   DefaultValuedAttr<I32Attr, "0">:$fmtScaleA,
-                   DefaultValuedAttr<I32Attr, "0">:$scaleBType,
-                   DefaultValuedAttr<I32Attr, "0">:$fmtScaleB
-                   )>,
-    Results<(outs ScaledWMMAOutTypes: $destD)> {
+def AMDGPU_ScaledWMMAOp
+    : AMDGPU_Op<"scaled_wmma", [AllTypesMatch<["destC", "destD"]>, Pure]>,
+      Arguments<(ins ConfinedAttr<I32Attr, [IntIsOneOf<[16, 32]>]>:$m,
+          ConfinedAttr<I32Attr, [IntIsOneOf<[16]>]>:$n,
+          ConfinedAttr<I32Attr, [IntIsOneOf<[128]>]>:$k,
+          ScaledWMMAInTypes:$sourceA, ScaledWMMAInTypes:$sourceB,
+          ScaledWMMAOutTypes:$destC, AnyTypeOf<[I32, I64]>:$scaleA,
+          AnyTypeOf<[I32, I64]>:$scaleB,
+          DefaultValuedAttr<I32Attr, "0">:$scaleAType,
+          DefaultValuedAttr<I32Attr, "0">:$fmtScaleA,
+          DefaultValuedAttr<I32Attr, "0">:$scaleBType,
+          DefaultValuedAttr<I32Attr, "0">:$fmtScaleB)>,
+      Results<(outs ScaledWMMAOutTypes:$destD)> {
   let summary = "MLIR wrapper for RDNA scaled wmma instructions";
   let description = [{
     The `amdgpu.scaled_wmma` op is an MLIR wrapper around intrinsics for scaled
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index f4034f44d06b8..731e33c82b75f 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1212,12 +1212,11 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
     }();
     OperationState loweredOp(loc, intrinsicName);
     loweredOp.addTypes(intrinsicOutType);
-    loweredOp.addOperands(
-        {packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceA(),
-                                      allowBf16),
-         packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceB(),
-                                      allowBf16),
-         adaptor.getDestC()});
+    loweredOp.addOperands({packSmallFloatVectorOperand(
+                               rewriter, loc, adaptor.getSourceA(), allowBf16),
+                           packSmallFloatVectorOperand(
+                               rewriter, loc, adaptor.getSourceB(), allowBf16),
+                           adaptor.getDestC()});
     if (isScaled) {
       Value zero = createI32Constant(rewriter, loc, 0);
       auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
@@ -1401,13 +1400,14 @@ struct ScaledWMMAOpLowering : public ConvertOpToLLVMPattern<ScaledWMMAOp> {
     bool is32x16 = (m == 32 && n == 16 && k == 128);
 
     if (m == 16 && n == 16 && k == 128) {
-      intrinsicName = isScale16
-                ? ROCDL::wmma_scale16_f32_16x16x128_f8f6f4::getOperationName()
-                : ROCDL::wmma_scale_f32_16x16x128_f8f6f4::getOperationName();
+      intrinsicName =
+          isScale16
+              ? ROCDL::wmma_scale16_f32_16x16x128_f8f6f4::getOperationName()
+              : ROCDL::wmma_scale_f32_16x16x128_f8f6f4::getOperationName();
     } else if (is32x16) {
-      intrinsicName = isScale16
-                ? ROCDL::wmma_scale16_f32_32x16x128_f4::getOperationName()
-                : ROCDL::wmma_scale_f32_32x16x128_f4::getOperationName();
+      intrinsicName =
+          isScale16 ? ROCDL::wmma_scale16_f32_32x16x128_f4::getOperationName()
+                    : ROCDL::wmma_scale_f32_32x16x128_f4::getOperationName();
     } else {
       return op.emitOpError("unsupported scaled_wmma dimensions: ")
              << m << "x" << n << "x" << k;
@@ -1417,29 +1417,29 @@ struct ScaledWMMAOpLowering : public ConvertOpToLLVMPattern<ScaledWMMAOp> {
 
     // The f4 variant does not have fmtA and fmtB attributes
     if (!is32x16) {
-      attrs.push_back(rewriter.getNamedAttr("fmtA",
-                              rewriter.getI32IntegerAttr(*aFmtCode)));
-      attrs.push_back(rewriter.getNamedAttr("fmtB",
-                              rewriter.getI32IntegerAttr(*bFmtCode)));
+      attrs.push_back(
+          rewriter.getNamedAttr("fmtA", rewriter.getI32IntegerAttr(*aFmtCode)));
+      attrs.push_back(
+          rewriter.getNamedAttr("fmtB", rewriter.getI32IntegerAttr(*bFmtCode)));
     }
 
     // Add modifier attributes - modC and reuse flags default to 0/false
-    attrs.push_back(rewriter.getNamedAttr("reuseA",
-                              rewriter.getBoolAttr(false)));
-    attrs.push_back(rewriter.getNamedAttr("reuseB",
-                              rewriter.getBoolAttr(false)));
-    attrs.push_back(rewriter.getNamedAttr("modC",
-                              rewriter.getI16IntegerAttr(0)));
+    attrs.push_back(
+        rewriter.getNamedAttr("reuseA", rewriter.getBoolAttr(false)));
+    attrs.push_back(
+        rewriter.getNamedAttr("reuseB", rewriter.getBoolAttr(false)));
+    attrs.push_back(
+        rewriter.getNamedAttr("modC", rewriter.getI16IntegerAttr(0)));
 
     // Scale type/format parameters from the operation
-    attrs.push_back(rewriter.getNamedAttr("scaleAType",
-                              rewriter.getI32IntegerAttr(op.getScaleAType())));
-    attrs.push_back(rewriter.getNamedAttr("fmtScaleA",
-                              rewriter.getI32IntegerAttr(op.getFmtScaleA())));
-    attrs.push_back(rewriter.getNamedAttr("scaleBType",
-                              rewriter.getI32IntegerAttr(op.getScaleBType())));
-    attrs.push_back(rewriter.getNamedAttr("fmtScaleB",
-                              rewriter.getI32IntegerAttr(op.getFmtScaleB())));
+    attrs.push_back(rewriter.getNamedAttr(
+        "scaleAType", rewriter.getI32IntegerAttr(op.getScaleAType())));
+    attrs.push_back(rewriter.getNamedAttr(
+        "fmtScaleA", rewriter.getI32IntegerAttr(op.getFmtScaleA())));
+    attrs.push_back(rewriter.getNamedAttr(
+        "scaleBType", rewriter.getI32IntegerAttr(op.getScaleBType())));
+    attrs.push_back(rewriter.getNamedAttr(
+        "fmtScaleB", rewriter.getI32IntegerAttr(op.getFmtScaleB())));
 
     // Convert typed float vectors to packed i32 format if needed
     Value sourceA =
@@ -2428,10 +2428,9 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
            AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering,
            SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
            WMMAOpLowering, ScaledWMMAOpLowering, ExtPackedFp8OpLowering,
-           ScaledExtPacked816OpLowering,
-           ScaledExtPackedOpLowering, PackedScaledTruncOpLowering,
-           PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
-           GatherToLDSOpLowering, TransposeLoadOpLowering,
-           AMDGPUPermlaneLowering>(converter, chipset);
+           ScaledExtPacked816OpLowering, ScaledExtPackedOpLowering,
+           PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
+           PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
+           TransposeLoadOpLowering, AMDGPUPermlaneLowering>(converter, chipset);
   patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
 }
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 87bd1903290ae..fceded5c2c01f 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -450,7 +450,7 @@ LogicalResult ScaledWMMAOp::verify() {
   auto sourceAType = cast<VectorType>(getSourceA().getType());
   auto sourceBType = cast<VectorType>(getSourceB().getType());
   auto destType = cast<VectorType>(getDestC().getType());
-  
+
   // Validate output type is F32
   if (!destType.getElementType().isF32())
     return emitOpError("destination must have f32 element type");
@@ -459,10 +459,10 @@ LogicalResult ScaledWMMAOp::verify() {
   Type aElemType = sourceAType.getElementType();
   Type bElemType = sourceBType.getElementType();
 
-  bool aIsSmallFloat = aElemType.isFloat(4) || aElemType.isFloat(6) ||
-                       aElemType.isFloat(8);
-  bool bIsSmallFloat = bElemType.isFloat(4) || bElemType.isFloat(6) ||
-                       bElemType.isFloat(8);
+  bool aIsSmallFloat =
+      aElemType.isFloat(4) || aElemType.isFloat(6) || aElemType.isFloat(8);
+  bool bIsSmallFloat =
+      bElemType.isFloat(4) || bElemType.isFloat(6) || bElemType.isFloat(8);
 
   if (!aIsSmallFloat || !bIsSmallFloat)
     return emitOpError("source operands must have small float element types "
@@ -479,7 +479,7 @@ LogicalResult ScaledWMMAOp::verify() {
   int64_t aLen = sourceAType.getNumElements();
   int64_t bLen = sourceBType.getNumElements();
   int64_t expectedOutLen = (m == 16) ? 4 : 8;
-  
+
   if (destType.getNumElements() != expectedOutLen)
     return emitOpError("expected output vector of length " +
                        Twine(expectedOutLen) + " but got " +

>From f40c9b7c38a07562a8d10ea7959d70131237f83f Mon Sep 17 00:00:00 2001
From: Justin Rosner <justin.rosner at amd.com>
Date: Fri, 28 Nov 2025 17:19:54 +0000
Subject: [PATCH 4/4] Attend to review comments

---
 mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 49 +++++-----
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           | 92 ++++++++++++++-----
 mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp  | 50 ++++++++--
 .../AMDGPUToROCDL/wmma-gfx1250.mlir           | 66 ++++++-------
 mlir/test/Dialect/AMDGPU/ops.mlir             | 37 ++++----
 5 files changed, 181 insertions(+), 113 deletions(-)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 9d65130154010..148c7adde1a75 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -957,7 +957,8 @@ def ScaledWMMAInTypes
     : AnyTypeOf<[VectorOfLengthAndType<[64], [F8E5M2, F8E4M3FN]>,
                  VectorOfLengthAndType<[64], [F6E2M3FN, F6E3M2FN]>,
                  VectorOfLengthAndType<[64, 128], [F4E2M1FN]>]>;
-def ScaledWMMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8], [F32]>]>;
+
+def ScaledWMMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[8, 16], [F32]>]>;
 
 // wmma
 def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[2], [F32]>,
@@ -1232,42 +1233,44 @@ def AMDGPU_ScaledWMMAOp
           ConfinedAttr<I32Attr, [IntIsOneOf<[16]>]>:$n,
           ConfinedAttr<I32Attr, [IntIsOneOf<[128]>]>:$k,
           ScaledWMMAInTypes:$sourceA, ScaledWMMAInTypes:$sourceB,
-          ScaledWMMAOutTypes:$destC, AnyTypeOf<[I32, I64]>:$scaleA,
-          AnyTypeOf<[I32, I64]>:$scaleB,
-          DefaultValuedAttr<I32Attr, "0">:$scaleAType,
-          DefaultValuedAttr<I32Attr, "0">:$fmtScaleA,
-          DefaultValuedAttr<I32Attr, "0">:$scaleBType,
-          DefaultValuedAttr<I32Attr, "0">:$fmtScaleB)>,
+          ScaledWMMAOutTypes:$destC,
+          VectorOfLengthAndType<[4, 8], [F8E8M0FNU, F8E4M3FN]>:$scaleA,
+          ConfinedAttr<I32Attr, [IntIsOneOf<[0, 1]>]>:$scaleAIdx,
+          VectorOfLengthAndType<[4, 8], [F8E8M0FNU, F8E4M3FN]>:$scaleB,
+          ConfinedAttr<I32Attr, [IntIsOneOf<[0, 1]>]>:$scaleBIdx)>,
       Results<(outs ScaledWMMAOutTypes:$destD)> {
-  let summary = "MLIR wrapper for RDNA scaled wmma instructions";
+  let summary = "MLIR wrapper for scaled wmma instructions";
   let description = [{
     The `amdgpu.scaled_wmma` op is an MLIR wrapper around intrinsics for scaled
-    `wmma` instructions in the RDNA architecture. These instructions perform
-    matrix multiplication with per-block scaling of inputs, supporting fp4, fp6,
-    and fp8 data formats.
+    `wmma` instructions. These instructions perform matrix multiplication with
+    per-block scaling of inputs, supporting fp4, fp6, and fp8 data formats.
 
     The scale instructions support two tile sizes:
     - 16x16x128 with mixed f8/f6/f4 formats (output: vector<4xf32>)
     - 32x16x128 with f4 format only (output: vector<8xf32>)
 
-    The `scaleA` and `scaleB` parameters are scale exponents that can be either
-    i32 (for wmma.scale) or i64 (for wmma.scale16) to support per-block scaling.
-
-    Optional modifiers:
-    - `scaleAType`, `scaleBType`: Type of scale parameter
-    - `fmtScaleA`, `fmtScaleB`: Format of scale parameter
+    Scale parameters (`scaleA`, `scaleB`) are small vectors of f8 scale values
+    (either f8E8M0FNU, or f8E4M3FN). The index attributes (`scaleAIdx`, `scaleBIdx`)
+    select which element from the scale vector to use for scaling. During lowering,
+    these vectors are packed into i32/i64 values for the hardware intrinsics.
 
     Example:
     ```mlir
-      %0 = amdgpu.scaled_wmma (%sa * %matA) * (%sb * %matB) + %matC
-        { m = 16, n = 16, k = 128 } : i32, vector<64xf8E4M3FN>, i32, vector<64xf8E4M3FN>, vector<4xf32>
-
-      %1 = amdgpu.scaled_wmma (%sc * %matD) * (%sd * %matE) + %matF
-        { m = 32, n = 16, k = 128 } : i32, vector<128xf4E2M1FN>, i32, vector<64xf4E2M1FN>, vector<8xf32>
+      // 16x16x128: fp8 inputs
+      %0 = amdgpu.scaled_wmma (%scaleVecA[0] * %matA) * (%scaleVecB[0] * %matB) + %matC
+        { m = 16, n = 16, k = 128 } : vector<4xf8E8M0FNU>, vector<64xf8E4M3FN>,
+        vector<4xf8E8M0FNU>, vector<64xf8E4M3FN>, vector<8xf32>
+
+      // 32x16x128: fp4 inputs
+      %1 = amdgpu.scaled_wmma (%scaleVecC[1] * %matD) * (%scaleVecD[0] * %matE) + %matF
+        { m = 32, n = 16, k = 128 } : vector<8xf8E4M3FN>, vector<128xf4E2M1FN>,
+        vector<8xf8E4M3FN>, vector<64xf4E2M1FN>, vector<16xf32>
     ```
   }];
   let assemblyFormat = [{
-    `(` $scaleA `*` $sourceA `)` `*` `(` $scaleB `*` $sourceB `)` `+` $destC
+    custom<MNKDimensionList>($m, $n, $k) ` `
+    `(` $scaleA `[` $scaleAIdx `]` `*` $sourceA `)` `*`
+    `(` $scaleB `[` $scaleBIdx `]` `*` $sourceB `)` `+` $destC
     attr-dict
     `:` type($scaleA) `,` type($sourceA) `,` type($scaleB) `,` type($sourceB) `,` type($destC)
   }];
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 731e33c82b75f..dcabd31ff2df3 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -653,23 +653,33 @@ static Value packSmallFloatVectorOperand(ConversionPatternRewriter &rewriter,
   return input;
 }
 
-/// Converts the scaled MFMA operands, `scalesA` and `scalesB`, from MLIR AMDGPU
-/// dialect convention to ROCDL and LLVM AMDGPU intrinsics convention.
+/// Converts the scaled MFMA/WMMA operands, `scalesA` and `scalesB`, from MLIR
+/// AMDGPU dialect convention to ROCDL and LLVM AMDGPU intrinsics convention.
 ///
 /// Specifically:
 /// 1. If `input` is a i8 value, zero extend it to i32
-/// 2. If `input` is a vector of length 4 and type i8, cast it to i32
+/// 2. If `input` is a vector of length 4 or 8 and type i8, cast it to i32
 ///
 /// Note that the type of `input` has already been LLVM type converted:
 /// therefore 8-bit and smaller floats are represented as their corresponding
 /// `iN` integers.
-static Value castMFMAScaleOperand(ConversionPatternRewriter &rewriter,
-                                  Location loc, Value input) {
+static Value castScaleOperand(ConversionPatternRewriter &rewriter, Location loc,
+                              Value input) {
   Type inputType = input.getType();
-  Type outputType = rewriter.getI32Type();
+
+  // Handle scalar i8: zero extend to i32
   if (auto intType = dyn_cast<IntegerType>(inputType))
-    return LLVM::ZExtOp::create(rewriter, loc, outputType, input);
-  return LLVM::BitcastOp::create(rewriter, loc, outputType, input);
+    return LLVM::ZExtOp::create(rewriter, loc, rewriter.getI32Type(), input);
+
+  // Handle vector<4xi8> -> i32 or vector<8xi8> -> i64
+  if (auto vectorType = dyn_cast<VectorType>(inputType)) {
+    int64_t numElements = vectorType.getNumElements();
+    Type outputType = (numElements == 4) ? (Type)rewriter.getI32Type()
+                                         : (Type)rewriter.getI64Type();
+    return LLVM::BitcastOp::create(rewriter, loc, outputType, input);
+  }
+
+  llvm_unreachable("unexpected input type for scale operand");
 }
 
 /// Push an input operand. If it is a float type, nothing to do. If it is
@@ -1273,10 +1283,10 @@ struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> {
          createI32Constant(rewriter, loc, bTypeCode),
          /*scales idx A=*/scalesIdxA,
          /*scales A*/
-         castMFMAScaleOperand(rewriter, loc, adaptor.getScalesA()),
+         castScaleOperand(rewriter, loc, adaptor.getScalesA()),
          /*scales idx B=*/scalesIdxB,
          /*scales B*/
-         castMFMAScaleOperand(rewriter, loc, adaptor.getScalesB())});
+         castScaleOperand(rewriter, loc, adaptor.getScalesB())});
     Value lowered = rewriter.create(loweredOp)->getResult(0);
     rewriter.replaceOp(op, lowered);
     return success();
@@ -1394,9 +1404,35 @@ struct ScaledWMMAOpLowering : public ConvertOpToLLVMPattern<ScaledWMMAOp> {
     if (!aFmtCode || !bFmtCode)
       return op.emitOpError("unsupported element types for scaled_wmma");
 
-    // Determine which intrinsic to use based on dimensions and scale type
+    // Get scale vector types and determine variant (scale vs scale16)
+    auto scaleAVecType = cast<VectorType>(op.getScaleA().getType());
+    auto scaleBVecType = cast<VectorType>(op.getScaleB().getType());
+
+    bool isScale16 = (scaleAVecType.getNumElements() == 8);
+    if (isScale16 != (scaleBVecType.getNumElements() == 8))
+      return op.emitOpError("scaleA and scaleB must have equal vector length");
+
+    // Extract scale format from element types
+    Type scaleAElemType = scaleAVecType.getElementType();
+    Type scaleBElemType = scaleBVecType.getElementType();
+
+    // Map f8 types to format codes
+    auto getScaleFormat = [](Type elemType) -> std::optional<uint32_t> {
+      if (isa<Float8E8M0FNUType>(elemType))
+        return 0;
+      if (isa<Float8E4M3FNType>(elemType))
+        return 2;
+      return std::nullopt;
+    };
+
+    std::optional<uint32_t> scaleAFmt = getScaleFormat(scaleAElemType);
+    std::optional<uint32_t> scaleBFmt = getScaleFormat(scaleBElemType);
+
+    if (!scaleAFmt || !scaleBFmt)
+      return op.emitOpError("unsupported scale element types");
+
+    // Determine which intrinsic to use based on dimensions
     StringRef intrinsicName;
-    bool isScale16 = adaptor.getScaleA().getType().isInteger(64);
     bool is32x16 = (m == 32 && n == 16 && k == 128);
 
     if (m == 16 && n == 16 && k == 128) {
@@ -1423,35 +1459,41 @@ struct ScaledWMMAOpLowering : public ConvertOpToLLVMPattern<ScaledWMMAOp> {
           rewriter.getNamedAttr("fmtB", rewriter.getI32IntegerAttr(*bFmtCode)));
     }
 
-    // Add modifier attributes - modC and reuse flags default to 0/false
-    attrs.push_back(
-        rewriter.getNamedAttr("reuseA", rewriter.getBoolAttr(false)));
-    attrs.push_back(
-        rewriter.getNamedAttr("reuseB", rewriter.getBoolAttr(false)));
+    // modC uses default value of 0
     attrs.push_back(
         rewriter.getNamedAttr("modC", rewriter.getI16IntegerAttr(0)));
 
-    // Scale type/format parameters from the operation
+    // Scale attributes
     attrs.push_back(rewriter.getNamedAttr(
-        "scaleAType", rewriter.getI32IntegerAttr(op.getScaleAType())));
+        "scaleAType", rewriter.getI32IntegerAttr(op.getScaleAIdx())));
     attrs.push_back(rewriter.getNamedAttr(
-        "fmtScaleA", rewriter.getI32IntegerAttr(op.getFmtScaleA())));
+        "fmtScaleA", rewriter.getI32IntegerAttr(*scaleAFmt)));
     attrs.push_back(rewriter.getNamedAttr(
-        "scaleBType", rewriter.getI32IntegerAttr(op.getScaleBType())));
+        "scaleBType", rewriter.getI32IntegerAttr(op.getScaleBIdx())));
     attrs.push_back(rewriter.getNamedAttr(
-        "fmtScaleB", rewriter.getI32IntegerAttr(op.getFmtScaleB())));
+        "fmtScaleB", rewriter.getI32IntegerAttr(*scaleBFmt)));
+
+    // Reuse flags use default value of false
+    attrs.push_back(
+        rewriter.getNamedAttr("reuseA", rewriter.getBoolAttr(false)));
+    attrs.push_back(
+        rewriter.getNamedAttr("reuseB", rewriter.getBoolAttr(false)));
 
-    // Convert typed float vectors to packed i32 format if needed
+    // Convert typed float vectors to packed format
     Value sourceA =
         packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceA());
     Value sourceB =
         packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceB());
 
+    // Pack scale vectors into i32/i64
+    Value packedScaleA = castScaleOperand(rewriter, loc, adaptor.getScaleA());
+    Value packedScaleB = castScaleOperand(rewriter, loc, adaptor.getScaleB());
+
     // Create the intrinsic call
     OperationState loweredOp(loc, intrinsicName);
     loweredOp.addTypes(outType);
-    loweredOp.addOperands({sourceA, sourceB, adaptor.getDestC(),
-                           adaptor.getScaleA(), adaptor.getScaleB()});
+    loweredOp.addOperands(
+        {sourceA, sourceB, adaptor.getDestC(), packedScaleA, packedScaleB});
     loweredOp.addAttributes(attrs);
 
     Operation *lowered = rewriter.create(loweredOp);
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index fceded5c2c01f..3a33432da5b14 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -468,17 +468,11 @@ LogicalResult ScaledWMMAOp::verify() {
     return emitOpError("source operands must have small float element types "
                        "(fp4/fp6/fp8)");
 
-  // Validate scale types match (both i32 or both i64)
-  Type scaleAType = getScaleA().getType();
-  Type scaleBType = getScaleB().getType();
-  if (scaleAType != scaleBType)
-    return emitOpError("scaleA and scaleB must have the same type");
-
   // Validate vector lengths based on dimensions
   int64_t m = getM();
   int64_t aLen = sourceAType.getNumElements();
   int64_t bLen = sourceBType.getNumElements();
-  int64_t expectedOutLen = (m == 16) ? 4 : 8;
+  int64_t expectedOutLen = (m == 16) ? 8 : 16;
 
   if (destType.getNumElements() != expectedOutLen)
     return emitOpError("expected output vector of length " +
@@ -510,7 +504,47 @@ LogicalResult ScaledWMMAOp::verify() {
           Twine(bLen));
   }
 
-  return success();
+  // Validate scale types and their compatibility with matrix element types
+  auto scaleAType = cast<VectorType>(getScaleA().getType());
+  auto scaleBType = cast<VectorType>(getScaleB().getType());
+  Type scaleAElemType = scaleAType.getElementType();
+  Type scaleBElemType = scaleBType.getElementType();
+
+  // Validate scale element types are valid f8 types
+  if (!scaleAElemType.isFloat(8) || !scaleBElemType.isFloat(8))
+    return emitOpError("scale operands must have f8 element types");
+
+  // Helper functions for scale type classification
+  auto isE8M0 = [](Type t) { return isa<Float8E8M0FNUType>(t); };
+  auto isE4M3 = [](Type t) {
+    return isa<Float8E4M3FNType, Float8E4M3FNUZType>(t);
+  };
+
+  bool aIsF4 = aElemType.isFloat(4);
+  bool bIsF4 = bElemType.isFloat(4);
+  bool aIsF8F6 = aElemType.isFloat(8) || aElemType.isFloat(6);
+  bool bIsF8F6 = bElemType.isFloat(8) || bElemType.isFloat(6);
+
+  // Any matrices A/B (fp8|fp6|fp4) with E8M0 scales for matrix A/B are valid
+  if (isE8M0(scaleAElemType) && isE8M0(scaleBElemType))
+    return success();
+
+  // Matrix A (F8|F6) x Matrix B (F4) with Scale A (E8M0), Scale B (E5M2|E4M3)
+  if (aIsF8F6 && isE8M0(scaleAElemType) && bIsF4 && (isE4M3(scaleBElemType)))
+    return success();
+
+  // Matrix A (F4) x Matrix B (F8|F6) with Scale A (E5M2|E4M3), Scale B (E8M0)
+  if (aIsF4 && (isE4M3(scaleAElemType)) && bIsF8F6 && isE8M0(scaleBElemType))
+    return success();
+
+  // Matrix A (F4) x Matrix B (F4) with Scale A (E4M3), Scale B (E4M3)
+  if (aIsF4 && bIsF4 && isE4M3(scaleAElemType) && isE4M3(scaleBElemType))
+    return success();
+
+  // No valid combination matched
+  return emitOpError("invalid combination of matrix and scale types: ")
+         << "sourceA=" << aElemType << ", scaleA=" << scaleAElemType
+         << ", sourceB=" << bElemType << ", scaleB=" << scaleBElemType;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir
index d187e62484059..896b5dba27b14 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir
@@ -90,74 +90,68 @@ func.func @wmma_k128(%arg0 : vector<64xf8E4M3FN>, %arg1 : vector<64xf8E5M2>,
 }
 
 // CHECK-LABEL: @wmma_scale_16x16x128_fp8
-func.func @wmma_scale_16x16x128_fp8(%arg0 : vector<64xf8E4M3FN>, %arg1 : vector<64xf8E5M2>,
-                                    %arg2 : vector<4xf32>, %arg3 : i32, %arg4 : i32) {
-  // CHECK: rocdl.wmma.scale.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg2, %arg3, %arg4 : (vector<16xi32>, vector<16xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
-  %0 = amdgpu.scaled_wmma (%arg3 * %arg0) * (%arg4 * %arg0) + %arg2
-    { m = 16 : i32, n = 16 : i32, k = 128 : i32 } : i32, vector<64xf8E4M3FN>, i32, vector<64xf8E4M3FN>, vector<4xf32>
+func.func @wmma_scale_16x16x128_fp8(%arg0 : vector<64xf8E4M3FN>, %arg1 : vector<64xf6E2M3FN>,
+                                    %arg2 : vector<8xf32>, %arg3 : vector<4xf8E8M0FNU>) {
+  // CHECK: rocdl.wmma.scale.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg2, {{.*}}, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf32>, i32, i32) -> vector<8xf32>
+  %0 = amdgpu.scaled_wmma 16x16x128 (%arg3[0] * %arg0) * (%arg3[0] * %arg0) + %arg2 : vector<4xf8E8M0FNU>, vector<64xf8E4M3FN>, vector<4xf8E8M0FNU>, vector<64xf8E4M3FN>, vector<8xf32>
 
-  // CHECK: rocdl.wmma.scale.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg2, %arg3, %arg4 {fmtA = 1 : i32, fmtB = 1 : i32} : (vector<16xi32>, vector<16xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
-  %1 = amdgpu.scaled_wmma (%arg3 * %arg1) * (%arg4 * %arg1) + %arg2
-    { m = 16 : i32, n = 16 : i32, k = 128 : i32 } : i32, vector<64xf8E5M2>, i32, vector<64xf8E5M2>, vector<4xf32>
+  // CHECK: rocdl.wmma.scale.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg2, {{.*}}, {{.*}} {fmtA = 2 : i32, fmtB = 2 : i32, scaleAType = 1 : i32} : (vector<12xi32>, vector<12xi32>, vector<8xf32>, i32, i32) -> vector<8xf32>
+  %1 = amdgpu.scaled_wmma 16x16x128 (%arg3[1] * %arg1) * (%arg3[0] * %arg1) + %arg2 : vector<4xf8E8M0FNU>, vector<64xf6E2M3FN>, vector<4xf8E8M0FNU>, vector<64xf6E2M3FN>, vector<8xf32>
 
   func.return
 }
 
 // CHECK-LABEL: @wmma_scale_16x16x128_fp6
 func.func @wmma_scale_16x16x128_fp6(%arg0 : vector<64xf6E2M3FN>, %arg1 : vector<64xf6E3M2FN>,
-                                    %arg2 : vector<4xf32>, %arg3 : i32, %arg4 : i32) {
-  // CHECK: rocdl.wmma.scale.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg2, %arg3, %arg4 {fmtA = 2 : i32, fmtB = 2 : i32} : (vector<12xi32>, vector<12xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
-  %0 = amdgpu.scaled_wmma (%arg3 * %arg0) * (%arg4 * %arg0) + %arg2
-    { m = 16 : i32, n = 16 : i32, k = 128 : i32 } : i32, vector<64xf6E2M3FN>, i32, vector<64xf6E2M3FN>, vector<4xf32>
+                                    %arg2 : vector<8xf32>, %arg3 : vector<4xf8E8M0FNU>) {
+  // CHECK: rocdl.wmma.scale.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg2, {{.*}}, {{.*}} {fmtA = 2 : i32, fmtB = 2 : i32} : (vector<12xi32>, vector<12xi32>, vector<8xf32>, i32, i32) -> vector<8xf32>
+  %0 = amdgpu.scaled_wmma 16x16x128 (%arg3[0] * %arg0) * (%arg3[0] * %arg0) + %arg2 : vector<4xf8E8M0FNU>, vector<64xf6E2M3FN>, vector<4xf8E8M0FNU>, vector<64xf6E2M3FN>, vector<8xf32>
 
-  // CHECK: rocdl.wmma.scale.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg2, %arg3, %arg4 {fmtA = 3 : i32, fmtB = 3 : i32} : (vector<12xi32>, vector<12xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
-  %1 = amdgpu.scaled_wmma (%arg3 * %arg1) * (%arg4 * %arg1) + %arg2
-    { m = 16 : i32, n = 16 : i32, k = 128 : i32 } : i32, vector<64xf6E3M2FN>, i32, vector<64xf6E3M2FN>, vector<4xf32>
+  // CHECK: rocdl.wmma.scale.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg2, {{.*}}, {{.*}} {fmtA = 3 : i32, fmtB = 3 : i32} : (vector<12xi32>, vector<12xi32>, vector<8xf32>, i32, i32) -> vector<8xf32>
+  %1 = amdgpu.scaled_wmma 16x16x128 (%arg3[0] * %arg1) * (%arg3[0] * %arg1) + %arg2 : vector<4xf8E8M0FNU>, vector<64xf6E3M2FN>, vector<4xf8E8M0FNU>, vector<64xf6E3M2FN>, vector<8xf32>
 
   func.return
 }
 
 // CHECK-LABEL: @wmma_scale_16x16x128_mixed
 func.func @wmma_scale_16x16x128_mixed(%arg0 : vector<64xf8E4M3FN>, %arg1 : vector<64xf6E2M3FN>,
-                                      %arg2 : vector<64xf4E2M1FN>, %arg3 : vector<4xf32>,
-                                      %arg4 : i32, %arg5 : i32) {
-  // CHECK: rocdl.wmma.scale.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg3, %arg4, %arg5 {fmtB = 2 : i32} : (vector<16xi32>, vector<12xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
-  %0 = amdgpu.scaled_wmma (%arg4 * %arg0) * (%arg5 * %arg1) + %arg3
-    { m = 16 : i32, n = 16 : i32, k = 128 : i32 } : i32, vector<64xf8E4M3FN>, i32, vector<64xf6E2M3FN>, vector<4xf32>
+                                      %arg2 : vector<64xf4E2M1FN>, %arg3 : vector<8xf32>,
+                                      %arg4 : vector<4xf8E8M0FNU>, %arg5 : vector<4xf8E4M3FN>) {
+  // CHECK: rocdl.wmma.scale.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg3, {{.*}}, {{.*}} {fmtB = 4 : i32, fmtScaleB = 2 : i32} : (vector<16xi32>, vector<8xi32>, vector<8xf32>, i32, i32) -> vector<8xf32>
+  %0 = amdgpu.scaled_wmma 16x16x128 (%arg4[0] * %arg0) * (%arg5[0] * %arg2) + %arg3 : vector<4xf8E8M0FNU>, vector<64xf8E4M3FN>, vector<4xf8E4M3FN>, vector<64xf4E2M1FN>, vector<8xf32>
 
-  // CHECK: rocdl.wmma.scale.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg3, %arg4, %arg5 {fmtA = 2 : i32, fmtB = 4 : i32} : (vector<12xi32>, vector<8xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
-  %1 = amdgpu.scaled_wmma (%arg4 * %arg1) * (%arg5 * %arg2) + %arg3
-    { m = 16 : i32, n = 16 : i32, k = 128 : i32 } : i32, vector<64xf6E2M3FN>, i32, vector<64xf4E2M1FN>, vector<4xf32>
+  // CHECK: rocdl.wmma.scale.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg3, {{.*}}, {{.*}} {fmtA = 2 : i32, fmtB = 4 : i32, fmtScaleB = 2 : i32} : (vector<12xi32>, vector<8xi32>, vector<8xf32>, i32, i32) -> vector<8xf32>
+  %1 = amdgpu.scaled_wmma 16x16x128 (%arg4[0] * %arg1) * (%arg5[0] * %arg2) + %arg3 : vector<4xf8E8M0FNU>, vector<64xf6E2M3FN>, vector<4xf8E4M3FN>, vector<64xf4E2M1FN>, vector<8xf32>
 
   func.return
 }
 
 // CHECK-LABEL: @wmma_scale16_16x16x128_fp8
-func.func @wmma_scale16_16x16x128_fp8(%arg0 : vector<64xf8E4M3FN>, %arg1 : vector<64xf8E5M2>,
-                                      %arg2 : vector<4xf32>, %arg3 : i64, %arg4 : i64) {
-  // CHECK: rocdl.wmma.scale16.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg2, %arg3, %arg4 : (vector<16xi32>, vector<16xi32>, vector<4xf32>, i64, i64) -> vector<4xf32>
-  %0 = amdgpu.scaled_wmma (%arg3 * %arg0) * (%arg4 * %arg0) + %arg2
-    { m = 16 : i32, n = 16 : i32, k = 128 : i32 } : i64, vector<64xf8E4M3FN>, i64, vector<64xf8E4M3FN>, vector<4xf32>
+func.func @wmma_scale16_16x16x128_fp8(%arg0 : vector<64xf8E4M3FN>, %arg1 : vector<64xf6E3M2FN>,
+                                      %arg2 : vector<8xf32>, %arg3 : vector<8xf8E8M0FNU>) {
+  // CHECK: rocdl.wmma.scale16.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg2, {{.*}}, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf32>, i64, i64) -> vector<8xf32>
+  %0 = amdgpu.scaled_wmma 16x16x128 (%arg3[0] * %arg0) * (%arg3[0] * %arg0) + %arg2 : vector<8xf8E8M0FNU>, vector<64xf8E4M3FN>, vector<8xf8E8M0FNU>, vector<64xf8E4M3FN>, vector<8xf32>
+
+  // CHECK: rocdl.wmma.scale16.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg2, {{.*}}, {{.*}} {fmtA = 3 : i32, fmtB = 3 : i32, scaleAType = 1 : i32} : (vector<12xi32>, vector<12xi32>, vector<8xf32>, i64, i64) -> vector<8xf32>
+  %1 = amdgpu.scaled_wmma 16x16x128 (%arg3[1] * %arg1) * (%arg3[0] * %arg1) + %arg2 : vector<8xf8E8M0FNU>, vector<64xf6E3M2FN>, vector<8xf8E8M0FNU>, vector<64xf6E3M2FN>, vector<8xf32>
 
   func.return
 }
 
 // CHECK-LABEL: @wmma_scale_32x16x128_fp4
 func.func @wmma_scale_32x16x128_fp4(%arg0 : vector<128xf4E2M1FN>, %arg1 : vector<64xf4E2M1FN>,
-                                    %arg2 : vector<8xf32>, %arg3 : i32, %arg4 : i32) {
-  // CHECK: rocdl.wmma.scale.f32.32x16x128.f4 {{.*}}, {{.*}}, %arg2, %arg3, %arg4 : (vector<16xi32>, vector<8xi32>, vector<8xf32>, i32, i32) -> vector<8xf32>
-  %0 = amdgpu.scaled_wmma (%arg3 * %arg0) * (%arg4 * %arg1) + %arg2
-    { m = 32 : i32, n = 16 : i32, k = 128 : i32 } : i32, vector<128xf4E2M1FN>, i32, vector<64xf4E2M1FN>, vector<8xf32>
+                                    %arg2 : vector<16xf32>, %arg3 : vector<4xf8E4M3FN>) {
+  // CHECK: rocdl.wmma.scale.f32.32x16x128.f4 {{.*}}, {{.*}}, %arg2, {{.*}}, {{.*}} {fmtScaleA = 2 : i32, fmtScaleB = 2 : i32} : (vector<16xi32>, vector<8xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
+  %0 = amdgpu.scaled_wmma 32x16x128 (%arg3[0] * %arg0) * (%arg3[0] * %arg1) + %arg2 : vector<4xf8E4M3FN>, vector<128xf4E2M1FN>, vector<4xf8E4M3FN>, vector<64xf4E2M1FN>, vector<16xf32>
 
   func.return
 }
 
 // CHECK-LABEL: @wmma_scale16_32x16x128_fp4
 func.func @wmma_scale16_32x16x128_fp4(%arg0 : vector<128xf4E2M1FN>, %arg1 : vector<64xf4E2M1FN>,
-                                      %arg2 : vector<8xf32>, %arg3 : i64, %arg4 : i64) {
-  // CHECK: rocdl.wmma.scale16.f32.32x16x128.f4 {{.*}}, {{.*}}, %arg2, %arg3, %arg4 : (vector<16xi32>, vector<8xi32>, vector<8xf32>, i64, i64) -> vector<8xf32>
-  %0 = amdgpu.scaled_wmma (%arg3 * %arg0) * (%arg4 * %arg1) + %arg2
-    { m = 32 : i32, n = 16 : i32, k = 128 : i32 } : i64, vector<128xf4E2M1FN>, i64, vector<64xf4E2M1FN>, vector<8xf32>
+                                      %arg2 : vector<16xf32>, %arg3 : vector<8xf8E4M3FN>) {
+  // CHECK: rocdl.wmma.scale16.f32.32x16x128.f4 {{.*}}, {{.*}}, %arg2, {{.*}}, {{.*}} {fmtScaleA = 2 : i32, fmtScaleB = 2 : i32} : (vector<16xi32>, vector<8xi32>, vector<16xf32>, i64, i64) -> vector<16xf32>
+  %0 = amdgpu.scaled_wmma 32x16x128 (%arg3[0] * %arg0) * (%arg3[0] * %arg1) + %arg2 : vector<8xf8E4M3FN>, vector<128xf4E2M1FN>, vector<8xf8E4M3FN>, vector<64xf4E2M1FN>, vector<16xf32>
 
   func.return
 }
diff --git a/mlir/test/Dialect/AMDGPU/ops.mlir b/mlir/test/Dialect/AMDGPU/ops.mlir
index f1492b2d4d14e..4e635cefcdf8f 100644
--- a/mlir/test/Dialect/AMDGPU/ops.mlir
+++ b/mlir/test/Dialect/AMDGPU/ops.mlir
@@ -698,28 +698,23 @@ func.func @make_dma_base(%idx: index, %mem: memref<8xi32>, %smem: memref<8xi32,
 }
 
 // CHECK-LABEL: func @wmma_scale
-func.func @wmma_scale(%fp8_src: vector<64xf8E4M3FN>, %bf8_src: vector<64xf8E5M2>,
+func.func @wmma_scale(%fp8_src: vector<64xf8E4M3FN>, %fp6_alt_src: vector<64xf6E3M2FN>,
                       %fp6_src: vector<64xf6E2M3FN>, %fp4_src_a: vector<128xf4E2M1FN>,
                       %fp4_src_b: vector<64xf4E2M1FN>,
-                      %dst0: vector<4xf32>, %dst1: vector<8xf32>,
-                      %scale32: i32, %scale64: i64) {
-  // CHECK: amdgpu.scaled_wmma({{.*}} * {{.*}}) * ({{.*}} * {{.*}}) + {{.*}} {k = 128 : i32, m = 16 : i32, n = 16 : i32} : i32, vector<64xf8E4M3FN>, i32, vector<64xf8E4M3FN>, vector<4xf32>
-  %0 = amdgpu.scaled_wmma (%scale32 * %fp8_src) * (%scale32 * %fp8_src) + %dst0
-    { m = 16 : i32, n = 16 : i32, k = 128 : i32 } : i32, vector<64xf8E4M3FN>, i32, vector<64xf8E4M3FN>, vector<4xf32>
-  // CHECK: amdgpu.scaled_wmma({{.*}} * {{.*}}) * ({{.*}} * {{.*}}) + {{.*}} {k = 128 : i32, m = 16 : i32, n = 16 : i32} : i32, vector<64xf8E5M2>, i32, vector<64xf8E5M2>, vector<4xf32>
-  %1 = amdgpu.scaled_wmma (%scale32 * %bf8_src) * (%scale32 * %bf8_src) + %dst0
-    { m = 16 : i32, n = 16 : i32, k = 128 : i32 } : i32, vector<64xf8E5M2>, i32, vector<64xf8E5M2>, vector<4xf32>
-  // CHECK: amdgpu.scaled_wmma({{.*}} * {{.*}}) * ({{.*}} * {{.*}}) + {{.*}} {k = 128 : i32, m = 16 : i32, n = 16 : i32} : i32, vector<64xf6E2M3FN>, i32, vector<64xf6E2M3FN>, vector<4xf32>
-  %2 = amdgpu.scaled_wmma (%scale32 * %fp6_src) * (%scale32 * %fp6_src) + %dst0
-    { m = 16 : i32, n = 16 : i32, k = 128 : i32 } : i32, vector<64xf6E2M3FN>, i32, vector<64xf6E2M3FN>, vector<4xf32>
-  // CHECK: amdgpu.scaled_wmma({{.*}} * {{.*}}) * ({{.*}} * {{.*}}) + {{.*}} {k = 128 : i32, m = 16 : i32, n = 16 : i32} : i32, vector<64xf4E2M1FN>, i32, vector<64xf6E2M3FN>, vector<4xf32>
-  %3 = amdgpu.scaled_wmma (%scale32 * %fp4_src_b) * (%scale32 * %fp6_src) + %dst0
-    { m = 16 : i32, n = 16 : i32, k = 128 : i32 } : i32, vector<64xf4E2M1FN>, i32, vector<64xf6E2M3FN>, vector<4xf32>
-  // CHECK: amdgpu.scaled_wmma({{.*}} * {{.*}}) * ({{.*}} * {{.*}}) + {{.*}} {k = 128 : i32, m = 16 : i32, n = 16 : i32} : i64, vector<64xf8E4M3FN>, i64, vector<64xf8E4M3FN>, vector<4xf32>
-  %4 = amdgpu.scaled_wmma (%scale64 * %fp8_src) * (%scale64 * %fp8_src) + %dst0
-    { m = 16 : i32, n = 16 : i32, k = 128 : i32 } : i64, vector<64xf8E4M3FN>, i64, vector<64xf8E4M3FN>, vector<4xf32>
-  // CHECK: amdgpu.scaled_wmma({{.*}} * {{.*}}) * ({{.*}} * {{.*}}) + {{.*}} {k = 128 : i32, m = 32 : i32, n = 16 : i32} : i32, vector<128xf4E2M1FN>, i32, vector<64xf4E2M1FN>, vector<8xf32>
-  %5 = amdgpu.scaled_wmma (%scale32 * %fp4_src_a) * (%scale32 * %fp4_src_b) + %dst1
-    { m = 32 : i32, n = 16 : i32, k = 128 : i32 } : i32, vector<128xf4E2M1FN>, i32, vector<64xf4E2M1FN>, vector<8xf32>
+                      %dst0: vector<8xf32>, %dst1: vector<16xf32>,
+                      %scale_vec4: vector<4xf8E8M0FNU>, %scale_vec8: vector<8xf8E8M0FNU>,
+                      %scale_vec4_e4m3: vector<4xf8E4M3FN>) {
+  // CHECK: amdgpu.scaled_wmma 16x16x128 ({{.*}}[{{.*}}] * {{.*}}) * ({{.*}}[{{.*}}] * {{.*}}) + {{.*}} : vector<4xf8E8M0FNU>, vector<64xf8E4M3FN>, vector<4xf8E8M0FNU>, vector<64xf8E4M3FN>, vector<8xf32>
+  %0 = amdgpu.scaled_wmma 16x16x128 (%scale_vec4[0] * %fp8_src) * (%scale_vec4[0] * %fp8_src) + %dst0 : vector<4xf8E8M0FNU>, vector<64xf8E4M3FN>, vector<4xf8E8M0FNU>, vector<64xf8E4M3FN>, vector<8xf32>
+  // CHECK: amdgpu.scaled_wmma 16x16x128 ({{.*}}[{{.*}}] * {{.*}}) * ({{.*}}[{{.*}}] * {{.*}}) + {{.*}} : vector<4xf8E8M0FNU>, vector<64xf6E3M2FN>, vector<4xf8E8M0FNU>, vector<64xf6E3M2FN>, vector<8xf32>
+  %1 = amdgpu.scaled_wmma 16x16x128 (%scale_vec4[0] * %fp6_alt_src) * (%scale_vec4[0] * %fp6_alt_src) + %dst0 : vector<4xf8E8M0FNU>, vector<64xf6E3M2FN>, vector<4xf8E8M0FNU>, vector<64xf6E3M2FN>, vector<8xf32>
+  // CHECK: amdgpu.scaled_wmma 16x16x128 ({{.*}}[{{.*}}] * {{.*}}) * ({{.*}}[{{.*}}] * {{.*}}) + {{.*}} : vector<4xf8E8M0FNU>, vector<64xf6E2M3FN>, vector<4xf8E8M0FNU>, vector<64xf6E2M3FN>, vector<8xf32>
+  %2 = amdgpu.scaled_wmma 16x16x128 (%scale_vec4[0] * %fp6_src) * (%scale_vec4[0] * %fp6_src) + %dst0 : vector<4xf8E8M0FNU>, vector<64xf6E2M3FN>, vector<4xf8E8M0FNU>, vector<64xf6E2M3FN>, vector<8xf32>
+  // CHECK: amdgpu.scaled_wmma 16x16x128 ({{.*}}[{{.*}}] * {{.*}}) * ({{.*}}[{{.*}}] * {{.*}}) + {{.*}} : vector<4xf8E4M3FN>, vector<64xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<64xf6E2M3FN>, vector<8xf32>
+  %3 = amdgpu.scaled_wmma 16x16x128 (%scale_vec4_e4m3[0] * %fp4_src_b) * (%scale_vec4[0] * %fp6_src) + %dst0 : vector<4xf8E4M3FN>, vector<64xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<64xf6E2M3FN>, vector<8xf32>
+  // CHECK: amdgpu.scaled_wmma 16x16x128 ({{.*}}[{{.*}}] * {{.*}}) * ({{.*}}[{{.*}}] * {{.*}}) + {{.*}} : vector<8xf8E8M0FNU>, vector<64xf8E4M3FN>, vector<8xf8E8M0FNU>, vector<64xf8E4M3FN>, vector<8xf32>
+  %4 = amdgpu.scaled_wmma 16x16x128 (%scale_vec8[0] * %fp8_src) * (%scale_vec8[0] * %fp8_src) + %dst0 : vector<8xf8E8M0FNU>, vector<64xf8E4M3FN>, vector<8xf8E8M0FNU>, vector<64xf8E4M3FN>, vector<8xf32>
+  // CHECK: amdgpu.scaled_wmma 32x16x128 ({{.*}}[{{.*}}] * {{.*}}) * ({{.*}}[{{.*}}] * {{.*}}) + {{.*}} : vector<4xf8E4M3FN>, vector<128xf4E2M1FN>, vector<4xf8E4M3FN>, vector<64xf4E2M1FN>, vector<16xf32>
+  %5 = amdgpu.scaled_wmma 32x16x128 (%scale_vec4_e4m3[0] * %fp4_src_a) * (%scale_vec4_e4m3[0] * %fp4_src_b) + %dst1 : vector<4xf8E4M3FN>, vector<128xf4E2M1FN>, vector<4xf8E4M3FN>, vector<64xf4E2M1FN>, vector<16xf32>
   func.return
 }



More information about the Mlir-commits mailing list