[Mlir-commits] [mlir] abc17a6 - [mlir][Arithmetic] Use matchPattern to simplify code.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Apr 22 01:43:02 PDT 2022


Author: jacquesguan
Date: 2022-04-22T08:42:51Z
New Revision: abc17a67519747be36f1fd03e227c5103da4c677

URL: https://github.com/llvm/llvm-project/commit/abc17a67519747be36f1fd03e227c5103da4c677
DIFF: https://github.com/llvm/llvm-project/commit/abc17a67519747be36f1fd03e227c5103da4c677.diff

LOG: [mlir][Arithmetic] Use matchPattern to simplify code.

This patch replaces some code with matchPattern and move them before the constant folder function in order to avoid redundant invoking.

Differential Revision: https://reviews.llvm.org/D124235

Added: 
    

Modified: 
    mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
index 1fa4b1b8032a2..5a104400c48a8 100644
--- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
+++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
@@ -262,6 +262,10 @@ OpFoldResult arith::MulIOp::fold(ArrayRef<Attribute> operands) {
 //===----------------------------------------------------------------------===//
 
 OpFoldResult arith::DivUIOp::fold(ArrayRef<Attribute> operands) {
+  // divui (x, 1) -> x.
+  if (matchPattern(getRhs(), m_One()))
+    return getLhs();
+
   // Don't fold if it would require a division by zero.
   bool div0 = false;
   auto result =
@@ -273,15 +277,6 @@ OpFoldResult arith::DivUIOp::fold(ArrayRef<Attribute> operands) {
         return a.udiv(b);
       });
 
-  // Fold out division by one. Assumes all tensors of all ones are splats.
-  if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
-    if (rhs.getValue() == 1)
-      return getLhs();
-  } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
-    if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
-      return getLhs();
-  }
-
   return div0 ? Attribute() : result;
 }
 
@@ -290,6 +285,10 @@ OpFoldResult arith::DivUIOp::fold(ArrayRef<Attribute> operands) {
 //===----------------------------------------------------------------------===//
 
 OpFoldResult arith::DivSIOp::fold(ArrayRef<Attribute> operands) {
+  // divsi (x, 1) -> x.
+  if (matchPattern(getRhs(), m_One()))
+    return getLhs();
+
   // Don't fold if it would overflow or if it requires a division by zero.
   bool overflowOrDiv0 = false;
   auto result =
@@ -301,15 +300,6 @@ OpFoldResult arith::DivSIOp::fold(ArrayRef<Attribute> operands) {
         return a.sdiv_ov(b, overflowOrDiv0);
       });
 
-  // Fold out division by one. Assumes all tensors of all ones are splats.
-  if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
-    if (rhs.getValue() == 1)
-      return getLhs();
-  } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
-    if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
-      return getLhs();
-  }
-
   return overflowOrDiv0 ? Attribute() : result;
 }
 
@@ -330,6 +320,10 @@ static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b,
 //===----------------------------------------------------------------------===//
 
 OpFoldResult arith::CeilDivUIOp::fold(ArrayRef<Attribute> operands) {
+  // ceildivui (x, 1) -> x.
+  if (matchPattern(getRhs(), m_One()))
+    return getLhs();
+
   bool overflowOrDiv0 = false;
   auto result =
       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
@@ -343,15 +337,6 @@ OpFoldResult arith::CeilDivUIOp::fold(ArrayRef<Attribute> operands) {
         APInt one(a.getBitWidth(), 1, true);
         return quotient.uadd_ov(one, overflowOrDiv0);
       });
-  // Fold out ceil division by one. Assumes all tensors of all ones are
-  // splats.
-  if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
-    if (rhs.getValue() == 1)
-      return getLhs();
-  } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
-    if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
-      return getLhs();
-  }
 
   return overflowOrDiv0 ? Attribute() : result;
 }
@@ -361,6 +346,10 @@ OpFoldResult arith::CeilDivUIOp::fold(ArrayRef<Attribute> operands) {
 //===----------------------------------------------------------------------===//
 
 OpFoldResult arith::CeilDivSIOp::fold(ArrayRef<Attribute> operands) {
+  // ceildivsi (x, 1) -> x.
+  if (matchPattern(getRhs(), m_One()))
+    return getLhs();
+
   // Don't fold if it would overflow or if it requires a division by zero.
   bool overflowOrDiv0 = false;
   auto result =
@@ -398,16 +387,6 @@ OpFoldResult arith::CeilDivSIOp::fold(ArrayRef<Attribute> operands) {
         return zero.ssub_ov(div, overflowOrDiv0);
       });
 
-  // Fold out ceil division by one. Assumes all tensors of all ones are
-  // splats.
-  if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
-    if (rhs.getValue() == 1)
-      return getLhs();
-  } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
-    if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
-      return getLhs();
-  }
-
   return overflowOrDiv0 ? Attribute() : result;
 }
 
@@ -416,6 +395,10 @@ OpFoldResult arith::CeilDivSIOp::fold(ArrayRef<Attribute> operands) {
 //===----------------------------------------------------------------------===//
 
 OpFoldResult arith::FloorDivSIOp::fold(ArrayRef<Attribute> operands) {
+  // floordivsi (x, 1) -> x.
+  if (matchPattern(getRhs(), m_One()))
+    return getLhs();
+
   // Don't fold if it would overflow or if it requires a division by zero.
   bool overflowOrDiv0 = false;
   auto result =
@@ -453,16 +436,6 @@ OpFoldResult arith::FloorDivSIOp::fold(ArrayRef<Attribute> operands) {
         return zero.ssub_ov(ceil, overflowOrDiv0);
       });
 
-  // Fold out floor division by one. Assumes all tensors of all ones are
-  // splats.
-  if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
-    if (rhs.getValue() == 1)
-      return getLhs();
-  } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
-    if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
-      return getLhs();
-  }
-
   return overflowOrDiv0 ? Attribute() : result;
 }
 


        


More information about the Mlir-commits mailing list