[Mlir-commits] [mlir] [mlir][amdgpu] Define an amdgpu.scaling_mfma wrapper (PR #137498)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Apr 27 22:55:50 PDT 2025
https://github.com/Muzammiluddin-Syed-ECE updated https://github.com/llvm/llvm-project/pull/137498
>From c43bc26494473a002c74dfa00088cc9b58a8dd8a Mon Sep 17 00:00:00 2001
From: Muzammiluddin Syed <muzasyed at amd.com>
Date: Fri, 25 Apr 2025 14:02:10 -0500
Subject: [PATCH 1/2] Defining amdgpu.scaled_mfma op
Signed-off-by: Muzammiluddin Syed <muzasyed at amd.com>
---
mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 48 ++++++++++++++
.../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 63 ++++++++++++++++++-
mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 43 +++++++++++++
.../Conversion/AMDGPUToROCDL/mfma-gfx950.mlir | 47 ++++++++++++++
4 files changed, 198 insertions(+), 3 deletions(-)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index f14aa5a2e1564..d1c601882fc93 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -830,4 +830,52 @@ def AMDGPU_GatherToLDSOp :
let hasVerifier = 1;
}
+def AMDGPU_ScaledMFMAOp :
+ AMDGPU_Op<"scaled_mfma", [AllTypesMatch<["destC", "destD"]>,
+ Pure]>,
+ Arguments<(ins
+ I32Attr:$m,
+ I32Attr:$n,
+ I32Attr:$k,
+ MFMAInTypes:$sourceA,
+ MFMAInTypes:$sourceB,
+ MFMAOutTypes:$destC,
+ I32Attr:$scaleA,
+ I32Attr:$scaleB,
+ I32Attr:$opselA,
+ I32Attr:$opselB)>,
+ Results<(outs MFMAOutTypes: $destD)> {
+ let summary = "MLIR wrapper for CDNA mfma instructions";
+ let description = [{
+ The `amdgpu.scaled_mfma` op is an MLIR wrapper around intrinsics
+ for various scaled versions of `mfma` instructions in the CDNA architecture, which perform
+ multiple outer products in order to allow fast matrix multiplication.
+
+ The wrapper will select an appropriate `mfma` instruction, if one is available,
+ based on the provided `m`, `k`, `n`, and `nBlks` attributes, along with the
+ types of the source and destination arguments.
+
+ Note, this wrapper allows specifying `vector<4Kxi8>` arguments to MFMA
+ intrinsics that take an integer type of width `4K`. For example,
+ one can provide a vector<4xi8> as an argument to an MFMA instruction that
+ logically takes 4 i8s but whose intrinsics are specified to take an i32.
+ In these cases, the bytes in the vector will be concatenated in little-endian
+ order (that is, v[0] will go to arg[7:0], v[1] to arg[15:8] and so on).
+
+ This wrapper takes inspiration from `amdgpu.mfma`, but has some key differences:
+ - `amdgpu.scaled_mfma` operates on fp4 (f4E2M1FN), fp6 (f6E2M3FN and f6E3M2FN) and
+ fp8 (f8E4M3FN and f8E5M2) types using either M=N=16, K=128 or M=N=32, K=64 as their tile
+ size.
+ - `amdgpu.scaled_mfma` does not support broadcasting. So, `cbsz`, `abid`, and `blgp`
+ are omitted from this wrapper.
+ - The negateA, negateB, and negateC flags in `amdgpu.mfma` are only supported for
+ double-precision operations on gfx94x and so are not included here.
+ }];
+ let assemblyFormat = [{
+ $sourceA `*` $sourceB `+` $destC
+ attr-dict
+ `:` type($sourceA) `,` type($sourceB) `,` type($destC)
+ }];
+ let hasVerifier = 1;
+}
#endif // AMDGPU
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 91dbc2de65c4e..6ace61ae653de 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -833,6 +833,14 @@ mfmaOpToScaledIntrinsic(MFMAOp mfma, Chipset chipset) {
mfma.getBlocks(), chipset);
}
+static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
+mfmaOpToScaledIntrinsic(ScaledMFMAOp smfma, Chipset chipset) {
+ return mfmaOpToScaledIntrinsic(smfma.getSourceA().getType(),
+ smfma.getSourceB().getType(),
+ smfma.getDestC().getType(), smfma.getM(),
+ smfma.getN(), smfma.getK(), 1u, chipset);
+}
+
/// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
/// if one exists. This includes checking to ensure the intrinsic is supported
/// on the architecture you are compiling for.
@@ -954,6 +962,54 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
}
};
+struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> {
+ ScaledMFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
+ : ConvertOpToLLVMPattern<ScaledMFMAOp>(converter), chipset(chipset) {}
+
+ Chipset chipset;
+
+ LogicalResult
+ matchAndRewrite(ScaledMFMAOp op, ScaledMFMAOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ Type outType = typeConverter->convertType(op.getDestD().getType());
+ Type intrinsicOutType = outType;
+ if (auto outVecType = dyn_cast<VectorType>(outType))
+ if (outVecType.getElementType().isBF16())
+ intrinsicOutType = outVecType.clone(rewriter.getI16Type());
+
+ if (chipset.majorVersion != 9 || chipset < kGfx908)
+ return op->emitOpError("Scaled MFMA only supported on gfx908+");
+ std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
+ maybeScaledIntrinsic = mfmaOpToScaledIntrinsic(op, chipset);
+ if (!maybeScaledIntrinsic.has_value())
+ return op.emitOpError(
+ "no intrinsic matching Scaled MFMA size on given chipset");
+
+ StringRef intrinsicName = std::get<0>(*maybeScaledIntrinsic);
+ OperationState loweredOp(loc, intrinsicName);
+ loweredOp.addTypes(intrinsicOutType);
+ loweredOp.addOperands(
+ {convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceA()),
+ convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceB()),
+ adaptor.getDestC()});
+ Value scaleA = createI32Constant(rewriter, loc, adaptor.getScaleA());
+ Value scaleB = createI32Constant(rewriter, loc, adaptor.getScaleB());
+ Value opselA = createI32Constant(rewriter, loc, adaptor.getOpselA());
+ Value opselB = createI32Constant(rewriter, loc, adaptor.getOpselB());
+ auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
+ loweredOp.addOperands({createI32Constant(rewriter, loc, aTypeCode),
+ createI32Constant(rewriter, loc, bTypeCode),
+ /*scale A byte=*/opselA, /*scale A=*/scaleA,
+ /*scale B byte=*/opselB, /*scale B=*/scaleB});
+ Value lowered = rewriter.create(loweredOp)->getResult(0);
+ if (outType != intrinsicOutType)
+ lowered = rewriter.create<LLVM::BitcastOp>(loc, outType, lowered);
+ rewriter.replaceOp(op, lowered);
+ return success();
+ }
+};
+
struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
WMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
: ConvertOpToLLVMPattern<WMMAOp>(converter), chipset(chipset) {}
@@ -1474,8 +1530,9 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
RawBufferOpLowering<RawBufferAtomicCmpswapOp,
ROCDL::RawPtrBufferAtomicCmpSwap>,
AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering,
- MFMAOpLowering, WMMAOpLowering, ExtPackedFp8OpLowering,
- PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
- GatherToLDSOpLowering>(converter, chipset);
+ MFMAOpLowering, ScaledMFMAOpLowering, WMMAOpLowering,
+ ExtPackedFp8OpLowering, PackedTrunc2xFp8OpLowering,
+ PackedStochRoundFp8OpLowering, GatherToLDSOpLowering>(converter,
+ chipset);
patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
}
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 549a4376a4a04..5da40a06dc2f0 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -506,6 +506,49 @@ LogicalResult GatherToLDSOp::verify() {
return success();
}
+LogicalResult ScaledMFMAOp::verify() {
+ unsigned opselA = getOpselA();
+ unsigned opselB = getOpselB();
+
+ opselA >>= 8;
+ opselB >>= 8;
+
+ if (opselA != 0)
+ return emitOpError("Opsel A must be a zero extended 8 bit value.");
+
+ if (opselB != 0)
+ return emitOpError("Opsel B must be a zero extended 8 bit value.");
+
+ auto validType = [&](Type mlirElemType) {
+ return llvm::TypeSwitch<Type, bool>(mlirElemType)
+ .Case([](Float8E4M3FNType) { return true; })
+ .Case([](Float8E5M2Type) { return true; })
+ .Case([](Float6E2M3FNType) { return true; })
+ .Case([](Float6E3M2FNType) { return true; })
+ .Case([](Float4E2M1FNType) { return true; })
+ .Default([](Type) { return false; });
+ };
+
+ Type aType = getSourceA().getType();
+ Type bType = getSourceB().getType();
+ aType = getElementTypeOrSelf(aType);
+ bType = getElementTypeOrSelf(bType);
+ if (!validType(aType))
+ return emitOpError("Source A must be of element type fp4, fp6 or fp8.");
+ if (!validType(bType))
+ return emitOpError("Source B must be of element type fp4, fp6 or fp8.");
+
+ unsigned m = getM();
+ unsigned n = getN();
+ unsigned k = getK();
+ bool tileConfig1 = (m == n && n == 32 && k == 64);
+ bool tileConfig2 = (m == n && n == 16 && k == 128);
+ if (!tileConfig1 && !tileConfig2)
+ return emitOpError("Invalid tile size for scaled mfma.");
+
+ return success();
+}
+
#include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
#define GET_ATTRDEF_CLASSES
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir b/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir
index de63f249bb530..f525a37e5ec80 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir
@@ -51,3 +51,50 @@ func.func @mfma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<16xf32>,
func.return
}
+
+func.func @scaled_mfma_to_rocdl(%arg0 : vector<16xf32>,
+ %arg1 : vector<4xf32>, %arg2 : vector<32xf8E4M3FN>,
+ %arg3 : vector<32xf8E5M2>, %arg4 : vector<32xf6E2M3FN>,
+ %arg5 : vector<32xf6E3M2FN>, %arg6 : vector<32xf4E2M1FN>) {
+
+ // CHECK: %[[c1:.+]] = llvm.mlir.constant(1 : i32) : i32
+ // CHECK: %[[c2:.+]] = llvm.mlir.constant(2 : i32) : i32
+ // CHECK: %[[c0:.+]] = llvm.mlir.constant(0 : i32) : i32
+
+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[c0]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+ amdgpu.scaled_mfma %arg2 * %arg2 + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32, scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<16xf32>
+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[c0]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+ amdgpu.scaled_mfma %arg2 * %arg2 + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32, scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<4xf32>
+
+ // CHECK: llvm.bitcast
+
+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c1]], %[[c1]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+ amdgpu.scaled_mfma %arg3 * %arg3 + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32, scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<16xf32>
+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c1]], %[[c1]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+ amdgpu.scaled_mfma %arg3 * %arg3 + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32, scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<4xf32>
+
+ // CHECK: llvm.bitcast
+
+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c2]], %[[c2]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+ amdgpu.scaled_mfma %arg4 * %arg4 + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32, scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf6E2M3FN>, vector<32xf6E2M3FN>, vector<16xf32>
+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c2]], %[[c2]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+ amdgpu.scaled_mfma %arg4 * %arg4 + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32, scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf6E2M3FN>, vector<32xf6E2M3FN>, vector<4xf32>
+
+ // CHECK: llvm.bitcast
+ // CHECK: %[[c3:.+]] = llvm.mlir.constant(3 : i32) : i32
+
+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c3]], %[[c3]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+ amdgpu.scaled_mfma %arg5 * %arg5 + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32, scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf6E3M2FN>, vector<32xf6E3M2FN>, vector<16xf32>
+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c3]], %[[c3]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+ amdgpu.scaled_mfma %arg5 * %arg5 + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32, scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf6E3M2FN>, vector<32xf6E3M2FN>, vector<4xf32>
+
+ // CHECK: llvm.bitcast
+ // CHECK: %[[c4:.+]] = llvm.mlir.constant(4 : i32) : i32
+
+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c4]], %[[c4]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<4xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+ amdgpu.scaled_mfma %arg6 * %arg6 + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32, scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf4E2M1FN>, vector<32xf4E2M1FN>, vector<16xf32>
+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c4]], %[[c4]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<4xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+ amdgpu.scaled_mfma %arg6 * %arg6 + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32, scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf4E2M1FN>, vector<32xf4E2M1FN>, vector<4xf32>
+
+ func.return
+}
>From 3ba7ea8cc97b59cafeba682dadedb00a982075d4 Mon Sep 17 00:00:00 2001
From: Muzammiluddin Syed <muzasyed at amd.com>
Date: Mon, 28 Apr 2025 00:37:09 -0500
Subject: [PATCH 2/2] PR Review Round 1
Signed-off-by: Muzammiluddin Syed <muzasyed at amd.com>
---
mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 8 ++---
.../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 33 ++++++++---------
mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 35 +++++++------------
3 files changed, 34 insertions(+), 42 deletions(-)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index d1c601882fc93..d8d59c2f27924 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -804,7 +804,7 @@ def AMDGPU_GatherToLDSOp :
TypeAttr:$transferType
)>,
Results<(outs)> {
- let summary = "MLIR wrapper for CDNA mfma instructions";
+ let summary = "MLIR wrapper for CDNA Gather to LDS instructions";
let description = [{
The `amdgpu.global_load` op is a wrapper around the `global_load_lds` instructions.
@@ -845,7 +845,7 @@ def AMDGPU_ScaledMFMAOp :
I32Attr:$opselA,
I32Attr:$opselB)>,
Results<(outs MFMAOutTypes: $destD)> {
- let summary = "MLIR wrapper for CDNA mfma instructions";
+ let summary = "MLIR wrapper for CDNA scaled mfma instructions";
let description = [{
The `amdgpu.scaled_mfma` op is an MLIR wrapper around intrinsics
for various scaled versions of `mfma` instructions in the CDNA architecture, which perform
@@ -857,7 +857,7 @@ def AMDGPU_ScaledMFMAOp :
Note, this wrapper allows specifying `vector<4Kxi8>` arguments to MFMA
intrinsics that take an integer type of width `4K`. For example,
- one can provide a vector<4xi8> as an argument to an MFMA instruction that
+ one can provide a `vector<4xi8>` as an argument to an MFMA instruction that
logically takes 4 i8s but whose intrinsics are specified to take an i32.
In these cases, the bytes in the vector will be concatenated in little-endian
order (that is, v[0] will go to arg[7:0], v[1] to arg[15:8] and so on).
@@ -868,7 +868,7 @@ def AMDGPU_ScaledMFMAOp :
size.
- `amdgpu.scaled_mfma` does not support broadcasting. So, `cbsz`, `abid`, and `blgp`
are omitted from this wrapper.
- - The negateA, negateB, and negateC flags in `amdgpu.mfma` are only supported for
+ - The `negateA`, `negateB`, and `negateC` flags in `amdgpu.mfma` are only supported for
double-precision operations on gfx94x and so are not included here.
}];
let assemblyFormat = [{
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 6ace61ae653de..af71ba7de5c6b 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -23,6 +23,7 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Casting.h"
#include <optional>
namespace mlir {
@@ -826,19 +827,20 @@ mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m,
}
static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
-mfmaOpToScaledIntrinsic(MFMAOp mfma, Chipset chipset) {
- return mfmaOpToScaledIntrinsic(
- mfma.getSourceA().getType(), mfma.getSourceB().getType(),
- mfma.getDestC().getType(), mfma.getM(), mfma.getN(), mfma.getK(),
- mfma.getBlocks(), chipset);
-}
-
-static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
-mfmaOpToScaledIntrinsic(ScaledMFMAOp smfma, Chipset chipset) {
- return mfmaOpToScaledIntrinsic(smfma.getSourceA().getType(),
- smfma.getSourceB().getType(),
- smfma.getDestC().getType(), smfma.getM(),
- smfma.getN(), smfma.getK(), 1u, chipset);
+mfmaOpToScaledIntrinsic(Operation *op, Chipset chipset) {
+ if (auto mfma = llvm::dyn_cast_or_null<MFMAOp>(op)) {
+ return mfmaOpToScaledIntrinsic(
+ mfma.getSourceA().getType(), mfma.getSourceB().getType(),
+ mfma.getDestC().getType(), mfma.getM(), mfma.getN(), mfma.getK(),
+ mfma.getBlocks(), chipset);
+ }
+ if (auto smfma = llvm::dyn_cast_or_null<ScaledMFMAOp>(op)) {
+ return mfmaOpToScaledIntrinsic(smfma.getSourceA().getType(),
+ smfma.getSourceB().getType(),
+ smfma.getDestC().getType(), smfma.getM(),
+ smfma.getN(), smfma.getK(), 1u, chipset);
+ }
+ return std::nullopt;
}
/// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
@@ -964,7 +966,7 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> {
ScaledMFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
- : ConvertOpToLLVMPattern<ScaledMFMAOp>(converter), chipset(chipset) {}
+ : ConvertOpToLLVMPattern(converter), chipset(chipset) {}
Chipset chipset;
@@ -986,7 +988,7 @@ struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> {
return op.emitOpError(
"no intrinsic matching Scaled MFMA size on given chipset");
- StringRef intrinsicName = std::get<0>(*maybeScaledIntrinsic);
+ auto [intrinsicName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
OperationState loweredOp(loc, intrinsicName);
loweredOp.addTypes(intrinsicOutType);
loweredOp.addOperands(
@@ -997,7 +999,6 @@ struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> {
Value scaleB = createI32Constant(rewriter, loc, adaptor.getScaleB());
Value opselA = createI32Constant(rewriter, loc, adaptor.getOpselA());
Value opselB = createI32Constant(rewriter, loc, adaptor.getOpselB());
- auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
loweredOp.addOperands({createI32Constant(rewriter, loc, aTypeCode),
createI32Constant(rewriter, loc, bTypeCode),
/*scale A byte=*/opselA, /*scale A=*/scaleA,
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 5da40a06dc2f0..1ffba5c542b1b 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -507,36 +507,27 @@ LogicalResult GatherToLDSOp::verify() {
}
LogicalResult ScaledMFMAOp::verify() {
- unsigned opselA = getOpselA();
- unsigned opselB = getOpselB();
-
- opselA >>= 8;
- opselB >>= 8;
+ unsigned opselA = getOpselA() >> 8;
+ unsigned opselB = getOpselB() >> 8;
if (opselA != 0)
- return emitOpError("Opsel A must be a zero extended 8 bit value.");
+ return emitOpError("Opsel A must be a zero extended 8 bit value");
if (opselB != 0)
- return emitOpError("Opsel B must be a zero extended 8 bit value.");
-
- auto validType = [&](Type mlirElemType) {
- return llvm::TypeSwitch<Type, bool>(mlirElemType)
- .Case([](Float8E4M3FNType) { return true; })
- .Case([](Float8E5M2Type) { return true; })
- .Case([](Float6E2M3FNType) { return true; })
- .Case([](Float6E3M2FNType) { return true; })
- .Case([](Float4E2M1FNType) { return true; })
- .Default([](Type) { return false; });
- };
+ return emitOpError("Opsel B must be a zero extended 8 bit value");
+
+ auto isValidType =
+ llvm::IsaPred<Float8E4M3FNType, Float8E5M2Type, Float6E2M3FNType,
+ Float6E3M2FNType, Float4E2M1FNType>;
Type aType = getSourceA().getType();
Type bType = getSourceB().getType();
aType = getElementTypeOrSelf(aType);
bType = getElementTypeOrSelf(bType);
- if (!validType(aType))
- return emitOpError("Source A must be of element type fp4, fp6 or fp8.");
- if (!validType(bType))
- return emitOpError("Source B must be of element type fp4, fp6 or fp8.");
+ if (!isValidType(aType))
+ return emitOpError("Source A must be of element type fp4, fp6 or fp8");
+ if (!isValidType(bType))
+ return emitOpError("Source B must be of element type fp4, fp6 or fp8");
unsigned m = getM();
unsigned n = getN();
@@ -544,7 +535,7 @@ LogicalResult ScaledMFMAOp::verify() {
bool tileConfig1 = (m == n && n == 32 && k == 64);
bool tileConfig2 = (m == n && n == 16 && k == 128);
if (!tileConfig1 && !tileConfig2)
- return emitOpError("Invalid tile size for scaled mfma.");
+ return emitOpError("Invalid tile size for scaled mfma");
return success();
}
More information about the Mlir-commits
mailing list