[Mlir-commits] [mlir] [mlir][AMDGPU] Add scaled wmma ops for gfx1250 (PR #169854)

Jakub Kuderski llvmlistbot at llvm.org
Mon Dec 1 06:26:55 PST 2025


================
@@ -1404,93 +1427,74 @@ struct ScaledWMMAOpLowering : public ConvertOpToLLVMPattern<ScaledWMMAOp> {
     if (!aFmtCode || !bFmtCode)
       return op.emitOpError("unsupported element types for scaled_wmma");
 
-    // Get scale vector types and determine variant (scale vs scale16)
+    // Get scale vector types and determine variant (scale vs scale16).
     auto scaleAVecType = cast<VectorType>(op.getScaleA().getType());
     auto scaleBVecType = cast<VectorType>(op.getScaleB().getType());
 
-    bool isScale16 = (scaleAVecType.getNumElements() == 8);
-    if (isScale16 != (scaleBVecType.getNumElements() == 8))
+    if (scaleAVecType.getNumElements() != scaleBVecType.getNumElements())
       return op.emitOpError("scaleA and scaleB must have equal vector length");
 
-    // Extract scale format from element types
+    // Extract scale format from element types.
     Type scaleAElemType = scaleAVecType.getElementType();
     Type scaleBElemType = scaleBVecType.getElementType();
 
-    // Map f8 types to format codes
-    auto getScaleFormat = [](Type elemType) -> std::optional<uint32_t> {
-      if (isa<Float8E8M0FNUType>(elemType))
-        return 0;
-      if (isa<Float8E4M3FNType>(elemType))
-        return 2;
-      return std::nullopt;
-    };
-
-    std::optional<uint32_t> scaleAFmt = getScaleFormat(scaleAElemType);
-    std::optional<uint32_t> scaleBFmt = getScaleFormat(scaleBElemType);
+    std::optional<uint32_t> scaleAFmt = getWmmaScaleFormat(scaleAElemType);
+    std::optional<uint32_t> scaleBFmt = getWmmaScaleFormat(scaleBElemType);
 
     if (!scaleAFmt || !scaleBFmt)
       return op.emitOpError("unsupported scale element types");
 
-    // Determine which intrinsic to use based on dimensions
-    StringRef intrinsicName;
-    bool is32x16 = (m == 32 && n == 16 && k == 128);
-
-    if (m == 16 && n == 16 && k == 128) {
-      intrinsicName =
-          isScale16
-              ? ROCDL::wmma_scale16_f32_16x16x128_f8f6f4::getOperationName()
-              : ROCDL::wmma_scale_f32_16x16x128_f8f6f4::getOperationName();
-    } else if (is32x16) {
-      intrinsicName =
-          isScale16 ? ROCDL::wmma_scale16_f32_32x16x128_f4::getOperationName()
-                    : ROCDL::wmma_scale_f32_32x16x128_f4::getOperationName();
-    } else {
+    // Determine which intrinsic to use based on dimensions.
+    bool isScale16 = (scaleAVecType.getNumElements() == 8);
+    std::optional<StringRef> intrinsicName =
+        getScaledWmmaIntrinsicName(m, n, k, isScale16);
+    if (!intrinsicName)
       return op.emitOpError("unsupported scaled_wmma dimensions: ")
              << m << "x" << n << "x" << k;
-    }
 
     SmallVector<NamedAttribute, 8> attrs;
 
-    // The f4 variant does not have fmtA and fmtB attributes
+    // The f4 variant does not have fmtA and fmtB attributes.
+    bool is32x16 = (m == 32 && n == 16 && k == 128);
     if (!is32x16) {
-      attrs.push_back(
+      attrs.emplace_back(
           rewriter.getNamedAttr("fmtA", rewriter.getI32IntegerAttr(*aFmtCode)));
-      attrs.push_back(
+      attrs.emplace_back(
           rewriter.getNamedAttr("fmtB", rewriter.getI32IntegerAttr(*bFmtCode)));
----------------
kuhar wrote:

You don't need to call `rewriter.getNamedAttr` in any of these

https://github.com/llvm/llvm-project/pull/169854


More information about the Mlir-commits mailing list