[Mlir-commits] [mlir] 10c9ecc - [mlir][NFC] Replace some nested if with logical and.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun May 22 19:24:02 PDT 2022


Author: jacquesguan
Date: 2022-05-23T02:04:48Z
New Revision: 10c9ecce9f6096e18222a331c5e7d085bd813f75

URL: https://github.com/llvm/llvm-project/commit/10c9ecce9f6096e18222a331c5e7d085bd813f75
DIFF: https://github.com/llvm/llvm-project/commit/10c9ecce9f6096e18222a331c5e7d085bd813f75.diff

LOG: [mlir][NFC] Replace some nested if with logical and.

This patch replaces some nested if statement with logical and to reduce the nesting depth.

Differential Revision: https://reviews.llvm.org/D126050

Added: 
    

Modified: 
    mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
index 7db71a636c9d7..04a41acc041bc 100644
--- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
+++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
@@ -1296,20 +1296,16 @@ OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) {
 
   if (matchPattern(getRhs(), m_Zero())) {
     if (auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
-      if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1) {
-        // extsi(%x : i1 -> iN) != 0  ->  %x
-        if (getPredicate() == arith::CmpIPredicate::ne) {
-          return extOp.getOperand();
-        }
-      }
+      // extsi(%x : i1 -> iN) != 0  ->  %x
+      if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1 &&
+          getPredicate() == arith::CmpIPredicate::ne)
+        return extOp.getOperand();
     }
     if (auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
-      if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1) {
-        // extui(%x : i1 -> iN) != 0  ->  %x
-        if (getPredicate() == arith::CmpIPredicate::ne) {
-          return extOp.getOperand();
-        }
-      }
+      // extui(%x : i1 -> iN) != 0  ->  %x
+      if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1 &&
+          getPredicate() == arith::CmpIPredicate::ne)
+        return extOp.getOperand();
     }
   }
 
@@ -1733,24 +1729,24 @@ struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {
       return failure();
 
     // select %x, c1, %c0 => extui %arg
-    if (matchPattern(op.getTrueValue(), m_One()))
-      if (matchPattern(op.getFalseValue(), m_Zero())) {
-        rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, op.getType(),
-                                                    op.getCondition());
-        return success();
-      }
+    if (matchPattern(op.getTrueValue(), m_One()) &&
+        matchPattern(op.getFalseValue(), m_Zero())) {
+      rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, op.getType(),
+                                                  op.getCondition());
+      return success();
+    }
 
     // select %x, c0, %c1 => extui (xor %arg, true)
-    if (matchPattern(op.getTrueValue(), m_Zero()))
-      if (matchPattern(op.getFalseValue(), m_One())) {
-        rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
-            op, op.getType(),
-            rewriter.create<arith::XOrIOp>(
-                op.getLoc(), op.getCondition(),
-                rewriter.create<arith::ConstantIntOp>(
-                    op.getLoc(), 1, op.getCondition().getType())));
-        return success();
-      }
+    if (matchPattern(op.getTrueValue(), m_Zero()) &&
+        matchPattern(op.getFalseValue(), m_One())) {
+      rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
+          op, op.getType(),
+          rewriter.create<arith::XOrIOp>(
+              op.getLoc(), op.getCondition(),
+              rewriter.create<arith::ConstantIntOp>(
+                  op.getLoc(), 1, op.getCondition().getType())));
+      return success();
+    }
 
     return failure();
   }
@@ -1778,10 +1774,9 @@ OpFoldResult arith::SelectOp::fold(ArrayRef<Attribute> operands) {
     return falseVal;
 
   // select %x, true, false => %x
-  if (getType().isInteger(1))
-    if (matchPattern(getTrueValue(), m_One()))
-      if (matchPattern(getFalseValue(), m_Zero()))
-        return condition;
+  if (getType().isInteger(1) && matchPattern(getTrueValue(), m_One()) &&
+      matchPattern(getFalseValue(), m_Zero()))
+    return condition;
 
   if (auto cmp = dyn_cast_or_null<arith::CmpIOp>(condition.getDefiningOp())) {
     auto pred = cmp.getPredicate();


        


More information about the Mlir-commits mailing list