[Mlir-commits] [mlir] [mlir][Vector] Teach how to materialize UB constant to Vector (PR #125596)

Diego Caballero llvmlistbot at llvm.org
Tue Feb 4 11:05:45 PST 2025


https://github.com/dcaballe updated https://github.com/llvm/llvm-project/pull/125596

>From 96b4a39552ec1ad84c2a49d8b7115d548cafe714 Mon Sep 17 00:00:00 2001
From: Diego Caballero <dieg0ca6aller0 at gmail.com>
Date: Mon, 3 Feb 2025 15:04:30 -0800
Subject: [PATCH 1/2] [mlir][Vector] Teach Vector to materialize UB constant

This PR adds support for UB constant materialization (i.e., generating
`ub::PoisonOp` to `VectorDialect::materializeConstant`. This was the
reason why the vector folders generating poison didn't work.
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 19 +++----------------
 .../VectorToLLVM/vector-to-llvm.mlir          |  4 ++--
 2 files changed, 5 insertions(+), 18 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 93f89eda2da5a6b..190b643f1c7ae58 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -437,6 +437,9 @@ void VectorDialect::initialize() {
 Operation *VectorDialect::materializeConstant(OpBuilder &builder,
                                               Attribute value, Type type,
                                               Location loc) {
+  if (auto poisonAttr = dyn_cast<ub::PoisonAttrInterface>(value))
+    return builder.create<ub::PoisonOp>(loc, type, poisonAttr);
+
   return arith::ConstantOp::materialize(builder, value, type, loc);
 }
 
@@ -2273,20 +2276,6 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
   return success();
 }
 
-/// Fold an insert or extract operation into an poison value when a poison index
-/// is found at any dimension of the static position.
-template <typename OpTy>
-LogicalResult
-canonicalizePoisonIndexInsertExtractOp(OpTy op, PatternRewriter &rewriter) {
-  if (auto poisonAttr = foldPoisonIndexInsertExtractOp(
-          op.getContext(), op.getStaticPosition(), OpTy::kPoisonIndex)) {
-    rewriter.replaceOpWithNewOp<ub::PoisonOp>(op, op.getType(), poisonAttr);
-    return success();
-  }
-
-  return failure();
-}
-
 } // namespace
 
 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
@@ -2295,7 +2284,6 @@ void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
               ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
   results.add(foldExtractFromShapeCastToShapeCast);
   results.add(foldExtractFromFromElements);
-  results.add(canonicalizePoisonIndexInsertExtractOp<ExtractOp>);
 }
 
 static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
@@ -3068,7 +3056,6 @@ void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                            MLIRContext *context) {
   results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
               InsertOpConstantFolder>(context);
-  results.add(canonicalizePoisonIndexInsertExtractOp<InsertOp>);
 }
 
 OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 7df6defc0f202f1..722ab0499d858ea 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1255,8 +1255,8 @@ func.func @extract_poison_idx(%arg0: vector<16xf32>) -> f32 {
   return %0 : f32
 }
 // CHECK-LABEL: @extract_poison_idx
-//       CHECK:   %[[IDX:.*]] = llvm.mlir.constant(-1 : i64) : i64
-//       CHECK:   llvm.extractelement {{.*}}[%[[IDX]] : i64] : vector<16xf32>
+//       CHECK:   %[[UB:.*]] = ub.poison : f32
+//       CHECK:   return %[[UB]] : f32
 
 // -----
 

>From ea3a0f2af27137dd408c77e15b4a0ecdddedcffa Mon Sep 17 00:00:00 2001
From: Diego Caballero <dieg0ca6aller0 at gmail.com>
Date: Tue, 4 Feb 2025 10:38:21 -0800
Subject: [PATCH 2/2] Feedback

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp           |  4 ++--
 .../Conversion/VectorToLLVM/vector-to-llvm.mlir    | 14 ++++++++++++--
 2 files changed, 14 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 190b643f1c7ae58..2ec1b97f2f241d1 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -437,8 +437,8 @@ void VectorDialect::initialize() {
 Operation *VectorDialect::materializeConstant(OpBuilder &builder,
                                               Attribute value, Type type,
                                               Location loc) {
-  if (auto poisonAttr = dyn_cast<ub::PoisonAttrInterface>(value))
-    return builder.create<ub::PoisonOp>(loc, type, poisonAttr);
+  if (isa<ub::PoisonAttrInterface>(value))
+    return value.getDialect().materializeConstant(builder, value, type, loc);
 
   return arith::ConstantOp::materialize(builder, value, type, loc);
 }
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 722ab0499d858ea..9a6337f14ace334 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1250,11 +1250,11 @@ func.func @extract_scalar_from_vec_1d_f32(%arg0: vector<16xf32>) -> f32 {
 
 // -----
 
-func.func @extract_poison_idx(%arg0: vector<16xf32>) -> f32 {
+func.func @extract_scalar_from_vec_1d_f32_poison_idx(%arg0: vector<16xf32>) -> f32 {
   %0 = vector.extract %arg0[-1]: f32 from vector<16xf32>
   return %0 : f32
 }
-// CHECK-LABEL: @extract_poison_idx
+// CHECK-LABEL: @extract_scalar_from_vec_1d_f32_poison_idx
 //       CHECK:   %[[UB:.*]] = ub.poison : f32
 //       CHECK:   return %[[UB]] : f32
 
@@ -1335,6 +1335,16 @@ func.func @extract_vec_2d_from_vec_3d_f32(%arg0: vector<4x3x16xf32>) -> vector<3
 
 // -----
 
+func.func @extract_vec_2d_from_vec_3d_f32_poison_idx(%arg0: vector<4x3x16xf32>) -> vector<3x16xf32> {
+  %0 = vector.extract %arg0[-1]: vector<3x16xf32> from vector<4x3x16xf32>
+  return %0 : vector<3x16xf32>
+}
+// CHECK-LABEL: @extract_vec_2d_from_vec_3d_f32_poison_idx
+//       CHECK:   %[[UB:.*]] = ub.poison : vector<3x16xf32>
+//       CHECK:   return %[[UB]] : vector<3x16xf32>
+
+// -----
+
 func.func @extract_vec_2d_from_vec_3d_f32_scalable(%arg0: vector<4x3x[16]xf32>) -> vector<3x[16]xf32> {
   %0 = vector.extract %arg0[0]: vector<3x[16]xf32> from vector<4x3x[16]xf32>
   return %0 : vector<3x[16]xf32>



More information about the Mlir-commits mailing list