[Mlir-commits] [mlir] 97567bd - [MLIR][SCF] Canonicalize while statement whose cmp condition is recomputed in the after region
    William S. Moses 
    llvmlistbot at llvm.org
       
    Tue Jan 11 15:34:07 PST 2022
    
    
  
Author: William S. Moses
Date: 2022-01-11T18:34:04-05:00
New Revision: 97567bde5baaf87c0ac71cfc114dad56442b85d3
URL: https://github.com/llvm/llvm-project/commit/97567bde5baaf87c0ac71cfc114dad56442b85d3
DIFF: https://github.com/llvm/llvm-project/commit/97567bde5baaf87c0ac71cfc114dad56442b85d3.diff
LOG: [MLIR][SCF] Canonicalize while statement whose cmp condition is recomputed in the after region
Given a while loop whose condition is given by a cmp, don't recomputed the comparison (or its inverse) in the after region, instead use a constant since  the original condition must be true if we branched to the after region.
Reviewed By: mehdi_amini
Differential Revision: https://reviews.llvm.org/D117047
Added: 
    
Modified: 
    mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h
    mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
    mlir/lib/Dialect/SCF/SCF.cpp
    mlir/test/Dialect/SCF/canonicalize.mlir
Removed: 
    
################################################################################
diff  --git a/mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h b/mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h
index 31d6239388454..65d819f18df28 100644
--- a/mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h
+++ b/mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h
@@ -121,6 +121,8 @@ Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder,
 /// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
 Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc,
                      Value lhs, Value rhs);
+
+arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred);
 } // namespace arith
 } // namespace mlir
 
diff  --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
index 69acc19fb9e4f..188d6e56543a4 100644
--- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
+++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
@@ -40,7 +40,7 @@ static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res,
 }
 
 /// Invert an integer comparison predicate.
-static arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred) {
+arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
   switch (pred) {
   case arith::CmpIPredicate::eq:
     return arith::CmpIPredicate::ne;
diff  --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp
index 3d6d2052f7fe1..6b9fe80bbccdd 100644
--- a/mlir/lib/Dialect/SCF/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/SCF.cpp
@@ -2443,11 +2443,76 @@ struct WhileUnusedResult : public OpRewritePattern<WhileOp> {
     return success();
   }
 };
+
+/// Replace operations equivalent to the condition in the do block with true,
+/// since otherwise the block would not be evaluated.
+///
+/// scf.while (..) : (i32, ...) -> ... {
+///  %z = ... : i32
+///  %condition = cmpi pred %z, %a
+///  scf.condition(%condition) %z : i32, ...
+/// } do {
+/// ^bb0(%arg0: i32, ...):
+///    %condition2 = cmpi pred %arg0, %a
+///    use(%condition2)
+///    ...
+///
+/// becomes
+/// scf.while (..) : (i32, ...) -> ... {
+///  %z = ... : i32
+///  %condition = cmpi pred %z, %a
+///  scf.condition(%condition) %z : i32, ...
+/// } do {
+/// ^bb0(%arg0: i32, ...):
+///    use(%true)
+///    ...
+struct WhileCmpCond : public OpRewritePattern<scf::WhileOp> {
+  using OpRewritePattern<scf::WhileOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(scf::WhileOp op,
+                                PatternRewriter &rewriter) const override {
+    using namespace scf;
+    auto cond = op.getConditionOp();
+    auto cmp = cond.getCondition().getDefiningOp<arith::CmpIOp>();
+    if (!cmp)
+      return failure();
+    bool changed = false;
+    for (auto tup :
+         llvm::zip(cond.getArgs(), op.getAfter().front().getArguments())) {
+      for (size_t opIdx = 0; opIdx < 2; opIdx++) {
+        if (std::get<0>(tup) != cmp.getOperand(opIdx))
+          continue;
+        for (OpOperand &u :
+             llvm::make_early_inc_range(std::get<1>(tup).getUses())) {
+          auto cmp2 = dyn_cast<arith::CmpIOp>(u.getOwner());
+          if (!cmp2)
+            continue;
+          // For a binary operator 1-opIdx gets the other side.
+          if (cmp2.getOperand(1 - opIdx) != cmp.getOperand(1 - opIdx))
+            continue;
+          bool samePredicate;
+          if (cmp2.getPredicate() == cmp.getPredicate())
+            samePredicate = true;
+          else if (cmp2.getPredicate() ==
+                   arith::invertPredicate(cmp.getPredicate()))
+            samePredicate = false;
+          else
+            continue;
+
+          rewriter.replaceOpWithNewOp<arith::ConstantIntOp>(cmp2, samePredicate,
+                                                            1);
+          changed = true;
+        }
+      }
+    }
+    return success(changed);
+  }
+};
 } // namespace
 
 void WhileOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
                                           MLIRContext *context) {
-  results.insert<WhileConditionTruth, WhileUnusedResult>(context);
+  results.insert<WhileConditionTruth, WhileUnusedResult, WhileCmpCond>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff  --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 7f424d892b764..2b3a31e833a3d 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -872,6 +872,60 @@ func @while_unused_result() -> i32 {
 // CHECK-NEXT:         }
 // CHECK-NEXT:         return %[[res]] : i32
 
+// CHECK-LABEL: @while_cmp_lhs
+func @while_cmp_lhs(%arg0 : i32) {
+  %0 = scf.while () : () -> i32 {
+    %val = "test.val"() : () -> i32
+    %condition = arith.cmpi ne, %val, %arg0 : i32
+    scf.condition(%condition) %val : i32
+  } do {
+  ^bb0(%val2: i32):
+    %condition2 = arith.cmpi ne, %val2, %arg0 : i32
+    %negcondition2 = arith.cmpi eq, %val2, %arg0 : i32
+    "test.use"(%condition2, %negcondition2, %val2) : (i1, i1, i32) -> ()
+    scf.yield
+  }
+  return
+}
+// CHECK-DAG:         %[[true:.+]] = arith.constant true
+// CHECK-DAG:         %[[false:.+]] = arith.constant false
+// CHECK-DAG:         %{{.+}} = scf.while : () -> i32 {
+// CHECK-NEXT:         %[[val:.+]] = "test.val"
+// CHECK-NEXT:         %[[cmp:.+]] = arith.cmpi ne, %[[val]], %arg0 : i32
+// CHECK-NEXT:           scf.condition(%[[cmp]]) %[[val]] : i32
+// CHECK-NEXT:         } do {
+// CHECK-NEXT:         ^bb0(%arg1: i32):  // no predecessors
+// CHECK-NEXT:           "test.use"(%[[true]], %[[false]], %arg1) : (i1, i1, i32) -> ()
+// CHECK-NEXT:           scf.yield
+// CHECK-NEXT:         }
+
+// CHECK-LABEL: @while_cmp_rhs
+func @while_cmp_rhs(%arg0 : i32) {
+  %0 = scf.while () : () -> i32 {
+    %val = "test.val"() : () -> i32
+    %condition = arith.cmpi ne, %arg0, %val : i32
+    scf.condition(%condition) %val : i32
+  } do {
+  ^bb0(%val2: i32):
+    %condition2 = arith.cmpi ne, %arg0, %val2 : i32
+    %negcondition2 = arith.cmpi eq, %arg0, %val2 : i32
+    "test.use"(%condition2, %negcondition2, %val2) : (i1, i1, i32) -> ()
+    scf.yield
+  }
+  return
+}
+// CHECK-DAG:         %[[true:.+]] = arith.constant true
+// CHECK-DAG:         %[[false:.+]] = arith.constant false
+// CHECK-DAG:         %{{.+}} = scf.while : () -> i32 {
+// CHECK-NEXT:         %[[val:.+]] = "test.val"
+// CHECK-NEXT:         %[[cmp:.+]] = arith.cmpi ne, %arg0, %[[val]] : i32
+// CHECK-NEXT:           scf.condition(%[[cmp]]) %[[val]] : i32
+// CHECK-NEXT:         } do {
+// CHECK-NEXT:         ^bb0(%arg1: i32):  // no predecessors
+// CHECK-NEXT:           "test.use"(%[[true]], %[[false]], %arg1) : (i1, i1, i32) -> ()
+// CHECK-NEXT:           scf.yield
+// CHECK-NEXT:         }
+
 // -----
 
 // CHECK-LABEL: @combineIfs
        
    
    
More information about the Mlir-commits
mailing list