[Mlir-commits] [mlir] [mlir][ArithToAMDGPU] Add option for saturating truncation to fp8 (PR #74153)

Krzysztof Drewniak llvmlistbot at llvm.org
Tue Jan 23 12:30:30 PST 2024


https://github.com/krzysz00 updated https://github.com/llvm/llvm-project/pull/74153

>From f113249fface8a6ce9fda4917ea750ca3d45a268 Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
Date: Tue, 28 Nov 2023 20:22:09 +0000
Subject: [PATCH] [mlir][ArithToAMDGPU] Add option for saturating truncation to
 fp8

Many machine-learning applications (and most software written at AMD)
expect the operation that truncates floats to 8-bit floats to be
saturatinng. That is, they expect `truncf 256.0 : f32 to f8E4M3FNUZ`
to yield `240.0`, not `NaN`, and similarly for negative numbers.
However, the underlying hardware instruction that can be used for this
truncation implements overflow-to-NaN semantics.

To enable handling this usecase, we add the saturate-fp8-truncf option
to ArithToAMDGPU (off by default), which causes the requisite clamping
code to be emitted. Said clamping code ensures that Inf and NaN are
passed through exactly (and thus trancate to NaN).

Per review feedback, this commit efactors
createScalarOrSplatConstant() to the Arith dialect utilities and uses
it in this code. It also fixes naming of existing patterns and
switches from vector.extractelement/insertelement to
vector.extract/insert.
---
 .../Conversion/ArithToAMDGPU/ArithToAMDGPU.h  |   8 +-
 mlir/include/mlir/Conversion/Passes.td        |   6 +
 mlir/include/mlir/Dialect/Arith/Utils/Utils.h |  11 ++
 .../ArithToAMDGPU/ArithToAMDGPU.cpp           | 108 +++++++++++++-----
 .../Conversion/ArithToAMDGPU/CMakeLists.txt   |   1 +
 .../Arith/Transforms/EmulateWideInt.cpp       |  30 +----
 mlir/lib/Dialect/Arith/Utils/Utils.cpp        |  34 ++++++
 .../ArithToAMDGPU/8-bit-float-saturation.mlir |  54 +++++++++
 .../ArithToAMDGPU/8-bit-floats.mlir           |  33 +++---
 9 files changed, 208 insertions(+), 77 deletions(-)
 create mode 100644 mlir/test/Conversion/ArithToAMDGPU/8-bit-float-saturation.mlir

diff --git a/mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h b/mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h
index 7f445fee5ba6b8..78c79c915e0607 100644
--- a/mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h
+++ b/mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h
@@ -20,7 +20,13 @@ class Pass;
 #include "mlir/Conversion/Passes.h.inc"
 
 namespace arith {
-void populateArithToAMDGPUConversionPatterns(RewritePatternSet &patterns);
+/// Add patterns for rewriting `arith.extf` and `arith.truncf` on FP8 types
+/// to wrappers around AMDGPU--specific intrinsics. If `saturateFP8TruncF`
+/// is set, values outside the range of the destination type are clamped
+/// to the largest value of that type instead of being rewritten to Inf (aka
+/// NaN).
+void populateArithToAMDGPUConversionPatterns(RewritePatternSet &patterns,
+                                             bool saturateFP8TruncF);
 } // namespace arith
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 3467e042c493e9..ec0a6284fe97d8 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -125,6 +125,12 @@ def ArithToAMDGPUConversionPass : Pass<"convert-arith-to-amdgpu"> {
   }];
 
   let dependentDialects = ["amdgpu::AMDGPUDialect", "vector::VectorDialect"];
+
+  let options = [
+    Option<"saturateFP8Truncf", "saturate-fp8-truncf", "bool",
+           /*default=*/"false",
+           "Use saturating truncation for 8-bit float types">,
+  ];
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
index 62a84ee7903d76..402bd196f0736a 100644
--- a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
@@ -53,6 +53,17 @@ Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc,
 Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand,
                            Type toType, bool isUnsignedCast);
 
+/// Create a constant of type `type` at location `loc` whose value is `value`
+/// (an APInt or APFloat whose type must match the element type of `type`).
+/// If `type` is a shaped type, create a splat constant of the given value.
+/// Constants are folded if possible.
+Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type,
+                                  const APInt &value);
+Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type,
+                                  int64_t value);
+Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type,
+                                  const APFloat &value);
+
 /// Helper struct to build simple arithmetic quantities with minimal type
 /// inference support.
 struct ArithBuilder {
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 7785405eae67be..c625a302a39703 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -10,6 +10,7 @@
 
 #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/PatternMatch.h"
@@ -34,17 +35,17 @@ struct ArithToAMDGPUConversionPass final
   void runOnOperation() override;
 };
 
-struct ExtfOnFloat8RewritePattern final
-    : public OpRewritePattern<arith::ExtFOp> {
-  using OpRewritePattern<arith::ExtFOp>::OpRewritePattern;
+struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> {
+  using OpRewritePattern::OpRewritePattern;
 
   LogicalResult match(arith::ExtFOp op) const override;
   void rewrite(arith::ExtFOp op, PatternRewriter &rewriter) const override;
 };
 
-struct TruncfToFloat8RewritePattern final
-    : public OpRewritePattern<arith::TruncFOp> {
-  using OpRewritePattern<arith::TruncFOp>::OpRewritePattern;
+struct TruncFToFloat8RewritePattern final : OpRewritePattern<arith::TruncFOp> {
+  bool saturateFP8 = false;
+  TruncFToFloat8RewritePattern(MLIRContext *ctx, bool saturateFP8)
+      : OpRewritePattern::OpRewritePattern(ctx), saturateFP8(saturateFP8) {}
 
   LogicalResult match(arith::TruncFOp op) const override;
   void rewrite(arith::TruncFOp op, PatternRewriter &rewriter) const override;
@@ -62,7 +63,7 @@ static Value castF32To(Type elementType, Value f32, Location loc,
   llvm_unreachable("The only 32-bit float type is f32");
 }
 
-LogicalResult ExtfOnFloat8RewritePattern::match(arith::ExtFOp op) const {
+LogicalResult ExtFOnFloat8RewritePattern::match(arith::ExtFOp op) const {
   Type inType = op.getIn().getType();
   if (auto inVecType = inType.dyn_cast<VectorType>()) {
     if (inVecType.isScalable())
@@ -75,7 +76,7 @@ LogicalResult ExtfOnFloat8RewritePattern::match(arith::ExtFOp op) const {
   return success(inType.isFloat8E5M2FNUZ() || inType.isFloat8E4M3FNUZ());
 }
 
-void ExtfOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
+void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
                                          PatternRewriter &rewriter) const {
   Location loc = op.getLoc();
   Value in = op.getIn();
@@ -93,11 +94,13 @@ void ExtfOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
   Value result =
       rewriter.createOrFold<vector::SplatOp>(loc, op.getOut().getType(), zero);
   if (inType.getShape().empty()) {
-    Value scalarIn = rewriter.create<vector::ExtractElementOp>(loc, in);
+    Value scalarIn =
+        rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{});
     // Recurse to send the 0-D vector case to the 1-D vector case
     Value scalarExt =
         rewriter.create<arith::ExtFOp>(loc, outElemType, scalarIn);
-    result = rewriter.create<vector::InsertElementOp>(loc, scalarExt, zero);
+    result = rewriter.create<vector::InsertOp>(loc, scalarExt, zero,
+                                               ArrayRef<int64_t>{});
     return rewriter.replaceOp(op, result);
   }
   for (int64_t i = 0; i < numElements; i += 4) {
@@ -108,9 +111,7 @@ void ExtfOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
       Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
           loc, rewriter.getF32Type(), inSlice, j);
       Value asType = castF32To(outElemType, asFloat, loc, rewriter);
-      result = rewriter.create<vector::InsertElementOp>(
-          loc, asType, result,
-          rewriter.createOrFold<arith::ConstantIndexOp>(loc, i + j));
+      result = rewriter.create<vector::InsertOp>(loc, asType, result, i + j);
     }
   }
   rewriter.replaceOp(op, result);
@@ -127,7 +128,53 @@ static Value castToF32(Value value, Location loc, PatternRewriter &rewriter) {
   llvm_unreachable("The only 32-bit float type is f32");
 }
 
-LogicalResult TruncfToFloat8RewritePattern::match(arith::TruncFOp op) const {
+// If `in` is a finite value, clamp it between the maximum and minimum values
+// of `outElemType` so that subsequent conversion instructions don't
+// overflow those out-of-range values to NaN. These semantics are commonly
+// used in machine-learning contexts where failure to clamp would lead to
+// excessive NaN production.
+static Value clampInput(PatternRewriter &rewriter, Location loc,
+                        Type outElemType, Value source) {
+  Type sourceType = source.getType();
+  const llvm::fltSemantics &sourceSem =
+      cast<FloatType>(getElementTypeOrSelf(sourceType)).getFloatSemantics();
+  const llvm::fltSemantics &targetSem =
+      cast<FloatType>(outElemType).getFloatSemantics();
+
+  APFloat min = APFloat::getLargest(targetSem, /*Negative=*/true);
+  APFloat max = APFloat::getLargest(targetSem, /*Negative=*/false);
+  bool ignoredLosesInfo = false;
+  // We can ignore conversion failures here because this conversion promotes
+  // from a smaller type to a larger one - ex. there can be no loss of precision
+  // when casting fp8 to f16.
+  (void)min.convert(sourceSem, APFloat::rmNearestTiesToEven, &ignoredLosesInfo);
+  (void)max.convert(sourceSem, APFloat::rmNearestTiesToEven, &ignoredLosesInfo);
+
+  Value minCst = createScalarOrSplatConstant(rewriter, loc, sourceType, min);
+  Value maxCst = createScalarOrSplatConstant(rewriter, loc, sourceType, max);
+
+  Value inf = createScalarOrSplatConstant(
+      rewriter, loc, sourceType,
+      APFloat::getInf(sourceSem, /*Negative=*/false));
+  Value negInf = createScalarOrSplatConstant(
+      rewriter, loc, sourceType, APFloat::getInf(sourceSem, /*Negative=*/true));
+  Value isInf = rewriter.createOrFold<arith::CmpFOp>(
+      loc, arith::CmpFPredicate::OEQ, source, inf);
+  Value isNegInf = rewriter.createOrFold<arith::CmpFOp>(
+      loc, arith::CmpFPredicate::OEQ, source, negInf);
+  Value isNan = rewriter.createOrFold<arith::CmpFOp>(
+      loc, arith::CmpFPredicate::UNO, source, source);
+  Value isNonFinite = rewriter.create<arith::OrIOp>(
+      loc, rewriter.create<arith::OrIOp>(loc, isInf, isNegInf), isNan);
+
+  Value clampedBelow = rewriter.create<arith::MaximumFOp>(loc, source, minCst);
+  Value clamped = rewriter.create<arith::MinimumFOp>(loc, clampedBelow, maxCst);
+  Value res =
+      rewriter.create<arith::SelectOp>(loc, isNonFinite, source, clamped);
+  return res;
+}
+
+LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const {
   Type outType = op.getOut().getType();
   if (auto outVecType = outType.dyn_cast<VectorType>()) {
     if (outVecType.isScalable())
@@ -137,22 +184,27 @@ LogicalResult TruncfToFloat8RewritePattern::match(arith::TruncFOp op) const {
       return failure();
     outType = outVecType.getElementType();
   }
+  auto inType = dyn_cast<FloatType>(getElementTypeOrSelf(op.getIn().getType()));
+  if (inType && inType.getWidth() <= 8 && saturateFP8)
+    // Conversion between 8-bit floats is not supported with truncation enabled.
+    return failure();
   return success(outType.isFloat8E5M2FNUZ() || outType.isFloat8E4M3FNUZ());
 }
 
-void TruncfToFloat8RewritePattern::rewrite(arith::TruncFOp op,
+void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
                                            PatternRewriter &rewriter) const {
   Location loc = op.getLoc();
   Value in = op.getIn();
   Type outElemType = getElementTypeOrSelf(op.getOut().getType());
+  if (saturateFP8)
+    in = clampInput(rewriter, loc, outElemType, in);
   VectorType truncResType = VectorType::get(4, outElemType);
   if (!in.getType().isa<VectorType>()) {
     Value asFloat = castToF32(in, loc, rewriter);
     Value asF8s = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
         loc, truncResType, asFloat, /*sourceB=*/nullptr, 0,
         /*existing=*/nullptr);
-    Value result = rewriter.create<vector::ExtractElementOp>(
-        loc, asF8s, rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0));
+    Value result = rewriter.create<vector::ExtractOp>(loc, asF8s, 0);
     return rewriter.replaceOp(op, result);
   }
   VectorType outType = op.getOut().getType().cast<VectorType>();
@@ -161,11 +213,13 @@ void TruncfToFloat8RewritePattern::rewrite(arith::TruncFOp op,
       loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
   Value result = rewriter.createOrFold<vector::SplatOp>(loc, outType, zero);
   if (outType.getShape().empty()) {
-    Value scalarIn = rewriter.create<vector::ExtractElementOp>(loc, in);
+    Value scalarIn =
+        rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{});
     // Recurse to send the 0-D vector case to the 1-D vector case
     Value scalarTrunc =
         rewriter.create<arith::TruncFOp>(loc, outElemType, scalarIn);
-    result = rewriter.create<vector::InsertElementOp>(loc, scalarTrunc, zero);
+    result = rewriter.create<vector::InsertOp>(loc, scalarTrunc, zero,
+                                               ArrayRef<int64_t>{});
     return rewriter.replaceOp(op, result);
   }
 
@@ -173,14 +227,11 @@ void TruncfToFloat8RewritePattern::rewrite(arith::TruncFOp op,
     int64_t elemsThisOp = std::min(numElements, i + 4) - i;
     Value thisResult = nullptr;
     for (int64_t j = 0; j < elemsThisOp; j += 2) {
-      Value elemA = rewriter.create<vector::ExtractElementOp>(
-          loc, in, rewriter.create<arith::ConstantIndexOp>(loc, i + j));
+      Value elemA = rewriter.create<vector::ExtractOp>(loc, in, i + j);
       Value asFloatA = castToF32(elemA, loc, rewriter);
       Value asFloatB = nullptr;
       if (j + 1 < elemsThisOp) {
-        Value elemB = rewriter.create<vector::ExtractElementOp>(
-            loc, in,
-            rewriter.createOrFold<arith::ConstantIndexOp>(loc, i + j + 1));
+        Value elemB = rewriter.create<vector::ExtractOp>(loc, in, i + j + 1);
         asFloatB = castToF32(elemB, loc, rewriter);
       }
       thisResult = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
@@ -196,15 +247,16 @@ void TruncfToFloat8RewritePattern::rewrite(arith::TruncFOp op,
 }
 
 void mlir::arith::populateArithToAMDGPUConversionPatterns(
-    RewritePatternSet &patterns) {
-  patterns.add<ExtfOnFloat8RewritePattern, TruncfToFloat8RewritePattern>(
-      patterns.getContext());
+    RewritePatternSet &patterns, bool saturateFP8TruncF) {
+  patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext());
+  patterns.add<TruncFToFloat8RewritePattern>(patterns.getContext(),
+                                             saturateFP8TruncF);
 }
 
 void ArithToAMDGPUConversionPass::runOnOperation() {
   Operation *op = getOperation();
   RewritePatternSet patterns(op->getContext());
-  arith::populateArithToAMDGPUConversionPatterns(patterns);
+  arith::populateArithToAMDGPUConversionPatterns(patterns, saturateFP8Truncf);
   if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
     return signalPassFailure();
 }
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/CMakeLists.txt b/mlir/lib/Conversion/ArithToAMDGPU/CMakeLists.txt
index 359015b6f86ad8..e2c951b0b34d8b 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/CMakeLists.txt
+++ b/mlir/lib/Conversion/ArithToAMDGPU/CMakeLists.txt
@@ -13,6 +13,7 @@ add_mlir_conversion_library(MLIRArithToAMDGPU
   LINK_LIBS PUBLIC
   MLIRAMDGPUDialect
   MLIRArithDialect
+  MLIRArithUtils
   MLIRVectorDialect
   MLIRPass
   MLIRTransforms
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
index 9e783c51c63d1d..8a4080ea019702 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
@@ -10,6 +10,7 @@
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Transforms/WideIntEmulationConverter.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -58,35 +59,6 @@ static Type reduceInnermostDim(VectorType type) {
   return VectorType::get(newShape, type.getElementType());
 }
 
-/// Returns a constant of integer of vector type filled with (repeated) `value`.
-static Value createScalarOrSplatConstant(ConversionPatternRewriter &rewriter,
-                                         Location loc, Type type,
-                                         const APInt &value) {
-  TypedAttr attr;
-  if (dyn_cast<IntegerType>(type)) {
-    attr = rewriter.getIntegerAttr(type, value);
-  } else {
-    auto vecTy = cast<VectorType>(type);
-    attr = SplatElementsAttr::get(vecTy, value);
-  }
-
-  return rewriter.create<arith::ConstantOp>(loc, attr);
-}
-
-/// Returns a constant of integer of vector type filled with (repeated) `value`.
-static Value createScalarOrSplatConstant(ConversionPatternRewriter &rewriter,
-                                         Location loc, Type type,
-                                         int64_t value) {
-  unsigned elementBitWidth = 0;
-  if (auto intTy = dyn_cast<IntegerType>(type))
-    elementBitWidth = intTy.getWidth();
-  else
-    elementBitWidth = cast<VectorType>(type).getElementTypeBitWidth();
-
-  return createScalarOrSplatConstant(rewriter, loc, type,
-                                     APInt(elementBitWidth, value));
-}
-
 /// Extracts the `input` vector slice with elements at the last dimension offset
 /// by `lastOffset`. Returns a value of vector type with the last dimension
 /// reduced to x1 or fully scalarized, e.g.:
diff --git a/mlir/lib/Dialect/Arith/Utils/Utils.cpp b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
index 0f39c24fb917d8..bf274d4ae27ed8 100644
--- a/mlir/lib/Dialect/Arith/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
@@ -197,6 +197,40 @@ mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
       }));
 }
 
+Value mlir::createScalarOrSplatConstant(OpBuilder &builder, Location loc,
+                                        Type type, const APInt &value) {
+  TypedAttr attr;
+  if (isa<IntegerType>(type)) {
+    attr = builder.getIntegerAttr(type, value);
+  } else {
+    auto vecTy = cast<ShapedType>(type);
+    attr = SplatElementsAttr::get(vecTy, value);
+  }
+
+  return builder.create<arith::ConstantOp>(loc, attr);
+}
+
+Value mlir::createScalarOrSplatConstant(OpBuilder &builder, Location loc,
+                                        Type type, int64_t value) {
+  unsigned elementBitWidth = 0;
+  if (auto intTy = dyn_cast<IntegerType>(type))
+    elementBitWidth = intTy.getWidth();
+  else
+    elementBitWidth = cast<ShapedType>(type).getElementTypeBitWidth();
+
+  return createScalarOrSplatConstant(builder, loc, type,
+                                     APInt(elementBitWidth, value));
+}
+
+Value mlir::createScalarOrSplatConstant(OpBuilder &builder, Location loc,
+                                        Type type, const APFloat &value) {
+  if (isa<FloatType>(type))
+    return builder.createOrFold<arith::ConstantOp>(
+        loc, type, builder.getFloatAttr(type, value));
+  TypedAttr splat = SplatElementsAttr::get(cast<ShapedType>(type), value);
+  return builder.createOrFold<arith::ConstantOp>(loc, type, splat);
+}
+
 Value ArithBuilder::_and(Value lhs, Value rhs) {
   return b.create<arith::AndIOp>(loc, lhs, rhs);
 }
diff --git a/mlir/test/Conversion/ArithToAMDGPU/8-bit-float-saturation.mlir b/mlir/test/Conversion/ArithToAMDGPU/8-bit-float-saturation.mlir
new file mode 100644
index 00000000000000..c7f39440a349b0
--- /dev/null
+++ b/mlir/test/Conversion/ArithToAMDGPU/8-bit-float-saturation.mlir
@@ -0,0 +1,54 @@
+// RUN: mlir-opt --split-input-file %s \
+// RUN:   --pass-pipeline='builtin.module(func.func(convert-arith-to-amdgpu{saturate-fp8-truncf=true}))' \
+// RUN:   | FileCheck %s
+
+// CHECK-LABEL: func.func @scalar_trunc
+// CHECK-SAME: ([[V:%.+]]: f16)
+// CHECK-DAG: [[CMin:%.+]] = arith.constant -5.734400e+04 : f16
+// CHECK-DAG: [[CMax:%.+]] = arith.constant 5.734400e+04 : f16
+// CHECK-DAG: [[CInf:%.+]] = arith.constant 0x7C00 : f16
+// CHECK-DAG: [[CNegInf:%.+]] = arith.constant 0xFC00 : f16
+// CHECK: [[ISINF:%.+]] = arith.cmpf oeq, [[V]], [[CInf]]
+// CHECK: [[ISNEGINF:%.+]] = arith.cmpf oeq, [[V]], [[CNegInf]]
+// CHECK: [[ISNAN:%.+]] = arith.cmpf uno, [[V]], [[V]]
+// CHECK: [[ISNONFINITE_1:%.+]] = arith.ori [[ISINF]], [[ISNEGINF]]
+// CHECK: [[ISNONFINITE:%.+]] = arith.ori [[ISNONFINITE_1]], [[ISNAN]]
+// CHECK: [[CLAMPEDBELOW:%.+]] = arith.maximumf [[V]], [[CMin]]
+// CHECK: [[CLAMPED:%.+]] = arith.minimumf [[CLAMPEDBELOW]], [[CMax]]
+// CHECK: [[SATURATED:%.+]] = arith.select [[ISNONFINITE]], [[V]], [[CLAMPED]]
+// CHECK: [[FLOAT:%.+]] = arith.extf [[SATURATED]] : f16 to f32
+// CHECK: [[TRUNCV:%.+]] = amdgpu.packed_trunc_2xfp8 [[FLOAT]], undef into undef[word 0] : f32 to vector<4xf8E5M2FNUZ>
+// CHECK: [[W:%.+]] = vector.extract [[TRUNCV]][0] : f8E5M2FNUZ from vector<4xf8E5M2FNUZ>
+// CHECK: return [[W]] : f8E5M2FNUZ
+func.func @scalar_trunc(%v: f16) -> f8E5M2FNUZ {
+  %w = arith.truncf %v : f16 to f8E5M2FNUZ
+  return %w : f8E5M2FNUZ
+}
+
+// No 0-D test because arith.truncf hasn't been extended to support it.
+
+// -----
+
+// CHECK-LABEL: func.func @vector_trunc
+// CHECK-SAME: ([[V:%.+]]: vector<2xf32>) -> vector<2xf8E4M3FNUZ> {
+// CHECK-DAG: [[CMin:%.+]] = arith.constant dense<-2.400000e+02> : vector<2xf32>
+// CHECK-DAG: [[CMax:%.+]] = arith.constant dense<2.400000e+02> : vector<2xf32>
+// CHECK-DAG: [[CInf:%.+]] = arith.constant dense<0x7F800000> : vector<2xf32>
+// CHECK-DAG: [[CNegInf:%.+]] = arith.constant dense<0xFF800000> : vector<2xf32>
+// CHECK: [[ISINF:%.+]] = arith.cmpf oeq, [[V]], [[CInf]]
+// CHECK: [[ISNEGINF:%.+]] = arith.cmpf oeq, [[V]], [[CNegInf]]
+// CHECK: [[ISNAN:%.+]] = arith.cmpf uno, [[V]], [[V]]
+// CHECK: [[ISNONFINITE_1:%.+]] = arith.ori [[ISINF]], [[ISNEGINF]]
+// CHECK: [[ISNONFINITE:%.+]] = arith.ori [[ISNONFINITE_1]], [[ISNAN]]
+// CHECK: [[CLAMPEDBELOW:%.+]] = arith.maximumf [[V]], [[CMin]]
+// CHECK: [[CLAMPED:%.+]] = arith.minimumf [[CLAMPEDBELOW]], [[CMax]]
+// CHECK: [[SATURATED:%.+]] = arith.select [[ISNONFINITE]], [[V]], [[CLAMPED]]
+// CHECK: [[F0:%.+]] = vector.extract [[SATURATED]][0]
+// CHECK: [[F1:%.+]] = vector.extract [[SATURATED]][1]
+// CHECK: [[W0:%.+]] = amdgpu.packed_trunc_2xfp8 [[F0]], [[F1]] into undef[word 0] : f32 to vector<4xf8E4M3FNUZ>
+// CHECK: [[W:%.+]] = vector.extract_strided_slice [[W0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xf8E4M3FNUZ> to vector<2xf8E4M3FNUZ>
+// CHECK: return [[W]] : vector<2xf8E4M3FNUZ>
+func.func @vector_trunc_short(%v: vector<2xf32>) -> vector<2xf8E4M3FNUZ> {
+  %w = arith.truncf %v : vector<2xf32> to vector<2xf8E4M3FNUZ>
+  return %w : vector<2xf8E4M3FNUZ>
+}
diff --git a/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir b/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir
index a6c11d022e2c15..159a2f02f0560e 100644
--- a/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir
+++ b/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir
@@ -17,14 +17,12 @@ func.func @scalar_ext(%v: f8E5M2FNUZ) -> f16 {
 // CHECK-LABEL: func.func @vector_ext_short
 // CHECK-SAME: ([[V:%.+]]: vector<2xf8E5M2FNUZ>)
 // CHECK-DAG: [[ZEROES:%.+]] = arith.constant dense<0.000000e+00> : vector<2xf64>
-// CHECK-DAG: [[C0:%.+]] = arith.constant 0 : index
-// CHECK-DAG: [[C1:%.+]] = arith.constant 1 : index
 // CHECK: [[FLOAT0:%.+]] = amdgpu.ext_packed_fp8 [[V]][0] : vector<2xf8E5M2FNUZ> to f32
 // CHECK: [[EXT0:%.+]] = arith.extf [[FLOAT0]] : f32 to f64
-// CHECK: [[W0:%.+]] = vector.insertelement [[EXT0]], [[ZEROES]]{{\[}}[[C0]]
+// CHECK: [[W0:%.+]] = vector.insert [[EXT0]], [[ZEROES]] [0]
 // CHECK: [[FLOAT1:%.+]] = amdgpu.ext_packed_fp8 [[V]][1] : vector<2xf8E5M2FNUZ> to f32
 // CHECK: [[EXT1:%.+]] = arith.extf [[FLOAT1]]
-// CHECK: [[W1:%.+]] = vector.insertelement [[EXT1]], [[W0]]{{\[}}[[C1]]
+// CHECK: [[W1:%.+]] = vector.insert [[EXT1]], [[W0]] [1]
 // CHECK: return [[W1]] : vector<2xf64>
 
 func.func @vector_ext_short(%v: vector<2xf8E5M2FNUZ>) -> vector<2xf64> {
@@ -38,27 +36,27 @@ func.func @vector_ext_short(%v: vector<2xf8E5M2FNUZ>) -> vector<2xf64> {
 // CHECK-SAME: ([[V:%.+]]: vector<9xf8E4M3FNUZ>)
 // CHECK: [[V0:%.+]] = vector.extract_strided_slice [[V]] {offsets = [0], sizes = [4], strides = [1]}
 // CHECK: [[F0:%.+]] = amdgpu.ext_packed_fp8 [[V0]][0]
-// CHECK: [[W0:%.+]] = vector.insertelement [[F0]]
+// CHECK: [[W0:%.+]] = vector.insert [[F0]]
 // CHECK: [[F1:%.+]] = amdgpu.ext_packed_fp8 [[V0]][1]
-// CHECK: [[W1:%.+]] = vector.insertelement [[F1]], [[W0]]
+// CHECK: [[W1:%.+]] = vector.insert [[F1]], [[W0]]
 // CHECK: [[F2:%.+]] = amdgpu.ext_packed_fp8 [[V0]][2]
-// CHECK: [[W2:%.+]] = vector.insertelement [[F2]], [[W1]]
+// CHECK: [[W2:%.+]] = vector.insert [[F2]], [[W1]]
 // CHECK: [[F3:%.+]] = amdgpu.ext_packed_fp8 [[V0]][3]
-// CHECK: [[W3:%.+]] = vector.insertelement [[F3]], [[W2]]
+// CHECK: [[W3:%.+]] = vector.insert [[F3]], [[W2]]
 
 // CHECK: [[V1:%.+]] = vector.extract_strided_slice [[V]] {offsets = [4], sizes = [4], strides = [1]} : vector<9xf8E4M3FNUZ> to vector<4xf8E4M3FNUZ>
 // CHECK: [[F4:%.+]] = amdgpu.ext_packed_fp8 [[V1]][0]
-// CHECK: [[W4:%.+]] = vector.insertelement [[F4]], [[W3]]
+// CHECK: [[W4:%.+]] = vector.insert [[F4]], [[W3]]
 // CHECK: [[F5:%.+]] = amdgpu.ext_packed_fp8 [[V1]][1]
-// CHECK: [[W5:%.+]] = vector.insertelement [[F5]], [[W4]]
+// CHECK: [[W5:%.+]] = vector.insert [[F5]], [[W4]]
 // CHECK: [[F6:%.+]] = amdgpu.ext_packed_fp8 [[V1]][2]
-// CHECK: [[W6:%.+]] = vector.insertelement [[F6]], [[W5]]
+// CHECK: [[W6:%.+]] = vector.insert [[F6]], [[W5]]
 // CHECK: [[F7:%.+]] = amdgpu.ext_packed_fp8 [[V1]][3]
-// CHECK: [[W7:%.+]] = vector.insertelement [[F7]], [[W6]]
+// CHECK: [[W7:%.+]] = vector.insert [[F7]], [[W6]]
 
 // CHECK: [[V2:%.+]] = vector.extract_strided_slice [[V]] {offsets = [8], sizes = [1], strides = [1]} : vector<9xf8E4M3FNUZ> to vector<1xf8E4M3FNUZ>
 // CHECK: [[F8:%.+]] = amdgpu.ext_packed_fp8 [[V2]][0]
-// CHECK: [[W8:%.+]] = vector.insertelement [[F8]], [[W7]]
+// CHECK: [[W8:%.+]] = vector.insert [[F8]], [[W7]]
 // CHECK: return [[W8]]
 func.func @vector_ext_long(%v: vector<9xf8E4M3FNUZ>) -> vector<9xf32> {
   %w = arith.extf %v : vector<9xf8E4M3FNUZ> to vector<9xf32>
@@ -69,10 +67,9 @@ func.func @vector_ext_long(%v: vector<9xf8E4M3FNUZ>) -> vector<9xf32> {
 
 // CHECK-LABEL: func.func @scalar_trunc
 // CHECK-SAME: ([[V:%.+]]: f16)
-// CHECK: [[C0:%.+]] = arith.constant 0 : index
 // CHECK: [[FLOAT:%.+]] = arith.extf [[V]] : f16 to f32
 // CHECK: [[TRUNCV:%.+]] = amdgpu.packed_trunc_2xfp8 [[FLOAT]], undef into undef[word 0] : f32 to vector<4xf8E5M2FNUZ>
-// CHECK: [[W:%.+]] = vector.extractelement [[TRUNCV]]{{\[}}[[C0]] : index] : vector<4xf8E5M2FNUZ>
+// CHECK: [[W:%.+]] = vector.extract [[TRUNCV]][0] : f8E5M2FNUZ from vector<4xf8E5M2FNUZ>
 // CHECK: return [[W]] : f8E5M2FNUZ
 func.func @scalar_trunc(%v: f16) -> f8E5M2FNUZ {
   %w = arith.truncf %v : f16 to f8E5M2FNUZ
@@ -85,11 +82,9 @@ func.func @scalar_trunc(%v: f16) -> f8E5M2FNUZ {
 
 // CHECK-LABEL: func.func @vector_trunc_short
 // CHECK-SAME: ([[V:%.+]]: vector<2xf64>) -> vector<2xf8E5M2FNUZ> {
-// CHECK-DAG: [[C0:%.+]] = arith.constant 0 : index
-// CHECK-DAG: [[C1:%.+]] = arith.constant 1 : index
-// CHECK: [[V0:%.+]] = vector.extractelement [[V]]{{\[}}[[C0]] : index]
+// CHECK: [[V0:%.+]] = vector.extract [[V]][0]
 // CHECK: [[F0:%.+]] = arith.truncf [[V0]] : f64 to f32
-// CHECK: [[V1:%.+]] = vector.extractelement [[V]]{{\[}}[[C1]] : index]
+// CHECK: [[V1:%.+]] = vector.extract [[V]][1]
 // CHECK: [[F1:%.+]] = arith.truncf [[V1]] : f64 to f32
 // CHECK: [[W0:%.+]] = amdgpu.packed_trunc_2xfp8 [[F0]], [[F1]] into undef[word 0] : f32 to vector<4xf8E5M2FNUZ>
 // CHECK: [[W:%.+]] = vector.extract_strided_slice [[W0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xf8E5M2FNUZ> to vector<2xf8E5M2FNUZ>



More information about the Mlir-commits mailing list