[Mlir-commits] [mlir] 35df525 - [mlir][Vector] Add support for poison indices to `Extract/IndexOp` (#123488)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jan 28 13:51:54 PST 2025


Author: Diego Caballero
Date: 2025-01-28T13:51:50-08:00
New Revision: 35df525fd00c2037ef144189ee818b7d612241ff

URL: https://github.com/llvm/llvm-project/commit/35df525fd00c2037ef144189ee818b7d612241ff
DIFF: https://github.com/llvm/llvm-project/commit/35df525fd00c2037ef144189ee818b7d612241ff.diff

LOG: [mlir][Vector] Add support for poison indices to `Extract/IndexOp` (#123488)

Following up on #122188, this PR adds support for poison indices to
`ExtractOp` and `InsertOp`. It also includes canonicalization patterns
to turn extract/insert ops with poison indices into `ub.poison`.

Added: 
    

Modified: 
    mlir/include/mlir/Conversion/Passes.td
    mlir/include/mlir/Dialect/Vector/IR/Vector.td
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
    mlir/include/mlir/Transforms/Passes.td
    mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
    mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
    mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/lib/Transforms/CMakeLists.txt
    mlir/lib/Transforms/Canonicalizer.cpp
    mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
    mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
    mlir/test/Dialect/Vector/canonicalize.mlir
    mlir/test/Dialect/Vector/invalid.mlir
    mlir/test/Dialect/Vector/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index b547839d76738c..4cd6c17e3379cd 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1454,7 +1454,10 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
 def ConvertVectorToSPIRV : Pass<"convert-vector-to-spirv"> {
   let summary = "Convert Vector dialect to SPIR-V dialect";
   let constructor = "mlir::createConvertVectorToSPIRVPass()";
-  let dependentDialects = ["spirv::SPIRVDialect"];
+  let dependentDialects = [
+    "spirv::SPIRVDialect",
+    "ub::UBDialect"
+  ];
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/Vector/IR/Vector.td b/mlir/include/mlir/Dialect/Vector/IR/Vector.td
index c439ca083e2e09..1922cc63ef3538 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/Vector.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/Vector.td
@@ -26,6 +26,15 @@ def Vector_Dialect : Dialect {
 
 // Base class for Vector dialect ops.
 class Vector_Op<string mnemonic, list<Trait> traits = []> :
-    Op<Vector_Dialect, mnemonic, traits>;
+    Op<Vector_Dialect, mnemonic, traits> {
+
+  // Includes definitions for operations that support the use of poison values
+  // within positive index ranges.
+  code extraPoisonClassDeclaration = [{
+    // Integer to represent a poison index within a static and positive integer
+    // range.
+    static constexpr int64_t kPoisonIndex = -1;
+  }];
+}
 
 #endif // MLIR_DIALECT_VECTOR_IR_VECTOR

diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 4331eda1661960..3b027dcfdfc70a 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -469,10 +469,7 @@ def Vector_ShuffleOp
     ```
   }];
 
-  let extraClassDeclaration = [{
-    // Integer to represent a poison value in a vector shuffle mask.
-    static constexpr int64_t kMaskPoisonValue = -1;
-
+  let extraClassDeclaration = extraPoisonClassDeclaration # [{
     VectorType getV1VectorType() {
       return ::llvm::cast<VectorType>(getV1().getType());
     }
@@ -693,9 +690,10 @@ def Vector_ExtractOp :
     Takes an n-D vector and a k-D position and extracts the (n-k)-D vector at
     the proper position. Degenerates to an element type if n-k is zero.
 
-    Dynamic indices must be greater or equal to zero and less than the size of
-    the corresponding dimension. The result is undefined if any index is
-    out-of-bounds.
+    Static and dynamic indices must be greater or equal to zero and less than
+    the size of the corresponding dimension. The result is undefined if any
+    index is out-of-bounds. The value `-1` represents a poison index, which
+    specifies that the extracted element is poison.
 
     Example:
 
@@ -705,9 +703,8 @@ def Vector_ExtractOp :
     %3 = vector.extract %1[]: vector<f32> from vector<f32>
     %4 = vector.extract %0[%a, %b, %c]: f32 from vector<4x8x16xf32>
     %5 = vector.extract %0[2, %b]: vector<16xf32> from vector<4x8x16xf32>
+    %6 = vector.extract %10[-1, %c]: f32 from vector<4x16xf32>
     ```
-
-    TODO: Implement support for poison indices.
   }];
 
   let arguments = (ins
@@ -724,7 +721,7 @@ def Vector_ExtractOp :
     OpBuilder<(ins "Value":$source, "ArrayRef<OpFoldResult>":$position)>,
   ];
 
-  let extraClassDeclaration = [{
+  let extraClassDeclaration = extraPoisonClassDeclaration # [{
     VectorType getSourceVectorType() {
       return ::llvm::cast<VectorType>(getVector().getType());
     }
@@ -885,9 +882,10 @@ def Vector_InsertOp :
     and inserts the n-D source into the (n+k)-D destination at the proper
     position. Degenerates to a scalar or a 0-d vector source type when n = 0.
 
-    Dynamic indices must be greater or equal to zero and less than the size of
-    the corresponding dimension. The result is undefined if any index is
-    out-of-bounds.
+    Static and dynamic indices must be greater or equal to zero and less than
+    the size of the corresponding dimension. The result is undefined if any
+    index is out-of-bounds. The value `-1` represents a poison index, which
+    specifies that the resulting vector is poison.
 
     Example:
 
@@ -897,9 +895,8 @@ def Vector_InsertOp :
     %8 = vector.insert %6, %7[] : f32 into vector<f32>
     %11 = vector.insert %9, %10[%a, %b, %c] : vector<f32> into vector<4x8x16xf32>
     %12 = vector.insert %4, %10[2, %b] : vector<16xf32> into vector<4x8x16xf32>
+    %13 = vector.insert %20, %1[-1, %c] : f32 into vector<4x16xf32>
     ```
-
-    TODO: Implement support for poison indices.
   }];
 
   let arguments = (ins
@@ -917,7 +914,7 @@ def Vector_InsertOp :
     OpBuilder<(ins "Value":$source, "Value":$dest, "ArrayRef<OpFoldResult>":$position)>,
   ];
 
-  let extraClassDeclaration = [{
+  let extraClassDeclaration = extraPoisonClassDeclaration # [{
     Type getSourceType() { return getSource().getType(); }
     VectorType getDestVectorType() {
       return ::llvm::cast<VectorType>(getDest().getType());
@@ -990,15 +987,13 @@ def Vector_ScalableInsertOp :
     ```mlir
     %2 = vector.scalable.insert %0, %1[5] : vector<4xf32> into vector<[16]xf32>
     ```
-
-    TODO: Implement support for poison indices.
   }];
 
   let assemblyFormat = [{
     $source `,` $dest `[` $pos `]` attr-dict `:` type($source) `into` type($dest)
   }];
 
-  let extraClassDeclaration = [{
+  let extraClassDeclaration = extraPoisonClassDeclaration # [{
     VectorType getSourceVectorType() {
       return ::llvm::cast<VectorType>(getSource().getType());
     }
@@ -1043,15 +1038,13 @@ def Vector_ScalableExtractOp :
     ```mlir
     %1 = vector.scalable.extract %0[5] : vector<4xf32> from vector<[16]xf32>
     ```
-
-    TODO: Implement support for poison indices.
   }];
 
   let assemblyFormat = [{
     $source `[` $pos `]` attr-dict `:` type($res) `from` type($source)
   }];
 
-  let extraClassDeclaration = [{
+  let extraClassDeclaration = extraPoisonClassDeclaration # [{
     VectorType getSourceVectorType() {
       return ::llvm::cast<VectorType>(getSource().getType());
     }
@@ -1089,8 +1082,6 @@ def Vector_InsertStridedSliceOp :
         {offsets = [0, 0, 2], strides = [1, 1]}:
       vector<2x4xf32> into vector<16x4x8xf32>
     ```
-
-    TODO: Implement support for poison indices.
   }];
 
   let assemblyFormat = [{

diff  --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index c4a8e7a81fa483..a39ab77fc8fb3b 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -28,6 +28,7 @@ def Canonicalizer : Pass<"canonicalize"> {
     details.
   }];
   let constructor = "mlir::createCanonicalizerPass()";
+  let dependentDialects = ["ub::UBDialect"];
   let options = [
     Option<"topDownProcessingEnabled", "top-down", "bool",
            /*default=*/"true",

diff  --git a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
index ab9c048f561069..4481c0a4973544 100644
--- a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
@@ -18,7 +18,6 @@
 #include "mlir/Dialect/Arith/Transforms/Passes.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
-#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
@@ -27,7 +26,6 @@
 #include "mlir/Pass/Pass.h"
 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
 #include "mlir/Transforms/DialectConversion.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include <memory>
 
 #define DEBUG_TYPE "convert-to-spirv"

diff  --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index d3731db1ce55c9..af882cb1ca6e91 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -13,7 +13,6 @@
 #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"

diff  --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp
index 1932de1be603b6..cc115b1d368262 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/DialectConversion.h"
 

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 3fbfcb4979b495..b35422f4ca3a9f 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -19,6 +19,7 @@
 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/IR/AffineExpr.h"
@@ -1274,6 +1275,13 @@ OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) {
   return srcElements[posIdx];
 }
 
+// Returns `true` if `index` is either within [0, maxIndex) or equal to
+// `poisonValue`.
+static bool isValidPositiveIndexOrPoison(int64_t index, int64_t poisonValue,
+                                         int64_t maxIndex) {
+  return index == poisonValue || (index >= 0 && index < maxIndex);
+}
+
 //===----------------------------------------------------------------------===//
 // ExtractOp
 //===----------------------------------------------------------------------===//
@@ -1355,11 +1363,12 @@ LogicalResult vector::ExtractOp::verify() {
   for (auto [idx, pos] : llvm::enumerate(position)) {
     if (auto attr = dyn_cast<Attribute>(pos)) {
       int64_t constIdx = cast<IntegerAttr>(attr).getInt();
-      if (constIdx < 0 || constIdx >= getSourceVectorType().getDimSize(idx)) {
+      if (!isValidPositiveIndexOrPoison(
+              constIdx, kPoisonIndex, getSourceVectorType().getDimSize(idx))) {
         return emitOpError("expected position attribute #")
                << (idx + 1)
                << " to be a non-negative integer smaller than the "
-                  "corresponding vector dimension";
+                  "corresponding vector dimension or poison (-1)";
       }
     }
   }
@@ -1977,12 +1986,26 @@ static Value foldScalarExtractFromFromElements(ExtractOp extractOp) {
   return fromElementsOp.getElements()[flatIndex];
 }
 
-OpFoldResult ExtractOp::fold(FoldAdaptor) {
+/// Fold an insert or extract operation into an poison value when a poison index
+/// is found at any dimension of the static position.
+static ub::PoisonAttr
+foldPoisonIndexInsertExtractOp(MLIRContext *context,
+                               ArrayRef<int64_t> staticPos, int64_t poisonVal) {
+  if (!llvm::is_contained(staticPos, poisonVal))
+    return ub::PoisonAttr();
+
+  return ub::PoisonAttr::get(context);
+}
+
+OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
   // Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
   // Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
   // mismatch).
   if (getNumIndices() == 0 && getVector().getType() == getResult().getType())
     return getVector();
+  if (auto res = foldPoisonIndexInsertExtractOp(
+          getContext(), adaptor.getStaticPosition(), kPoisonIndex))
+    return res;
   if (succeeded(foldExtractOpFromExtractChain(*this)))
     return getResult();
   if (auto res = ExtractFromInsertTransposeChainState(*this).fold())
@@ -2249,6 +2272,21 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
                                          resultType.getNumElements()));
   return success();
 }
+
+/// Fold an insert or extract operation into an poison value when a poison index
+/// is found at any dimension of the static position.
+template <typename OpTy>
+LogicalResult
+canonicalizePoisonIndexInsertExtractOp(OpTy op, PatternRewriter &rewriter) {
+  if (auto poisonAttr = foldPoisonIndexInsertExtractOp(
+          op.getContext(), op.getStaticPosition(), OpTy::kPoisonIndex)) {
+    rewriter.replaceOpWithNewOp<ub::PoisonOp>(op, op.getType(), poisonAttr);
+    return success();
+  }
+
+  return failure();
+}
+
 } // namespace
 
 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
@@ -2257,6 +2295,7 @@ void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
               ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
   results.add(foldExtractFromShapeCastToShapeCast);
   results.add(foldExtractFromFromElements);
+  results.add(canonicalizePoisonIndexInsertExtractOp<ExtractOp>);
 }
 
 static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
@@ -2600,7 +2639,7 @@ LogicalResult ShuffleOp::verify() {
   int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
                       (v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
   for (auto [idx, maskPos] : llvm::enumerate(mask)) {
-    if (maskPos != kMaskPoisonValue && (maskPos < 0 || maskPos >= indexSize))
+    if (!isValidPositiveIndexOrPoison(maskPos, kPoisonIndex, indexSize))
       return emitOpError("mask index #") << (idx + 1) << " out of range";
   }
   return success();
@@ -2882,7 +2921,8 @@ LogicalResult InsertOp::verify() {
   for (auto [idx, pos] : llvm::enumerate(position)) {
     if (auto attr = pos.dyn_cast<Attribute>()) {
       int64_t constIdx = cast<IntegerAttr>(attr).getInt();
-      if (constIdx < 0 || constIdx >= destVectorType.getDimSize(idx)) {
+      if (!isValidPositiveIndexOrPoison(constIdx, kPoisonIndex,
+                                        destVectorType.getDimSize(idx))) {
         return emitOpError("expected position attribute #")
                << (idx + 1)
                << " to be a non-negative integer smaller than the "
@@ -3020,6 +3060,7 @@ void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                            MLIRContext *context) {
   results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
               InsertOpConstantFolder>(context);
+  results.add(canonicalizePoisonIndexInsertExtractOp<InsertOp>);
 }
 
 OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
@@ -3028,6 +3069,10 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
   // (type mismatch).
   if (getNumIndices() == 0 && getSourceType() == getType())
     return getSource();
+  if (auto res = foldPoisonIndexInsertExtractOp(
+          getContext(), adaptor.getStaticPosition(), kPoisonIndex))
+    return res;
+
   return {};
 }
 

diff  --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt
index 058039e47313e2..3a8088bccf2994 100644
--- a/mlir/lib/Transforms/CMakeLists.txt
+++ b/mlir/lib/Transforms/CMakeLists.txt
@@ -37,4 +37,5 @@ add_mlir_library(MLIRTransforms
   MLIRSideEffectInterfaces
   MLIRSupport
   MLIRTransformUtils
+  MLIRUBDialect
   )

diff  --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp
index 5f469605070367..7ccd503fb02882 100644
--- a/mlir/lib/Transforms/Canonicalizer.cpp
+++ b/mlir/lib/Transforms/Canonicalizer.cpp
@@ -13,6 +13,7 @@
 
 #include "mlir/Transforms/Passes.h"
 
+#include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 

diff  --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 29bed9aae56827..62649b83d887d1 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1250,6 +1250,16 @@ func.func @extract_scalar_from_vec_1d_f32(%arg0: vector<16xf32>) -> f32 {
 
 // -----
 
+func.func @extract_poison_idx(%arg0: vector<16xf32>) -> f32 {
+  %0 = vector.extract %arg0[-1]: f32 from vector<16xf32>
+  return %0 : f32
+}
+// CHECK-LABEL: @extract_poison_idx
+//       CHECK:   %[[IDX:.*]] = llvm.mlir.constant(-1 : i64) : i64
+//       CHECK:   llvm.extractelement {{.*}}[%[[IDX]] : i64] : vector<16xf32>
+
+// -----
+
 func.func @extract_scalar_from_vec_1d_f32_scalable(%arg0: vector<[16]xf32>) -> f32 {
   %0 = vector.extract %arg0[15]: f32 from vector<[16]xf32>
   return %0 : f32

diff  --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index fd73cea5e4f306..383215c016039a 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -175,6 +175,14 @@ func.func @extract(%arg0 : vector<2xf32>) -> (vector<1xf32>, f32) {
 
 // -----
 
+func.func @extract_poison_idx(%arg0 : vector<4xf32>) -> f32 {
+  // expected-error at +1 {{index -1 out of bounds for 'vector<4xf32>'}}
+  %0 = vector.extract %arg0[-1] : f32 from vector<4xf32>
+  return %0: f32
+}
+
+// -----
+
 // CHECK-LABEL: @extract_size1_vector
 //  CHECK-SAME: %[[ARG0:.+]]: vector<1xf32>
 //       CHECK:   %[[R:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
@@ -256,6 +264,14 @@ func.func @insert(%arg0 : vector<4xf32>, %arg1: f32) -> vector<4xf32> {
 
 // -----
 
+func.func @insert_poison_idx(%arg0 : vector<4xf32>, %arg1: f32) -> vector<4xf32> {
+  // expected-error at +1 {{index -1 out of bounds for 'vector<4xf32>'}}
+  %1 = vector.insert %arg1, %arg0[-1] : f32 into vector<4xf32>
+  return %1: vector<4xf32>
+}
+
+// -----
+
 // CHECK-LABEL: @insert_index_vector
 //       CHECK:   spirv.CompositeInsert %{{.+}}, %{{.+}}[2 : i32] : i32 into vector<4xi32>
 func.func @insert_index_vector(%arg0 : vector<4xindex>, %arg1: index) -> vector<4xindex> {

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 0eebb6e8d612d4..f9e3b772f9f0a2 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -132,6 +132,37 @@ func.func @extract_from_create_mask_dynamic_position(%dim0: index, %index: index
 
 // -----
 
+// CHECK-LABEL: @extract_scalar_poison_idx
+func.func @extract_scalar_poison_idx(%a: vector<4x5xf32>) -> f32 {
+  //  CHECK-NOT: vector.extract
+  // CHECK-NEXT: ub.poison : f32
+  %0 = vector.extract %a[-1, 0] : f32 from vector<4x5xf32>
+  return %0 : f32
+}
+
+// -----
+
+// CHECK-LABEL: @extract_vector_poison_idx
+func.func @extract_vector_poison_idx(%a: vector<4x5xf32>) -> vector<5xf32> {
+  //  CHECK-NOT: vector.extract
+  // CHECK-NEXT: ub.poison : vector<5xf32>
+  %0 = vector.extract %a[-1] : vector<5xf32> from vector<4x5xf32>
+  return %0 : vector<5xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @extract_multiple_poison_idx
+func.func @extract_multiple_poison_idx(%a: vector<4x5x8xf32>)
+    -> vector<8xf32> {
+  //  CHECK-NOT: vector.extract
+  // CHECK-NEXT: ub.poison : vector<8xf32>
+  %0 = vector.extract %a[-1, -1] : vector<8xf32> from vector<4x5x8xf32>
+  return %0 : vector<8xf32>
+}
+
+// -----
+
 // CHECK-LABEL: extract_from_create_mask_dynamic_position_all_false
 //  CHECK-SAME: %[[DIM0:.*]]: index, %[[INDEX:.*]]: index
 func.func @extract_from_create_mask_dynamic_position_all_false(%dim0: index, %index: index) -> vector<6xi1> {
@@ -2778,7 +2809,6 @@ func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<
   return %0, %1, %2 : vector<2x3xf32>, vector<2x3xf32>, vector<f32>
 }
 
-
 // -----
 
 // CHECK-LABEL: func @vector_insert_const_regression(
@@ -2792,6 +2822,39 @@ func.func @vector_insert_const_regression(%arg0: i8) -> vector<4xi8> {
 
 // -----
 
+// CHECK-LABEL: @insert_scalar_poison_idx
+func.func @insert_scalar_poison_idx(%a: vector<4x5xf32>, %b: f32)
+    -> vector<4x5xf32> {
+  //  CHECK-NOT: vector.insert
+  // CHECK-NEXT: ub.poison : vector<4x5xf32>
+  %0 = vector.insert %b, %a[-1, 0] : f32 into vector<4x5xf32>
+  return %0 : vector<4x5xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @insert_vector_poison_idx
+func.func @insert_vector_poison_idx(%a: vector<4x5xf32>, %b: vector<5xf32>)
+    -> vector<4x5xf32> {
+  //  CHECK-NOT: vector.insert
+  // CHECK-NEXT: ub.poison : vector<4x5xf32>
+  %0 = vector.insert %b, %a[-1] : vector<5xf32> into vector<4x5xf32>
+  return %0 : vector<4x5xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @insert_multiple_poison_idx
+func.func @insert_multiple_poison_idx(%a: vector<4x5x8xf32>, %b: vector<8xf32>)
+    -> vector<4x5x8xf32> {
+  //  CHECK-NOT: vector.insert
+  // CHECK-NEXT: ub.poison : vector<4x5x8xf32>
+  %0 = vector.insert %b, %a[-1, -1] : vector<8xf32> into vector<4x5x8xf32>
+  return %0 : vector<4x5x8xf32>
+}
+
+// -----
+
 // CHECK-LABEL: @contiguous_extract_strided_slices_to_extract
 // CHECK:        %[[EXTRACT:.+]] = vector.extract {{.*}}[0, 0, 0, 0, 0] : vector<4xi32> from vector<8x1x2x1x1x4xi32>
 // CHECK-NEXT:   return %[[EXTRACT]] :  vector<4xi32>

diff  --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 1a70791fae1257..57e348c7d59912 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -186,8 +186,8 @@ func.func @extract_0d(%arg0: vector<f32>) {
 // -----
 
 func.func @extract_position_overflow(%arg0: vector<4x8x16xf32>) {
-  // expected-error at +1 {{expected position attribute #3 to be a non-negative integer smaller than the corresponding vector dimension}}
-  %1 = vector.extract %arg0[0, 0, -1] : f32 from vector<4x8x16xf32>
+  // expected-error at +1 {{expected position attribute #3 to be a non-negative integer smaller than the corresponding vector dimension or poison (-1)}}
+  %1 = vector.extract %arg0[0, 0, -5] : f32 from vector<4x8x16xf32>
 }
 
 // -----
@@ -247,7 +247,7 @@ func.func @insert_vector_type(%a: f32, %b: vector<4x8x16xf32>) {
 
 func.func @insert_position_overflow(%a: f32, %b: vector<4x8x16xf32>) {
   // expected-error at +1 {{expected position attribute #3 to be a non-negative integer smaller than the corresponding dest vector dimension}}
-  %1 = vector.insert %a, %b[0, 0, -1] : f32 into vector<4x8x16xf32>
+  %1 = vector.insert %a, %b[0, 0, -5] : f32 into vector<4x8x16xf32>
 }
 
 // -----

diff  --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index cd6f3f518a1c07..67484e06f456dc 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -247,6 +247,13 @@ func.func @extract_0d(%a: vector<f32>) -> f32 {
   return %0 : f32
 }
 
+// CHECK-LABEL: @extract_poison_idx
+func.func @extract_poison_idx(%a: vector<4x5xf32>) -> f32 {
+  // CHECK-NEXT: vector.extract %{{.*}}[-1, 0] : f32 from vector<4x5xf32>
+  %0 = vector.extract %a[-1, 0] : f32 from vector<4x5xf32>
+  return %0 : f32
+}
+
 // CHECK-LABEL: @insert_element_0d
 func.func @insert_element_0d(%a: f32, %b: vector<f32>) -> vector<f32> {
   // CHECK-NEXT: vector.insertelement %{{.*}}, %{{.*}}[] : vector<f32>
@@ -299,6 +306,13 @@ func.func @insert_0d(%a: f32, %b: vector<f32>, %c: vector<2x3xf32>) -> (vector<f
   return %1, %2 : vector<f32>, vector<2x3xf32>
 }
 
+// CHECK-LABEL: @insert_poison_idx
+func.func @insert_poison_idx(%a: vector<4x5xf32>, %b: f32) {
+  // CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[-1, 0] : f32 into vector<4x5xf32>
+  vector.insert %b, %a[-1, 0] : f32 into vector<4x5xf32>
+  return
+}
+
 // CHECK-LABEL: @outerproduct
 func.func @outerproduct(%arg0: vector<4xf32>, %arg1: vector<8xf32>, %arg2: vector<4x8xf32>) -> vector<4x8xf32> {
   // CHECK: vector.outerproduct {{.*}} : vector<4xf32>, vector<8xf32>


        


More information about the Mlir-commits mailing list