[Mlir-commits] [mlir] [MLIR][AMDGPU] Add OCP FP8 support for new hardware (PR #127728)
Mirza Halilčević
llvmlistbot at llvm.org
Mon Feb 24 12:48:33 PST 2025
https://github.com/mirza-halilcevic updated https://github.com/llvm/llvm-project/pull/127728
>From 32e052ddddf3067d16fe89a507c074f4bdf55137 Mon Sep 17 00:00:00 2001
From: Paul Fuqua <pf at acm.org>
Date: Thu, 11 Jul 2024 20:12:45 -0500
Subject: [PATCH 1/7] [MLIR][AMDGPU] Add OCP FP8 support to for new hardware
Upcoming hardware (gfx12 and some future gfx9) will support the OCP
8-bit float formats for their matrix multiplication intrinsics and
conversion operations, retaining existing opcodes and compiler builtins.
This commit adds support for these types to the MLIR wrappers around
such operations, ensuring that the OCP types aren't used to generate
those builtins on hardware that doesn't expect that format and,
conversely, to ensure that the pre-OCP formats aren't used on new
hardware.
---
mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 14 +++----
.../mlir/Dialect/AMDGPU/Utils/Chipset.h | 7 ++++
.../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 38 ++++++++++---------
.../ArithToAMDGPU/ArithToAMDGPU.cpp | 9 ++++-
mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 6 ++-
.../Tosa/Transforms/TosaValidation.cpp | 3 +-
6 files changed, 47 insertions(+), 30 deletions(-)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index f795dd89b79a1..28a908476c2ba 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -41,8 +41,8 @@ class AMDGPU_Op<string mnemonic, list<Trait> traits = []> :
def AMDGPU_ExtPackedFp8Op :
AMDGPU_Op<"ext_packed_fp8", [Pure]>,
- Arguments<(ins AnyTypeOf<[F8E5M2FNUZ, F8E4M3FNUZ,
- VectorOfLengthAndType<[1, 2, 3, 4], [F8E5M2FNUZ, F8E4M3FNUZ]>]>:$source,
+ Arguments<(ins AnyTypeOf<[F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN,
+ VectorOfLengthAndType<[1, 2, 3, 4], [F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN]>]>:$source,
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<3>]>:$index)>,
Results<(outs F32:$res)> {
let summary = "Extend one of a vector of packed fp8 values to a float";
@@ -68,8 +68,8 @@ def AMDGPU_PackedTrunc2xFp8Op :
Arguments<(ins F32:$sourceA,
Optional<F32>:$sourceB,
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<1>]>:$wordIndex,
- Optional<FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>>:$existing)>,
- Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>:$res)> {
+ Optional<FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2]>>:$existing)>,
+ Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2]>:$res)> {
let summary = "Round two floats into a packed vector of 8-bit floats";
let description = [{
Round the inputs `sourceA` and `sourceB` (which is undefined if not
@@ -95,8 +95,8 @@ def AMDGPU_PackedStochRoundFp8Op :
Arguments<(ins F32:$source,
I32:$stochiasticParam,
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<3>]>:$storeIndex,
- Optional<FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>>:$existing)>,
- Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>:$res)> {
+ Optional<FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2]>>:$existing)>,
+ Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2]>:$res)> {
let summary = "Round float stochiastically into a packed vector of 8-bit floats";
let description = [{
Round the input `source`, adding in `stochiasticParam`, and place it into
@@ -546,7 +546,7 @@ def MFMAInTypes : AnyTypeOf<[F32, F64, I32, I64,
VectorOfLengthAndType<[4], [F16]>,
VectorOfLengthAndType<[2, 4], [BF16]>,
VectorOfLengthAndType<[4, 8], [I8]>,
- VectorOfLengthAndType<[8], [F8E5M2FNUZ, F8E4M3FNUZ]>]>;
+ VectorOfLengthAndType<[8], [F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN]>]>;
def MFMAOutTypes : AnyTypeOf<[F64,
VectorOfLengthAndType<[4, 16, 32], [F32]>,
VectorOfLengthAndType<[4, 16, 32], [I32]>,
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h b/mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h
index a5dab1ab89630..e5c14c1cb6827 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h
+++ b/mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h
@@ -47,6 +47,13 @@ struct Chipset {
DEFINE_COMP_OPERATOR(>)
DEFINE_COMP_OPERATOR(>=)
#undef DEFINE_COMP_OPERATOR
+
+ bool isGfx940() const {
+ return majorVersion == 9 && minorVersion >= 0x40 && minorVersion < 0x50;
+ }
+ bool hasOcpFp8() const {
+ return (majorVersion == 9 && minorVersion >= 0x50) || majorVersion >= 12;
+ }
};
} // namespace mlir::amdgpu
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index b8574bbbee345..c8db72f85f103 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -570,40 +570,42 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
return ROCDL::mfma_f64_4x4x4f64::getOperationName();
}
- if (isa<Float8E5M2FNUZType>(sourceElem) && destElem.isF32() &&
- chipset >= kGfx942) {
+ if (destElem.isF32() &&
+ ((isa<Float8E5M2FNUZType>(sourceElem) && chipset >= kGfx942) ||
+ (isa<Float8E5M2Type>(sourceElem) && chipset.hasOcpFp8()))) {
// Known to be correct because there are no scalar f8 instructions and
// because a length mismatch will have been caught by the verifier.
Type sourceBElem =
cast<VectorType>(mfma.getSourceB().getType()).getElementType();
if (m == 16 && n == 16 && k == 32 && b == 1) {
- if (isa<Float8E5M2FNUZType>(sourceBElem))
+ if (isa<Float8E5M2FNUZType, Float8E5M2Type>(sourceBElem))
return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
- if (isa<Float8E4M3FNUZType>(sourceBElem))
+ if (isa<Float8E4M3FNUZType, Float8E4M3FNType>(sourceBElem))
return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
}
if (m == 32 && n == 32 && k == 16 && b == 1) {
- if (isa<Float8E5M2FNUZType>(sourceBElem))
+ if (isa<Float8E5M2FNUZType, Float8E5M2Type>(sourceBElem))
return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
- if (isa<Float8E4M3FNUZType>(sourceBElem))
+ if (isa<Float8E4M3FNUZType, Float8E4M3FNType>(sourceBElem))
return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
}
}
- if (isa<Float8E4M3FNUZType>(sourceElem) && destElem.isF32() &&
- chipset >= kGfx942) {
+ if (destElem.isF32() &&
+ ((isa<Float8E4M3FNUZType>(sourceElem) && chipset >= kGfx942) ||
+ (isa<Float8E4M3FNType>(sourceElem) && chipset.hasOcpFp8()))) {
Type sourceBElem =
cast<VectorType>(mfma.getSourceB().getType()).getElementType();
if (m == 16 && n == 16 && k == 32 && b == 1) {
- if (isa<Float8E5M2FNUZType>(sourceBElem))
+ if (isa<Float8E5M2FNUZType, Float8E5M2Type>(sourceBElem))
return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
- if (isa<Float8E4M3FNUZType>(sourceBElem))
+ if (isa<Float8E4M3FNUZType, Float8E4M3FNType>(sourceBElem))
return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
}
if (m == 32 && n == 32 && k == 16 && b == 1) {
- if (isa<Float8E5M2FNUZType>(sourceBElem))
+ if (isa<Float8E5M2FNUZType, Float8E5M2Type>(sourceBElem))
return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
- if (isa<Float8E4M3FNUZType>(sourceBElem))
+ if (isa<Float8E4M3FNUZType, Float8E4M3FNType>(sourceBElem))
return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
}
}
@@ -811,10 +813,10 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
}
Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
Value wordSel = createI32Constant(rewriter, loc, op.getIndex());
- if (isa<Float8E5M2FNUZType>(sourceElemType)) {
+ if (isa<Float8E5M2FNUZType, Float8E5M2Type>(sourceElemType)) {
rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
wordSel);
- } else if (isa<Float8E4M3FNUZType>(sourceElemType)) {
+ } else if (isa<Float8E4M3FNUZType, Float8E4M3FNType>(sourceElemType)) {
rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
wordSel);
}
@@ -846,10 +848,10 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
Value wordSel = createI1Constant(rewriter, loc, op.getWordIndex());
Value result;
- if (isa<Float8E5M2FNUZType>(resultElemType))
+ if (isa<Float8E5M2FNUZType, Float8E5M2Type>(resultElemType))
result = rewriter.create<ROCDL::CvtPkBf8F32Op>(loc, i32, sourceA, sourceB,
existing, wordSel);
- else if (isa<Float8E4M3FNUZType>(resultElemType))
+ else if (isa<Float8E4M3FNUZType, Float8E4M3FNType>(resultElemType))
result = rewriter.create<ROCDL::CvtPkFp8F32Op>(loc, i32, sourceA, sourceB,
existing, wordSel);
@@ -881,10 +883,10 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
Value byteSel = createI32Constant(rewriter, loc, op.getStoreIndex());
Value result;
- if (isa<Float8E5M2FNUZType>(resultElemType))
+ if (isa<Float8E5M2FNUZType, Float8E5M2Type>(resultElemType))
result = rewriter.create<ROCDL::CvtSrBf8F32Op>(loc, i32, source, stoch,
existing, byteSel);
- else if (isa<Float8E4M3FNUZType>(resultElemType))
+ else if (isa<Float8E4M3FNUZType, Float8E4M3FNType>(resultElemType))
result = rewriter.create<ROCDL::CvtSrFp8F32Op>(loc, i32, source, stoch,
existing, byteSel);
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index b22d852f7c543..d70717f0ca2d1 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -86,7 +86,8 @@ LogicalResult ExtFOnFloat8RewritePattern::match(arith::ExtFOp op) const {
return failure();
inType = inVecType.getElementType();
}
- return success(isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(inType));
+ return success(isa<Float8E5M2FNUZType, Float8E4M3FNUZType, Float8E5M2Type,
+ Float8E4M3FNType>(inType));
}
void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
@@ -219,7 +220,11 @@ LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const {
if (inType && inType.getWidth() <= 8 && saturateFP8)
// Conversion between 8-bit floats is not supported with truncation enabled.
return failure();
- return success(isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(outType));
+
+ return success((
+ (isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(outType) &&
+ chipset.isGfx940()) ||
+ (isa<Float8E5M2Type, Float8E4M3FNType>(outType) && chipset.hasOcpFp8())));
}
void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 271ca382e2f0b..c6aba7560acf5 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -272,14 +272,16 @@ LogicalResult MFMAOp::verify() {
}
Type sourceBType = getSourceB().getType();
- if (isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(sourceElem)) {
+ if (isa<Float8E5M2FNUZType, Float8E4M3FNUZType, Float8E5M2Type,
+ Float8E4M3FNType>(sourceElem)) {
int64_t sourceBLen = 1;
Type sourceBElem = sourceBType;
if (auto sourceBVector = llvm::dyn_cast<VectorType>(sourceBType)) {
sourceBLen = sourceBVector.getNumElements();
sourceBElem = sourceBVector.getElementType();
}
- if (!isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(sourceBElem))
+ if (!isa<Float8E5M2FNUZType, Float8E4M3FNUZType, Float8E5M2Type,
+ Float8E4M3FNType>(sourceBElem))
return emitOpError("expected both source operands to have f8 elements");
if (sourceLen != sourceBLen)
return emitOpError(
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index f74a4b4c58b80..0113813e13419 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -696,7 +696,8 @@ LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
bool TosaValidation::isValidElementType(Type type) {
if (isa<FloatType>(type)) {
- return type.isF32() || type.isF16() || type.isBF16();
+ return isa<Float32Type, Float16Type, BFloat16Type, Float8E4M3FNUZType,
+ Float8E5M2FNUZType, Float8E4M3FNType, Float8E5M2Type>(type);
} else if (auto intTy = dyn_cast<IntegerType>(type)) {
if (intTy.isSignless()) {
switch (intTy.getWidth()) {
>From b87a0a07883a1b55e409e20bd0d5dd45dc6a9f5f Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
Date: Mon, 9 Sep 2024 17:39:17 -0500
Subject: [PATCH 2/7] [MLIR][AMDGPU] After fp8 conversions were lowered to
AMDGPU dialect ops, those operations were not being converted to the LLVM
intrinsics they correspond to because the rewrite patterns were still
checking for gfx940+.
As part of this, factor out tests for type-match isto isNativeFp8()
and isNativeBf8() functions in the AMDGPUToRocdl rewrites.
Also, fix a typo in isGfx940() that caused it to be true for gfx950.
Finally, test all these OCP format conversions by duplicating the
gfx940 tests.
---
.../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 50 +++--
.../AMDGPUToROCDL/8-bit-floats-ocp.mlir | 109 +++++++++++
.../8-bit-float-saturation-ocp.mlir | 58 ++++++
.../ArithToAMDGPU/8-bit-floats-ocp.mlir | 176 ++++++++++++++++++
4 files changed, 373 insertions(+), 20 deletions(-)
create mode 100644 mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-ocp.mlir
create mode 100644 mlir/test/Conversion/ArithToAMDGPU/8-bit-float-saturation-ocp.mlir
create mode 100644 mlir/test/Conversion/ArithToAMDGPU/8-bit-floats-ocp.mlir
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index c8db72f85f103..d904c47b28300 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -474,6 +474,20 @@ static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
}
}
+/// Return true if `type` is the E5M2 variant of an 8-bit float that is
+/// supported by the `_bf8` instructions on the given `chipset`.
+static bool isNativeBf8(Chipset chipset, Type type) {
+ return (chipset.isGfx940() && isa<Float8E5M2FNUZType>(type)) ||
+ (chipset.hasOcpFp8() && isa<Float8E5M2Type>(type));
+}
+
+/// Return true if `type` is the E4M3FN variant of an 8-bit float that is
+/// supported by the `_fp8` instructions on the given `chipset`.
+static bool isNativeFp8(Chipset chipset, Type type) {
+ return (chipset.isGfx940() && isa<Float8E4M3FNUZType>(type)) ||
+ (chipset.hasOcpFp8() && isa<Float8E4M3FNType>(type));
+}
+
/// Return the `rocdl` intrinsic corresponding to a MFMA operation `mfma`
/// if one exists. This includes checking to ensure the intrinsic is supported
/// on the architecture you are compiling for.
@@ -570,42 +584,38 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
return ROCDL::mfma_f64_4x4x4f64::getOperationName();
}
- if (destElem.isF32() &&
- ((isa<Float8E5M2FNUZType>(sourceElem) && chipset >= kGfx942) ||
- (isa<Float8E5M2Type>(sourceElem) && chipset.hasOcpFp8()))) {
+ if (destElem.isF32() && isNativeBf8(chipset, sourceElem)) {
// Known to be correct because there are no scalar f8 instructions and
// because a length mismatch will have been caught by the verifier.
Type sourceBElem =
cast<VectorType>(mfma.getSourceB().getType()).getElementType();
if (m == 16 && n == 16 && k == 32 && b == 1) {
- if (isa<Float8E5M2FNUZType, Float8E5M2Type>(sourceBElem))
+ if (isNativeBf8(chipset, sourceBElem))
return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
- if (isa<Float8E4M3FNUZType, Float8E4M3FNType>(sourceBElem))
+ if (isNativeFp8(chipset, sourceBElem))
return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
}
if (m == 32 && n == 32 && k == 16 && b == 1) {
- if (isa<Float8E5M2FNUZType, Float8E5M2Type>(sourceBElem))
+ if (isNativeBf8(chipset, sourceBElem))
return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
- if (isa<Float8E4M3FNUZType, Float8E4M3FNType>(sourceBElem))
+ if (isNativeFp8(chipset, sourceBElem))
return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
}
}
- if (destElem.isF32() &&
- ((isa<Float8E4M3FNUZType>(sourceElem) && chipset >= kGfx942) ||
- (isa<Float8E4M3FNType>(sourceElem) && chipset.hasOcpFp8()))) {
+ if (destElem.isF32() && isNativeFp8(chipset, sourceElem)) {
Type sourceBElem =
cast<VectorType>(mfma.getSourceB().getType()).getElementType();
if (m == 16 && n == 16 && k == 32 && b == 1) {
- if (isa<Float8E5M2FNUZType, Float8E5M2Type>(sourceBElem))
+ if (isNativeBf8(chipset, sourceBElem))
return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
- if (isa<Float8E4M3FNUZType, Float8E4M3FNType>(sourceBElem))
+ if (isNativeFp8(chipset, sourceBElem))
return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
}
if (m == 32 && n == 32 && k == 16 && b == 1) {
- if (isa<Float8E5M2FNUZType, Float8E5M2Type>(sourceBElem))
+ if (isNativeBf8(chipset, sourceBElem))
return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
- if (isa<Float8E4M3FNUZType, Float8E4M3FNType>(sourceBElem))
+ if (isNativeFp8(chipset, sourceBElem))
return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
}
}
@@ -813,10 +823,10 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
}
Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
Value wordSel = createI32Constant(rewriter, loc, op.getIndex());
- if (isa<Float8E5M2FNUZType, Float8E5M2Type>(sourceElemType)) {
+ if (isNativeBf8(chipset, sourceElemType)) {
rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
wordSel);
- } else if (isa<Float8E4M3FNUZType, Float8E4M3FNType>(sourceElemType)) {
+ } else if (isNativeFp8(chipset, sourceElemType)) {
rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
wordSel);
}
@@ -848,10 +858,10 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
Value wordSel = createI1Constant(rewriter, loc, op.getWordIndex());
Value result;
- if (isa<Float8E5M2FNUZType, Float8E5M2Type>(resultElemType))
+ if (isNativeBf8(chipset, resultElemType))
result = rewriter.create<ROCDL::CvtPkBf8F32Op>(loc, i32, sourceA, sourceB,
existing, wordSel);
- else if (isa<Float8E4M3FNUZType, Float8E4M3FNType>(resultElemType))
+ else if (isNativeFp8(chipset, resultElemType))
result = rewriter.create<ROCDL::CvtPkFp8F32Op>(loc, i32, sourceA, sourceB,
existing, wordSel);
@@ -883,10 +893,10 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
Value byteSel = createI32Constant(rewriter, loc, op.getStoreIndex());
Value result;
- if (isa<Float8E5M2FNUZType, Float8E5M2Type>(resultElemType))
+ if (isNativeBf8(chipset, resultElemType))
result = rewriter.create<ROCDL::CvtSrBf8F32Op>(loc, i32, source, stoch,
existing, byteSel);
- else if (isa<Float8E4M3FNUZType, Float8E4M3FNType>(resultElemType))
+ else if (isNativeFp8(chipset, resultElemType))
result = rewriter.create<ROCDL::CvtSrFp8F32Op>(loc, i32, source, stoch,
existing, byteSel);
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-ocp.mlir b/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-ocp.mlir
new file mode 100644
index 0000000000000..70775a603e54d
--- /dev/null
+++ b/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-ocp.mlir
@@ -0,0 +1,109 @@
+// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx950 | FileCheck %s
+// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1200 | FileCheck %s
+
+// CHECK-LABEL: func @ext_scalar
+// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : f8E5M2 to i8
+// CHECK-DAG: [[UNDEF:%.+]] = llvm.mlir.undef : vector<4xi8>
+// CHECK-DAG: [[C0_1:%.+]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: [[VEC:%.+]] = llvm.insertelement [[V]], [[UNDEF]]{{\[}}[[C0_1]] : i32] : vector<4xi8>
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC]] : vector<4xi8> to i32
+// CHECK: [[C0_2:%.+]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.bf8 [[CAST]]{{\[}}[[C0_2]]] : f32
+// CHECK: return [[EXT]]
+func.func @ext_scalar(%v: f8E5M2) -> f32 {
+ %ret = amdgpu.ext_packed_fp8 %v[0] : f8E5M2 to f32
+ func.return %ret : f32
+}
+
+// CHECK-LABEL: func @ext_short_vec
+// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<2xf8E4M3FN> to vector<2xi8>
+// CHECK-DAG: [[UNDEF:%.+]] = llvm.mlir.undef : vector<4xi8>
+// CHECK-DAG: [[C0:%.+]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: [[ELEM_0:%.+]] = llvm.extractelement [[V]]{{\[}}[[C0]] : i32] : vector<2xi8>
+// CHECK: [[VEC_0:%.+]] = llvm.insertelement [[ELEM_0]], [[UNDEF]]{{\[}}[[C0]] : i32] : vector<4xi8>
+// CHECK: [[C1_1:%.+]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: [[ELEM_1:%.+]] = llvm.extractelement [[V]]{{\[}}[[C1_1]] : i32] : vector<2xi8>
+// CHECK: [[VEC_1:%.+]] = llvm.insertelement [[ELEM_1]], [[VEC_0]]{{\[}}[[C1_1]] : i32] : vector<4xi8>
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC_1]] : vector<4xi8> to i32
+// CHECK: [[C1_2:%.+]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]]{{\[}}[[C1_2]]] : f32
+// CHECK: return [[EXT]]
+func.func @ext_short_vec(%v: vector<2xf8E4M3FN>) -> f32 {
+ %ret = amdgpu.ext_packed_fp8 %v[1] : vector<2xf8E4M3FN> to f32
+ func.return %ret : f32
+}
+
+// CHECK-LABEL: func @ext_full_vec(
+// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<4xf8E4M3FN> to vector<4xi8>
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[V]] : vector<4xi8> to i32
+// CHECK: [[C3:%.+]] = llvm.mlir.constant(3 : i32) : i32
+// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]]{{\[}}[[C3]]] : f32
+// CHECK: return [[EXT]] : f32
+
+func.func @ext_full_vec(%v: vector<4xf8E4M3FN>) -> f32 {
+ %ret = amdgpu.ext_packed_fp8 %v[3] : vector<4xf8E4M3FN> to f32
+ func.return %ret : f32
+}
+
+// CHECK-LABEL: func @packed_trunc
+// CHECK-SAME: ([[V:%.+]]: f32)
+// CHECK: [[V2:%.+]] = llvm.mlir.undef : f32
+// CHECK: [[EXISTING:%.+]] = llvm.mlir.undef : i32
+// CHECK: [[FALSE:%.+]] = llvm.mlir.constant(false) : i1
+// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.fp8.f32 [[V]], [[V2]] -> [[EXISTING]]{{\[}}[[FALSE]]] : i32
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8>
+// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E4M3FN>
+func.func @packed_trunc(%v: f32) -> vector<4xf8E4M3FN> {
+ %ret = amdgpu.packed_trunc_2xfp8 %v, undef into undef[word 0] : f32 to vector<4xf8E4M3FN>
+ func.return %ret : vector<4xf8E4M3FN>
+}
+
+// CHECK-LABEL: func @packed_truncx2
+// CHECK-SAME: ([[V:%.+]]: f32, [[W:%.+]]: f32)
+// CHECK: [[EXISTING:%.+]] = llvm.mlir.undef : i32
+// CHECK: [[FALSE:%.+]] = llvm.mlir.constant(false) : i1
+// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.fp8.f32 [[V]], [[W]] -> [[EXISTING]]{{\[}}[[FALSE]]] : i32
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8>
+// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E4M3FN>
+func.func @packed_truncx2(%v: f32, %w: f32) -> vector<4xf8E4M3FN> {
+ %ret = amdgpu.packed_trunc_2xfp8 %v, %w into undef[word 0] : f32 to vector<4xf8E4M3FN>
+ func.return %ret : vector<4xf8E4M3FN>
+}
+
+// CHECK-LABEL: func @packed_truncx2_into
+// CHECK-SAME: ([[V:%.+]]: f32, [[W:%.+]]: f32, [[EXISTING:%.+]]: vector<4xf8E5M2>)
+// CHECK: [[EXISTING_BYTES:%.+]] = builtin.unrealized_conversion_cast [[EXISTING]] : vector<4xf8E5M2> to vector<4xi8>
+// CHECK: [[EXISTING_INT:%.+]] = llvm.bitcast [[EXISTING_BYTES]] : vector<4xi8> to i32
+// CHECK: [[TRUE:%.+]] = llvm.mlir.constant(true) : i1
+// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.bf8.f32 [[V]], [[W]] -> [[EXISTING_INT]]{{\[}}[[TRUE]]] : i32
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8>
+// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E5M2>
+func.func @packed_truncx2_into(%v: f32, %w: f32, %existing: vector<4xf8E5M2>) -> vector<4xf8E5M2> {
+ %ret = amdgpu.packed_trunc_2xfp8 %v, %w into %existing[word 1] : f32 to vector<4xf8E5M2> into vector<4xf8E5M2>
+ func.return %ret : vector<4xf8E5M2>
+}
+
+// CHECK-LABEL: func @packed_stoch_round
+// CHECK-SAME: ([[V:%.+]]: f32, [[S:%.+]]: i32)
+// CHECK: [[EXISTING:%.+]] = llvm.mlir.undef : i32
+// CHECK: [[C0:%.+]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: [[PACKED:%.+]] = rocdl.cvt.sr.fp8.f32 [[V]], [[S]] -> [[EXISTING]]{{\[}}[[C0]]] : i32
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8>
+// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E4M3FN>
+func.func @packed_stoch_round(%v: f32, %s: i32) -> vector<4xf8E4M3FN> {
+ %ret = amdgpu.packed_stoch_round_fp8 %v + %s into undef[0] : f32 to vector<4xf8E4M3FN>
+ func.return %ret : vector<4xf8E4M3FN>
+}
+
+// CHECK-LABEL: func @packed_stoch_round_into
+// CHECK-SAME: ([[V:%.+]]: f32, [[S:%.+]]: i32, [[EXISTING:%.+]]: vector<4xf8E5M2>)
+// CHECK: [[EXISTING_BYTES:%.+]] = builtin.unrealized_conversion_cast [[EXISTING]] : vector<4xf8E5M2> to vector<4xi8>
+// CHECK: [[EXISTING_INT:%.+]] = llvm.bitcast [[EXISTING_BYTES]] : vector<4xi8> to i32
+// CHECK: [[C1:%.+]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: [[PACKED:%.+]] = rocdl.cvt.sr.bf8.f32 [[V]], [[S]] -> [[EXISTING_INT]]{{\[}}[[C1]]] : i32
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8>
+// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E5M2>
+func.func @packed_stoch_round_into(%v: f32, %s: i32, %existing: vector<4xf8E5M2>) -> vector<4xf8E5M2> {
+ %ret = amdgpu.packed_stoch_round_fp8 %v + %s into %existing[1] : f32 to vector<4xf8E5M2> into vector<4xf8E5M2>
+ func.return %ret : vector<4xf8E5M2>
+}
diff --git a/mlir/test/Conversion/ArithToAMDGPU/8-bit-float-saturation-ocp.mlir b/mlir/test/Conversion/ArithToAMDGPU/8-bit-float-saturation-ocp.mlir
new file mode 100644
index 0000000000000..2df5f2fa1965f
--- /dev/null
+++ b/mlir/test/Conversion/ArithToAMDGPU/8-bit-float-saturation-ocp.mlir
@@ -0,0 +1,58 @@
+// RUN: mlir-opt --split-input-file %s \
+// RUN: --pass-pipeline='builtin.module(func.func(convert-arith-to-amdgpu{chipset=gfx950 saturate-fp8-truncf=true}))' \
+// RUN: | FileCheck %s
+
+// RUN: mlir-opt --split-input-file %s \
+// RUN: --pass-pipeline='builtin.module(func.func(convert-arith-to-amdgpu{chipset=gfx1200 saturate-fp8-truncf=true}))' \
+// RUN: | FileCheck %s
+
+// CHECK-LABEL: func.func @scalar_trunc
+// CHECK-SAME: ([[V:%.+]]: f16)
+// CHECK-DAG: [[CMin:%.+]] = arith.constant -5.734400e+04 : f16
+// CHECK-DAG: [[CMax:%.+]] = arith.constant 5.734400e+04 : f16
+// CHECK-DAG: [[CInf:%.+]] = arith.constant 0x7C00 : f16
+// CHECK-DAG: [[CNegInf:%.+]] = arith.constant 0xFC00 : f16
+// CHECK: [[ISINF:%.+]] = arith.cmpf oeq, [[V]], [[CInf]]
+// CHECK: [[ISNEGINF:%.+]] = arith.cmpf oeq, [[V]], [[CNegInf]]
+// CHECK: [[ISNAN:%.+]] = arith.cmpf uno, [[V]], [[V]]
+// CHECK: [[ISNONFINITE_1:%.+]] = arith.ori [[ISINF]], [[ISNEGINF]]
+// CHECK: [[ISNONFINITE:%.+]] = arith.ori [[ISNONFINITE_1]], [[ISNAN]]
+// CHECK: [[CLAMPEDBELOW:%.+]] = arith.maximumf [[V]], [[CMin]]
+// CHECK: [[CLAMPED:%.+]] = arith.minimumf [[CLAMPEDBELOW]], [[CMax]]
+// CHECK: [[SATURATED:%.+]] = arith.select [[ISNONFINITE]], [[V]], [[CLAMPED]]
+// CHECK: [[FLOAT:%.+]] = arith.extf [[SATURATED]] : f16 to f32
+// CHECK: [[TRUNCV:%.+]] = amdgpu.packed_trunc_2xfp8 [[FLOAT]], undef into undef[word 0] : f32 to vector<4xf8E5M2>
+// CHECK: [[W:%.+]] = vector.extract [[TRUNCV]][0] : f8E5M2 from vector<4xf8E5M2>
+// CHECK: return [[W]] : f8E5M2
+func.func @scalar_trunc(%v: f16) -> f8E5M2 {
+ %w = arith.truncf %v : f16 to f8E5M2
+ return %w : f8E5M2
+}
+
+// No 0-D test because arith.truncf hasn't been extended to support it.
+
+// -----
+
+// CHECK-LABEL: func.func @vector_trunc
+// CHECK-SAME: ([[V:%.+]]: vector<2xf32>) -> vector<2xf8E4M3FN> {
+// CHECK-DAG: [[CMin:%.+]] = arith.constant dense<-4.480000e+02> : vector<2xf32>
+// CHECK-DAG: [[CMax:%.+]] = arith.constant dense<4.480000e+02> : vector<2xf32>
+// CHECK-DAG: [[CInf:%.+]] = arith.constant dense<0x7F800000> : vector<2xf32>
+// CHECK-DAG: [[CNegInf:%.+]] = arith.constant dense<0xFF800000> : vector<2xf32>
+// CHECK: [[ISINF:%.+]] = arith.cmpf oeq, [[V]], [[CInf]]
+// CHECK: [[ISNEGINF:%.+]] = arith.cmpf oeq, [[V]], [[CNegInf]]
+// CHECK: [[ISNAN:%.+]] = arith.cmpf uno, [[V]], [[V]]
+// CHECK: [[ISNONFINITE_1:%.+]] = arith.ori [[ISINF]], [[ISNEGINF]]
+// CHECK: [[ISNONFINITE:%.+]] = arith.ori [[ISNONFINITE_1]], [[ISNAN]]
+// CHECK: [[CLAMPEDBELOW:%.+]] = arith.maximumf [[V]], [[CMin]]
+// CHECK: [[CLAMPED:%.+]] = arith.minimumf [[CLAMPEDBELOW]], [[CMax]]
+// CHECK: [[SATURATED:%.+]] = arith.select [[ISNONFINITE]], [[V]], [[CLAMPED]]
+// CHECK: [[F0:%.+]] = vector.extract [[SATURATED]][0]
+// CHECK: [[F1:%.+]] = vector.extract [[SATURATED]][1]
+// CHECK: [[W0:%.+]] = amdgpu.packed_trunc_2xfp8 [[F0]], [[F1]] into undef[word 0] : f32 to vector<4xf8E4M3FN>
+// CHECK: [[W:%.+]] = vector.extract_strided_slice [[W0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xf8E4M3FN> to vector<2xf8E4M3FN>
+// CHECK: return [[W]] : vector<2xf8E4M3FN>
+func.func @vector_trunc_short(%v: vector<2xf32>) -> vector<2xf8E4M3FN> {
+ %w = arith.truncf %v : vector<2xf32> to vector<2xf8E4M3FN>
+ return %w : vector<2xf8E4M3FN>
+}
diff --git a/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats-ocp.mlir b/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats-ocp.mlir
new file mode 100644
index 0000000000000..0e7f58c9e6749
--- /dev/null
+++ b/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats-ocp.mlir
@@ -0,0 +1,176 @@
+// RUN: mlir-opt --split-input-file %s -convert-arith-to-amdgpu="chipset=gfx950" | FileCheck %s
+// RUN: mlir-opt --split-input-file %s -convert-arith-to-amdgpu="chipset=gfx1200" | FileCheck %s
+
+// CHECK-LABEL: func.func @scalar_ext
+// CHECK-SAME: ([[V:%.+]]: f8E5M2)
+// CHECK: [[FLOAT:%.+]] = amdgpu.ext_packed_fp8 [[V]][0] : f8E5M2 to f32
+// CHECK: [[W:%.+]] = arith.truncf [[FLOAT]] : f32 to f16
+// CHECK: return [[W]]
+func.func @scalar_ext(%v: f8E5M2) -> f16 {
+ %w = arith.extf %v : f8E5M2 to f16
+ return %w : f16
+}
+
+// No 0-D test because arith.extf hasn't been extended to support it.
+
+// -----
+
+// CHECK-LABEL: func.func @vector_ext_short
+// CHECK-SAME: ([[V:%.+]]: vector<2xf8E5M2>)
+// CHECK-DAG: [[ZEROES:%.+]] = arith.constant dense<0.000000e+00> : vector<2xf64>
+// CHECK: [[FLOAT0:%.+]] = amdgpu.ext_packed_fp8 [[V]][0] : vector<2xf8E5M2> to f32
+// CHECK: [[EXT0:%.+]] = arith.extf [[FLOAT0]] : f32 to f64
+// CHECK: [[W0:%.+]] = vector.insert [[EXT0]], [[ZEROES]] [0]
+// CHECK: [[FLOAT1:%.+]] = amdgpu.ext_packed_fp8 [[V]][1] : vector<2xf8E5M2> to f32
+// CHECK: [[EXT1:%.+]] = arith.extf [[FLOAT1]]
+// CHECK: [[W1:%.+]] = vector.insert [[EXT1]], [[W0]] [1]
+// CHECK: return [[W1]] : vector<2xf64>
+
+func.func @vector_ext_short(%v: vector<2xf8E5M2>) -> vector<2xf64> {
+ %w = arith.extf %v : vector<2xf8E5M2> to vector<2xf64>
+ return %w : vector<2xf64>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @vector_ext_long
+// CHECK-SAME: ([[V:%.+]]: vector<9xf8E4M3FN>)
+// CHECK: [[V0:%.+]] = vector.extract_strided_slice [[V]] {offsets = [0], sizes = [4], strides = [1]}
+// CHECK: [[F0:%.+]] = amdgpu.ext_packed_fp8 [[V0]][0]
+// CHECK: [[W0:%.+]] = vector.insert [[F0]]
+// CHECK: [[F1:%.+]] = amdgpu.ext_packed_fp8 [[V0]][1]
+// CHECK: [[W1:%.+]] = vector.insert [[F1]], [[W0]]
+// CHECK: [[F2:%.+]] = amdgpu.ext_packed_fp8 [[V0]][2]
+// CHECK: [[W2:%.+]] = vector.insert [[F2]], [[W1]]
+// CHECK: [[F3:%.+]] = amdgpu.ext_packed_fp8 [[V0]][3]
+// CHECK: [[W3:%.+]] = vector.insert [[F3]], [[W2]]
+
+// CHECK: [[V1:%.+]] = vector.extract_strided_slice [[V]] {offsets = [4], sizes = [4], strides = [1]} : vector<9xf8E4M3FN> to vector<4xf8E4M3FN>
+// CHECK: [[F4:%.+]] = amdgpu.ext_packed_fp8 [[V1]][0]
+// CHECK: [[W4:%.+]] = vector.insert [[F4]], [[W3]]
+// CHECK: [[F5:%.+]] = amdgpu.ext_packed_fp8 [[V1]][1]
+// CHECK: [[W5:%.+]] = vector.insert [[F5]], [[W4]]
+// CHECK: [[F6:%.+]] = amdgpu.ext_packed_fp8 [[V1]][2]
+// CHECK: [[W6:%.+]] = vector.insert [[F6]], [[W5]]
+// CHECK: [[F7:%.+]] = amdgpu.ext_packed_fp8 [[V1]][3]
+// CHECK: [[W7:%.+]] = vector.insert [[F7]], [[W6]]
+
+// CHECK: [[V2:%.+]] = vector.extract_strided_slice [[V]] {offsets = [8], sizes = [1], strides = [1]} : vector<9xf8E4M3FN> to vector<1xf8E4M3FN>
+// CHECK: [[F8:%.+]] = amdgpu.ext_packed_fp8 [[V2]][0]
+// CHECK: [[W8:%.+]] = vector.insert [[F8]], [[W7]]
+// CHECK: return [[W8]]
+func.func @vector_ext_long(%v: vector<9xf8E4M3FN>) -> vector<9xf32> {
+ %w = arith.extf %v : vector<9xf8E4M3FN> to vector<9xf32>
+ return %w : vector<9xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @scalar_trunc
+// CHECK-SAME: ([[V:%.+]]: f16)
+// CHECK: [[FLOAT:%.+]] = arith.extf [[V]] : f16 to f32
+// CHECK: [[TRUNCV:%.+]] = amdgpu.packed_trunc_2xfp8 [[FLOAT]], undef into undef[word 0] : f32 to vector<4xf8E5M2>
+// CHECK: [[W:%.+]] = vector.extract [[TRUNCV]][0] : f8E5M2 from vector<4xf8E5M2>
+// CHECK: return [[W]] : f8E5M2
+func.func @scalar_trunc(%v: f16) -> f8E5M2 {
+ %w = arith.truncf %v : f16 to f8E5M2
+ return %w : f8E5M2
+}
+
+// No 0-D test because arith.truncf hasn't been extended to support it.
+
+// -----
+
+// CHECK-LABEL: func.func @vector_trunc_short
+// CHECK-SAME: ([[V:%.+]]: vector<2xf64>) -> vector<2xf8E5M2> {
+// CHECK: [[V0:%.+]] = vector.extract [[V]][0]
+// CHECK: [[F0:%.+]] = arith.truncf [[V0]] : f64 to f32
+// CHECK: [[V1:%.+]] = vector.extract [[V]][1]
+// CHECK: [[F1:%.+]] = arith.truncf [[V1]] : f64 to f32
+// CHECK: [[W0:%.+]] = amdgpu.packed_trunc_2xfp8 [[F0]], [[F1]] into undef[word 0] : f32 to vector<4xf8E5M2>
+// CHECK: [[W:%.+]] = vector.extract_strided_slice [[W0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xf8E5M2> to vector<2xf8E5M2>
+// CHECK: return [[W]] : vector<2xf8E5M2>
+func.func @vector_trunc_short(%v: vector<2xf64>) -> vector<2xf8E5M2> {
+ %w = arith.truncf %v : vector<2xf64> to vector<2xf8E5M2>
+ return %w : vector<2xf8E5M2>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @vector_trunc_long
+// CHECK-SAME: ([[V:%.+]]: vector<9xf32>)
+// CHECK: [[ZEROES:%.+]] = arith.constant dense<0.000000e+00> : vector<9xf8E4M3FN>
+// CHECK: [[T0:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into undef[word 0]
+// CHECK: [[T1:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into [[T0]][word 1]
+// CHECK: [[W0:%.+]] = vector.insert_strided_slice [[T1]], [[ZEROES]] {offsets = [0], strides = [1]}
+
+// CHECK: [[T2:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into undef[word 0]
+// CHECK: [[T3:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into [[T2]][word 1]
+// CHECK: [[W1:%.+]] = vector.insert_strided_slice [[T3]], [[W0]] {offsets = [4], strides = [1]}
+
+// CHECK: [[T4:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, undef into undef[word 0]
+// CHECK: [[T4_SHORT:%.+]] = vector.extract_strided_slice [[T4]] {offsets = [0], sizes = [1], strides = [1]}
+// CHECK: [[W:%.+]] = vector.insert_strided_slice [[T4_SHORT]], [[W1]] {offsets = [8], strides = [1]}
+// CHECK: return [[W]]
+func.func @vector_trunc_long(%v: vector<9xf32>) -> vector<9xf8E4M3FN> {
+ %w = arith.truncf %v : vector<9xf32> to vector<9xf8E4M3FN>
+ return %w : vector<9xf8E4M3FN>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @vector_trunc_long_2d
+// CHECK-SAME: ([[V:%.+]]: vector<1x9xf32>)
+// CHECK: [[ZEROES:%.+]] = arith.constant dense<0.000000e+00> : vector<9xf8E4M3FN>
+// CHECK: [[T0:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into undef[word 0]
+// CHECK: [[T1:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into [[T0]][word 1]
+// CHECK: [[W0:%.+]] = vector.insert_strided_slice [[T1]], [[ZEROES]] {offsets = [0], strides = [1]}
+
+// CHECK: [[T2:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into undef[word 0]
+// CHECK: [[T3:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into [[T2]][word 1]
+// CHECK: [[W1:%.+]] = vector.insert_strided_slice [[T3]], [[W0]] {offsets = [4], strides = [1]}
+
+// CHECK: [[T4:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, undef into undef[word 0]
+// CHECK: [[T4_SHORT:%.+]] = vector.extract_strided_slice [[T4]] {offsets = [0], sizes = [1], strides = [1]}
+// CHECK: [[W:%.+]] = vector.insert_strided_slice [[T4_SHORT]], [[W1]] {offsets = [8], strides = [1]}
+// CHECK: [[RE:%.+]] = vector.shape_cast [[W]] : vector<9xf8E4M3FN> to vector<1x9xf8E4M3FN>
+// CHECK: return [[RE]]
+func.func @vector_trunc_long_2d(%v: vector<1x9xf32>) -> vector<1x9xf8E4M3FN> {
+ %w = arith.truncf %v : vector<1x9xf32> to vector<1x9xf8E4M3FN>
+ return %w : vector<1x9xf8E4M3FN>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @vector_ext_long_2d
+// CHECK-SAME: ([[V:%.+]]: vector<1x9xf8E4M3FN>)
+// CHECK: [[CAST:%.+]] = vector.shape_cast [[V]] : vector<1x9xf8E4M3FN> to vector<9xf8E4M3FN>
+// CHECK: [[V0:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [0], sizes = [4], strides = [1]}
+// CHECK: [[F0:%.+]] = amdgpu.ext_packed_fp8 [[V0]][0]
+// CHECK: [[W0:%.+]] = vector.insert [[F0]]
+// CHECK: [[F1:%.+]] = amdgpu.ext_packed_fp8 [[V0]][1]
+// CHECK: [[W1:%.+]] = vector.insert [[F1]], [[W0]]
+// CHECK: [[F2:%.+]] = amdgpu.ext_packed_fp8 [[V0]][2]
+// CHECK: [[W2:%.+]] = vector.insert [[F2]], [[W1]]
+// CHECK: [[F3:%.+]] = amdgpu.ext_packed_fp8 [[V0]][3]
+// CHECK: [[W3:%.+]] = vector.insert [[F3]], [[W2]]
+
+// CHECK: [[V1:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [4], sizes = [4], strides = [1]} : vector<9xf8E4M3FN> to vector<4xf8E4M3FN>
+// CHECK: [[F4:%.+]] = amdgpu.ext_packed_fp8 [[V1]][0]
+// CHECK: [[W4:%.+]] = vector.insert [[F4]], [[W3]]
+// CHECK: [[F5:%.+]] = amdgpu.ext_packed_fp8 [[V1]][1]
+// CHECK: [[W5:%.+]] = vector.insert [[F5]], [[W4]]
+// CHECK: [[F6:%.+]] = amdgpu.ext_packed_fp8 [[V1]][2]
+// CHECK: [[W6:%.+]] = vector.insert [[F6]], [[W5]]
+// CHECK: [[F7:%.+]] = amdgpu.ext_packed_fp8 [[V1]][3]
+// CHECK: [[W7:%.+]] = vector.insert [[F7]], [[W6]]
+
+// CHECK: [[V2:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [8], sizes = [1], strides = [1]} : vector<9xf8E4M3FN> to vector<1xf8E4M3FN>
+// CHECK: [[F8:%.+]] = amdgpu.ext_packed_fp8 [[V2]][0]
+// CHECK: [[W8:%.+]] = vector.insert [[F8]], [[W7]]
+// CHECK: [[CAST:%.+]] = vector.shape_cast [[W8]] : vector<9xf32> to vector<1x9xf32>
+// CHECK: return [[CAST]]
+func.func @vector_ext_long_2d(%v: vector<1x9xf8E4M3FN>) -> vector<1x9xf32> {
+ %w = arith.extf %v : vector<1x9xf8E4M3FN> to vector<1x9xf32>
+ return %w : vector<1x9xf32>
+}
>From a0911cc0417b5f24bf38881b04983c91d4bcf548 Mon Sep 17 00:00:00 2001
From: Paul Fuqua <pf at acm.org>
Date: Thu, 12 Sep 2024 20:54:06 -0500
Subject: [PATCH 3/7] [MLIR][AMDGPU] Clean up and redo after other recent
patches here.
---
.../mlir/Dialect/AMDGPU/Utils/Chipset.h | 4 ++--
.../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 6 ++---
.../ArithToAMDGPU/ArithToAMDGPU.cpp | 24 ++++++++++++-------
3 files changed, 21 insertions(+), 13 deletions(-)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h b/mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h
index e5c14c1cb6827..5b071a46f49ed 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h
+++ b/mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h
@@ -49,10 +49,10 @@ struct Chipset {
#undef DEFINE_COMP_OPERATOR
bool isGfx940() const {
- return majorVersion == 9 && minorVersion >= 0x40 && minorVersion < 0x50;
+ return majorVersion == 9 && minorVersion >= 4 && minorVersion < 5;
}
bool hasOcpFp8() const {
- return (majorVersion == 9 && minorVersion >= 0x50) || majorVersion >= 12;
+ return (majorVersion == 9 && minorVersion >= 5) || majorVersion >= 12;
}
};
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index d904c47b28300..7227f879e666b 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -793,7 +793,7 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
- if (chipset.majorVersion != 9 || chipset < kGfx942)
+ if (!(chipset.isGfx940() || chipset.hasOcpFp8()))
return rewriter.notifyMatchFailure(
loc, "Fp8 conversion instructions are not available on target "
"architecture and their emulation is not implemented");
@@ -837,7 +837,7 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
- if (chipset.majorVersion != 9 || chipset < kGfx942)
+ if (!(chipset.isGfx940() || chipset.hasOcpFp8()))
return rewriter.notifyMatchFailure(
loc, "Fp8 conversion instructions are not available on target "
"architecture and their emulation is not implemented");
@@ -874,7 +874,7 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
- if (chipset.majorVersion != 9 || chipset < kGfx942)
+ if (!(chipset.isGfx940() || chipset.hasOcpFp8()))
return rewriter.notifyMatchFailure(
loc, "Fp8 conversion instructions are not available on target "
"architecture and their emulation is not implemented");
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index d70717f0ca2d1..83d239bb6f269 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -41,6 +41,10 @@ struct ArithToAMDGPUConversionPass final
struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> {
using OpRewritePattern::OpRewritePattern;
+ Chipset chipset;
+ ExtFOnFloat8RewritePattern(MLIRContext *ctx, Chipset chipset)
+ : OpRewritePattern::OpRewritePattern(ctx), chipset(chipset) {}
+
LogicalResult match(arith::ExtFOp op) const override;
void rewrite(arith::ExtFOp op, PatternRewriter &rewriter) const override;
};
@@ -68,6 +72,14 @@ struct TruncfToFloat16RewritePattern final
} // end namespace
+static LogicalResult isSupportedFp8(Type elementType, Chipset chipset) {
+ if (chipset.isGfx940())
+ return success(isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(elementType));
+ if (chipset.hasOcpFp8())
+ return success(isa<Float8E5M2Type, Float8E4M3FNType>(elementType));
+ return failure();
+}
+
static Value castF32To(Type elementType, Value f32, Location loc,
PatternRewriter &rewriter) {
if (elementType.isF32())
@@ -86,8 +98,7 @@ LogicalResult ExtFOnFloat8RewritePattern::match(arith::ExtFOp op) const {
return failure();
inType = inVecType.getElementType();
}
- return success(isa<Float8E5M2FNUZType, Float8E4M3FNUZType, Float8E5M2Type,
- Float8E4M3FNType>(inType));
+ return isSupportedFp8(inType, chipset);
}
void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
@@ -221,10 +232,7 @@ LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const {
// Conversion between 8-bit floats is not supported with truncation enabled.
return failure();
- return success((
- (isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(outType) &&
- chipset.isGfx940()) ||
- (isa<Float8E5M2Type, Float8E4M3FNType>(outType) && chipset.hasOcpFp8())));
+ return isSupportedFp8(outType, chipset);
}
void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
@@ -370,7 +378,7 @@ void mlir::arith::populateArithToAMDGPUConversionPatterns(
bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset) {
if (convertFP8Arithmetic) {
- patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext());
+ patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext(), chipset);
patterns.add<TruncFToFloat8RewritePattern>(patterns.getContext(),
saturateFP8Truncf, chipset);
}
@@ -389,7 +397,7 @@ void ArithToAMDGPUConversionPass::runOnOperation() {
}
bool convertFP8Arithmetic =
- maybeChipset->majorVersion == 9 && *maybeChipset >= Chipset(9, 4, 2);
+ maybeChipset->isGfx940() || maybeChipset->hasOcpFp8();
arith::populateArithToAMDGPUConversionPatterns(
patterns, convertFP8Arithmetic, saturateFP8Truncf, allowPackedF16Rtz,
*maybeChipset);
>From 0a761c0b0f0b542dcbb52049435412798609b843 Mon Sep 17 00:00:00 2001
From: Paul Fuqua <pf at acm.org>
Date: Wed, 18 Sep 2024 14:22:38 -0500
Subject: [PATCH 4/7] [MLIR][AMDGPU] Changes from the review.
---
mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h | 15 ++++++++-------
mlir/include/mlir/IR/Types.h | 3 +++
.../Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 14 +++++++-------
.../Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp | 16 ++++++++--------
mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 6 ++----
mlir/lib/IR/Types.cpp | 9 +++++++++
6 files changed, 37 insertions(+), 26 deletions(-)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h b/mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h
index 5b071a46f49ed..768b390ed5381 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h
+++ b/mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h
@@ -47,15 +47,16 @@ struct Chipset {
DEFINE_COMP_OPERATOR(>)
DEFINE_COMP_OPERATOR(>=)
#undef DEFINE_COMP_OPERATOR
-
- bool isGfx940() const {
- return majorVersion == 9 && minorVersion >= 4 && minorVersion < 5;
- }
- bool hasOcpFp8() const {
- return (majorVersion == 9 && minorVersion >= 5) || majorVersion >= 12;
- }
};
+inline bool isGfx940Series(const Chipset &chipset) {
+ return chipset.majorVersion == 9 && chipset.minorVersion == 4;
+}
+inline bool hasOcpFp8(const Chipset &chipset) {
+ return (chipset.majorVersion == 9 && chipset.minorVersion >= 5) ||
+ chipset.majorVersion >= 12;
+}
+
} // namespace mlir::amdgpu
#endif
diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index 0e82ad2be907a..e60f19a1ca585 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -132,6 +132,9 @@ class Type {
bool isF64() const;
bool isF80() const;
bool isF128() const;
+ /// Return true if this is an float type (with the specified width).
+ bool isFloat() const;
+ bool isFloat(unsigned width) const;
/// Return true if this is an integer type (with the specified width).
bool isInteger() const;
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 7227f879e666b..887e4401c2776 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -477,15 +477,15 @@ static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
/// Return true if `type` is the E5M2 variant of an 8-bit float that is
/// supported by the `_bf8` instructions on the given `chipset`.
static bool isNativeBf8(Chipset chipset, Type type) {
- return (chipset.isGfx940() && isa<Float8E5M2FNUZType>(type)) ||
- (chipset.hasOcpFp8() && isa<Float8E5M2Type>(type));
+ return (isGfx940Series(chipset) && isa<Float8E5M2FNUZType>(type)) ||
+ (hasOcpFp8(chipset) && isa<Float8E5M2Type>(type));
}
/// Return true if `type` is the E4M3FN variant of an 8-bit float that is
/// supported by the `_fp8` instructions on the given `chipset`.
static bool isNativeFp8(Chipset chipset, Type type) {
- return (chipset.isGfx940() && isa<Float8E4M3FNUZType>(type)) ||
- (chipset.hasOcpFp8() && isa<Float8E4M3FNType>(type));
+ return (isGfx940Series(chipset) && isa<Float8E4M3FNUZType>(type)) ||
+ (hasOcpFp8(chipset) && isa<Float8E4M3FNType>(type));
}
/// Return the `rocdl` intrinsic corresponding to a MFMA operation `mfma`
@@ -793,7 +793,7 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
- if (!(chipset.isGfx940() || chipset.hasOcpFp8()))
+ if (!(isGfx940Series(chipset) || hasOcpFp8(chipset)))
return rewriter.notifyMatchFailure(
loc, "Fp8 conversion instructions are not available on target "
"architecture and their emulation is not implemented");
@@ -837,7 +837,7 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
- if (!(chipset.isGfx940() || chipset.hasOcpFp8()))
+ if (!(isGfx940Series(chipset) || hasOcpFp8(chipset)))
return rewriter.notifyMatchFailure(
loc, "Fp8 conversion instructions are not available on target "
"architecture and their emulation is not implemented");
@@ -874,7 +874,7 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
- if (!(chipset.isGfx940() || chipset.hasOcpFp8()))
+ if (!(isGfx940Series(chipset) || hasOcpFp8(chipset)))
return rewriter.notifyMatchFailure(
loc, "Fp8 conversion instructions are not available on target "
"architecture and their emulation is not implemented");
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 83d239bb6f269..29ecbd4b9a604 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -72,11 +72,11 @@ struct TruncfToFloat16RewritePattern final
} // end namespace
-static LogicalResult isSupportedFp8(Type elementType, Chipset chipset) {
- if (chipset.isGfx940())
- return success(isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(elementType));
- if (chipset.hasOcpFp8())
- return success(isa<Float8E5M2Type, Float8E4M3FNType>(elementType));
+static LogicalResult isSupportedF8(Type elementType, Chipset chipset) {
+ if (isGfx940Series(chipset))
+ return success(isa<Float8E4M3FNUZType, Float8E5M2FNUZType>(elementType));
+ if (hasOcpFp8(chipset))
+ return success(isa<Float8E4M3FNType, Float8E5M2Type>(elementType));
return failure();
}
@@ -98,7 +98,7 @@ LogicalResult ExtFOnFloat8RewritePattern::match(arith::ExtFOp op) const {
return failure();
inType = inVecType.getElementType();
}
- return isSupportedFp8(inType, chipset);
+ return isSupportedF8(inType, chipset);
}
void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
@@ -232,7 +232,7 @@ LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const {
// Conversion between 8-bit floats is not supported with truncation enabled.
return failure();
- return isSupportedFp8(outType, chipset);
+ return isSupportedF8(outType, chipset);
}
void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
@@ -397,7 +397,7 @@ void ArithToAMDGPUConversionPass::runOnOperation() {
}
bool convertFP8Arithmetic =
- maybeChipset->isGfx940() || maybeChipset->hasOcpFp8();
+ isGfx940Series(*maybeChipset) || hasOcpFp8(*maybeChipset);
arith::populateArithToAMDGPUConversionPatterns(
patterns, convertFP8Arithmetic, saturateFP8Truncf, allowPackedF16Rtz,
*maybeChipset);
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index c6aba7560acf5..630bece3a685f 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -272,16 +272,14 @@ LogicalResult MFMAOp::verify() {
}
Type sourceBType = getSourceB().getType();
- if (isa<Float8E5M2FNUZType, Float8E4M3FNUZType, Float8E5M2Type,
- Float8E4M3FNType>(sourceElem)) {
+ if (sourceElem.isFloat(8)) {
int64_t sourceBLen = 1;
Type sourceBElem = sourceBType;
if (auto sourceBVector = llvm::dyn_cast<VectorType>(sourceBType)) {
sourceBLen = sourceBVector.getNumElements();
sourceBElem = sourceBVector.getElementType();
}
- if (!isa<Float8E5M2FNUZType, Float8E4M3FNUZType, Float8E5M2Type,
- Float8E4M3FNType>(sourceBElem))
+ if (!sourceBElem.isFloat(8))
return emitOpError("expected both source operands to have f8 elements");
if (sourceLen != sourceBLen)
return emitOpError(
diff --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp
index bca90de6f4a8a..76bfa4f4e4cef 100644
--- a/mlir/lib/IR/Types.cpp
+++ b/mlir/lib/IR/Types.cpp
@@ -42,6 +42,15 @@ bool Type::isF64() const { return llvm::isa<Float64Type>(*this); }
bool Type::isF80() const { return llvm::isa<Float80Type>(*this); }
bool Type::isF128() const { return llvm::isa<Float128Type>(*this); }
+bool Type::isFloat() const { return llvm::isa<FloatType>(*this); }
+
+/// Return true if this is an integer type with the specified width.
+bool Type::isFloat(unsigned width) const {
+ if (auto fltTy = llvm::dyn_cast<FloatType>(*this))
+ return fltTy.getWidth() == width;
+ return false;
+}
+
bool Type::isIndex() const { return llvm::isa<IndexType>(*this); }
bool Type::isInteger() const { return llvm::isa<IntegerType>(*this); }
>From dd2cad3567e0cad4961e00d2bf9d1722705f0318 Mon Sep 17 00:00:00 2001
From: Paul Fuqua <pf at acm.org>
Date: Mon, 30 Sep 2024 12:42:13 -0500
Subject: [PATCH 5/7] [MLIR][AMDGPU] Renaming using suggestions from review.
---
.../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 36 +++++++++----------
1 file changed, 18 insertions(+), 18 deletions(-)
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 887e4401c2776..a793eb7b746fc 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -476,14 +476,14 @@ static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
/// Return true if `type` is the E5M2 variant of an 8-bit float that is
/// supported by the `_bf8` instructions on the given `chipset`.
-static bool isNativeBf8(Chipset chipset, Type type) {
+static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type) {
return (isGfx940Series(chipset) && isa<Float8E5M2FNUZType>(type)) ||
(hasOcpFp8(chipset) && isa<Float8E5M2Type>(type));
}
/// Return true if `type` is the E4M3FN variant of an 8-bit float that is
/// supported by the `_fp8` instructions on the given `chipset`.
-static bool isNativeFp8(Chipset chipset, Type type) {
+static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type) {
return (isGfx940Series(chipset) && isa<Float8E4M3FNUZType>(type)) ||
(hasOcpFp8(chipset) && isa<Float8E4M3FNType>(type));
}
@@ -584,38 +584,38 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
return ROCDL::mfma_f64_4x4x4f64::getOperationName();
}
- if (destElem.isF32() && isNativeBf8(chipset, sourceElem)) {
+ if (destElem.isF32() && typeIsExpectedBf8ForChipset(chipset, sourceElem)) {
// Known to be correct because there are no scalar f8 instructions and
// because a length mismatch will have been caught by the verifier.
Type sourceBElem =
cast<VectorType>(mfma.getSourceB().getType()).getElementType();
if (m == 16 && n == 16 && k == 32 && b == 1) {
- if (isNativeBf8(chipset, sourceBElem))
+ if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
- if (isNativeFp8(chipset, sourceBElem))
+ if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
}
if (m == 32 && n == 32 && k == 16 && b == 1) {
- if (isNativeBf8(chipset, sourceBElem))
+ if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
- if (isNativeFp8(chipset, sourceBElem))
+ if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
}
}
- if (destElem.isF32() && isNativeFp8(chipset, sourceElem)) {
+ if (destElem.isF32() && typeIsExpectedFp8ForChipset(chipset, sourceElem)) {
Type sourceBElem =
cast<VectorType>(mfma.getSourceB().getType()).getElementType();
if (m == 16 && n == 16 && k == 32 && b == 1) {
- if (isNativeBf8(chipset, sourceBElem))
+ if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
- if (isNativeFp8(chipset, sourceBElem))
+ if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
}
if (m == 32 && n == 32 && k == 16 && b == 1) {
- if (isNativeBf8(chipset, sourceBElem))
+ if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
- if (isNativeFp8(chipset, sourceBElem))
+ if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
}
}
@@ -823,10 +823,10 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
}
Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
Value wordSel = createI32Constant(rewriter, loc, op.getIndex());
- if (isNativeBf8(chipset, sourceElemType)) {
+ if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
wordSel);
- } else if (isNativeFp8(chipset, sourceElemType)) {
+ } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
wordSel);
}
@@ -858,10 +858,10 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
Value wordSel = createI1Constant(rewriter, loc, op.getWordIndex());
Value result;
- if (isNativeBf8(chipset, resultElemType))
+ if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
result = rewriter.create<ROCDL::CvtPkBf8F32Op>(loc, i32, sourceA, sourceB,
existing, wordSel);
- else if (isNativeFp8(chipset, resultElemType))
+ else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
result = rewriter.create<ROCDL::CvtPkFp8F32Op>(loc, i32, sourceA, sourceB,
existing, wordSel);
@@ -893,10 +893,10 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
Value byteSel = createI32Constant(rewriter, loc, op.getStoreIndex());
Value result;
- if (isNativeBf8(chipset, resultElemType))
+ if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
result = rewriter.create<ROCDL::CvtSrBf8F32Op>(loc, i32, source, stoch,
existing, byteSel);
- else if (isNativeFp8(chipset, resultElemType))
+ else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
result = rewriter.create<ROCDL::CvtSrFp8F32Op>(loc, i32, source, stoch,
existing, byteSel);
>From 5bace4697b6f381f7cef0b6a25fecc67268565ff Mon Sep 17 00:00:00 2001
From: Mirza Halilcevic <mirza.halilcevic at amd.com>
Date: Tue, 18 Feb 2025 14:55:23 +0000
Subject: [PATCH 6/7] [MLIR][AMDGPU] Address TOSA related review comments.
Signed-off-by: Mirza Halilcevic <mirza.halilcevic at amd.com>
---
mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp | 4 ++--
mlir/lib/IR/Types.cpp | 2 +-
2 files changed, 3 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 0113813e13419..404e45edf3a2f 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -696,8 +696,8 @@ LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
bool TosaValidation::isValidElementType(Type type) {
if (isa<FloatType>(type)) {
- return isa<Float32Type, Float16Type, BFloat16Type, Float8E4M3FNUZType,
- Float8E5M2FNUZType, Float8E4M3FNType, Float8E5M2Type>(type);
+ return isa<Float32Type, Float16Type, BFloat16Type, Float8E4M3FNType,
+ Float8E5M2Type>(type);
} else if (auto intTy = dyn_cast<IntegerType>(type)) {
if (intTy.isSignless()) {
switch (intTy.getWidth()) {
diff --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp
index 76bfa4f4e4cef..bd00ffeabec7b 100644
--- a/mlir/lib/IR/Types.cpp
+++ b/mlir/lib/IR/Types.cpp
@@ -44,7 +44,7 @@ bool Type::isF128() const { return llvm::isa<Float128Type>(*this); }
bool Type::isFloat() const { return llvm::isa<FloatType>(*this); }
-/// Return true if this is an integer type with the specified width.
+/// Return true if this is a float type with the specified width.
bool Type::isFloat(unsigned width) const {
if (auto fltTy = llvm::dyn_cast<FloatType>(*this))
return fltTy.getWidth() == width;
>From 19b23a289da5bada4eaad6b3ada59f9cfa3bcb36 Mon Sep 17 00:00:00 2001
From: Mirza Halilcevic <mirza.halilcevic at amd.com>
Date: Mon, 24 Feb 2025 20:45:26 +0000
Subject: [PATCH 7/7] [MLIR][AMDGPU] Replace isGfx940Series with an equality
check for gfx942, since gfx940 and gfx941 are no longer supported.
Signed-off-by: Mirza Halilcevic <mirza.halilcevic at amd.com>
---
mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h | 3 ---
mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 10 +++++-----
mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp | 7 +++++--
3 files changed, 10 insertions(+), 10 deletions(-)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h b/mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h
index 768b390ed5381..ca9809799588c 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h
+++ b/mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h
@@ -49,9 +49,6 @@ struct Chipset {
#undef DEFINE_COMP_OPERATOR
};
-inline bool isGfx940Series(const Chipset &chipset) {
- return chipset.majorVersion == 9 && chipset.minorVersion == 4;
-}
inline bool hasOcpFp8(const Chipset &chipset) {
return (chipset.majorVersion == 9 && chipset.minorVersion >= 5) ||
chipset.majorVersion >= 12;
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index a793eb7b746fc..ac41ba9938ea5 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -477,14 +477,14 @@ static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
/// Return true if `type` is the E5M2 variant of an 8-bit float that is
/// supported by the `_bf8` instructions on the given `chipset`.
static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type) {
- return (isGfx940Series(chipset) && isa<Float8E5M2FNUZType>(type)) ||
+ return (chipset == kGfx942 && isa<Float8E5M2FNUZType>(type)) ||
(hasOcpFp8(chipset) && isa<Float8E5M2Type>(type));
}
/// Return true if `type` is the E4M3FN variant of an 8-bit float that is
/// supported by the `_fp8` instructions on the given `chipset`.
static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type) {
- return (isGfx940Series(chipset) && isa<Float8E4M3FNUZType>(type)) ||
+ return (chipset == kGfx942 && isa<Float8E4M3FNUZType>(type)) ||
(hasOcpFp8(chipset) && isa<Float8E4M3FNType>(type));
}
@@ -793,7 +793,7 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
- if (!(isGfx940Series(chipset) || hasOcpFp8(chipset)))
+ if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
return rewriter.notifyMatchFailure(
loc, "Fp8 conversion instructions are not available on target "
"architecture and their emulation is not implemented");
@@ -837,7 +837,7 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
- if (!(isGfx940Series(chipset) || hasOcpFp8(chipset)))
+ if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
return rewriter.notifyMatchFailure(
loc, "Fp8 conversion instructions are not available on target "
"architecture and their emulation is not implemented");
@@ -874,7 +874,7 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
- if (!(isGfx940Series(chipset) || hasOcpFp8(chipset)))
+ if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
return rewriter.notifyMatchFailure(
loc, "Fp8 conversion instructions are not available on target "
"architecture and their emulation is not implemented");
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 29ecbd4b9a604..cba71740f9380 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -30,6 +30,9 @@ using namespace mlir;
using namespace mlir::amdgpu;
namespace {
+// Define commonly used chipsets versions for convenience.
+constexpr Chipset kGfx942 = Chipset(9, 4, 2);
+
struct ArithToAMDGPUConversionPass final
: impl::ArithToAMDGPUConversionPassBase<ArithToAMDGPUConversionPass> {
using impl::ArithToAMDGPUConversionPassBase<
@@ -73,7 +76,7 @@ struct TruncfToFloat16RewritePattern final
} // end namespace
static LogicalResult isSupportedF8(Type elementType, Chipset chipset) {
- if (isGfx940Series(chipset))
+ if (chipset == kGfx942)
return success(isa<Float8E4M3FNUZType, Float8E5M2FNUZType>(elementType));
if (hasOcpFp8(chipset))
return success(isa<Float8E4M3FNType, Float8E5M2Type>(elementType));
@@ -397,7 +400,7 @@ void ArithToAMDGPUConversionPass::runOnOperation() {
}
bool convertFP8Arithmetic =
- isGfx940Series(*maybeChipset) || hasOcpFp8(*maybeChipset);
+ *maybeChipset == kGfx942 || hasOcpFp8(*maybeChipset);
arith::populateArithToAMDGPUConversionPatterns(
patterns, convertFP8Arithmetic, saturateFP8Truncf, allowPackedF16Rtz,
*maybeChipset);
More information about the Mlir-commits
mailing list