[Mlir-commits] [mlir] 16b75cd - [mlir][vector] Use DenseI64ArrayAttr for ExtractOp/InsertOp positions
Matthias Springer
llvmlistbot at llvm.org
Mon Jul 31 06:29:30 PDT 2023
Author: Matthias Springer
Date: 2023-07-31T15:25:37+02:00
New Revision: 16b75cd2bb439633d29c99a7663f2586e4068ecf
URL: https://github.com/llvm/llvm-project/commit/16b75cd2bb439633d29c99a7663f2586e4068ecf
DIFF: https://github.com/llvm/llvm-project/commit/16b75cd2bb439633d29c99a7663f2586e4068ecf.diff
LOG: [mlir][vector] Use DenseI64ArrayAttr for ExtractOp/InsertOp positions
`DenseI64ArrayAttr` provides a better API than `I64ArrayAttr`. E.g., accessors returning `ArrayRef<int64_t>` (instead of `ArrayAttr`) are generated.
Differential Revision: https://reviews.llvm.org/D156684
Added:
Modified:
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
mlir/test/Dialect/Vector/invalid.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 588998853e6995..63d96721bfd400 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -573,7 +573,7 @@ def Vector_ExtractOp :
PredOpTrait<"operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>,
InferTypeOpAdaptorWithIsCompatible]>,
- Arguments<(ins AnyVectorOfAnyRank:$vector, I64ArrayAttr:$position)>,
+ Arguments<(ins AnyVectorOfAnyRank:$vector, DenseI64ArrayAttr:$position)>,
Results<(outs AnyType)> {
let summary = "extract operation";
let description = [{
@@ -589,7 +589,6 @@ def Vector_ExtractOp :
```
}];
let builders = [
- OpBuilder<(ins "Value":$source, "ArrayRef<int64_t>":$position)>,
// Convenience builder which assumes the values in `position` are defined by
// ConstantIndexOp.
OpBuilder<(ins "Value":$source, "ValueRange":$position)>
@@ -689,7 +688,7 @@ def Vector_InsertOp :
PredOpTrait<"source operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>,
AllTypesMatch<["dest", "res"]>]>,
- Arguments<(ins AnyType:$source, AnyVectorOfAnyRank:$dest, I64ArrayAttr:$position)>,
+ Arguments<(ins AnyType:$source, AnyVectorOfAnyRank:$dest, DenseI64ArrayAttr:$position)>,
Results<(outs AnyVectorOfAnyRank:$res)> {
let summary = "insert operation";
let description = [{
@@ -711,8 +710,6 @@ def Vector_InsertOp :
}];
let builders = [
- OpBuilder<(ins "Value":$source, "Value":$dest,
- "ArrayRef<int64_t>":$position)>,
// Convenience builder which assumes all values are constant indices.
OpBuilder<(ins "Value":$source, "Value":$dest, "ValueRange":$position)>
];
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index da573686967971..409e9365a9f207 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -807,8 +807,7 @@ createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op,
Value el = rewriter.create<vector::LoadOp>(loc, loadedElType,
op.getSource(), newIndices);
- result = rewriter.create<vector::InsertOp>(loc, el, result,
- rewriter.getI64ArrayAttr(i));
+ result = rewriter.create<vector::InsertOp>(loc, el, result, i);
}
} else {
if (auto vecType = dyn_cast<VectorType>(loadedElType)) {
@@ -832,7 +831,7 @@ createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op,
Value el = rewriter.create<memref::LoadOp>(op.getLoc(), loadedElType,
op.getSource(), newIndices);
result = rewriter.create<vector::InsertOp>(
- op.getLoc(), el, result, rewriter.getI64ArrayAttr({i, innerIdx}));
+ op.getLoc(), el, result, ArrayRef<int64_t>{i, innerIdx});
}
}
}
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index d0c0d8fa0540f9..fc93f0537c47f0 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1025,44 +1025,37 @@ class VectorExtractOpConversion
auto loc = extractOp->getLoc();
auto resultType = extractOp.getResult().getType();
auto llvmResultType = typeConverter->convertType(resultType);
- auto positionArrayAttr = extractOp.getPosition();
+ ArrayRef<int64_t> positionArray = extractOp.getPosition();
// Bail if result type cannot be lowered.
if (!llvmResultType)
return failure();
// Extract entire vector. Should be handled by folder, but just to be safe.
- if (positionArrayAttr.empty()) {
+ if (positionArray.empty()) {
rewriter.replaceOp(extractOp, adaptor.getVector());
return success();
}
// One-shot extraction of vector from array (only requires extractvalue).
if (isa<VectorType>(resultType)) {
- SmallVector<int64_t> indices;
- for (auto idx : positionArrayAttr.getAsRange<IntegerAttr>())
- indices.push_back(idx.getInt());
Value extracted = rewriter.create<LLVM::ExtractValueOp>(
- loc, adaptor.getVector(), indices);
+ loc, adaptor.getVector(), positionArray);
rewriter.replaceOp(extractOp, extracted);
return success();
}
// Potential extraction of 1-D vector from array.
Value extracted = adaptor.getVector();
- auto positionAttrs = positionArrayAttr.getValue();
- if (positionAttrs.size() > 1) {
- SmallVector<int64_t> nMinusOnePosition;
- for (auto idx : positionAttrs.drop_back())
- nMinusOnePosition.push_back(cast<IntegerAttr>(idx).getInt());
- extracted = rewriter.create<LLVM::ExtractValueOp>(loc, extracted,
- nMinusOnePosition);
+ if (positionArray.size() > 1) {
+ extracted = rewriter.create<LLVM::ExtractValueOp>(
+ loc, extracted, positionArray.drop_back());
}
// Remaining extraction of element from 1-D LLVM vector
- auto position = cast<IntegerAttr>(positionAttrs.back());
auto i64Type = IntegerType::get(rewriter.getContext(), 64);
- auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
+ auto constant =
+ rewriter.create<LLVM::ConstantOp>(loc, i64Type, positionArray.back());
extracted =
rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
rewriter.replaceOp(extractOp, extracted);
@@ -1147,7 +1140,7 @@ class VectorInsertOpConversion
auto sourceType = insertOp.getSourceType();
auto destVectorType = insertOp.getDestVectorType();
auto llvmResultType = typeConverter->convertType(destVectorType);
- auto positionArrayAttr = insertOp.getPosition();
+ ArrayRef<int64_t> positionArray = insertOp.getPosition();
// Bail if result type cannot be lowered.
if (!llvmResultType)
@@ -1155,7 +1148,7 @@ class VectorInsertOpConversion
// Overwrite entire vector with value. Should be handled by folder, but
// just to be safe.
- if (positionArrayAttr.empty()) {
+ if (positionArray.empty()) {
rewriter.replaceOp(insertOp, adaptor.getSource());
return success();
}
@@ -1163,36 +1156,32 @@ class VectorInsertOpConversion
// One-shot insertion of a vector into an array (only requires insertvalue).
if (isa<VectorType>(sourceType)) {
Value inserted = rewriter.create<LLVM::InsertValueOp>(
- loc, adaptor.getDest(), adaptor.getSource(),
- LLVM::convertArrayToIndices(positionArrayAttr));
+ loc, adaptor.getDest(), adaptor.getSource(), positionArray);
rewriter.replaceOp(insertOp, inserted);
return success();
}
// Potential extraction of 1-D vector from array.
Value extracted = adaptor.getDest();
- auto positionAttrs = positionArrayAttr.getValue();
- auto position = cast<IntegerAttr>(positionAttrs.back());
auto oneDVectorType = destVectorType;
- if (positionAttrs.size() > 1) {
+ if (positionArray.size() > 1) {
oneDVectorType = reducedVectorTypeBack(destVectorType);
extracted = rewriter.create<LLVM::ExtractValueOp>(
- loc, extracted,
- LLVM::convertArrayToIndices(positionAttrs.drop_back()));
+ loc, extracted, positionArray.drop_back());
}
// Insertion of an element into a 1-D LLVM vector.
auto i64Type = IntegerType::get(rewriter.getContext(), 64);
- auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
+ auto constant =
+ rewriter.create<LLVM::ConstantOp>(loc, i64Type, positionArray.back());
Value inserted = rewriter.create<LLVM::InsertElementOp>(
loc, typeConverter->convertType(oneDVectorType), extracted,
adaptor.getSource(), constant);
// Potential insertion of resulting 1-D vector into array.
- if (positionAttrs.size() > 1) {
+ if (positionArray.size() > 1) {
inserted = rewriter.create<LLVM::InsertValueOp>(
- loc, adaptor.getDest(), inserted,
- LLVM::convertArrayToIndices(positionAttrs.drop_back()));
+ loc, adaptor.getDest(), inserted, positionArray.drop_back());
}
rewriter.replaceOp(insertOp, inserted);
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index fc274c989196ca..5e19e422b61116 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -886,10 +886,9 @@ struct UnrollTransferReadConversion
/// vector::InsertOp, return that operation's indices.
void getInsertionIndices(TransferReadOp xferOp,
SmallVector<int64_t, 8> &indices) const {
- if (auto insertOp = getInsertOp(xferOp)) {
- for (Attribute attr : insertOp.getPosition())
- indices.push_back(dyn_cast<IntegerAttr>(attr).getInt());
- }
+ if (auto insertOp = getInsertOp(xferOp))
+ indices.assign(insertOp.getPosition().begin(),
+ insertOp.getPosition().end());
}
/// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
@@ -1013,10 +1012,9 @@ struct UnrollTransferWriteConversion
/// indices.
void getExtractionIndices(TransferWriteOp xferOp,
SmallVector<int64_t, 8> &indices) const {
- if (auto extractOp = getExtractOp(xferOp)) {
- for (Attribute attr : extractOp.getPosition())
- indices.push_back(dyn_cast<IntegerAttr>(attr).getInt());
- }
+ if (auto extractOp = getExtractOp(xferOp))
+ indices.assign(extractOp.getPosition().begin(),
+ extractOp.getPosition().end());
}
/// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 903441943f200e..c15b99b5a62d3e 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -152,7 +152,7 @@ struct VectorExtractOpConvert final
return success();
}
- int32_t id = getFirstIntValue(extractOp.getPosition());
+ int32_t id = extractOp.getPosition()[0];
rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
extractOp, adaptor.getVector(), id);
return success();
@@ -232,7 +232,7 @@ struct VectorInsertOpConvert final
return success();
}
- int32_t id = getFirstIntValue(insertOp.getPosition());
+ int32_t id = insertOp.getPosition()[0];
rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
insertOp, adaptor.getSource(), adaptor.getDest(), id);
return success();
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 075b139e2f3b1c..20bd3f32fac91c 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -385,8 +385,7 @@ struct ElideUnitDimsInMultiDimReduction
} else {
// This means we are reducing all the dimensions, and all reduction
// dimensions are of size 1. So a simple extraction would do.
- auto zeroAttr =
- rewriter.getI64ArrayAttr(SmallVector<int64_t>(shape.size(), 0));
+ SmallVector<int64_t> zeroAttr(shape.size(), 0);
if (mask)
mask = rewriter.create<vector::ExtractOp>(loc, rewriter.getI1Type(),
mask, zeroAttr);
@@ -560,12 +559,10 @@ struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
result = rewriter.create<ExtractElementOp>(loc, reductionOp.getVector());
} else {
if (mask) {
- mask = rewriter.create<ExtractOp>(loc, rewriter.getI1Type(), mask,
- rewriter.getI64ArrayAttr(0));
+ mask = rewriter.create<ExtractOp>(loc, rewriter.getI1Type(), mask, 0);
}
result = rewriter.create<ExtractOp>(loc, reductionOp.getType(),
- reductionOp.getVector(),
- rewriter.getI64ArrayAttr(0));
+ reductionOp.getVector(), 0);
}
if (Value acc = reductionOp.getAcc())
@@ -1129,18 +1126,11 @@ OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) {
// ExtractOp
//===----------------------------------------------------------------------===//
-void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
- Value source, ArrayRef<int64_t> position) {
- build(builder, result, source, getVectorSubscriptAttr(builder, position));
-}
-
// Convenience builder which assumes the values are constant indices.
void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
Value source, ValueRange position) {
- SmallVector<int64_t, 4> positionConstants =
- llvm::to_vector<4>(llvm::map_range(position, [](Value pos) {
- return getConstantIntValue(pos).value();
- }));
+ SmallVector<int64_t> positionConstants = llvm::to_vector(llvm::map_range(
+ position, [](Value pos) { return getConstantIntValue(pos).value(); }));
build(builder, result, source, positionConstants);
}
@@ -1175,15 +1165,13 @@ bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
}
LogicalResult vector::ExtractOp::verify() {
- auto positionAttr = getPosition().getValue();
- if (positionAttr.size() >
- static_cast<unsigned>(getSourceVectorType().getRank()))
+ ArrayRef<int64_t> position = getPosition();
+ if (position.size() > static_cast<unsigned>(getSourceVectorType().getRank()))
return emitOpError(
"expected position attribute of rank no greater than vector rank");
- for (const auto &en : llvm::enumerate(positionAttr)) {
- auto attr = llvm::dyn_cast<IntegerAttr>(en.value());
- if (!attr || attr.getInt() < 0 ||
- attr.getInt() >= getSourceVectorType().getDimSize(en.index()))
+ for (const auto &en : llvm::enumerate(position)) {
+ if (en.value() < 0 ||
+ en.value() >= getSourceVectorType().getDimSize(en.index()))
return emitOpError("expected position attribute #")
<< (en.index() + 1)
<< " to be a non-negative integer smaller than the corresponding "
@@ -1207,18 +1195,18 @@ static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
SmallVector<int64_t, 4> globalPosition;
ExtractOp currentOp = extractOp;
- auto extrPos = extractVector<int64_t>(currentOp.getPosition());
+ ArrayRef<int64_t> extrPos = currentOp.getPosition();
globalPosition.append(extrPos.rbegin(), extrPos.rend());
while (ExtractOp nextOp = currentOp.getVector().getDefiningOp<ExtractOp>()) {
currentOp = nextOp;
- auto extrPos = extractVector<int64_t>(currentOp.getPosition());
+ ArrayRef<int64_t> extrPos = currentOp.getPosition();
globalPosition.append(extrPos.rbegin(), extrPos.rend());
}
extractOp.setOperand(currentOp.getVector());
// OpBuilder is only used as a helper to build an I64ArrayAttr.
OpBuilder b(extractOp.getContext());
std::reverse(globalPosition.begin(), globalPosition.end());
- extractOp.setPositionAttr(b.getI64ArrayAttr(globalPosition));
+ extractOp.setPosition(globalPosition);
return success();
}
@@ -1329,7 +1317,8 @@ ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState(
sentinels.reserve(vectorRank - extractedRank);
for (int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i)
sentinels.push_back(-(i + 1));
- extractPosition = extractVector<int64_t>(extractOp.getPosition());
+ extractPosition.assign(extractOp.getPosition().begin(),
+ extractOp.getPosition().end());
llvm::append_range(extractPosition, sentinels);
}
@@ -1349,9 +1338,8 @@ LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
LogicalResult
ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
Value &res) {
- auto insertedPos = extractVector<int64_t>(nextInsertOp.getPosition());
- if (ArrayRef(insertedPos) !=
- llvm::ArrayRef(extractPosition).take_front(extractedRank))
+ ArrayRef<int64_t> insertedPos = nextInsertOp.getPosition();
+ if (insertedPos != llvm::ArrayRef(extractPosition).take_front(extractedRank))
return failure();
// Case 2.a. early-exit fold.
res = nextInsertOp.getSource();
@@ -1364,7 +1352,7 @@ ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
/// This method updates the internal state.
LogicalResult
ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) {
- auto insertedPos = extractVector<int64_t>(nextInsertOp.getPosition());
+ ArrayRef<int64_t> insertedPos = nextInsertOp.getPosition();
if (!isContainedWithin(insertedPos, extractPosition))
return failure();
// Set leading dims to zero.
@@ -1390,9 +1378,7 @@ Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
return Value();
// Otherwise, fold by updating the op inplace and return its result.
OpBuilder b(extractOp.getContext());
- extractOp->setAttr(
- extractOp.getPositionAttrName(),
- b.getI64ArrayAttr(ArrayRef(extractPosition).take_front(extractedRank)));
+ extractOp.setPosition(ArrayRef(extractPosition).take_front(extractedRank));
extractOp.getVectorMutable().assign(source);
return extractOp.getResult();
}
@@ -1422,7 +1408,7 @@ Value ExtractFromInsertTransposeChainState::fold() {
// Case 4: extractPositionRef intersects insertedPosRef on non-sentinel
// values. This is a more
diff icult case and we bail.
- auto insertedPos = extractVector<int64_t>(nextInsertOp.getPosition());
+ ArrayRef<int64_t> insertedPos = nextInsertOp.getPosition();
if (isContainedWithin(extractPosition, insertedPos) ||
intersectsWhereNonNegative(extractPosition, insertedPos))
return Value();
@@ -1487,7 +1473,7 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
// extract position to `0` when extracting from the source operand.
llvm::SetVector<int64_t> broadcastedUnitDims =
broadcastOp.computeBroadcastedUnitDims();
- auto extractPos = extractVector<int64_t>(extractOp.getPosition());
+ SmallVector<int64_t> extractPos(extractOp.getPosition());
for (int64_t i = rankDiff, e = extractPos.size(); i < e; ++i)
if (broadcastedUnitDims.contains(i))
extractPos[i] = 0;
@@ -1498,7 +1484,7 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
// OpBuilder is only used as a helper to build an I64ArrayAttr.
OpBuilder b(extractOp.getContext());
extractOp.setOperand(source);
- extractOp.setPositionAttr(b.getI64ArrayAttr(extractPos));
+ extractOp.setPosition(extractPos);
return extractOp.getResult();
}
@@ -1537,7 +1523,7 @@ static Value foldExtractFromShapeCast(ExtractOp extractOp) {
}
// Extract the strides associated with the extract op vector source. Then use
// this to calculate a linearized position for the extract.
- auto extractedPos = extractVector<int64_t>(extractOp.getPosition());
+ SmallVector<int64_t> extractedPos(extractOp.getPosition());
std::reverse(extractedPos.begin(), extractedPos.end());
SmallVector<int64_t, 4> strides;
int64_t stride = 1;
@@ -1563,7 +1549,7 @@ static Value foldExtractFromShapeCast(ExtractOp extractOp) {
SmallVector<int64_t, 4> newPosition = delinearize(position, newStrides);
// OpBuilder is only used as a helper to build an I64ArrayAttr.
OpBuilder b(extractOp.getContext());
- extractOp.setPositionAttr(b.getI64ArrayAttr(newPosition));
+ extractOp.setPosition(newPosition);
extractOp.setOperand(shapeCastOp.getSource());
return extractOp.getResult();
}
@@ -1603,14 +1589,14 @@ static Value foldExtractFromExtractStrided(ExtractOp extractOp) {
if (destinationRank > extractStridedSliceOp.getSourceVectorType().getRank() -
sliceOffsets.size())
return Value();
- auto extractedPos = extractVector<int64_t>(extractOp.getPosition());
+ SmallVector<int64_t> extractedPos(extractOp.getPosition());
assert(extractedPos.size() >= sliceOffsets.size());
for (size_t i = 0, e = sliceOffsets.size(); i < e; i++)
extractedPos[i] = extractedPos[i] + sliceOffsets[i];
extractOp.getVectorMutable().assign(extractStridedSliceOp.getVector());
// OpBuilder is only used as a helper to build an I64ArrayAttr.
OpBuilder b(extractOp.getContext());
- extractOp.setPositionAttr(b.getI64ArrayAttr(extractedPos));
+ extractOp.setPosition(extractedPos);
return extractOp.getResult();
}
@@ -1635,7 +1621,7 @@ static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) {
if (destinationRank > insertOp.getSourceVectorType().getRank())
return Value();
auto insertOffsets = extractVector<int64_t>(insertOp.getOffsets());
- auto extractOffsets = extractVector<int64_t>(extractOp.getPosition());
+ ArrayRef<int64_t> extractOffsets = extractOp.getPosition();
if (llvm::any_of(insertOp.getStrides(), [](Attribute attr) {
return llvm::cast<IntegerAttr>(attr).getInt() != 1;
@@ -1675,7 +1661,7 @@ static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) {
extractOp.getVectorMutable().assign(insertOp.getSource());
// OpBuilder is only used as a helper to build an I64ArrayAttr.
OpBuilder b(extractOp.getContext());
- extractOp.setPositionAttr(b.getI64ArrayAttr(offsetDiffs));
+ extractOp.setPosition(offsetDiffs);
return extractOp.getResult();
}
// If the chunk extracted is disjoint from the chunk inserted, keep
@@ -1795,7 +1781,7 @@ class ExtractOpNonSplatConstantFolder final
// Calculate the linearized position of the continuous chunk of elements to
// extract.
llvm::SmallVector<int64_t> completePositions(vecTy.getRank(), 0);
- copy(getI64SubArray(extractOp.getPosition()), completePositions.begin());
+ copy(extractOp.getPosition(), completePositions.begin());
int64_t elemBeginPosition =
linearize(completePositions, computeStrides(vecTy.getShape()));
auto denseValuesBegin = dense.value_begin<TypedAttr>() + elemBeginPosition;
@@ -2288,14 +2274,6 @@ OpFoldResult vector::InsertElementOp::fold(FoldAdaptor adaptor) {
// InsertOp
//===----------------------------------------------------------------------===//
-void InsertOp::build(OpBuilder &builder, OperationState &result, Value source,
- Value dest, ArrayRef<int64_t> position) {
- result.addOperands({source, dest});
- auto positionAttr = getVectorSubscriptAttr(builder, position);
- result.addTypes(dest.getType());
- result.addAttribute(InsertOp::getPositionAttrName(result.name), positionAttr);
-}
-
// Convenience builder which assumes the values are constant indices.
void InsertOp::build(OpBuilder &builder, OperationState &result, Value source,
Value dest, ValueRange position) {
@@ -2307,25 +2285,24 @@ void InsertOp::build(OpBuilder &builder, OperationState &result, Value source,
}
LogicalResult InsertOp::verify() {
- auto positionAttr = getPosition().getValue();
+ ArrayRef<int64_t> position = getPosition();
auto destVectorType = getDestVectorType();
- if (positionAttr.size() > static_cast<unsigned>(destVectorType.getRank()))
+ if (position.size() > static_cast<unsigned>(destVectorType.getRank()))
return emitOpError(
"expected position attribute of rank no greater than dest vector rank");
auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
if (srcVectorType &&
- (static_cast<unsigned>(srcVectorType.getRank()) + positionAttr.size() !=
+ (static_cast<unsigned>(srcVectorType.getRank()) + position.size() !=
static_cast<unsigned>(destVectorType.getRank())))
return emitOpError("expected position attribute rank + source rank to "
"match dest vector rank");
if (!srcVectorType &&
- (positionAttr.size() != static_cast<unsigned>(destVectorType.getRank())))
+ (position.size() != static_cast<unsigned>(destVectorType.getRank())))
return emitOpError(
"expected position attribute rank to match the dest vector rank");
- for (const auto &en : llvm::enumerate(positionAttr)) {
- auto attr = llvm::dyn_cast<IntegerAttr>(en.value());
- if (!attr || attr.getInt() < 0 ||
- attr.getInt() >= destVectorType.getDimSize(en.index()))
+ for (const auto &en : llvm::enumerate(position)) {
+ int64_t attr = en.value();
+ if (attr < 0 || attr >= destVectorType.getDimSize(en.index()))
return emitOpError("expected position attribute #")
<< (en.index() + 1)
<< " to be a non-negative integer smaller than the corresponding "
@@ -2412,7 +2389,7 @@ class InsertOpConstantFolder final : public OpRewritePattern<InsertOp> {
// Calculate the linearized position of the continuous chunk of elements to
// insert.
llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
- copy(getI64SubArray(op.getPosition()), completePositions.begin());
+ copy(op.getPosition(), completePositions.begin());
int64_t insertBeginPosition =
linearize(completePositions, computeStrides(destTy.getShape()));
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index 986c5f81d60c22..66ac5ffef3e3ed 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -91,10 +91,8 @@ static Value reshapeLoad(Location loc, Value val, VectorType type,
return val;
Type lowType = VectorType::Builder(type).dropDim(0);
// At extraction dimension?
- if (index == 0) {
- auto posAttr = rewriter.getI64ArrayAttr(pos);
- return rewriter.create<vector::ExtractOp>(loc, lowType, val, posAttr);
- }
+ if (index == 0)
+ return rewriter.create<vector::ExtractOp>(loc, lowType, val, pos);
// Unroll leading dimensions.
VectorType vType = cast<VectorType>(lowType);
Type resType = VectorType::Builder(type).dropDim(index);
@@ -102,11 +100,10 @@ static Value reshapeLoad(Location loc, Value val, VectorType type,
Value result = rewriter.create<arith::ConstantOp>(
loc, resVectorType, rewriter.getZeroAttr(resVectorType));
for (int64_t d = 0, e = resVectorType.getDimSize(0); d < e; d++) {
- auto posAttr = rewriter.getI64ArrayAttr(d);
- Value ext = rewriter.create<vector::ExtractOp>(loc, vType, val, posAttr);
+ Value ext = rewriter.create<vector::ExtractOp>(loc, vType, val, d);
Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter);
- result = rewriter.create<vector::InsertOp>(loc, resVectorType, load, result,
- posAttr);
+ result =
+ rewriter.create<vector::InsertOp>(loc, resVectorType, load, result, d);
}
return result;
}
@@ -120,20 +117,17 @@ static Value reshapeStore(Location loc, Value val, Value result,
if (index == -1)
return val;
// At insertion dimension?
- if (index == 0) {
- auto posAttr = rewriter.getI64ArrayAttr(pos);
- return rewriter.create<vector::InsertOp>(loc, type, val, result, posAttr);
- }
+ if (index == 0)
+ return rewriter.create<vector::InsertOp>(loc, type, val, result, pos);
// Unroll leading dimensions.
Type lowType = VectorType::Builder(type).dropDim(0);
VectorType vType = cast<VectorType>(lowType);
Type insType = VectorType::Builder(vType).dropDim(0);
for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) {
- auto posAttr = rewriter.getI64ArrayAttr(d);
- Value ext = rewriter.create<vector::ExtractOp>(loc, vType, result, posAttr);
- Value ins = rewriter.create<vector::ExtractOp>(loc, insType, val, posAttr);
+ Value ext = rewriter.create<vector::ExtractOp>(loc, vType, result, d);
+ Value ins = rewriter.create<vector::ExtractOp>(loc, insType, val, d);
Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter);
- result = rewriter.create<vector::InsertOp>(loc, type, sto, result, posAttr);
+ result = rewriter.create<vector::InsertOp>(loc, type, sto, result, d);
}
return result;
}
@@ -823,10 +817,8 @@ struct ContractOpToElementwise
newRhs = rewriter.create<vector::TransposeOp>(loc, newRhs, rhsTranspose);
SmallVector<int64_t> lhsOffsets(lhsReductionDims.size(), 0);
SmallVector<int64_t> rhsOffsets(rhsReductionDims.size(), 0);
- newLhs = rewriter.create<vector::ExtractOp>(
- loc, newLhs, rewriter.getI64ArrayAttr(lhsOffsets));
- newRhs = rewriter.create<vector::ExtractOp>(
- loc, newRhs, rewriter.getI64ArrayAttr(rhsOffsets));
+ newLhs = rewriter.create<vector::ExtractOp>(loc, newLhs, lhsOffsets);
+ newRhs = rewriter.create<vector::ExtractOp>(loc, newRhs, rhsOffsets);
std::optional<Value> result =
createContractArithOp(loc, newLhs, newRhs, contractOp.getAcc(),
contractOp.getKind(), rewriter, isInt);
@@ -1167,21 +1159,20 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
Value result = rewriter.create<arith::ConstantOp>(
loc, resType, rewriter.getZeroAttr(resType));
for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) {
- auto pos = rewriter.getI64ArrayAttr(d);
- Value x = rewriter.create<vector::ExtractOp>(loc, op.getLhs(), pos);
+ Value x = rewriter.create<vector::ExtractOp>(loc, op.getLhs(), d);
Value a = rewriter.create<vector::BroadcastOp>(loc, rhsType, x);
Value r = nullptr;
if (acc)
- r = rewriter.create<vector::ExtractOp>(loc, acc, pos);
+ r = rewriter.create<vector::ExtractOp>(loc, acc, d);
Value extrMask;
if (mask)
- extrMask = rewriter.create<vector::ExtractOp>(loc, mask, pos);
+ extrMask = rewriter.create<vector::ExtractOp>(loc, mask, d);
std::optional<Value> m = createContractArithOp(
loc, a, op.getRhs(), r, kind, rewriter, isInt, extrMask);
if (!m.has_value())
return failure();
- result = rewriter.create<vector::InsertOp>(loc, resType, *m, result, pos);
+ result = rewriter.create<vector::InsertOp>(loc, resType, *m, result, d);
}
rewriter.replaceOp(rootOp, result);
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
index a0ed056fc7a328..796bbab38dcbf6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
@@ -77,9 +77,7 @@ class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> {
Value val = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
bnd, idx);
Value sel = rewriter.create<arith::SelectOp>(loc, val, trueVal, falseVal);
- auto pos = rewriter.getI64ArrayAttr(d);
- result =
- rewriter.create<vector::InsertOp>(loc, dstType, sel, result, pos);
+ result = rewriter.create<vector::InsertOp>(loc, dstType, sel, result, d);
}
rewriter.replaceOp(op, result);
return success();
@@ -151,11 +149,9 @@ class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
loc, lowType, rewriter.getI64ArrayAttr(newDimSizes));
Value result = rewriter.create<arith::ConstantOp>(
loc, dstType, rewriter.getZeroAttr(dstType));
- for (int64_t d = 0; d < trueDim; d++) {
- auto pos = rewriter.getI64ArrayAttr(d);
+ for (int64_t d = 0; d < trueDim; d++)
result =
- rewriter.create<vector::InsertOp>(loc, dstType, trueVal, result, pos);
- }
+ rewriter.create<vector::InsertOp>(loc, dstType, trueVal, result, d);
rewriter.replaceOp(op, result);
return success();
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 72aae4956b3e28..9d6c45b4bceaec 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -944,7 +944,7 @@ struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
// Rewrite vector.extract with 1d source to vector.extractelement.
if (extractSrcType.getRank() == 1) {
assert(extractOp.getPosition().size() == 1 && "expected 1 index");
- int64_t pos = cast<IntegerAttr>(extractOp.getPosition()[0]).getInt();
+ int64_t pos = extractOp.getPosition()[0];
rewriter.setInsertionPoint(extractOp);
rewriter.replaceOpWithNewOp<vector::ExtractElementOp>(
extractOp, extractOp.getVector(),
@@ -1201,7 +1201,7 @@ struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> {
// Rewrite vector.insert with 1d dest to vector.insertelement.
if (insertOp.getDestVectorType().getRank() == 1) {
assert(insertOp.getPosition().size() == 1 && "expected 1 index");
- int64_t pos = cast<IntegerAttr>(insertOp.getPosition()[0]).getInt();
+ int64_t pos = insertOp.getPosition()[0];
rewriter.setInsertionPoint(insertOp);
rewriter.replaceOpWithNewOp<vector::InsertElementOp>(
insertOp, insertOp.getSource(), insertOp.getDest(),
@@ -1276,10 +1276,7 @@ struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> {
} else {
// One lane inserts the entire source vector.
int64_t elementsPerLane = distrDestType.getDimSize(distrDestDim);
- SmallVector<int64_t> newPos = llvm::to_vector(
- llvm::map_range(insertOp.getPosition(), [](Attribute attr) {
- return cast<IntegerAttr>(attr).getInt();
- }));
+ SmallVector<int64_t> newPos(insertOp.getPosition());
// tid of inserting lane: pos / elementsPerLane
Value insertingLane = rewriter.create<arith::ConstantIndexOp>(
loc, newPos[distrDestDim] / elementsPerLane);
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index 062950f6456f46..dabbca3f7a2271 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -165,16 +165,14 @@ struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> {
// type has leading unit dims, we also trim the position array accordingly,
// then (2) if source type also has leading unit dims, we need to append
// zeroes to the position array accordingly.
- unsigned oldPosRank = insertOp.getPosition().getValue().size();
+ unsigned oldPosRank = insertOp.getPosition().size();
unsigned newPosRank = std::max<int64_t>(0, oldPosRank - dstDropCount);
- SmallVector<Attribute> newPositions = llvm::to_vector(
- insertOp.getPosition().getValue().take_back(newPosRank));
- newPositions.resize(newDstType.getRank() - newSrcRank,
- rewriter.getI64IntegerAttr(0));
+ SmallVector<int64_t> newPositions =
+ llvm::to_vector(insertOp.getPosition().take_back(newPosRank));
+ newPositions.resize(newDstType.getRank() - newSrcRank, 0);
auto newInsertOp = rewriter.create<vector::InsertOp>(
- loc, newDstType, newSrcVector, newDstVector,
- rewriter.getArrayAttr(newPositions));
+ loc, newDstType, newSrcVector, newDstVector, newPositions);
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType,
newInsertOp);
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index dd4948f34d6682..74d4b7636315fd 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -704,7 +704,7 @@ class RewriteScalarExtractOfTransferRead
SmallVector<Value> newIndices(xferOp.getIndices().begin(),
xferOp.getIndices().end());
for (const auto &it : llvm::enumerate(extractOp.getPosition())) {
- int64_t offset = cast<IntegerAttr>(it.value()).getInt();
+ int64_t offset = it.value();
int64_t idx =
newIndices.size() - extractOp.getPosition().size() + it.index();
OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index abe6d8846a2357..a6177641dc6b43 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -598,11 +598,7 @@ struct BubbleDownVectorBitCastForExtract
unsigned expandRatio =
castDstType.getNumElements() / castSrcType.getNumElements();
- auto getFirstIntValue = [](ArrayAttr attr) -> uint64_t {
- return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
- };
-
- uint64_t index = getFirstIntValue(extractOp.getPosition());
+ uint64_t index = extractOp.getPosition()[0];
// Get the single scalar (as a vector) in the source value that packs the
// desired scalar. E.g. extract vector<1xf32> from vector<4xf32>
@@ -610,7 +606,7 @@ struct BubbleDownVectorBitCastForExtract
VectorType::get({1}, castSrcType.getElementType());
Value packedValue = rewriter.create<vector::ExtractOp>(
extractOp.getLoc(), oneScalarType, castOp.getSource(),
- rewriter.getI64ArrayAttr(index / expandRatio));
+ index / expandRatio);
// Cast it to a vector with the desired scalar's type.
// E.g. f32 -> vector<2xf16>
@@ -621,8 +617,7 @@ struct BubbleDownVectorBitCastForExtract
// Finally extract the desired scalar.
rewriter.replaceOpWithNewOp<vector::ExtractOp>(
- extractOp, extractOp.getType(), castedValue,
- rewriter.getI64ArrayAttr(index % expandRatio));
+ extractOp, extractOp.getType(), castedValue, index % expandRatio);
return success();
}
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 535da1328d34c5..922351265e38b9 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -155,8 +155,8 @@ func.func @broadcast(%arg0 : f32) -> (vector<4xf32>, vector<2xf32>) {
// CHECK: spirv.CompositeExtract %[[ARG]][0 : i32] : vector<2xf32>
// CHECK: spirv.CompositeExtract %[[ARG]][1 : i32] : vector<2xf32>
func.func @extract(%arg0 : vector<2xf32>) -> (vector<1xf32>, f32) {
- %0 = "vector.extract"(%arg0) {position = [0]} : (vector<2xf32>) -> vector<1xf32>
- %1 = "vector.extract"(%arg0) {position = [1]} : (vector<2xf32>) -> f32
+ %0 = "vector.extract"(%arg0) <{position = array<i64: 0>}> : (vector<2xf32>) -> vector<1xf32>
+ %1 = "vector.extract"(%arg0) <{position = array<i64: 1>}> : (vector<2xf32>) -> f32
return %0, %1: vector<1xf32>, f32
}
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 242b0728e9953b..16fb631af25834 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -133,7 +133,7 @@ func.func @extract_position_rank_overflow(%arg0: vector<4x8x16xf32>) {
func.func @extract_position_rank_overflow_generic(%arg0: vector<4x8x16xf32>) {
// expected-error at +1 {{expected position attribute of rank no greater than vector rank}}
- %1 = "vector.extract" (%arg0) { position = [0, 0, 0, 0] } : (vector<4x8x16xf32>) -> (vector<16xf32>)
+ %1 = "vector.extract" (%arg0) <{position = array<i64: 0, 0, 0, 0>}> : (vector<4x8x16xf32>) -> (vector<16xf32>)
}
// -----
More information about the Mlir-commits
mailing list