[Mlir-commits] [mlir] Define a amdgpu.scaling_mfma wrapper (PR #137498)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Apr 26 23:02:54 PDT 2025
https://github.com/Muzammiluddin-Syed-ECE created https://github.com/llvm/llvm-project/pull/137498
Create a wrapper around the new scaled MFMAs that operate on specific element types and tile sizes.
See [Issue](https://github.com/iree-org/iree/issues/20616).
>From 03a7547735ddbaa04f11b80e564569ac9cb307de Mon Sep 17 00:00:00 2001
From: Muzammiluddin Syed <muzasyed at amd.com>
Date: Fri, 25 Apr 2025 14:02:10 -0500
Subject: [PATCH] Defining amdgpu.scaled_mfma op
Signed-off-by: Muzammiluddin Syed <muzasyed at amd.com>
---
mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 48 ++++++++++++++
.../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 64 +++++++++++++++++--
mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 33 ++++++++++
.../Conversion/AMDGPUToROCDL/mfma-gfx950.mlir | 47 ++++++++++++++
4 files changed, 188 insertions(+), 4 deletions(-)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index f14aa5a2e1564..d1c601882fc93 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -830,4 +830,52 @@ def AMDGPU_GatherToLDSOp :
let hasVerifier = 1;
}
+def AMDGPU_ScaledMFMAOp :
+ AMDGPU_Op<"scaled_mfma", [AllTypesMatch<["destC", "destD"]>,
+ Pure]>,
+ Arguments<(ins
+ I32Attr:$m,
+ I32Attr:$n,
+ I32Attr:$k,
+ MFMAInTypes:$sourceA,
+ MFMAInTypes:$sourceB,
+ MFMAOutTypes:$destC,
+ I32Attr:$scaleA,
+ I32Attr:$scaleB,
+ I32Attr:$opselA,
+ I32Attr:$opselB)>,
+ Results<(outs MFMAOutTypes: $destD)> {
+ let summary = "MLIR wrapper for CDNA mfma instructions";
+ let description = [{
+ The `amdgpu.scaled_mfma` op is an MLIR wrapper around intrinsics
+ for various scaled versions of `mfma` instructions in the CDNA architecture, which perform
+ multiple outer products in order to allow fast matrix multiplication.
+
+ The wrapper will select an appropriate `mfma` instruction, if one is available,
+ based on the provided `m`, `k`, `n`, and `nBlks` attributes, along with the
+ types of the source and destination arguments.
+
+ Note, this wrapper allows specifying `vector<4Kxi8>` arguments to MFMA
+ intrinsics that take an integer type of width `4K`. For example,
+ one can provide a vector<4xi8> as an argument to an MFMA instruction that
+ logically takes 4 i8s but whose intrinsics are specified to take an i32.
+ In these cases, the bytes in the vector will be concatenated in little-endian
+ order (that is, v[0] will go to arg[7:0], v[1] to arg[15:8] and so on).
+
+ This wrapper takes inspiration from `amdgpu.mfma`, but has some key differences:
+ - `amdgpu.scaled_mfma` operates on fp4 (f4E2M1FN), fp6 (f6E2M3FN and f6E3M2FN) and
+ fp8 (f8E4M3FN and f8E5M2) types using either M=N=16, K=128 or M=N=32, K=64 as their tile
+ size.
+ - `amdgpu.scaled_mfma` does not support broadcasting. So, `cbsz`, `abid`, and `blgp`
+ are omitted from this wrapper.
+ - The negateA, negateB, and negateC flags in `amdgpu.mfma` are only supported for
+ double-precision operations on gfx94x and so are not included here.
+ }];
+ let assemblyFormat = [{
+ $sourceA `*` $sourceB `+` $destC
+ attr-dict
+ `:` type($sourceA) `,` type($sourceB) `,` type($destC)
+ }];
+ let hasVerifier = 1;
+}
#endif // AMDGPU
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 91dbc2de65c4e..527c22ad34782 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -803,7 +803,6 @@ mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m,
aType = getElementTypeOrSelf(aType);
bType = getElementTypeOrSelf(bType);
destType = getElementTypeOrSelf(destType);
-
if (chipset < kGfx950)
return std::nullopt;
if (!isa<Float32Type>(destType))
@@ -833,6 +832,14 @@ mfmaOpToScaledIntrinsic(MFMAOp mfma, Chipset chipset) {
mfma.getBlocks(), chipset);
}
+static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
+mfmaOpToScaledIntrinsic(ScaledMFMAOp smfma, Chipset chipset) {
+ return mfmaOpToScaledIntrinsic(smfma.getSourceA().getType(),
+ smfma.getSourceB().getType(),
+ smfma.getDestC().getType(), smfma.getM(),
+ smfma.getN(), smfma.getK(), 1u, chipset);
+}
+
/// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
/// if one exists. This includes checking to ensure the intrinsic is supported
/// on the architecture you are compiling for.
@@ -954,6 +961,54 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
}
};
+struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> {
+ ScaledMFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
+ : ConvertOpToLLVMPattern<ScaledMFMAOp>(converter), chipset(chipset) {}
+
+ Chipset chipset;
+
+ LogicalResult
+ matchAndRewrite(ScaledMFMAOp op, ScaledMFMAOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ Type outType = typeConverter->convertType(op.getDestD().getType());
+ Type intrinsicOutType = outType;
+ if (auto outVecType = dyn_cast<VectorType>(outType))
+ if (outVecType.getElementType().isBF16())
+ intrinsicOutType = outVecType.clone(rewriter.getI16Type());
+
+ if (chipset.majorVersion != 9 || chipset < kGfx908)
+ return op->emitOpError("Scaled MFMA only supported on gfx908+");
+ std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
+ maybeScaledIntrinsic = mfmaOpToScaledIntrinsic(op, chipset);
+ if (!maybeScaledIntrinsic.has_value())
+ return op.emitOpError(
+ "no intrinsic matching Scaled MFMA size on given chipset");
+
+ StringRef intrinsicName = std::get<0>(*maybeScaledIntrinsic);
+ OperationState loweredOp(loc, intrinsicName);
+ loweredOp.addTypes(intrinsicOutType);
+ loweredOp.addOperands(
+ {convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceA()),
+ convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceB()),
+ adaptor.getDestC()});
+ Value scaleA = createI32Constant(rewriter, loc, adaptor.getScaleA());
+ Value scaleB = createI32Constant(rewriter, loc, adaptor.getScaleB());
+ Value opselA = createI32Constant(rewriter, loc, adaptor.getOpselA());
+ Value opselB = createI32Constant(rewriter, loc, adaptor.getOpselB());
+ auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
+ loweredOp.addOperands({createI32Constant(rewriter, loc, aTypeCode),
+ createI32Constant(rewriter, loc, bTypeCode),
+ /*scale A byte=*/opselA, /*scale A=*/scaleA,
+ /*scale B byte=*/opselB, /*scale B=*/scaleB});
+ Value lowered = rewriter.create(loweredOp)->getResult(0);
+ if (outType != intrinsicOutType)
+ lowered = rewriter.create<LLVM::BitcastOp>(loc, outType, lowered);
+ rewriter.replaceOp(op, lowered);
+ return success();
+ }
+};
+
struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
WMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
: ConvertOpToLLVMPattern<WMMAOp>(converter), chipset(chipset) {}
@@ -1474,8 +1529,9 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
RawBufferOpLowering<RawBufferAtomicCmpswapOp,
ROCDL::RawPtrBufferAtomicCmpSwap>,
AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering,
- MFMAOpLowering, WMMAOpLowering, ExtPackedFp8OpLowering,
- PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
- GatherToLDSOpLowering>(converter, chipset);
+ MFMAOpLowering, ScaledMFMAOpLowering, WMMAOpLowering,
+ ExtPackedFp8OpLowering, PackedTrunc2xFp8OpLowering,
+ PackedStochRoundFp8OpLowering, GatherToLDSOpLowering>(converter,
+ chipset);
patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
}
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 549a4376a4a04..90131a66448fc 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -506,6 +506,39 @@ LogicalResult GatherToLDSOp::verify() {
return success();
}
+LogicalResult ScaledMFMAOp::verify() {
+ unsigned opselA = getOpselA();
+ unsigned opselB = getOpselB();
+
+ opselA >>= 8;
+ opselB >>= 8;
+
+ if (opselA != 0)
+ return emitOpError("Opsel A must be a zero extended 8 bit value.");
+
+ if (opselB != 0)
+ return emitOpError("Opsel B must be a zero extended 8 bit value.");
+
+ auto valid = [&](Type mlirElemType){
+ return llvm::TypeSwitch<Type, bool>(mlirElemType)
+ .Case([](Float8E4M3FNType) { return true; })
+ .Case([](Float8E5M2Type) { return true; })
+ .Case([](Float6E2M3FNType) { return true; })
+ .Case([](Float6E3M2FNType) { return true; })
+ .Case([](Float4E2M1FNType) { return true; })
+ .Default([](Type) { return false; });
+ };
+
+ Type aType = getSourceA().getType();
+ Type bType = getSourceB().getType();
+ if (!valid(aType))
+ return emitOpError("Source A must be of element type fp4, fp6 or fp8.");
+ if (!valid(bType))
+ return emitOpError("Source B must be of element type fp4, fp6 or fp8.");
+
+ return success();
+}
+
#include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
#define GET_ATTRDEF_CLASSES
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir b/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir
index de63f249bb530..f525a37e5ec80 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir
@@ -51,3 +51,50 @@ func.func @mfma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<16xf32>,
func.return
}
+
+func.func @scaled_mfma_to_rocdl(%arg0 : vector<16xf32>,
+ %arg1 : vector<4xf32>, %arg2 : vector<32xf8E4M3FN>,
+ %arg3 : vector<32xf8E5M2>, %arg4 : vector<32xf6E2M3FN>,
+ %arg5 : vector<32xf6E3M2FN>, %arg6 : vector<32xf4E2M1FN>) {
+
+ // CHECK: %[[c1:.+]] = llvm.mlir.constant(1 : i32) : i32
+ // CHECK: %[[c2:.+]] = llvm.mlir.constant(2 : i32) : i32
+ // CHECK: %[[c0:.+]] = llvm.mlir.constant(0 : i32) : i32
+
+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[c0]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+ amdgpu.scaled_mfma %arg2 * %arg2 + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32, scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<16xf32>
+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[c0]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+ amdgpu.scaled_mfma %arg2 * %arg2 + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32, scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<4xf32>
+
+ // CHECK: llvm.bitcast
+
+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c1]], %[[c1]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+ amdgpu.scaled_mfma %arg3 * %arg3 + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32, scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<16xf32>
+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c1]], %[[c1]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+ amdgpu.scaled_mfma %arg3 * %arg3 + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32, scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<4xf32>
+
+ // CHECK: llvm.bitcast
+
+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c2]], %[[c2]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+ amdgpu.scaled_mfma %arg4 * %arg4 + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32, scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf6E2M3FN>, vector<32xf6E2M3FN>, vector<16xf32>
+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c2]], %[[c2]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+ amdgpu.scaled_mfma %arg4 * %arg4 + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32, scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf6E2M3FN>, vector<32xf6E2M3FN>, vector<4xf32>
+
+ // CHECK: llvm.bitcast
+ // CHECK: %[[c3:.+]] = llvm.mlir.constant(3 : i32) : i32
+
+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c3]], %[[c3]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+ amdgpu.scaled_mfma %arg5 * %arg5 + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32, scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf6E3M2FN>, vector<32xf6E3M2FN>, vector<16xf32>
+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c3]], %[[c3]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+ amdgpu.scaled_mfma %arg5 * %arg5 + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32, scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf6E3M2FN>, vector<32xf6E3M2FN>, vector<4xf32>
+
+ // CHECK: llvm.bitcast
+ // CHECK: %[[c4:.+]] = llvm.mlir.constant(4 : i32) : i32
+
+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c4]], %[[c4]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<4xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+ amdgpu.scaled_mfma %arg6 * %arg6 + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32, scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf4E2M1FN>, vector<32xf4E2M1FN>, vector<16xf32>
+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c4]], %[[c4]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<4xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+ amdgpu.scaled_mfma %arg6 * %arg6 + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32, scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf4E2M1FN>, vector<32xf4E2M1FN>, vector<4xf32>
+
+ func.return
+}
More information about the Mlir-commits
mailing list