[Mlir-commits] [mlir] f278cf9 - [MLIR][arith] More float op folders
Christian Sigg
llvmlistbot at llvm.org
Mon Jan 31 10:31:55 PST 2022
Author: Christian Sigg
Date: 2022-01-31T19:31:48+01:00
New Revision: f278cf9cbc3edbd4441da87e23e703d9a331c994
URL: https://github.com/llvm/llvm-project/commit/f278cf9cbc3edbd4441da87e23e703d9a331c994
DIFF: https://github.com/llvm/llvm-project/commit/f278cf9cbc3edbd4441da87e23e703d9a331c994.diff
LOG: [MLIR][arith] More float op folders
Fold `arith.fadd %x, -0.0 -> %x` and similarly for `fsub`, `fmul`, `fdiv`.
Fold `arith.fmin %x, %x -> %x`, `arith.fmin %x, +inf -> %x` and similarly for `fmax`.
Reviewed By: pifon2a, mehdi_amini, bondhugula
Differential Revision: https://reviews.llvm.org/D118244
Added:
Modified:
mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
mlir/include/mlir/IR/Matchers.h
mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
mlir/test/Dialect/Arithmetic/canonicalize.mlir
mlir/test/Dialect/Linalg/vectorization.mlir
mlir/test/Dialect/SCF/loop-pipelining.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
index 5f52e440c4da6..b7eb2730136ef 100644
--- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
@@ -653,6 +653,7 @@ def Arith_MaxFOp : Arith_FloatBinaryOp<"maxf", [Commutative]> {
%a = arith.maxf %b, %c : f64
```
}];
+ let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
@@ -696,6 +697,7 @@ def Arith_MinFOp : Arith_FloatBinaryOp<"minf", [Commutative]> {
%a = arith.minf %b, %c : f64
```
}];
+ let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h
index 78c65f8cce301..cc879cfc80dc4 100644
--- a/mlir/include/mlir/IR/Matchers.h
+++ b/mlir/include/mlir/IR/Matchers.h
@@ -91,6 +91,43 @@ struct constant_op_binder {
}
};
+/// The matcher that matches a constant scalar / vector splat / tensor splat
+/// float operation and binds the constant float value.
+struct constant_float_op_binder {
+ FloatAttr::ValueType *bind_value;
+
+ /// Creates a matcher instance that binds the value to bv if match succeeds.
+ constant_float_op_binder(FloatAttr::ValueType *bv) : bind_value(bv) {}
+
+ bool match(Operation *op) {
+ Attribute attr;
+ if (!constant_op_binder<Attribute>(&attr).match(op))
+ return false;
+ auto type = op->getResult(0).getType();
+
+ if (type.isa<FloatType>())
+ return attr_value_binder<FloatAttr>(bind_value).match(attr);
+ if (type.isa<VectorType, RankedTensorType>()) {
+ if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) {
+ return attr_value_binder<FloatAttr>(bind_value)
+ .match(splatAttr.getSplatValue<Attribute>());
+ }
+ }
+ return false;
+ }
+};
+
+/// The matcher that matches a given target constant scalar / vector splat /
+/// tensor splat float value that fulfills a predicate.
+struct constant_float_predicate_matcher {
+ bool (*predicate)(const APFloat &);
+
+ bool match(Operation *op) {
+ APFloat value(APFloat::Bogus());
+ return constant_float_op_binder(&value).match(op) && predicate(value);
+ }
+};
+
/// The matcher that matches a constant scalar / vector splat / tensor splat
/// integer operation and binds the constant integer value.
struct constant_int_op_binder {
@@ -118,22 +155,13 @@ struct constant_int_op_binder {
};
/// The matcher that matches a given target constant scalar / vector splat /
-/// tensor splat integer value.
-template <int64_t TargetValue>
-struct constant_int_value_matcher {
- bool match(Operation *op) {
- APInt value;
- return constant_int_op_binder(&value).match(op) && TargetValue == value;
- }
-};
+/// tensor splat integer value that fulfills a predicate.
+struct constant_int_predicate_matcher {
+ bool (*predicate)(const APInt &);
-/// The matcher that matches anything except the given target constant scalar /
-/// vector splat / tensor splat integer value.
-template <int64_t TargetNotValue>
-struct constant_int_not_value_matcher {
bool match(Operation *op) {
APInt value;
- return constant_int_op_binder(&value).match(op) && TargetNotValue != value;
+ return constant_int_op_binder(&value).match(op) && predicate(value);
}
};
@@ -239,26 +267,65 @@ inline detail::constant_op_binder<AttrT> m_Constant(AttrT *bind_value) {
return detail::constant_op_binder<AttrT>(bind_value);
}
-/// Matches a constant scalar / vector splat / tensor splat integer one.
-inline detail::constant_int_value_matcher<1> m_One() {
- return detail::constant_int_value_matcher<1>();
+/// Matches a constant scalar / vector splat / tensor splat float (both positive
+/// and negative) zero.
+inline detail::constant_float_predicate_matcher m_AnyZeroFloat() {
+ return {[](const APFloat &value) { return value.isZero(); }};
}
-/// Matches the given OpClass.
-template <typename OpClass>
-inline detail::op_matcher<OpClass> m_Op() {
- return detail::op_matcher<OpClass>();
+/// Matches a constant scalar / vector splat / tensor splat float positive zero.
+inline detail::constant_float_predicate_matcher m_PosZeroFloat() {
+ return {[](const APFloat &value) { return value.isPosZero(); }};
+}
+
+/// Matches a constant scalar / vector splat / tensor splat float negative zero.
+inline detail::constant_float_predicate_matcher m_NegZeroFloat() {
+ return {[](const APFloat &value) { return value.isNegZero(); }};
+}
+
+/// Matches a constant scalar / vector splat / tensor splat float ones.
+inline detail::constant_float_predicate_matcher m_OneFloat() {
+ return {[](const APFloat &value) {
+ return APFloat(value.getSemantics(), 1) == value;
+ }};
+}
+
+/// Matches a constant scalar / vector splat / tensor splat float positive
+/// infinity.
+inline detail::constant_float_predicate_matcher m_PosInfFloat() {
+ return {[](const APFloat &value) {
+ return !value.isNegative() && value.isInfinity();
+ }};
+}
+
+/// Matches a constant scalar / vector splat / tensor splat float negative
+/// infinity.
+inline detail::constant_float_predicate_matcher m_NegInfFloat() {
+ return {[](const APFloat &value) {
+ return value.isNegative() && value.isInfinity();
+ }};
}
/// Matches a constant scalar / vector splat / tensor splat integer zero.
-inline detail::constant_int_value_matcher<0> m_Zero() {
- return detail::constant_int_value_matcher<0>();
+inline detail::constant_int_predicate_matcher m_Zero() {
+ return {[](const APInt &value) { return 0 == value; }};
}
/// Matches a constant scalar / vector splat / tensor splat integer that is any
/// non-zero value.
-inline detail::constant_int_not_value_matcher<0> m_NonZero() {
- return detail::constant_int_not_value_matcher<0>();
+inline detail::constant_int_predicate_matcher m_NonZero() {
+ return {[](const APInt &value) { return 0 != value; }};
+}
+
+/// Matches a constant scalar / vector splat / tensor splat integer one.
+inline detail::constant_int_predicate_matcher m_One() {
+ return {[](const APInt &value) { return 1 == value; }};
+}
+
+/// Matches the given OpClass.
+template <typename OpClass>
+inline detail::op_matcher<OpClass> m_Op() {
+ return detail::op_matcher<OpClass>();
}
/// Entry point for matching a pattern over a Value.
@@ -276,6 +343,13 @@ inline bool matchPattern(Operation *op, const Pattern &pattern) {
return const_cast<Pattern &>(pattern).match(op);
}
+/// Matches a constant holding a scalar/vector/tensor float (splat) and
+/// writes the float value to bind_value.
+inline detail::constant_float_op_binder
+m_ConstantFloat(FloatAttr::ValueType *bind_value) {
+ return detail::constant_float_op_binder(bind_value);
+}
+
/// Matches a constant holding a scalar/vector/tensor integer (splat) and
/// writes the integer value to bind_value.
inline detail::constant_int_op_binder
diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
index d2473ebbff407..022a5674ef9ee 100644
--- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
+++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
@@ -194,12 +194,12 @@ OpFoldResult arith::AddIOp::fold(ArrayRef<Attribute> operands) {
if (matchPattern(getRhs(), m_Zero()))
return getLhs();
- // add(sub(a, b), b) -> a
+ // addi(subi(a, b), b) -> a
if (auto sub = getLhs().getDefiningOp<SubIOp>())
if (getRhs() == sub.getRhs())
return sub.getLhs();
- // add(b, sub(a, b)) -> a
+ // addi(b, subi(a, b)) -> a
if (auto sub = getRhs().getDefiningOp<SubIOp>())
if (getLhs() == sub.getRhs())
return sub.getLhs();
@@ -576,6 +576,14 @@ void arith::XOrIOp::getCanonicalizationPatterns(
//===----------------------------------------------------------------------===//
OpFoldResult arith::AddFOp::fold(ArrayRef<Attribute> operands) {
+ // addf(x, -0) -> x
+ if (matchPattern(getRhs(), m_NegZeroFloat()))
+ return getLhs();
+
+ // addf(-0, x) -> x
+ if (matchPattern(getLhs(), m_NegZeroFloat()))
+ return getRhs();
+
return constFoldBinaryOp<FloatAttr>(
operands, [](const APFloat &a, const APFloat &b) { return a + b; });
}
@@ -585,10 +593,34 @@ OpFoldResult arith::AddFOp::fold(ArrayRef<Attribute> operands) {
//===----------------------------------------------------------------------===//
OpFoldResult arith::SubFOp::fold(ArrayRef<Attribute> operands) {
+ // subf(x, +0) -> x
+ if (matchPattern(getRhs(), m_PosZeroFloat()))
+ return getLhs();
+
return constFoldBinaryOp<FloatAttr>(
operands, [](const APFloat &a, const APFloat &b) { return a - b; });
}
+//===----------------------------------------------------------------------===//
+// MaxFOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult arith::MaxFOp::fold(ArrayRef<Attribute> operands) {
+ assert(operands.size() == 2 && "maxf takes two operands");
+
+ // maxf(x,x) -> x
+ if (getLhs() == getRhs())
+ return getRhs();
+
+ // maxf(x, -inf) -> x
+ if (matchPattern(getRhs(), m_NegInfFloat()))
+ return getLhs();
+
+ return constFoldBinaryOp<FloatAttr>(
+ operands,
+ [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
+}
+
//===----------------------------------------------------------------------===//
// MaxSIOp
//===----------------------------------------------------------------------===//
@@ -643,6 +675,26 @@ OpFoldResult MaxUIOp::fold(ArrayRef<Attribute> operands) {
});
}
+//===----------------------------------------------------------------------===//
+// MinFOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult arith::MinFOp::fold(ArrayRef<Attribute> operands) {
+ assert(operands.size() == 2 && "minf takes two operands");
+
+ // minf(x,x) -> x
+ if (getLhs() == getRhs())
+ return getRhs();
+
+ // minf(x, +inf) -> x
+ if (matchPattern(getRhs(), m_PosInfFloat()))
+ return getLhs();
+
+ return constFoldBinaryOp<FloatAttr>(
+ operands,
+ [](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); });
+}
+
//===----------------------------------------------------------------------===//
// MinSIOp
//===----------------------------------------------------------------------===//
@@ -702,6 +754,15 @@ OpFoldResult MinUIOp::fold(ArrayRef<Attribute> operands) {
//===----------------------------------------------------------------------===//
OpFoldResult arith::MulFOp::fold(ArrayRef<Attribute> operands) {
+ APFloat floatValue(0.0f), inverseValue(0.0f);
+ // mulf(x, 1) -> x
+ if (matchPattern(getRhs(), m_OneFloat()))
+ return getLhs();
+
+ // mulf(1, x) -> x
+ if (matchPattern(getLhs(), m_OneFloat()))
+ return getRhs();
+
return constFoldBinaryOp<FloatAttr>(
operands, [](const APFloat &a, const APFloat &b) { return a * b; });
}
@@ -711,6 +772,11 @@ OpFoldResult arith::MulFOp::fold(ArrayRef<Attribute> operands) {
//===----------------------------------------------------------------------===//
OpFoldResult arith::DivFOp::fold(ArrayRef<Attribute> operands) {
+ APFloat floatValue(0.0f), inverseValue(0.0f);
+ // divf(x, 1) -> x
+ if (matchPattern(getRhs(), m_OneFloat()))
+ return getLhs();
+
return constFoldBinaryOp<FloatAttr>(
operands, [](const APFloat &a, const APFloat &b) { return a / b; });
}
diff --git a/mlir/test/Dialect/Arithmetic/canonicalize.mlir b/mlir/test/Dialect/Arithmetic/canonicalize.mlir
index 22a88d165d300..cf87af2e30e3c 100644
--- a/mlir/test/Dialect/Arithmetic/canonicalize.mlir
+++ b/mlir/test/Dialect/Arithmetic/canonicalize.mlir
@@ -619,6 +619,113 @@ func @test_minui(%arg0 : i8) -> (i8, i8, i8, i8) {
// -----
+// CHECK-LABEL: @test_minf(
+func @test_minf(%arg0 : f32) -> (f32, f32, f32) {
+ // CHECK-DAG: %[[C0:.+]] = arith.constant 0.0
+ // CHECK-NEXT: %[[X:.+]] = arith.minf %arg0, %[[C0]]
+ // CHECK-NEXT: return %[[X]], %arg0, %arg0
+ %c0 = arith.constant 0.0 : f32
+ %inf = arith.constant 0x7F800000 : f32
+ %0 = arith.minf %c0, %arg0 : f32
+ %1 = arith.minf %arg0, %arg0 : f32
+ %2 = arith.minf %inf, %arg0 : f32
+ return %0, %1, %2 : f32, f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: @test_maxf(
+func @test_maxf(%arg0 : f32) -> (f32, f32, f32) {
+ // CHECK-DAG: %[[C0:.+]] = arith.constant
+ // CHECK-NEXT: %[[X:.+]] = arith.maxf %arg0, %[[C0]]
+ // CHECK-NEXT: return %[[X]], %arg0, %arg0
+ %c0 = arith.constant 0.0 : f32
+ %-inf = arith.constant 0xFF800000 : f32
+ %0 = arith.maxf %c0, %arg0 : f32
+ %1 = arith.maxf %arg0, %arg0 : f32
+ %2 = arith.maxf %-inf, %arg0 : f32
+ return %0, %1, %2 : f32, f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: @test_addf(
+func @test_addf(%arg0 : f32) -> (f32, f32, f32, f32) {
+ // CHECK-DAG: %[[C2:.+]] = arith.constant 2.0
+ // CHECK-DAG: %[[C0:.+]] = arith.constant 0.0
+ // CHECK-NEXT: %[[X:.+]] = arith.addf %arg0, %[[C0]]
+ // CHECK-NEXT: return %[[X]], %arg0, %arg0, %[[C2]]
+ %c0 = arith.constant 0.0 : f32
+ %c-0 = arith.constant -0.0 : f32
+ %c1 = arith.constant 1.0 : f32
+ %0 = arith.addf %arg0, %c0 : f32
+ %1 = arith.addf %arg0, %c-0 : f32
+ %2 = arith.addf %c-0, %arg0 : f32
+ %3 = arith.addf %c1, %c1 : f32
+ return %0, %1, %2, %3 : f32, f32, f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: @test_subf(
+func @test_subf(%arg0 : f16) -> (f16, f16, f16) {
+ // CHECK-DAG: %[[C1:.+]] = arith.constant -1.0
+ // CHECK-DAG: %[[C0:.+]] = arith.constant -0.0
+ // CHECK-NEXT: %[[X:.+]] = arith.subf %arg0, %[[C0]]
+ // CHECK-NEXT: return %arg0, %[[X]], %[[C1]]
+ %c0 = arith.constant 0.0 : f16
+ %c-0 = arith.constant -0.0 : f16
+ %c1 = arith.constant 1.0 : f16
+ %0 = arith.subf %arg0, %c0 : f16
+ %1 = arith.subf %arg0, %c-0 : f16
+ %2 = arith.subf %c0, %c1 : f16
+ return %0, %1, %2 : f16, f16, f16
+}
+
+// -----
+
+// CHECK-LABEL: @test_mulf(
+func @test_mulf(%arg0 : f32) -> (f32, f32, f32) {
+ // CHECK-NEXT: %[[C4:.+]] = arith.constant 4.0
+ // CHECK-NEXT: return %arg0, %arg0, %[[C4]]
+ %c1 = arith.constant 1.0 : f32
+ %c2 = arith.constant 2.0 : f32
+ %0 = arith.mulf %arg0, %c1 : f32
+ %1 = arith.mulf %c1, %arg0 : f32
+ %2 = arith.mulf %c2, %c2 : f32
+ return %0, %1, %2 : f32, f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: @test_divf(
+func @test_divf(%arg0 : f64) -> (f64, f64) {
+ // CHECK-NEXT: %[[C5:.+]] = arith.constant 5.000000e-01
+ // CHECK-NEXT: return %arg0, %[[C5]]
+ %c1 = arith.constant 1.0 : f64
+ %c2 = arith.constant 2.0 : f64
+ %0 = arith.divf %arg0, %c1 : f64
+ %1 = arith.divf %c1, %c2 : f64
+ return %0, %1 : f64, f64
+}
+
+// -----
+
+// CHECK-LABEL: @test_cmpf(
+func @test_cmpf(%arg0 : f32) -> (i1, i1, i1, i1) {
+// CHECK-DAG: %[[T:.*]] = arith.constant true
+// CHECK-DAG: %[[F:.*]] = arith.constant false
+// CHECK: return %[[F]], %[[F]], %[[T]], %[[T]]
+ %nan = arith.constant 0x7fffffff : f32
+ %0 = arith.cmpf olt, %nan, %arg0 : f32
+ %1 = arith.cmpf olt, %arg0, %nan : f32
+ %2 = arith.cmpf ugt, %nan, %arg0 : f32
+ %3 = arith.cmpf ugt, %arg0, %nan : f32
+ return %0, %1, %2, %3 : i1, i1, i1, i1
+}
+
+// -----
+
// CHECK-LABEL: @constant_FPtoUI(
func @constant_FPtoUI() -> i32 {
// CHECK: %[[C0:.+]] = arith.constant 2 : i32
@@ -678,30 +785,3 @@ func @constant_UItoFP() -> f32 {
%res = arith.sitofp %c0 : i32 to f32
return %res : f32
}
-
-// -----
-// CHECK-LABEL: @constant_MinMax(
-func @constant_MinMax(%arg0 : f32) -> f32 {
- // CHECK: %[[const:.+]] = arith.constant
- // CHECK: %[[min:.+]] = arith.minf %arg0, %[[const]] : f32
- // CHECK: %[[res:.+]] = arith.maxf %[[min]], %[[const]] : f32
- // CHECK: return %[[res]]
- %const = arith.constant 0.0 : f32
- %min = arith.minf %const, %arg0 : f32
- %res = arith.maxf %const, %min : f32
- return %res : f32
-}
-
-// -----
-// CHECK-LABEL: @cmpf_nan(
-func @cmpf_nan(%arg0 : f32) -> (i1, i1, i1, i1) {
-// CHECK-DAG: %[[T:.*]] = arith.constant true
-// CHECK-DAG: %[[F:.*]] = arith.constant false
-// CHECK: return %[[F]], %[[F]], %[[T]], %[[T]]
- %nan = arith.constant 0x7fffffff : f32
- %0 = arith.cmpf olt, %nan, %arg0 : f32
- %1 = arith.cmpf olt, %arg0, %nan : f32
- %2 = arith.cmpf ugt, %nan, %arg0 : f32
- %3 = arith.cmpf ugt, %arg0, %nan : f32
- return %0, %1, %2, %3 : i1, i1, i1, i1
-}
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index c9f50af28ef27..ceafa28576fa0 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -878,7 +878,6 @@ func @red_mul_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> {
// CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
// CHECK: vector.transfer_read {{.*}} : tensor<4x4xf32>, vector<4x4xf32>
// CHECK: vector.multi_reduction <mul>, {{.*}} [1] : vector<4x4xf32> to vector<4xf32>
- // CHECK: mulf {{.*}} : vector<4xf32>
// CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
%ident = arith.constant 1.0 : f32
%init = linalg.init_tensor [4] : tensor<4xf32>
diff --git a/mlir/test/Dialect/SCF/loop-pipelining.mlir b/mlir/test/Dialect/SCF/loop-pipelining.mlir
index 83f0b78e1a1aa..9424af25bd12b 100644
--- a/mlir/test/Dialect/SCF/loop-pipelining.mlir
+++ b/mlir/test/Dialect/SCF/loop-pipelining.mlir
@@ -224,17 +224,14 @@ func @loop_carried(%A: memref<?xf32>, %result: memref<?xf32>) {
// CHECK-NEXT: %[[R:.*]]:3 = scf.for %[[IV:.*]] = %[[C0]] to %[[C2]]
// CHECK-SAME: step %[[C1]] iter_args(%[[C:.*]] = %[[CSTF]],
// CHECK-SAME: %[[ADDARG:.*]] = %[[ADD0]], %[[LARG:.*]] = %[[L1]]) -> (f32, f32, f32) {
-// CHECK-NEXT: %[[MUL0:.*]] = arith.mulf %[[CSTF]], %[[ADDARG]] : f32
-// CHECK-NEXT: %[[ADD1:.*]] = arith.addf %[[LARG]], %[[MUL0]] : f32
+// CHECK-NEXT: %[[ADD1:.*]] = arith.addf %[[LARG]], %[[ADDARG]] : f32
// CHECK-NEXT: %[[IV2:.*]] = arith.addi %[[IV]], %[[C2]] : index
// CHECK-NEXT: %[[L2:.*]] = memref.load %[[A]][%[[IV2]]] : memref<?xf32>
-// CHECK-NEXT: scf.yield %[[MUL0]], %[[ADD1]], %[[L2]] : f32, f32, f32
+// CHECK-NEXT: scf.yield %[[ADDARG]], %[[ADD1]], %[[L2]] : f32, f32, f32
// CHECK-NEXT: }
// Epilogue:
-// CHECK-NEXT: %[[MUL1:.*]] = arith.mulf %[[CSTF]], %[[R]]#1 : f32
-// CHECK-NEXT: %[[ADD2:.*]] = arith.addf %[[R]]#2, %[[MUL1]] : f32
-// CHECK-NEXT: %[[MUL2:.*]] = arith.mulf %[[CSTF]], %[[ADD2]] : f32
-// CHECK-NEXT: return %[[MUL2]] : f32
+// CHECK-NEXT: %[[ADD2:.*]] = arith.addf %[[R]]#2, %[[R]]#1 : f32
+// CHECK-NEXT: return %[[ADD2]] : f32
func @backedge_
diff erent_stage(%A: memref<?xf32>) -> f32 {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
@@ -264,15 +261,13 @@ func @backedge_
diff erent_stage(%A: memref<?xf32>) -> f32 {
// CHECK-SAME: step %[[C1]] iter_args(%[[C:.*]] = %[[CSTF]],
// CHECK-SAME: %[[LARG:.*]] = %[[L0]]) -> (f32, f32) {
// CHECK-NEXT: %[[ADD0:.*]] = arith.addf %[[LARG]], %[[C]] : f32
-// CHECK-NEXT: %[[MUL0:.*]] = arith.mulf %[[CSTF]], %[[ADD0]] : f32
// CHECK-NEXT: %[[IV1:.*]] = arith.addi %[[IV]], %[[C1]] : index
// CHECK-NEXT: %[[L2:.*]] = memref.load %[[A]][%[[IV1]]] : memref<?xf32>
-// CHECK-NEXT: scf.yield %[[MUL0]], %[[L2]] : f32, f32
+// CHECK-NEXT: scf.yield %[[ADD0]], %[[L2]] : f32, f32
// CHECK-NEXT: }
// Epilogue:
// CHECK-NEXT: %[[ADD1:.*]] = arith.addf %[[R]]#1, %[[R]]#0 : f32
-// CHECK-NEXT: %[[MUL1:.*]] = arith.mulf %[[CSTF]], %[[ADD1]] : f32
-// CHECK-NEXT: return %[[MUL1]] : f32
+// CHECK-NEXT: return %[[ADD1]] : f32
func @backedge_same_stage(%A: memref<?xf32>) -> f32 {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
More information about the Mlir-commits
mailing list