[Mlir-commits] [mlir] 5dce748 - [mlir][ub] Add poison support to CommonFolders.h

Ivan Butygin llvmlistbot at llvm.org
Thu Sep 7 03:30:39 PDT 2023


Author: Ivan Butygin
Date: 2023-09-07T12:30:29+02:00
New Revision: 5dce74817b71a1f646fb2857c037b3a66f41c7cd

URL: https://github.com/llvm/llvm-project/commit/5dce74817b71a1f646fb2857c037b3a66f41c7cd
DIFF: https://github.com/llvm/llvm-project/commit/5dce74817b71a1f646fb2857c037b3a66f41c7cd.diff

LOG: [mlir][ub] Add poison support to CommonFolders.h

Return poison from foldBinary/unary if argument(s) is poison. Add ub dialect as dependency to affected dialects (arith, math, spirv, shape).
Add poison materialization to dialects. Add tests for some ops from each dialect.
Not all affected ops are covered as it will involve a huge copypaste.

Differential Revision: https://reviews.llvm.org/D159013

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/CommonFolders.h
    mlir/lib/Dialect/Arith/IR/ArithDialect.cpp
    mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
    mlir/lib/Dialect/Math/IR/CMakeLists.txt
    mlir/lib/Dialect/Math/IR/MathDialect.cpp
    mlir/lib/Dialect/Math/IR/MathOps.cpp
    mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
    mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
    mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
    mlir/lib/Dialect/Shape/IR/CMakeLists.txt
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/test/Dialect/Arith/canonicalize.mlir
    mlir/test/Dialect/Math/canonicalize.mlir
    mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
    mlir/test/Dialect/Shape/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/CommonFolders.h b/mlir/include/mlir/Dialect/CommonFolders.h
index d3fbc723632a3b8..6257e4a60188d57 100644
--- a/mlir/include/mlir/Dialect/CommonFolders.h
+++ b/mlir/include/mlir/Dialect/CommonFolders.h
@@ -22,17 +22,35 @@
 #include <optional>
 
 namespace mlir {
+namespace ub {
+class PoisonAttr;
+}
 /// Performs constant folding `calculate` with element-wise behavior on the two
 /// attributes in `operands` and returns the result if possible.
 /// Uses `resultType` for the type of the returned attribute.
+/// Optional PoisonAttr template argument allows to specify 'poison' attribute
+/// which will be directly propagated to result.
 template <class AttrElementT,
           class ElementValueT = typename AttrElementT::ValueType,
+          class PoisonAttr = ub::PoisonAttr,
           class CalculationT = function_ref<
               std::optional<ElementValueT>(ElementValueT, ElementValueT)>>
 Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
                                        Type resultType,
-                                       const CalculationT &calculate) {
+                                       CalculationT &&calculate) {
   assert(operands.size() == 2 && "binary op takes two operands");
+  static_assert(
+      std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
+      "PoisonAttr is undefined, either add a dependency on UB dialect or pass "
+      "void as template argument to opt-out from poison semantics.");
+  if constexpr (!std::is_void_v<PoisonAttr>) {
+    if (isa_and_nonnull<PoisonAttr>(operands[0]))
+      return operands[0];
+
+    if (isa_and_nonnull<PoisonAttr>(operands[1]))
+      return operands[1];
+  }
+
   if (!resultType || !operands[0] || !operands[1])
     return {};
 
@@ -95,13 +113,28 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
 /// attributes in `operands` and returns the result if possible.
 /// Uses the operand element type for the element type of the returned
 /// attribute.
+/// Optional PoisonAttr template argument allows to specify 'poison' attribute
+/// which will be directly propagated to result.
 template <class AttrElementT,
           class ElementValueT = typename AttrElementT::ValueType,
+          class PoisonAttr = ub::PoisonAttr,
           class CalculationT = function_ref<
               std::optional<ElementValueT>(ElementValueT, ElementValueT)>>
 Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
-                                       const CalculationT &calculate) {
+                                       CalculationT &&calculate) {
   assert(operands.size() == 2 && "binary op takes two operands");
+  static_assert(
+      std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
+      "PoisonAttr is undefined, either add a dependency on UB dialect or pass "
+      "void as template argument to opt-out from poison semantics.");
+  if constexpr (!std::is_void_v<PoisonAttr>) {
+    if (isa_and_nonnull<PoisonAttr>(operands[0]))
+      return operands[0];
+
+    if (isa_and_nonnull<PoisonAttr>(operands[1]))
+      return operands[1];
+  }
+
   auto getResultType = [](Attribute attr) -> Type {
     if (auto typed = dyn_cast_or_null<TypedAttr>(attr))
       return typed.getType();
@@ -115,18 +148,19 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
   if (lhsType != rhsType)
     return {};
 
-  return constFoldBinaryOpConditional<AttrElementT, ElementValueT,
-                                      CalculationT>(operands, lhsType,
-                                                    calculate);
+  return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
+                                      CalculationT>(
+      operands, lhsType, std::forward<CalculationT>(calculate));
 }
 
 template <class AttrElementT,
           class ElementValueT = typename AttrElementT::ValueType,
+          class PoisonAttr = void,
           class CalculationT =
               function_ref<ElementValueT(ElementValueT, ElementValueT)>>
 Attribute constFoldBinaryOp(ArrayRef<Attribute> operands, Type resultType,
-                            const CalculationT &calculate) {
-  return constFoldBinaryOpConditional<AttrElementT>(
+                            CalculationT &&calculate) {
+  return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr>(
       operands, resultType,
       [&](ElementValueT a, ElementValueT b) -> std::optional<ElementValueT> {
         return calculate(a, b);
@@ -135,11 +169,12 @@ Attribute constFoldBinaryOp(ArrayRef<Attribute> operands, Type resultType,
 
 template <class AttrElementT,
           class ElementValueT = typename AttrElementT::ValueType,
+          class PoisonAttr = ub::PoisonAttr,
           class CalculationT =
               function_ref<ElementValueT(ElementValueT, ElementValueT)>>
 Attribute constFoldBinaryOp(ArrayRef<Attribute> operands,
-                            const CalculationT &calculate) {
-  return constFoldBinaryOpConditional<AttrElementT>(
+                            CalculationT &&calculate) {
+  return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr>(
       operands,
       [&](ElementValueT a, ElementValueT b) -> std::optional<ElementValueT> {
         return calculate(a, b);
@@ -148,16 +183,28 @@ Attribute constFoldBinaryOp(ArrayRef<Attribute> operands,
 
 /// Performs constant folding `calculate` with element-wise behavior on the one
 /// attributes in `operands` and returns the result if possible.
+/// Optional PoisonAttr template argument allows to specify 'poison' attribute
+/// which will be directly propagated to result.
 template <class AttrElementT,
           class ElementValueT = typename AttrElementT::ValueType,
+          class PoisonAttr = ub::PoisonAttr,
           class CalculationT =
               function_ref<std::optional<ElementValueT>(ElementValueT)>>
 Attribute constFoldUnaryOpConditional(ArrayRef<Attribute> operands,
-                                      const CalculationT &&calculate) {
+                                      CalculationT &&calculate) {
   assert(operands.size() == 1 && "unary op takes one operands");
   if (!operands[0])
     return {};
 
+  static_assert(
+      std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
+      "PoisonAttr is undefined, either add a dependency on UB dialect or pass "
+      "void as template argument to opt-out from poison semantics.");
+  if constexpr (!std::is_void_v<PoisonAttr>) {
+    if (isa<PoisonAttr>(operands[0]))
+      return operands[0];
+  }
+
   if (isa<AttrElementT>(operands[0])) {
     auto op = cast<AttrElementT>(operands[0]);
 
@@ -196,10 +243,11 @@ Attribute constFoldUnaryOpConditional(ArrayRef<Attribute> operands,
 
 template <class AttrElementT,
           class ElementValueT = typename AttrElementT::ValueType,
+          class PoisonAttr = ub::PoisonAttr,
           class CalculationT = function_ref<ElementValueT(ElementValueT)>>
 Attribute constFoldUnaryOp(ArrayRef<Attribute> operands,
-                           const CalculationT &&calculate) {
-  return constFoldUnaryOpConditional<AttrElementT>(
+                           CalculationT &&calculate) {
+  return constFoldUnaryOpConditional<AttrElementT, ElementValueT, PoisonAttr>(
       operands, [&](ElementValueT a) -> std::optional<ElementValueT> {
         return calculate(a);
       });
@@ -209,13 +257,23 @@ template <
     class AttrElementT, class TargetAttrElementT,
     class ElementValueT = typename AttrElementT::ValueType,
     class TargetElementValueT = typename TargetAttrElementT::ValueType,
+    class PoisonAttr = ub::PoisonAttr,
     class CalculationT = function_ref<TargetElementValueT(ElementValueT, bool)>>
 Attribute constFoldCastOp(ArrayRef<Attribute> operands, Type resType,
-                          const CalculationT &calculate) {
+                          CalculationT &&calculate) {
   assert(operands.size() == 1 && "Cast op takes one operand");
   if (!operands[0])
     return {};
 
+  static_assert(
+      std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
+      "PoisonAttr is undefined, either add a dependency on UB dialect or pass "
+      "void as template argument to opt-out from poison semantics.");
+  if constexpr (!std::is_void_v<PoisonAttr>) {
+    if (isa<PoisonAttr>(operands[0]))
+      return operands[0];
+  }
+
   if (isa<AttrElementT>(operands[0])) {
     auto op = cast<AttrElementT>(operands[0]);
     bool castStatus = true;
@@ -254,7 +312,6 @@ Attribute constFoldCastOp(ArrayRef<Attribute> operands, Type resType,
   }
   return {};
 }
-
 } // namespace mlir
 
 #endif // MLIR_DIALECT_COMMONFOLDERS_H

diff  --git a/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp b/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp
index 7f2d79355fe0f0f..ed4b91cbe516c95 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/DialectImplementation.h"
 #include "mlir/Transforms/InliningUtils.h"
@@ -49,5 +50,8 @@ void arith::ArithDialect::initialize() {
 Operation *arith::ArithDialect::materializeConstant(OpBuilder &builder,
                                                     Attribute value, Type type,
                                                     Location loc) {
+  if (auto poison = dyn_cast<ub::PoisonAttr>(value))
+    return builder.create<ub::PoisonOp>(loc, type, poison);
+
   return ConstantOp::materialize(builder, value, type, loc);
 }

diff  --git a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
index 0a86d8f15b0d631..fab6f341699908d 100644
--- a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
@@ -9,7 +9,6 @@
 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/CommonFolders.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Builders.h"

diff  --git a/mlir/lib/Dialect/Math/IR/CMakeLists.txt b/mlir/lib/Dialect/Math/IR/CMakeLists.txt
index 3b7b65e581432cb..ed95bf846cdeade 100644
--- a/mlir/lib/Dialect/Math/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Math/IR/CMakeLists.txt
@@ -12,4 +12,5 @@ add_mlir_dialect_library(MLIRMathDialect
   MLIRArithDialect
   MLIRDialect
   MLIRIR
+  MLIRUBDialect
   )

diff  --git a/mlir/lib/Dialect/Math/IR/MathDialect.cpp b/mlir/lib/Dialect/Math/IR/MathDialect.cpp
index 54a8cc1d697b49c..9cf47ac7130622b 100644
--- a/mlir/lib/Dialect/Math/IR/MathDialect.cpp
+++ b/mlir/lib/Dialect/Math/IR/MathDialect.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/Transforms/InliningUtils.h"
 
 using namespace mlir;

diff  --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp
index ae9dc08c745b4fd..28d1c062f235e60 100644
--- a/mlir/lib/Dialect/Math/IR/MathOps.cpp
+++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/CommonFolders.h"
 #include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/IR/Builders.h"
 #include <optional>
 
@@ -522,5 +523,8 @@ OpFoldResult math::TruncOp::fold(FoldAdaptor adaptor) {
 Operation *math::MathDialect::materializeConstant(OpBuilder &builder,
                                                   Attribute value, Type type,
                                                   Location loc) {
+  if (auto poison = dyn_cast<ub::PoisonAttr>(value))
+    return builder.create<ub::PoisonOp>(loc, type, poison);
+
   return arith::ConstantOp::materialize(builder, value, type, loc);
 }

diff  --git a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
index 0189e79ea12fa15..2b5cedafae1e85b 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
@@ -43,4 +43,5 @@ add_mlir_dialect_library(MLIRSPIRVDialect
   MLIRSideEffectInterfaces
   MLIRSupport
   MLIRTransforms
+  MLIRUBDialect
 )

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index def62b4467ce30b..9acd982dc95af6d 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -18,6 +18,7 @@
 #include "mlir/Dialect/CommonFolders.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
 #include "llvm/ADT/STLExtras.h"

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index 76e703946428361..a51d77dda78bf2f 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/DialectImplementation.h"
@@ -949,6 +950,9 @@ void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
 Operation *SPIRVDialect::materializeConstant(OpBuilder &builder,
                                              Attribute value, Type type,
                                              Location loc) {
+  if (auto poison = dyn_cast<ub::PoisonAttr>(value))
+    return builder.create<ub::PoisonOp>(loc, type, poison);
+
   if (!spirv::ConstantOp::isBuildableWith(type))
     return nullptr;
 

diff  --git a/mlir/lib/Dialect/Shape/IR/CMakeLists.txt b/mlir/lib/Dialect/Shape/IR/CMakeLists.txt
index ba41f1aec8d97a8..32a86b483a49b10 100644
--- a/mlir/lib/Dialect/Shape/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Shape/IR/CMakeLists.txt
@@ -23,4 +23,5 @@ add_mlir_dialect_library(MLIRShapeDialect
   MLIRIR
   MLIRSideEffectInterfaces
   MLIRTensorDialect
+  MLIRUBDialect
   )

diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index e4efa09316770c4..2444556a4563512 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -14,6 +14,7 @@
 #include "mlir/Dialect/CommonFolders.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Traits.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/DialectImplementation.h"
@@ -147,6 +148,9 @@ void ShapeDialect::initialize() {
 Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
                                              Attribute value, Type type,
                                              Location loc) {
+  if (auto poison = dyn_cast<ub::PoisonAttr>(value))
+    return builder.create<ub::PoisonOp>(loc, type, poison);
+
   if (llvm::isa<ShapeType>(type) || isExtentTensorType(type))
     return builder.create<ConstShapeOp>(
         loc, type, llvm::cast<DenseIntElementsAttr>(value));
@@ -156,6 +160,7 @@ Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
   if (llvm::isa<WitnessType>(type))
     return builder.create<ConstWitnessOp>(loc, type,
                                           llvm::cast<BoolAttr>(value));
+
   return arith::ConstantOp::materialize(builder, value, type, loc);
 }
 

diff  --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 0c8e0974b017dcd..347b6346b786279 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -2584,3 +2584,58 @@ func.func @selectOfPoison(%cond : i1, %arg: i32) -> (i32, i32, i32, i32) {
   %select4 = arith.select %false, %poison, %arg : i32
   return %select1, %select2, %select3, %select4 : i32, i32, i32, i32
 }
+
+// CHECK-LABEL: @addi_poison1
+//       CHECK:   %[[P:.*]] = ub.poison : i32
+//       CHECK:   return %[[P]]
+func.func @addi_poison1(%arg: i32) -> i32 {
+  %0 = ub.poison : i32
+  %1 = arith.addi %0, %arg : i32
+  return %1 : i32
+}
+
+// CHECK-LABEL: @addi_poison2
+//       CHECK:   %[[P:.*]] = ub.poison : i32
+//       CHECK:   return %[[P]]
+func.func @addi_poison2(%arg: i32) -> i32 {
+  %0 = ub.poison : i32
+  %1 = arith.addi %arg, %0 : i32
+  return %1 : i32
+}
+
+// CHECK-LABEL: @addf_poison1
+//       CHECK:   %[[P:.*]] = ub.poison : f32
+//       CHECK:   return %[[P]]
+func.func @addf_poison1(%arg: f32) -> f32 {
+  %0 = ub.poison : f32
+  %1 = arith.addf %0, %arg : f32
+  return %1 : f32
+}
+
+// CHECK-LABEL: @addf_poison2
+//       CHECK:   %[[P:.*]] = ub.poison : f32
+//       CHECK:   return %[[P]]
+func.func @addf_poison2(%arg: f32) -> f32 {
+  %0 = ub.poison : f32
+  %1 = arith.addf %arg, %0 : f32
+  return %1 : f32
+}
+
+
+// CHECK-LABEL: @negf_poison
+//       CHECK:   %[[P:.*]] = ub.poison : f32
+//       CHECK:   return %[[P]]
+func.func @negf_poison() -> f32 {
+  %0 = ub.poison : f32
+  %1 = arith.negf %0 : f32
+  return %1 : f32
+}
+
+// CHECK-LABEL: @extsi_poison
+//       CHECK:   %[[P:.*]] = ub.poison : i64
+//       CHECK:   return %[[P]]
+func.func @extsi_poison() -> i64 {
+  %0 = ub.poison : i32
+  %1 = arith.extsi %0 : i32 to i64
+  return %1 : i64
+}

diff  --git a/mlir/test/Dialect/Math/canonicalize.mlir b/mlir/test/Dialect/Math/canonicalize.mlir
index 7a5194b89a5ceeb..d24f7649269fe02 100644
--- a/mlir/test/Dialect/Math/canonicalize.mlir
+++ b/mlir/test/Dialect/Math/canonicalize.mlir
@@ -483,3 +483,12 @@ func.func @erf_fold_vec() -> (vector<4xf32>) {
   %0 = math.erf %v1 : vector<4xf32>
   return %0 : vector<4xf32>
 }
+
+// CHECK-LABEL: @abs_poison
+//       CHECK:   %[[P:.*]] = ub.poison : f32
+//       CHECK:   return %[[P]]
+func.func @abs_poison() -> f32 {
+  %0 = ub.poison : f32
+  %1 = math.absf %0 : f32
+  return %1 : f32
+}

diff  --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
index 52607d7267852a8..0200805a444397a 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
@@ -325,6 +325,15 @@ func.func @const_fold_vector_iadd() -> vector<3xi32> {
   return %0: vector<3xi32>
 }
 
+// CHECK-LABEL: @iadd_poison
+//       CHECK:   %[[P:.*]] = ub.poison : i32
+//       CHECK:   return %[[P]]
+func.func @iadd_poison(%arg0: i32) -> i32 {
+  %0 = ub.poison : i32
+  %1 = spirv.IAdd %arg0, %0 : i32
+  return %1: i32
+}
+
 // -----
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index aec5f3202c9b881..8edbae3baf52e6a 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -1479,3 +1479,16 @@ func.func @extract_shapeof(%arg0 : tensor<?x?xf64>) -> index {
 // CHECK:        return %[[DIM]]
  return %result : index
 }
+
+
+// -----
+
+// CHECK-LABEL: @add_poison
+//       CHECK:   %[[P:.*]] = ub.poison : !shape.siz
+//       CHECK:   return %[[P]]
+func.func @add_poison() -> !shape.size {
+  %1 = shape.const_size 2
+  %2 = ub.poison : !shape.size
+  %result = shape.add %1, %2 : !shape.size, !shape.size -> !shape.size
+  return %result : !shape.size
+}


        


More information about the Mlir-commits mailing list