[Mlir-commits] [mlir] 7042fcc - [MLIR][Arith][Resubmit] add fastMathAttr on arith::extf and arith::truncf (#95346)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jun 14 16:42:33 PDT 2024
Author: Ivy Zhang
Date: 2024-06-15T07:42:29+08:00
New Revision: 7042fcc6389c6c103d501b6f39988eafed0d9b5b
URL: https://github.com/llvm/llvm-project/commit/7042fcc6389c6c103d501b6f39988eafed0d9b5b
DIFF: https://github.com/llvm/llvm-project/commit/7042fcc6389c6c103d501b6f39988eafed0d9b5b.diff
LOG: [MLIR][Arith][Resubmit] add fastMathAttr on arith::extf and arith::truncf (#95346)
Add an `fastMathAttr` on `arith::extf` and `arith::truncf`. If these two
ops are inserted by some promotion passes (like legalize-to-f32 /
emulate-unsupported-floats), they will be labeled as
`FastMathFlags::contract`, denoting that they can be then `eliminated by
canonicalizer`.
The `elimination` can help improve performance, while may introduce some
numerical differences.
Added:
Modified:
mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
mlir/lib/Dialect/Arith/IR/ArithOps.cpp
mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
mlir/test/Dialect/Arith/canonicalize.mlir
mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 06fbdb7f2c4cb..6e24b607cebcf 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1199,7 +1199,7 @@ def Arith_ExtSIOp : Arith_IToICastOp<"extsi"> {
// ExtFOp
//===----------------------------------------------------------------------===//
-def Arith_ExtFOp : Arith_FToFCastOp<"extf"> {
+def Arith_ExtFOp : Arith_FToFCastOp<"extf", [DeclareOpInterfaceMethods<ArithFastMathInterface>]> {
let summary = "cast from floating-point to wider floating-point";
let description = [{
Cast a floating-point value to a larger floating-point-typed value.
@@ -1208,6 +1208,13 @@ def Arith_ExtFOp : Arith_FToFCastOp<"extf"> {
}];
let hasVerifier = 1;
let hasFolder = 1;
+
+ let arguments = (ins FloatLike:$in,
+ OptionalAttr<Arith_FastMathAttr>:$fastmath);
+ let results = (outs FloatLike:$out);
+
+ let assemblyFormat = [{ $in (`fastmath` `` $fastmath^)?
+ attr-dict `:` type($in) `to` type($out) }];
}
//===----------------------------------------------------------------------===//
@@ -1246,9 +1253,11 @@ def Arith_TruncFOp :
Arith_Op<"truncf",
[Pure, SameOperandsAndResultShape, SameInputOutputTensorDims,
DeclareOpInterfaceMethods<ArithRoundingModeInterface>,
+ DeclareOpInterfaceMethods<ArithFastMathInterface>,
DeclareOpInterfaceMethods<CastOpInterface>]>,
Arguments<(ins FloatLike:$in,
- OptionalAttr<Arith_RoundingModeAttr>:$roundingmode)>,
+ OptionalAttr<Arith_RoundingModeAttr>:$roundingmode,
+ OptionalAttr<Arith_FastMathAttr>:$fastmath)>,
Results<(outs FloatLike:$out)> {
let summary = "cast from floating-point to narrower floating-point";
let description = [{
@@ -1267,7 +1276,9 @@ def Arith_TruncFOp :
let hasFolder = 1;
let hasVerifier = 1;
- let assemblyFormat = "$in ($roundingmode^)? attr-dict `:` type($in) `to` type($out)";
+ let assemblyFormat = [{ $in ($roundingmode^)?
+ (`fastmath` `` $fastmath^)?
+ attr-dict `:` type($in) `to` type($out) }];
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 2f6647a2a27b1..f4781fcb546a3 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1390,6 +1390,22 @@ LogicalResult arith::ExtSIOp::verify() {
/// Fold extension of float constants when there is no information loss due the
///
diff erence in fp semantics.
OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {
+ if (auto truncFOp = getOperand().getDefiningOp<TruncFOp>()) {
+ if (truncFOp.getOperand().getType() == getType()) {
+ arith::FastMathFlags truncFMF =
+ truncFOp.getFastmath().value_or(arith::FastMathFlags::none);
+ bool isTruncContract =
+ bitEnumContainsAll(truncFMF, arith::FastMathFlags::contract);
+ arith::FastMathFlags extFMF =
+ getFastmath().value_or(arith::FastMathFlags::none);
+ bool isExtContract =
+ bitEnumContainsAll(extFMF, arith::FastMathFlags::contract);
+ if (isTruncContract && isExtContract) {
+ return truncFOp.getOperand();
+ }
+ }
+ }
+
auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
return constFoldCastOp<FloatAttr, FloatAttr>(
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
index 4a50da3513f99..8e1cb474feee7 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
@@ -94,8 +94,11 @@ void EmulateFloatPattern::rewrite(Operation *op, ArrayRef<Value> operands,
SmallVector<Value> newResults(expandedOp->getResults());
for (auto [res, oldType, newType] : llvm::zip_equal(
MutableArrayRef{newResults}, op->getResultTypes(), resultTypes)) {
- if (oldType != newType)
- res = rewriter.create<arith::TruncFOp>(loc, oldType, res);
+ if (oldType != newType) {
+ auto truncFOp = rewriter.create<arith::TruncFOp>(loc, oldType, res);
+ truncFOp.setFastmath(arith::FastMathFlags::contract);
+ res = truncFOp.getResult();
+ }
}
rewriter.replaceOp(op, newResults);
}
@@ -114,7 +117,9 @@ void mlir::arith::populateEmulateUnsupportedFloatsConversions(
});
converter.addTargetMaterialization(
[](OpBuilder &b, Type target, ValueRange input, Location loc) {
- return b.create<arith::ExtFOp>(loc, target, input);
+ auto extFOp = b.create<arith::ExtFOp>(loc, target, input);
+ extFOp.setFastmath(arith::FastMathFlags::contract);
+ return extFOp;
});
}
diff --git a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
index 5998133b7eab8..3d99f3033cf56 100644
--- a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
@@ -57,7 +57,9 @@ void mlir::math::populateLegalizeToF32TypeConverter(
});
typeConverter.addTargetMaterialization(
[](OpBuilder &b, Type target, ValueRange input, Location loc) {
- return b.create<arith::ExtFOp>(loc, target, input);
+ auto extFOp = b.create<arith::ExtFOp>(loc, target, input);
+ extFOp.setFastmath(arith::FastMathFlags::contract);
+ return extFOp;
});
}
@@ -84,8 +86,11 @@ LogicalResult LegalizeToF32RewritePattern::matchAndRewrite(
SmallVector<Value> results = (*legalized)->getResults();
for (auto [result, newType, origType] : llvm::zip_equal(
results, (*legalized)->getResultTypes(), op->getResultTypes())) {
- if (newType != origType)
- result = rewriter.create<arith::TruncFOp>(loc, origType, result);
+ if (newType != origType) {
+ auto truncFOp = rewriter.create<arith::TruncFOp>(loc, origType, result);
+ truncFOp.setFastmath(arith::FastMathFlags::contract);
+ result = truncFOp.getResult();
+ }
}
rewriter.replaceOp(op, results);
return success();
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index e4f95bb0545a2..4fe7cfb689be8 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -3031,6 +3031,143 @@ func.func @mulsi_extended_i0() -> (i0, i0) {
return %mulsi_extended#0, %mulsi_extended#1 : i0, i0
}
+// CHECK-LABEL: @sequences_fastmath_contract
+// CHECK-SAME: ([[ARG0:%.+]]: bf16)
+// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
+// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
+// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
+// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
+// CHECK: return [[TRUNCF]] : bf16
+func.func @sequences_fastmath_contract(%arg0: bf16) -> bf16 {
+ %0 = arith.extf %arg0 fastmath<contract> : bf16 to f32
+ %1 = math.absf %0 : f32
+ %2 = arith.truncf %1 fastmath<contract> : f32 to bf16
+ %3 = arith.extf %2 fastmath<contract> : bf16 to f32
+ %4 = math.sin %3 : f32
+ %5 = arith.truncf %4 fastmath<contract> : f32 to bf16
+ return %5 : bf16
+}
+
+// CHECK-LABEL: @sequences_no_fastmath
+// CHECK-SAME: ([[ARG0:%.+]]: bf16)
+// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
+// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
+// CHECK: [[TRUNCF1:%.+]] = arith.truncf [[ABSF]]
+// CHECK: [[EXTF1:%.+]] = arith.extf [[TRUNCF1]]
+// CHECK: [[SIN:%.+]] = math.sin [[EXTF1]]
+// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
+// CHECK: return [[TRUNCF]] : bf16
+func.func @sequences_no_fastmath(%arg0: bf16) -> bf16 {
+ %0 = arith.extf %arg0 : bf16 to f32
+ %1 = math.absf %0 : f32
+ %2 = arith.truncf %1 : f32 to bf16
+ %3 = arith.extf %2 : bf16 to f32
+ %4 = math.sin %3 : f32
+ %5 = arith.truncf %4 : f32 to bf16
+ return %5 : bf16
+}
+
+// CHECK-LABEL: @eliminate_cast_to_f16
+// CHECK: return [[arg0:%.+]] : f32
+func.func @eliminate_cast_to_f16(%arg0: f32) -> f32 {
+ %0 = arith.truncf %arg0 fastmath<contract> : f32 to f16
+ %1 = arith.extf %0 fastmath<contract> : f16 to f32
+ return %1 : f32
+}
+
+// CHECK-LABEL: @eliminate_cast_to_bf16
+// CHECK: return [[arg0:%.+]] : f32
+func.func @eliminate_cast_to_bf16(%arg0: f32) -> f32 {
+ %0 = arith.truncf %arg0 fastmath<contract> : f32 to bf16
+ %1 = arith.extf %0 fastmath<contract> : bf16 to f32
+ return %1 : f32
+}
+
+// CHECK-LABEL: @bf16_sin_vector
+// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>)
+// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
+// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
+// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
+// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
+// CHECK: return [[TRUNCF]] : vector<32x32x32xbf16>
+func.func @bf16_sin_vector(%arg0: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> {
+ %0 = arith.extf %arg0 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
+ %1 = math.absf %0 : vector<32x32x32xf32>
+ %2 = arith.truncf %1 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
+ %3 = arith.extf %2 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
+ %4 = math.sin %3 : vector<32x32x32xf32>
+ %5 = arith.truncf %4 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
+ return %5 : vector<32x32x32xbf16>
+}
+
+// CHECK-LABEL: @f16_sin_vector
+// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xf16>)
+// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
+// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
+// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
+// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
+// CHECK: return [[TRUNCF]] : vector<32x32x32xf16>
+func.func @f16_sin_vector(%arg0: vector<32x32x32xf16>) -> vector<32x32x32xf16> {
+ %0 = arith.extf %arg0 fastmath<contract> : vector<32x32x32xf16> to vector<32x32x32xf32>
+ %1 = math.absf %0 : vector<32x32x32xf32>
+ %2 = arith.truncf %1 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xf16>
+ %3 = arith.extf %2 fastmath<contract> : vector<32x32x32xf16> to vector<32x32x32xf32>
+ %4 = math.sin %3 : vector<32x32x32xf32>
+ %5 = arith.truncf %4 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xf16>
+ return %5 : vector<32x32x32xf16>
+}
+
+// CHECK-LABEL: @bf16_branch_vector
+// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>)
+// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
+// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
+// CHECK-DAG: [[SIN:%.+]] = math.sin [[ABSF]]
+// CHECK-DAG: [[COS:%.+]] = math.cos [[ABSF]]
+// CHECK: [[ADDF:%.+]] = arith.addf [[SIN]], [[COS]]
+// CHECK: [[TRUNCF:%.+]] = arith.truncf [[ADDF]]
+// CHECK: return [[TRUNCF]] : vector<32x32x32xbf16>
+func.func @bf16_branch_vector(%arg0: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> {
+ %0 = arith.extf %arg0 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
+ %1 = math.absf %0 : vector<32x32x32xf32>
+ %2 = arith.truncf %1 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
+ %3 = arith.extf %2 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
+ %4 = math.sin %3 : vector<32x32x32xf32>
+ %5 = arith.truncf %4 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
+ %6 = arith.extf %5 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
+ %7 = math.cos %3 : vector<32x32x32xf32>
+ %8 = arith.truncf %7 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
+ %9 = arith.extf %8 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
+ %10 = arith.addf %6, %9 : vector<32x32x32xf32>
+ %11 = arith.truncf %10 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
+ return %11 : vector<32x32x32xbf16>
+}
+
+// CHECK-LABEL: @bf16_fma
+// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>, [[ARG1:%.+]]: vector<32x32x32xbf16>, [[ARG2:%.+]]: vector<32x32x32xbf16>)
+// CHECK: [[EXTF0:%.+]] = arith.extf [[ARG0]]
+// CHECK: [[ABSF:%.+]] = math.absf [[EXTF0]]
+// CHECK-DAG: [[SIN:%.+]] = math.sin [[ABSF]]
+// CHECK: [[TRUNCF0:%.+]] = arith.truncf [[SIN]]
+// CHECK-DAG: [[FMA:%.+]] = math.fma [[TRUNCF0]], [[ARG1]], [[ARG2]]
+// CHECK: [[EXTF1:%.+]] = arith.extf [[FMA]]
+// CHECK: [[ADDF:%.+]] = arith.addf [[EXTF1]], [[SIN]]
+// CHECK: [[TRUNCF1:%.+]] = arith.truncf [[ADDF]]
+// CHECK: return [[TRUNCF1]] : vector<32x32x32xbf16>
+func.func @bf16_fma(%arg0: vector<32x32x32xbf16>, %arg1: vector<32x32x32xbf16>, %arg2: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> {
+ %0 = arith.extf %arg0 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
+ %1 = math.absf %0 : vector<32x32x32xf32>
+ %2 = arith.truncf %1 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
+ %3 = arith.extf %2 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
+ %4 = math.sin %3 : vector<32x32x32xf32>
+ %5 = arith.truncf %4 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
+ %6 = arith.extf %5 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
+ %7 = math.fma %5, %arg1, %arg2 : vector<32x32x32xbf16>
+ %8 = arith.extf %7 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
+ %9 = arith.addf %8, %6 : vector<32x32x32xf32>
+ %10 = arith.truncf %9 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
+ return %10 : vector<32x32x32xbf16>
+}
+
{-#
dialect_resources: {
builtin: {
diff --git a/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir b/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir
index a69ef131d8d47..99790cc45d490 100644
--- a/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir
+++ b/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir
@@ -4,10 +4,10 @@ func.func @basic_expansion(%x: bf16) -> bf16 {
// CHECK-LABEL: @basic_expansion
// CHECK-SAME: [[X:%.+]]: bf16
// CHECK-DAG: [[C:%.+]] = arith.constant {{.*}} : bf16
-// CHECK-DAG: [[X_EXP:%.+]] = arith.extf [[X]] : bf16 to f32
-// CHECK-DAG: [[C_EXP:%.+]] = arith.extf [[C]] : bf16 to f32
+// CHECK-DAG: [[X_EXP:%.+]] = arith.extf [[X]] fastmath<contract> : bf16 to f32
+// CHECK-DAG: [[C_EXP:%.+]] = arith.extf [[C]] fastmath<contract> : bf16 to f32
// CHECK: [[Y_EXP:%.+]] = arith.addf [[X_EXP]], [[C_EXP]] : f32
-// CHECK: [[Y:%.+]] = arith.truncf [[Y_EXP]] : f32 to bf16
+// CHECK: [[Y:%.+]] = arith.truncf [[Y_EXP]] fastmath<contract> : f32 to bf16
// CHECK: return [[Y]]
%c = arith.constant 1.0 : bf16
%y = arith.addf %x, %c : bf16
@@ -19,15 +19,15 @@ func.func @basic_expansion(%x: bf16) -> bf16 {
func.func @chained(%x: bf16, %y: bf16, %z: bf16) -> i1 {
// CHECK-LABEL: @chained
// CHECK-SAME: [[X:%.+]]: bf16, [[Y:%.+]]: bf16, [[Z:%.+]]: bf16
-// CHECK-DAG: [[X_EXP:%.+]] = arith.extf [[X]] : bf16 to f32
-// CHECK-DAG: [[Y_EXP:%.+]] = arith.extf [[Y]] : bf16 to f32
-// CHECK-DAG: [[Z_EXP:%.+]] = arith.extf [[Z]] : bf16 to f32
+// CHECK-DAG: [[X_EXP:%.+]] = arith.extf [[X]] fastmath<contract> : bf16 to f32
+// CHECK-DAG: [[Y_EXP:%.+]] = arith.extf [[Y]] fastmath<contract> : bf16 to f32
+// CHECK-DAG: [[Z_EXP:%.+]] = arith.extf [[Z]] fastmath<contract> : bf16 to f32
// CHECK: [[P_EXP:%.+]] = arith.addf [[X_EXP]], [[Y_EXP]] : f32
-// CHECK: [[P:%.+]] = arith.truncf [[P_EXP]] : f32 to bf16
-// CHECK: [[P_EXP2:%.+]] = arith.extf [[P]] : bf16 to f32
+// CHECK: [[P:%.+]] = arith.truncf [[P_EXP]] fastmath<contract> : f32 to bf16
+// CHECK: [[P_EXP2:%.+]] = arith.extf [[P]] fastmath<contract> : bf16 to f32
// CHECK: [[Q_EXP:%.+]] = arith.mulf [[P_EXP2]], [[Z_EXP]]
-// CHECK: [[Q:%.+]] = arith.truncf [[Q_EXP]] : f32 to bf16
-// CHECK: [[Q_EXP2:%.+]] = arith.extf [[Q]] : bf16 to f32
+// CHECK: [[Q:%.+]] = arith.truncf [[Q_EXP]] fastmath<contract> : f32 to bf16
+// CHECK: [[Q_EXP2:%.+]] = arith.extf [[Q]] fastmath<contract> : bf16 to f32
// CHECK: [[RES:%.+]] = arith.cmpf ole, [[P_EXP2]], [[Q_EXP2]] : f32
// CHECK: return [[RES]]
%p = arith.addf %x, %y : bf16
@@ -41,12 +41,12 @@ func.func @chained(%x: bf16, %y: bf16, %z: bf16) -> i1 {
func.func @memops(%a: memref<4xf8E4M3FNUZ>, %b: memref<4xf8E4M3FNUZ>) {
// CHECK-LABEL: @memops
// CHECK: [[V:%.+]] = memref.load {{.*}} : memref<4xf8E4M3FNUZ>
-// CHECK: [[V_EXP:%.+]] = arith.extf [[V]] : f8E4M3FNUZ to f32
+// CHECK: [[V_EXP:%.+]] = arith.extf [[V]] fastmath<contract> : f8E4M3FNUZ to f32
// CHECK: memref.store [[V]]
// CHECK: [[W:%.+]] = memref.load
-// CHECK: [[W_EXP:%.+]] = arith.extf [[W]] : f8E4M3FNUZ to f32
+// CHECK: [[W_EXP:%.+]] = arith.extf [[W]] fastmath<contract> : f8E4M3FNUZ to f32
// CHECK: [[X_EXP:%.+]] = arith.addf [[V_EXP]], [[W_EXP]] : f32
-// CHECK: [[X:%.+]] = arith.truncf [[X_EXP]] : f32 to f8E4M3FNUZ
+// CHECK: [[X:%.+]] = arith.truncf [[X_EXP]] fastmath<contract> : f32 to f8E4M3FNUZ
// CHECK: memref.store [[X]]
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
@@ -63,9 +63,9 @@ func.func @memops(%a: memref<4xf8E4M3FNUZ>, %b: memref<4xf8E4M3FNUZ>) {
func.func @vectors(%a: vector<4xf8E4M3FNUZ>) -> vector<4xf32> {
// CHECK-LABEL: @vectors
// CHECK-SAME: [[A:%.+]]: vector<4xf8E4M3FNUZ>
-// CHECK: [[A_EXP:%.+]] = arith.extf [[A]] : vector<4xf8E4M3FNUZ> to vector<4xf32>
+// CHECK: [[A_EXP:%.+]] = arith.extf [[A]] fastmath<contract> : vector<4xf8E4M3FNUZ> to vector<4xf32>
// CHECK: [[B_EXP:%.+]] = arith.mulf [[A_EXP]], [[A_EXP]] : vector<4xf32>
-// CHECK: [[B:%.+]] = arith.truncf [[B_EXP]] : vector<4xf32> to vector<4xf8E4M3FNUZ>
+// CHECK: [[B:%.+]] = arith.truncf [[B_EXP]] fastmath<contract> : vector<4xf32> to vector<4xf8E4M3FNUZ>
// CHECK: [[RET:%.+]] = arith.extf [[B]] : vector<4xf8E4M3FNUZ> to vector<4xf32>
// CHECK: return [[RET]]
%b = arith.mulf %a, %a : vector<4xf8E4M3FNUZ>
More information about the Mlir-commits
mailing list