[Mlir-commits] [mlir] 1a0a177 - [MLIR] Create fold for cmp of ext

William S. Moses llvmlistbot at llvm.org
Sun Jan 2 16:48:56 PST 2022


Author: William S. Moses
Date: 2022-01-02T19:48:52-05:00
New Revision: 1a0a177965e88d61b5d3cd3e7f7f89011f0827c1

URL: https://github.com/llvm/llvm-project/commit/1a0a177965e88d61b5d3cd3e7f7f89011f0827c1
DIFF: https://github.com/llvm/llvm-project/commit/1a0a177965e88d61b5d3cd3e7f7f89011f0827c1.diff

LOG: [MLIR] Create fold for cmp of ext

This patch creates folds for cmpi( ext(%x : i1, iN) != 0) -> %x

In essence this matches patterns matching an extension of a boolean, that != 0, which is equivalent to the original condition.

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D116504

Added: 
    

Modified: 
    mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
    mlir/test/Dialect/Arithmetic/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
index 91cbf4bdb5280..2a6b463bd4e85 100644
--- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
+++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
@@ -1150,6 +1150,25 @@ OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) {
     return getBoolAttribute(getType(), getContext(), val);
   }
 
+  if (matchPattern(getRhs(), m_Zero())) {
+    if (auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
+      if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1) {
+        // extsi(%x : i1 -> iN) != 0  ->  %x
+        if (getPredicate() == arith::CmpIPredicate::ne) {
+          return extOp.getOperand();
+        }
+      }
+    }
+    if (auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
+      if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1) {
+        // extui(%x : i1 -> iN) != 0  ->  %x
+        if (getPredicate() == arith::CmpIPredicate::ne) {
+          return extOp.getOperand();
+        }
+      }
+    }
+  }
+
   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
   auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
   if (!lhs || !rhs)

diff  --git a/mlir/test/Dialect/Arithmetic/canonicalize.mlir b/mlir/test/Dialect/Arithmetic/canonicalize.mlir
index 834842c0f3515..b4a5cf43ba82c 100644
--- a/mlir/test/Dialect/Arithmetic/canonicalize.mlir
+++ b/mlir/test/Dialect/Arithmetic/canonicalize.mlir
@@ -50,6 +50,26 @@ func @cmpi_equal_vector_operands(%arg0: vector<1x8xi64>)
 
 // -----
 
+// CHECK-LABEL: @cmpOfExtSI
+//  CHECK-NEXT:   return %arg0
+func @cmpOfExtSI(%arg0: i1) -> i1 {
+  %ext = arith.extsi %arg0 : i1 to i64
+  %c0 = arith.constant 0 : i64
+  %res = arith.cmpi ne, %ext, %c0 : i64
+  return %res : i1
+}
+
+// CHECK-LABEL: @cmpOfExtUI
+//  CHECK-NEXT:   return %arg0
+func @cmpOfExtUI(%arg0: i1) -> i1 {
+  %ext = arith.extui %arg0 : i1 to i64
+  %c0 = arith.constant 0 : i64
+  %res = arith.cmpi ne, %ext, %c0 : i64
+  return %res : i1
+}
+
+// -----
+
 // CHECK-LABEL: @indexCastOfSignExtend
 //       CHECK:   %[[res:.+]] = arith.index_cast %arg0 : i8 to index
 //       CHECK:   return %[[res]]


        


More information about the Mlir-commits mailing list