[Mlir-commits] [mlir] f278cf9 - [MLIR][arith] More float op folders

Christian Sigg llvmlistbot at llvm.org
Mon Jan 31 10:31:55 PST 2022


Author: Christian Sigg
Date: 2022-01-31T19:31:48+01:00
New Revision: f278cf9cbc3edbd4441da87e23e703d9a331c994

URL: https://github.com/llvm/llvm-project/commit/f278cf9cbc3edbd4441da87e23e703d9a331c994
DIFF: https://github.com/llvm/llvm-project/commit/f278cf9cbc3edbd4441da87e23e703d9a331c994.diff

LOG: [MLIR][arith] More float op folders

Fold `arith.fadd %x, -0.0 -> %x` and similarly for `fsub`, `fmul`, `fdiv`.

Fold `arith.fmin %x, %x -> %x`, `arith.fmin %x, +inf -> %x` and similarly for `fmax`.

Reviewed By: pifon2a, mehdi_amini, bondhugula

Differential Revision: https://reviews.llvm.org/D118244

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
    mlir/include/mlir/IR/Matchers.h
    mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
    mlir/test/Dialect/Arithmetic/canonicalize.mlir
    mlir/test/Dialect/Linalg/vectorization.mlir
    mlir/test/Dialect/SCF/loop-pipelining.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
index 5f52e440c4da6..b7eb2730136ef 100644
--- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
@@ -653,6 +653,7 @@ def Arith_MaxFOp : Arith_FloatBinaryOp<"maxf", [Commutative]> {
     %a = arith.maxf %b, %c : f64
     ```
   }];
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -696,6 +697,7 @@ def Arith_MinFOp : Arith_FloatBinaryOp<"minf", [Commutative]> {
     %a = arith.minf %b, %c : f64
     ```
   }];
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h
index 78c65f8cce301..cc879cfc80dc4 100644
--- a/mlir/include/mlir/IR/Matchers.h
+++ b/mlir/include/mlir/IR/Matchers.h
@@ -91,6 +91,43 @@ struct constant_op_binder {
   }
 };
 
+/// The matcher that matches a constant scalar / vector splat / tensor splat
+/// float operation and binds the constant float value.
+struct constant_float_op_binder {
+  FloatAttr::ValueType *bind_value;
+
+  /// Creates a matcher instance that binds the value to bv if match succeeds.
+  constant_float_op_binder(FloatAttr::ValueType *bv) : bind_value(bv) {}
+
+  bool match(Operation *op) {
+    Attribute attr;
+    if (!constant_op_binder<Attribute>(&attr).match(op))
+      return false;
+    auto type = op->getResult(0).getType();
+
+    if (type.isa<FloatType>())
+      return attr_value_binder<FloatAttr>(bind_value).match(attr);
+    if (type.isa<VectorType, RankedTensorType>()) {
+      if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) {
+        return attr_value_binder<FloatAttr>(bind_value)
+            .match(splatAttr.getSplatValue<Attribute>());
+      }
+    }
+    return false;
+  }
+};
+
+/// The matcher that matches a given target constant scalar / vector splat /
+/// tensor splat float value that fulfills a predicate.
+struct constant_float_predicate_matcher {
+  bool (*predicate)(const APFloat &);
+
+  bool match(Operation *op) {
+    APFloat value(APFloat::Bogus());
+    return constant_float_op_binder(&value).match(op) && predicate(value);
+  }
+};
+
 /// The matcher that matches a constant scalar / vector splat / tensor splat
 /// integer operation and binds the constant integer value.
 struct constant_int_op_binder {
@@ -118,22 +155,13 @@ struct constant_int_op_binder {
 };
 
 /// The matcher that matches a given target constant scalar / vector splat /
-/// tensor splat integer value.
-template <int64_t TargetValue>
-struct constant_int_value_matcher {
-  bool match(Operation *op) {
-    APInt value;
-    return constant_int_op_binder(&value).match(op) && TargetValue == value;
-  }
-};
+/// tensor splat integer value that fulfills a predicate.
+struct constant_int_predicate_matcher {
+  bool (*predicate)(const APInt &);
 
-/// The matcher that matches anything except the given target constant scalar /
-/// vector splat / tensor splat integer value.
-template <int64_t TargetNotValue>
-struct constant_int_not_value_matcher {
   bool match(Operation *op) {
     APInt value;
-    return constant_int_op_binder(&value).match(op) && TargetNotValue != value;
+    return constant_int_op_binder(&value).match(op) && predicate(value);
   }
 };
 
@@ -239,26 +267,65 @@ inline detail::constant_op_binder<AttrT> m_Constant(AttrT *bind_value) {
   return detail::constant_op_binder<AttrT>(bind_value);
 }
 
-/// Matches a constant scalar / vector splat / tensor splat integer one.
-inline detail::constant_int_value_matcher<1> m_One() {
-  return detail::constant_int_value_matcher<1>();
+/// Matches a constant scalar / vector splat / tensor splat float (both positive
+/// and negative) zero.
+inline detail::constant_float_predicate_matcher m_AnyZeroFloat() {
+  return {[](const APFloat &value) { return value.isZero(); }};
 }
 
-/// Matches the given OpClass.
-template <typename OpClass>
-inline detail::op_matcher<OpClass> m_Op() {
-  return detail::op_matcher<OpClass>();
+/// Matches a constant scalar / vector splat / tensor splat float positive zero.
+inline detail::constant_float_predicate_matcher m_PosZeroFloat() {
+  return {[](const APFloat &value) { return value.isPosZero(); }};
+}
+
+/// Matches a constant scalar / vector splat / tensor splat float negative zero.
+inline detail::constant_float_predicate_matcher m_NegZeroFloat() {
+  return {[](const APFloat &value) { return value.isNegZero(); }};
+}
+
+/// Matches a constant scalar / vector splat / tensor splat float ones.
+inline detail::constant_float_predicate_matcher m_OneFloat() {
+  return {[](const APFloat &value) {
+    return APFloat(value.getSemantics(), 1) == value;
+  }};
+}
+
+/// Matches a constant scalar / vector splat / tensor splat float positive
+/// infinity.
+inline detail::constant_float_predicate_matcher m_PosInfFloat() {
+  return {[](const APFloat &value) {
+    return !value.isNegative() && value.isInfinity();
+  }};
+}
+
+/// Matches a constant scalar / vector splat / tensor splat float negative
+/// infinity.
+inline detail::constant_float_predicate_matcher m_NegInfFloat() {
+  return {[](const APFloat &value) {
+    return value.isNegative() && value.isInfinity();
+  }};
 }
 
 /// Matches a constant scalar / vector splat / tensor splat integer zero.
-inline detail::constant_int_value_matcher<0> m_Zero() {
-  return detail::constant_int_value_matcher<0>();
+inline detail::constant_int_predicate_matcher m_Zero() {
+  return {[](const APInt &value) { return 0 == value; }};
 }
 
 /// Matches a constant scalar / vector splat / tensor splat integer that is any
 /// non-zero value.
-inline detail::constant_int_not_value_matcher<0> m_NonZero() {
-  return detail::constant_int_not_value_matcher<0>();
+inline detail::constant_int_predicate_matcher m_NonZero() {
+  return {[](const APInt &value) { return 0 != value; }};
+}
+
+/// Matches a constant scalar / vector splat / tensor splat integer one.
+inline detail::constant_int_predicate_matcher m_One() {
+  return {[](const APInt &value) { return 1 == value; }};
+}
+
+/// Matches the given OpClass.
+template <typename OpClass>
+inline detail::op_matcher<OpClass> m_Op() {
+  return detail::op_matcher<OpClass>();
 }
 
 /// Entry point for matching a pattern over a Value.
@@ -276,6 +343,13 @@ inline bool matchPattern(Operation *op, const Pattern &pattern) {
   return const_cast<Pattern &>(pattern).match(op);
 }
 
+/// Matches a constant holding a scalar/vector/tensor float (splat) and
+/// writes the float value to bind_value.
+inline detail::constant_float_op_binder
+m_ConstantFloat(FloatAttr::ValueType *bind_value) {
+  return detail::constant_float_op_binder(bind_value);
+}
+
 /// Matches a constant holding a scalar/vector/tensor integer (splat) and
 /// writes the integer value to bind_value.
 inline detail::constant_int_op_binder

diff  --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
index d2473ebbff407..022a5674ef9ee 100644
--- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
+++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
@@ -194,12 +194,12 @@ OpFoldResult arith::AddIOp::fold(ArrayRef<Attribute> operands) {
   if (matchPattern(getRhs(), m_Zero()))
     return getLhs();
 
-  // add(sub(a, b), b) -> a
+  // addi(subi(a, b), b) -> a
   if (auto sub = getLhs().getDefiningOp<SubIOp>())
     if (getRhs() == sub.getRhs())
       return sub.getLhs();
 
-  // add(b, sub(a, b)) -> a
+  // addi(b, subi(a, b)) -> a
   if (auto sub = getRhs().getDefiningOp<SubIOp>())
     if (getLhs() == sub.getRhs())
       return sub.getLhs();
@@ -576,6 +576,14 @@ void arith::XOrIOp::getCanonicalizationPatterns(
 //===----------------------------------------------------------------------===//
 
 OpFoldResult arith::AddFOp::fold(ArrayRef<Attribute> operands) {
+  // addf(x, -0) -> x
+  if (matchPattern(getRhs(), m_NegZeroFloat()))
+    return getLhs();
+
+  // addf(-0, x) -> x
+  if (matchPattern(getLhs(), m_NegZeroFloat()))
+    return getRhs();
+
   return constFoldBinaryOp<FloatAttr>(
       operands, [](const APFloat &a, const APFloat &b) { return a + b; });
 }
@@ -585,10 +593,34 @@ OpFoldResult arith::AddFOp::fold(ArrayRef<Attribute> operands) {
 //===----------------------------------------------------------------------===//
 
 OpFoldResult arith::SubFOp::fold(ArrayRef<Attribute> operands) {
+  // subf(x, +0) -> x
+  if (matchPattern(getRhs(), m_PosZeroFloat()))
+    return getLhs();
+
   return constFoldBinaryOp<FloatAttr>(
       operands, [](const APFloat &a, const APFloat &b) { return a - b; });
 }
 
+//===----------------------------------------------------------------------===//
+// MaxFOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult arith::MaxFOp::fold(ArrayRef<Attribute> operands) {
+  assert(operands.size() == 2 && "maxf takes two operands");
+
+  // maxf(x,x) -> x
+  if (getLhs() == getRhs())
+    return getRhs();
+
+  // maxf(x, -inf) -> x
+  if (matchPattern(getRhs(), m_NegInfFloat()))
+    return getLhs();
+
+  return constFoldBinaryOp<FloatAttr>(
+      operands,
+      [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
+}
+
 //===----------------------------------------------------------------------===//
 // MaxSIOp
 //===----------------------------------------------------------------------===//
@@ -643,6 +675,26 @@ OpFoldResult MaxUIOp::fold(ArrayRef<Attribute> operands) {
                                         });
 }
 
+//===----------------------------------------------------------------------===//
+// MinFOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult arith::MinFOp::fold(ArrayRef<Attribute> operands) {
+  assert(operands.size() == 2 && "minf takes two operands");
+
+  // minf(x,x) -> x
+  if (getLhs() == getRhs())
+    return getRhs();
+
+  // minf(x, +inf) -> x
+  if (matchPattern(getRhs(), m_PosInfFloat()))
+    return getLhs();
+
+  return constFoldBinaryOp<FloatAttr>(
+      operands,
+      [](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); });
+}
+
 //===----------------------------------------------------------------------===//
 // MinSIOp
 //===----------------------------------------------------------------------===//
@@ -702,6 +754,15 @@ OpFoldResult MinUIOp::fold(ArrayRef<Attribute> operands) {
 //===----------------------------------------------------------------------===//
 
 OpFoldResult arith::MulFOp::fold(ArrayRef<Attribute> operands) {
+  APFloat floatValue(0.0f), inverseValue(0.0f);
+  // mulf(x, 1) -> x
+  if (matchPattern(getRhs(), m_OneFloat()))
+    return getLhs();
+
+  // mulf(1, x) -> x
+  if (matchPattern(getLhs(), m_OneFloat()))
+    return getRhs();
+
   return constFoldBinaryOp<FloatAttr>(
       operands, [](const APFloat &a, const APFloat &b) { return a * b; });
 }
@@ -711,6 +772,11 @@ OpFoldResult arith::MulFOp::fold(ArrayRef<Attribute> operands) {
 //===----------------------------------------------------------------------===//
 
 OpFoldResult arith::DivFOp::fold(ArrayRef<Attribute> operands) {
+  APFloat floatValue(0.0f), inverseValue(0.0f);
+  // divf(x, 1) -> x
+  if (matchPattern(getRhs(), m_OneFloat()))
+    return getLhs();
+
   return constFoldBinaryOp<FloatAttr>(
       operands, [](const APFloat &a, const APFloat &b) { return a / b; });
 }

diff  --git a/mlir/test/Dialect/Arithmetic/canonicalize.mlir b/mlir/test/Dialect/Arithmetic/canonicalize.mlir
index 22a88d165d300..cf87af2e30e3c 100644
--- a/mlir/test/Dialect/Arithmetic/canonicalize.mlir
+++ b/mlir/test/Dialect/Arithmetic/canonicalize.mlir
@@ -619,6 +619,113 @@ func @test_minui(%arg0 : i8) -> (i8, i8, i8, i8) {
 
 // -----
 
+// CHECK-LABEL: @test_minf(
+func @test_minf(%arg0 : f32) -> (f32, f32, f32) {
+  // CHECK-DAG:   %[[C0:.+]] = arith.constant 0.0
+  // CHECK-NEXT:  %[[X:.+]] = arith.minf %arg0, %[[C0]]
+  // CHECK-NEXT:  return %[[X]], %arg0, %arg0
+  %c0 = arith.constant 0.0 : f32
+  %inf = arith.constant 0x7F800000 : f32
+  %0 = arith.minf %c0, %arg0 : f32
+  %1 = arith.minf %arg0, %arg0 : f32
+  %2 = arith.minf %inf, %arg0 : f32
+  return %0, %1, %2 : f32, f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: @test_maxf(
+func @test_maxf(%arg0 : f32) -> (f32, f32, f32) {
+  // CHECK-DAG:   %[[C0:.+]] = arith.constant
+  // CHECK-NEXT:  %[[X:.+]] = arith.maxf %arg0, %[[C0]]
+  // CHECK-NEXT:   return %[[X]], %arg0, %arg0
+  %c0 = arith.constant 0.0 : f32
+  %-inf = arith.constant 0xFF800000 : f32
+  %0 = arith.maxf %c0, %arg0 : f32
+  %1 = arith.maxf %arg0, %arg0 : f32
+  %2 = arith.maxf %-inf, %arg0 : f32
+  return %0, %1, %2 : f32, f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: @test_addf(
+func @test_addf(%arg0 : f32) -> (f32, f32, f32, f32) {
+  // CHECK-DAG:   %[[C2:.+]] = arith.constant 2.0
+  // CHECK-DAG:   %[[C0:.+]] = arith.constant 0.0
+  // CHECK-NEXT:  %[[X:.+]] = arith.addf %arg0, %[[C0]]
+  // CHECK-NEXT:   return %[[X]], %arg0, %arg0, %[[C2]]
+  %c0 = arith.constant 0.0 : f32
+  %c-0 = arith.constant -0.0 : f32
+  %c1 = arith.constant 1.0 : f32
+  %0 = arith.addf %arg0, %c0 : f32
+  %1 = arith.addf %arg0, %c-0 : f32
+  %2 = arith.addf %c-0, %arg0 : f32
+  %3 = arith.addf %c1, %c1 : f32
+  return %0, %1, %2, %3 : f32, f32, f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: @test_subf(
+func @test_subf(%arg0 : f16) -> (f16, f16, f16) {
+  // CHECK-DAG:   %[[C1:.+]] = arith.constant -1.0
+  // CHECK-DAG:   %[[C0:.+]] = arith.constant -0.0
+  // CHECK-NEXT:  %[[X:.+]] = arith.subf %arg0, %[[C0]]
+  // CHECK-NEXT:   return %arg0, %[[X]], %[[C1]]
+  %c0 = arith.constant 0.0 : f16
+  %c-0 = arith.constant -0.0 : f16
+  %c1 = arith.constant 1.0 : f16
+  %0 = arith.subf %arg0, %c0 : f16
+  %1 = arith.subf %arg0, %c-0 : f16
+  %2 = arith.subf %c0, %c1 : f16
+  return %0, %1, %2 : f16, f16, f16
+}
+
+// -----
+
+// CHECK-LABEL: @test_mulf(
+func @test_mulf(%arg0 : f32) -> (f32, f32, f32) {
+  // CHECK-NEXT:  %[[C4:.+]] = arith.constant 4.0
+  // CHECK-NEXT:   return %arg0, %arg0, %[[C4]]
+  %c1 = arith.constant 1.0 : f32
+  %c2 = arith.constant 2.0 : f32
+  %0 = arith.mulf %arg0, %c1 : f32
+  %1 = arith.mulf %c1, %arg0 : f32
+  %2 = arith.mulf %c2, %c2 : f32
+  return %0, %1, %2 : f32, f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: @test_divf(
+func @test_divf(%arg0 : f64) -> (f64, f64) {
+  // CHECK-NEXT:  %[[C5:.+]] = arith.constant 5.000000e-01
+  // CHECK-NEXT:   return %arg0, %[[C5]]
+  %c1 = arith.constant 1.0 : f64
+  %c2 = arith.constant 2.0 : f64
+  %0 = arith.divf %arg0, %c1 : f64
+  %1 = arith.divf %c1, %c2 : f64
+  return %0, %1 : f64, f64
+}
+
+// -----
+
+// CHECK-LABEL: @test_cmpf(
+func @test_cmpf(%arg0 : f32) -> (i1, i1, i1, i1) {
+//   CHECK-DAG:   %[[T:.*]] = arith.constant true
+//   CHECK-DAG:   %[[F:.*]] = arith.constant false
+//       CHECK:   return %[[F]], %[[F]], %[[T]], %[[T]]
+  %nan = arith.constant 0x7fffffff : f32
+  %0 = arith.cmpf olt, %nan, %arg0 : f32
+  %1 = arith.cmpf olt, %arg0, %nan : f32
+  %2 = arith.cmpf ugt, %nan, %arg0 : f32
+  %3 = arith.cmpf ugt, %arg0, %nan : f32
+  return %0, %1, %2, %3 : i1, i1, i1, i1
+}
+
+// -----
+
 // CHECK-LABEL: @constant_FPtoUI(
 func @constant_FPtoUI() -> i32 {
   // CHECK: %[[C0:.+]] = arith.constant 2 : i32
@@ -678,30 +785,3 @@ func @constant_UItoFP() -> f32 {
   %res = arith.sitofp %c0 : i32 to f32
   return %res : f32
 }
-
-// -----
-// CHECK-LABEL: @constant_MinMax(
-func @constant_MinMax(%arg0 : f32) -> f32 {
-  // CHECK:  %[[const:.+]] = arith.constant
-  // CHECK:  %[[min:.+]] = arith.minf %arg0, %[[const]] : f32
-  // CHECK:  %[[res:.+]] = arith.maxf %[[min]], %[[const]] : f32
-  // CHECK:   return %[[res]]
-  %const = arith.constant 0.0 : f32
-  %min = arith.minf %const, %arg0 : f32
-  %res = arith.maxf %const, %min : f32
-  return %res : f32
-}
-
-// -----
-// CHECK-LABEL: @cmpf_nan(
-func @cmpf_nan(%arg0 : f32) -> (i1, i1, i1, i1) {
-//   CHECK-DAG:   %[[T:.*]] = arith.constant true
-//   CHECK-DAG:   %[[F:.*]] = arith.constant false
-//       CHECK:   return %[[F]], %[[F]], %[[T]], %[[T]]
-  %nan = arith.constant 0x7fffffff : f32
-  %0 = arith.cmpf olt, %nan, %arg0 : f32
-  %1 = arith.cmpf olt, %arg0, %nan : f32
-  %2 = arith.cmpf ugt, %nan, %arg0 : f32
-  %3 = arith.cmpf ugt, %arg0, %nan : f32
-  return %0, %1, %2, %3 : i1, i1, i1, i1
-}

diff  --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index c9f50af28ef27..ceafa28576fa0 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -878,7 +878,6 @@ func @red_mul_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> {
   // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
   // CHECK: vector.transfer_read {{.*}} : tensor<4x4xf32>, vector<4x4xf32>
   // CHECK: vector.multi_reduction <mul>, {{.*}} [1] : vector<4x4xf32> to vector<4xf32>
-  // CHECK: mulf {{.*}} : vector<4xf32>
   // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
   %ident = arith.constant 1.0 : f32
   %init = linalg.init_tensor [4] : tensor<4xf32>

diff  --git a/mlir/test/Dialect/SCF/loop-pipelining.mlir b/mlir/test/Dialect/SCF/loop-pipelining.mlir
index 83f0b78e1a1aa..9424af25bd12b 100644
--- a/mlir/test/Dialect/SCF/loop-pipelining.mlir
+++ b/mlir/test/Dialect/SCF/loop-pipelining.mlir
@@ -224,17 +224,14 @@ func @loop_carried(%A: memref<?xf32>, %result: memref<?xf32>) {
 //  CHECK-NEXT:   %[[R:.*]]:3 = scf.for %[[IV:.*]] = %[[C0]] to %[[C2]]
 //  CHECK-SAME:     step %[[C1]] iter_args(%[[C:.*]] = %[[CSTF]],
 //  CHECK-SAME:     %[[ADDARG:.*]] = %[[ADD0]], %[[LARG:.*]] = %[[L1]]) -> (f32, f32, f32) {
-//  CHECK-NEXT:     %[[MUL0:.*]] = arith.mulf %[[CSTF]], %[[ADDARG]] : f32
-//  CHECK-NEXT:     %[[ADD1:.*]] = arith.addf %[[LARG]], %[[MUL0]] : f32
+//  CHECK-NEXT:     %[[ADD1:.*]] = arith.addf %[[LARG]], %[[ADDARG]] : f32
 //  CHECK-NEXT:     %[[IV2:.*]] = arith.addi %[[IV]], %[[C2]] : index
 //  CHECK-NEXT:     %[[L2:.*]] = memref.load %[[A]][%[[IV2]]] : memref<?xf32>
-//  CHECK-NEXT:     scf.yield %[[MUL0]], %[[ADD1]], %[[L2]] : f32, f32, f32
+//  CHECK-NEXT:     scf.yield %[[ADDARG]], %[[ADD1]], %[[L2]] : f32, f32, f32
 //  CHECK-NEXT:   }
 // Epilogue:
-//  CHECK-NEXT:   %[[MUL1:.*]] = arith.mulf %[[CSTF]], %[[R]]#1 : f32
-//  CHECK-NEXT:   %[[ADD2:.*]] = arith.addf %[[R]]#2, %[[MUL1]] : f32
-//  CHECK-NEXT:   %[[MUL2:.*]] = arith.mulf %[[CSTF]], %[[ADD2]] : f32
-//  CHECK-NEXT:   return %[[MUL2]] : f32
+//  CHECK-NEXT:   %[[ADD2:.*]] = arith.addf %[[R]]#2, %[[R]]#1 : f32
+//  CHECK-NEXT:   return %[[ADD2]] : f32
 func @backedge_
diff erent_stage(%A: memref<?xf32>) -> f32 {
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
@@ -264,15 +261,13 @@ func @backedge_
diff erent_stage(%A: memref<?xf32>) -> f32 {
 //  CHECK-SAME:     step %[[C1]] iter_args(%[[C:.*]] = %[[CSTF]],
 //  CHECK-SAME:     %[[LARG:.*]] = %[[L0]]) -> (f32, f32) {
 //  CHECK-NEXT:     %[[ADD0:.*]] = arith.addf %[[LARG]], %[[C]] : f32
-//  CHECK-NEXT:     %[[MUL0:.*]] = arith.mulf %[[CSTF]], %[[ADD0]] : f32
 //  CHECK-NEXT:     %[[IV1:.*]] = arith.addi %[[IV]], %[[C1]] : index
 //  CHECK-NEXT:     %[[L2:.*]] = memref.load %[[A]][%[[IV1]]] : memref<?xf32>
-//  CHECK-NEXT:     scf.yield %[[MUL0]], %[[L2]] : f32, f32
+//  CHECK-NEXT:     scf.yield %[[ADD0]], %[[L2]] : f32, f32
 //  CHECK-NEXT:   }
 // Epilogue:
 //  CHECK-NEXT:   %[[ADD1:.*]] = arith.addf %[[R]]#1, %[[R]]#0 : f32
-//  CHECK-NEXT:   %[[MUL1:.*]] = arith.mulf %[[CSTF]], %[[ADD1]] : f32
-//  CHECK-NEXT:   return %[[MUL1]] : f32
+//  CHECK-NEXT:   return %[[ADD1]] : f32
 func @backedge_same_stage(%A: memref<?xf32>) -> f32 {
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index


        


More information about the Mlir-commits mailing list