[Mlir-commits] [mlir] bbf0733 - [mlir][arith] Fold `select` with poison

Markus Böck llvmlistbot at llvm.org
Mon Aug 28 23:32:52 PDT 2023


Author: Markus Böck
Date: 2023-08-29T08:26:25+02:00
New Revision: bbf0733030ae16a1ef19c2a031f805a971e941d2

URL: https://github.com/llvm/llvm-project/commit/bbf0733030ae16a1ef19c2a031f805a971e941d2
DIFF: https://github.com/llvm/llvm-project/commit/bbf0733030ae16a1ef19c2a031f805a971e941d2.diff

LOG: [mlir][arith] Fold `select` with poison

If either of the operands of `select` is fully poisoned we can simply return the other.
This PR implements this optimization inside the `fold` method.

Note that this patch is the first to add a dependency on the UB dialect within Arith. Given this was inevitable (and part of the motivation) it should be fine I believe.

Differential Revision: https://reviews.llvm.org/D158986

Added: 
    

Modified: 
    mlir/lib/Dialect/Arith/IR/ArithOps.cpp
    mlir/lib/Dialect/Arith/IR/CMakeLists.txt
    mlir/test/Dialect/Arith/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 5e1c4302a16c43..007b105d2328c7 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -13,6 +13,7 @@
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/CommonFolders.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributeInterfaces.h"
 #include "mlir/IR/BuiltinAttributes.h"
@@ -2194,6 +2195,13 @@ OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
   if (matchPattern(condition, m_Zero()))
     return falseVal;
 
+  // If either operand is fully poisoned, return the other.
+  if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getTrueValue()))
+    return falseVal;
+
+  if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getFalseValue()))
+    return trueVal;
+
   // select %x, true, false => %x
   if (getType().isInteger(1) && matchPattern(getTrueValue(), m_One()) &&
       matchPattern(getFalseValue(), m_Zero()))

diff  --git a/mlir/lib/Dialect/Arith/IR/CMakeLists.txt b/mlir/lib/Dialect/Arith/IR/CMakeLists.txt
index fdbeb39e60c066..4beb99ccfdfbab 100644
--- a/mlir/lib/Dialect/Arith/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Arith/IR/CMakeLists.txt
@@ -28,6 +28,7 @@ add_mlir_dialect_library(MLIRArithDialect
   MLIRInferIntRangeInterface
   MLIRInferTypeOpInterface
   MLIRIR
+  MLIRUBDialect
   )
 
 add_mlir_dialect_library(MLIRArithValueBoundsOpInterfaceImpl

diff  --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 5b392fe9cf58a0..0c8e0974b017dc 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -2567,3 +2567,20 @@ func.func @foldOrXor6(%arg0: index) -> index {
   %2 = arith.ori %arg0, %1 : index
   return %2 : index
 }
+
+// CHECK-LABEL: @selectOfPoison
+// CHECK-SAME: %[[ARG:[[:alnum:]]+]]: i32
+// CHECK: %[[UB:.*]] = ub.poison : i32
+// CHECK: return %[[ARG]], %[[ARG]], %[[UB]], %[[ARG]]
+func.func @selectOfPoison(%cond : i1, %arg: i32) -> (i32, i32, i32, i32) {
+  %poison = ub.poison : i32
+  %select1 = arith.select %cond, %poison, %arg : i32
+  %select2 = arith.select %cond, %arg, %poison : i32
+
+  // Check that constant folding is applied prior to poison handling.
+  %true = arith.constant true
+  %false = arith.constant false
+  %select3 = arith.select %true, %poison, %arg : i32
+  %select4 = arith.select %false, %poison, %arg : i32
+  return %select1, %select2, %select3, %select4 : i32, i32, i32, i32
+}


        


More information about the Mlir-commits mailing list