[Mlir-commits] [mlir] 1a0a177 - [MLIR] Create fold for cmp of ext
William S. Moses
llvmlistbot at llvm.org
Sun Jan 2 16:48:56 PST 2022
Author: William S. Moses
Date: 2022-01-02T19:48:52-05:00
New Revision: 1a0a177965e88d61b5d3cd3e7f7f89011f0827c1
URL: https://github.com/llvm/llvm-project/commit/1a0a177965e88d61b5d3cd3e7f7f89011f0827c1
DIFF: https://github.com/llvm/llvm-project/commit/1a0a177965e88d61b5d3cd3e7f7f89011f0827c1.diff
LOG: [MLIR] Create fold for cmp of ext
This patch creates folds for cmpi( ext(%x : i1, iN) != 0) -> %x
In essence this matches patterns matching an extension of a boolean, that != 0, which is equivalent to the original condition.
Reviewed By: mehdi_amini
Differential Revision: https://reviews.llvm.org/D116504
Added:
Modified:
mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
mlir/test/Dialect/Arithmetic/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
index 91cbf4bdb5280..2a6b463bd4e85 100644
--- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
+++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
@@ -1150,6 +1150,25 @@ OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) {
return getBoolAttribute(getType(), getContext(), val);
}
+ if (matchPattern(getRhs(), m_Zero())) {
+ if (auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
+ if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1) {
+ // extsi(%x : i1 -> iN) != 0 -> %x
+ if (getPredicate() == arith::CmpIPredicate::ne) {
+ return extOp.getOperand();
+ }
+ }
+ }
+ if (auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
+ if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1) {
+ // extui(%x : i1 -> iN) != 0 -> %x
+ if (getPredicate() == arith::CmpIPredicate::ne) {
+ return extOp.getOperand();
+ }
+ }
+ }
+ }
+
auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
if (!lhs || !rhs)
diff --git a/mlir/test/Dialect/Arithmetic/canonicalize.mlir b/mlir/test/Dialect/Arithmetic/canonicalize.mlir
index 834842c0f3515..b4a5cf43ba82c 100644
--- a/mlir/test/Dialect/Arithmetic/canonicalize.mlir
+++ b/mlir/test/Dialect/Arithmetic/canonicalize.mlir
@@ -50,6 +50,26 @@ func @cmpi_equal_vector_operands(%arg0: vector<1x8xi64>)
// -----
+// CHECK-LABEL: @cmpOfExtSI
+// CHECK-NEXT: return %arg0
+func @cmpOfExtSI(%arg0: i1) -> i1 {
+ %ext = arith.extsi %arg0 : i1 to i64
+ %c0 = arith.constant 0 : i64
+ %res = arith.cmpi ne, %ext, %c0 : i64
+ return %res : i1
+}
+
+// CHECK-LABEL: @cmpOfExtUI
+// CHECK-NEXT: return %arg0
+func @cmpOfExtUI(%arg0: i1) -> i1 {
+ %ext = arith.extui %arg0 : i1 to i64
+ %c0 = arith.constant 0 : i64
+ %res = arith.cmpi ne, %ext, %c0 : i64
+ return %res : i1
+}
+
+// -----
+
// CHECK-LABEL: @indexCastOfSignExtend
// CHECK: %[[res:.+]] = arith.index_cast %arg0 : i8 to index
// CHECK: return %[[res]]
More information about the Mlir-commits
mailing list