[Mlir-commits] [mlir] [mlir][AMDGPU] Add scaled wmma ops for gfx1250 (PR #169854)

Jakub Kuderski llvmlistbot at llvm.org
Sun Nov 30 06:26:25 PST 2025


================
@@ -1363,6 +1373,136 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
   }
 };
 
+struct ScaledWMMAOpLowering : public ConvertOpToLLVMPattern<ScaledWMMAOp> {
+  ScaledWMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
+      : ConvertOpToLLVMPattern<ScaledWMMAOp>(converter), chipset(chipset) {}
+
+  Chipset chipset;
+
+  LogicalResult
+  matchAndRewrite(ScaledWMMAOp op, ScaledWMMAOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    auto outType =
+        typeConverter->convertType<VectorType>(op.getDestD().getType());
+    if (!outType)
+      return rewriter.notifyMatchFailure(op, "type conversion failed");
+
+    if (chipset < Chipset(12, 5, 0))
+      return op->emitOpError("WMMA scale only supported on gfx1250+");
+
+    int64_t m = op.getM();
+    int64_t n = op.getN();
+    int64_t k = op.getK();
+
+    Type aElemType = getElementTypeOrSelf(op.getSourceA().getType());
+    Type bElemType = getElementTypeOrSelf(op.getSourceB().getType());
+
+    std::optional<uint32_t> aFmtCode = smallFloatTypeToFormatCode(aElemType);
+    std::optional<uint32_t> bFmtCode = smallFloatTypeToFormatCode(bElemType);
+
+    if (!aFmtCode || !bFmtCode)
+      return op.emitOpError("unsupported element types for scaled_wmma");
+
+    // Get scale vector types and determine variant (scale vs scale16)
+    auto scaleAVecType = cast<VectorType>(op.getScaleA().getType());
+    auto scaleBVecType = cast<VectorType>(op.getScaleB().getType());
+
+    bool isScale16 = (scaleAVecType.getNumElements() == 8);
+    if (isScale16 != (scaleBVecType.getNumElements() == 8))
+      return op.emitOpError("scaleA and scaleB must have equal vector length");
+
+    // Extract scale format from element types
+    Type scaleAElemType = scaleAVecType.getElementType();
+    Type scaleBElemType = scaleBVecType.getElementType();
+
+    // Map f8 types to format codes
+    auto getScaleFormat = [](Type elemType) -> std::optional<uint32_t> {
+      if (isa<Float8E8M0FNUType>(elemType))
+        return 0;
+      if (isa<Float8E4M3FNType>(elemType))
+        return 2;
+      return std::nullopt;
+    };
+
+    std::optional<uint32_t> scaleAFmt = getScaleFormat(scaleAElemType);
+    std::optional<uint32_t> scaleBFmt = getScaleFormat(scaleBElemType);
+
+    if (!scaleAFmt || !scaleBFmt)
+      return op.emitOpError("unsupported scale element types");
+
+    // Determine which intrinsic to use based on dimensions
+    StringRef intrinsicName;
+    bool is32x16 = (m == 32 && n == 16 && k == 128);
+
+    if (m == 16 && n == 16 && k == 128) {
+      intrinsicName =
+          isScale16
+              ? ROCDL::wmma_scale16_f32_16x16x128_f8f6f4::getOperationName()
+              : ROCDL::wmma_scale_f32_16x16x128_f8f6f4::getOperationName();
+    } else if (is32x16) {
+      intrinsicName =
+          isScale16 ? ROCDL::wmma_scale16_f32_32x16x128_f4::getOperationName()
+                    : ROCDL::wmma_scale_f32_32x16x128_f4::getOperationName();
+    } else {
+      return op.emitOpError("unsupported scaled_wmma dimensions: ")
+             << m << "x" << n << "x" << k;
+    }
----------------
kuhar wrote:

Can you outline this to a helper function?

https://github.com/llvm/llvm-project/pull/169854


More information about the Mlir-commits mailing list