[Mlir-commits] [mlir] [mlir][AMDGPU] Add scaled wmma ops for gfx1250 (PR #169854)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Nov 27 11:40:52 PST 2025


github-actions[bot] wrote:

<!--LLVM CODE FORMAT COMMENT: {clang-format}-->


:warning: C/C++ code formatter, clang-format found issues in your code. :warning:

<details>
<summary>
You can test this locally with the following command:
</summary>

``````````bash
git-clang-format --diff origin/main HEAD --extensions cpp -- mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp --diff_from_common_commit
``````````

:warning:
The reproduction instructions above might return results for more than one PR
in a stack if you are using a stacked PR workflow. You can limit the results by
changing `origin/main` to the base branch/commit you want to compare against.
:warning:

</details>

<details>
<summary>
View the diff from clang-format here.
</summary>

``````````diff
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index f4034f44d..731e33c82 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1212,12 +1212,11 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
     }();
     OperationState loweredOp(loc, intrinsicName);
     loweredOp.addTypes(intrinsicOutType);
-    loweredOp.addOperands(
-        {packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceA(),
-                                      allowBf16),
-         packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceB(),
-                                      allowBf16),
-         adaptor.getDestC()});
+    loweredOp.addOperands({packSmallFloatVectorOperand(
+                               rewriter, loc, adaptor.getSourceA(), allowBf16),
+                           packSmallFloatVectorOperand(
+                               rewriter, loc, adaptor.getSourceB(), allowBf16),
+                           adaptor.getDestC()});
     if (isScaled) {
       Value zero = createI32Constant(rewriter, loc, 0);
       auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
@@ -1401,13 +1400,14 @@ struct ScaledWMMAOpLowering : public ConvertOpToLLVMPattern<ScaledWMMAOp> {
     bool is32x16 = (m == 32 && n == 16 && k == 128);
 
     if (m == 16 && n == 16 && k == 128) {
-      intrinsicName = isScale16
-                ? ROCDL::wmma_scale16_f32_16x16x128_f8f6f4::getOperationName()
-                : ROCDL::wmma_scale_f32_16x16x128_f8f6f4::getOperationName();
+      intrinsicName =
+          isScale16
+              ? ROCDL::wmma_scale16_f32_16x16x128_f8f6f4::getOperationName()
+              : ROCDL::wmma_scale_f32_16x16x128_f8f6f4::getOperationName();
     } else if (is32x16) {
-      intrinsicName = isScale16
-                ? ROCDL::wmma_scale16_f32_32x16x128_f4::getOperationName()
-                : ROCDL::wmma_scale_f32_32x16x128_f4::getOperationName();
+      intrinsicName =
+          isScale16 ? ROCDL::wmma_scale16_f32_32x16x128_f4::getOperationName()
+                    : ROCDL::wmma_scale_f32_32x16x128_f4::getOperationName();
     } else {
       return op.emitOpError("unsupported scaled_wmma dimensions: ")
              << m << "x" << n << "x" << k;
@@ -1417,29 +1417,29 @@ struct ScaledWMMAOpLowering : public ConvertOpToLLVMPattern<ScaledWMMAOp> {
 
     // The f4 variant does not have fmtA and fmtB attributes
     if (!is32x16) {
-      attrs.push_back(rewriter.getNamedAttr("fmtA",
-                              rewriter.getI32IntegerAttr(*aFmtCode)));
-      attrs.push_back(rewriter.getNamedAttr("fmtB",
-                              rewriter.getI32IntegerAttr(*bFmtCode)));
+      attrs.push_back(
+          rewriter.getNamedAttr("fmtA", rewriter.getI32IntegerAttr(*aFmtCode)));
+      attrs.push_back(
+          rewriter.getNamedAttr("fmtB", rewriter.getI32IntegerAttr(*bFmtCode)));
     }
 
     // Add modifier attributes - modC and reuse flags default to 0/false
-    attrs.push_back(rewriter.getNamedAttr("reuseA",
-                              rewriter.getBoolAttr(false)));
-    attrs.push_back(rewriter.getNamedAttr("reuseB",
-                              rewriter.getBoolAttr(false)));
-    attrs.push_back(rewriter.getNamedAttr("modC",
-                              rewriter.getI16IntegerAttr(0)));
+    attrs.push_back(
+        rewriter.getNamedAttr("reuseA", rewriter.getBoolAttr(false)));
+    attrs.push_back(
+        rewriter.getNamedAttr("reuseB", rewriter.getBoolAttr(false)));
+    attrs.push_back(
+        rewriter.getNamedAttr("modC", rewriter.getI16IntegerAttr(0)));
 
     // Scale type/format parameters from the operation
-    attrs.push_back(rewriter.getNamedAttr("scaleAType",
-                              rewriter.getI32IntegerAttr(op.getScaleAType())));
-    attrs.push_back(rewriter.getNamedAttr("fmtScaleA",
-                              rewriter.getI32IntegerAttr(op.getFmtScaleA())));
-    attrs.push_back(rewriter.getNamedAttr("scaleBType",
-                              rewriter.getI32IntegerAttr(op.getScaleBType())));
-    attrs.push_back(rewriter.getNamedAttr("fmtScaleB",
-                              rewriter.getI32IntegerAttr(op.getFmtScaleB())));
+    attrs.push_back(rewriter.getNamedAttr(
+        "scaleAType", rewriter.getI32IntegerAttr(op.getScaleAType())));
+    attrs.push_back(rewriter.getNamedAttr(
+        "fmtScaleA", rewriter.getI32IntegerAttr(op.getFmtScaleA())));
+    attrs.push_back(rewriter.getNamedAttr(
+        "scaleBType", rewriter.getI32IntegerAttr(op.getScaleBType())));
+    attrs.push_back(rewriter.getNamedAttr(
+        "fmtScaleB", rewriter.getI32IntegerAttr(op.getFmtScaleB())));
 
     // Convert typed float vectors to packed i32 format if needed
     Value sourceA =
@@ -2428,10 +2428,9 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
            AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering,
            SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
            WMMAOpLowering, ScaledWMMAOpLowering, ExtPackedFp8OpLowering,
-           ScaledExtPacked816OpLowering,
-           ScaledExtPackedOpLowering, PackedScaledTruncOpLowering,
-           PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
-           GatherToLDSOpLowering, TransposeLoadOpLowering,
-           AMDGPUPermlaneLowering>(converter, chipset);
+           ScaledExtPacked816OpLowering, ScaledExtPackedOpLowering,
+           PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
+           PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
+           TransposeLoadOpLowering, AMDGPUPermlaneLowering>(converter, chipset);
   patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
 }
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 87bd19032..fceded5c2 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -450,7 +450,7 @@ LogicalResult ScaledWMMAOp::verify() {
   auto sourceAType = cast<VectorType>(getSourceA().getType());
   auto sourceBType = cast<VectorType>(getSourceB().getType());
   auto destType = cast<VectorType>(getDestC().getType());
-  
+
   // Validate output type is F32
   if (!destType.getElementType().isF32())
     return emitOpError("destination must have f32 element type");
@@ -459,10 +459,10 @@ LogicalResult ScaledWMMAOp::verify() {
   Type aElemType = sourceAType.getElementType();
   Type bElemType = sourceBType.getElementType();
 
-  bool aIsSmallFloat = aElemType.isFloat(4) || aElemType.isFloat(6) ||
-                       aElemType.isFloat(8);
-  bool bIsSmallFloat = bElemType.isFloat(4) || bElemType.isFloat(6) ||
-                       bElemType.isFloat(8);
+  bool aIsSmallFloat =
+      aElemType.isFloat(4) || aElemType.isFloat(6) || aElemType.isFloat(8);
+  bool bIsSmallFloat =
+      bElemType.isFloat(4) || bElemType.isFloat(6) || bElemType.isFloat(8);
 
   if (!aIsSmallFloat || !bIsSmallFloat)
     return emitOpError("source operands must have small float element types "
@@ -479,7 +479,7 @@ LogicalResult ScaledWMMAOp::verify() {
   int64_t aLen = sourceAType.getNumElements();
   int64_t bLen = sourceBType.getNumElements();
   int64_t expectedOutLen = (m == 16) ? 4 : 8;
-  
+
   if (destType.getNumElements() != expectedOutLen)
     return emitOpError("expected output vector of length " +
                        Twine(expectedOutLen) + " but got " +

``````````

</details>


https://github.com/llvm/llvm-project/pull/169854


More information about the Mlir-commits mailing list