[Mlir-commits] [mlir] [MLIR][AMDGPU] Implement emulated F8 for the OCP formats. (PR #106160)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Aug 26 16:45:09 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-tosa
Author: Paul C Fuqua (pcf000)
<details>
<summary>Changes</summary>
This part mostly just allows the new types in places where the other F8 formats were allowed.
---
Full diff: https://github.com/llvm/llvm-project/pull/106160.diff
6 Files Affected:
- (modified) mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td (+9-9)
- (modified) mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h (+7)
- (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+21-18)
- (modified) mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp (+7-2)
- (modified) mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp (+4-2)
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (+3-1)
``````````diff
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index e5c1a53f34bf64..04b66cea661afc 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -41,8 +41,8 @@ class AMDGPU_Op<string mnemonic, list<Trait> traits = []> :
def AMDGPU_ExtPackedFp8Op :
AMDGPU_Op<"ext_packed_fp8", [Pure]>,
- Arguments<(ins AnyTypeOf<[F8E5M2FNUZ, F8E4M3FNUZ,
- VectorOfLengthAndType<[1, 2, 3, 4], [F8E5M2FNUZ, F8E4M3FNUZ]>]>:$source,
+ Arguments<(ins AnyTypeOf<[F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN,
+ VectorOfLengthAndType<[1, 2, 3, 4], [F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN]>]>:$source,
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<3>]>:$index)>,
Results<(outs F32:$res)> {
let summary = "Extend one of a vector of packed fp8 values to a float";
@@ -68,8 +68,8 @@ def AMDGPU_PackedTrunc2xFp8Op :
Arguments<(ins F32:$sourceA,
Optional<F32>:$sourceB,
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<1>]>:$wordIndex,
- Optional<FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>>:$existing)>,
- Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>:$res)> {
+ Optional<FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2]>>:$existing)>,
+ Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2]>:$res)> {
let summary = "Round two floats into a packed vector of 8-bit floats";
let description = [{
Round the inputs `sourceA` and `sourceB` (which is undefined if not
@@ -95,8 +95,8 @@ def AMDGPU_PackedStochRoundFp8Op :
Arguments<(ins F32:$source,
I32:$stochiasticParam,
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<3>]>:$storeIndex,
- Optional<FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>>:$existing)>,
- Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>:$res)> {
+ Optional<FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2]>>:$existing)>,
+ Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2]>:$res)> {
let summary = "Round float stochiastically into a packed vector of 8-bit floats";
let description = [{
Round the input `source`, adding in `stochiasticParam`, and place it into
@@ -405,7 +405,7 @@ def AMDGPU_RawBufferAtomicUminOp :
def AMDGPU_DPPPerm : I32EnumAttr<"DPPPerm",
"The possible permutations for a DPP operation",
- [
+ [
I32EnumAttrCase<"quad_perm", 0>,
I32EnumAttrCase<"row_shl", 1>,
I32EnumAttrCase<"row_shr", 2>,
@@ -419,7 +419,7 @@ def AMDGPU_DPPPerm : I32EnumAttr<"DPPPerm",
I32EnumAttrCase<"row_bcast_15", 10>,
I32EnumAttrCase<"row_bcast_31", 11>
]> {
- let genSpecializedAttr = 0;
+ let genSpecializedAttr = 0;
let cppNamespace = "::mlir::amdgpu";
}
@@ -546,7 +546,7 @@ def MFMAInTypes : AnyTypeOf<[F32, F64, I32, I64,
VectorOfLengthAndType<[4], [F16]>,
VectorOfLengthAndType<[2, 4], [BF16]>,
VectorOfLengthAndType<[4, 8], [I8]>,
- VectorOfLengthAndType<[8], [F8E5M2FNUZ, F8E4M3FNUZ]>]>;
+ VectorOfLengthAndType<[8], [F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN]>]>;
def MFMAOutTypes : AnyTypeOf<[F64,
VectorOfLengthAndType<[4, 16, 32], [F32]>,
VectorOfLengthAndType<[4, 16, 32], [I32]>,
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h b/mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h
index 38e0ebe68f943b..6de12a3d50878b 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h
+++ b/mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h
@@ -18,6 +18,13 @@ struct Chipset {
: majorVersion(majorVersion), minorVersion(minorVersion){};
static FailureOr<Chipset> parse(StringRef name);
+ bool isGfx940() const {
+ return majorVersion == 9 && minorVersion >= 0x40 && majorVersion < 0x50;
+ }
+ bool hasOcpFp8() const {
+ return (majorVersion == 9 && minorVersion >= 0x50) || majorVersion >= 12;
+ }
+
unsigned majorVersion = 0;
unsigned minorVersion = 0;
};
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 809e9448e80abf..9323fdc7dacd6d 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -539,40 +539,42 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
return ROCDL::mfma_f64_4x4x4f64::getOperationName();
}
- if (sourceElem.isFloat8E5M2FNUZ() && destElem.isF32() &&
- chipset.minorVersion >= 0x40) {
+ if (destElem.isF32() &&
+ ((sourceElem.isFloat8E5M2FNUZ() && chipset.isGfx940()) ||
+ (sourceElem.isFloat8E5M2() && chipset.hasOcpFp8()))) {
// Known to be correct because there are no scalar f8 instructions and
// because a length mismatch will have been caught by the verifier.
Type sourceBElem =
cast<VectorType>(mfma.getSourceB().getType()).getElementType();
if (m == 16 && n == 16 && k == 32 && b == 1) {
- if (sourceBElem.isFloat8E5M2FNUZ())
+ if (sourceBElem.isFloat8E5M2FNUZ() || sourceBElem.isFloat8E5M2())
return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
- if (sourceBElem.isFloat8E4M3FNUZ())
+ if (sourceBElem.isFloat8E4M3FNUZ() || sourceBElem.isFloat8E4M3FN())
return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
}
if (m == 32 && n == 32 && k == 16 && b == 1) {
- if (sourceBElem.isFloat8E5M2FNUZ())
+ if (sourceBElem.isFloat8E5M2FNUZ() || sourceBElem.isFloat8E5M2())
return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
- if (sourceBElem.isFloat8E4M3FNUZ())
+ if (sourceBElem.isFloat8E4M3FNUZ() || sourceBElem.isFloat8E4M3FN())
return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
}
}
- if (sourceElem.isFloat8E4M3FNUZ() && destElem.isF32() &&
- chipset.minorVersion >= 0x40) {
+ if (destElem.isF32() &&
+ ((sourceElem.isFloat8E4M3FNUZ() && chipset.isGfx940()) ||
+ (sourceElem.isFloat8E4M3FN() && chipset.hasOcpFp8()))) {
Type sourceBElem =
cast<VectorType>(mfma.getSourceB().getType()).getElementType();
if (m == 16 && n == 16 && k == 32 && b == 1) {
- if (sourceBElem.isFloat8E5M2FNUZ())
+ if (sourceBElem.isFloat8E5M2FNUZ() || sourceBElem.isFloat8E5M2())
return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
- if (sourceBElem.isFloat8E4M3FNUZ())
+ if (sourceBElem.isFloat8E4M3FNUZ() || sourceBElem.isFloat8E4M3FN())
return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
}
if (m == 32 && n == 32 && k == 16 && b == 1) {
- if (sourceBElem.isFloat8E5M2FNUZ())
+ if (sourceBElem.isFloat8E5M2FNUZ() || sourceBElem.isFloat8E5M2())
return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
- if (sourceBElem.isFloat8E4M3FNUZ())
+ if (sourceBElem.isFloat8E4M3FNUZ() || sourceBElem.isFloat8E4M3FN())
return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
}
}
@@ -762,10 +764,11 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
}
Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
Value wordSel = createI32Constant(rewriter, loc, op.getIndex());
- if (sourceElemType.isFloat8E5M2FNUZ()) {
+ if (sourceElemType.isFloat8E5M2FNUZ() || sourceElemType.isFloat8E5M2()) {
rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
wordSel);
- } else if (sourceElemType.isFloat8E4M3FNUZ()) {
+ } else if (sourceElemType.isFloat8E4M3FNUZ() ||
+ sourceElemType.isFloat8E4M3FN()) {
rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
wordSel);
}
@@ -797,10 +800,10 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
Value wordSel = createI1Constant(rewriter, loc, op.getWordIndex());
Value result;
- if (resultElemType.isFloat8E5M2FNUZ())
+ if (resultElemType.isFloat8E5M2FNUZ() || resultElemType.isFloat8E5M2())
result = rewriter.create<ROCDL::CvtPkBf8F32Op>(loc, i32, sourceA, sourceB,
existing, wordSel);
- else if (resultElemType.isFloat8E4M3FNUZ())
+ else if (resultElemType.isFloat8E4M3FNUZ() || resultElemType.isFloat8E4M3FN())
result = rewriter.create<ROCDL::CvtPkFp8F32Op>(loc, i32, sourceA, sourceB,
existing, wordSel);
@@ -832,10 +835,10 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
Value byteSel = createI32Constant(rewriter, loc, op.getStoreIndex());
Value result;
- if (resultElemType.isFloat8E5M2FNUZ())
+ if (resultElemType.isFloat8E5M2FNUZ() || resultElemType.isFloat8E5M2())
result = rewriter.create<ROCDL::CvtSrBf8F32Op>(loc, i32, source, stoch,
existing, byteSel);
- else if (resultElemType.isFloat8E4M3FNUZ())
+ else if (resultElemType.isFloat8E4M3FNUZ() || resultElemType.isFloat8E4M3FN())
result = rewriter.create<ROCDL::CvtSrFp8F32Op>(loc, i32, source, stoch,
existing, byteSel);
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index d36583c8118ff4..a66c13caa6d0ab 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -86,7 +86,8 @@ LogicalResult ExtFOnFloat8RewritePattern::match(arith::ExtFOp op) const {
return failure();
inType = inVecType.getElementType();
}
- return success(inType.isFloat8E5M2FNUZ() || inType.isFloat8E4M3FNUZ());
+ return success(inType.isFloat8E5M2FNUZ() || inType.isFloat8E4M3FNUZ() ||
+ inType.isFloat8E5M2() || inType.isFloat8E4M3FN());
}
void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
@@ -216,7 +217,11 @@ LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const {
if (inType && inType.getWidth() <= 8 && saturateFP8)
// Conversion between 8-bit floats is not supported with truncation enabled.
return failure();
- return success(outType.isFloat8E5M2FNUZ() || outType.isFloat8E4M3FNUZ());
+
+ return success((((outType.isFloat8E5M2FNUZ() || outType.isFloat8E4M3FNUZ()) &&
+ chipset.isGfx940()) ||
+ ((outType.isFloat8E5M2() || outType.isFloat8E4M3FN()) &&
+ chipset.hasOcpFp8())));
}
void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 3943696364950f..2747eebebefa52 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -271,14 +271,16 @@ LogicalResult MFMAOp::verify() {
}
Type sourceBType = getSourceB().getType();
- if (sourceElem.isFloat8E5M2FNUZ() || sourceElem.isFloat8E4M3FNUZ()) {
+ if (sourceElem.isFloat8E5M2FNUZ() || sourceElem.isFloat8E4M3FNUZ() ||
+ sourceElem.isFloat8E5M2() || sourceElem.isFloat8E4M3FN()) {
int64_t sourceBLen = 1;
Type sourceBElem = sourceBType;
if (auto sourceBVector = llvm::dyn_cast<VectorType>(sourceBType)) {
sourceBLen = sourceBVector.getNumElements();
sourceBElem = sourceBVector.getElementType();
}
- if (!sourceBElem.isFloat8E5M2FNUZ() && !sourceBElem.isFloat8E4M3FNUZ())
+ if (!sourceBElem.isFloat8E5M2FNUZ() && !sourceBElem.isFloat8E4M3FNUZ() &&
+ !sourceBElem.isFloat8E5M2() && !sourceBElem.isFloat8E4M3FN())
return emitOpError("expected both source operands to have f8 elements");
if (sourceLen != sourceBLen)
return emitOpError(
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index b78c372af77e64..963fd6fd7c0511 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -509,7 +509,9 @@ bool TosaValidation::isValidElementType(Type type) {
if (isa<FloatType>(type)) {
if (profile == TosaProfileEnum::BaseInference)
return false;
- return type.isF32() || type.isF16() || type.isBF16();
+ return type.isF32() || type.isF16() || type.isBF16() ||
+ type.isFloat8E4M3FNUZ() || type.isFloat8E5M2FNUZ() ||
+ type.isFloat8E4M3FN() || type.isFloat8E5M2();
}
if (auto intTy = dyn_cast<IntegerType>(type)) {
if (intTy.isUnsigned()) {
``````````
</details>
https://github.com/llvm/llvm-project/pull/106160
More information about the Mlir-commits
mailing list