[Mlir-commits] [mlir] 6e41a06 - [mlir][tosa] Revert add-0 canonicalization for floating-point
Rob Suderman
llvmlistbot at llvm.org
Wed Nov 17 17:31:07 PST 2021
Author: Robert Suderman
Date: 2021-11-17T17:29:57-08:00
New Revision: 6e41a0691132482478c4e82d0610fe19a4128f74
URL: https://github.com/llvm/llvm-project/commit/6e41a0691132482478c4e82d0610fe19a4128f74
DIFF: https://github.com/llvm/llvm-project/commit/6e41a0691132482478c4e82d0610fe19a4128f74.diff
LOG: [mlir][tosa] Revert add-0 canonicalization for floating-point
Floating point optimization can produce incorrect numerical resutls for
-0.0 + 0.0 optimization as result needs to be -0.0.
Reviewed By: eric-k256
Differential Revision: https://reviews.llvm.org/D114127
Added:
Modified:
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
mlir/test/Dialect/Tosa/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index aaba64afb71d4..b6d8ecb538a1a 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -301,12 +301,6 @@ struct AddZeroOptimization : public OpRewritePattern<tosa::AddOp> {
DenseElementsAttr input1Attr;
if (matchPattern(input1, m_Constant(&input1Attr)) && input1Attr.isSplat() &&
input2.getType() == op.getType()) {
- if (input1Attr.getType().getElementType().isa<FloatType>() &&
- input1Attr.getSplatValue<APFloat>().isZero()) {
- rewriter.replaceOp(op, op.input2());
- return success();
- }
-
if (input1Attr.getType().getElementType().isa<IntegerType>() &&
input1Attr.getSplatValue<APInt>().isZero()) {
rewriter.replaceOp(op, op.input2());
@@ -317,12 +311,6 @@ struct AddZeroOptimization : public OpRewritePattern<tosa::AddOp> {
DenseElementsAttr input2Attr;
if (matchPattern(input2, m_Constant(&input2Attr)) && input2Attr.isSplat() &&
input1.getType() == op.getType()) {
- if (input2Attr.getType().getElementType().isa<FloatType>() &&
- input2Attr.getSplatValue<APFloat>().isZero()) {
- rewriter.replaceOp(op, op.input1());
- return success();
- }
-
if (input2Attr.getType().getElementType().isa<IntegerType>() &&
input2Attr.getSplatValue<APInt>().isZero()) {
rewriter.replaceOp(op, op.input1());
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 11583eef7c867..65e59b201a248 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -10,23 +10,13 @@ func @argmax_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
// -----
// CHECK-LABEL: @add_zero_
diff erent_shape
-func @add_zero_
diff erent_shape(%arg0: tensor<2x3xf32>) -> tensor<4x2x3xf32> {
+func @add_zero_
diff erent_shape(%arg0: tensor<2x3xi32>) -> tensor<4x2x3xi32> {
// CHECK: tosa.add
- %zeros = "tosa.const"() {value = dense<0.0> : tensor<4x2x3xf32>} : () -> tensor<4x2x3xf32>
- %1 = "tosa.add"(%arg0, %zeros) : (tensor<2x3xf32>, tensor<4x2x3xf32>) -> tensor<4x2x3xf32>
- return %1 : tensor<4x2x3xf32>
+ %zeros = "tosa.const"() {value = dense<0> : tensor<4x2x3xi32>} : () -> tensor<4x2x3xi32>
+ %1 = "tosa.add"(%arg0, %zeros) : (tensor<2x3xi32>, tensor<4x2x3xi32>) -> tensor<4x2x3xi32>
+ return %1 : tensor<4x2x3xi32>
}
-// -----
-
-// CHECK-LABEL: @add_zero_float
-func @add_zero_float(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
- // CHECK: return %arg0
- // CHECK-NOT: tosa.add
- %zeros = "tosa.const"() {value = dense<0.0> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
- %1 = "tosa.add"(%arg0, %zeros) : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
- return %1 : tensor<2x3xf32>
-}
// -----
More information about the Mlir-commits
mailing list