[Mlir-commits] [mlir] 98f6289 - [mlir][Vector] Add support for Value indices to vector.extract/insert

Diego Caballero llvmlistbot at llvm.org
Thu Sep 21 17:41:54 PDT 2023


Author: Diego Caballero
Date: 2023-09-22T00:39:32Z
New Revision: 98f6289a34bdaf7bc6cda8768e26e4405fc7726e

URL: https://github.com/llvm/llvm-project/commit/98f6289a34bdaf7bc6cda8768e26e4405fc7726e
DIFF: https://github.com/llvm/llvm-project/commit/98f6289a34bdaf7bc6cda8768e26e4405fc7726e.diff

LOG: [mlir][Vector] Add support for Value indices to vector.extract/insert

`vector.extract/insert` ops only support constant indices. This PR is
extending them so that arbitrary values can be used instead.

This work is part of the RFC: https://discourse.llvm.org/t/rfc-psa-remove-vector-extractelement-and-vector-insertelement-ops-in-favor-of-vector-extract-and-vector-insert-ops

Differential Revision: https://reviews.llvm.org/D155034

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
    mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
    mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
    mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
    mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
    mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
    mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
    mlir/test/Dialect/Vector/invalid.mlir
    mlir/test/Dialect/Vector/ops.mlir
    mlir/test/Dialect/Vector/vector-transforms.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
index fcf7eb4a616b073..fc0c80036ff79ad 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
@@ -131,6 +131,24 @@ inline bool isReductionIterator(Attribute attr) {
   return cast<IteratorTypeAttr>(attr).getValue() == IteratorType::reduction;
 }
 
+/// Returns the integer numbers in `values`. `values` are expected to be
+/// constant operations.
+SmallVector<int64_t> getAsIntegers(ArrayRef<Value> values);
+
+/// Returns the integer numbers in `foldResults`. `foldResults` are expected to
+/// be constant operations.
+SmallVector<int64_t> getAsIntegers(ArrayRef<OpFoldResult> foldResults);
+
+/// Convert `foldResults` into Values. Integer attributes are converted to
+/// constant op.
+SmallVector<Value> getAsValues(OpBuilder &builder, Location loc,
+                               ArrayRef<OpFoldResult> foldResults);
+
+/// Returns the constant index ops in `values`. `values` are expected to be
+/// constant operations.
+SmallVector<arith::ConstantIndexOp>
+getAsConstantIndexOps(ArrayRef<Value> values);
+
 //===----------------------------------------------------------------------===//
 // Vector Masking Utilities
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 701eefcc1e7da6a..ea96f2660126870 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -523,9 +523,7 @@ def Vector_ExtractOp :
   Vector_Op<"extract", [Pure,
      PredOpTrait<"operand and result have same element type",
                  TCresVTEtIsSameAsOpBase<0, 0>>,
-     InferTypeOpAdaptorWithIsCompatible]>,
-    Arguments<(ins AnyVectorOfAnyRank:$vector, DenseI64ArrayAttr:$position)>,
-    Results<(outs AnyType)> {
+     InferTypeOpAdaptorWithIsCompatible]> {
   let summary = "extract operation";
   let description = [{
     Takes an n-D vector and a k-D position and extracts the (n-k)-D vector at
@@ -535,21 +533,55 @@ def Vector_ExtractOp :
 
     ```mlir
     %1 = vector.extract %0[3]: vector<4x8x16xf32>
-    %2 = vector.extract %0[3, 3, 3]: vector<4x8x16xf32>
+    %2 = vector.extract %0[2, 1, 3]: vector<4x8x16xf32>
     %3 = vector.extract %1[]: vector<f32>
+    %4 = vector.extract %0[%a, %b, %c]: vector<4x8x16xf32>
+    %5 = vector.extract %0[2, %b]: vector<4x8x16xf32>
     ```
   }];
+
+  let arguments = (ins
+    AnyVectorOfAnyRank:$vector,
+    Variadic<Index>:$dynamic_position,
+    DenseI64ArrayAttr:$static_position
+  );
+  let results = (outs AnyType:$result);
+
   let builders = [
-    // Convenience builder which assumes the values in `position` are defined by
-    // ConstantIndexOp.
-    OpBuilder<(ins "Value":$source, "ValueRange":$position)>
+    OpBuilder<(ins "Value":$source, "int64_t":$position)>,
+    OpBuilder<(ins "Value":$source, "OpFoldResult":$position)>,
+    OpBuilder<(ins "Value":$source, "ArrayRef<int64_t>":$position)>,
+    OpBuilder<(ins "Value":$source, "ArrayRef<OpFoldResult>":$position)>,
   ];
+
   let extraClassDeclaration = [{
     VectorType getSourceVectorType() {
       return ::llvm::cast<VectorType>(getVector().getType());
     }
+
+    /// Return a vector with all the static and dynamic position indices.
+    SmallVector<OpFoldResult> getMixedPosition() {
+      OpBuilder builder(getContext());
+      return getMixedValues(getStaticPosition(), getDynamicPosition(), builder);
+    }
+
+    unsigned getNumIndices() {
+      return getStaticPosition().size();
+    }
+
+    bool hasDynamicPosition() {
+      auto dynPos = getDynamicPosition();
+      return std::any_of(dynPos.begin(), dynPos.end(),
+                         [](Value operand) { return operand != nullptr; });
+    }
   }];
-  let assemblyFormat = "$vector `` $position attr-dict `:` type($vector)";
+
+  let assemblyFormat = [{
+    $vector ``
+    custom<DynamicIndexList>($dynamic_position, $static_position)
+    attr-dict `:` type($vector)
+  }];
+
   let hasCanonicalizer = 1;
   let hasFolder = 1;
   let hasVerifier = 1;
@@ -638,9 +670,7 @@ def Vector_InsertOp :
   Vector_Op<"insert", [Pure,
      PredOpTrait<"source operand and result have same element type",
                  TCresVTEtIsSameAsOpBase<0, 0>>,
-     AllTypesMatch<["dest", "res"]>]>,
-     Arguments<(ins AnyType:$source, AnyVectorOfAnyRank:$dest, DenseI64ArrayAttr:$position)>,
-     Results<(outs AnyVectorOfAnyRank:$res)> {
+     AllTypesMatch<["dest", "result"]>]> {
   let summary = "insert operation";
   let description = [{
     Takes an n-D source vector, an (n+k)-D destination vector and a k-D position
@@ -651,24 +681,53 @@ def Vector_InsertOp :
 
     ```mlir
     %2 = vector.insert %0, %1[3] : vector<8x16xf32> into vector<4x8x16xf32>
-    %5 = vector.insert %3, %4[3, 3, 3] : f32 into vector<4x8x16xf32>
+    %5 = vector.insert %3, %4[2, 1, 3] : f32 into vector<4x8x16xf32>
     %8 = vector.insert %6, %7[] : f32 into vector<f32>
-    %11 = vector.insert %9, %10[3, 3, 3] : vector<f32> into vector<4x8x16xf32>
+    %11 = vector.insert %9, %10[%a, %b, %c] : vector<f32> into vector<4x8x16xf32>
+    %12 = vector.insert %4, %10[2, %b] : vector<16xf32> into vector<4x8x16xf32>
     ```
   }];
-  let assemblyFormat = [{
-    $source `,` $dest $position attr-dict `:` type($source) `into` type($dest)
-  }];
+
+  let arguments = (ins
+    AnyType:$source,
+    AnyVectorOfAnyRank:$dest,
+    Variadic<Index>:$dynamic_position,
+    DenseI64ArrayAttr:$static_position
+  );
+  let results = (outs AnyVectorOfAnyRank:$result);
 
   let builders = [
-    // Convenience builder which assumes all values are constant indices.
-    OpBuilder<(ins "Value":$source, "Value":$dest, "ValueRange":$position)>
+    OpBuilder<(ins "Value":$source, "Value":$dest, "int64_t":$position)>,
+    OpBuilder<(ins "Value":$source, "Value":$dest, "OpFoldResult":$position)>,
+    OpBuilder<(ins "Value":$source, "Value":$dest, "ArrayRef<int64_t>":$position)>,
+    OpBuilder<(ins "Value":$source, "Value":$dest, "ArrayRef<OpFoldResult>":$position)>,
   ];
+
   let extraClassDeclaration = [{
     Type getSourceType() { return getSource().getType(); }
     VectorType getDestVectorType() {
       return ::llvm::cast<VectorType>(getDest().getType());
     }
+
+    /// Return a vector with all the static and dynamic position indices.
+    SmallVector<OpFoldResult> getMixedPosition() {
+      OpBuilder builder(getContext());
+      return getMixedValues(getStaticPosition(), getDynamicPosition(), builder);
+    }
+
+    unsigned getNumIndices() {
+      return getStaticPosition().size();
+    }
+
+    bool hasDynamicPosition() {
+      return llvm::any_of(getDynamicPosition(),
+                          [](Value operand) { return operand != nullptr; });
+    }
+  }];
+
+  let assemblyFormat = [{
+    $source `,` $dest custom<DynamicIndexList>($dynamic_position, $static_position)
+    attr-dict `:` type($source) `into` type($dest)
   }];
 
   let hasCanonicalizer = 1;

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 2c08257fc3089b6..3f77c5b5f24e9b5 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -126,6 +126,18 @@ static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc,
   return rewriter.create<LLVM::BitcastOp>(loc, pType, ptr);
 }
 
+/// Convert `foldResult` into a Value. Integer attribute is converted to
+/// an LLVM constant op.
+static Value getAsLLVMValue(OpBuilder &builder, Location loc,
+                            OpFoldResult foldResult) {
+  if (auto attr = foldResult.dyn_cast<Attribute>()) {
+    auto intAttr = cast<IntegerAttr>(attr);
+    return builder.create<LLVM::ConstantOp>(loc, intAttr).getResult();
+  }
+
+  return foldResult.get<Value>();
+}
+
 namespace {
 
 /// Trivial Vector to LLVM conversions
@@ -1079,41 +1091,53 @@ class VectorExtractOpConversion
     auto loc = extractOp->getLoc();
     auto resultType = extractOp.getResult().getType();
     auto llvmResultType = typeConverter->convertType(resultType);
-    ArrayRef<int64_t> positionArray = extractOp.getPosition();
-
     // Bail if result type cannot be lowered.
     if (!llvmResultType)
       return failure();
 
+    SmallVector<OpFoldResult> positionVec;
+    for (auto [idx, pos] : llvm::enumerate(extractOp.getMixedPosition())) {
+      if (pos.is<Value>())
+        // Make sure we use the value that has been already converted to LLVM.
+        positionVec.push_back(adaptor.getDynamicPosition()[idx]);
+      else
+        positionVec.push_back(pos);
+    }
+
     // Extract entire vector. Should be handled by folder, but just to be safe.
-    if (positionArray.empty()) {
+    ArrayRef<OpFoldResult> position(positionVec);
+    if (position.empty()) {
       rewriter.replaceOp(extractOp, adaptor.getVector());
       return success();
     }
 
     // One-shot extraction of vector from array (only requires extractvalue).
     if (isa<VectorType>(resultType)) {
+      if (extractOp.hasDynamicPosition())
+        return failure();
+
       Value extracted = rewriter.create<LLVM::ExtractValueOp>(
-          loc, adaptor.getVector(), positionArray);
+          loc, adaptor.getVector(), getAsIntegers(position));
       rewriter.replaceOp(extractOp, extracted);
       return success();
     }
 
     // Potential extraction of 1-D vector from array.
     Value extracted = adaptor.getVector();
-    if (positionArray.size() > 1) {
-      extracted = rewriter.create<LLVM::ExtractValueOp>(
-          loc, extracted, positionArray.drop_back());
-    }
+    if (position.size() > 1) {
+      if (extractOp.hasDynamicPosition())
+        return failure();
 
-    // Remaining extraction of element from 1-D LLVM vector
-    auto i64Type = IntegerType::get(rewriter.getContext(), 64);
-    auto constant =
-        rewriter.create<LLVM::ConstantOp>(loc, i64Type, positionArray.back());
-    extracted =
-        rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
-    rewriter.replaceOp(extractOp, extracted);
+      SmallVector<int64_t> nMinusOnePosition =
+          getAsIntegers(position.drop_back());
+      extracted = rewriter.create<LLVM::ExtractValueOp>(loc, extracted,
+                                                        nMinusOnePosition);
+    }
 
+    Value lastPosition = getAsLLVMValue(rewriter, loc, position.back());
+    // Remaining extraction of element from 1-D LLVM vector.
+    rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(extractOp, extracted,
+                                                        lastPosition);
     return success();
   }
 };
@@ -1194,23 +1218,34 @@ class VectorInsertOpConversion
     auto sourceType = insertOp.getSourceType();
     auto destVectorType = insertOp.getDestVectorType();
     auto llvmResultType = typeConverter->convertType(destVectorType);
-    ArrayRef<int64_t> positionArray = insertOp.getPosition();
-
     // Bail if result type cannot be lowered.
     if (!llvmResultType)
       return failure();
 
+    SmallVector<OpFoldResult> positionVec;
+    for (auto [idx, pos] : llvm::enumerate(insertOp.getMixedPosition())) {
+      if (pos.is<Value>())
+        // Make sure we use the value that has been already converted to LLVM.
+        positionVec.push_back(adaptor.getDynamicPosition()[idx]);
+      else
+        positionVec.push_back(pos);
+    }
+
     // Overwrite entire vector with value. Should be handled by folder, but
     // just to be safe.
-    if (positionArray.empty()) {
+    ArrayRef<OpFoldResult> position(positionVec);
+    if (position.empty()) {
       rewriter.replaceOp(insertOp, adaptor.getSource());
       return success();
     }
 
     // One-shot insertion of a vector into an array (only requires insertvalue).
     if (isa<VectorType>(sourceType)) {
+      if (insertOp.hasDynamicPosition())
+        return failure();
+
       Value inserted = rewriter.create<LLVM::InsertValueOp>(
-          loc, adaptor.getDest(), adaptor.getSource(), positionArray);
+          loc, adaptor.getDest(), adaptor.getSource(), getAsIntegers(position));
       rewriter.replaceOp(insertOp, inserted);
       return success();
     }
@@ -1218,24 +1253,28 @@ class VectorInsertOpConversion
     // Potential extraction of 1-D vector from array.
     Value extracted = adaptor.getDest();
     auto oneDVectorType = destVectorType;
-    if (positionArray.size() > 1) {
+    if (position.size() > 1) {
+      if (insertOp.hasDynamicPosition())
+        return failure();
+
       oneDVectorType = reducedVectorTypeBack(destVectorType);
       extracted = rewriter.create<LLVM::ExtractValueOp>(
-          loc, extracted, positionArray.drop_back());
+          loc, extracted, getAsIntegers(position.drop_back()));
     }
 
     // Insertion of an element into a 1-D LLVM vector.
-    auto i64Type = IntegerType::get(rewriter.getContext(), 64);
-    auto constant =
-        rewriter.create<LLVM::ConstantOp>(loc, i64Type, positionArray.back());
     Value inserted = rewriter.create<LLVM::InsertElementOp>(
         loc, typeConverter->convertType(oneDVectorType), extracted,
-        adaptor.getSource(), constant);
+        adaptor.getSource(), getAsLLVMValue(rewriter, loc, position.back()));
 
     // Potential insertion of resulting 1-D vector into array.
-    if (positionArray.size() > 1) {
+    if (position.size() > 1) {
+      if (insertOp.hasDynamicPosition())
+        return failure();
+
       inserted = rewriter.create<LLVM::InsertValueOp>(
-          loc, adaptor.getDest(), inserted, positionArray.drop_back());
+          loc, adaptor.getDest(), inserted,
+          getAsIntegers(position.drop_back()));
     }
 
     rewriter.replaceOp(insertOp, inserted);

diff  --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 1aeed4594f94505..f8fd89c542c0699 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -1063,10 +1063,11 @@ struct UnrollTransferReadConversion
   /// If the result of the TransferReadOp has exactly one user, which is a
   /// vector::InsertOp, return that operation's indices.
   void getInsertionIndices(TransferReadOp xferOp,
-                           SmallVector<int64_t, 8> &indices) const {
-    if (auto insertOp = getInsertOp(xferOp))
-      indices.assign(insertOp.getPosition().begin(),
-                     insertOp.getPosition().end());
+                           SmallVectorImpl<OpFoldResult> &indices) const {
+    if (auto insertOp = getInsertOp(xferOp)) {
+      auto pos = insertOp.getMixedPosition();
+      indices.append(pos.begin(), pos.end());
+    }
   }
 
   /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
@@ -1110,9 +1111,9 @@ struct UnrollTransferReadConversion
             getXferIndices(b, xferOp, iv, xferIndices);
 
             // Indices for the new vector.insert op.
-            SmallVector<int64_t, 8> insertionIndices;
+            SmallVector<OpFoldResult, 8> insertionIndices;
             getInsertionIndices(xferOp, insertionIndices);
-            insertionIndices.push_back(i);
+            insertionIndices.push_back(rewriter.getIndexAttr(i));
 
             auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
             auto newXferOp = b.create<vector::TransferReadOp>(
@@ -1195,10 +1196,11 @@ struct UnrollTransferWriteConversion
   /// If the input of the given TransferWriteOp is an ExtractOp, return its
   /// indices.
   void getExtractionIndices(TransferWriteOp xferOp,
-                            SmallVector<int64_t, 8> &indices) const {
-    if (auto extractOp = getExtractOp(xferOp))
-      indices.assign(extractOp.getPosition().begin(),
-                     extractOp.getPosition().end());
+                            SmallVectorImpl<OpFoldResult> &indices) const {
+    if (auto extractOp = getExtractOp(xferOp)) {
+      auto pos = extractOp.getMixedPosition();
+      indices.append(pos.begin(), pos.end());
+    }
   }
 
   /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
@@ -1235,9 +1237,9 @@ struct UnrollTransferWriteConversion
             getXferIndices(b, xferOp, iv, xferIndices);
 
             // Indices for the new vector.extract op.
-            SmallVector<int64_t, 8> extractionIndices;
+            SmallVector<OpFoldResult, 8> extractionIndices;
             getExtractionIndices(xferOp, extractionIndices);
-            extractionIndices.push_back(i);
+            extractionIndices.push_back(b.getI64IntegerAttr(i));
 
             auto extracted =
                 b.create<vector::ExtractOp>(loc, vec, extractionIndices);

diff  --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index a8c68abc8bcbf5c..9b29179f3687165 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -35,11 +35,25 @@
 
 using namespace mlir;
 
-/// Gets the first integer value from `attr`, assuming it is an integer array
-/// attribute.
+/// Returns the integer value from the first valid input element, assuming Value
+/// inputs are defined by a constant index ops and Attribute inputs are integer
+/// attributes.
+static uint64_t getFirstIntValue(ValueRange values) {
+  return values[0].getDefiningOp<arith::ConstantIndexOp>().value();
+}
+static uint64_t getFirstIntValue(ArrayRef<Attribute> attr) {
+  return cast<IntegerAttr>(attr[0]).getInt();
+}
 static uint64_t getFirstIntValue(ArrayAttr attr) {
   return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
 }
+static uint64_t getFirstIntValue(ArrayRef<OpFoldResult> foldResults) {
+  auto attr = foldResults[0].dyn_cast<Attribute>();
+  if (attr)
+    return getFirstIntValue(attr);
+
+  return getFirstIntValue(ValueRange{foldResults[0].get<Value>()});
+}
 
 /// Returns the number of bits for the given scalar/vector type.
 static int getNumBits(Type type) {
@@ -141,9 +155,7 @@ struct VectorExtractOpConvert final
   LogicalResult
   matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    // Only support extracting a scalar value now.
-    VectorType resultVectorType = dyn_cast<VectorType>(extractOp.getType());
-    if (resultVectorType && resultVectorType.getNumElements() > 1)
+    if (extractOp.hasDynamicPosition())
       return failure();
 
     Type dstType = getTypeConverter()->convertType(extractOp.getType());
@@ -155,7 +167,7 @@ struct VectorExtractOpConvert final
       return success();
     }
 
-    int32_t id = extractOp.getPosition()[0];
+    int32_t id = getFirstIntValue(extractOp.getMixedPosition());
     rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
         extractOp, adaptor.getVector(), id);
     return success();
@@ -235,7 +247,7 @@ struct VectorInsertOpConvert final
       return success();
     }
 
-    int32_t id = insertOp.getPosition()[0];
+    int32_t id = getFirstIntValue(insertOp.getMixedPosition());
     rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
         insertOp, adaptor.getSource(), adaptor.getDest(), id);
     return success();

diff  --git a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
index 906c13a6579f158..1084fbc890053b9 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
@@ -516,7 +516,7 @@ struct ExtensionOverExtract final : NarrowingPattern<vector::ExtractOp> {
       return failure();
 
     Value newExtract = rewriter.create<vector::ExtractOp>(
-        op.getLoc(), ext->getIn(), op.getPosition());
+        op.getLoc(), ext->getIn(), op.getMixedPosition());
     ext->recreateAndReplace(rewriter, op, newExtract);
     return success();
   }
@@ -645,8 +645,9 @@ struct ExtensionOverInsert final
                                      vector::InsertOp origInsert,
                                      Value narrowValue,
                                      Value narrowDest) const override {
-    return rewriter.create<vector::InsertOp>(
-        origInsert.getLoc(), narrowValue, narrowDest, origInsert.getPosition());
+    return rewriter.create<vector::InsertOp>(origInsert.getLoc(), narrowValue,
+                                             narrowDest,
+                                             origInsert.getMixedPosition());
   }
 };
 

diff  --git a/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp b/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp
index ad2180d501148f1..f63825cdc8f6179 100644
--- a/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp
+++ b/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp
@@ -74,7 +74,7 @@ static FailureOr<TransferMask> getMaskOp(Operation *loadOp) {
     if (auto maskOp =
             extractOp.getVector().getDefiningOp<vector::CreateMaskOp>())
       return TransferMask{maskOp,
-                          SmallVector<int64_t>(extractOp.getPosition())};
+                          SmallVector<int64_t>(extractOp.getStaticPosition())};
 
   // All other cases: not supported.
   return failure();

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 7b9c5f9b879e8c4..85d21938d0ab711 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -223,6 +223,48 @@ static LogicalResult incSlicePosition(MutableArrayRef<int64_t> position,
   return failure();
 }
 
+/// Returns the integer numbers in `values`. `values` are expected to be
+/// constant operations.
+SmallVector<int64_t> vector::getAsIntegers(ArrayRef<Value> values) {
+  SmallVector<int64_t> ints;
+  llvm::transform(values, std::back_inserter(ints), [](Value value) {
+    auto constOp = value.getDefiningOp<arith::ConstantIndexOp>();
+    assert(constOp && "Unexpected non-constant index");
+    return constOp.value();
+  });
+  return ints;
+}
+
+/// Returns the integer numbers in `foldResults`. `foldResults` are expected to
+/// be constant operations.
+SmallVector<int64_t> vector::getAsIntegers(ArrayRef<OpFoldResult> foldResults) {
+  SmallVector<int64_t> ints;
+  llvm::transform(
+      foldResults, std::back_inserter(ints), [](OpFoldResult foldResult) {
+        assert(foldResult.is<Attribute>() && "Unexpected non-constant index");
+        return cast<IntegerAttr>(foldResult.get<Attribute>()).getInt();
+      });
+  return ints;
+}
+
+/// Convert `foldResults` into Values. Integer attributes are converted to
+/// constant op.
+SmallVector<Value> vector::getAsValues(OpBuilder &builder, Location loc,
+                                       ArrayRef<OpFoldResult> foldResults) {
+  SmallVector<Value> values;
+  llvm::transform(foldResults, std::back_inserter(values),
+                  [&](OpFoldResult foldResult) {
+                    if (auto attr = foldResult.dyn_cast<Attribute>())
+                      return builder
+                          .create<arith::ConstantIndexOp>(
+                              loc, cast<IntegerAttr>(attr).getInt())
+                          .getResult();
+
+                    return foldResult.get<Value>();
+                  });
+  return values;
+}
+
 //===----------------------------------------------------------------------===//
 // CombiningKindAttr
 //===----------------------------------------------------------------------===//
@@ -389,12 +431,11 @@ struct ElideUnitDimsInMultiDimReduction
     } else {
       // This means we are reducing all the dimensions, and all reduction
       // dimensions are of size 1. So a simple extraction would do.
-      SmallVector<int64_t> zeroAttr(shape.size(), 0);
+      SmallVector<int64_t> zeroIdx(shape.size(), 0);
       if (mask)
-        mask = rewriter.create<vector::ExtractOp>(loc, rewriter.getI1Type(),
-                                                  mask, zeroAttr);
-      cast = rewriter.create<vector::ExtractOp>(
-          loc, reductionOp.getDestType(), reductionOp.getSource(), zeroAttr);
+        mask = rewriter.create<vector::ExtractOp>(loc, mask, zeroIdx);
+      cast = rewriter.create<vector::ExtractOp>(loc, reductionOp.getSource(),
+                                                zeroIdx);
     }
 
     Value result = vector::makeArithReduction(
@@ -574,11 +615,9 @@ struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
         mask = rewriter.create<ExtractElementOp>(loc, mask);
       result = rewriter.create<ExtractElementOp>(loc, reductionOp.getVector());
     } else {
-      if (mask) {
-        mask = rewriter.create<ExtractOp>(loc, rewriter.getI1Type(), mask, 0);
-      }
-      result = rewriter.create<ExtractOp>(loc, reductionOp.getType(),
-                                          reductionOp.getVector(), 0);
+      if (mask)
+        mask = rewriter.create<ExtractOp>(loc, mask, 0);
+      result = rewriter.create<ExtractOp>(loc, reductionOp.getVector(), 0);
     }
 
     if (Value acc = reductionOp.getAcc())
@@ -1148,12 +1187,29 @@ OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) {
 // ExtractOp
 //===----------------------------------------------------------------------===//
 
-// Convenience builder which assumes the values are constant indices.
 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
-                              Value source, ValueRange position) {
-  SmallVector<int64_t> positionConstants = llvm::to_vector(llvm::map_range(
-      position, [](Value pos) { return getConstantIntValue(pos).value(); }));
-  build(builder, result, source, positionConstants);
+                              Value source, int64_t position) {
+  build(builder, result, source, ArrayRef<int64_t>{position});
+}
+
+void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
+                              Value source, OpFoldResult position) {
+  build(builder, result, source, ArrayRef<OpFoldResult>{position});
+}
+
+void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
+                              Value source, ArrayRef<int64_t> position) {
+  build(builder, result, source, /*dynamic_position=*/ArrayRef<Value>(),
+        builder.getDenseI64ArrayAttr(position));
+}
+
+void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
+                              Value source, ArrayRef<OpFoldResult> position) {
+  SmallVector<int64_t> staticPos;
+  SmallVector<Value> dynamicPos;
+  dispatchIndexOpFoldResults(position, dynamicPos, staticPos);
+  build(builder, result, source, dynamicPos,
+        builder.getDenseI64ArrayAttr(staticPos));
 }
 
 LogicalResult
@@ -1161,12 +1217,12 @@ ExtractOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
                             ExtractOp::Adaptor adaptor,
                             SmallVectorImpl<Type> &inferredReturnTypes) {
   auto vectorType = llvm::cast<VectorType>(adaptor.getVector().getType());
-  if (static_cast<int64_t>(adaptor.getPosition().size()) ==
+  if (static_cast<int64_t>(adaptor.getStaticPosition().size()) ==
       vectorType.getRank()) {
     inferredReturnTypes.push_back(vectorType.getElementType());
   } else {
-    auto n =
-        std::min<size_t>(adaptor.getPosition().size(), vectorType.getRank());
+    auto n = std::min<size_t>(adaptor.getStaticPosition().size(),
+                              vectorType.getRank());
     inferredReturnTypes.push_back(VectorType::get(
         vectorType.getShape().drop_front(n), vectorType.getElementType(),
         vectorType.getScalableDims().drop_front(n)));
@@ -1188,17 +1244,20 @@ bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
 }
 
 LogicalResult vector::ExtractOp::verify() {
-  ArrayRef<int64_t> position = getPosition();
+  auto position = getMixedPosition();
   if (position.size() > static_cast<unsigned>(getSourceVectorType().getRank()))
     return emitOpError(
         "expected position attribute of rank no greater than vector rank");
-  for (const auto &en : llvm::enumerate(position)) {
-    if (en.value() < 0 ||
-        en.value() >= getSourceVectorType().getDimSize(en.index()))
-      return emitOpError("expected position attribute #")
-             << (en.index() + 1)
-             << " to be a non-negative integer smaller than the corresponding "
-                "vector dimension";
+  for (auto [idx, pos] : llvm::enumerate(position)) {
+    if (pos.is<Attribute>()) {
+      int64_t constIdx = cast<IntegerAttr>(pos.get<Attribute>()).getInt();
+      if (constIdx < 0 || constIdx >= getSourceVectorType().getDimSize(idx)) {
+        return emitOpError("expected position attribute #")
+               << (idx + 1)
+               << " to be a non-negative integer smaller than the "
+                  "corresponding vector dimension";
+      }
+    }
   }
   return success();
 }
@@ -1216,20 +1275,24 @@ static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
   if (!extractOp.getVector().getDefiningOp<ExtractOp>())
     return failure();
 
-  SmallVector<int64_t, 4> globalPosition;
+  // TODO: Canonicalization for dynamic position not implemented yet.
+  if (extractOp.hasDynamicPosition())
+    return failure();
+
+  SmallVector<int64_t> globalPosition;
   ExtractOp currentOp = extractOp;
-  ArrayRef<int64_t> extrPos = currentOp.getPosition();
+  ArrayRef<int64_t> extrPos = currentOp.getStaticPosition();
   globalPosition.append(extrPos.rbegin(), extrPos.rend());
   while (ExtractOp nextOp = currentOp.getVector().getDefiningOp<ExtractOp>()) {
     currentOp = nextOp;
-    ArrayRef<int64_t> extrPos = currentOp.getPosition();
+    ArrayRef<int64_t> extrPos = currentOp.getStaticPosition();
     globalPosition.append(extrPos.rbegin(), extrPos.rend());
   }
-  extractOp.setOperand(currentOp.getVector());
+  extractOp.setOperand(0, currentOp.getVector());
   // OpBuilder is only used as a helper to build an I64ArrayAttr.
   OpBuilder b(extractOp.getContext());
   std::reverse(globalPosition.begin(), globalPosition.end());
-  extractOp.setPosition(globalPosition);
+  extractOp.setStaticPosition(globalPosition);
   return success();
 }
 
@@ -1335,19 +1398,23 @@ class ExtractFromInsertTransposeChainState {
 ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState(
     ExtractOp e)
     : extractOp(e), vectorRank(extractOp.getSourceVectorType().getRank()),
-      extractedRank(extractOp.getPosition().size()) {
-  assert(vectorRank >= extractedRank && "extracted pos overflow");
+      extractedRank(extractOp.getNumIndices()) {
+  assert(vectorRank >= extractedRank && "Extracted position overflow");
   sentinels.reserve(vectorRank - extractedRank);
   for (int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i)
     sentinels.push_back(-(i + 1));
-  extractPosition.assign(extractOp.getPosition().begin(),
-                         extractOp.getPosition().end());
+  extractPosition.assign(extractOp.getStaticPosition().begin(),
+                         extractOp.getStaticPosition().end());
   llvm::append_range(extractPosition, sentinels);
 }
 
 // Case 1. If we hit a transpose, just compose the map and iterate.
 // Invariant: insert + transpose do not change rank, we can always compose.
 LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
+  // TODO: Canonicalization for dynamic position not implemented yet.
+  if (extractOp.hasDynamicPosition())
+    return failure();
+
   if (!nextTransposeOp)
     return failure();
   auto permutation = extractVector<unsigned>(nextTransposeOp.getTransp());
@@ -1361,7 +1428,11 @@ LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
 LogicalResult
 ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
     Value &res) {
-  ArrayRef<int64_t> insertedPos = nextInsertOp.getPosition();
+  // TODO: Canonicalization for dynamic position not implemented yet.
+  if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
+    return failure();
+
+  ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
   if (insertedPos != llvm::ArrayRef(extractPosition).take_front(extractedRank))
     return failure();
   // Case 2.a. early-exit fold.
@@ -1375,7 +1446,11 @@ ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
 /// This method updates the internal state.
 LogicalResult
 ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) {
-  ArrayRef<int64_t> insertedPos = nextInsertOp.getPosition();
+  // TODO: Canonicalization for dynamic position not implemented yet.
+  if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
+    return failure();
+
+  ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
   if (!isContainedWithin(insertedPos, extractPosition))
     return failure();
   // Set leading dims to zero.
@@ -1395,19 +1470,29 @@ ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) {
 /// internal tranposition in the result).
 Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
     Value source) {
+  // TODO: Canonicalization for dynamic position not implemented yet.
+  if (extractOp.hasDynamicPosition())
+    return Value();
+
   // If we can't fold (either internal transposition, or nothing to fold), bail.
   bool nothingToFold = (source == extractOp.getVector());
   if (nothingToFold || !canFold())
     return Value();
+
   // Otherwise, fold by updating the op inplace and return its result.
   OpBuilder b(extractOp.getContext());
-  extractOp.setPosition(ArrayRef(extractPosition).take_front(extractedRank));
+  extractOp.setStaticPosition(
+      ArrayRef(extractPosition).take_front(extractedRank));
   extractOp.getVectorMutable().assign(source);
   return extractOp.getResult();
 }
 
 /// Iterate over producing insert and transpose ops until we find a fold.
 Value ExtractFromInsertTransposeChainState::fold() {
+  // TODO: Canonicalization for dynamic position not implemented yet.
+  if (extractOp.hasDynamicPosition())
+    return Value();
+
   Value valueToExtractFrom = extractOp.getVector();
   updateStateForNextIteration(valueToExtractFrom);
   while (nextInsertOp || nextTransposeOp) {
@@ -1431,7 +1516,7 @@ Value ExtractFromInsertTransposeChainState::fold() {
 
     // Case 4: extractPositionRef intersects insertedPosRef on non-sentinel
     // values. This is a more 
diff icult case and we bail.
-    ArrayRef<int64_t> insertedPos = nextInsertOp.getPosition();
+    ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
     if (isContainedWithin(extractPosition, insertedPos) ||
         intersectsWhereNonNegative(extractPosition, insertedPos))
       return Value();
@@ -1457,6 +1542,10 @@ static bool hasZeroDimVectors(Operation *op) {
 
 /// Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
 static Value foldExtractFromBroadcast(ExtractOp extractOp) {
+  // TODO: Canonicalization for dynamic position not implemented yet.
+  if (extractOp.hasDynamicPosition())
+    return Value();
+
   Operation *defOp = extractOp.getVector().getDefiningOp();
   if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
     return Value();
@@ -1497,7 +1586,7 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
   // extract position to `0` when extracting from the source operand.
   llvm::SetVector<int64_t> broadcastedUnitDims =
       broadcastOp.computeBroadcastedUnitDims();
-  SmallVector<int64_t> extractPos(extractOp.getPosition());
+  SmallVector<int64_t> extractPos(extractOp.getStaticPosition());
   int64_t broadcastRankDiff = broadcastDstRank - broadcastSrcRank;
   for (int64_t i = broadcastRankDiff, e = extractPos.size(); i < e; ++i)
     if (broadcastedUnitDims.contains(i))
@@ -1509,13 +1598,17 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
                    std::next(extractPos.begin(), extractPos.size() - rankDiff));
   // OpBuilder is only used as a helper to build an I64ArrayAttr.
   OpBuilder b(extractOp.getContext());
-  extractOp.setOperand(source);
-  extractOp.setPosition(extractPos);
+  extractOp.setOperand(0, source);
+  extractOp.setStaticPosition(extractPos);
   return extractOp.getResult();
 }
 
 // Fold extractOp with source coming from ShapeCast op.
 static Value foldExtractFromShapeCast(ExtractOp extractOp) {
+  // TODO: Canonicalization for dynamic position not implemented yet.
+  if (extractOp.hasDynamicPosition())
+    return Value();
+
   auto shapeCastOp = extractOp.getVector().getDefiningOp<vector::ShapeCastOp>();
   if (!shapeCastOp)
     return Value();
@@ -1549,7 +1642,7 @@ static Value foldExtractFromShapeCast(ExtractOp extractOp) {
   }
   // Extract the strides associated with the extract op vector source. Then use
   // this to calculate a linearized position for the extract.
-  SmallVector<int64_t> extractedPos(extractOp.getPosition());
+  SmallVector<int64_t> extractedPos(extractOp.getStaticPosition());
   std::reverse(extractedPos.begin(), extractedPos.end());
   SmallVector<int64_t, 4> strides;
   int64_t stride = 1;
@@ -1575,13 +1668,17 @@ static Value foldExtractFromShapeCast(ExtractOp extractOp) {
   SmallVector<int64_t, 4> newPosition = delinearize(position, newStrides);
   // OpBuilder is only used as a helper to build an I64ArrayAttr.
   OpBuilder b(extractOp.getContext());
-  extractOp.setPosition(newPosition);
-  extractOp.setOperand(shapeCastOp.getSource());
+  extractOp.setStaticPosition(newPosition);
+  extractOp.setOperand(0, shapeCastOp.getSource());
   return extractOp.getResult();
 }
 
 /// Fold an ExtractOp from ExtractStridedSliceOp.
 static Value foldExtractFromExtractStrided(ExtractOp extractOp) {
+  // TODO: Canonicalization for dynamic position not implemented yet.
+  if (extractOp.hasDynamicPosition())
+    return Value();
+
   auto extractStridedSliceOp =
       extractOp.getVector().getDefiningOp<vector::ExtractStridedSliceOp>();
   if (!extractStridedSliceOp)
@@ -1615,19 +1712,25 @@ static Value foldExtractFromExtractStrided(ExtractOp extractOp) {
   if (destinationRank > extractStridedSliceOp.getSourceVectorType().getRank() -
                             sliceOffsets.size())
     return Value();
-  SmallVector<int64_t> extractedPos(extractOp.getPosition());
+
+  SmallVector<int64_t> extractedPos(extractOp.getStaticPosition());
   assert(extractedPos.size() >= sliceOffsets.size());
   for (size_t i = 0, e = sliceOffsets.size(); i < e; i++)
     extractedPos[i] = extractedPos[i] + sliceOffsets[i];
   extractOp.getVectorMutable().assign(extractStridedSliceOp.getVector());
+
   // OpBuilder is only used as a helper to build an I64ArrayAttr.
   OpBuilder b(extractOp.getContext());
-  extractOp.setPosition(extractedPos);
+  extractOp.setStaticPosition(extractedPos);
   return extractOp.getResult();
 }
 
 /// Fold extract_op fed from a chain of insertStridedSlice ops.
 static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) {
+  // TODO: Canonicalization for dynamic position not implemented yet.
+  if (extractOp.hasDynamicPosition())
+    return Value();
+
   int64_t destinationRank =
       llvm::isa<VectorType>(extractOp.getType())
           ? llvm::cast<VectorType>(extractOp.getType()).getRank()
@@ -1647,7 +1750,7 @@ static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) {
     if (destinationRank > insertOp.getSourceVectorType().getRank())
       return Value();
     auto insertOffsets = extractVector<int64_t>(insertOp.getOffsets());
-    ArrayRef<int64_t> extractOffsets = extractOp.getPosition();
+    ArrayRef<int64_t> extractOffsets = extractOp.getStaticPosition();
 
     if (llvm::any_of(insertOp.getStrides(), [](Attribute attr) {
           return llvm::cast<IntegerAttr>(attr).getInt() != 1;
@@ -1687,7 +1790,7 @@ static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) {
       extractOp.getVectorMutable().assign(insertOp.getSource());
       // OpBuilder is only used as a helper to build an I64ArrayAttr.
       OpBuilder b(extractOp.getContext());
-      extractOp.setPosition(offsetDiffs);
+      extractOp.setStaticPosition(offsetDiffs);
       return extractOp.getResult();
     }
     // If the chunk extracted is disjoint from the chunk inserted, keep
@@ -1698,7 +1801,7 @@ static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) {
 }
 
 OpFoldResult ExtractOp::fold(FoldAdaptor) {
-  if (getPosition().empty())
+  if (getNumIndices() == 0)
     return getVector();
   if (succeeded(foldExtractOpFromExtractChain(*this)))
     return getResult();
@@ -1788,6 +1891,10 @@ class ExtractOpNonSplatConstantFolder final
 
   LogicalResult matchAndRewrite(ExtractOp extractOp,
                                 PatternRewriter &rewriter) const override {
+    // TODO: Canonicalization for dynamic position not implemented yet.
+    if (extractOp.hasDynamicPosition())
+      return failure();
+
     // Return if 'ExtractOp' operand is not defined by a compatible vector
     // ConstantOp.
     Value sourceVector = extractOp.getVector();
@@ -1807,7 +1914,7 @@ class ExtractOpNonSplatConstantFolder final
     // Calculate the linearized position of the continuous chunk of elements to
     // extract.
     llvm::SmallVector<int64_t> completePositions(vecTy.getRank(), 0);
-    copy(extractOp.getPosition(), completePositions.begin());
+    copy(extractOp.getStaticPosition(), completePositions.begin());
     int64_t elemBeginPosition =
         linearize(completePositions, computeStrides(vecTy.getShape()));
     auto denseValuesBegin = dense.value_begin<TypedAttr>() + elemBeginPosition;
@@ -2322,18 +2429,38 @@ OpFoldResult vector::InsertElementOp::fold(FoldAdaptor adaptor) {
 // InsertOp
 //===----------------------------------------------------------------------===//
 
-// Convenience builder which assumes the values are constant indices.
-void InsertOp::build(OpBuilder &builder, OperationState &result, Value source,
-                     Value dest, ValueRange position) {
-  SmallVector<int64_t, 4> positionConstants =
-      llvm::to_vector<4>(llvm::map_range(position, [](Value pos) {
-        return getConstantIntValue(pos).value();
-      }));
-  build(builder, result, source, dest, positionConstants);
+void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
+                             Value source, Value dest, int64_t position) {
+  build(builder, result, source, dest, ArrayRef<int64_t>{position});
+}
+
+void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
+                             Value source, Value dest, OpFoldResult position) {
+  build(builder, result, source, dest, ArrayRef<OpFoldResult>{position});
+}
+
+void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
+                             Value source, Value dest,
+                             ArrayRef<int64_t> position) {
+  SmallVector<OpFoldResult> posVals;
+  posVals.reserve(position.size());
+  llvm::transform(position, std::back_inserter(posVals),
+                  [&](int64_t pos) { return builder.getI64IntegerAttr(pos); });
+  build(builder, result, source, dest, posVals);
+}
+
+void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
+                             Value source, Value dest,
+                             ArrayRef<OpFoldResult> position) {
+  SmallVector<int64_t> staticPos;
+  SmallVector<Value> dynamicPos;
+  dispatchIndexOpFoldResults(position, dynamicPos, staticPos);
+  build(builder, result, source, dest, dynamicPos,
+        builder.getDenseI64ArrayAttr(staticPos));
 }
 
 LogicalResult InsertOp::verify() {
-  ArrayRef<int64_t> position = getPosition();
+  SmallVector<OpFoldResult> position = getMixedPosition();
   auto destVectorType = getDestVectorType();
   if (position.size() > static_cast<unsigned>(destVectorType.getRank()))
     return emitOpError(
@@ -2348,13 +2475,17 @@ LogicalResult InsertOp::verify() {
       (position.size() != static_cast<unsigned>(destVectorType.getRank())))
     return emitOpError(
         "expected position attribute rank to match the dest vector rank");
-  for (const auto &en : llvm::enumerate(position)) {
-    int64_t attr = en.value();
-    if (attr < 0 || attr >= destVectorType.getDimSize(en.index()))
-      return emitOpError("expected position attribute #")
-             << (en.index() + 1)
-             << " to be a non-negative integer smaller than the corresponding "
-                "dest vector dimension";
+  for (auto [idx, pos] : llvm::enumerate(position)) {
+    if (auto attr = pos.dyn_cast<Attribute>()) {
+      int64_t constIdx = cast<IntegerAttr>(attr).getInt();
+      if (constIdx < 0 || constIdx >= destVectorType.getDimSize(idx)) {
+        return emitOpError("expected position attribute #")
+               << (idx + 1)
+               << " to be a non-negative integer smaller than the "
+                  "corresponding "
+                  "dest vector dimension";
+      }
+    }
   }
   return success();
 }
@@ -2411,6 +2542,10 @@ class InsertOpConstantFolder final : public OpRewritePattern<InsertOp> {
 
   LogicalResult matchAndRewrite(InsertOp op,
                                 PatternRewriter &rewriter) const override {
+    // TODO: Canonicalization for dynamic position not implemented yet.
+    if (op.hasDynamicPosition())
+      return failure();
+
     // Return if 'InsertOp' operand is not defined by a compatible vector
     // ConstantOp.
     TypedValue<VectorType> destVector = op.getDest();
@@ -2437,7 +2572,7 @@ class InsertOpConstantFolder final : public OpRewritePattern<InsertOp> {
     // Calculate the linearized position of the continuous chunk of elements to
     // insert.
     llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
-    copy(op.getPosition(), completePositions.begin());
+    copy(op.getStaticPosition(), completePositions.begin());
     int64_t insertBeginPosition =
         linearize(completePositions, computeStrides(destTy.getShape()));
 
@@ -2468,7 +2603,7 @@ void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
 // value. This happens when the source and destination vectors have identical
 // sizes.
 OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
-  if (getPosition().empty())
+  if (getNumIndices() == 0)
     return getSource();
   return {};
 }

diff  --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index 64ab0abda26e640..7560db2332cf8d9 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -89,20 +89,20 @@ static Value reshapeLoad(Location loc, Value val, VectorType type,
                          PatternRewriter &rewriter) {
   if (index == -1)
     return val;
-  Type lowType = type.getRank() > 1 ? VectorType::Builder(type).dropDim(0)
-                                    : type.getElementType();
+
   // At extraction dimension?
   if (index == 0)
-    return rewriter.create<vector::ExtractOp>(loc, lowType, val, pos);
+    return rewriter.create<vector::ExtractOp>(loc, val, pos);
+
   // Unroll leading dimensions.
-  VectorType vType = cast<VectorType>(lowType);
+  VectorType vType = VectorType::Builder(type).dropDim(0);
   VectorType resType = VectorType::Builder(type).dropDim(index);
   Value result = rewriter.create<arith::ConstantOp>(
       loc, resType, rewriter.getZeroAttr(resType));
   for (int64_t d = 0, e = resType.getDimSize(0); d < e; d++) {
-    Value ext = rewriter.create<vector::ExtractOp>(loc, vType, val, d);
+    Value ext = rewriter.create<vector::ExtractOp>(loc, val, d);
     Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter);
-    result = rewriter.create<vector::InsertOp>(loc, resType, load, result, d);
+    result = rewriter.create<vector::InsertOp>(loc, load, result, d);
   }
   return result;
 }
@@ -117,16 +117,15 @@ static Value reshapeStore(Location loc, Value val, Value result,
     return val;
   // At insertion dimension?
   if (index == 0)
-    return rewriter.create<vector::InsertOp>(loc, type, val, result, pos);
+    return rewriter.create<vector::InsertOp>(loc, val, result, pos);
+
   // Unroll leading dimensions.
-  VectorType lowType = VectorType::Builder(type).dropDim(0);
-  Type insType = lowType.getRank() > 1 ? VectorType::Builder(lowType).dropDim(0)
-                                       : lowType.getElementType();
+  VectorType vType = VectorType::Builder(type).dropDim(0);
   for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) {
-    Value ext = rewriter.create<vector::ExtractOp>(loc, lowType, result, d);
-    Value ins = rewriter.create<vector::ExtractOp>(loc, insType, val, d);
-    Value sto = reshapeStore(loc, ins, ext, lowType, index - 1, pos, rewriter);
-    result = rewriter.create<vector::InsertOp>(loc, type, sto, result, d);
+    Value ext = rewriter.create<vector::ExtractOp>(loc, result, d);
+    Value ins = rewriter.create<vector::ExtractOp>(loc, val, d);
+    Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter);
+    result = rewriter.create<vector::InsertOp>(loc, sto, result, d);
   }
   return result;
 }
@@ -1175,7 +1174,7 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
           loc, a, op.getRhs(), r, kind, rewriter, isInt, extrMask);
       if (!m.has_value())
         return failure();
-      result = rewriter.create<vector::InsertOp>(loc, resType, *m, result, d);
+      result = rewriter.create<vector::InsertOp>(loc, *m, result, d);
     }
 
     rewriter.replaceOp(rootOp, result);

diff  --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
index 95b5ea011c82569..887d1af7645419f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
@@ -79,7 +79,7 @@ class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> {
       Value val = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
                                                  bnd, idx);
       Value sel = rewriter.create<arith::SelectOp>(loc, val, trueVal, falseVal);
-      result = rewriter.create<vector::InsertOp>(loc, dstType, sel, result, d);
+      result = rewriter.create<vector::InsertOp>(loc, sel, result, d);
     }
     rewriter.replaceOp(op, result);
     return success();
@@ -151,8 +151,8 @@ class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
     Value result = rewriter.create<arith::ConstantOp>(
         loc, dstType, rewriter.getZeroAttr(dstType));
     for (int64_t d = 0; d < trueDimSize; d++)
-      result =
-          rewriter.create<vector::InsertOp>(loc, dstType, trueVal, result, d);
+      result = rewriter.create<vector::InsertOp>(loc, trueVal, result, d);
+
     rewriter.replaceOp(op, result);
     return success();
   }

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 2a50947e976dffb..f4486ea117a2934 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1040,13 +1040,17 @@ struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
            "vector.extract does not support rank 0 sources");
 
     // "vector.extract %v[] : vector<...xf32>" can be canonicalized to %v.
-    if (extractOp.getPosition().empty())
+    if (extractOp.getNumIndices() == 0)
       return failure();
 
     // Rewrite vector.extract with 1d source to vector.extractelement.
     if (extractSrcType.getRank() == 1) {
-      assert(extractOp.getPosition().size() == 1 && "expected 1 index");
-      int64_t pos = extractOp.getPosition()[0];
+      if (extractOp.hasDynamicPosition())
+        // TODO: Dinamic position not supported yet.
+        return failure();
+
+      assert(extractOp.getNumIndices() == 1 && "expected 1 index");
+      int64_t pos = extractOp.getStaticPosition()[0];
       rewriter.setInsertionPoint(extractOp);
       rewriter.replaceOpWithNewOp<vector::ExtractElementOp>(
           extractOp, extractOp.getVector(),
@@ -1070,7 +1074,7 @@ struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
       Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
       // Extract from distributed vector.
       Value newExtract = rewriter.create<vector::ExtractOp>(
-          loc, distributedVec, extractOp.getPosition());
+          loc, distributedVec, extractOp.getMixedPosition());
       rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
                                   newExtract);
       return success();
@@ -1096,7 +1100,7 @@ struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
     SmallVector<int64_t> newDistributedShape(extractSrcType.getShape().begin(),
                                              extractSrcType.getShape().end());
     for (int i = 0; i < distributedType.getRank(); ++i)
-      newDistributedShape[i + extractOp.getPosition().size()] =
+      newDistributedShape[i + extractOp.getNumIndices()] =
           distributedType.getDimSize(i);
     auto newDistributedType =
         VectorType::get(newDistributedShape, distributedType.getElementType());
@@ -1108,7 +1112,7 @@ struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
     Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
     // Extract from distributed vector.
     Value newExtract = rewriter.create<vector::ExtractOp>(
-        loc, distributedVec, extractOp.getPosition());
+        loc, distributedVec, extractOp.getMixedPosition());
     rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
                                 newExtract);
     return success();
@@ -1297,13 +1301,17 @@ struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> {
     Location loc = insertOp.getLoc();
 
     // "vector.insert %v, %v[] : ..." can be canonicalized to %v.
-    if (insertOp.getPosition().empty())
+    if (insertOp.getNumIndices() == 0)
       return failure();
 
     // Rewrite vector.insert with 1d dest to vector.insertelement.
     if (insertOp.getDestVectorType().getRank() == 1) {
-      assert(insertOp.getPosition().size() == 1 && "expected 1 index");
-      int64_t pos = insertOp.getPosition()[0];
+      if (insertOp.hasDynamicPosition())
+        // TODO: Dinamic position not supported yet.
+        return failure();
+
+      assert(insertOp.getNumIndices() == 1 && "expected 1 index");
+      int64_t pos = insertOp.getStaticPosition()[0];
       rewriter.setInsertionPoint(insertOp);
       rewriter.replaceOpWithNewOp<vector::InsertElementOp>(
           insertOp, insertOp.getSource(), insertOp.getDest(),
@@ -1323,7 +1331,7 @@ struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> {
       Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
       Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
       Value newResult = rewriter.create<vector::InsertOp>(
-          loc, distributedSrc, distributedDest, insertOp.getPosition());
+          loc, distributedSrc, distributedDest, insertOp.getMixedPosition());
       rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
                                   newResult);
       return success();
@@ -1354,7 +1362,7 @@ struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> {
     // Case 2: distrDestDim = 0 (dim of size 128) => distrSrcDim = -1. In that
     //         case, one lane will insert the source vector<96xf32>. The other
     //         lanes will not do anything.
-    int64_t distrSrcDim = distrDestDim - insertOp.getPosition().size();
+    int64_t distrSrcDim = distrDestDim - insertOp.getNumIndices();
     if (distrSrcDim >= 0)
       distrSrcShape[distrSrcDim] = distrDestType.getDimSize(distrDestDim);
     auto distrSrcType =
@@ -1374,11 +1382,12 @@ struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> {
     if (distrSrcDim >= 0) {
       // Every lane inserts a small piece.
       newResult = rewriter.create<vector::InsertOp>(
-          loc, distributedSrc, distributedDest, insertOp.getPosition());
+          loc, distributedSrc, distributedDest, insertOp.getMixedPosition());
     } else {
       // One lane inserts the entire source vector.
       int64_t elementsPerLane = distrDestType.getDimSize(distrDestDim);
-      SmallVector<int64_t> newPos(insertOp.getPosition());
+      SmallVector<OpFoldResult> pos = insertOp.getMixedPosition();
+      SmallVector<int64_t> newPos = getAsIntegers(pos);
       // tid of inserting lane: pos / elementsPerLane
       Value insertingLane = rewriter.create<arith::ConstantIndexOp>(
           loc, newPos[distrDestDim] / elementsPerLane);

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index 913c826dd912470..6bbb293fa2a6b5c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -6,6 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
@@ -176,14 +177,16 @@ struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> {
     // type has leading unit dims, we also trim the position array accordingly,
     // then (2) if source type also has leading unit dims, we need to append
     // zeroes to the position array accordingly.
-    unsigned oldPosRank = insertOp.getPosition().size();
+    unsigned oldPosRank = insertOp.getNumIndices();
     unsigned newPosRank = std::max<int64_t>(0, oldPosRank - dstDropCount);
-    SmallVector<int64_t> newPositions =
-        llvm::to_vector(insertOp.getPosition().take_back(newPosRank));
-    newPositions.resize(newDstType.getRank() - newSrcRank, 0);
+    SmallVector<OpFoldResult> oldPosition = insertOp.getMixedPosition();
+    SmallVector<OpFoldResult> newPosition =
+        llvm::to_vector(ArrayRef(oldPosition).take_back(newPosRank));
+    newPosition.resize(newDstType.getRank() - newSrcRank,
+                       rewriter.getI64IntegerAttr(0));
 
     auto newInsertOp = rewriter.create<vector::InsertOp>(
-        loc, newDstType, newSrcVector, newDstVector, newPositions);
+        loc, newSrcVector, newDstVector, newPosition);
 
     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType,
                                                      newInsertOp);

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index f715c543eb17955..603b88f11c8e007 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -707,10 +707,10 @@ class RewriteScalarExtractOfTransferRead
     auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
     SmallVector<Value> newIndices(xferOp.getIndices().begin(),
                                   xferOp.getIndices().end());
-    for (const auto &it : llvm::enumerate(extractOp.getPosition())) {
-      int64_t offset = it.value();
-      int64_t idx =
-          newIndices.size() - extractOp.getPosition().size() + it.index();
+    for (auto [i, pos] : llvm::enumerate(extractOp.getMixedPosition())) {
+      assert(pos.is<Attribute>() && "Unexpected non-constant index");
+      int64_t offset = cast<IntegerAttr>(pos.get<Attribute>()).getInt();
+      int64_t idx = newIndices.size() - extractOp.getNumIndices() + i;
       OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
           rewriter, extractOp.getLoc(),
           rewriter.getAffineSymbolExpr(0) + offset, {newIndices[idx]});

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index b2a5aef5ee62d0f..b891d62ee508e30 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -598,27 +598,34 @@ struct BubbleDownVectorBitCastForExtract
     unsigned expandRatio =
         castDstType.getNumElements() / castSrcType.getNumElements();
 
-    uint64_t index = extractOp.getPosition()[0];
+    auto getFirstIntValue = [](ArrayRef<OpFoldResult> values) -> uint64_t {
+      assert(values[0].is<Attribute>() && "Unexpected non-constant index");
+      return cast<IntegerAttr>(values[0].get<Attribute>()).getInt();
+    };
+
+    uint64_t index = getFirstIntValue(extractOp.getMixedPosition());
 
     // Get the single scalar (as a vector) in the source value that packs the
     // desired scalar. E.g. extract vector<1xf32> from vector<4xf32>
-    VectorType oneScalarType =
-        VectorType::get({1}, castSrcType.getElementType());
+    Location loc = extractOp.getLoc();
     Value packedValue = rewriter.create<vector::ExtractOp>(
-        extractOp.getLoc(), oneScalarType, castOp.getSource(),
-        index / expandRatio);
+        loc, castOp.getSource(), index / expandRatio);
+    Type packedVecType = VectorType::get(/*shape=*/{1}, packedValue.getType());
+    Value zero = rewriter.create<arith::ConstantOp>(
+        loc, packedVecType, rewriter.getZeroAttr(packedVecType));
+    packedValue = rewriter.create<vector::InsertOp>(loc, packedValue, zero,
+                                                    /*position=*/0);
 
     // Cast it to a vector with the desired scalar's type.
     // E.g. f32 -> vector<2xf16>
     VectorType packedType =
         VectorType::get({expandRatio}, castDstType.getElementType());
-    Value castedValue = rewriter.create<vector::BitCastOp>(
-        extractOp.getLoc(), packedType, packedValue);
+    Value castedValue =
+        rewriter.create<vector::BitCastOp>(loc, packedType, packedValue);
 
     // Finally extract the desired scalar.
-    rewriter.replaceOpWithNewOp<vector::ExtractOp>(
-        extractOp, extractOp.getType(), castedValue, index % expandRatio);
-
+    rewriter.replaceOpWithNewOp<vector::ExtractOp>(extractOp, castedValue,
+                                                   index % expandRatio);
     return success();
   }
 };

diff  --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index b07c4bd67be2dc7..41ab06f2e23b501 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -728,6 +728,17 @@ func.func @extract_element_from_vec_3d(%arg0: vector<4x3x16xf32>) -> f32 {
 
 // -----
 
+func.func @extract_element_with_value_1d(%arg0: vector<16xf32>, %arg1: index) -> f32 {
+  %0 = vector.extract %arg0[%arg1]: vector<16xf32>
+  return %0 : f32
+}
+// CHECK-LABEL: @extract_element_with_value_1d
+//  CHECK-SAME:   %[[VEC:.+]]: vector<16xf32>, %[[INDEX:.+]]: index
+//       CHECK:   %[[UC:.+]] = builtin.unrealized_conversion_cast %[[INDEX]] : index to i64
+//       CHECK:   llvm.extractelement %[[VEC]][%[[UC]] : i64] : vector<16xf32>
+
+// -----
+
 // CHECK-LABEL: @insert_element_0d
 // CHECK-SAME: %[[A:.*]]: f32,
 func.func @insert_element_0d(%a: f32, %b: vector<f32>) -> vector<f32> {
@@ -830,6 +841,19 @@ func.func @insert_element_into_vec_3d(%arg0: f32, %arg1: vector<4x8x16xf32>) ->
 
 // -----
 
+func.func @insert_element_with_value_1d(%arg0: vector<16xf32>, %arg1: f32, %arg2: index)
+                                      -> vector<16xf32> {
+  %0 = vector.insert %arg1, %arg0[%arg2]: f32 into vector<16xf32>
+  return %0 : vector<16xf32>
+}
+
+// CHECK-LABEL: @insert_element_with_value_1d
+//  CHECK-SAME:   %[[DST:.+]]: vector<16xf32>, %[[SRC:.+]]: f32, %[[INDEX:.+]]: index
+//       CHECK:   %[[UC:.+]] = builtin.unrealized_conversion_cast %[[INDEX]] : index to i64
+//       CHECK:   llvm.insertelement %[[SRC]], %[[DST]][%[[UC]] : i64] : vector<16xf32>
+
+// -----
+
 func.func @vector_type_cast(%arg0: memref<8x8x8xf32>) -> memref<vector<8x8x8xf32>> {
   %0 = vector.type_cast %arg0: memref<8x8x8xf32> to memref<vector<8x8x8xf32>>
   return %0 : memref<vector<8x8x8xf32>>

diff  --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index f60a522cbfdba56..266161d5268e985 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -155,8 +155,8 @@ func.func @broadcast(%arg0 : f32) -> (vector<4xf32>, vector<2xf32>) {
 //       CHECK:   spirv.CompositeExtract %[[ARG]][0 : i32] : vector<2xf32>
 //       CHECK:   spirv.CompositeExtract %[[ARG]][1 : i32] : vector<2xf32>
 func.func @extract(%arg0 : vector<2xf32>) -> (vector<1xf32>, f32) {
-  %0 = "vector.extract"(%arg0) <{position = array<i64: 0>}> : (vector<2xf32>) -> vector<1xf32>
-  %1 = "vector.extract"(%arg0) <{position = array<i64: 1>}> : (vector<2xf32>) -> f32
+  %0 = "vector.extract"(%arg0) <{static_position = array<i64: 0>}> : (vector<2xf32>) -> vector<1xf32>
+  %1 = "vector.extract"(%arg0) <{static_position = array<i64: 1>}> : (vector<2xf32>) -> f32
   return %0, %1: vector<1xf32>, f32
 }
 

diff  --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 26772b929493585..549fe7a6a61f6ac 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -133,7 +133,7 @@ func.func @extract_position_rank_overflow(%arg0: vector<4x8x16xf32>) {
 
 func.func @extract_position_rank_overflow_generic(%arg0: vector<4x8x16xf32>) {
   // expected-error at +1 {{expected position attribute of rank no greater than vector rank}}
-  %1 = "vector.extract" (%arg0) <{position = array<i64: 0, 0, 0, 0>}> : (vector<4x8x16xf32>) -> (vector<16xf32>)
+  %1 = "vector.extract" (%arg0) <{static_position = array<i64: 0, 0, 0, 0>}> : (vector<4x8x16xf32>) -> (vector<16xf32>)
 }
 
 // -----

diff  --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 3a23ee14ca14fa0..f879cd122469a65 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -206,8 +206,9 @@ func.func @extract_element(%a: vector<16xf32>) -> f32 {
   return %1 : f32
 }
 
-// CHECK-LABEL: @extract
-func.func @extract(%arg0: vector<4x8x16xf32>) -> (vector<4x8x16xf32>, vector<8x16xf32>, vector<16xf32>, f32) {
+// CHECK-LABEL: @extract_const_idx
+func.func @extract_const_idx(%arg0: vector<4x8x16xf32>)
+                             -> (vector<4x8x16xf32>, vector<8x16xf32>, vector<16xf32>, f32) {
   // CHECK: vector.extract {{.*}}[] : vector<4x8x16xf32>
   %0 = vector.extract %arg0[] : vector<4x8x16xf32>
   // CHECK: vector.extract {{.*}}[3] : vector<4x8x16xf32>
@@ -219,6 +220,19 @@ func.func @extract(%arg0: vector<4x8x16xf32>) -> (vector<4x8x16xf32>, vector<8x1
   return %0, %1, %2, %3 : vector<4x8x16xf32>, vector<8x16xf32>, vector<16xf32>, f32
 }
 
+// CHECK-LABEL: @extract_val_idx
+//  CHECK-SAME:   %[[VEC:.+]]: vector<4x8x16xf32>, %[[IDX:.+]]: index
+func.func @extract_val_idx(%arg0: vector<4x8x16xf32>, %idx: index)
+                           -> (vector<8x16xf32>, vector<16xf32>, f32) {
+  // CHECK: vector.extract %[[VEC]][%[[IDX]]] : vector<4x8x16xf32>
+  %0 = vector.extract %arg0[%idx] : vector<4x8x16xf32>
+  // CHECK-NEXT: vector.extract %[[VEC]][%[[IDX]], %[[IDX]]] : vector<4x8x16xf32>
+  %1 = vector.extract %arg0[%idx, %idx] : vector<4x8x16xf32>
+  // CHECK-NEXT: vector.extract %[[VEC]][%[[IDX]], 5, %[[IDX]]] : vector<4x8x16xf32>
+  %2 = vector.extract %arg0[%idx, 5, %idx] : vector<4x8x16xf32>
+  return %0, %1, %2 : vector<8x16xf32>, vector<16xf32>, f32
+}
+
 // CHECK-LABEL: @extract_0d
 func.func @extract_0d(%a: vector<f32>) -> f32 {
   // CHECK-NEXT: vector.extract %{{.*}}[] : vector<f32>
@@ -242,8 +256,9 @@ func.func @insert_element(%a: f32, %b: vector<16xf32>) -> vector<16xf32> {
   return %1 : vector<16xf32>
 }
 
-// CHECK-LABEL: @insert
-func.func @insert(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>, %res: vector<4x8x16xf32>) -> vector<4x8x16xf32> {
+// CHECK-LABEL: @insert_const_idx
+func.func @insert_const_idx(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>,
+                            %res: vector<4x8x16xf32>) -> vector<4x8x16xf32> {
   // CHECK: vector.insert %{{.*}}, %{{.*}}[3] : vector<8x16xf32> into vector<4x8x16xf32>
   %1 = vector.insert %c, %res[3] : vector<8x16xf32> into vector<4x8x16xf32>
   // CHECK: vector.insert %{{.*}}, %{{.*}}[3, 3] : vector<16xf32> into vector<4x8x16xf32>
@@ -255,6 +270,19 @@ func.func @insert(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>, %res: vecto
   return %4 : vector<4x8x16xf32>
 }
 
+// CHECK-LABEL: @insert_val_idx
+//  CHECK-SAME:   %[[A:.+]]: f32, %[[B:.+]]: vector<16xf32>, %[[C:.+]]: vector<8x16xf32>, %[[IDX:.+]]: index
+func.func @insert_val_idx(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>,
+                          %idx: index, %res: vector<4x8x16xf32>) -> vector<4x8x16xf32> {
+  // CHECK: vector.insert %[[C]], %{{.*}}[%[[IDX]]] : vector<8x16xf32> into vector<4x8x16xf32>
+  %0 = vector.insert %c, %res[%idx] : vector<8x16xf32> into vector<4x8x16xf32>
+  // CHECK: vector.insert %[[B]], %{{.*}}[%[[IDX]], %[[IDX]]] : vector<16xf32> into vector<4x8x16xf32>
+  %1 = vector.insert %b, %res[%idx, %idx] : vector<16xf32> into vector<4x8x16xf32>
+  // CHECK: vector.insert %[[A]], %{{.*}}[%[[IDX]], 5, %[[IDX]]] : f32 into vector<4x8x16xf32>
+  %2 = vector.insert %a, %res[%idx, 5, %idx] : f32 into vector<4x8x16xf32>
+  return %2 : vector<4x8x16xf32>
+}
+
 // CHECK-LABEL: @insert_0d
 func.func @insert_0d(%a: f32, %b: vector<f32>, %c: vector<2x3xf32>) -> (vector<f32>, vector<2x3xf32>) {
   // CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[] : f32 into vector<f32>
@@ -1007,7 +1035,7 @@ func.func @contraction_masked_scalable(%A: vector<3x4xf32>,
                                     %C: vector<3x[8]xf32>,
                                     %M : vector<3x[8]x4xi1>) -> vector<3x[8]xf32> {
  // CHECK:  vector.mask %[[M]] { vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[A]], %[[B]], %[[C]] : vector<3x4xf32>, vector<4x[8]xf32> into vector<3x[8]xf32> } : vector<3x[8]x4xi1> -> vector<3x[8]xf32>
-  %0 = vector.mask %M { vector.contract #matmat_trait %A, %B, %C : vector<3x4xf32>, vector<4x[8]xf32> into vector<3x[8]xf32> } 
+  %0 = vector.mask %M { vector.contract #matmat_trait %A, %B, %C : vector<3x4xf32>, vector<4x[8]xf32> into vector<3x[8]xf32> }
     : vector<3x[8]x4xi1> -> vector<3x[8]xf32>
   return %0 : vector<3x[8]xf32>
 }

diff  --git a/mlir/test/Dialect/Vector/vector-transforms.mlir b/mlir/test/Dialect/Vector/vector-transforms.mlir
index dfc564ca6fe4836..27bbe1bb0d0349d 100644
--- a/mlir/test/Dialect/Vector/vector-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-transforms.mlir
@@ -286,11 +286,13 @@ func.func @contraction4x4_ikj_xfer_read_tensor(%arg0 : tensor<4x2xf32>,
 func.func @bubble_down_bitcast_in_extract(%src: vector<4xf32>) -> (f16, f16) {
   %0 = vector.bitcast %src : vector<4xf32> to vector<8xf16>
   // CHECK: %[[EXTRACT1:.+]] = vector.extract %[[SRC]][1] : vector<4xf32>
-  // CHECK:    %[[CAST1:.+]] = vector.bitcast %[[EXTRACT1]] : vector<1xf32> to vector<2xf16>
+  // CHECK:  %[[INSERT1:.+]] = vector.insert %[[EXTRACT1]], %{{.+}} [0] : f32 into vector<1xf32>
+  // CHECK:    %[[CAST1:.+]] = vector.bitcast %[[INSERT1]] : vector<1xf32> to vector<2xf16>
   // CHECK: %[[EXTRACT2:.+]] = vector.extract %[[CAST1]][1] : vector<2xf16>
   %1 = vector.extract %0[3] : vector<8xf16>
   // CHECK: %[[EXTRACT3:.+]] = vector.extract %[[SRC]][2] : vector<4xf32>
-  // CHECK:    %[[CAST2:.+]] = vector.bitcast %[[EXTRACT3]] : vector<1xf32> to vector<2xf16>
+  // CHECK:  %[[INSERT3:.+]] = vector.insert %[[EXTRACT3]], %{{.+}} [0] : f32 into vector<1xf32>
+  // CHECK:    %[[CAST2:.+]] = vector.bitcast %[[INSERT3]] : vector<1xf32> to vector<2xf16>
   // CHECK: %[[EXTRACT4:.+]] = vector.extract %[[CAST2]][0] : vector<2xf16>
   %2 = vector.extract %0[4] : vector<8xf16>
   // CHECK: return %[[EXTRACT2]], %[[EXTRACT4]]


        


More information about the Mlir-commits mailing list