[Mlir-commits] [mlir] bbf0733 - [mlir][arith] Fold `select` with poison
Markus Böck
llvmlistbot at llvm.org
Mon Aug 28 23:32:52 PDT 2023
Author: Markus Böck
Date: 2023-08-29T08:26:25+02:00
New Revision: bbf0733030ae16a1ef19c2a031f805a971e941d2
URL: https://github.com/llvm/llvm-project/commit/bbf0733030ae16a1ef19c2a031f805a971e941d2
DIFF: https://github.com/llvm/llvm-project/commit/bbf0733030ae16a1ef19c2a031f805a971e941d2.diff
LOG: [mlir][arith] Fold `select` with poison
If either of the operands of `select` is fully poisoned we can simply return the other.
This PR implements this optimization inside the `fold` method.
Note that this patch is the first to add a dependency on the UB dialect within Arith. Given this was inevitable (and part of the motivation) it should be fine I believe.
Differential Revision: https://reviews.llvm.org/D158986
Added:
Modified:
mlir/lib/Dialect/Arith/IR/ArithOps.cpp
mlir/lib/Dialect/Arith/IR/CMakeLists.txt
mlir/test/Dialect/Arith/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 5e1c4302a16c43..007b105d2328c7 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -13,6 +13,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/CommonFolders.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/BuiltinAttributes.h"
@@ -2194,6 +2195,13 @@ OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
if (matchPattern(condition, m_Zero()))
return falseVal;
+ // If either operand is fully poisoned, return the other.
+ if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getTrueValue()))
+ return falseVal;
+
+ if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getFalseValue()))
+ return trueVal;
+
// select %x, true, false => %x
if (getType().isInteger(1) && matchPattern(getTrueValue(), m_One()) &&
matchPattern(getFalseValue(), m_Zero()))
diff --git a/mlir/lib/Dialect/Arith/IR/CMakeLists.txt b/mlir/lib/Dialect/Arith/IR/CMakeLists.txt
index fdbeb39e60c066..4beb99ccfdfbab 100644
--- a/mlir/lib/Dialect/Arith/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Arith/IR/CMakeLists.txt
@@ -28,6 +28,7 @@ add_mlir_dialect_library(MLIRArithDialect
MLIRInferIntRangeInterface
MLIRInferTypeOpInterface
MLIRIR
+ MLIRUBDialect
)
add_mlir_dialect_library(MLIRArithValueBoundsOpInterfaceImpl
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 5b392fe9cf58a0..0c8e0974b017dc 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -2567,3 +2567,20 @@ func.func @foldOrXor6(%arg0: index) -> index {
%2 = arith.ori %arg0, %1 : index
return %2 : index
}
+
+// CHECK-LABEL: @selectOfPoison
+// CHECK-SAME: %[[ARG:[[:alnum:]]+]]: i32
+// CHECK: %[[UB:.*]] = ub.poison : i32
+// CHECK: return %[[ARG]], %[[ARG]], %[[UB]], %[[ARG]]
+func.func @selectOfPoison(%cond : i1, %arg: i32) -> (i32, i32, i32, i32) {
+ %poison = ub.poison : i32
+ %select1 = arith.select %cond, %poison, %arg : i32
+ %select2 = arith.select %cond, %arg, %poison : i32
+
+ // Check that constant folding is applied prior to poison handling.
+ %true = arith.constant true
+ %false = arith.constant false
+ %select3 = arith.select %true, %poison, %arg : i32
+ %select4 = arith.select %false, %poison, %arg : i32
+ return %select1, %select2, %select3, %select4 : i32, i32, i32, i32
+}
More information about the Mlir-commits
mailing list