[Mlir-commits] [mlir] 4373520 - Simplify affine bound min or max with operands

Uday Bondhugula llvmlistbot at llvm.org
Fri Apr 28 18:50:17 PDT 2023


Author: Uday Bondhugula
Date: 2023-04-29T07:12:54+05:30
New Revision: 43735204d59a87e40893460f3a05d80010184f77

URL: https://github.com/llvm/llvm-project/commit/43735204d59a87e40893460f3a05d80010184f77
DIFF: https://github.com/llvm/llvm-project/commit/43735204d59a87e40893460f3a05d80010184f77.diff

LOG: Simplify affine bound min or max with operands

Add canonicalization for affine.for bounds to use operand info (when
operands are outer loop affine.for IVs) to simplify bounds: redundant
bound expressions are eliminated in specific cases that are easy to
check and well-suited for op canonicaliation. If the lowest or the
highest value the affine expression can take is already covered by other
constant bounds, the expression can be removed.

Eg:
`min (d0) -> (32 * d0 + 32, 32)(%i) where 0 <= %i < 2`

The first expression can't be less than 32 and can be simplified away
with a lightweight local rewrite.

This simplification being part of canonicalization only handles simple
expressions, specifically, a sum of products of operands with constants.
This is a very common and a dominant case where such simplification is
desired. These can be flattened without any local variables.

Reviewed By: dcaballe, springerm

Differential Revision: https://reviews.llvm.org/D149007

Added: 
    

Modified: 
    mlir/lib/Dialect/Affine/IR/AffineOps.cpp
    mlir/test/Dialect/Affine/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 8f63b22ea4f4e..537ed8745ba19 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -934,6 +934,118 @@ static void simplifyExprAndOperands(AffineExpr &expr, unsigned numDims,
   }
 }
 
+/// Simplify the expressions in `map` while making use of lower or upper bounds
+/// of its operands. If `isMax` is true, the map is to be treated as a max of
+/// its result expressions, and min otherwise. Eg: min (d0, d1) -> (8, 4 * d0 +
+/// d1) can be simplified to (8) if the operands are respectively lower bounded
+/// by 2 and 0 (the second expression can't be lower than 8).
+static void simplifyMinOrMaxExprWithOperands(AffineMap &map,
+                                             ArrayRef<Value> operands,
+                                             bool isMax) {
+  // Can't simplify.
+  if (operands.empty())
+    return;
+
+  // Get the upper or lower bound on an affine.for op IV using its range.
+  // Get the constant lower or upper bounds on the operands.
+  SmallVector<std::optional<int64_t>> constLowerBounds, constUpperBounds;
+  constLowerBounds.reserve(operands.size());
+  constUpperBounds.reserve(operands.size());
+  for (Value operand : operands) {
+    constLowerBounds.push_back(getLowerBound(operand));
+    constUpperBounds.push_back(getUpperBound(operand));
+  }
+
+  // We will compute the lower and upper bounds on each of the expressions
+  // Then, we will check (depending on max or min) as to whether a specific
+  // bound is redundant by checking if its highest (in case of max) and its
+  // lowest (in the case of min) value is already lower than (or higher than)
+  // the lower bound (or upper bound in the case of min) of another bound.
+  SmallVector<Optional<int64_t>, 4> lowerBounds, upperBounds;
+  lowerBounds.reserve(map.getNumResults());
+  upperBounds.reserve(map.getNumResults());
+  for (AffineExpr e : map.getResults()) {
+    if (auto constExpr = e.dyn_cast<AffineConstantExpr>()) {
+      lowerBounds.push_back(constExpr.getValue());
+      upperBounds.push_back(constExpr.getValue());
+    } else {
+      lowerBounds.push_back(getBoundForExpr(e, map.getNumDims(),
+                                            map.getNumSymbols(),
+                                            constLowerBounds, constUpperBounds,
+                                            /*isUpper=*/false));
+      upperBounds.push_back(getBoundForExpr(e, map.getNumDims(),
+                                            map.getNumSymbols(),
+                                            constLowerBounds, constUpperBounds,
+                                            /*isUpper=*/true));
+    }
+  }
+
+  // Collect expressions that are not redundant.
+  SmallVector<AffineExpr, 4> irredundantExprs;
+  for (auto exprEn : llvm::enumerate(map.getResults())) {
+    AffineExpr e = exprEn.value();
+    unsigned i = exprEn.index();
+    // Some expressions can be turned into constants.
+    if (lowerBounds[i] && upperBounds[i] && *lowerBounds[i] == *upperBounds[i])
+      e = getAffineConstantExpr(*lowerBounds[i], e.getContext());
+
+    // Check if the expression is redundant.
+    if (isMax) {
+      if (!upperBounds[i]) {
+        irredundantExprs.push_back(e);
+        continue;
+      }
+      // If there exists another expression such that its lower bound is greater
+      // than this expression's upper bound, it's redundant.
+      if (!llvm::any_of(llvm::enumerate(lowerBounds), [&](const auto &en) {
+            auto otherLowerBound = en.value();
+            unsigned pos = en.index();
+            if (pos == i || !otherLowerBound)
+              return false;
+            if (*otherLowerBound > *upperBounds[i])
+              return true;
+            if (*otherLowerBound < *upperBounds[i])
+              return false;
+            // Equality case. When both expressions are considered redundant, we
+            // don't want to get both of them. We keep the one that appears
+            // first.
+            if (upperBounds[pos] && lowerBounds[i] &&
+                lowerBounds[i] == upperBounds[i] &&
+                otherLowerBound == *upperBounds[pos] && i < pos)
+              return false;
+            return true;
+          }))
+        irredundantExprs.push_back(e);
+    } else {
+      if (!lowerBounds[i]) {
+        irredundantExprs.push_back(e);
+        continue;
+      }
+      // Likewise for the `min` case. Use the complement of the condition above.
+      if (!llvm::any_of(llvm::enumerate(upperBounds), [&](const auto &en) {
+            auto otherUpperBound = en.value();
+            unsigned pos = en.index();
+            if (pos == i || !otherUpperBound)
+              return false;
+            if (*otherUpperBound < *lowerBounds[i])
+              return true;
+            if (*otherUpperBound > *lowerBounds[i])
+              return false;
+            if (lowerBounds[pos] && upperBounds[i] &&
+                lowerBounds[i] == upperBounds[i] &&
+                otherUpperBound == lowerBounds[pos] && i < pos)
+              return false;
+            return true;
+          }))
+        irredundantExprs.push_back(e);
+    }
+  }
+
+  // Create the map without the redundant expressions.
+  map = AffineMap::get(map.getNumDims(), map.getNumSymbols(), irredundantExprs,
+                       map.getContext());
+}
+
 /// Simplify the map while exploiting information on the values in `operands`.
 //  Use "unused attribute" marker to silence warning stemming from the inability
 //  to see through the template expansion.
@@ -2217,6 +2329,8 @@ static LogicalResult canonicalizeLoopBounds(AffineForOp forOp) {
 
   composeAffineMapAndOperands(&lbMap, &lbOperands);
   canonicalizeMapAndOperands(&lbMap, &lbOperands);
+  simplifyMinOrMaxExprWithOperands(lbMap, lbOperands, /*isMax=*/true);
+  simplifyMinOrMaxExprWithOperands(ubMap, ubOperands, /*isMax=*/false);
   lbMap = removeDuplicateExprs(lbMap);
 
   composeAffineMapAndOperands(&ubMap, &ubOperands);

diff  --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir
index d6ca5033d6fc5..ace9798c368dc 100644
--- a/mlir/test/Dialect/Affine/canonicalize.mlir
+++ b/mlir/test/Dialect/Affine/canonicalize.mlir
@@ -1264,6 +1264,141 @@ func.func @simplify_div_mod_with_operands(%N: index, %A: memref<64xf32>, %unknow
 
 // -----
 
+#map0 = affine_map<(d0) -> (32, d0 * -32 + 32)>
+#map1 = affine_map<(d0) -> (32, d0 * -32 + 64)>
+#map3 = affine_map<(d0) -> (16, d0 * -16 + 32)>
+
+// CHECK-DAG: #[[$SIMPLE_MAP:.*]] = affine_map<()[s0] -> (3, s0)>
+// CHECK-DAG: #[[$SIMPLE_MAP_MAX:.*]] = affine_map<()[s0] -> (5, s0)>
+// CHECK-DAG: #[[$SIMPLIFIED_MAP:.*]] = affine_map<(d0, d1) -> (-9, d0 * 4 - d1 * 4)>
+// CHECK-DAG: #[[$FLOORDIV:.*]] = affine_map<(d0) -> (d0 floordiv 2)>
+
+// CHECK-LABEL: func @simplify_min_max_bounds_simple
+func.func @simplify_min_max_bounds_simple(%M: index) {
+
+  // CHECK-NEXT: affine.for %{{.*}} = 0 to min #[[$SIMPLE_MAP]]
+  affine.for %i = 0 to min affine_map<(d0) -> (3, 5, d0)>(%M) {
+    "test.foo"() : () -> ()
+  }
+
+  // CHECK: affine.for %{{.*}} = 0 to min #[[$SIMPLE_MAP]]
+  affine.for %i = 0 to min affine_map<(d0) -> (3, 3, d0)>(%M) {
+    "test.foo"() : () -> ()
+  }
+
+  // CHECK: affine.for %{{.*}} = max #[[$SIMPLE_MAP_MAX]]
+  affine.for %i = max affine_map<(d0) -> (3, 5, d0)>(%M) to 10 {
+    "test.foo"() : () -> ()
+  }
+
+  // CHECK: affine.for %{{.*}} = max #[[$SIMPLE_MAP_MAX]]
+  affine.for %i = max affine_map<(d0) -> (5, 5, d0)>(%M) to 10 {
+    "test.foo"() : () -> ()
+  }
+
+  return
+}
+
+// CHECK-LABEL: func @simplify_bounds_tiled
+func.func @simplify_bounds_tiled() {
+  affine.for %arg5 = 0 to 1 {
+    affine.for %arg6 = 0 to 2 {
+      affine.for %arg8 = 0 to min #map0(%arg5) step 16 {
+        affine.for %arg9 = 0 to min #map1(%arg6) step 16 {
+          affine.for %arg10 = 0 to 2 {
+            affine.for %arg12 = 0 to min #map3(%arg10) step 16 {
+              "test.foo"() : () -> ()
+            }
+          }
+        }
+      }
+    }
+  }
+  // CHECK:      affine.for
+  // CHECK-NEXT:   affine.for
+  // CHECK-NEXT:     affine.for %{{.*}} = 0 to 32 step 16
+  // CHECK-NEXT:       affine.for %{{.*}} = 0 to 32 step 16
+  // CHECK-NEXT:         affine.for %{{.*}} = 0 to 2
+  // CHECK-NEXT:           affine.for %{{.*}} = 0 to 16 step 16
+
+  return
+}
+
+// CHECK-LABEL: func @simplify_min_max_multi_expr
+func.func @simplify_min_max_multi_expr() {
+  // Lower bound max.
+  // CHECK: affine.for
+  affine.for %i = 0 to 2 {
+    // CHECK: affine.for %{{.*}} = 5 to
+    affine.for %j = max affine_map<(d0) -> (5, 4 * d0)> (%i) to affine_map<(d0) -> (4 * d0 + 3)>(%i) {
+      "test.foo"() : () -> ()
+    }
+  }
+
+  // Expressions with multiple operands.
+  // CHECK: affine.for
+  affine.for %i = 0 to 2 {
+    // CHECK: affine.for
+    affine.for %j = 0 to 4 {
+      // The first upper bound expression will not be lower than -9. So, it's redundant.
+      // CHECK-NEXT: affine.for %{{.*}} = -10 to -9
+      affine.for %k = -10 to min affine_map<(d0, d1) -> (4 * d0 - 3 * d1, -9)>(%i, %j) {
+        "test.foo"() : () -> ()
+      }
+    }
+  }
+
+  // One expression is redundant but not the others.
+  // CHECK: affine.for
+  affine.for %i = 0 to 2 {
+    // CHECK: affine.for
+    affine.for %j = 0 to 4 {
+      // The first upper bound expression will not be lower than -9. So, it's redundant.
+      // CHECK-NEXT: affine.for %{{.*}} = -10 to min #[[$SIMPLIFIED_MAP]]
+      affine.for %k = -10 to min affine_map<(d0, d1) -> (4 * d0 - 3 * d1, -9, 4 * d0 - 4 * d1)>(%i, %j) {
+        "test.foo"() : () -> ()
+      }
+    }
+  }
+
+  // CHECK: affine.for %{{.*}} = 0 to 1
+  affine.for %i = 0 to 2 {
+    affine.for %j = max affine_map<(d0) -> (d0 floordiv 2, 0)>(%i) to 1 {
+      "test.foo"() : () -> ()
+    }
+  }
+
+  // The constant bound is redundant here.
+  // CHECK: affine.for %{{.*}} = #[[$FLOORDIV]](%{{.*}} to 10
+  affine.for %i = 0 to 8 {
+    affine.for %j = max affine_map<(d0) -> (d0 floordiv 2, 0)>(%i) to 10 {
+      "test.foo"() : () -> ()
+    }
+  }
+
+  return
+}
+
+// CHECK-LABEL: func @no_simplify_min_max
+func.func @no_simplify_min_max(%M: index) {
+  // Negative test cases.
+  // CHECK: affine.for
+  affine.for %i = 0 to 4 {
+    // CHECK-NEXT: affine.for %{{.*}} = 0 to min
+    affine.for %j = 0 to min affine_map<(d0) -> (2 * d0, 2)>(%i) {
+      "test.foo"() : () -> ()
+    }
+    // CHECK:      affine.for %{{.*}} = 0 to min {{.*}}(%{{.*}})[%{{.*}}]
+    affine.for %j = 0 to min affine_map<(d0)[s0] -> (d0, s0)>(%i)[%M] {
+      "test.foo"() : () -> ()
+    }
+  }
+
+  return
+}
+
+// -----
+
 //           CHECK: #[[$map:.*]] = affine_map<()[s0] -> (s0 * ((-s0 + 40961) ceildiv 512))>
 // CHECK-BOTTOM-UP: #[[$map:.*]] = affine_map<()[s0] -> (s0 * ((-s0 + 40961) ceildiv 512))>
 //           CHECK-LABEL: func @regression_do_not_perform_invalid_replacements


        


More information about the Mlir-commits mailing list