[Mlir-commits] [mlir] [MLIR][Arith] SelectOp fix invalid folding (PR #117555)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Nov 26 01:58:49 PST 2024


https://github.com/7FM updated https://github.com/llvm/llvm-project/pull/117555

>From f459a27ef77f42d3b4cb2607bea054d31cc8cc47 Mon Sep 17 00:00:00 2001
From: 7FM <muermann at esa.tu-darmstadt.de>
Date: Mon, 25 Nov 2024 14:16:47 +0100
Subject: [PATCH 1/2] [MLIR][Arith] SelectOp fix invalid folding

The pattern `select %x, true, false => %x` is only valid in case that the return type is identical to the type of `%x` (i.e., i1).
Hence, the check `isInteger(1)` was replaced with `isSignlessInteger(1)`.
---
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 254f54d9e459e1..f2f23954d5c191 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -2314,7 +2314,8 @@ OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
     return trueVal;
 
   // select %x, true, false => %x
-  if (getType().isInteger(1) && matchPattern(adaptor.getTrueValue(), m_One()) &&
+  if (getType().isSignlessInteger(1) &&
+      matchPattern(adaptor.getTrueValue(), m_One()) &&
       matchPattern(adaptor.getFalseValue(), m_Zero()))
     return condition;
 

>From 9aca6a19f7e8338e6dc7cc3c448e695def83907a Mon Sep 17 00:00:00 2001
From: 7FM <muermann at esa.tu-darmstadt.de>
Date: Tue, 26 Nov 2024 10:58:24 +0100
Subject: [PATCH 2/2] Add LIT test

---
 mlir/test/Dialect/Arith/canonicalize.mlir | 12 ++++++++++++
 1 file changed, 12 insertions(+)

diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index a386a178b78995..f9997ec2796afe 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -54,6 +54,18 @@ func.func @select_extui_i1(%arg0: i1) -> i1 {
   return %res : i1
 }
 
+// CHECK-LABEL: @select_no_fold_ui1
+//       CHECK:  %[[CONST_0:.+]] = "test.constant"() <{value = 0 : i32}> : () -> ui1
+//       CHECK:  %[[CONST_1:.+]] = "test.constant"() <{value = 1 : i32}> : () -> ui1
+//  CHECK-NEXT:  %[[RES:.+]] = arith.select %arg0, %[[CONST_1]], %[[CONST_0]] : ui1
+//  CHECK-NEXT:   return %[[RES]]
+func.func @select_no_fold_ui1(%arg0: i1) -> ui1 {
+  %c0_i1 = "test.constant"() {value = 0 : i32} : () -> ui1
+  %c1_i1 = "test.constant"() {value = 1 : i32} : () -> ui1
+  %res = arith.select %arg0, %c1_i1, %c0_i1 : ui1
+  return %res : ui1
+}
+
 // CHECK-LABEL: @select_cst_false_scalar
 //  CHECK-SAME:   (%[[ARG0:.+]]: i32, %[[ARG1:.+]]: i32)
 //  CHECK-NEXT:   return %[[ARG1]]



More information about the Mlir-commits mailing list