[Mlir-commits] [mlir] 6edfb62 - [mlir] Extend AffineForEmptyLoopFolder

Amy Zhuang llvmlistbot at llvm.org
Tue Mar 8 17:37:28 PST 2022


Author: Amy Zhuang
Date: 2022-03-08T17:17:22-08:00
New Revision: 6edfb628f9cc1008d6d0dd7719483458a324daa8

URL: https://github.com/llvm/llvm-project/commit/6edfb628f9cc1008d6d0dd7719483458a324daa8
DIFF: https://github.com/llvm/llvm-project/commit/6edfb628f9cc1008d6d0dd7719483458a324daa8.diff

LOG: [mlir] Extend AffineForEmptyLoopFolder

Currently when we fold an empty loop, we assume that any loop
with iterArgs returns its iterArgs in order, which is not always
the case. It may return values defined outside of the loop or
return its iterArgs out of order. This patch adds support to
those cases.

Reviewed By: dcaballe

Differential Revision: https://reviews.llvm.org/D120776

Added: 
    

Modified: 
    mlir/lib/Dialect/Affine/IR/AffineOps.cpp
    mlir/test/Dialect/Affine/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 90aa7baef8e31..c2952069b7745 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -1657,6 +1657,16 @@ static LogicalResult canonicalizeLoopBounds(AffineForOp forOp) {
 }
 
 namespace {
+/// Returns constant trip count in trivial cases.
+static Optional<uint64_t> getTrivialConstantTripCount(AffineForOp forOp) {
+  int64_t step = forOp.getStep();
+  if (!forOp.hasConstantBounds() || step <= 0)
+    return None;
+  int64_t lb = forOp.getConstantLowerBound();
+  int64_t ub = forOp.getConstantUpperBound();
+  return ub - lb <= 0 ? 0 : (ub - lb + step - 1) / step;
+}
+
 /// This is a pattern to fold trivially empty loop bodies.
 /// TODO: This should be moved into the folding hook.
 struct AffineForEmptyLoopFolder : public OpRewritePattern<AffineForOp> {
@@ -1667,8 +1677,46 @@ struct AffineForEmptyLoopFolder : public OpRewritePattern<AffineForOp> {
     // Check that the body only contains a yield.
     if (!llvm::hasSingleElement(*forOp.getBody()))
       return failure();
-    // The initial values of the iteration arguments would be the op's results.
-    rewriter.replaceOp(forOp, forOp.getIterOperands());
+    if (forOp.getNumResults() == 0)
+      return success();
+    Optional<uint64_t> tripCount = getTrivialConstantTripCount(forOp);
+    if (tripCount.hasValue() && tripCount.getValue() == 0) {
+      // The initial values of the iteration arguments would be the op's
+      // results.
+      rewriter.replaceOp(forOp, forOp.getIterOperands());
+      return success();
+    }
+    SmallVector<Value, 4> replacements;
+    auto yieldOp = cast<AffineYieldOp>(forOp.getBody()->getTerminator());
+    auto iterArgs = forOp.getRegionIterArgs();
+    bool hasValDefinedOutsideLoop = false;
+    bool iterArgsNotInOrder = false;
+    for (unsigned i = 0, e = yieldOp->getNumOperands(); i < e; ++i) {
+      Value val = yieldOp.getOperand(i);
+      auto iterArgIt = llvm::find(iterArgs, val);
+      if (iterArgIt == iterArgs.end()) {
+        // `val` is defined outside of the loop.
+        assert(forOp.isDefinedOutsideOfLoop(val) &&
+               "must be defined outside of the loop");
+        hasValDefinedOutsideLoop = true;
+        replacements.push_back(val);
+      } else {
+        unsigned pos = std::distance(iterArgs.begin(), iterArgIt);
+        if (pos != i)
+          iterArgsNotInOrder = true;
+        replacements.push_back(forOp.getIterOperands()[pos]);
+      }
+    }
+    // Bail out when the trip count is unknown and the loop returns any value
+    // defined outside of the loop or any iterArg out of order.
+    if (!tripCount.hasValue() &&
+        (hasValDefinedOutsideLoop || iterArgsNotInOrder))
+      return failure();
+    // Bail out when the loop iterates more than once and it returns any iterArg
+    // out of order.
+    if (tripCount.hasValue() && tripCount.getValue() >= 2 && iterArgsNotInOrder)
+      return failure();
+    rewriter.replaceOp(forOp, replacements);
     return success();
   }
 };
@@ -1681,11 +1729,10 @@ void AffineForOp::getCanonicalizationPatterns(RewritePatternSet &results,
 
 /// Returns true if the affine.for has zero iterations in trivial cases.
 static bool hasTrivialZeroTripCount(AffineForOp op) {
-  if (!op.hasConstantBounds())
-    return false;
-  int64_t lb = op.getConstantLowerBound();
-  int64_t ub = op.getConstantUpperBound();
-  return ub - lb <= 0;
+  Optional<uint64_t> tripCount = getTrivialConstantTripCount(op);
+  if (tripCount.hasValue() && tripCount.getValue() == 0)
+    return true;
+  return false;
 }
 
 LogicalResult AffineForOp::fold(ArrayRef<Attribute> operands,

diff  --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir
index bed5f0c9d41d1..cd6b08a7bab11 100644
--- a/mlir/test/Dialect/Affine/canonicalize.mlir
+++ b/mlir/test/Dialect/Affine/canonicalize.mlir
@@ -475,6 +475,112 @@ func @fold_empty_loops() -> index {
 
 // -----
 
+// CHECK-LABEL:  func @fold_empty_loop()
+func @fold_empty_loop() -> (index, index) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %res:2 = affine.for %i = 0 to 10 iter_args(%arg0 = %c0, %arg1 = %c1) -> (index, index) {
+    affine.yield %c2, %arg1 : index, index
+  }
+  // CHECK-DAG: %[[one:.*]] = arith.constant 1
+  // CHECK-DAG: %[[two:.*]] = arith.constant 2
+  // CHECK-NEXT: return %[[two]], %[[one]]
+  return %res#0, %res#1 : index, index
+}
+
+// -----
+
+// CHECK-LABEL:  func @fold_empty_loops_trip_count_1()
+func @fold_empty_loops_trip_count_1() -> (index, index, index, index) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %res1:2 = affine.for %i = 0 to 1 iter_args(%arg0 = %c2, %arg1 = %c0) -> (index, index) {
+    affine.yield %c1, %arg0 : index, index
+  }
+  %res2:2 = affine.for %i = 0 to 2 step 3 iter_args(%arg0 = %c2, %arg1 = %c0) -> (index, index) {
+    affine.yield %arg1, %arg0 : index, index
+  }
+  // CHECK-DAG: %[[zero:.*]] = arith.constant 0
+  // CHECK-DAG: %[[one:.*]] = arith.constant 1
+  // CHECK-DAG: %[[two:.*]] = arith.constant 2
+  // CHECK-NEXT: return %[[one]], %[[two]], %[[zero]], %[[two]]
+  return %res1#0, %res1#1, %res2#0, %res2#1 : index, index, index, index
+}
+
+// -----
+
+// CHECK-LABEL:  func @fold_empty_loop_trip_count_0()
+func @fold_empty_loop_trip_count_0() -> (index, index) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %res:2 = affine.for %i = 0 to 0 iter_args(%arg0 = %c2, %arg1 = %c0) -> (index, index) {
+    affine.yield %c1, %arg0 : index, index
+  }
+  // CHECK-DAG: %[[zero:.*]] = arith.constant 0
+  // CHECK-DAG: %[[two:.*]] = arith.constant 2
+  // CHECK-NEXT: return %[[two]], %[[zero]]
+  return %res#0, %res#1 : index, index
+}
+
+// -----
+
+// CHECK-LABEL:  func @fold_empty_loop_trip_count_unknown
+func @fold_empty_loop_trip_count_unknown(%in : index) -> (index, index) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %res:2 = affine.for %i = 0 to %in iter_args(%arg0 = %c0, %arg1 = %c1) -> (index, index) {
+    affine.yield %arg0, %arg1 : index, index
+  }
+  // CHECK-DAG: %[[zero:.*]] = arith.constant 0
+  // CHECK-DAG: %[[one:.*]] = arith.constant 1
+  // CHECK-NEXT: return %[[zero]], %[[one]]
+  return %res#0, %res#1 : index, index
+}
+
+// -----
+
+// CHECK-LABEL:  func @empty_loops_not_folded_1
+func @empty_loops_not_folded_1(%in : index) -> index {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  // CHECK: affine.for
+  %res = affine.for %i = 0 to %in iter_args(%arg = %c0) -> index {
+    affine.yield %c1 : index
+  }
+  return %res : index
+}
+
+// -----
+
+// CHECK-LABEL:  func @empty_loops_not_folded_2
+func @empty_loops_not_folded_2(%in : index) -> (index, index) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  // CHECK: affine.for
+  %res:2 = affine.for %i = 0 to %in iter_args(%arg0 = %c0, %arg1 = %c1) -> (index, index) {
+    affine.yield %arg1, %arg0 : index, index
+  }
+  return %res#0, %res#1 : index, index
+}
+
+// -----
+
+// CHECK-LABEL:  func @empty_loops_not_folded_3
+func @empty_loops_not_folded_3() -> (index, index) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  // CHECK: affine.for
+  %res:2 = affine.for %i = 0 to 10 iter_args(%arg0 = %c0, %arg1 = %c1) -> (index, index) {
+    affine.yield %arg1, %arg0 : index, index
+  }
+  return %res#0, %res#1 : index, index
+}
+
+// -----
+
 // CHECK-LABEL:  func @fold_zero_iter_loops
 // CHECK-SAME: %[[ARG:.*]]: index
 func @fold_zero_iter_loops(%in : index) -> index {


        


More information about the Mlir-commits mailing list