[Mlir-commits] [mlir] [mlir][amdgpu] Define an amdgpu.scaling_mfma wrapper (PR #137498)

Krzysztof Drewniak llvmlistbot at llvm.org
Tue Apr 29 23:52:31 PDT 2025


================
@@ -954,6 +964,50 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
   }
 };
 
+struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> {
+  ScaledMFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
+      : ConvertOpToLLVMPattern(converter), chipset(chipset) {}
+
+  Chipset chipset;
+
+  LogicalResult
+  matchAndRewrite(ScaledMFMAOp op, ScaledMFMAOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    Type intrinsicOutType = typeConverter->convertType(op.getDestD().getType());
+
+    if (chipset.majorVersion != 9 || chipset < kGfx950)
+      return op->emitOpError("scaled MFMA only supported on gfx908+");
+    std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
+        maybeScaledIntrinsic = mfmaOpToScaledIntrinsic(op, chipset);
+    if (!maybeScaledIntrinsic.has_value())
+      return op.emitOpError(
+          "no intrinsic matching scaled MFMA size on given chipset");
+
+    auto [intrinsicName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
+    OperationState loweredOp(loc, intrinsicName);
+    loweredOp.addTypes(intrinsicOutType);
+    loweredOp.addOperands(
+        {convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceA()),
+         convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceB()),
+         adaptor.getDestC()});
+    Value scalesIdxA = createI32Constant(rewriter, loc, adaptor.getScalesIdxA());
+    Value scalesIdxB = createI32Constant(rewriter, loc, adaptor.getScalesIdxB());
+    loweredOp.addOperands(
+        {createI32Constant(rewriter, loc, aTypeCode),
+         createI32Constant(rewriter, loc, bTypeCode),
+         /*scales A*/
----------------
krzysz00 wrote:

I think these are in the wrong order: per IntrinsicsAMDGPU.td
```
           llvm_i32_ty, // op_sel (A matrix scale, 2-bits) // TODO: Make i2?
             llvm_i32_ty, // v_mfma_ld_scale_b32 src0 (A matrix scale)
             llvm_i32_ty, // op_sel (B matrix scale, 2-bits) // TODO: Make i2?
             llvm_i32_ty  // v_mfma_ld_scale_b32 src1 (B matrix scale)
            ],
```

https://github.com/llvm/llvm-project/pull/137498


More information about the Mlir-commits mailing list