[Mlir-commits] [mlir] [MLIR][AMDGPU] Implement emulated F8 for the OCP formats. (PR #106160)

Paul C Fuqua llvmlistbot at llvm.org
Mon Aug 26 16:44:41 PDT 2024


https://github.com/pcf000 created https://github.com/llvm/llvm-project/pull/106160

 This part mostly just allows the new types in places where the other F8 formats were allowed.

>From 05fe240009878f51156b29b2a9afc7238c7726f2 Mon Sep 17 00:00:00 2001
From: Paul Fuqua <pf at acm.org>
Date: Thu, 11 Jul 2024 20:12:45 -0500
Subject: [PATCH] [MLIR][AMDGPU] Implement emulated FP8 for the OCP formats.
 This part mostly just allows the new types.

---
 mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 18 ++++-----
 .../mlir/Dialect/AMDGPU/Utils/Chipset.h       |  7 ++++
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           | 39 ++++++++++---------
 .../ArithToAMDGPU/ArithToAMDGPU.cpp           |  9 ++++-
 mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp  |  6 ++-
 .../Tosa/Transforms/TosaValidation.cpp        |  4 +-
 6 files changed, 51 insertions(+), 32 deletions(-)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index e5c1a53f34bf64..04b66cea661afc 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -41,8 +41,8 @@ class AMDGPU_Op<string mnemonic, list<Trait> traits = []> :
 
 def AMDGPU_ExtPackedFp8Op :
     AMDGPU_Op<"ext_packed_fp8", [Pure]>,
-    Arguments<(ins AnyTypeOf<[F8E5M2FNUZ, F8E4M3FNUZ,
-        VectorOfLengthAndType<[1, 2, 3, 4], [F8E5M2FNUZ, F8E4M3FNUZ]>]>:$source,
+    Arguments<(ins AnyTypeOf<[F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN,
+        VectorOfLengthAndType<[1, 2, 3, 4], [F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN]>]>:$source,
       ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<3>]>:$index)>,
     Results<(outs F32:$res)> {
   let summary = "Extend one of a vector of packed fp8 values to a float";
@@ -68,8 +68,8 @@ def AMDGPU_PackedTrunc2xFp8Op :
     Arguments<(ins F32:$sourceA,
       Optional<F32>:$sourceB,
       ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<1>]>:$wordIndex,
-      Optional<FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>>:$existing)>,
-    Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>:$res)> {
+      Optional<FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2]>>:$existing)>,
+    Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2]>:$res)> {
   let summary = "Round two floats into a packed vector of 8-bit floats";
   let description = [{
     Round the inputs `sourceA` and `sourceB` (which is undefined if not
@@ -95,8 +95,8 @@ def AMDGPU_PackedStochRoundFp8Op :
     Arguments<(ins F32:$source,
       I32:$stochiasticParam,
       ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<3>]>:$storeIndex,
-      Optional<FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>>:$existing)>,
-    Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>:$res)> {
+      Optional<FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2]>>:$existing)>,
+    Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2]>:$res)> {
   let summary = "Round float stochiastically into a packed vector of 8-bit floats";
   let description = [{
     Round the input `source`, adding in `stochiasticParam`, and place it into
@@ -405,7 +405,7 @@ def AMDGPU_RawBufferAtomicUminOp :
 
 def AMDGPU_DPPPerm : I32EnumAttr<"DPPPerm",
     "The possible permutations for a DPP operation",
-    [  
+    [
       I32EnumAttrCase<"quad_perm",  0>,
       I32EnumAttrCase<"row_shl",    1>,
       I32EnumAttrCase<"row_shr",    2>,
@@ -419,7 +419,7 @@ def AMDGPU_DPPPerm : I32EnumAttr<"DPPPerm",
       I32EnumAttrCase<"row_bcast_15", 10>,
       I32EnumAttrCase<"row_bcast_31", 11>
     ]> {
-  let genSpecializedAttr = 0;  
+  let genSpecializedAttr = 0;
   let cppNamespace = "::mlir::amdgpu";
 }
 
@@ -546,7 +546,7 @@ def MFMAInTypes : AnyTypeOf<[F32, F64, I32, I64,
                              VectorOfLengthAndType<[4], [F16]>,
                              VectorOfLengthAndType<[2, 4], [BF16]>,
                              VectorOfLengthAndType<[4, 8], [I8]>,
-                             VectorOfLengthAndType<[8], [F8E5M2FNUZ, F8E4M3FNUZ]>]>;
+                             VectorOfLengthAndType<[8], [F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN]>]>;
 def MFMAOutTypes : AnyTypeOf<[F64,
                               VectorOfLengthAndType<[4, 16, 32], [F32]>,
                               VectorOfLengthAndType<[4, 16, 32], [I32]>,
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h b/mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h
index 38e0ebe68f943b..6de12a3d50878b 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h
+++ b/mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h
@@ -18,6 +18,13 @@ struct Chipset {
       : majorVersion(majorVersion), minorVersion(minorVersion){};
   static FailureOr<Chipset> parse(StringRef name);
 
+  bool isGfx940() const {
+    return majorVersion == 9 && minorVersion >= 0x40 && majorVersion < 0x50;
+  }
+  bool hasOcpFp8() const {
+    return (majorVersion == 9 && minorVersion >= 0x50) || majorVersion >= 12;
+  }
+
   unsigned majorVersion = 0;
   unsigned minorVersion = 0;
 };
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 809e9448e80abf..9323fdc7dacd6d 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -539,40 +539,42 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
       return ROCDL::mfma_f64_4x4x4f64::getOperationName();
   }
 
-  if (sourceElem.isFloat8E5M2FNUZ() && destElem.isF32() &&
-      chipset.minorVersion >= 0x40) {
+  if (destElem.isF32() &&
+      ((sourceElem.isFloat8E5M2FNUZ() && chipset.isGfx940()) ||
+       (sourceElem.isFloat8E5M2() && chipset.hasOcpFp8()))) {
     // Known to be correct because there are no scalar f8 instructions and
     // because a length mismatch will have been caught by the verifier.
     Type sourceBElem =
         cast<VectorType>(mfma.getSourceB().getType()).getElementType();
     if (m == 16 && n == 16 && k == 32 && b == 1) {
-      if (sourceBElem.isFloat8E5M2FNUZ())
+      if (sourceBElem.isFloat8E5M2FNUZ() || sourceBElem.isFloat8E5M2())
         return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
-      if (sourceBElem.isFloat8E4M3FNUZ())
+      if (sourceBElem.isFloat8E4M3FNUZ() || sourceBElem.isFloat8E4M3FN())
         return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
     }
     if (m == 32 && n == 32 && k == 16 && b == 1) {
-      if (sourceBElem.isFloat8E5M2FNUZ())
+      if (sourceBElem.isFloat8E5M2FNUZ() || sourceBElem.isFloat8E5M2())
         return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
-      if (sourceBElem.isFloat8E4M3FNUZ())
+      if (sourceBElem.isFloat8E4M3FNUZ() || sourceBElem.isFloat8E4M3FN())
         return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
     }
   }
 
-  if (sourceElem.isFloat8E4M3FNUZ() && destElem.isF32() &&
-      chipset.minorVersion >= 0x40) {
+  if (destElem.isF32() &&
+      ((sourceElem.isFloat8E4M3FNUZ() && chipset.isGfx940()) ||
+       (sourceElem.isFloat8E4M3FN() && chipset.hasOcpFp8()))) {
     Type sourceBElem =
         cast<VectorType>(mfma.getSourceB().getType()).getElementType();
     if (m == 16 && n == 16 && k == 32 && b == 1) {
-      if (sourceBElem.isFloat8E5M2FNUZ())
+      if (sourceBElem.isFloat8E5M2FNUZ() || sourceBElem.isFloat8E5M2())
         return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
-      if (sourceBElem.isFloat8E4M3FNUZ())
+      if (sourceBElem.isFloat8E4M3FNUZ() || sourceBElem.isFloat8E4M3FN())
         return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
     }
     if (m == 32 && n == 32 && k == 16 && b == 1) {
-      if (sourceBElem.isFloat8E5M2FNUZ())
+      if (sourceBElem.isFloat8E5M2FNUZ() || sourceBElem.isFloat8E5M2())
         return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
-      if (sourceBElem.isFloat8E4M3FNUZ())
+      if (sourceBElem.isFloat8E4M3FNUZ() || sourceBElem.isFloat8E4M3FN())
         return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
     }
   }
@@ -762,10 +764,11 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
   }
   Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
   Value wordSel = createI32Constant(rewriter, loc, op.getIndex());
-  if (sourceElemType.isFloat8E5M2FNUZ()) {
+  if (sourceElemType.isFloat8E5M2FNUZ() || sourceElemType.isFloat8E5M2()) {
     rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
                                                     wordSel);
-  } else if (sourceElemType.isFloat8E4M3FNUZ()) {
+  } else if (sourceElemType.isFloat8E4M3FNUZ() ||
+             sourceElemType.isFloat8E4M3FN()) {
     rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
                                                     wordSel);
   }
@@ -797,10 +800,10 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
   Value wordSel = createI1Constant(rewriter, loc, op.getWordIndex());
 
   Value result;
-  if (resultElemType.isFloat8E5M2FNUZ())
+  if (resultElemType.isFloat8E5M2FNUZ() || resultElemType.isFloat8E5M2())
     result = rewriter.create<ROCDL::CvtPkBf8F32Op>(loc, i32, sourceA, sourceB,
                                                    existing, wordSel);
-  else if (resultElemType.isFloat8E4M3FNUZ())
+  else if (resultElemType.isFloat8E4M3FNUZ() || resultElemType.isFloat8E4M3FN())
     result = rewriter.create<ROCDL::CvtPkFp8F32Op>(loc, i32, sourceA, sourceB,
                                                    existing, wordSel);
 
@@ -832,10 +835,10 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
   Value byteSel = createI32Constant(rewriter, loc, op.getStoreIndex());
 
   Value result;
-  if (resultElemType.isFloat8E5M2FNUZ())
+  if (resultElemType.isFloat8E5M2FNUZ() || resultElemType.isFloat8E5M2())
     result = rewriter.create<ROCDL::CvtSrBf8F32Op>(loc, i32, source, stoch,
                                                    existing, byteSel);
-  else if (resultElemType.isFloat8E4M3FNUZ())
+  else if (resultElemType.isFloat8E4M3FNUZ() || resultElemType.isFloat8E4M3FN())
     result = rewriter.create<ROCDL::CvtSrFp8F32Op>(loc, i32, source, stoch,
                                                    existing, byteSel);
 
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index d36583c8118ff4..a66c13caa6d0ab 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -86,7 +86,8 @@ LogicalResult ExtFOnFloat8RewritePattern::match(arith::ExtFOp op) const {
       return failure();
     inType = inVecType.getElementType();
   }
-  return success(inType.isFloat8E5M2FNUZ() || inType.isFloat8E4M3FNUZ());
+  return success(inType.isFloat8E5M2FNUZ() || inType.isFloat8E4M3FNUZ() ||
+                 inType.isFloat8E5M2() || inType.isFloat8E4M3FN());
 }
 
 void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
@@ -216,7 +217,11 @@ LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const {
   if (inType && inType.getWidth() <= 8 && saturateFP8)
     // Conversion between 8-bit floats is not supported with truncation enabled.
     return failure();
-  return success(outType.isFloat8E5M2FNUZ() || outType.isFloat8E4M3FNUZ());
+
+  return success((((outType.isFloat8E5M2FNUZ() || outType.isFloat8E4M3FNUZ()) &&
+                   chipset.isGfx940()) ||
+                  ((outType.isFloat8E5M2() || outType.isFloat8E4M3FN()) &&
+                   chipset.hasOcpFp8())));
 }
 
 void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 3943696364950f..2747eebebefa52 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -271,14 +271,16 @@ LogicalResult MFMAOp::verify() {
   }
 
   Type sourceBType = getSourceB().getType();
-  if (sourceElem.isFloat8E5M2FNUZ() || sourceElem.isFloat8E4M3FNUZ()) {
+  if (sourceElem.isFloat8E5M2FNUZ() || sourceElem.isFloat8E4M3FNUZ() ||
+      sourceElem.isFloat8E5M2() || sourceElem.isFloat8E4M3FN()) {
     int64_t sourceBLen = 1;
     Type sourceBElem = sourceBType;
     if (auto sourceBVector = llvm::dyn_cast<VectorType>(sourceBType)) {
       sourceBLen = sourceBVector.getNumElements();
       sourceBElem = sourceBVector.getElementType();
     }
-    if (!sourceBElem.isFloat8E5M2FNUZ() && !sourceBElem.isFloat8E4M3FNUZ())
+    if (!sourceBElem.isFloat8E5M2FNUZ() && !sourceBElem.isFloat8E4M3FNUZ() &&
+        !sourceBElem.isFloat8E5M2() && !sourceBElem.isFloat8E4M3FN())
       return emitOpError("expected both source operands to have f8 elements");
     if (sourceLen != sourceBLen)
       return emitOpError(
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index b78c372af77e64..963fd6fd7c0511 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -509,7 +509,9 @@ bool TosaValidation::isValidElementType(Type type) {
   if (isa<FloatType>(type)) {
     if (profile == TosaProfileEnum::BaseInference)
       return false;
-    return type.isF32() || type.isF16() || type.isBF16();
+    return type.isF32() || type.isF16() || type.isBF16() ||
+           type.isFloat8E4M3FNUZ() || type.isFloat8E5M2FNUZ() ||
+           type.isFloat8E4M3FN() || type.isFloat8E5M2();
   }
   if (auto intTy = dyn_cast<IntegerType>(type)) {
     if (intTy.isUnsigned()) {



More information about the Mlir-commits mailing list