[Mlir-commits] [mlir] [MLIR][AMDGPU] Implement emulated F8 for the OCP formats. (PR #106160)

Paul C Fuqua llvmlistbot at llvm.org
Thu Sep 12 18:57:55 PDT 2024


https://github.com/pcf000 updated https://github.com/llvm/llvm-project/pull/106160

>From ab7b9f13b4d4e216184a6959e41c2a1c52d23b21 Mon Sep 17 00:00:00 2001
From: Paul Fuqua <pf at acm.org>
Date: Thu, 11 Jul 2024 20:12:45 -0500
Subject: [PATCH 1/3] [MLIR][AMDGPU] Implement emulated FP8 for the OCP
 formats. This part mostly just allows the new types.

---
 mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 14 +++----
 .../mlir/Dialect/AMDGPU/Utils/Chipset.h       |  7 ++++
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           | 37 +++++++++++--------
 .../ArithToAMDGPU/ArithToAMDGPU.cpp           |  9 ++++-
 mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp  |  6 ++-
 .../Tosa/Transforms/TosaValidation.cpp        |  4 +-
 6 files changed, 49 insertions(+), 28 deletions(-)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index cba35bbca1f83b..484cea84f669b8 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -41,8 +41,8 @@ class AMDGPU_Op<string mnemonic, list<Trait> traits = []> :
 
 def AMDGPU_ExtPackedFp8Op :
     AMDGPU_Op<"ext_packed_fp8", [Pure]>,
-    Arguments<(ins AnyTypeOf<[F8E5M2FNUZ, F8E4M3FNUZ,
-        VectorOfLengthAndType<[1, 2, 3, 4], [F8E5M2FNUZ, F8E4M3FNUZ]>]>:$source,
+    Arguments<(ins AnyTypeOf<[F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN,
+        VectorOfLengthAndType<[1, 2, 3, 4], [F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN]>]>:$source,
       ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<3>]>:$index)>,
     Results<(outs F32:$res)> {
   let summary = "Extend one of a vector of packed fp8 values to a float";
@@ -68,8 +68,8 @@ def AMDGPU_PackedTrunc2xFp8Op :
     Arguments<(ins F32:$sourceA,
       Optional<F32>:$sourceB,
       ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<1>]>:$wordIndex,
-      Optional<FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>>:$existing)>,
-    Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>:$res)> {
+      Optional<FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2]>>:$existing)>,
+    Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2]>:$res)> {
   let summary = "Round two floats into a packed vector of 8-bit floats";
   let description = [{
     Round the inputs `sourceA` and `sourceB` (which is undefined if not
@@ -95,8 +95,8 @@ def AMDGPU_PackedStochRoundFp8Op :
     Arguments<(ins F32:$source,
       I32:$stochiasticParam,
       ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<3>]>:$storeIndex,
-      Optional<FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>>:$existing)>,
-    Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>:$res)> {
+      Optional<FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2]>>:$existing)>,
+    Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2]>:$res)> {
   let summary = "Round float stochiastically into a packed vector of 8-bit floats";
   let description = [{
     Round the input `source`, adding in `stochiasticParam`, and place it into
@@ -546,7 +546,7 @@ def MFMAInTypes : AnyTypeOf<[F32, F64, I32, I64,
                              VectorOfLengthAndType<[4], [F16]>,
                              VectorOfLengthAndType<[2, 4], [BF16]>,
                              VectorOfLengthAndType<[4, 8], [I8]>,
-                             VectorOfLengthAndType<[8], [F8E5M2FNUZ, F8E4M3FNUZ]>]>;
+                             VectorOfLengthAndType<[8], [F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN]>]>;
 def MFMAOutTypes : AnyTypeOf<[F64,
                               VectorOfLengthAndType<[4, 16, 32], [F32]>,
                               VectorOfLengthAndType<[4, 16, 32], [I32]>,
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h b/mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h
index a5dab1ab896302..e5c14c1cb68278 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h
+++ b/mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h
@@ -47,6 +47,13 @@ struct Chipset {
   DEFINE_COMP_OPERATOR(>)
   DEFINE_COMP_OPERATOR(>=)
 #undef DEFINE_COMP_OPERATOR
+
+  bool isGfx940() const {
+    return majorVersion == 9 && minorVersion >= 0x40 && minorVersion < 0x50;
+  }
+  bool hasOcpFp8() const {
+    return (majorVersion == 9 && minorVersion >= 0x50) || majorVersion >= 12;
+  }
 };
 
 } // namespace mlir::amdgpu
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index f80d2793eaef59..60bda2d92df587 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -550,38 +550,42 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
       return ROCDL::mfma_f64_4x4x4f64::getOperationName();
   }
 
-  if (sourceElem.isFloat8E5M2FNUZ() && destElem.isF32() && chipset >= kGfx940) {
+  if (destElem.isF32() &&
+      ((sourceElem.isFloat8E5M2FNUZ() && chipset >= kGfx940) ||
+       (sourceElem.isFloat8E5M2() && chipset.hasOcpFp8()))) {
     // Known to be correct because there are no scalar f8 instructions and
     // because a length mismatch will have been caught by the verifier.
     Type sourceBElem =
         cast<VectorType>(mfma.getSourceB().getType()).getElementType();
     if (m == 16 && n == 16 && k == 32 && b == 1) {
-      if (sourceBElem.isFloat8E5M2FNUZ())
+      if (sourceBElem.isFloat8E5M2FNUZ() || sourceBElem.isFloat8E5M2())
         return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
-      if (sourceBElem.isFloat8E4M3FNUZ())
+      if (sourceBElem.isFloat8E4M3FNUZ() || sourceBElem.isFloat8E4M3FN())
         return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
     }
     if (m == 32 && n == 32 && k == 16 && b == 1) {
-      if (sourceBElem.isFloat8E5M2FNUZ())
+      if (sourceBElem.isFloat8E5M2FNUZ() || sourceBElem.isFloat8E5M2())
         return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
-      if (sourceBElem.isFloat8E4M3FNUZ())
+      if (sourceBElem.isFloat8E4M3FNUZ() || sourceBElem.isFloat8E4M3FN())
         return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
     }
   }
 
-  if (sourceElem.isFloat8E4M3FNUZ() && destElem.isF32() && chipset >= kGfx940) {
+  if (destElem.isF32() &&
+      ((sourceElem.isFloat8E4M3FNUZ() && chipset >= kGfx940) ||
+       (sourceElem.isFloat8E4M3FN() && chipset.hasOcpFp8()))) {
     Type sourceBElem =
         cast<VectorType>(mfma.getSourceB().getType()).getElementType();
     if (m == 16 && n == 16 && k == 32 && b == 1) {
-      if (sourceBElem.isFloat8E5M2FNUZ())
+      if (sourceBElem.isFloat8E5M2FNUZ() || sourceBElem.isFloat8E5M2())
         return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
-      if (sourceBElem.isFloat8E4M3FNUZ())
+      if (sourceBElem.isFloat8E4M3FNUZ() || sourceBElem.isFloat8E4M3FN())
         return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
     }
     if (m == 32 && n == 32 && k == 16 && b == 1) {
-      if (sourceBElem.isFloat8E5M2FNUZ())
+      if (sourceBElem.isFloat8E5M2FNUZ() || sourceBElem.isFloat8E5M2())
         return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
-      if (sourceBElem.isFloat8E4M3FNUZ())
+      if (sourceBElem.isFloat8E4M3FNUZ() || sourceBElem.isFloat8E4M3FN())
         return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
     }
   }
@@ -787,10 +791,11 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
   }
   Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
   Value wordSel = createI32Constant(rewriter, loc, op.getIndex());
-  if (sourceElemType.isFloat8E5M2FNUZ()) {
+  if (sourceElemType.isFloat8E5M2FNUZ() || sourceElemType.isFloat8E5M2()) {
     rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
                                                     wordSel);
-  } else if (sourceElemType.isFloat8E4M3FNUZ()) {
+  } else if (sourceElemType.isFloat8E4M3FNUZ() ||
+             sourceElemType.isFloat8E4M3FN()) {
     rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
                                                     wordSel);
   }
@@ -822,10 +827,10 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
   Value wordSel = createI1Constant(rewriter, loc, op.getWordIndex());
 
   Value result;
-  if (resultElemType.isFloat8E5M2FNUZ())
+  if (resultElemType.isFloat8E5M2FNUZ() || resultElemType.isFloat8E5M2())
     result = rewriter.create<ROCDL::CvtPkBf8F32Op>(loc, i32, sourceA, sourceB,
                                                    existing, wordSel);
-  else if (resultElemType.isFloat8E4M3FNUZ())
+  else if (resultElemType.isFloat8E4M3FNUZ() || resultElemType.isFloat8E4M3FN())
     result = rewriter.create<ROCDL::CvtPkFp8F32Op>(loc, i32, sourceA, sourceB,
                                                    existing, wordSel);
 
@@ -857,10 +862,10 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
   Value byteSel = createI32Constant(rewriter, loc, op.getStoreIndex());
 
   Value result;
-  if (resultElemType.isFloat8E5M2FNUZ())
+  if (resultElemType.isFloat8E5M2FNUZ() || resultElemType.isFloat8E5M2())
     result = rewriter.create<ROCDL::CvtSrBf8F32Op>(loc, i32, source, stoch,
                                                    existing, byteSel);
-  else if (resultElemType.isFloat8E4M3FNUZ())
+  else if (resultElemType.isFloat8E4M3FNUZ() || resultElemType.isFloat8E4M3FN())
     result = rewriter.create<ROCDL::CvtSrFp8F32Op>(loc, i32, source, stoch,
                                                    existing, byteSel);
 
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 6b27ec9947cb0b..423069d406f472 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -86,7 +86,8 @@ LogicalResult ExtFOnFloat8RewritePattern::match(arith::ExtFOp op) const {
       return failure();
     inType = inVecType.getElementType();
   }
-  return success(inType.isFloat8E5M2FNUZ() || inType.isFloat8E4M3FNUZ());
+  return success(inType.isFloat8E5M2FNUZ() || inType.isFloat8E4M3FNUZ() ||
+                 inType.isFloat8E5M2() || inType.isFloat8E4M3FN());
 }
 
 void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
@@ -216,7 +217,11 @@ LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const {
   if (inType && inType.getWidth() <= 8 && saturateFP8)
     // Conversion between 8-bit floats is not supported with truncation enabled.
     return failure();
-  return success(outType.isFloat8E5M2FNUZ() || outType.isFloat8E4M3FNUZ());
+
+  return success((((outType.isFloat8E5M2FNUZ() || outType.isFloat8E4M3FNUZ()) &&
+                   chipset.isGfx940()) ||
+                  ((outType.isFloat8E5M2() || outType.isFloat8E4M3FN()) &&
+                   chipset.hasOcpFp8())));
 }
 
 void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 63447baa31eb0c..4b9c532ce67a07 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -272,14 +272,16 @@ LogicalResult MFMAOp::verify() {
   }
 
   Type sourceBType = getSourceB().getType();
-  if (sourceElem.isFloat8E5M2FNUZ() || sourceElem.isFloat8E4M3FNUZ()) {
+  if (sourceElem.isFloat8E5M2FNUZ() || sourceElem.isFloat8E4M3FNUZ() ||
+      sourceElem.isFloat8E5M2() || sourceElem.isFloat8E4M3FN()) {
     int64_t sourceBLen = 1;
     Type sourceBElem = sourceBType;
     if (auto sourceBVector = llvm::dyn_cast<VectorType>(sourceBType)) {
       sourceBLen = sourceBVector.getNumElements();
       sourceBElem = sourceBVector.getElementType();
     }
-    if (!sourceBElem.isFloat8E5M2FNUZ() && !sourceBElem.isFloat8E4M3FNUZ())
+    if (!sourceBElem.isFloat8E5M2FNUZ() && !sourceBElem.isFloat8E4M3FNUZ() &&
+        !sourceBElem.isFloat8E5M2() && !sourceBElem.isFloat8E4M3FN())
       return emitOpError("expected both source operands to have f8 elements");
     if (sourceLen != sourceBLen)
       return emitOpError(
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index b78c372af77e64..963fd6fd7c0511 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -509,7 +509,9 @@ bool TosaValidation::isValidElementType(Type type) {
   if (isa<FloatType>(type)) {
     if (profile == TosaProfileEnum::BaseInference)
       return false;
-    return type.isF32() || type.isF16() || type.isBF16();
+    return type.isF32() || type.isF16() || type.isBF16() ||
+           type.isFloat8E4M3FNUZ() || type.isFloat8E5M2FNUZ() ||
+           type.isFloat8E4M3FN() || type.isFloat8E5M2();
   }
   if (auto intTy = dyn_cast<IntegerType>(type)) {
     if (intTy.isUnsigned()) {

>From cba8da0d7099852749747caa7c221879516b6cbe Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
Date: Mon, 9 Sep 2024 17:39:17 -0500
Subject: [PATCH 2/3] [MLIR][AMDGPU] After fp8 conversions were lowered to
 AMDGPU dialect ops, those operations were not being converted to the LLVM
 intrinsics they correspond to because the rewrite patterns were still
 checking for gfx940+.

As part of this, factor out tests for type-match isto isNativeFp8()
and isNativeBf8() functions in the AMDGPUToRocdl rewrites.

Also, fix a typo in isGfx940() that caused it to be true for gfx950.

Finally, test all these OCP format conversions by duplicating the
gfx940 tests.
---
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           |  51 ++---
 .../AMDGPUToROCDL/8-bit-floats-ocp.mlir       | 109 +++++++++++
 .../8-bit-float-saturation-ocp.mlir           |  58 ++++++
 .../ArithToAMDGPU/8-bit-floats-ocp.mlir       | 176 ++++++++++++++++++
 4 files changed, 373 insertions(+), 21 deletions(-)
 create mode 100644 mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-ocp.mlir
 create mode 100644 mlir/test/Conversion/ArithToAMDGPU/8-bit-float-saturation-ocp.mlir
 create mode 100644 mlir/test/Conversion/ArithToAMDGPU/8-bit-floats-ocp.mlir

diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 60bda2d92df587..770c5072e2e79d 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -454,6 +454,20 @@ static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
   }
 }
 
+/// Return true if `type` is the E5M2 variant of an 8-bit float that is
+/// supported by the `_bf8` instructions on the given `chipset`.
+static bool isNativeBf8(Chipset chipset, Type type) {
+  return (chipset.isGfx940() && type.isFloat8E5M2FNUZ()) ||
+         (chipset.hasOcpFp8() && type.isFloat8E5M2());
+}
+
+/// Return true if `type` is the E4M3FN variant of an 8-bit float that is
+/// supported by the `_fp8` instructions on the given `chipset`.
+static bool isNativeFp8(Chipset chipset, Type type) {
+  return (chipset.isGfx940() && type.isFloat8E4M3FNUZ()) ||
+         (chipset.hasOcpFp8() && type.isFloat8E4M3FN());
+}
+
 /// Return the `rocdl` intrinsic corresponding to a MFMA operation `mfma`
 /// if one exists. This includes checking to ensure the intrinsic is supported
 /// on the architecture you are compiling for.
@@ -550,42 +564,38 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
       return ROCDL::mfma_f64_4x4x4f64::getOperationName();
   }
 
-  if (destElem.isF32() &&
-      ((sourceElem.isFloat8E5M2FNUZ() && chipset >= kGfx940) ||
-       (sourceElem.isFloat8E5M2() && chipset.hasOcpFp8()))) {
+  if (destElem.isF32() && isNativeBf8(chipset, sourceElem)) {
     // Known to be correct because there are no scalar f8 instructions and
     // because a length mismatch will have been caught by the verifier.
     Type sourceBElem =
         cast<VectorType>(mfma.getSourceB().getType()).getElementType();
     if (m == 16 && n == 16 && k == 32 && b == 1) {
-      if (sourceBElem.isFloat8E5M2FNUZ() || sourceBElem.isFloat8E5M2())
+      if (isNativeBf8(chipset, sourceBElem))
         return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
-      if (sourceBElem.isFloat8E4M3FNUZ() || sourceBElem.isFloat8E4M3FN())
+      if (isNativeFp8(chipset, sourceBElem))
         return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
     }
     if (m == 32 && n == 32 && k == 16 && b == 1) {
-      if (sourceBElem.isFloat8E5M2FNUZ() || sourceBElem.isFloat8E5M2())
+      if (isNativeBf8(chipset, sourceBElem))
         return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
-      if (sourceBElem.isFloat8E4M3FNUZ() || sourceBElem.isFloat8E4M3FN())
+      if (isNativeFp8(chipset, sourceBElem))
         return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
     }
   }
 
-  if (destElem.isF32() &&
-      ((sourceElem.isFloat8E4M3FNUZ() && chipset >= kGfx940) ||
-       (sourceElem.isFloat8E4M3FN() && chipset.hasOcpFp8()))) {
+  if (destElem.isF32() && isNativeFp8(chipset, sourceElem)) {
     Type sourceBElem =
         cast<VectorType>(mfma.getSourceB().getType()).getElementType();
     if (m == 16 && n == 16 && k == 32 && b == 1) {
-      if (sourceBElem.isFloat8E5M2FNUZ() || sourceBElem.isFloat8E5M2())
+      if (isNativeBf8(chipset, sourceBElem))
         return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
-      if (sourceBElem.isFloat8E4M3FNUZ() || sourceBElem.isFloat8E4M3FN())
+      if (isNativeFp8(chipset, sourceBElem))
         return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
     }
     if (m == 32 && n == 32 && k == 16 && b == 1) {
-      if (sourceBElem.isFloat8E5M2FNUZ() || sourceBElem.isFloat8E5M2())
+      if (isNativeBf8(chipset, sourceBElem))
         return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
-      if (sourceBElem.isFloat8E4M3FNUZ() || sourceBElem.isFloat8E4M3FN())
+      if (isNativeFp8(chipset, sourceBElem))
         return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
     }
   }
@@ -791,11 +801,10 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
   }
   Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
   Value wordSel = createI32Constant(rewriter, loc, op.getIndex());
-  if (sourceElemType.isFloat8E5M2FNUZ() || sourceElemType.isFloat8E5M2()) {
+  if (isNativeBf8(chipset, sourceElemType)) {
     rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
                                                     wordSel);
-  } else if (sourceElemType.isFloat8E4M3FNUZ() ||
-             sourceElemType.isFloat8E4M3FN()) {
+  } else if (isNativeFp8(chipset, sourceElemType)) {
     rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
                                                     wordSel);
   }
@@ -827,10 +836,10 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
   Value wordSel = createI1Constant(rewriter, loc, op.getWordIndex());
 
   Value result;
-  if (resultElemType.isFloat8E5M2FNUZ() || resultElemType.isFloat8E5M2())
+  if (isNativeBf8(chipset, resultElemType))
     result = rewriter.create<ROCDL::CvtPkBf8F32Op>(loc, i32, sourceA, sourceB,
                                                    existing, wordSel);
-  else if (resultElemType.isFloat8E4M3FNUZ() || resultElemType.isFloat8E4M3FN())
+  else if (isNativeFp8(chipset, resultElemType))
     result = rewriter.create<ROCDL::CvtPkFp8F32Op>(loc, i32, sourceA, sourceB,
                                                    existing, wordSel);
 
@@ -862,10 +871,10 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
   Value byteSel = createI32Constant(rewriter, loc, op.getStoreIndex());
 
   Value result;
-  if (resultElemType.isFloat8E5M2FNUZ() || resultElemType.isFloat8E5M2())
+  if (isNativeBf8(chipset, resultElemType))
     result = rewriter.create<ROCDL::CvtSrBf8F32Op>(loc, i32, source, stoch,
                                                    existing, byteSel);
-  else if (resultElemType.isFloat8E4M3FNUZ() || resultElemType.isFloat8E4M3FN())
+  else if (isNativeFp8(chipset, resultElemType))
     result = rewriter.create<ROCDL::CvtSrFp8F32Op>(loc, i32, source, stoch,
                                                    existing, byteSel);
 
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-ocp.mlir b/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-ocp.mlir
new file mode 100644
index 00000000000000..70775a603e54d9
--- /dev/null
+++ b/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-ocp.mlir
@@ -0,0 +1,109 @@
+// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx950 | FileCheck %s
+// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1200 | FileCheck %s
+
+// CHECK-LABEL: func @ext_scalar
+// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : f8E5M2 to i8
+// CHECK-DAG: [[UNDEF:%.+]] = llvm.mlir.undef : vector<4xi8>
+// CHECK-DAG: [[C0_1:%.+]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: [[VEC:%.+]] = llvm.insertelement [[V]], [[UNDEF]]{{\[}}[[C0_1]] : i32] : vector<4xi8>
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC]] : vector<4xi8> to i32
+// CHECK: [[C0_2:%.+]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.bf8 [[CAST]]{{\[}}[[C0_2]]] : f32
+// CHECK: return [[EXT]]
+func.func @ext_scalar(%v: f8E5M2) -> f32 {
+  %ret = amdgpu.ext_packed_fp8 %v[0] : f8E5M2 to f32
+  func.return %ret : f32
+}
+
+// CHECK-LABEL: func @ext_short_vec
+// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<2xf8E4M3FN> to vector<2xi8>
+// CHECK-DAG: [[UNDEF:%.+]] = llvm.mlir.undef : vector<4xi8>
+// CHECK-DAG: [[C0:%.+]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: [[ELEM_0:%.+]] = llvm.extractelement [[V]]{{\[}}[[C0]] : i32] : vector<2xi8>
+// CHECK: [[VEC_0:%.+]] = llvm.insertelement [[ELEM_0]], [[UNDEF]]{{\[}}[[C0]] : i32] : vector<4xi8>
+// CHECK: [[C1_1:%.+]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: [[ELEM_1:%.+]] = llvm.extractelement [[V]]{{\[}}[[C1_1]] : i32] : vector<2xi8>
+// CHECK: [[VEC_1:%.+]] = llvm.insertelement [[ELEM_1]], [[VEC_0]]{{\[}}[[C1_1]] : i32] : vector<4xi8>
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC_1]] : vector<4xi8> to i32
+// CHECK: [[C1_2:%.+]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]]{{\[}}[[C1_2]]] : f32
+// CHECK: return [[EXT]]
+func.func @ext_short_vec(%v: vector<2xf8E4M3FN>) -> f32 {
+  %ret = amdgpu.ext_packed_fp8 %v[1] : vector<2xf8E4M3FN> to f32
+  func.return %ret : f32
+}
+
+// CHECK-LABEL: func @ext_full_vec(
+// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<4xf8E4M3FN> to vector<4xi8>
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[V]] : vector<4xi8> to i32
+// CHECK: [[C3:%.+]] = llvm.mlir.constant(3 : i32) : i32
+// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]]{{\[}}[[C3]]] : f32
+// CHECK: return [[EXT]] : f32
+
+func.func @ext_full_vec(%v: vector<4xf8E4M3FN>) -> f32 {
+  %ret = amdgpu.ext_packed_fp8 %v[3] : vector<4xf8E4M3FN> to f32
+  func.return %ret : f32
+}
+
+// CHECK-LABEL: func @packed_trunc
+// CHECK-SAME: ([[V:%.+]]: f32)
+// CHECK: [[V2:%.+]] = llvm.mlir.undef : f32
+// CHECK: [[EXISTING:%.+]] = llvm.mlir.undef : i32
+// CHECK: [[FALSE:%.+]] = llvm.mlir.constant(false) : i1
+// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.fp8.f32 [[V]], [[V2]] -> [[EXISTING]]{{\[}}[[FALSE]]] : i32
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8>
+// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E4M3FN>
+func.func @packed_trunc(%v: f32) -> vector<4xf8E4M3FN> {
+  %ret = amdgpu.packed_trunc_2xfp8 %v, undef into undef[word 0] : f32 to vector<4xf8E4M3FN>
+  func.return %ret : vector<4xf8E4M3FN>
+}
+
+// CHECK-LABEL: func @packed_truncx2
+// CHECK-SAME: ([[V:%.+]]: f32, [[W:%.+]]: f32)
+// CHECK: [[EXISTING:%.+]] = llvm.mlir.undef : i32
+// CHECK: [[FALSE:%.+]] = llvm.mlir.constant(false) : i1
+// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.fp8.f32 [[V]], [[W]] -> [[EXISTING]]{{\[}}[[FALSE]]] : i32
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8>
+// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E4M3FN>
+func.func @packed_truncx2(%v: f32, %w: f32) -> vector<4xf8E4M3FN> {
+  %ret = amdgpu.packed_trunc_2xfp8 %v, %w into undef[word 0] : f32 to vector<4xf8E4M3FN>
+  func.return %ret : vector<4xf8E4M3FN>
+}
+
+// CHECK-LABEL: func @packed_truncx2_into
+// CHECK-SAME: ([[V:%.+]]: f32, [[W:%.+]]: f32, [[EXISTING:%.+]]:  vector<4xf8E5M2>)
+// CHECK: [[EXISTING_BYTES:%.+]] = builtin.unrealized_conversion_cast [[EXISTING]] : vector<4xf8E5M2> to vector<4xi8>
+// CHECK: [[EXISTING_INT:%.+]] = llvm.bitcast [[EXISTING_BYTES]] : vector<4xi8> to i32
+// CHECK: [[TRUE:%.+]] = llvm.mlir.constant(true) : i1
+// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.bf8.f32 [[V]], [[W]] -> [[EXISTING_INT]]{{\[}}[[TRUE]]] : i32
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8>
+// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E5M2>
+func.func @packed_truncx2_into(%v: f32, %w: f32, %existing: vector<4xf8E5M2>) -> vector<4xf8E5M2> {
+  %ret = amdgpu.packed_trunc_2xfp8 %v, %w into %existing[word 1] : f32 to vector<4xf8E5M2> into vector<4xf8E5M2>
+  func.return %ret : vector<4xf8E5M2>
+}
+
+// CHECK-LABEL: func @packed_stoch_round
+// CHECK-SAME: ([[V:%.+]]: f32, [[S:%.+]]: i32)
+// CHECK: [[EXISTING:%.+]] = llvm.mlir.undef : i32
+// CHECK: [[C0:%.+]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: [[PACKED:%.+]] = rocdl.cvt.sr.fp8.f32 [[V]], [[S]] -> [[EXISTING]]{{\[}}[[C0]]] : i32
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8>
+// CHECK:  builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E4M3FN>
+func.func @packed_stoch_round(%v: f32, %s: i32) -> vector<4xf8E4M3FN> {
+  %ret = amdgpu.packed_stoch_round_fp8 %v + %s into undef[0] : f32 to vector<4xf8E4M3FN>
+  func.return %ret : vector<4xf8E4M3FN>
+}
+
+// CHECK-LABEL: func @packed_stoch_round_into
+// CHECK-SAME: ([[V:%.+]]: f32, [[S:%.+]]: i32, [[EXISTING:%.+]]:  vector<4xf8E5M2>)
+// CHECK: [[EXISTING_BYTES:%.+]] = builtin.unrealized_conversion_cast [[EXISTING]] : vector<4xf8E5M2> to vector<4xi8>
+// CHECK: [[EXISTING_INT:%.+]] = llvm.bitcast [[EXISTING_BYTES]] : vector<4xi8> to i32
+// CHECK: [[C1:%.+]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: [[PACKED:%.+]] = rocdl.cvt.sr.bf8.f32 [[V]], [[S]] -> [[EXISTING_INT]]{{\[}}[[C1]]] : i32
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8>
+// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E5M2>
+func.func @packed_stoch_round_into(%v: f32, %s: i32, %existing: vector<4xf8E5M2>) -> vector<4xf8E5M2> {
+  %ret = amdgpu.packed_stoch_round_fp8 %v + %s into %existing[1] : f32 to vector<4xf8E5M2> into vector<4xf8E5M2>
+  func.return %ret : vector<4xf8E5M2>
+}
diff --git a/mlir/test/Conversion/ArithToAMDGPU/8-bit-float-saturation-ocp.mlir b/mlir/test/Conversion/ArithToAMDGPU/8-bit-float-saturation-ocp.mlir
new file mode 100644
index 00000000000000..2df5f2fa1965fd
--- /dev/null
+++ b/mlir/test/Conversion/ArithToAMDGPU/8-bit-float-saturation-ocp.mlir
@@ -0,0 +1,58 @@
+// RUN: mlir-opt --split-input-file %s \
+// RUN: --pass-pipeline='builtin.module(func.func(convert-arith-to-amdgpu{chipset=gfx950 saturate-fp8-truncf=true}))' \
+// RUN: | FileCheck %s
+
+// RUN: mlir-opt --split-input-file %s \
+// RUN: --pass-pipeline='builtin.module(func.func(convert-arith-to-amdgpu{chipset=gfx1200 saturate-fp8-truncf=true}))' \
+// RUN: | FileCheck %s
+
+// CHECK-LABEL: func.func @scalar_trunc
+// CHECK-SAME: ([[V:%.+]]: f16)
+// CHECK-DAG: [[CMin:%.+]] = arith.constant -5.734400e+04 : f16
+// CHECK-DAG: [[CMax:%.+]] = arith.constant 5.734400e+04 : f16
+// CHECK-DAG: [[CInf:%.+]] = arith.constant 0x7C00 : f16
+// CHECK-DAG: [[CNegInf:%.+]] = arith.constant 0xFC00 : f16
+// CHECK: [[ISINF:%.+]] = arith.cmpf oeq, [[V]], [[CInf]]
+// CHECK: [[ISNEGINF:%.+]] = arith.cmpf oeq, [[V]], [[CNegInf]]
+// CHECK: [[ISNAN:%.+]] = arith.cmpf uno, [[V]], [[V]]
+// CHECK: [[ISNONFINITE_1:%.+]] = arith.ori [[ISINF]], [[ISNEGINF]]
+// CHECK: [[ISNONFINITE:%.+]] = arith.ori [[ISNONFINITE_1]], [[ISNAN]]
+// CHECK: [[CLAMPEDBELOW:%.+]] = arith.maximumf [[V]], [[CMin]]
+// CHECK: [[CLAMPED:%.+]] = arith.minimumf [[CLAMPEDBELOW]], [[CMax]]
+// CHECK: [[SATURATED:%.+]] = arith.select [[ISNONFINITE]], [[V]], [[CLAMPED]]
+// CHECK: [[FLOAT:%.+]] = arith.extf [[SATURATED]] : f16 to f32
+// CHECK: [[TRUNCV:%.+]] = amdgpu.packed_trunc_2xfp8 [[FLOAT]], undef into undef[word 0] : f32 to vector<4xf8E5M2>
+// CHECK: [[W:%.+]] = vector.extract [[TRUNCV]][0] : f8E5M2 from vector<4xf8E5M2>
+// CHECK: return [[W]] : f8E5M2
+func.func @scalar_trunc(%v: f16) -> f8E5M2 {
+  %w = arith.truncf %v : f16 to f8E5M2
+  return %w : f8E5M2
+}
+
+// No 0-D test because arith.truncf hasn't been extended to support it.
+
+// -----
+
+// CHECK-LABEL: func.func @vector_trunc
+// CHECK-SAME: ([[V:%.+]]: vector<2xf32>) -> vector<2xf8E4M3FN> {
+// CHECK-DAG: [[CMin:%.+]] = arith.constant dense<-4.480000e+02> : vector<2xf32>
+// CHECK-DAG: [[CMax:%.+]] = arith.constant dense<4.480000e+02> : vector<2xf32>
+// CHECK-DAG: [[CInf:%.+]] = arith.constant dense<0x7F800000> : vector<2xf32>
+// CHECK-DAG: [[CNegInf:%.+]] = arith.constant dense<0xFF800000> : vector<2xf32>
+// CHECK: [[ISINF:%.+]] = arith.cmpf oeq, [[V]], [[CInf]]
+// CHECK: [[ISNEGINF:%.+]] = arith.cmpf oeq, [[V]], [[CNegInf]]
+// CHECK: [[ISNAN:%.+]] = arith.cmpf uno, [[V]], [[V]]
+// CHECK: [[ISNONFINITE_1:%.+]] = arith.ori [[ISINF]], [[ISNEGINF]]
+// CHECK: [[ISNONFINITE:%.+]] = arith.ori [[ISNONFINITE_1]], [[ISNAN]]
+// CHECK: [[CLAMPEDBELOW:%.+]] = arith.maximumf [[V]], [[CMin]]
+// CHECK: [[CLAMPED:%.+]] = arith.minimumf [[CLAMPEDBELOW]], [[CMax]]
+// CHECK: [[SATURATED:%.+]] = arith.select [[ISNONFINITE]], [[V]], [[CLAMPED]]
+// CHECK: [[F0:%.+]] = vector.extract [[SATURATED]][0]
+// CHECK: [[F1:%.+]] = vector.extract [[SATURATED]][1]
+// CHECK: [[W0:%.+]] = amdgpu.packed_trunc_2xfp8 [[F0]], [[F1]] into undef[word 0] : f32 to vector<4xf8E4M3FN>
+// CHECK: [[W:%.+]] = vector.extract_strided_slice [[W0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xf8E4M3FN> to vector<2xf8E4M3FN>
+// CHECK: return [[W]] : vector<2xf8E4M3FN>
+func.func @vector_trunc_short(%v: vector<2xf32>) -> vector<2xf8E4M3FN> {
+  %w = arith.truncf %v : vector<2xf32> to vector<2xf8E4M3FN>
+  return %w : vector<2xf8E4M3FN>
+}
diff --git a/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats-ocp.mlir b/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats-ocp.mlir
new file mode 100644
index 00000000000000..0e7f58c9e67497
--- /dev/null
+++ b/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats-ocp.mlir
@@ -0,0 +1,176 @@
+// RUN: mlir-opt --split-input-file %s -convert-arith-to-amdgpu="chipset=gfx950" | FileCheck %s
+// RUN: mlir-opt --split-input-file %s -convert-arith-to-amdgpu="chipset=gfx1200" | FileCheck %s
+
+// CHECK-LABEL: func.func @scalar_ext
+// CHECK-SAME: ([[V:%.+]]: f8E5M2)
+// CHECK: [[FLOAT:%.+]] = amdgpu.ext_packed_fp8 [[V]][0] : f8E5M2 to f32
+// CHECK: [[W:%.+]] = arith.truncf [[FLOAT]] : f32 to f16
+// CHECK: return [[W]]
+func.func @scalar_ext(%v: f8E5M2) -> f16 {
+  %w = arith.extf %v : f8E5M2 to f16
+  return %w : f16
+}
+
+// No 0-D test because arith.extf hasn't been extended to support it.
+
+// -----
+
+// CHECK-LABEL: func.func @vector_ext_short
+// CHECK-SAME: ([[V:%.+]]: vector<2xf8E5M2>)
+// CHECK-DAG: [[ZEROES:%.+]] = arith.constant dense<0.000000e+00> : vector<2xf64>
+// CHECK: [[FLOAT0:%.+]] = amdgpu.ext_packed_fp8 [[V]][0] : vector<2xf8E5M2> to f32
+// CHECK: [[EXT0:%.+]] = arith.extf [[FLOAT0]] : f32 to f64
+// CHECK: [[W0:%.+]] = vector.insert [[EXT0]], [[ZEROES]] [0]
+// CHECK: [[FLOAT1:%.+]] = amdgpu.ext_packed_fp8 [[V]][1] : vector<2xf8E5M2> to f32
+// CHECK: [[EXT1:%.+]] = arith.extf [[FLOAT1]]
+// CHECK: [[W1:%.+]] = vector.insert [[EXT1]], [[W0]] [1]
+// CHECK: return [[W1]] : vector<2xf64>
+
+func.func @vector_ext_short(%v: vector<2xf8E5M2>) -> vector<2xf64> {
+  %w = arith.extf %v : vector<2xf8E5M2> to vector<2xf64>
+  return %w : vector<2xf64>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @vector_ext_long
+// CHECK-SAME: ([[V:%.+]]: vector<9xf8E4M3FN>)
+// CHECK: [[V0:%.+]] = vector.extract_strided_slice [[V]] {offsets = [0], sizes = [4], strides = [1]}
+// CHECK: [[F0:%.+]] = amdgpu.ext_packed_fp8 [[V0]][0]
+// CHECK: [[W0:%.+]] = vector.insert [[F0]]
+// CHECK: [[F1:%.+]] = amdgpu.ext_packed_fp8 [[V0]][1]
+// CHECK: [[W1:%.+]] = vector.insert [[F1]], [[W0]]
+// CHECK: [[F2:%.+]] = amdgpu.ext_packed_fp8 [[V0]][2]
+// CHECK: [[W2:%.+]] = vector.insert [[F2]], [[W1]]
+// CHECK: [[F3:%.+]] = amdgpu.ext_packed_fp8 [[V0]][3]
+// CHECK: [[W3:%.+]] = vector.insert [[F3]], [[W2]]
+
+// CHECK: [[V1:%.+]] = vector.extract_strided_slice [[V]] {offsets = [4], sizes = [4], strides = [1]} : vector<9xf8E4M3FN> to vector<4xf8E4M3FN>
+// CHECK: [[F4:%.+]] = amdgpu.ext_packed_fp8 [[V1]][0]
+// CHECK: [[W4:%.+]] = vector.insert [[F4]], [[W3]]
+// CHECK: [[F5:%.+]] = amdgpu.ext_packed_fp8 [[V1]][1]
+// CHECK: [[W5:%.+]] = vector.insert [[F5]], [[W4]]
+// CHECK: [[F6:%.+]] = amdgpu.ext_packed_fp8 [[V1]][2]
+// CHECK: [[W6:%.+]] = vector.insert [[F6]], [[W5]]
+// CHECK: [[F7:%.+]] = amdgpu.ext_packed_fp8 [[V1]][3]
+// CHECK: [[W7:%.+]] = vector.insert [[F7]], [[W6]]
+
+// CHECK: [[V2:%.+]] = vector.extract_strided_slice [[V]] {offsets = [8], sizes = [1], strides = [1]} : vector<9xf8E4M3FN> to vector<1xf8E4M3FN>
+// CHECK: [[F8:%.+]] = amdgpu.ext_packed_fp8 [[V2]][0]
+// CHECK: [[W8:%.+]] = vector.insert [[F8]], [[W7]]
+// CHECK: return [[W8]]
+func.func @vector_ext_long(%v: vector<9xf8E4M3FN>) -> vector<9xf32> {
+  %w = arith.extf %v : vector<9xf8E4M3FN> to vector<9xf32>
+  return %w : vector<9xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @scalar_trunc
+// CHECK-SAME: ([[V:%.+]]: f16)
+// CHECK: [[FLOAT:%.+]] = arith.extf [[V]] : f16 to f32
+// CHECK: [[TRUNCV:%.+]] = amdgpu.packed_trunc_2xfp8 [[FLOAT]], undef into undef[word 0] : f32 to vector<4xf8E5M2>
+// CHECK: [[W:%.+]] = vector.extract [[TRUNCV]][0] : f8E5M2 from vector<4xf8E5M2>
+// CHECK: return [[W]] : f8E5M2
+func.func @scalar_trunc(%v: f16) -> f8E5M2 {
+  %w = arith.truncf %v : f16 to f8E5M2
+  return %w : f8E5M2
+}
+
+// No 0-D test because arith.truncf hasn't been extended to support it.
+
+// -----
+
+// CHECK-LABEL: func.func @vector_trunc_short
+// CHECK-SAME: ([[V:%.+]]: vector<2xf64>) -> vector<2xf8E5M2> {
+// CHECK: [[V0:%.+]] = vector.extract [[V]][0]
+// CHECK: [[F0:%.+]] = arith.truncf [[V0]] : f64 to f32
+// CHECK: [[V1:%.+]] = vector.extract [[V]][1]
+// CHECK: [[F1:%.+]] = arith.truncf [[V1]] : f64 to f32
+// CHECK: [[W0:%.+]] = amdgpu.packed_trunc_2xfp8 [[F0]], [[F1]] into undef[word 0] : f32 to vector<4xf8E5M2>
+// CHECK: [[W:%.+]] = vector.extract_strided_slice [[W0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xf8E5M2> to vector<2xf8E5M2>
+// CHECK: return [[W]] : vector<2xf8E5M2>
+func.func @vector_trunc_short(%v: vector<2xf64>) -> vector<2xf8E5M2> {
+  %w = arith.truncf %v : vector<2xf64> to vector<2xf8E5M2>
+  return %w : vector<2xf8E5M2>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @vector_trunc_long
+// CHECK-SAME: ([[V:%.+]]: vector<9xf32>)
+// CHECK: [[ZEROES:%.+]] = arith.constant dense<0.000000e+00> : vector<9xf8E4M3FN>
+// CHECK: [[T0:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into undef[word 0]
+// CHECK: [[T1:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into [[T0]][word 1]
+// CHECK: [[W0:%.+]] = vector.insert_strided_slice [[T1]], [[ZEROES]] {offsets = [0], strides = [1]}
+
+// CHECK: [[T2:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into undef[word 0]
+// CHECK: [[T3:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into [[T2]][word 1]
+// CHECK: [[W1:%.+]] = vector.insert_strided_slice [[T3]], [[W0]] {offsets = [4], strides = [1]}
+
+// CHECK: [[T4:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, undef into undef[word 0]
+// CHECK: [[T4_SHORT:%.+]] = vector.extract_strided_slice [[T4]] {offsets = [0], sizes = [1], strides = [1]}
+// CHECK: [[W:%.+]] = vector.insert_strided_slice [[T4_SHORT]], [[W1]] {offsets = [8], strides = [1]}
+// CHECK: return [[W]]
+func.func @vector_trunc_long(%v: vector<9xf32>) -> vector<9xf8E4M3FN> {
+  %w = arith.truncf %v : vector<9xf32> to vector<9xf8E4M3FN>
+  return %w : vector<9xf8E4M3FN>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @vector_trunc_long_2d
+// CHECK-SAME: ([[V:%.+]]: vector<1x9xf32>)
+// CHECK: [[ZEROES:%.+]] = arith.constant dense<0.000000e+00> : vector<9xf8E4M3FN>
+// CHECK: [[T0:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into undef[word 0]
+// CHECK: [[T1:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into [[T0]][word 1]
+// CHECK: [[W0:%.+]] = vector.insert_strided_slice [[T1]], [[ZEROES]] {offsets = [0], strides = [1]}
+
+// CHECK: [[T2:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into undef[word 0]
+// CHECK: [[T3:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into [[T2]][word 1]
+// CHECK: [[W1:%.+]] = vector.insert_strided_slice [[T3]], [[W0]] {offsets = [4], strides = [1]}
+
+// CHECK: [[T4:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, undef into undef[word 0]
+// CHECK: [[T4_SHORT:%.+]] = vector.extract_strided_slice [[T4]] {offsets = [0], sizes = [1], strides = [1]}
+// CHECK: [[W:%.+]] = vector.insert_strided_slice [[T4_SHORT]], [[W1]] {offsets = [8], strides = [1]}
+// CHECK: [[RE:%.+]] = vector.shape_cast [[W]] : vector<9xf8E4M3FN> to vector<1x9xf8E4M3FN>
+// CHECK: return [[RE]]
+func.func @vector_trunc_long_2d(%v: vector<1x9xf32>) -> vector<1x9xf8E4M3FN> {
+  %w = arith.truncf %v : vector<1x9xf32> to vector<1x9xf8E4M3FN>
+  return %w : vector<1x9xf8E4M3FN>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @vector_ext_long_2d
+// CHECK-SAME: ([[V:%.+]]: vector<1x9xf8E4M3FN>)
+// CHECK: [[CAST:%.+]] = vector.shape_cast [[V]] : vector<1x9xf8E4M3FN> to vector<9xf8E4M3FN>
+// CHECK: [[V0:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [0], sizes = [4], strides = [1]}
+// CHECK: [[F0:%.+]] = amdgpu.ext_packed_fp8 [[V0]][0]
+// CHECK: [[W0:%.+]] = vector.insert [[F0]]
+// CHECK: [[F1:%.+]] = amdgpu.ext_packed_fp8 [[V0]][1]
+// CHECK: [[W1:%.+]] = vector.insert [[F1]], [[W0]]
+// CHECK: [[F2:%.+]] = amdgpu.ext_packed_fp8 [[V0]][2]
+// CHECK: [[W2:%.+]] = vector.insert [[F2]], [[W1]]
+// CHECK: [[F3:%.+]] = amdgpu.ext_packed_fp8 [[V0]][3]
+// CHECK: [[W3:%.+]] = vector.insert [[F3]], [[W2]]
+
+// CHECK: [[V1:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [4], sizes = [4], strides = [1]} : vector<9xf8E4M3FN> to vector<4xf8E4M3FN>
+// CHECK: [[F4:%.+]] = amdgpu.ext_packed_fp8 [[V1]][0]
+// CHECK: [[W4:%.+]] = vector.insert [[F4]], [[W3]]
+// CHECK: [[F5:%.+]] = amdgpu.ext_packed_fp8 [[V1]][1]
+// CHECK: [[W5:%.+]] = vector.insert [[F5]], [[W4]]
+// CHECK: [[F6:%.+]] = amdgpu.ext_packed_fp8 [[V1]][2]
+// CHECK: [[W6:%.+]] = vector.insert [[F6]], [[W5]]
+// CHECK: [[F7:%.+]] = amdgpu.ext_packed_fp8 [[V1]][3]
+// CHECK: [[W7:%.+]] = vector.insert [[F7]], [[W6]]
+
+// CHECK: [[V2:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [8], sizes = [1], strides = [1]} : vector<9xf8E4M3FN> to vector<1xf8E4M3FN>
+// CHECK: [[F8:%.+]] = amdgpu.ext_packed_fp8 [[V2]][0]
+// CHECK: [[W8:%.+]] = vector.insert [[F8]], [[W7]]
+// CHECK: [[CAST:%.+]] = vector.shape_cast [[W8]] : vector<9xf32> to vector<1x9xf32>
+// CHECK: return [[CAST]]
+func.func @vector_ext_long_2d(%v: vector<1x9xf8E4M3FN>) -> vector<1x9xf32> {
+  %w = arith.extf %v : vector<1x9xf8E4M3FN> to vector<1x9xf32>
+  return %w : vector<1x9xf32>
+}

>From a37ab7d71aa7cf4fda609d0be9bbe6d36da81d65 Mon Sep 17 00:00:00 2001
From: Paul Fuqua <pf at acm.org>
Date: Thu, 12 Sep 2024 20:54:06 -0500
Subject: [PATCH 3/3] [MLIR][AMDGPU] Clean up and redo after other recent
 patches here.

---
 .../mlir/Dialect/AMDGPU/Utils/Chipset.h       |  4 +--
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           |  6 ++---
 .../ArithToAMDGPU/ArithToAMDGPU.cpp           | 25 +++++++++++++------
 3 files changed, 22 insertions(+), 13 deletions(-)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h b/mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h
index e5c14c1cb68278..5b071a46f49ed9 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h
+++ b/mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h
@@ -49,10 +49,10 @@ struct Chipset {
 #undef DEFINE_COMP_OPERATOR
 
   bool isGfx940() const {
-    return majorVersion == 9 && minorVersion >= 0x40 && minorVersion < 0x50;
+    return majorVersion == 9 && minorVersion >= 4 && minorVersion < 5;
   }
   bool hasOcpFp8() const {
-    return (majorVersion == 9 && minorVersion >= 0x50) || majorVersion >= 12;
+    return (majorVersion == 9 && minorVersion >= 5) || majorVersion >= 12;
   }
 };
 
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 770c5072e2e79d..75cd9a499e61f2 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -771,7 +771,7 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
     ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
     ConversionPatternRewriter &rewriter) const {
   Location loc = op.getLoc();
-  if (chipset.majorVersion != 9 || chipset < kGfx940)
+  if (!(chipset.isGfx940() || chipset.hasOcpFp8()))
     return rewriter.notifyMatchFailure(
         loc, "Fp8 conversion instructions are not available on target "
              "architecture and their emulation is not implemented");
@@ -815,7 +815,7 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
     PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
     ConversionPatternRewriter &rewriter) const {
   Location loc = op.getLoc();
-  if (chipset.majorVersion != 9 || chipset < kGfx940)
+  if (!(chipset.isGfx940() || chipset.hasOcpFp8()))
     return rewriter.notifyMatchFailure(
         loc, "Fp8 conversion instructions are not available on target "
              "architecture and their emulation is not implemented");
@@ -852,7 +852,7 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
     PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
     ConversionPatternRewriter &rewriter) const {
   Location loc = op.getLoc();
-  if (chipset.majorVersion != 9 || chipset < kGfx940)
+  if (!(chipset.isGfx940() || chipset.hasOcpFp8()))
     return rewriter.notifyMatchFailure(
         loc, "Fp8 conversion instructions are not available on target "
              "architecture and their emulation is not implemented");
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 423069d406f472..542f3ed0043e03 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -41,6 +41,10 @@ struct ArithToAMDGPUConversionPass final
 struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> {
   using OpRewritePattern::OpRewritePattern;
 
+  Chipset chipset;
+  ExtFOnFloat8RewritePattern(MLIRContext *ctx, Chipset chipset)
+      : OpRewritePattern::OpRewritePattern(ctx), chipset(chipset) {}
+
   LogicalResult match(arith::ExtFOp op) const override;
   void rewrite(arith::ExtFOp op, PatternRewriter &rewriter) const override;
 };
@@ -68,6 +72,15 @@ struct TruncfToFloat16RewritePattern final
 
 } // end namespace
 
+static LogicalResult isSupportedFp8(Type elementType, Chipset chipset) {
+  if (chipset.isGfx940())
+    return success(elementType.isFloat8E5M2FNUZ() ||
+                   elementType.isFloat8E4M3FNUZ());
+  if (chipset.hasOcpFp8())
+    return success(elementType.isFloat8E5M2() || elementType.isFloat8E4M3FN());
+  return failure();
+}
+
 static Value castF32To(Type elementType, Value f32, Location loc,
                        PatternRewriter &rewriter) {
   if (elementType.isF32())
@@ -86,8 +99,7 @@ LogicalResult ExtFOnFloat8RewritePattern::match(arith::ExtFOp op) const {
       return failure();
     inType = inVecType.getElementType();
   }
-  return success(inType.isFloat8E5M2FNUZ() || inType.isFloat8E4M3FNUZ() ||
-                 inType.isFloat8E5M2() || inType.isFloat8E4M3FN());
+  return isSupportedFp8(inType, chipset);
 }
 
 void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
@@ -218,10 +230,7 @@ LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const {
     // Conversion between 8-bit floats is not supported with truncation enabled.
     return failure();
 
-  return success((((outType.isFloat8E5M2FNUZ() || outType.isFloat8E4M3FNUZ()) &&
-                   chipset.isGfx940()) ||
-                  ((outType.isFloat8E5M2() || outType.isFloat8E4M3FN()) &&
-                   chipset.hasOcpFp8())));
+  return isSupportedFp8(outType, chipset);
 }
 
 void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
@@ -370,7 +379,7 @@ void mlir::arith::populateArithToAMDGPUConversionPatterns(
     bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset) {
 
   if (convertFP8Arithmetic) {
-    patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext());
+    patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext(), chipset);
     patterns.add<TruncFToFloat8RewritePattern>(patterns.getContext(),
                                                saturateFP8Truncf, chipset);
   }
@@ -389,7 +398,7 @@ void ArithToAMDGPUConversionPass::runOnOperation() {
   }
 
   bool convertFP8Arithmetic =
-      maybeChipset->majorVersion == 9 && *maybeChipset >= Chipset(9, 4, 0);
+      maybeChipset->isGfx940() || maybeChipset->hasOcpFp8();
   arith::populateArithToAMDGPUConversionPatterns(
       patterns, convertFP8Arithmetic, saturateFP8Truncf, allowPackedF16Rtz,
       *maybeChipset);



More information about the Mlir-commits mailing list