[Mlir-commits] [mlir] [mlir][arith] Fold `arith.cmpi eq, %val, %one : i1` -> `%val` and `arith.cmpi ne, %val, %zero : i1 -> %val` (PR #124436)
Ivan Butygin
llvmlistbot at llvm.org
Sat Jan 25 16:34:25 PST 2025
https://github.com/Hardcode84 created https://github.com/llvm/llvm-project/pull/124436
None
>From 3e34fe8f1b6ddf292add90d73d7c2a5938fafe27 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 26 Jan 2025 01:25:13 +0100
Subject: [PATCH] [mlir][arith] Fold `arith.cmpi eq, %val, %one : i1` -> `%val`
---
mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 12 ++++
mlir/test/Dialect/Arith/canonicalize.mlir | 72 +++++++++++++++++++++++
2 files changed, 84 insertions(+)
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 7ca104691e6df6..75d59ba8c1a108 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1865,6 +1865,18 @@ OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {
getPredicate() == arith::CmpIPredicate::ne)
return extOp.getOperand();
}
+
+ // arith.cmpi ne, %val, %zero : i1 -> %val
+ if (getElementTypeOrSelf(getLhs().getType()).isInteger(1) &&
+ getPredicate() == arith::CmpIPredicate::ne)
+ return getLhs();
+ }
+
+ if (matchPattern(adaptor.getRhs(), m_One())) {
+ // arith.cmpi eq, %val, %one : i1 -> %val
+ if (getElementTypeOrSelf(getLhs().getType()).isInteger(1) &&
+ getPredicate() == arith::CmpIPredicate::eq)
+ return getLhs();
}
// Move constant to the right side.
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 522711b08f289d..3a16ee3d4f8fde 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -160,6 +160,78 @@ func.func @selNotCond(%arg0: i1, %arg1 : i32, %arg2 : i32, %arg3 : i32, %arg4 :
return %res1, %res2 : i32, i32
}
+// CHECK-LABEL: @cmpiI1eq
+// CHECK-SAME: (%[[ARG:.*]]: i1)
+// CHECK: return %[[ARG]]
+func.func @cmpiI1eq(%arg0: i1) -> i1 {
+ %one = arith.constant 1 : i1
+ %res = arith.cmpi eq, %arg0, %one : i1
+ return %res : i1
+}
+
+// CHECK-LABEL: @cmpiI1eqVec
+// CHECK-SAME: (%[[ARG:.*]]: vector<4xi1>)
+// CHECK: return %[[ARG]]
+func.func @cmpiI1eqVec(%arg0: vector<4xi1>) -> vector<4xi1> {
+ %one = arith.constant dense<1> : vector<4xi1>
+ %res = arith.cmpi eq, %arg0, %one : vector<4xi1>
+ return %res : vector<4xi1>
+}
+
+// CHECK-LABEL: @cmpiI1ne
+// CHECK-SAME: (%[[ARG:.*]]: i1)
+// CHECK: return %[[ARG]]
+func.func @cmpiI1ne(%arg0: i1) -> i1 {
+ %zero = arith.constant 0 : i1
+ %res = arith.cmpi ne, %arg0, %zero : i1
+ return %res : i1
+}
+
+// CHECK-LABEL: @cmpiI1neVec
+// CHECK-SAME: (%[[ARG:.*]]: vector<4xi1>)
+// CHECK: return %[[ARG]]
+func.func @cmpiI1neVec(%arg0: vector<4xi1>) -> vector<4xi1> {
+ %zero = arith.constant dense<0> : vector<4xi1>
+ %res = arith.cmpi ne, %arg0, %zero : vector<4xi1>
+ return %res : vector<4xi1>
+}
+
+// CHECK-LABEL: @cmpiI1eqLhs
+// CHECK-SAME: (%[[ARG:.*]]: i1)
+// CHECK: return %[[ARG]]
+func.func @cmpiI1eqLhs(%arg0: i1) -> i1 {
+ %one = arith.constant 1 : i1
+ %res = arith.cmpi eq, %one, %arg0 : i1
+ return %res : i1
+}
+
+// CHECK-LABEL: @cmpiI1eqVecLhs
+// CHECK-SAME: (%[[ARG:.*]]: vector<4xi1>)
+// CHECK: return %[[ARG]]
+func.func @cmpiI1eqVecLhs(%arg0: vector<4xi1>) -> vector<4xi1> {
+ %one = arith.constant dense<1> : vector<4xi1>
+ %res = arith.cmpi eq, %one, %arg0 : vector<4xi1>
+ return %res : vector<4xi1>
+}
+
+// CHECK-LABEL: @cmpiI1neLhs
+// CHECK-SAME: (%[[ARG:.*]]: i1)
+// CHECK: return %[[ARG]]
+func.func @cmpiI1neLhs(%arg0: i1) -> i1 {
+ %zero = arith.constant 0 : i1
+ %res = arith.cmpi ne, %zero, %arg0 : i1
+ return %res : i1
+}
+
+// CHECK-LABEL: @cmpiI1neVecLhs
+// CHECK-SAME: (%[[ARG:.*]]: vector<4xi1>)
+// CHECK: return %[[ARG]]
+func.func @cmpiI1neVecLhs(%arg0: vector<4xi1>) -> vector<4xi1> {
+ %zero = arith.constant dense<0> : vector<4xi1>
+ %res = arith.cmpi ne, %zero, %arg0 : vector<4xi1>
+ return %res : vector<4xi1>
+}
+
// Test case: Folding of comparisons with equal operands.
// CHECK-LABEL: @cmpi_equal_operands
// CHECK-DAG: %[[T:.*]] = arith.constant true
More information about the Mlir-commits
mailing list