[Mlir-commits] [mlir] Define a amdgpu.scaling_mfma wrapper (PR #137498)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Apr 26 23:03:26 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-amdgpu

Author: Muzammil (Muzammiluddin-Syed-ECE)

<details>
<summary>Changes</summary>

Create a wrapper around the new scaled MFMAs that operate on specific element types and tile sizes.

See [Issue](https://github.com/iree-org/iree/issues/20616).

---
Full diff: https://github.com/llvm/llvm-project/pull/137498.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td (+48) 
- (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+60-4) 
- (modified) mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp (+33) 
- (modified) mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir (+47) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index f14aa5a2e1564..d1c601882fc93 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -830,4 +830,52 @@ def AMDGPU_GatherToLDSOp :
   let hasVerifier = 1;
 }
 
+def AMDGPU_ScaledMFMAOp :
+    AMDGPU_Op<"scaled_mfma", [AllTypesMatch<["destC", "destD"]>,
+                        Pure]>,
+    Arguments<(ins
+                   I32Attr:$m,
+                   I32Attr:$n,
+                   I32Attr:$k,
+                   MFMAInTypes:$sourceA,
+                   MFMAInTypes:$sourceB,
+                   MFMAOutTypes:$destC,
+                   I32Attr:$scaleA,
+                   I32Attr:$scaleB,
+                   I32Attr:$opselA,
+                   I32Attr:$opselB)>,
+    Results<(outs MFMAOutTypes: $destD)> {
+  let summary = "MLIR wrapper for CDNA mfma instructions";
+  let description = [{
+    The `amdgpu.scaled_mfma` op is an MLIR wrapper around intrinsics
+    for various scaled versions of `mfma` instructions in the CDNA architecture, which perform
+    multiple outer products in order to allow fast matrix multiplication.
+
+    The wrapper will select an appropriate `mfma` instruction, if one is available,
+    based on the provided `m`, `k`, `n`, and `nBlks` attributes, along with the
+    types of the source and destination arguments.
+
+    Note, this wrapper allows specifying `vector<4Kxi8>` arguments to MFMA
+    intrinsics that take an integer type of width `4K`. For example,
+    one can provide a vector<4xi8> as an argument to an MFMA instruction that
+    logically takes 4 i8s but whose intrinsics are specified to take an i32.
+    In these cases, the bytes in the vector will be concatenated in little-endian
+    order (that is, v[0] will go to arg[7:0], v[1] to arg[15:8] and so on).
+
+    This wrapper takes inspiration from `amdgpu.mfma`, but has some key differences:
+    - `amdgpu.scaled_mfma` operates on fp4 (f4E2M1FN), fp6 (f6E2M3FN and f6E3M2FN) and 
+    fp8 (f8E4M3FN and f8E5M2) types using either M=N=16, K=128 or M=N=32, K=64 as their tile 
+    size. 
+    - `amdgpu.scaled_mfma` does not support broadcasting. So, `cbsz`, `abid`, and `blgp` 
+    are omitted from this wrapper.
+    - The negateA, negateB, and negateC flags in `amdgpu.mfma` are only supported for 
+    double-precision operations on gfx94x and so are not included here. 
+  }];
+  let assemblyFormat = [{
+    $sourceA `*` $sourceB `+` $destC
+    attr-dict
+    `:` type($sourceA) `,` type($sourceB) `,` type($destC)
+  }];
+  let hasVerifier = 1;
+}
 #endif // AMDGPU
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 91dbc2de65c4e..527c22ad34782 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -803,7 +803,6 @@ mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m,
   aType = getElementTypeOrSelf(aType);
   bType = getElementTypeOrSelf(bType);
   destType = getElementTypeOrSelf(destType);
-
   if (chipset < kGfx950)
     return std::nullopt;
   if (!isa<Float32Type>(destType))
@@ -833,6 +832,14 @@ mfmaOpToScaledIntrinsic(MFMAOp mfma, Chipset chipset) {
       mfma.getBlocks(), chipset);
 }
 
+static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
+mfmaOpToScaledIntrinsic(ScaledMFMAOp smfma, Chipset chipset) {
+  return mfmaOpToScaledIntrinsic(smfma.getSourceA().getType(),
+                                 smfma.getSourceB().getType(),
+                                 smfma.getDestC().getType(), smfma.getM(),
+                                 smfma.getN(), smfma.getK(), 1u, chipset);
+}
+
 /// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
 /// if one exists. This includes checking to ensure the intrinsic is supported
 /// on the architecture you are compiling for.
@@ -954,6 +961,54 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
   }
 };
 
+struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> {
+  ScaledMFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
+      : ConvertOpToLLVMPattern<ScaledMFMAOp>(converter), chipset(chipset) {}
+
+  Chipset chipset;
+
+  LogicalResult
+  matchAndRewrite(ScaledMFMAOp op, ScaledMFMAOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    Type outType = typeConverter->convertType(op.getDestD().getType());
+    Type intrinsicOutType = outType;
+    if (auto outVecType = dyn_cast<VectorType>(outType))
+      if (outVecType.getElementType().isBF16())
+        intrinsicOutType = outVecType.clone(rewriter.getI16Type());
+
+    if (chipset.majorVersion != 9 || chipset < kGfx908)
+      return op->emitOpError("Scaled MFMA only supported on gfx908+");
+    std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
+        maybeScaledIntrinsic = mfmaOpToScaledIntrinsic(op, chipset);
+    if (!maybeScaledIntrinsic.has_value())
+      return op.emitOpError(
+          "no intrinsic matching Scaled MFMA size on given chipset");
+
+    StringRef intrinsicName = std::get<0>(*maybeScaledIntrinsic);
+    OperationState loweredOp(loc, intrinsicName);
+    loweredOp.addTypes(intrinsicOutType);
+    loweredOp.addOperands(
+        {convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceA()),
+         convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceB()),
+         adaptor.getDestC()});
+    Value scaleA = createI32Constant(rewriter, loc, adaptor.getScaleA());
+    Value scaleB = createI32Constant(rewriter, loc, adaptor.getScaleB());
+    Value opselA = createI32Constant(rewriter, loc, adaptor.getOpselA());
+    Value opselB = createI32Constant(rewriter, loc, adaptor.getOpselB());
+    auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
+    loweredOp.addOperands({createI32Constant(rewriter, loc, aTypeCode),
+                           createI32Constant(rewriter, loc, bTypeCode),
+                           /*scale A byte=*/opselA, /*scale A=*/scaleA,
+                           /*scale B byte=*/opselB, /*scale B=*/scaleB});
+    Value lowered = rewriter.create(loweredOp)->getResult(0);
+    if (outType != intrinsicOutType)
+      lowered = rewriter.create<LLVM::BitcastOp>(loc, outType, lowered);
+    rewriter.replaceOp(op, lowered);
+    return success();
+  }
+};
+
 struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
   WMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
       : ConvertOpToLLVMPattern<WMMAOp>(converter), chipset(chipset) {}
@@ -1474,8 +1529,9 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
            RawBufferOpLowering<RawBufferAtomicCmpswapOp,
                                ROCDL::RawPtrBufferAtomicCmpSwap>,
            AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering,
-           MFMAOpLowering, WMMAOpLowering, ExtPackedFp8OpLowering,
-           PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
-           GatherToLDSOpLowering>(converter, chipset);
+           MFMAOpLowering, ScaledMFMAOpLowering, WMMAOpLowering,
+           ExtPackedFp8OpLowering, PackedTrunc2xFp8OpLowering,
+           PackedStochRoundFp8OpLowering, GatherToLDSOpLowering>(converter,
+                                                                 chipset);
   patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
 }
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 549a4376a4a04..90131a66448fc 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -506,6 +506,39 @@ LogicalResult GatherToLDSOp::verify() {
   return success();
 }
 
+LogicalResult ScaledMFMAOp::verify() {
+  unsigned opselA = getOpselA();
+  unsigned opselB = getOpselB();
+
+  opselA >>= 8;
+  opselB >>= 8;
+
+  if (opselA != 0)
+    return emitOpError("Opsel A must be a zero extended 8 bit value.");
+
+  if (opselB != 0)
+    return emitOpError("Opsel B must be a zero extended 8 bit value.");
+
+  auto valid = [&](Type mlirElemType){
+    return llvm::TypeSwitch<Type, bool>(mlirElemType)
+    .Case([](Float8E4M3FNType) { return true; })
+    .Case([](Float8E5M2Type) { return true; })
+    .Case([](Float6E2M3FNType) { return true; })
+    .Case([](Float6E3M2FNType) { return true; })
+    .Case([](Float4E2M1FNType) { return true; })
+    .Default([](Type) { return false; });
+  };
+  
+  Type aType = getSourceA().getType();
+  Type bType = getSourceB().getType();
+  if (!valid(aType))
+    return emitOpError("Source A must be of element type fp4, fp6 or fp8."); 
+  if (!valid(bType)) 
+    return emitOpError("Source B must be of element type fp4, fp6 or fp8.");
+  
+  return success();
+}
+
 #include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
 
 #define GET_ATTRDEF_CLASSES
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir b/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir
index de63f249bb530..f525a37e5ec80 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir
@@ -51,3 +51,50 @@ func.func @mfma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<16xf32>,
 
   func.return
 }
+
+func.func @scaled_mfma_to_rocdl(%arg0 : vector<16xf32>,
+                    %arg1 : vector<4xf32>, %arg2 : vector<32xf8E4M3FN>,
+                    %arg3 : vector<32xf8E5M2>, %arg4 : vector<32xf6E2M3FN>,
+                    %arg5 : vector<32xf6E3M2FN>, %arg6 : vector<32xf4E2M1FN>) {
+  
+  // CHECK: %[[c1:.+]] = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: %[[c2:.+]] = llvm.mlir.constant(2 : i32) : i32
+  // CHECK: %[[c0:.+]] = llvm.mlir.constant(0 : i32) : i32
+
+  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[c0]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+  amdgpu.scaled_mfma %arg2 * %arg2 + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32,  scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<16xf32>
+  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[c0]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+  amdgpu.scaled_mfma %arg2 * %arg2 + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32,  scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<4xf32>
+  
+  // CHECK: llvm.bitcast
+  
+  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c1]], %[[c1]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+  amdgpu.scaled_mfma %arg3 * %arg3 + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32,  scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<16xf32>
+  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c1]], %[[c1]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+  amdgpu.scaled_mfma %arg3 * %arg3 + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32,  scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<4xf32>
+  
+  // CHECK: llvm.bitcast
+  
+  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c2]], %[[c2]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+  amdgpu.scaled_mfma %arg4 * %arg4 + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32,  scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf6E2M3FN>, vector<32xf6E2M3FN>, vector<16xf32>
+  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c2]], %[[c2]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+  amdgpu.scaled_mfma %arg4 * %arg4 + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32,  scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf6E2M3FN>, vector<32xf6E2M3FN>, vector<4xf32>
+  
+  // CHECK: llvm.bitcast
+  // CHECK: %[[c3:.+]] = llvm.mlir.constant(3 : i32) : i32
+
+  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c3]], %[[c3]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+  amdgpu.scaled_mfma %arg5 * %arg5 + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32,  scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf6E3M2FN>, vector<32xf6E3M2FN>, vector<16xf32>
+  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c3]], %[[c3]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+  amdgpu.scaled_mfma %arg5 * %arg5 + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32,  scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf6E3M2FN>, vector<32xf6E3M2FN>, vector<4xf32>
+  
+  // CHECK: llvm.bitcast
+  // CHECK: %[[c4:.+]] = llvm.mlir.constant(4 : i32) : i32
+  
+  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c4]], %[[c4]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<4xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+  amdgpu.scaled_mfma %arg6 * %arg6 + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32,  scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf4E2M1FN>, vector<32xf4E2M1FN>, vector<16xf32>
+  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c4]], %[[c4]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<4xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+  amdgpu.scaled_mfma %arg6 * %arg6 + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32,  scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf4E2M1FN>, vector<32xf4E2M1FN>, vector<4xf32>
+
+  func.return
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/137498


More information about the Mlir-commits mailing list