[Mlir-commits] [mlir] 687f782 - [MLIR] Fold away divs and mods in affine ops with operand info

Uday Bondhugula llvmlistbot at llvm.org
Fri Feb 10 00:11:41 PST 2023


Author: Uday Bondhugula
Date: 2023-02-10T13:39:56+05:30
New Revision: 687f78210d07fe7f7741273bbaaf8c18864b6191

URL: https://github.com/llvm/llvm-project/commit/687f78210d07fe7f7741273bbaaf8c18864b6191
DIFF: https://github.com/llvm/llvm-project/commit/687f78210d07fe7f7741273bbaaf8c18864b6191.diff

LOG: [MLIR] Fold away divs and mods in affine ops with operand info

Fold away divs and mods in affine maps exploiting operand info during
canonicalization. This simplifies affine map applications such as the ones
below:

```
// Simple ones.
affine.for %i = 0 to 32 {
  affine.load %A[%i floordiv 32]
  affine.load %A[%i mod 32]
  affine.load %A[2 * %i floordiv 64]
  affine.load %A[(%i mod 16) floordiv 16]
  ...
}

// Others.
 affine.for %i = -8 to 32 {
   // Will be simplified %A[0].
   affine.store %cst, %A[2 + (%i - 96) floordiv 64] : memref<64xf32>
}
```

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D143456

Added: 
    

Modified: 
    mlir/lib/Dialect/Affine/IR/AffineOps.cpp
    mlir/test/Dialect/Affine/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 4481c147a8f8c..284e099fb4a9e 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -16,6 +16,7 @@
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/ShapedOpInterfaces.h"
+#include "mlir/Support/MathExtras.h"
 #include "mlir/Transforms/InliningUtils.h"
 #include "llvm/ADT/ScopeExit.h"
 #include "llvm/ADT/SmallBitVector.h"
@@ -673,8 +674,168 @@ static bool isQTimesDPlusR(AffineExpr e, ArrayRef<Value> operands, int64_t &div,
   return false;
 }
 
+/// Gets the constant lower bound on an `iv`.
+static std::optional<int64_t> getLowerBound(Value iv) {
+  AffineForOp forOp = getForInductionVarOwner(iv);
+  if (forOp && forOp.hasConstantLowerBound())
+    return forOp.getConstantLowerBound();
+  return std::nullopt;
+}
+
+/// Gets the constant upper bound on an affine.for `iv`.
+static Optional<int64_t> getUpperBound(Value iv) {
+  AffineForOp forOp = getForInductionVarOwner(iv);
+  if (!forOp || !forOp.hasConstantUpperBound())
+    return std::nullopt;
+
+  // If its lower bound is also known, we can get a more precise bound
+  // whenever the step is not one.
+  if (forOp.hasConstantLowerBound()) {
+    return forOp.getConstantUpperBound() - 1 -
+           (forOp.getConstantUpperBound() - forOp.getConstantLowerBound() - 1) %
+               forOp.getStep();
+  }
+  return forOp.getConstantUpperBound() - 1;
+}
+
+/// Get a lower or upper (depending on `isUpper`) bound for `expr` while using
+/// the constant lower and upper bounds for its inputs provided in
+/// `constLowerBounds` and `constUpperBounds`. Return None if such a bound can't
+/// be computed. This method only handles simple sum of product expressions
+/// (w.r.t constant coefficients) so as to not depend on anything heavyweight in
+/// `Analysis`. Expressions of the form: c0*d0 + c1*d1 + c2*s0 + ... + c_n are
+/// handled. Expressions involving floordiv, ceildiv, mod or semi-affine ones
+/// will lead a none being returned.
+static std::optional<int64_t>
+getBoundForExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols,
+                ArrayRef<Optional<int64_t>> constLowerBounds,
+                ArrayRef<Optional<int64_t>> constUpperBounds, bool isUpper) {
+  // Handle divs and mods.
+  if (auto binOpExpr = expr.dyn_cast<AffineBinaryOpExpr>()) {
+    // If the LHS of a floor or ceil is bounded and the RHS is a constant, we
+    // can compute an upper bound.
+    if (binOpExpr.getKind() == AffineExprKind::FloorDiv) {
+      auto rhsConst = binOpExpr.getRHS().dyn_cast<AffineConstantExpr>();
+      if (!rhsConst || rhsConst.getValue() < 1)
+        return std::nullopt;
+      auto bound = getBoundForExpr(binOpExpr.getLHS(), numDims, numSymbols,
+                                   constLowerBounds, constUpperBounds, isUpper);
+      if (!bound)
+        return std::nullopt;
+      return mlir::floorDiv(*bound, rhsConst.getValue());
+    }
+    if (binOpExpr.getKind() == AffineExprKind::CeilDiv) {
+      auto rhsConst = binOpExpr.getRHS().dyn_cast<AffineConstantExpr>();
+      if (rhsConst && rhsConst.getValue() >= 1) {
+        auto bound =
+            getBoundForExpr(binOpExpr.getLHS(), numDims, numSymbols,
+                            constLowerBounds, constUpperBounds, isUpper);
+        if (!bound)
+          return std::nullopt;
+        return mlir::ceilDiv(*bound, rhsConst.getValue());
+      }
+      return std::nullopt;
+    }
+    if (binOpExpr.getKind() == AffineExprKind::Mod) {
+      // lhs mod c is always <= c - 1 and non-negative. In addition, if `lhs` is
+      // bounded such that lb <= lhs <= ub and lb floordiv c == ub floordiv c
+      // (same "interval"), then lb mod c <= lhs mod c <= ub mod c.
+      auto rhsConst = binOpExpr.getRHS().dyn_cast<AffineConstantExpr>();
+      if (rhsConst && rhsConst.getValue() >= 1) {
+        int64_t rhsConstVal = rhsConst.getValue();
+        auto lb = getBoundForExpr(binOpExpr.getLHS(), numDims, numSymbols,
+                                  constLowerBounds, constUpperBounds,
+                                  /*isUpper=*/false);
+        auto ub = getBoundForExpr(binOpExpr.getLHS(), numDims, numSymbols,
+                                  constLowerBounds, constUpperBounds, isUpper);
+        if (ub && lb &&
+            floorDiv(*lb, rhsConstVal) == floorDiv(*ub, rhsConstVal))
+          return isUpper ? mod(*ub, rhsConstVal) : mod(*lb, rhsConstVal);
+        return isUpper ? rhsConstVal - 1 : 0;
+      }
+    }
+  }
+  // Flatten the expression.
+  SimpleAffineExprFlattener flattener(numDims, numSymbols);
+  flattener.walkPostOrder(expr);
+  ArrayRef<int64_t> flattenedExpr = flattener.operandExprStack.back();
+  // TODO: Handle local variables. We can get hold of flattener.localExprs and
+  // get bound on the local expr recursively.
+  if (flattener.numLocals > 0)
+    return std::nullopt;
+  int64_t bound = 0;
+  // Substitute the constant lower or upper bound for the dimensional or
+  // symbolic input depending on `isUpper` to determine the bound.
+  for (unsigned i = 0, e = numDims + numSymbols; i < e; ++i) {
+    if (flattenedExpr[i] > 0) {
+      auto &constBound = isUpper ? constUpperBounds[i] : constLowerBounds[i];
+      if (!constBound)
+        return std::nullopt;
+      bound += *constBound * flattenedExpr[i];
+    } else if (flattenedExpr[i] < 0) {
+      auto &constBound = isUpper ? constLowerBounds[i] : constUpperBounds[i];
+      if (!constBound)
+        return std::nullopt;
+      bound += *constBound * flattenedExpr[i];
+    }
+  }
+  // Constant term.
+  bound += flattenedExpr.back();
+  return bound;
+}
+
+/// Determine a constant upper bound for `expr` if one exists while exploiting
+/// values in `operands`. Note that the upper bound is an inclusive one. `expr`
+/// is guaranteed to be less than or equal to it.
+static Optional<int64_t> getUpperBound(AffineExpr expr, unsigned numDims,
+                                       unsigned numSymbols,
+                                       ArrayRef<Value> operands) {
+  // Get the constant lower or upper bounds on the operands.
+  SmallVector<Optional<int64_t>> constLowerBounds, constUpperBounds;
+  constLowerBounds.reserve(operands.size());
+  constUpperBounds.reserve(operands.size());
+  for (Value operand : operands) {
+    constLowerBounds.push_back(getLowerBound(operand));
+    constUpperBounds.push_back(getUpperBound(operand));
+  }
+
+  if (auto constExpr = expr.dyn_cast<AffineConstantExpr>())
+    return constExpr.getValue();
+
+  return getBoundForExpr(expr, numDims, numSymbols, constLowerBounds,
+                         constUpperBounds,
+                         /*isUpper=*/true);
+}
+
+/// Determine a constant lower bound for `expr` if one exists while exploiting
+/// values in `operands`. Note that the upper bound is an inclusive one. `expr`
+/// is guaranteed to be less than or equal to it.
+static Optional<int64_t> getLowerBound(AffineExpr expr, unsigned numDims,
+                                       unsigned numSymbols,
+                                       ArrayRef<Value> operands) {
+  // Get the constant lower or upper bounds on the operands.
+  SmallVector<Optional<int64_t>> constLowerBounds, constUpperBounds;
+  constLowerBounds.reserve(operands.size());
+  constUpperBounds.reserve(operands.size());
+  for (Value operand : operands) {
+    constLowerBounds.push_back(getLowerBound(operand));
+    constUpperBounds.push_back(getUpperBound(operand));
+  }
+
+  Optional<int64_t> lowerBound;
+  if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) {
+    lowerBound = constExpr.getValue();
+  } else {
+    lowerBound = getBoundForExpr(expr, numDims, numSymbols, constLowerBounds,
+                                 constUpperBounds,
+                                 /*isUpper=*/false);
+  }
+  return lowerBound;
+}
+
 /// Simplify `expr` while exploiting information from the values in `operands`.
-static void simplifyExprAndOperands(AffineExpr &expr,
+static void simplifyExprAndOperands(AffineExpr &expr, unsigned numDims,
+                                    unsigned numSymbols,
                                     ArrayRef<Value> operands) {
   // We do this only for certain floordiv/mod expressions.
   auto binExpr = expr.dyn_cast<AffineBinaryOpExpr>();
@@ -684,13 +845,14 @@ static void simplifyExprAndOperands(AffineExpr &expr,
   // Simplify the child expressions first.
   AffineExpr lhs = binExpr.getLHS();
   AffineExpr rhs = binExpr.getRHS();
-  simplifyExprAndOperands(lhs, operands);
-  simplifyExprAndOperands(rhs, operands);
+  simplifyExprAndOperands(lhs, numDims, numSymbols, operands);
+  simplifyExprAndOperands(rhs, numDims, numSymbols, operands);
   expr = getAffineBinaryOpExpr(binExpr.getKind(), lhs, rhs);
 
   binExpr = expr.dyn_cast<AffineBinaryOpExpr>();
-  if (!binExpr || (binExpr.getKind() != AffineExprKind::FloorDiv &&
-                   binExpr.getKind() != AffineExprKind::Mod)) {
+  if (!binExpr || (expr.getKind() != AffineExprKind::FloorDiv &&
+                   expr.getKind() != AffineExprKind::CeilDiv &&
+                   expr.getKind() != AffineExprKind::Mod)) {
     return;
   }
 
@@ -703,16 +865,50 @@ static void simplifyExprAndOperands(AffineExpr &expr,
 
   int64_t rhsConstVal = rhsConst.getValue();
   // Undefined exprsessions aren't touched; IR can still be valid with them.
-  if (rhsConstVal == 0)
+  if (rhsConstVal <= 0)
     return;
 
-  AffineExpr quotientTimesDiv, rem;
-  int64_t divisor;
+  // Exploit constant lower/upper bounds to simplify a floordiv or mod.
+  MLIRContext *context = expr.getContext();
+  std::optional<int64_t> lhsLbConst =
+      getLowerBound(lhs, numDims, numSymbols, operands);
+  std::optional<int64_t> lhsUbConst =
+      getUpperBound(lhs, numDims, numSymbols, operands);
+  if (lhsLbConst && lhsUbConst) {
+    int64_t lhsLbConstVal = *lhsLbConst;
+    int64_t lhsUbConstVal = *lhsUbConst;
+    // lhs floordiv c is a single value lhs is bounded in a range `c` that has
+    // the same quotient.
+    if (binExpr.getKind() == AffineExprKind::FloorDiv &&
+        floorDiv(lhsLbConstVal, rhsConstVal) ==
+            floorDiv(lhsUbConstVal, rhsConstVal)) {
+      expr =
+          getAffineConstantExpr(floorDiv(lhsLbConstVal, rhsConstVal), context);
+      return;
+    }
+    // lhs ceildiv c is a single value if the entire range has the same ceil
+    // quotient.
+    if (binExpr.getKind() == AffineExprKind::CeilDiv &&
+        ceilDiv(lhsLbConstVal, rhsConstVal) ==
+            ceilDiv(lhsUbConstVal, rhsConstVal)) {
+      expr =
+          getAffineConstantExpr(ceilDiv(lhsLbConstVal, rhsConstVal), context);
+      return;
+    }
+    // lhs mod c is lhs if the entire range has quotient 0 w.r.t the rhs.
+    if (binExpr.getKind() == AffineExprKind::Mod && lhsLbConstVal >= 0 &&
+        lhsLbConstVal < rhsConstVal && lhsUbConstVal < rhsConstVal) {
+      expr = lhs;
+      return;
+    }
+  }
 
   // Simplify expressions of the form e = (e_1 + e_2) floordiv c or (e_1 + e_2)
   // mod c, where e_1 is a multiple of `k` and 0 <= e_2 < k. In such cases, if
   // `c` % `k` == 0, (e_1 + e_2) floordiv c can be simplified to e_1 floordiv c.
   // And when k % c == 0, (e_1 + e_2) mod c can be simplified to e_2 mod c.
+  AffineExpr quotientTimesDiv, rem;
+  int64_t divisor;
   if (isQTimesDPlusR(lhs, operands, divisor, quotientTimesDiv, rem)) {
     if (rhsConstVal % divisor == 0 &&
         binExpr.getKind() == AffineExprKind::FloorDiv) {
@@ -745,7 +941,8 @@ simplifyMapWithOperands(AffineMap &map, ArrayRef<Value> operands) {
   SmallVector<AffineExpr> newResults;
   newResults.reserve(map.getNumResults());
   for (AffineExpr expr : map.getResults()) {
-    simplifyExprAndOperands(expr, operands);
+    simplifyExprAndOperands(expr, map.getNumDims(), map.getNumSymbols(),
+                            operands);
     newResults.push_back(expr);
   }
   map = AffineMap::get(map.getNumDims(), map.getNumSymbols(), newResults,

diff  --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir
index ddebab4e5f4b5..d6ca5033d6fc5 100644
--- a/mlir/test/Dialect/Affine/canonicalize.mlir
+++ b/mlir/test/Dialect/Affine/canonicalize.mlir
@@ -1170,8 +1170,8 @@ func.func @simplify_with_operands(%N: index, %A: memref<?x32xf32>) {
       "test.foo"(%x) : (f32) -> ()
 
       // %i is aligned at 32 boundary and %ii < 32.
-      // CHECK: affine.load %{{.*}}[%[[I]] floordiv 32, %[[II]] mod 32]
-      %a = affine.load %A[(%i + %ii) floordiv 32, (%i + %ii) mod 32] : memref<?x32xf32>
+      // CHECK: affine.load %{{.*}}[%[[I]] floordiv 32, %[[II]] mod 16]
+      %a = affine.load %A[(%i + %ii) floordiv 32, (%i + %ii) mod 16] : memref<?x32xf32>
       "test.foo"(%a) : (f32) -> ()
       // CHECK: affine.load %{{.*}}[%[[I]] floordiv 64, (%[[I]] + %[[II]]) mod 64]
       %b = affine.load %A[(%i + %ii) floordiv 64, (%i + %ii) mod 64] : memref<?x32xf32>
@@ -1202,6 +1202,66 @@ func.func @simplify_with_operands(%N: index, %A: memref<?x32xf32>) {
   return
 }
 
+// CHECK-LABEL: func @simplify_div_mod_with_operands
+func.func @simplify_div_mod_with_operands(%N: index, %A: memref<64xf32>, %unknown: index) {
+  // CHECK: affine.for %[[I:.*]] = 0 to 32
+  %cst = arith.constant 1.0 : f32
+  affine.for %i = 0 to 32 {
+    // CHECK: affine.store %{{.*}}, %{{.*}}[0]
+    affine.store %cst, %A[%i floordiv 32] : memref<64xf32>
+    // CHECK: affine.store %{{.*}}, %{{.*}}[1]
+    affine.store %cst, %A[(%i + 1) ceildiv 32] : memref<64xf32>
+    // CHECK: affine.store %{{.*}}, %{{.*}}[%[[I]]]
+    affine.store %cst, %A[%i mod 32] : memref<64xf32>
+    // CHECK: affine.store %{{.*}}, %{{.*}}[0]
+    affine.store %cst, %A[2 * %i floordiv 64] : memref<64xf32>
+    // CHECK: affine.store %{{.*}}, %{{.*}}[0]
+    affine.store %cst, %A[(%i mod 16) floordiv 16] : memref<64xf32>
+
+    // The ones below can't be simplified.
+    affine.store %cst, %A[%i floordiv 16] : memref<64xf32>
+    affine.store %cst, %A[%i mod 16] : memref<64xf32>
+    affine.store %cst, %A[(%i mod 16) floordiv 15] : memref<64xf32>
+    affine.store %cst, %A[%i mod 31] : memref<64xf32>
+    // CHECK:      affine.store %{{.*}}, %{{.*}}[%{{.*}} floordiv 16] : memref<64xf32>
+    // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}} mod 16] : memref<64xf32>
+    // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[(%{{.*}} mod 16) floordiv 15] : memref<64xf32>
+    // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}} mod 31] : memref<64xf32>
+  }
+
+  affine.for %i = -8 to 32 {
+    // Can't be simplified.
+    // CHECK: affine.store %{{.*}}, %{{.*}}[%{{.*}} floordiv 32] : memref<64xf32>
+    affine.store %cst, %A[%i floordiv 32] : memref<64xf32>
+    // CHECK: affine.store %{{.*}}, %{{.*}}[%{{.*}} mod 32] : memref<64xf32>
+    affine.store %cst, %A[%i mod 32] : memref<64xf32>
+    // floordiv rounds toward -inf; (%i - 96) floordiv 64 will be -2.
+    // CHECK: affine.store %{{.*}}, %{{.*}}[0] : memref<64xf32>
+    affine.store %cst, %A[2 + (%i - 96) floordiv 64] : memref<64xf32>
+  }
+
+  // CHECK: affine.for %[[II:.*]] = 8 to 16
+  affine.for %i = 8 to 16 {
+    // CHECK: affine.store %{{.*}}, %{{.*}}[1] : memref<64xf32>
+    affine.store %cst, %A[%i floordiv 8] : memref<64xf32>
+    // CHECK: affine.store %{{.*}}, %{{.*}}[2] : memref<64xf32>
+    affine.store %cst, %A[(%i + 1) ceildiv 8] : memref<64xf32>
+    // CHECK: affine.store %{{.*}}, %{{.*}}[%[[II]] mod 8] : memref<64xf32>
+    affine.store %cst, %A[%i mod 8] : memref<64xf32>
+    // CHECK: affine.store %{{.*}}, %{{.*}}[%[[II]]] : memref<64xf32>
+    affine.store %cst, %A[%i mod 32] : memref<64xf32>
+    // Upper bound on the mod 32 expression will be 15.
+    // CHECK: affine.store %{{.*}}, %{{.*}}[0] : memref<64xf32>
+    affine.store %cst, %A[(%i mod 32) floordiv 16] : memref<64xf32>
+    // Lower bound on the mod 16 expression will be 8.
+    // CHECK: affine.store %{{.*}}, %{{.*}}[1] : memref<64xf32>
+    affine.store %cst, %A[(%i mod 16) floordiv 8] : memref<64xf32>
+    // CHECK: affine.store %{{.*}}, %{{.*}}[0] : memref<64xf32>
+    affine.store %cst, %A[(%unknown mod 16) floordiv 16] : memref<64xf32>
+  }
+  return
+}
+
 // -----
 
 //           CHECK: #[[$map:.*]] = affine_map<()[s0] -> (s0 * ((-s0 + 40961) ceildiv 512))>


        


More information about the Mlir-commits mailing list