[Mlir-commits] [mlir] 16b75cd - [mlir][vector] Use DenseI64ArrayAttr for ExtractOp/InsertOp positions

Matthias Springer llvmlistbot at llvm.org
Mon Jul 31 06:29:30 PDT 2023


Author: Matthias Springer
Date: 2023-07-31T15:25:37+02:00
New Revision: 16b75cd2bb439633d29c99a7663f2586e4068ecf

URL: https://github.com/llvm/llvm-project/commit/16b75cd2bb439633d29c99a7663f2586e4068ecf
DIFF: https://github.com/llvm/llvm-project/commit/16b75cd2bb439633d29c99a7663f2586e4068ecf.diff

LOG: [mlir][vector] Use DenseI64ArrayAttr for ExtractOp/InsertOp positions

`DenseI64ArrayAttr` provides a better API than `I64ArrayAttr`. E.g., accessors returning `ArrayRef<int64_t>` (instead of `ArrayAttr`) are generated.

Differential Revision: https://reviews.llvm.org/D156684

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
    mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
    mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
    mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
    mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
    mlir/test/Dialect/Vector/invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 588998853e6995..63d96721bfd400 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -573,7 +573,7 @@ def Vector_ExtractOp :
      PredOpTrait<"operand and result have same element type",
                  TCresVTEtIsSameAsOpBase<0, 0>>,
      InferTypeOpAdaptorWithIsCompatible]>,
-    Arguments<(ins AnyVectorOfAnyRank:$vector, I64ArrayAttr:$position)>,
+    Arguments<(ins AnyVectorOfAnyRank:$vector, DenseI64ArrayAttr:$position)>,
     Results<(outs AnyType)> {
   let summary = "extract operation";
   let description = [{
@@ -589,7 +589,6 @@ def Vector_ExtractOp :
     ```
   }];
   let builders = [
-    OpBuilder<(ins "Value":$source, "ArrayRef<int64_t>":$position)>,
     // Convenience builder which assumes the values in `position` are defined by
     // ConstantIndexOp.
     OpBuilder<(ins "Value":$source, "ValueRange":$position)>
@@ -689,7 +688,7 @@ def Vector_InsertOp :
      PredOpTrait<"source operand and result have same element type",
                  TCresVTEtIsSameAsOpBase<0, 0>>,
      AllTypesMatch<["dest", "res"]>]>,
-     Arguments<(ins AnyType:$source, AnyVectorOfAnyRank:$dest, I64ArrayAttr:$position)>,
+     Arguments<(ins AnyType:$source, AnyVectorOfAnyRank:$dest, DenseI64ArrayAttr:$position)>,
      Results<(outs AnyVectorOfAnyRank:$res)> {
   let summary = "insert operation";
   let description = [{
@@ -711,8 +710,6 @@ def Vector_InsertOp :
   }];
 
   let builders = [
-    OpBuilder<(ins "Value":$source, "Value":$dest,
-      "ArrayRef<int64_t>":$position)>,
     // Convenience builder which assumes all values are constant indices.
     OpBuilder<(ins "Value":$source, "Value":$dest, "ValueRange":$position)>
   ];

diff  --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index da573686967971..409e9365a9f207 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -807,8 +807,7 @@ createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op,
 
       Value el = rewriter.create<vector::LoadOp>(loc, loadedElType,
                                                  op.getSource(), newIndices);
-      result = rewriter.create<vector::InsertOp>(loc, el, result,
-                                                 rewriter.getI64ArrayAttr(i));
+      result = rewriter.create<vector::InsertOp>(loc, el, result, i);
     }
   } else {
     if (auto vecType = dyn_cast<VectorType>(loadedElType)) {
@@ -832,7 +831,7 @@ createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op,
         Value el = rewriter.create<memref::LoadOp>(op.getLoc(), loadedElType,
                                                    op.getSource(), newIndices);
         result = rewriter.create<vector::InsertOp>(
-            op.getLoc(), el, result, rewriter.getI64ArrayAttr({i, innerIdx}));
+            op.getLoc(), el, result, ArrayRef<int64_t>{i, innerIdx});
       }
     }
   }

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index d0c0d8fa0540f9..fc93f0537c47f0 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1025,44 +1025,37 @@ class VectorExtractOpConversion
     auto loc = extractOp->getLoc();
     auto resultType = extractOp.getResult().getType();
     auto llvmResultType = typeConverter->convertType(resultType);
-    auto positionArrayAttr = extractOp.getPosition();
+    ArrayRef<int64_t> positionArray = extractOp.getPosition();
 
     // Bail if result type cannot be lowered.
     if (!llvmResultType)
       return failure();
 
     // Extract entire vector. Should be handled by folder, but just to be safe.
-    if (positionArrayAttr.empty()) {
+    if (positionArray.empty()) {
       rewriter.replaceOp(extractOp, adaptor.getVector());
       return success();
     }
 
     // One-shot extraction of vector from array (only requires extractvalue).
     if (isa<VectorType>(resultType)) {
-      SmallVector<int64_t> indices;
-      for (auto idx : positionArrayAttr.getAsRange<IntegerAttr>())
-        indices.push_back(idx.getInt());
       Value extracted = rewriter.create<LLVM::ExtractValueOp>(
-          loc, adaptor.getVector(), indices);
+          loc, adaptor.getVector(), positionArray);
       rewriter.replaceOp(extractOp, extracted);
       return success();
     }
 
     // Potential extraction of 1-D vector from array.
     Value extracted = adaptor.getVector();
-    auto positionAttrs = positionArrayAttr.getValue();
-    if (positionAttrs.size() > 1) {
-      SmallVector<int64_t> nMinusOnePosition;
-      for (auto idx : positionAttrs.drop_back())
-        nMinusOnePosition.push_back(cast<IntegerAttr>(idx).getInt());
-      extracted = rewriter.create<LLVM::ExtractValueOp>(loc, extracted,
-                                                        nMinusOnePosition);
+    if (positionArray.size() > 1) {
+      extracted = rewriter.create<LLVM::ExtractValueOp>(
+          loc, extracted, positionArray.drop_back());
     }
 
     // Remaining extraction of element from 1-D LLVM vector
-    auto position = cast<IntegerAttr>(positionAttrs.back());
     auto i64Type = IntegerType::get(rewriter.getContext(), 64);
-    auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
+    auto constant =
+        rewriter.create<LLVM::ConstantOp>(loc, i64Type, positionArray.back());
     extracted =
         rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
     rewriter.replaceOp(extractOp, extracted);
@@ -1147,7 +1140,7 @@ class VectorInsertOpConversion
     auto sourceType = insertOp.getSourceType();
     auto destVectorType = insertOp.getDestVectorType();
     auto llvmResultType = typeConverter->convertType(destVectorType);
-    auto positionArrayAttr = insertOp.getPosition();
+    ArrayRef<int64_t> positionArray = insertOp.getPosition();
 
     // Bail if result type cannot be lowered.
     if (!llvmResultType)
@@ -1155,7 +1148,7 @@ class VectorInsertOpConversion
 
     // Overwrite entire vector with value. Should be handled by folder, but
     // just to be safe.
-    if (positionArrayAttr.empty()) {
+    if (positionArray.empty()) {
       rewriter.replaceOp(insertOp, adaptor.getSource());
       return success();
     }
@@ -1163,36 +1156,32 @@ class VectorInsertOpConversion
     // One-shot insertion of a vector into an array (only requires insertvalue).
     if (isa<VectorType>(sourceType)) {
       Value inserted = rewriter.create<LLVM::InsertValueOp>(
-          loc, adaptor.getDest(), adaptor.getSource(),
-          LLVM::convertArrayToIndices(positionArrayAttr));
+          loc, adaptor.getDest(), adaptor.getSource(), positionArray);
       rewriter.replaceOp(insertOp, inserted);
       return success();
     }
 
     // Potential extraction of 1-D vector from array.
     Value extracted = adaptor.getDest();
-    auto positionAttrs = positionArrayAttr.getValue();
-    auto position = cast<IntegerAttr>(positionAttrs.back());
     auto oneDVectorType = destVectorType;
-    if (positionAttrs.size() > 1) {
+    if (positionArray.size() > 1) {
       oneDVectorType = reducedVectorTypeBack(destVectorType);
       extracted = rewriter.create<LLVM::ExtractValueOp>(
-          loc, extracted,
-          LLVM::convertArrayToIndices(positionAttrs.drop_back()));
+          loc, extracted, positionArray.drop_back());
     }
 
     // Insertion of an element into a 1-D LLVM vector.
     auto i64Type = IntegerType::get(rewriter.getContext(), 64);
-    auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
+    auto constant =
+        rewriter.create<LLVM::ConstantOp>(loc, i64Type, positionArray.back());
     Value inserted = rewriter.create<LLVM::InsertElementOp>(
         loc, typeConverter->convertType(oneDVectorType), extracted,
         adaptor.getSource(), constant);
 
     // Potential insertion of resulting 1-D vector into array.
-    if (positionAttrs.size() > 1) {
+    if (positionArray.size() > 1) {
       inserted = rewriter.create<LLVM::InsertValueOp>(
-          loc, adaptor.getDest(), inserted,
-          LLVM::convertArrayToIndices(positionAttrs.drop_back()));
+          loc, adaptor.getDest(), inserted, positionArray.drop_back());
     }
 
     rewriter.replaceOp(insertOp, inserted);

diff  --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index fc274c989196ca..5e19e422b61116 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -886,10 +886,9 @@ struct UnrollTransferReadConversion
   /// vector::InsertOp, return that operation's indices.
   void getInsertionIndices(TransferReadOp xferOp,
                            SmallVector<int64_t, 8> &indices) const {
-    if (auto insertOp = getInsertOp(xferOp)) {
-      for (Attribute attr : insertOp.getPosition())
-        indices.push_back(dyn_cast<IntegerAttr>(attr).getInt());
-    }
+    if (auto insertOp = getInsertOp(xferOp))
+      indices.assign(insertOp.getPosition().begin(),
+                     insertOp.getPosition().end());
   }
 
   /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
@@ -1013,10 +1012,9 @@ struct UnrollTransferWriteConversion
   /// indices.
   void getExtractionIndices(TransferWriteOp xferOp,
                             SmallVector<int64_t, 8> &indices) const {
-    if (auto extractOp = getExtractOp(xferOp)) {
-      for (Attribute attr : extractOp.getPosition())
-        indices.push_back(dyn_cast<IntegerAttr>(attr).getInt());
-    }
+    if (auto extractOp = getExtractOp(xferOp))
+      indices.assign(extractOp.getPosition().begin(),
+                     extractOp.getPosition().end());
   }
 
   /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds

diff  --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 903441943f200e..c15b99b5a62d3e 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -152,7 +152,7 @@ struct VectorExtractOpConvert final
       return success();
     }
 
-    int32_t id = getFirstIntValue(extractOp.getPosition());
+    int32_t id = extractOp.getPosition()[0];
     rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
         extractOp, adaptor.getVector(), id);
     return success();
@@ -232,7 +232,7 @@ struct VectorInsertOpConvert final
       return success();
     }
 
-    int32_t id = getFirstIntValue(insertOp.getPosition());
+    int32_t id = insertOp.getPosition()[0];
     rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
         insertOp, adaptor.getSource(), adaptor.getDest(), id);
     return success();

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 075b139e2f3b1c..20bd3f32fac91c 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -385,8 +385,7 @@ struct ElideUnitDimsInMultiDimReduction
     } else {
       // This means we are reducing all the dimensions, and all reduction
       // dimensions are of size 1. So a simple extraction would do.
-      auto zeroAttr =
-          rewriter.getI64ArrayAttr(SmallVector<int64_t>(shape.size(), 0));
+      SmallVector<int64_t> zeroAttr(shape.size(), 0);
       if (mask)
         mask = rewriter.create<vector::ExtractOp>(loc, rewriter.getI1Type(),
                                                   mask, zeroAttr);
@@ -560,12 +559,10 @@ struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
       result = rewriter.create<ExtractElementOp>(loc, reductionOp.getVector());
     } else {
       if (mask) {
-        mask = rewriter.create<ExtractOp>(loc, rewriter.getI1Type(), mask,
-                                          rewriter.getI64ArrayAttr(0));
+        mask = rewriter.create<ExtractOp>(loc, rewriter.getI1Type(), mask, 0);
       }
       result = rewriter.create<ExtractOp>(loc, reductionOp.getType(),
-                                          reductionOp.getVector(),
-                                          rewriter.getI64ArrayAttr(0));
+                                          reductionOp.getVector(), 0);
     }
 
     if (Value acc = reductionOp.getAcc())
@@ -1129,18 +1126,11 @@ OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) {
 // ExtractOp
 //===----------------------------------------------------------------------===//
 
-void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
-                              Value source, ArrayRef<int64_t> position) {
-  build(builder, result, source, getVectorSubscriptAttr(builder, position));
-}
-
 // Convenience builder which assumes the values are constant indices.
 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
                               Value source, ValueRange position) {
-  SmallVector<int64_t, 4> positionConstants =
-      llvm::to_vector<4>(llvm::map_range(position, [](Value pos) {
-        return getConstantIntValue(pos).value();
-      }));
+  SmallVector<int64_t> positionConstants = llvm::to_vector(llvm::map_range(
+      position, [](Value pos) { return getConstantIntValue(pos).value(); }));
   build(builder, result, source, positionConstants);
 }
 
@@ -1175,15 +1165,13 @@ bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
 }
 
 LogicalResult vector::ExtractOp::verify() {
-  auto positionAttr = getPosition().getValue();
-  if (positionAttr.size() >
-      static_cast<unsigned>(getSourceVectorType().getRank()))
+  ArrayRef<int64_t> position = getPosition();
+  if (position.size() > static_cast<unsigned>(getSourceVectorType().getRank()))
     return emitOpError(
         "expected position attribute of rank no greater than vector rank");
-  for (const auto &en : llvm::enumerate(positionAttr)) {
-    auto attr = llvm::dyn_cast<IntegerAttr>(en.value());
-    if (!attr || attr.getInt() < 0 ||
-        attr.getInt() >= getSourceVectorType().getDimSize(en.index()))
+  for (const auto &en : llvm::enumerate(position)) {
+    if (en.value() < 0 ||
+        en.value() >= getSourceVectorType().getDimSize(en.index()))
       return emitOpError("expected position attribute #")
              << (en.index() + 1)
              << " to be a non-negative integer smaller than the corresponding "
@@ -1207,18 +1195,18 @@ static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
 
   SmallVector<int64_t, 4> globalPosition;
   ExtractOp currentOp = extractOp;
-  auto extrPos = extractVector<int64_t>(currentOp.getPosition());
+  ArrayRef<int64_t> extrPos = currentOp.getPosition();
   globalPosition.append(extrPos.rbegin(), extrPos.rend());
   while (ExtractOp nextOp = currentOp.getVector().getDefiningOp<ExtractOp>()) {
     currentOp = nextOp;
-    auto extrPos = extractVector<int64_t>(currentOp.getPosition());
+    ArrayRef<int64_t> extrPos = currentOp.getPosition();
     globalPosition.append(extrPos.rbegin(), extrPos.rend());
   }
   extractOp.setOperand(currentOp.getVector());
   // OpBuilder is only used as a helper to build an I64ArrayAttr.
   OpBuilder b(extractOp.getContext());
   std::reverse(globalPosition.begin(), globalPosition.end());
-  extractOp.setPositionAttr(b.getI64ArrayAttr(globalPosition));
+  extractOp.setPosition(globalPosition);
   return success();
 }
 
@@ -1329,7 +1317,8 @@ ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState(
   sentinels.reserve(vectorRank - extractedRank);
   for (int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i)
     sentinels.push_back(-(i + 1));
-  extractPosition = extractVector<int64_t>(extractOp.getPosition());
+  extractPosition.assign(extractOp.getPosition().begin(),
+                         extractOp.getPosition().end());
   llvm::append_range(extractPosition, sentinels);
 }
 
@@ -1349,9 +1338,8 @@ LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
 LogicalResult
 ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
     Value &res) {
-  auto insertedPos = extractVector<int64_t>(nextInsertOp.getPosition());
-  if (ArrayRef(insertedPos) !=
-      llvm::ArrayRef(extractPosition).take_front(extractedRank))
+  ArrayRef<int64_t> insertedPos = nextInsertOp.getPosition();
+  if (insertedPos != llvm::ArrayRef(extractPosition).take_front(extractedRank))
     return failure();
   // Case 2.a. early-exit fold.
   res = nextInsertOp.getSource();
@@ -1364,7 +1352,7 @@ ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
 /// This method updates the internal state.
 LogicalResult
 ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) {
-  auto insertedPos = extractVector<int64_t>(nextInsertOp.getPosition());
+  ArrayRef<int64_t> insertedPos = nextInsertOp.getPosition();
   if (!isContainedWithin(insertedPos, extractPosition))
     return failure();
   // Set leading dims to zero.
@@ -1390,9 +1378,7 @@ Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
     return Value();
   // Otherwise, fold by updating the op inplace and return its result.
   OpBuilder b(extractOp.getContext());
-  extractOp->setAttr(
-      extractOp.getPositionAttrName(),
-      b.getI64ArrayAttr(ArrayRef(extractPosition).take_front(extractedRank)));
+  extractOp.setPosition(ArrayRef(extractPosition).take_front(extractedRank));
   extractOp.getVectorMutable().assign(source);
   return extractOp.getResult();
 }
@@ -1422,7 +1408,7 @@ Value ExtractFromInsertTransposeChainState::fold() {
 
     // Case 4: extractPositionRef intersects insertedPosRef on non-sentinel
     // values. This is a more 
diff icult case and we bail.
-    auto insertedPos = extractVector<int64_t>(nextInsertOp.getPosition());
+    ArrayRef<int64_t> insertedPos = nextInsertOp.getPosition();
     if (isContainedWithin(extractPosition, insertedPos) ||
         intersectsWhereNonNegative(extractPosition, insertedPos))
       return Value();
@@ -1487,7 +1473,7 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
   // extract position to `0` when extracting from the source operand.
   llvm::SetVector<int64_t> broadcastedUnitDims =
       broadcastOp.computeBroadcastedUnitDims();
-  auto extractPos = extractVector<int64_t>(extractOp.getPosition());
+  SmallVector<int64_t> extractPos(extractOp.getPosition());
   for (int64_t i = rankDiff, e = extractPos.size(); i < e; ++i)
     if (broadcastedUnitDims.contains(i))
       extractPos[i] = 0;
@@ -1498,7 +1484,7 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
   // OpBuilder is only used as a helper to build an I64ArrayAttr.
   OpBuilder b(extractOp.getContext());
   extractOp.setOperand(source);
-  extractOp.setPositionAttr(b.getI64ArrayAttr(extractPos));
+  extractOp.setPosition(extractPos);
   return extractOp.getResult();
 }
 
@@ -1537,7 +1523,7 @@ static Value foldExtractFromShapeCast(ExtractOp extractOp) {
   }
   // Extract the strides associated with the extract op vector source. Then use
   // this to calculate a linearized position for the extract.
-  auto extractedPos = extractVector<int64_t>(extractOp.getPosition());
+  SmallVector<int64_t> extractedPos(extractOp.getPosition());
   std::reverse(extractedPos.begin(), extractedPos.end());
   SmallVector<int64_t, 4> strides;
   int64_t stride = 1;
@@ -1563,7 +1549,7 @@ static Value foldExtractFromShapeCast(ExtractOp extractOp) {
   SmallVector<int64_t, 4> newPosition = delinearize(position, newStrides);
   // OpBuilder is only used as a helper to build an I64ArrayAttr.
   OpBuilder b(extractOp.getContext());
-  extractOp.setPositionAttr(b.getI64ArrayAttr(newPosition));
+  extractOp.setPosition(newPosition);
   extractOp.setOperand(shapeCastOp.getSource());
   return extractOp.getResult();
 }
@@ -1603,14 +1589,14 @@ static Value foldExtractFromExtractStrided(ExtractOp extractOp) {
   if (destinationRank > extractStridedSliceOp.getSourceVectorType().getRank() -
                             sliceOffsets.size())
     return Value();
-  auto extractedPos = extractVector<int64_t>(extractOp.getPosition());
+  SmallVector<int64_t> extractedPos(extractOp.getPosition());
   assert(extractedPos.size() >= sliceOffsets.size());
   for (size_t i = 0, e = sliceOffsets.size(); i < e; i++)
     extractedPos[i] = extractedPos[i] + sliceOffsets[i];
   extractOp.getVectorMutable().assign(extractStridedSliceOp.getVector());
   // OpBuilder is only used as a helper to build an I64ArrayAttr.
   OpBuilder b(extractOp.getContext());
-  extractOp.setPositionAttr(b.getI64ArrayAttr(extractedPos));
+  extractOp.setPosition(extractedPos);
   return extractOp.getResult();
 }
 
@@ -1635,7 +1621,7 @@ static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) {
     if (destinationRank > insertOp.getSourceVectorType().getRank())
       return Value();
     auto insertOffsets = extractVector<int64_t>(insertOp.getOffsets());
-    auto extractOffsets = extractVector<int64_t>(extractOp.getPosition());
+    ArrayRef<int64_t> extractOffsets = extractOp.getPosition();
 
     if (llvm::any_of(insertOp.getStrides(), [](Attribute attr) {
           return llvm::cast<IntegerAttr>(attr).getInt() != 1;
@@ -1675,7 +1661,7 @@ static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) {
       extractOp.getVectorMutable().assign(insertOp.getSource());
       // OpBuilder is only used as a helper to build an I64ArrayAttr.
       OpBuilder b(extractOp.getContext());
-      extractOp.setPositionAttr(b.getI64ArrayAttr(offsetDiffs));
+      extractOp.setPosition(offsetDiffs);
       return extractOp.getResult();
     }
     // If the chunk extracted is disjoint from the chunk inserted, keep
@@ -1795,7 +1781,7 @@ class ExtractOpNonSplatConstantFolder final
     // Calculate the linearized position of the continuous chunk of elements to
     // extract.
     llvm::SmallVector<int64_t> completePositions(vecTy.getRank(), 0);
-    copy(getI64SubArray(extractOp.getPosition()), completePositions.begin());
+    copy(extractOp.getPosition(), completePositions.begin());
     int64_t elemBeginPosition =
         linearize(completePositions, computeStrides(vecTy.getShape()));
     auto denseValuesBegin = dense.value_begin<TypedAttr>() + elemBeginPosition;
@@ -2288,14 +2274,6 @@ OpFoldResult vector::InsertElementOp::fold(FoldAdaptor adaptor) {
 // InsertOp
 //===----------------------------------------------------------------------===//
 
-void InsertOp::build(OpBuilder &builder, OperationState &result, Value source,
-                     Value dest, ArrayRef<int64_t> position) {
-  result.addOperands({source, dest});
-  auto positionAttr = getVectorSubscriptAttr(builder, position);
-  result.addTypes(dest.getType());
-  result.addAttribute(InsertOp::getPositionAttrName(result.name), positionAttr);
-}
-
 // Convenience builder which assumes the values are constant indices.
 void InsertOp::build(OpBuilder &builder, OperationState &result, Value source,
                      Value dest, ValueRange position) {
@@ -2307,25 +2285,24 @@ void InsertOp::build(OpBuilder &builder, OperationState &result, Value source,
 }
 
 LogicalResult InsertOp::verify() {
-  auto positionAttr = getPosition().getValue();
+  ArrayRef<int64_t> position = getPosition();
   auto destVectorType = getDestVectorType();
-  if (positionAttr.size() > static_cast<unsigned>(destVectorType.getRank()))
+  if (position.size() > static_cast<unsigned>(destVectorType.getRank()))
     return emitOpError(
         "expected position attribute of rank no greater than dest vector rank");
   auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
   if (srcVectorType &&
-      (static_cast<unsigned>(srcVectorType.getRank()) + positionAttr.size() !=
+      (static_cast<unsigned>(srcVectorType.getRank()) + position.size() !=
        static_cast<unsigned>(destVectorType.getRank())))
     return emitOpError("expected position attribute rank + source rank to "
                        "match dest vector rank");
   if (!srcVectorType &&
-      (positionAttr.size() != static_cast<unsigned>(destVectorType.getRank())))
+      (position.size() != static_cast<unsigned>(destVectorType.getRank())))
     return emitOpError(
         "expected position attribute rank to match the dest vector rank");
-  for (const auto &en : llvm::enumerate(positionAttr)) {
-    auto attr = llvm::dyn_cast<IntegerAttr>(en.value());
-    if (!attr || attr.getInt() < 0 ||
-        attr.getInt() >= destVectorType.getDimSize(en.index()))
+  for (const auto &en : llvm::enumerate(position)) {
+    int64_t attr = en.value();
+    if (attr < 0 || attr >= destVectorType.getDimSize(en.index()))
       return emitOpError("expected position attribute #")
              << (en.index() + 1)
              << " to be a non-negative integer smaller than the corresponding "
@@ -2412,7 +2389,7 @@ class InsertOpConstantFolder final : public OpRewritePattern<InsertOp> {
     // Calculate the linearized position of the continuous chunk of elements to
     // insert.
     llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
-    copy(getI64SubArray(op.getPosition()), completePositions.begin());
+    copy(op.getPosition(), completePositions.begin());
     int64_t insertBeginPosition =
         linearize(completePositions, computeStrides(destTy.getShape()));
 

diff  --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index 986c5f81d60c22..66ac5ffef3e3ed 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -91,10 +91,8 @@ static Value reshapeLoad(Location loc, Value val, VectorType type,
     return val;
   Type lowType = VectorType::Builder(type).dropDim(0);
   // At extraction dimension?
-  if (index == 0) {
-    auto posAttr = rewriter.getI64ArrayAttr(pos);
-    return rewriter.create<vector::ExtractOp>(loc, lowType, val, posAttr);
-  }
+  if (index == 0)
+    return rewriter.create<vector::ExtractOp>(loc, lowType, val, pos);
   // Unroll leading dimensions.
   VectorType vType = cast<VectorType>(lowType);
   Type resType = VectorType::Builder(type).dropDim(index);
@@ -102,11 +100,10 @@ static Value reshapeLoad(Location loc, Value val, VectorType type,
   Value result = rewriter.create<arith::ConstantOp>(
       loc, resVectorType, rewriter.getZeroAttr(resVectorType));
   for (int64_t d = 0, e = resVectorType.getDimSize(0); d < e; d++) {
-    auto posAttr = rewriter.getI64ArrayAttr(d);
-    Value ext = rewriter.create<vector::ExtractOp>(loc, vType, val, posAttr);
+    Value ext = rewriter.create<vector::ExtractOp>(loc, vType, val, d);
     Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter);
-    result = rewriter.create<vector::InsertOp>(loc, resVectorType, load, result,
-                                               posAttr);
+    result =
+        rewriter.create<vector::InsertOp>(loc, resVectorType, load, result, d);
   }
   return result;
 }
@@ -120,20 +117,17 @@ static Value reshapeStore(Location loc, Value val, Value result,
   if (index == -1)
     return val;
   // At insertion dimension?
-  if (index == 0) {
-    auto posAttr = rewriter.getI64ArrayAttr(pos);
-    return rewriter.create<vector::InsertOp>(loc, type, val, result, posAttr);
-  }
+  if (index == 0)
+    return rewriter.create<vector::InsertOp>(loc, type, val, result, pos);
   // Unroll leading dimensions.
   Type lowType = VectorType::Builder(type).dropDim(0);
   VectorType vType = cast<VectorType>(lowType);
   Type insType = VectorType::Builder(vType).dropDim(0);
   for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) {
-    auto posAttr = rewriter.getI64ArrayAttr(d);
-    Value ext = rewriter.create<vector::ExtractOp>(loc, vType, result, posAttr);
-    Value ins = rewriter.create<vector::ExtractOp>(loc, insType, val, posAttr);
+    Value ext = rewriter.create<vector::ExtractOp>(loc, vType, result, d);
+    Value ins = rewriter.create<vector::ExtractOp>(loc, insType, val, d);
     Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter);
-    result = rewriter.create<vector::InsertOp>(loc, type, sto, result, posAttr);
+    result = rewriter.create<vector::InsertOp>(loc, type, sto, result, d);
   }
   return result;
 }
@@ -823,10 +817,8 @@ struct ContractOpToElementwise
     newRhs = rewriter.create<vector::TransposeOp>(loc, newRhs, rhsTranspose);
     SmallVector<int64_t> lhsOffsets(lhsReductionDims.size(), 0);
     SmallVector<int64_t> rhsOffsets(rhsReductionDims.size(), 0);
-    newLhs = rewriter.create<vector::ExtractOp>(
-        loc, newLhs, rewriter.getI64ArrayAttr(lhsOffsets));
-    newRhs = rewriter.create<vector::ExtractOp>(
-        loc, newRhs, rewriter.getI64ArrayAttr(rhsOffsets));
+    newLhs = rewriter.create<vector::ExtractOp>(loc, newLhs, lhsOffsets);
+    newRhs = rewriter.create<vector::ExtractOp>(loc, newRhs, rhsOffsets);
     std::optional<Value> result =
         createContractArithOp(loc, newLhs, newRhs, contractOp.getAcc(),
                               contractOp.getKind(), rewriter, isInt);
@@ -1167,21 +1159,20 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
     Value result = rewriter.create<arith::ConstantOp>(
         loc, resType, rewriter.getZeroAttr(resType));
     for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) {
-      auto pos = rewriter.getI64ArrayAttr(d);
-      Value x = rewriter.create<vector::ExtractOp>(loc, op.getLhs(), pos);
+      Value x = rewriter.create<vector::ExtractOp>(loc, op.getLhs(), d);
       Value a = rewriter.create<vector::BroadcastOp>(loc, rhsType, x);
       Value r = nullptr;
       if (acc)
-        r = rewriter.create<vector::ExtractOp>(loc, acc, pos);
+        r = rewriter.create<vector::ExtractOp>(loc, acc, d);
       Value extrMask;
       if (mask)
-        extrMask = rewriter.create<vector::ExtractOp>(loc, mask, pos);
+        extrMask = rewriter.create<vector::ExtractOp>(loc, mask, d);
 
       std::optional<Value> m = createContractArithOp(
           loc, a, op.getRhs(), r, kind, rewriter, isInt, extrMask);
       if (!m.has_value())
         return failure();
-      result = rewriter.create<vector::InsertOp>(loc, resType, *m, result, pos);
+      result = rewriter.create<vector::InsertOp>(loc, resType, *m, result, d);
     }
 
     rewriter.replaceOp(rootOp, result);

diff  --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
index a0ed056fc7a328..796bbab38dcbf6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
@@ -77,9 +77,7 @@ class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> {
       Value val = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
                                                  bnd, idx);
       Value sel = rewriter.create<arith::SelectOp>(loc, val, trueVal, falseVal);
-      auto pos = rewriter.getI64ArrayAttr(d);
-      result =
-          rewriter.create<vector::InsertOp>(loc, dstType, sel, result, pos);
+      result = rewriter.create<vector::InsertOp>(loc, dstType, sel, result, d);
     }
     rewriter.replaceOp(op, result);
     return success();
@@ -151,11 +149,9 @@ class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
         loc, lowType, rewriter.getI64ArrayAttr(newDimSizes));
     Value result = rewriter.create<arith::ConstantOp>(
         loc, dstType, rewriter.getZeroAttr(dstType));
-    for (int64_t d = 0; d < trueDim; d++) {
-      auto pos = rewriter.getI64ArrayAttr(d);
+    for (int64_t d = 0; d < trueDim; d++)
       result =
-          rewriter.create<vector::InsertOp>(loc, dstType, trueVal, result, pos);
-    }
+          rewriter.create<vector::InsertOp>(loc, dstType, trueVal, result, d);
     rewriter.replaceOp(op, result);
     return success();
   }

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 72aae4956b3e28..9d6c45b4bceaec 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -944,7 +944,7 @@ struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
     // Rewrite vector.extract with 1d source to vector.extractelement.
     if (extractSrcType.getRank() == 1) {
       assert(extractOp.getPosition().size() == 1 && "expected 1 index");
-      int64_t pos = cast<IntegerAttr>(extractOp.getPosition()[0]).getInt();
+      int64_t pos = extractOp.getPosition()[0];
       rewriter.setInsertionPoint(extractOp);
       rewriter.replaceOpWithNewOp<vector::ExtractElementOp>(
           extractOp, extractOp.getVector(),
@@ -1201,7 +1201,7 @@ struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> {
     // Rewrite vector.insert with 1d dest to vector.insertelement.
     if (insertOp.getDestVectorType().getRank() == 1) {
       assert(insertOp.getPosition().size() == 1 && "expected 1 index");
-      int64_t pos = cast<IntegerAttr>(insertOp.getPosition()[0]).getInt();
+      int64_t pos = insertOp.getPosition()[0];
       rewriter.setInsertionPoint(insertOp);
       rewriter.replaceOpWithNewOp<vector::InsertElementOp>(
           insertOp, insertOp.getSource(), insertOp.getDest(),
@@ -1276,10 +1276,7 @@ struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> {
     } else {
       // One lane inserts the entire source vector.
       int64_t elementsPerLane = distrDestType.getDimSize(distrDestDim);
-      SmallVector<int64_t> newPos = llvm::to_vector(
-          llvm::map_range(insertOp.getPosition(), [](Attribute attr) {
-            return cast<IntegerAttr>(attr).getInt();
-          }));
+      SmallVector<int64_t> newPos(insertOp.getPosition());
       // tid of inserting lane: pos / elementsPerLane
       Value insertingLane = rewriter.create<arith::ConstantIndexOp>(
           loc, newPos[distrDestDim] / elementsPerLane);

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index 062950f6456f46..dabbca3f7a2271 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -165,16 +165,14 @@ struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> {
     // type has leading unit dims, we also trim the position array accordingly,
     // then (2) if source type also has leading unit dims, we need to append
     // zeroes to the position array accordingly.
-    unsigned oldPosRank = insertOp.getPosition().getValue().size();
+    unsigned oldPosRank = insertOp.getPosition().size();
     unsigned newPosRank = std::max<int64_t>(0, oldPosRank - dstDropCount);
-    SmallVector<Attribute> newPositions = llvm::to_vector(
-        insertOp.getPosition().getValue().take_back(newPosRank));
-    newPositions.resize(newDstType.getRank() - newSrcRank,
-                        rewriter.getI64IntegerAttr(0));
+    SmallVector<int64_t> newPositions =
+        llvm::to_vector(insertOp.getPosition().take_back(newPosRank));
+    newPositions.resize(newDstType.getRank() - newSrcRank, 0);
 
     auto newInsertOp = rewriter.create<vector::InsertOp>(
-        loc, newDstType, newSrcVector, newDstVector,
-        rewriter.getArrayAttr(newPositions));
+        loc, newDstType, newSrcVector, newDstVector, newPositions);
 
     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType,
                                                      newInsertOp);

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index dd4948f34d6682..74d4b7636315fd 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -704,7 +704,7 @@ class RewriteScalarExtractOfTransferRead
     SmallVector<Value> newIndices(xferOp.getIndices().begin(),
                                   xferOp.getIndices().end());
     for (const auto &it : llvm::enumerate(extractOp.getPosition())) {
-      int64_t offset = cast<IntegerAttr>(it.value()).getInt();
+      int64_t offset = it.value();
       int64_t idx =
           newIndices.size() - extractOp.getPosition().size() + it.index();
       OpFoldResult ofr = affine::makeComposedFoldedAffineApply(

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index abe6d8846a2357..a6177641dc6b43 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -598,11 +598,7 @@ struct BubbleDownVectorBitCastForExtract
     unsigned expandRatio =
         castDstType.getNumElements() / castSrcType.getNumElements();
 
-    auto getFirstIntValue = [](ArrayAttr attr) -> uint64_t {
-      return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
-    };
-
-    uint64_t index = getFirstIntValue(extractOp.getPosition());
+    uint64_t index = extractOp.getPosition()[0];
 
     // Get the single scalar (as a vector) in the source value that packs the
     // desired scalar. E.g. extract vector<1xf32> from vector<4xf32>
@@ -610,7 +606,7 @@ struct BubbleDownVectorBitCastForExtract
         VectorType::get({1}, castSrcType.getElementType());
     Value packedValue = rewriter.create<vector::ExtractOp>(
         extractOp.getLoc(), oneScalarType, castOp.getSource(),
-        rewriter.getI64ArrayAttr(index / expandRatio));
+        index / expandRatio);
 
     // Cast it to a vector with the desired scalar's type.
     // E.g. f32 -> vector<2xf16>
@@ -621,8 +617,7 @@ struct BubbleDownVectorBitCastForExtract
 
     // Finally extract the desired scalar.
     rewriter.replaceOpWithNewOp<vector::ExtractOp>(
-        extractOp, extractOp.getType(), castedValue,
-        rewriter.getI64ArrayAttr(index % expandRatio));
+        extractOp, extractOp.getType(), castedValue, index % expandRatio);
 
     return success();
   }

diff  --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 535da1328d34c5..922351265e38b9 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -155,8 +155,8 @@ func.func @broadcast(%arg0 : f32) -> (vector<4xf32>, vector<2xf32>) {
 //       CHECK:   spirv.CompositeExtract %[[ARG]][0 : i32] : vector<2xf32>
 //       CHECK:   spirv.CompositeExtract %[[ARG]][1 : i32] : vector<2xf32>
 func.func @extract(%arg0 : vector<2xf32>) -> (vector<1xf32>, f32) {
-  %0 = "vector.extract"(%arg0) {position = [0]} : (vector<2xf32>) -> vector<1xf32>
-  %1 = "vector.extract"(%arg0) {position = [1]} : (vector<2xf32>) -> f32
+  %0 = "vector.extract"(%arg0) <{position = array<i64: 0>}> : (vector<2xf32>) -> vector<1xf32>
+  %1 = "vector.extract"(%arg0) <{position = array<i64: 1>}> : (vector<2xf32>) -> f32
   return %0, %1: vector<1xf32>, f32
 }
 

diff  --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 242b0728e9953b..16fb631af25834 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -133,7 +133,7 @@ func.func @extract_position_rank_overflow(%arg0: vector<4x8x16xf32>) {
 
 func.func @extract_position_rank_overflow_generic(%arg0: vector<4x8x16xf32>) {
   // expected-error at +1 {{expected position attribute of rank no greater than vector rank}}
-  %1 = "vector.extract" (%arg0) { position = [0, 0, 0, 0] } : (vector<4x8x16xf32>) -> (vector<16xf32>)
+  %1 = "vector.extract" (%arg0) <{position = array<i64: 0, 0, 0, 0>}> : (vector<4x8x16xf32>) -> (vector<16xf32>)
 }
 
 // -----


        


More information about the Mlir-commits mailing list