[Mlir-commits] [mlir] 4373520 - Simplify affine bound min or max with operands
Uday Bondhugula
llvmlistbot at llvm.org
Fri Apr 28 18:50:17 PDT 2023
Author: Uday Bondhugula
Date: 2023-04-29T07:12:54+05:30
New Revision: 43735204d59a87e40893460f3a05d80010184f77
URL: https://github.com/llvm/llvm-project/commit/43735204d59a87e40893460f3a05d80010184f77
DIFF: https://github.com/llvm/llvm-project/commit/43735204d59a87e40893460f3a05d80010184f77.diff
LOG: Simplify affine bound min or max with operands
Add canonicalization for affine.for bounds to use operand info (when
operands are outer loop affine.for IVs) to simplify bounds: redundant
bound expressions are eliminated in specific cases that are easy to
check and well-suited for op canonicaliation. If the lowest or the
highest value the affine expression can take is already covered by other
constant bounds, the expression can be removed.
Eg:
`min (d0) -> (32 * d0 + 32, 32)(%i) where 0 <= %i < 2`
The first expression can't be less than 32 and can be simplified away
with a lightweight local rewrite.
This simplification being part of canonicalization only handles simple
expressions, specifically, a sum of products of operands with constants.
This is a very common and a dominant case where such simplification is
desired. These can be flattened without any local variables.
Reviewed By: dcaballe, springerm
Differential Revision: https://reviews.llvm.org/D149007
Added:
Modified:
mlir/lib/Dialect/Affine/IR/AffineOps.cpp
mlir/test/Dialect/Affine/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 8f63b22ea4f4e..537ed8745ba19 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -934,6 +934,118 @@ static void simplifyExprAndOperands(AffineExpr &expr, unsigned numDims,
}
}
+/// Simplify the expressions in `map` while making use of lower or upper bounds
+/// of its operands. If `isMax` is true, the map is to be treated as a max of
+/// its result expressions, and min otherwise. Eg: min (d0, d1) -> (8, 4 * d0 +
+/// d1) can be simplified to (8) if the operands are respectively lower bounded
+/// by 2 and 0 (the second expression can't be lower than 8).
+static void simplifyMinOrMaxExprWithOperands(AffineMap &map,
+ ArrayRef<Value> operands,
+ bool isMax) {
+ // Can't simplify.
+ if (operands.empty())
+ return;
+
+ // Get the upper or lower bound on an affine.for op IV using its range.
+ // Get the constant lower or upper bounds on the operands.
+ SmallVector<std::optional<int64_t>> constLowerBounds, constUpperBounds;
+ constLowerBounds.reserve(operands.size());
+ constUpperBounds.reserve(operands.size());
+ for (Value operand : operands) {
+ constLowerBounds.push_back(getLowerBound(operand));
+ constUpperBounds.push_back(getUpperBound(operand));
+ }
+
+ // We will compute the lower and upper bounds on each of the expressions
+ // Then, we will check (depending on max or min) as to whether a specific
+ // bound is redundant by checking if its highest (in case of max) and its
+ // lowest (in the case of min) value is already lower than (or higher than)
+ // the lower bound (or upper bound in the case of min) of another bound.
+ SmallVector<Optional<int64_t>, 4> lowerBounds, upperBounds;
+ lowerBounds.reserve(map.getNumResults());
+ upperBounds.reserve(map.getNumResults());
+ for (AffineExpr e : map.getResults()) {
+ if (auto constExpr = e.dyn_cast<AffineConstantExpr>()) {
+ lowerBounds.push_back(constExpr.getValue());
+ upperBounds.push_back(constExpr.getValue());
+ } else {
+ lowerBounds.push_back(getBoundForExpr(e, map.getNumDims(),
+ map.getNumSymbols(),
+ constLowerBounds, constUpperBounds,
+ /*isUpper=*/false));
+ upperBounds.push_back(getBoundForExpr(e, map.getNumDims(),
+ map.getNumSymbols(),
+ constLowerBounds, constUpperBounds,
+ /*isUpper=*/true));
+ }
+ }
+
+ // Collect expressions that are not redundant.
+ SmallVector<AffineExpr, 4> irredundantExprs;
+ for (auto exprEn : llvm::enumerate(map.getResults())) {
+ AffineExpr e = exprEn.value();
+ unsigned i = exprEn.index();
+ // Some expressions can be turned into constants.
+ if (lowerBounds[i] && upperBounds[i] && *lowerBounds[i] == *upperBounds[i])
+ e = getAffineConstantExpr(*lowerBounds[i], e.getContext());
+
+ // Check if the expression is redundant.
+ if (isMax) {
+ if (!upperBounds[i]) {
+ irredundantExprs.push_back(e);
+ continue;
+ }
+ // If there exists another expression such that its lower bound is greater
+ // than this expression's upper bound, it's redundant.
+ if (!llvm::any_of(llvm::enumerate(lowerBounds), [&](const auto &en) {
+ auto otherLowerBound = en.value();
+ unsigned pos = en.index();
+ if (pos == i || !otherLowerBound)
+ return false;
+ if (*otherLowerBound > *upperBounds[i])
+ return true;
+ if (*otherLowerBound < *upperBounds[i])
+ return false;
+ // Equality case. When both expressions are considered redundant, we
+ // don't want to get both of them. We keep the one that appears
+ // first.
+ if (upperBounds[pos] && lowerBounds[i] &&
+ lowerBounds[i] == upperBounds[i] &&
+ otherLowerBound == *upperBounds[pos] && i < pos)
+ return false;
+ return true;
+ }))
+ irredundantExprs.push_back(e);
+ } else {
+ if (!lowerBounds[i]) {
+ irredundantExprs.push_back(e);
+ continue;
+ }
+ // Likewise for the `min` case. Use the complement of the condition above.
+ if (!llvm::any_of(llvm::enumerate(upperBounds), [&](const auto &en) {
+ auto otherUpperBound = en.value();
+ unsigned pos = en.index();
+ if (pos == i || !otherUpperBound)
+ return false;
+ if (*otherUpperBound < *lowerBounds[i])
+ return true;
+ if (*otherUpperBound > *lowerBounds[i])
+ return false;
+ if (lowerBounds[pos] && upperBounds[i] &&
+ lowerBounds[i] == upperBounds[i] &&
+ otherUpperBound == lowerBounds[pos] && i < pos)
+ return false;
+ return true;
+ }))
+ irredundantExprs.push_back(e);
+ }
+ }
+
+ // Create the map without the redundant expressions.
+ map = AffineMap::get(map.getNumDims(), map.getNumSymbols(), irredundantExprs,
+ map.getContext());
+}
+
/// Simplify the map while exploiting information on the values in `operands`.
// Use "unused attribute" marker to silence warning stemming from the inability
// to see through the template expansion.
@@ -2217,6 +2329,8 @@ static LogicalResult canonicalizeLoopBounds(AffineForOp forOp) {
composeAffineMapAndOperands(&lbMap, &lbOperands);
canonicalizeMapAndOperands(&lbMap, &lbOperands);
+ simplifyMinOrMaxExprWithOperands(lbMap, lbOperands, /*isMax=*/true);
+ simplifyMinOrMaxExprWithOperands(ubMap, ubOperands, /*isMax=*/false);
lbMap = removeDuplicateExprs(lbMap);
composeAffineMapAndOperands(&ubMap, &ubOperands);
diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir
index d6ca5033d6fc5..ace9798c368dc 100644
--- a/mlir/test/Dialect/Affine/canonicalize.mlir
+++ b/mlir/test/Dialect/Affine/canonicalize.mlir
@@ -1264,6 +1264,141 @@ func.func @simplify_div_mod_with_operands(%N: index, %A: memref<64xf32>, %unknow
// -----
+#map0 = affine_map<(d0) -> (32, d0 * -32 + 32)>
+#map1 = affine_map<(d0) -> (32, d0 * -32 + 64)>
+#map3 = affine_map<(d0) -> (16, d0 * -16 + 32)>
+
+// CHECK-DAG: #[[$SIMPLE_MAP:.*]] = affine_map<()[s0] -> (3, s0)>
+// CHECK-DAG: #[[$SIMPLE_MAP_MAX:.*]] = affine_map<()[s0] -> (5, s0)>
+// CHECK-DAG: #[[$SIMPLIFIED_MAP:.*]] = affine_map<(d0, d1) -> (-9, d0 * 4 - d1 * 4)>
+// CHECK-DAG: #[[$FLOORDIV:.*]] = affine_map<(d0) -> (d0 floordiv 2)>
+
+// CHECK-LABEL: func @simplify_min_max_bounds_simple
+func.func @simplify_min_max_bounds_simple(%M: index) {
+
+ // CHECK-NEXT: affine.for %{{.*}} = 0 to min #[[$SIMPLE_MAP]]
+ affine.for %i = 0 to min affine_map<(d0) -> (3, 5, d0)>(%M) {
+ "test.foo"() : () -> ()
+ }
+
+ // CHECK: affine.for %{{.*}} = 0 to min #[[$SIMPLE_MAP]]
+ affine.for %i = 0 to min affine_map<(d0) -> (3, 3, d0)>(%M) {
+ "test.foo"() : () -> ()
+ }
+
+ // CHECK: affine.for %{{.*}} = max #[[$SIMPLE_MAP_MAX]]
+ affine.for %i = max affine_map<(d0) -> (3, 5, d0)>(%M) to 10 {
+ "test.foo"() : () -> ()
+ }
+
+ // CHECK: affine.for %{{.*}} = max #[[$SIMPLE_MAP_MAX]]
+ affine.for %i = max affine_map<(d0) -> (5, 5, d0)>(%M) to 10 {
+ "test.foo"() : () -> ()
+ }
+
+ return
+}
+
+// CHECK-LABEL: func @simplify_bounds_tiled
+func.func @simplify_bounds_tiled() {
+ affine.for %arg5 = 0 to 1 {
+ affine.for %arg6 = 0 to 2 {
+ affine.for %arg8 = 0 to min #map0(%arg5) step 16 {
+ affine.for %arg9 = 0 to min #map1(%arg6) step 16 {
+ affine.for %arg10 = 0 to 2 {
+ affine.for %arg12 = 0 to min #map3(%arg10) step 16 {
+ "test.foo"() : () -> ()
+ }
+ }
+ }
+ }
+ }
+ }
+ // CHECK: affine.for
+ // CHECK-NEXT: affine.for
+ // CHECK-NEXT: affine.for %{{.*}} = 0 to 32 step 16
+ // CHECK-NEXT: affine.for %{{.*}} = 0 to 32 step 16
+ // CHECK-NEXT: affine.for %{{.*}} = 0 to 2
+ // CHECK-NEXT: affine.for %{{.*}} = 0 to 16 step 16
+
+ return
+}
+
+// CHECK-LABEL: func @simplify_min_max_multi_expr
+func.func @simplify_min_max_multi_expr() {
+ // Lower bound max.
+ // CHECK: affine.for
+ affine.for %i = 0 to 2 {
+ // CHECK: affine.for %{{.*}} = 5 to
+ affine.for %j = max affine_map<(d0) -> (5, 4 * d0)> (%i) to affine_map<(d0) -> (4 * d0 + 3)>(%i) {
+ "test.foo"() : () -> ()
+ }
+ }
+
+ // Expressions with multiple operands.
+ // CHECK: affine.for
+ affine.for %i = 0 to 2 {
+ // CHECK: affine.for
+ affine.for %j = 0 to 4 {
+ // The first upper bound expression will not be lower than -9. So, it's redundant.
+ // CHECK-NEXT: affine.for %{{.*}} = -10 to -9
+ affine.for %k = -10 to min affine_map<(d0, d1) -> (4 * d0 - 3 * d1, -9)>(%i, %j) {
+ "test.foo"() : () -> ()
+ }
+ }
+ }
+
+ // One expression is redundant but not the others.
+ // CHECK: affine.for
+ affine.for %i = 0 to 2 {
+ // CHECK: affine.for
+ affine.for %j = 0 to 4 {
+ // The first upper bound expression will not be lower than -9. So, it's redundant.
+ // CHECK-NEXT: affine.for %{{.*}} = -10 to min #[[$SIMPLIFIED_MAP]]
+ affine.for %k = -10 to min affine_map<(d0, d1) -> (4 * d0 - 3 * d1, -9, 4 * d0 - 4 * d1)>(%i, %j) {
+ "test.foo"() : () -> ()
+ }
+ }
+ }
+
+ // CHECK: affine.for %{{.*}} = 0 to 1
+ affine.for %i = 0 to 2 {
+ affine.for %j = max affine_map<(d0) -> (d0 floordiv 2, 0)>(%i) to 1 {
+ "test.foo"() : () -> ()
+ }
+ }
+
+ // The constant bound is redundant here.
+ // CHECK: affine.for %{{.*}} = #[[$FLOORDIV]](%{{.*}} to 10
+ affine.for %i = 0 to 8 {
+ affine.for %j = max affine_map<(d0) -> (d0 floordiv 2, 0)>(%i) to 10 {
+ "test.foo"() : () -> ()
+ }
+ }
+
+ return
+}
+
+// CHECK-LABEL: func @no_simplify_min_max
+func.func @no_simplify_min_max(%M: index) {
+ // Negative test cases.
+ // CHECK: affine.for
+ affine.for %i = 0 to 4 {
+ // CHECK-NEXT: affine.for %{{.*}} = 0 to min
+ affine.for %j = 0 to min affine_map<(d0) -> (2 * d0, 2)>(%i) {
+ "test.foo"() : () -> ()
+ }
+ // CHECK: affine.for %{{.*}} = 0 to min {{.*}}(%{{.*}})[%{{.*}}]
+ affine.for %j = 0 to min affine_map<(d0)[s0] -> (d0, s0)>(%i)[%M] {
+ "test.foo"() : () -> ()
+ }
+ }
+
+ return
+}
+
+// -----
+
// CHECK: #[[$map:.*]] = affine_map<()[s0] -> (s0 * ((-s0 + 40961) ceildiv 512))>
// CHECK-BOTTOM-UP: #[[$map:.*]] = affine_map<()[s0] -> (s0 * ((-s0 + 40961) ceildiv 512))>
// CHECK-LABEL: func @regression_do_not_perform_invalid_replacements
More information about the Mlir-commits
mailing list