[llvm-branch-commits] [mlir] 8944c8d - Revert "[MLIR][Arith] add fastMathAttr on arith::extf and arith::truncf (#93443)"
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Wed Jun 12 20:12:43 PDT 2024
Author: Ivy Zhang
Date: 2024-06-13T11:12:39+08:00
New Revision: 8944c8df45f8e4da860bf04118106d9a950cbf75
URL: https://github.com/llvm/llvm-project/commit/8944c8df45f8e4da860bf04118106d9a950cbf75
DIFF: https://github.com/llvm/llvm-project/commit/8944c8df45f8e4da860bf04118106d9a950cbf75.diff
LOG: Revert "[MLIR][Arith] add fastMathAttr on arith::extf and arith::truncf (#93443)"
This reverts commit 6784bf764207d267b781b4f515a2fafdcb345509.
Added:
Modified:
mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
mlir/lib/Dialect/Arith/IR/ArithOps.cpp
mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
mlir/test/Dialect/Arith/canonicalize.mlir
mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index c4471f9bc5af2..06fbdb7f2c4cb 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1199,7 +1199,7 @@ def Arith_ExtSIOp : Arith_IToICastOp<"extsi"> {
// ExtFOp
//===----------------------------------------------------------------------===//
-def Arith_ExtFOp : Arith_FToFCastOp<"extf", [DeclareOpInterfaceMethods<ArithFastMathInterface>]> {
+def Arith_ExtFOp : Arith_FToFCastOp<"extf"> {
let summary = "cast from floating-point to wider floating-point";
let description = [{
Cast a floating-point value to a larger floating-point-typed value.
@@ -1208,13 +1208,6 @@ def Arith_ExtFOp : Arith_FToFCastOp<"extf", [DeclareOpInterfaceMethods<ArithFast
}];
let hasVerifier = 1;
let hasFolder = 1;
-
- let arguments = (ins FloatLike:$in, DefaultValuedAttr<
- Arith_FastMathAttr, "::mlir::arith::FastMathFlags::none">:$fastmath);
- let results = (outs FloatLike:$out);
-
- let assemblyFormat = [{ $in (`fastmath` `` $fastmath^)?
- attr-dict `:` type($in) `to` type($out) }];
}
//===----------------------------------------------------------------------===//
@@ -1253,11 +1246,8 @@ def Arith_TruncFOp :
Arith_Op<"truncf",
[Pure, SameOperandsAndResultShape, SameInputOutputTensorDims,
DeclareOpInterfaceMethods<ArithRoundingModeInterface>,
- DeclareOpInterfaceMethods<ArithFastMathInterface>,
DeclareOpInterfaceMethods<CastOpInterface>]>,
Arguments<(ins FloatLike:$in,
- DefaultValuedAttr<
- Arith_FastMathAttr, "::mlir::arith::FastMathFlags::none">:$fastmath,
OptionalAttr<Arith_RoundingModeAttr>:$roundingmode)>,
Results<(outs FloatLike:$out)> {
let summary = "cast from floating-point to narrower floating-point";
@@ -1277,9 +1267,7 @@ def Arith_TruncFOp :
let hasFolder = 1;
let hasVerifier = 1;
- let assemblyFormat = [{ $in ($roundingmode^)?
- (`fastmath` `` $fastmath^)?
- attr-dict `:` type($in) `to` type($out) }];
+ let assemblyFormat = "$in ($roundingmode^)? attr-dict `:` type($in) `to` type($out)";
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 291f6e5424ba5..2f6647a2a27b1 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1390,20 +1390,6 @@ LogicalResult arith::ExtSIOp::verify() {
/// Fold extension of float constants when there is no information loss due the
///
diff erence in fp semantics.
OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {
- if (auto truncFOp = getOperand().getDefiningOp<TruncFOp>()) {
- if (truncFOp.getOperand().getType() == getType()) {
- arith::FastMathFlags truncFMF = truncFOp.getFastmath();
- bool isTruncContract =
- bitEnumContainsAll(truncFMF, arith::FastMathFlags::contract);
- arith::FastMathFlags extFMF = getFastmath();
- bool isExtContract =
- bitEnumContainsAll(extFMF, arith::FastMathFlags::contract);
- if (isTruncContract && isExtContract) {
- return truncFOp.getOperand();
- }
- }
- }
-
auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
return constFoldCastOp<FloatAttr, FloatAttr>(
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
index 8e1cb474feee7..4a50da3513f99 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
@@ -94,11 +94,8 @@ void EmulateFloatPattern::rewrite(Operation *op, ArrayRef<Value> operands,
SmallVector<Value> newResults(expandedOp->getResults());
for (auto [res, oldType, newType] : llvm::zip_equal(
MutableArrayRef{newResults}, op->getResultTypes(), resultTypes)) {
- if (oldType != newType) {
- auto truncFOp = rewriter.create<arith::TruncFOp>(loc, oldType, res);
- truncFOp.setFastmath(arith::FastMathFlags::contract);
- res = truncFOp.getResult();
- }
+ if (oldType != newType)
+ res = rewriter.create<arith::TruncFOp>(loc, oldType, res);
}
rewriter.replaceOp(op, newResults);
}
@@ -117,9 +114,7 @@ void mlir::arith::populateEmulateUnsupportedFloatsConversions(
});
converter.addTargetMaterialization(
[](OpBuilder &b, Type target, ValueRange input, Location loc) {
- auto extFOp = b.create<arith::ExtFOp>(loc, target, input);
- extFOp.setFastmath(arith::FastMathFlags::contract);
- return extFOp;
+ return b.create<arith::ExtFOp>(loc, target, input);
});
}
diff --git a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
index 3d99f3033cf56..5998133b7eab8 100644
--- a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
@@ -57,9 +57,7 @@ void mlir::math::populateLegalizeToF32TypeConverter(
});
typeConverter.addTargetMaterialization(
[](OpBuilder &b, Type target, ValueRange input, Location loc) {
- auto extFOp = b.create<arith::ExtFOp>(loc, target, input);
- extFOp.setFastmath(arith::FastMathFlags::contract);
- return extFOp;
+ return b.create<arith::ExtFOp>(loc, target, input);
});
}
@@ -86,11 +84,8 @@ LogicalResult LegalizeToF32RewritePattern::matchAndRewrite(
SmallVector<Value> results = (*legalized)->getResults();
for (auto [result, newType, origType] : llvm::zip_equal(
results, (*legalized)->getResultTypes(), op->getResultTypes())) {
- if (newType != origType) {
- auto truncFOp = rewriter.create<arith::TruncFOp>(loc, origType, result);
- truncFOp.setFastmath(arith::FastMathFlags::contract);
- result = truncFOp.getResult();
- }
+ if (newType != origType)
+ result = rewriter.create<arith::TruncFOp>(loc, origType, result);
}
rewriter.replaceOp(op, results);
return success();
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index cacdd801871fb..56ae930e6d627 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -162,11 +162,11 @@ func.func @uitofp(%arg0 : i32, %arg1 : i64) {
// Checking conversion of integer types to floating point.
// CHECK-LABEL: @fpext
func.func @fpext(%arg0 : f16, %arg1 : f32) {
-// CHECK-NEXT: = llvm.fpext {{.*}} {fastmath = #arith.fastmath<none>} : f16 to f32
+// CHECK-NEXT: = llvm.fpext {{.*}} : f16 to f32
%0 = arith.extf %arg0: f16 to f32
-// CHECK-NEXT: = llvm.fpext {{.*}} {fastmath = #arith.fastmath<none>} : f16 to f64
+// CHECK-NEXT: = llvm.fpext {{.*}} : f16 to f64
%1 = arith.extf %arg0: f16 to f64
-// CHECK-NEXT: = llvm.fpext {{.*}} {fastmath = #arith.fastmath<none>} : f32 to f64
+// CHECK-NEXT: = llvm.fpext {{.*}} : f32 to f64
%2 = arith.extf %arg1: f32 to f64
return
}
@@ -174,11 +174,11 @@ func.func @fpext(%arg0 : f16, %arg1 : f32) {
// Checking conversion of integer types to floating point.
// CHECK-LABEL: @fpext
func.func @fpext_vector(%arg0 : vector<2xf16>, %arg1 : vector<2xf32>) {
-// CHECK-NEXT: = llvm.fpext {{.*}} {fastmath = #arith.fastmath<none>} : vector<2xf16> to vector<2xf32>
+// CHECK-NEXT: = llvm.fpext {{.*}} : vector<2xf16> to vector<2xf32>
%0 = arith.extf %arg0: vector<2xf16> to vector<2xf32>
-// CHECK-NEXT: = llvm.fpext {{.*}} {fastmath = #arith.fastmath<none>} : vector<2xf16> to vector<2xf64>
+// CHECK-NEXT: = llvm.fpext {{.*}} : vector<2xf16> to vector<2xf64>
%1 = arith.extf %arg0: vector<2xf16> to vector<2xf64>
-// CHECK-NEXT: = llvm.fpext {{.*}} {fastmath = #arith.fastmath<none>} : vector<2xf32> to vector<2xf64>
+// CHECK-NEXT: = llvm.fpext {{.*}} : vector<2xf32> to vector<2xf64>
%2 = arith.extf %arg1: vector<2xf32> to vector<2xf64>
return
}
@@ -268,11 +268,11 @@ func.func @uitofp_vector(%arg0 : vector<2xi16>, %arg1 : vector<2xi32>, %arg2 : v
// Checking conversion of integer types to floating point.
// CHECK-LABEL: @fptrunc
func.func @fptrunc(%arg0 : f32, %arg1 : f64) {
-// CHECK-NEXT: = llvm.fptrunc {{.*}} {fastmath = #arith.fastmath<none>} : f32 to f16
+// CHECK-NEXT: = llvm.fptrunc {{.*}} : f32 to f16
%0 = arith.truncf %arg0: f32 to f16
-// CHECK-NEXT: = llvm.fptrunc {{.*}} {fastmath = #arith.fastmath<none>} : f64 to f16
+// CHECK-NEXT: = llvm.fptrunc {{.*}} : f64 to f16
%1 = arith.truncf %arg1: f64 to f16
-// CHECK-NEXT: = llvm.fptrunc {{.*}} {fastmath = #arith.fastmath<none>} : f64 to f32
+// CHECK-NEXT: = llvm.fptrunc {{.*}} : f64 to f32
%2 = arith.truncf %arg1: f64 to f32
return
}
@@ -280,26 +280,26 @@ func.func @fptrunc(%arg0 : f32, %arg1 : f64) {
// Checking conversion of integer types to floating point.
// CHECK-LABEL: @fptrunc
func.func @fptrunc_vector(%arg0 : vector<2xf32>, %arg1 : vector<2xf64>) {
-// CHECK-NEXT: = llvm.fptrunc {{.*}} {fastmath = #arith.fastmath<none>} : vector<2xf32> to vector<2xf16>
+// CHECK-NEXT: = llvm.fptrunc {{.*}} : vector<2xf32> to vector<2xf16>
%0 = arith.truncf %arg0: vector<2xf32> to vector<2xf16>
-// CHECK-NEXT: = llvm.fptrunc {{.*}} {fastmath = #arith.fastmath<none>} : vector<2xf64> to vector<2xf16>
+// CHECK-NEXT: = llvm.fptrunc {{.*}} : vector<2xf64> to vector<2xf16>
%1 = arith.truncf %arg1: vector<2xf64> to vector<2xf16>
-// CHECK-NEXT: = llvm.fptrunc {{.*}} {fastmath = #arith.fastmath<none>} : vector<2xf64> to vector<2xf32>
+// CHECK-NEXT: = llvm.fptrunc {{.*}} : vector<2xf64> to vector<2xf32>
%2 = arith.truncf %arg1: vector<2xf64> to vector<2xf32>
return
}
// CHECK-LABEL: experimental_constrained_fptrunc
func.func @experimental_constrained_fptrunc(%arg0 : f64) {
-// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} tonearest ignore {fastmath = #arith.fastmath<none>} : f64 to f32
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} tonearest ignore : f64 to f32
%0 = arith.truncf %arg0 to_nearest_even : f64 to f32
-// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} downward ignore {fastmath = #arith.fastmath<none>} : f64 to f32
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} downward ignore : f64 to f32
%1 = arith.truncf %arg0 downward : f64 to f32
-// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} upward ignore {fastmath = #arith.fastmath<none>} : f64 to f32
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} upward ignore : f64 to f32
%2 = arith.truncf %arg0 upward : f64 to f32
-// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} towardzero ignore {fastmath = #arith.fastmath<none>} : f64 to f32
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} towardzero ignore : f64 to f32
%3 = arith.truncf %arg0 toward_zero : f64 to f32
-// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} tonearestaway ignore {fastmath = #arith.fastmath<none>} : f64 to f32
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} tonearestaway ignore : f64 to f32
%4 = arith.truncf %arg0 to_nearest_away : f64 to f32
return
}
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 4fe7cfb689be8..e4f95bb0545a2 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -3031,143 +3031,6 @@ func.func @mulsi_extended_i0() -> (i0, i0) {
return %mulsi_extended#0, %mulsi_extended#1 : i0, i0
}
-// CHECK-LABEL: @sequences_fastmath_contract
-// CHECK-SAME: ([[ARG0:%.+]]: bf16)
-// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
-// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
-// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
-// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
-// CHECK: return [[TRUNCF]] : bf16
-func.func @sequences_fastmath_contract(%arg0: bf16) -> bf16 {
- %0 = arith.extf %arg0 fastmath<contract> : bf16 to f32
- %1 = math.absf %0 : f32
- %2 = arith.truncf %1 fastmath<contract> : f32 to bf16
- %3 = arith.extf %2 fastmath<contract> : bf16 to f32
- %4 = math.sin %3 : f32
- %5 = arith.truncf %4 fastmath<contract> : f32 to bf16
- return %5 : bf16
-}
-
-// CHECK-LABEL: @sequences_no_fastmath
-// CHECK-SAME: ([[ARG0:%.+]]: bf16)
-// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
-// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
-// CHECK: [[TRUNCF1:%.+]] = arith.truncf [[ABSF]]
-// CHECK: [[EXTF1:%.+]] = arith.extf [[TRUNCF1]]
-// CHECK: [[SIN:%.+]] = math.sin [[EXTF1]]
-// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
-// CHECK: return [[TRUNCF]] : bf16
-func.func @sequences_no_fastmath(%arg0: bf16) -> bf16 {
- %0 = arith.extf %arg0 : bf16 to f32
- %1 = math.absf %0 : f32
- %2 = arith.truncf %1 : f32 to bf16
- %3 = arith.extf %2 : bf16 to f32
- %4 = math.sin %3 : f32
- %5 = arith.truncf %4 : f32 to bf16
- return %5 : bf16
-}
-
-// CHECK-LABEL: @eliminate_cast_to_f16
-// CHECK: return [[arg0:%.+]] : f32
-func.func @eliminate_cast_to_f16(%arg0: f32) -> f32 {
- %0 = arith.truncf %arg0 fastmath<contract> : f32 to f16
- %1 = arith.extf %0 fastmath<contract> : f16 to f32
- return %1 : f32
-}
-
-// CHECK-LABEL: @eliminate_cast_to_bf16
-// CHECK: return [[arg0:%.+]] : f32
-func.func @eliminate_cast_to_bf16(%arg0: f32) -> f32 {
- %0 = arith.truncf %arg0 fastmath<contract> : f32 to bf16
- %1 = arith.extf %0 fastmath<contract> : bf16 to f32
- return %1 : f32
-}
-
-// CHECK-LABEL: @bf16_sin_vector
-// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>)
-// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
-// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
-// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
-// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
-// CHECK: return [[TRUNCF]] : vector<32x32x32xbf16>
-func.func @bf16_sin_vector(%arg0: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> {
- %0 = arith.extf %arg0 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
- %1 = math.absf %0 : vector<32x32x32xf32>
- %2 = arith.truncf %1 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
- %3 = arith.extf %2 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
- %4 = math.sin %3 : vector<32x32x32xf32>
- %5 = arith.truncf %4 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
- return %5 : vector<32x32x32xbf16>
-}
-
-// CHECK-LABEL: @f16_sin_vector
-// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xf16>)
-// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
-// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
-// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
-// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
-// CHECK: return [[TRUNCF]] : vector<32x32x32xf16>
-func.func @f16_sin_vector(%arg0: vector<32x32x32xf16>) -> vector<32x32x32xf16> {
- %0 = arith.extf %arg0 fastmath<contract> : vector<32x32x32xf16> to vector<32x32x32xf32>
- %1 = math.absf %0 : vector<32x32x32xf32>
- %2 = arith.truncf %1 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xf16>
- %3 = arith.extf %2 fastmath<contract> : vector<32x32x32xf16> to vector<32x32x32xf32>
- %4 = math.sin %3 : vector<32x32x32xf32>
- %5 = arith.truncf %4 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xf16>
- return %5 : vector<32x32x32xf16>
-}
-
-// CHECK-LABEL: @bf16_branch_vector
-// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>)
-// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
-// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
-// CHECK-DAG: [[SIN:%.+]] = math.sin [[ABSF]]
-// CHECK-DAG: [[COS:%.+]] = math.cos [[ABSF]]
-// CHECK: [[ADDF:%.+]] = arith.addf [[SIN]], [[COS]]
-// CHECK: [[TRUNCF:%.+]] = arith.truncf [[ADDF]]
-// CHECK: return [[TRUNCF]] : vector<32x32x32xbf16>
-func.func @bf16_branch_vector(%arg0: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> {
- %0 = arith.extf %arg0 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
- %1 = math.absf %0 : vector<32x32x32xf32>
- %2 = arith.truncf %1 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
- %3 = arith.extf %2 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
- %4 = math.sin %3 : vector<32x32x32xf32>
- %5 = arith.truncf %4 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
- %6 = arith.extf %5 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
- %7 = math.cos %3 : vector<32x32x32xf32>
- %8 = arith.truncf %7 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
- %9 = arith.extf %8 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
- %10 = arith.addf %6, %9 : vector<32x32x32xf32>
- %11 = arith.truncf %10 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
- return %11 : vector<32x32x32xbf16>
-}
-
-// CHECK-LABEL: @bf16_fma
-// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>, [[ARG1:%.+]]: vector<32x32x32xbf16>, [[ARG2:%.+]]: vector<32x32x32xbf16>)
-// CHECK: [[EXTF0:%.+]] = arith.extf [[ARG0]]
-// CHECK: [[ABSF:%.+]] = math.absf [[EXTF0]]
-// CHECK-DAG: [[SIN:%.+]] = math.sin [[ABSF]]
-// CHECK: [[TRUNCF0:%.+]] = arith.truncf [[SIN]]
-// CHECK-DAG: [[FMA:%.+]] = math.fma [[TRUNCF0]], [[ARG1]], [[ARG2]]
-// CHECK: [[EXTF1:%.+]] = arith.extf [[FMA]]
-// CHECK: [[ADDF:%.+]] = arith.addf [[EXTF1]], [[SIN]]
-// CHECK: [[TRUNCF1:%.+]] = arith.truncf [[ADDF]]
-// CHECK: return [[TRUNCF1]] : vector<32x32x32xbf16>
-func.func @bf16_fma(%arg0: vector<32x32x32xbf16>, %arg1: vector<32x32x32xbf16>, %arg2: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> {
- %0 = arith.extf %arg0 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
- %1 = math.absf %0 : vector<32x32x32xf32>
- %2 = arith.truncf %1 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
- %3 = arith.extf %2 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
- %4 = math.sin %3 : vector<32x32x32xf32>
- %5 = arith.truncf %4 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
- %6 = arith.extf %5 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
- %7 = math.fma %5, %arg1, %arg2 : vector<32x32x32xbf16>
- %8 = arith.extf %7 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
- %9 = arith.addf %8, %6 : vector<32x32x32xf32>
- %10 = arith.truncf %9 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
- return %10 : vector<32x32x32xbf16>
-}
-
{-#
dialect_resources: {
builtin: {
diff --git a/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir b/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir
index 99790cc45d490..a69ef131d8d47 100644
--- a/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir
+++ b/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir
@@ -4,10 +4,10 @@ func.func @basic_expansion(%x: bf16) -> bf16 {
// CHECK-LABEL: @basic_expansion
// CHECK-SAME: [[X:%.+]]: bf16
// CHECK-DAG: [[C:%.+]] = arith.constant {{.*}} : bf16
-// CHECK-DAG: [[X_EXP:%.+]] = arith.extf [[X]] fastmath<contract> : bf16 to f32
-// CHECK-DAG: [[C_EXP:%.+]] = arith.extf [[C]] fastmath<contract> : bf16 to f32
+// CHECK-DAG: [[X_EXP:%.+]] = arith.extf [[X]] : bf16 to f32
+// CHECK-DAG: [[C_EXP:%.+]] = arith.extf [[C]] : bf16 to f32
// CHECK: [[Y_EXP:%.+]] = arith.addf [[X_EXP]], [[C_EXP]] : f32
-// CHECK: [[Y:%.+]] = arith.truncf [[Y_EXP]] fastmath<contract> : f32 to bf16
+// CHECK: [[Y:%.+]] = arith.truncf [[Y_EXP]] : f32 to bf16
// CHECK: return [[Y]]
%c = arith.constant 1.0 : bf16
%y = arith.addf %x, %c : bf16
@@ -19,15 +19,15 @@ func.func @basic_expansion(%x: bf16) -> bf16 {
func.func @chained(%x: bf16, %y: bf16, %z: bf16) -> i1 {
// CHECK-LABEL: @chained
// CHECK-SAME: [[X:%.+]]: bf16, [[Y:%.+]]: bf16, [[Z:%.+]]: bf16
-// CHECK-DAG: [[X_EXP:%.+]] = arith.extf [[X]] fastmath<contract> : bf16 to f32
-// CHECK-DAG: [[Y_EXP:%.+]] = arith.extf [[Y]] fastmath<contract> : bf16 to f32
-// CHECK-DAG: [[Z_EXP:%.+]] = arith.extf [[Z]] fastmath<contract> : bf16 to f32
+// CHECK-DAG: [[X_EXP:%.+]] = arith.extf [[X]] : bf16 to f32
+// CHECK-DAG: [[Y_EXP:%.+]] = arith.extf [[Y]] : bf16 to f32
+// CHECK-DAG: [[Z_EXP:%.+]] = arith.extf [[Z]] : bf16 to f32
// CHECK: [[P_EXP:%.+]] = arith.addf [[X_EXP]], [[Y_EXP]] : f32
-// CHECK: [[P:%.+]] = arith.truncf [[P_EXP]] fastmath<contract> : f32 to bf16
-// CHECK: [[P_EXP2:%.+]] = arith.extf [[P]] fastmath<contract> : bf16 to f32
+// CHECK: [[P:%.+]] = arith.truncf [[P_EXP]] : f32 to bf16
+// CHECK: [[P_EXP2:%.+]] = arith.extf [[P]] : bf16 to f32
// CHECK: [[Q_EXP:%.+]] = arith.mulf [[P_EXP2]], [[Z_EXP]]
-// CHECK: [[Q:%.+]] = arith.truncf [[Q_EXP]] fastmath<contract> : f32 to bf16
-// CHECK: [[Q_EXP2:%.+]] = arith.extf [[Q]] fastmath<contract> : bf16 to f32
+// CHECK: [[Q:%.+]] = arith.truncf [[Q_EXP]] : f32 to bf16
+// CHECK: [[Q_EXP2:%.+]] = arith.extf [[Q]] : bf16 to f32
// CHECK: [[RES:%.+]] = arith.cmpf ole, [[P_EXP2]], [[Q_EXP2]] : f32
// CHECK: return [[RES]]
%p = arith.addf %x, %y : bf16
@@ -41,12 +41,12 @@ func.func @chained(%x: bf16, %y: bf16, %z: bf16) -> i1 {
func.func @memops(%a: memref<4xf8E4M3FNUZ>, %b: memref<4xf8E4M3FNUZ>) {
// CHECK-LABEL: @memops
// CHECK: [[V:%.+]] = memref.load {{.*}} : memref<4xf8E4M3FNUZ>
-// CHECK: [[V_EXP:%.+]] = arith.extf [[V]] fastmath<contract> : f8E4M3FNUZ to f32
+// CHECK: [[V_EXP:%.+]] = arith.extf [[V]] : f8E4M3FNUZ to f32
// CHECK: memref.store [[V]]
// CHECK: [[W:%.+]] = memref.load
-// CHECK: [[W_EXP:%.+]] = arith.extf [[W]] fastmath<contract> : f8E4M3FNUZ to f32
+// CHECK: [[W_EXP:%.+]] = arith.extf [[W]] : f8E4M3FNUZ to f32
// CHECK: [[X_EXP:%.+]] = arith.addf [[V_EXP]], [[W_EXP]] : f32
-// CHECK: [[X:%.+]] = arith.truncf [[X_EXP]] fastmath<contract> : f32 to f8E4M3FNUZ
+// CHECK: [[X:%.+]] = arith.truncf [[X_EXP]] : f32 to f8E4M3FNUZ
// CHECK: memref.store [[X]]
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
@@ -63,9 +63,9 @@ func.func @memops(%a: memref<4xf8E4M3FNUZ>, %b: memref<4xf8E4M3FNUZ>) {
func.func @vectors(%a: vector<4xf8E4M3FNUZ>) -> vector<4xf32> {
// CHECK-LABEL: @vectors
// CHECK-SAME: [[A:%.+]]: vector<4xf8E4M3FNUZ>
-// CHECK: [[A_EXP:%.+]] = arith.extf [[A]] fastmath<contract> : vector<4xf8E4M3FNUZ> to vector<4xf32>
+// CHECK: [[A_EXP:%.+]] = arith.extf [[A]] : vector<4xf8E4M3FNUZ> to vector<4xf32>
// CHECK: [[B_EXP:%.+]] = arith.mulf [[A_EXP]], [[A_EXP]] : vector<4xf32>
-// CHECK: [[B:%.+]] = arith.truncf [[B_EXP]] fastmath<contract> : vector<4xf32> to vector<4xf8E4M3FNUZ>
+// CHECK: [[B:%.+]] = arith.truncf [[B_EXP]] : vector<4xf32> to vector<4xf8E4M3FNUZ>
// CHECK: [[RET:%.+]] = arith.extf [[B]] : vector<4xf8E4M3FNUZ> to vector<4xf32>
// CHECK: return [[RET]]
%b = arith.mulf %a, %a : vector<4xf8E4M3FNUZ>
More information about the llvm-branch-commits
mailing list