[Mlir-commits] [mlir] 98f6289 - [mlir][Vector] Add support for Value indices to vector.extract/insert
Diego Caballero
llvmlistbot at llvm.org
Thu Sep 21 17:41:54 PDT 2023
Author: Diego Caballero
Date: 2023-09-22T00:39:32Z
New Revision: 98f6289a34bdaf7bc6cda8768e26e4405fc7726e
URL: https://github.com/llvm/llvm-project/commit/98f6289a34bdaf7bc6cda8768e26e4405fc7726e
DIFF: https://github.com/llvm/llvm-project/commit/98f6289a34bdaf7bc6cda8768e26e4405fc7726e.diff
LOG: [mlir][Vector] Add support for Value indices to vector.extract/insert
`vector.extract/insert` ops only support constant indices. This PR is
extending them so that arbitrary values can be used instead.
This work is part of the RFC: https://discourse.llvm.org/t/rfc-psa-remove-vector-extractelement-and-vector-insertelement-ops-in-favor-of-vector-extract-and-vector-insert-ops
Differential Revision: https://reviews.llvm.org/D155034
Added:
Modified:
mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
mlir/test/Dialect/Vector/invalid.mlir
mlir/test/Dialect/Vector/ops.mlir
mlir/test/Dialect/Vector/vector-transforms.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
index fcf7eb4a616b073..fc0c80036ff79ad 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
@@ -131,6 +131,24 @@ inline bool isReductionIterator(Attribute attr) {
return cast<IteratorTypeAttr>(attr).getValue() == IteratorType::reduction;
}
+/// Returns the integer numbers in `values`. `values` are expected to be
+/// constant operations.
+SmallVector<int64_t> getAsIntegers(ArrayRef<Value> values);
+
+/// Returns the integer numbers in `foldResults`. `foldResults` are expected to
+/// be constant operations.
+SmallVector<int64_t> getAsIntegers(ArrayRef<OpFoldResult> foldResults);
+
+/// Convert `foldResults` into Values. Integer attributes are converted to
+/// constant op.
+SmallVector<Value> getAsValues(OpBuilder &builder, Location loc,
+ ArrayRef<OpFoldResult> foldResults);
+
+/// Returns the constant index ops in `values`. `values` are expected to be
+/// constant operations.
+SmallVector<arith::ConstantIndexOp>
+getAsConstantIndexOps(ArrayRef<Value> values);
+
//===----------------------------------------------------------------------===//
// Vector Masking Utilities
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 701eefcc1e7da6a..ea96f2660126870 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -523,9 +523,7 @@ def Vector_ExtractOp :
Vector_Op<"extract", [Pure,
PredOpTrait<"operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>,
- InferTypeOpAdaptorWithIsCompatible]>,
- Arguments<(ins AnyVectorOfAnyRank:$vector, DenseI64ArrayAttr:$position)>,
- Results<(outs AnyType)> {
+ InferTypeOpAdaptorWithIsCompatible]> {
let summary = "extract operation";
let description = [{
Takes an n-D vector and a k-D position and extracts the (n-k)-D vector at
@@ -535,21 +533,55 @@ def Vector_ExtractOp :
```mlir
%1 = vector.extract %0[3]: vector<4x8x16xf32>
- %2 = vector.extract %0[3, 3, 3]: vector<4x8x16xf32>
+ %2 = vector.extract %0[2, 1, 3]: vector<4x8x16xf32>
%3 = vector.extract %1[]: vector<f32>
+ %4 = vector.extract %0[%a, %b, %c]: vector<4x8x16xf32>
+ %5 = vector.extract %0[2, %b]: vector<4x8x16xf32>
```
}];
+
+ let arguments = (ins
+ AnyVectorOfAnyRank:$vector,
+ Variadic<Index>:$dynamic_position,
+ DenseI64ArrayAttr:$static_position
+ );
+ let results = (outs AnyType:$result);
+
let builders = [
- // Convenience builder which assumes the values in `position` are defined by
- // ConstantIndexOp.
- OpBuilder<(ins "Value":$source, "ValueRange":$position)>
+ OpBuilder<(ins "Value":$source, "int64_t":$position)>,
+ OpBuilder<(ins "Value":$source, "OpFoldResult":$position)>,
+ OpBuilder<(ins "Value":$source, "ArrayRef<int64_t>":$position)>,
+ OpBuilder<(ins "Value":$source, "ArrayRef<OpFoldResult>":$position)>,
];
+
let extraClassDeclaration = [{
VectorType getSourceVectorType() {
return ::llvm::cast<VectorType>(getVector().getType());
}
+
+ /// Return a vector with all the static and dynamic position indices.
+ SmallVector<OpFoldResult> getMixedPosition() {
+ OpBuilder builder(getContext());
+ return getMixedValues(getStaticPosition(), getDynamicPosition(), builder);
+ }
+
+ unsigned getNumIndices() {
+ return getStaticPosition().size();
+ }
+
+ bool hasDynamicPosition() {
+ auto dynPos = getDynamicPosition();
+ return std::any_of(dynPos.begin(), dynPos.end(),
+ [](Value operand) { return operand != nullptr; });
+ }
}];
- let assemblyFormat = "$vector `` $position attr-dict `:` type($vector)";
+
+ let assemblyFormat = [{
+ $vector ``
+ custom<DynamicIndexList>($dynamic_position, $static_position)
+ attr-dict `:` type($vector)
+ }];
+
let hasCanonicalizer = 1;
let hasFolder = 1;
let hasVerifier = 1;
@@ -638,9 +670,7 @@ def Vector_InsertOp :
Vector_Op<"insert", [Pure,
PredOpTrait<"source operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>,
- AllTypesMatch<["dest", "res"]>]>,
- Arguments<(ins AnyType:$source, AnyVectorOfAnyRank:$dest, DenseI64ArrayAttr:$position)>,
- Results<(outs AnyVectorOfAnyRank:$res)> {
+ AllTypesMatch<["dest", "result"]>]> {
let summary = "insert operation";
let description = [{
Takes an n-D source vector, an (n+k)-D destination vector and a k-D position
@@ -651,24 +681,53 @@ def Vector_InsertOp :
```mlir
%2 = vector.insert %0, %1[3] : vector<8x16xf32> into vector<4x8x16xf32>
- %5 = vector.insert %3, %4[3, 3, 3] : f32 into vector<4x8x16xf32>
+ %5 = vector.insert %3, %4[2, 1, 3] : f32 into vector<4x8x16xf32>
%8 = vector.insert %6, %7[] : f32 into vector<f32>
- %11 = vector.insert %9, %10[3, 3, 3] : vector<f32> into vector<4x8x16xf32>
+ %11 = vector.insert %9, %10[%a, %b, %c] : vector<f32> into vector<4x8x16xf32>
+ %12 = vector.insert %4, %10[2, %b] : vector<16xf32> into vector<4x8x16xf32>
```
}];
- let assemblyFormat = [{
- $source `,` $dest $position attr-dict `:` type($source) `into` type($dest)
- }];
+
+ let arguments = (ins
+ AnyType:$source,
+ AnyVectorOfAnyRank:$dest,
+ Variadic<Index>:$dynamic_position,
+ DenseI64ArrayAttr:$static_position
+ );
+ let results = (outs AnyVectorOfAnyRank:$result);
let builders = [
- // Convenience builder which assumes all values are constant indices.
- OpBuilder<(ins "Value":$source, "Value":$dest, "ValueRange":$position)>
+ OpBuilder<(ins "Value":$source, "Value":$dest, "int64_t":$position)>,
+ OpBuilder<(ins "Value":$source, "Value":$dest, "OpFoldResult":$position)>,
+ OpBuilder<(ins "Value":$source, "Value":$dest, "ArrayRef<int64_t>":$position)>,
+ OpBuilder<(ins "Value":$source, "Value":$dest, "ArrayRef<OpFoldResult>":$position)>,
];
+
let extraClassDeclaration = [{
Type getSourceType() { return getSource().getType(); }
VectorType getDestVectorType() {
return ::llvm::cast<VectorType>(getDest().getType());
}
+
+ /// Return a vector with all the static and dynamic position indices.
+ SmallVector<OpFoldResult> getMixedPosition() {
+ OpBuilder builder(getContext());
+ return getMixedValues(getStaticPosition(), getDynamicPosition(), builder);
+ }
+
+ unsigned getNumIndices() {
+ return getStaticPosition().size();
+ }
+
+ bool hasDynamicPosition() {
+ return llvm::any_of(getDynamicPosition(),
+ [](Value operand) { return operand != nullptr; });
+ }
+ }];
+
+ let assemblyFormat = [{
+ $source `,` $dest custom<DynamicIndexList>($dynamic_position, $static_position)
+ attr-dict `:` type($source) `into` type($dest)
}];
let hasCanonicalizer = 1;
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 2c08257fc3089b6..3f77c5b5f24e9b5 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -126,6 +126,18 @@ static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc,
return rewriter.create<LLVM::BitcastOp>(loc, pType, ptr);
}
+/// Convert `foldResult` into a Value. Integer attribute is converted to
+/// an LLVM constant op.
+static Value getAsLLVMValue(OpBuilder &builder, Location loc,
+ OpFoldResult foldResult) {
+ if (auto attr = foldResult.dyn_cast<Attribute>()) {
+ auto intAttr = cast<IntegerAttr>(attr);
+ return builder.create<LLVM::ConstantOp>(loc, intAttr).getResult();
+ }
+
+ return foldResult.get<Value>();
+}
+
namespace {
/// Trivial Vector to LLVM conversions
@@ -1079,41 +1091,53 @@ class VectorExtractOpConversion
auto loc = extractOp->getLoc();
auto resultType = extractOp.getResult().getType();
auto llvmResultType = typeConverter->convertType(resultType);
- ArrayRef<int64_t> positionArray = extractOp.getPosition();
-
// Bail if result type cannot be lowered.
if (!llvmResultType)
return failure();
+ SmallVector<OpFoldResult> positionVec;
+ for (auto [idx, pos] : llvm::enumerate(extractOp.getMixedPosition())) {
+ if (pos.is<Value>())
+ // Make sure we use the value that has been already converted to LLVM.
+ positionVec.push_back(adaptor.getDynamicPosition()[idx]);
+ else
+ positionVec.push_back(pos);
+ }
+
// Extract entire vector. Should be handled by folder, but just to be safe.
- if (positionArray.empty()) {
+ ArrayRef<OpFoldResult> position(positionVec);
+ if (position.empty()) {
rewriter.replaceOp(extractOp, adaptor.getVector());
return success();
}
// One-shot extraction of vector from array (only requires extractvalue).
if (isa<VectorType>(resultType)) {
+ if (extractOp.hasDynamicPosition())
+ return failure();
+
Value extracted = rewriter.create<LLVM::ExtractValueOp>(
- loc, adaptor.getVector(), positionArray);
+ loc, adaptor.getVector(), getAsIntegers(position));
rewriter.replaceOp(extractOp, extracted);
return success();
}
// Potential extraction of 1-D vector from array.
Value extracted = adaptor.getVector();
- if (positionArray.size() > 1) {
- extracted = rewriter.create<LLVM::ExtractValueOp>(
- loc, extracted, positionArray.drop_back());
- }
+ if (position.size() > 1) {
+ if (extractOp.hasDynamicPosition())
+ return failure();
- // Remaining extraction of element from 1-D LLVM vector
- auto i64Type = IntegerType::get(rewriter.getContext(), 64);
- auto constant =
- rewriter.create<LLVM::ConstantOp>(loc, i64Type, positionArray.back());
- extracted =
- rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
- rewriter.replaceOp(extractOp, extracted);
+ SmallVector<int64_t> nMinusOnePosition =
+ getAsIntegers(position.drop_back());
+ extracted = rewriter.create<LLVM::ExtractValueOp>(loc, extracted,
+ nMinusOnePosition);
+ }
+ Value lastPosition = getAsLLVMValue(rewriter, loc, position.back());
+ // Remaining extraction of element from 1-D LLVM vector.
+ rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(extractOp, extracted,
+ lastPosition);
return success();
}
};
@@ -1194,23 +1218,34 @@ class VectorInsertOpConversion
auto sourceType = insertOp.getSourceType();
auto destVectorType = insertOp.getDestVectorType();
auto llvmResultType = typeConverter->convertType(destVectorType);
- ArrayRef<int64_t> positionArray = insertOp.getPosition();
-
// Bail if result type cannot be lowered.
if (!llvmResultType)
return failure();
+ SmallVector<OpFoldResult> positionVec;
+ for (auto [idx, pos] : llvm::enumerate(insertOp.getMixedPosition())) {
+ if (pos.is<Value>())
+ // Make sure we use the value that has been already converted to LLVM.
+ positionVec.push_back(adaptor.getDynamicPosition()[idx]);
+ else
+ positionVec.push_back(pos);
+ }
+
// Overwrite entire vector with value. Should be handled by folder, but
// just to be safe.
- if (positionArray.empty()) {
+ ArrayRef<OpFoldResult> position(positionVec);
+ if (position.empty()) {
rewriter.replaceOp(insertOp, adaptor.getSource());
return success();
}
// One-shot insertion of a vector into an array (only requires insertvalue).
if (isa<VectorType>(sourceType)) {
+ if (insertOp.hasDynamicPosition())
+ return failure();
+
Value inserted = rewriter.create<LLVM::InsertValueOp>(
- loc, adaptor.getDest(), adaptor.getSource(), positionArray);
+ loc, adaptor.getDest(), adaptor.getSource(), getAsIntegers(position));
rewriter.replaceOp(insertOp, inserted);
return success();
}
@@ -1218,24 +1253,28 @@ class VectorInsertOpConversion
// Potential extraction of 1-D vector from array.
Value extracted = adaptor.getDest();
auto oneDVectorType = destVectorType;
- if (positionArray.size() > 1) {
+ if (position.size() > 1) {
+ if (insertOp.hasDynamicPosition())
+ return failure();
+
oneDVectorType = reducedVectorTypeBack(destVectorType);
extracted = rewriter.create<LLVM::ExtractValueOp>(
- loc, extracted, positionArray.drop_back());
+ loc, extracted, getAsIntegers(position.drop_back()));
}
// Insertion of an element into a 1-D LLVM vector.
- auto i64Type = IntegerType::get(rewriter.getContext(), 64);
- auto constant =
- rewriter.create<LLVM::ConstantOp>(loc, i64Type, positionArray.back());
Value inserted = rewriter.create<LLVM::InsertElementOp>(
loc, typeConverter->convertType(oneDVectorType), extracted,
- adaptor.getSource(), constant);
+ adaptor.getSource(), getAsLLVMValue(rewriter, loc, position.back()));
// Potential insertion of resulting 1-D vector into array.
- if (positionArray.size() > 1) {
+ if (position.size() > 1) {
+ if (insertOp.hasDynamicPosition())
+ return failure();
+
inserted = rewriter.create<LLVM::InsertValueOp>(
- loc, adaptor.getDest(), inserted, positionArray.drop_back());
+ loc, adaptor.getDest(), inserted,
+ getAsIntegers(position.drop_back()));
}
rewriter.replaceOp(insertOp, inserted);
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 1aeed4594f94505..f8fd89c542c0699 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -1063,10 +1063,11 @@ struct UnrollTransferReadConversion
/// If the result of the TransferReadOp has exactly one user, which is a
/// vector::InsertOp, return that operation's indices.
void getInsertionIndices(TransferReadOp xferOp,
- SmallVector<int64_t, 8> &indices) const {
- if (auto insertOp = getInsertOp(xferOp))
- indices.assign(insertOp.getPosition().begin(),
- insertOp.getPosition().end());
+ SmallVectorImpl<OpFoldResult> &indices) const {
+ if (auto insertOp = getInsertOp(xferOp)) {
+ auto pos = insertOp.getMixedPosition();
+ indices.append(pos.begin(), pos.end());
+ }
}
/// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
@@ -1110,9 +1111,9 @@ struct UnrollTransferReadConversion
getXferIndices(b, xferOp, iv, xferIndices);
// Indices for the new vector.insert op.
- SmallVector<int64_t, 8> insertionIndices;
+ SmallVector<OpFoldResult, 8> insertionIndices;
getInsertionIndices(xferOp, insertionIndices);
- insertionIndices.push_back(i);
+ insertionIndices.push_back(rewriter.getIndexAttr(i));
auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
auto newXferOp = b.create<vector::TransferReadOp>(
@@ -1195,10 +1196,11 @@ struct UnrollTransferWriteConversion
/// If the input of the given TransferWriteOp is an ExtractOp, return its
/// indices.
void getExtractionIndices(TransferWriteOp xferOp,
- SmallVector<int64_t, 8> &indices) const {
- if (auto extractOp = getExtractOp(xferOp))
- indices.assign(extractOp.getPosition().begin(),
- extractOp.getPosition().end());
+ SmallVectorImpl<OpFoldResult> &indices) const {
+ if (auto extractOp = getExtractOp(xferOp)) {
+ auto pos = extractOp.getMixedPosition();
+ indices.append(pos.begin(), pos.end());
+ }
}
/// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
@@ -1235,9 +1237,9 @@ struct UnrollTransferWriteConversion
getXferIndices(b, xferOp, iv, xferIndices);
// Indices for the new vector.extract op.
- SmallVector<int64_t, 8> extractionIndices;
+ SmallVector<OpFoldResult, 8> extractionIndices;
getExtractionIndices(xferOp, extractionIndices);
- extractionIndices.push_back(i);
+ extractionIndices.push_back(b.getI64IntegerAttr(i));
auto extracted =
b.create<vector::ExtractOp>(loc, vec, extractionIndices);
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index a8c68abc8bcbf5c..9b29179f3687165 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -35,11 +35,25 @@
using namespace mlir;
-/// Gets the first integer value from `attr`, assuming it is an integer array
-/// attribute.
+/// Returns the integer value from the first valid input element, assuming Value
+/// inputs are defined by a constant index ops and Attribute inputs are integer
+/// attributes.
+static uint64_t getFirstIntValue(ValueRange values) {
+ return values[0].getDefiningOp<arith::ConstantIndexOp>().value();
+}
+static uint64_t getFirstIntValue(ArrayRef<Attribute> attr) {
+ return cast<IntegerAttr>(attr[0]).getInt();
+}
static uint64_t getFirstIntValue(ArrayAttr attr) {
return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
}
+static uint64_t getFirstIntValue(ArrayRef<OpFoldResult> foldResults) {
+ auto attr = foldResults[0].dyn_cast<Attribute>();
+ if (attr)
+ return getFirstIntValue(attr);
+
+ return getFirstIntValue(ValueRange{foldResults[0].get<Value>()});
+}
/// Returns the number of bits for the given scalar/vector type.
static int getNumBits(Type type) {
@@ -141,9 +155,7 @@ struct VectorExtractOpConvert final
LogicalResult
matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- // Only support extracting a scalar value now.
- VectorType resultVectorType = dyn_cast<VectorType>(extractOp.getType());
- if (resultVectorType && resultVectorType.getNumElements() > 1)
+ if (extractOp.hasDynamicPosition())
return failure();
Type dstType = getTypeConverter()->convertType(extractOp.getType());
@@ -155,7 +167,7 @@ struct VectorExtractOpConvert final
return success();
}
- int32_t id = extractOp.getPosition()[0];
+ int32_t id = getFirstIntValue(extractOp.getMixedPosition());
rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
extractOp, adaptor.getVector(), id);
return success();
@@ -235,7 +247,7 @@ struct VectorInsertOpConvert final
return success();
}
- int32_t id = insertOp.getPosition()[0];
+ int32_t id = getFirstIntValue(insertOp.getMixedPosition());
rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
insertOp, adaptor.getSource(), adaptor.getDest(), id);
return success();
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
index 906c13a6579f158..1084fbc890053b9 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
@@ -516,7 +516,7 @@ struct ExtensionOverExtract final : NarrowingPattern<vector::ExtractOp> {
return failure();
Value newExtract = rewriter.create<vector::ExtractOp>(
- op.getLoc(), ext->getIn(), op.getPosition());
+ op.getLoc(), ext->getIn(), op.getMixedPosition());
ext->recreateAndReplace(rewriter, op, newExtract);
return success();
}
@@ -645,8 +645,9 @@ struct ExtensionOverInsert final
vector::InsertOp origInsert,
Value narrowValue,
Value narrowDest) const override {
- return rewriter.create<vector::InsertOp>(
- origInsert.getLoc(), narrowValue, narrowDest, origInsert.getPosition());
+ return rewriter.create<vector::InsertOp>(origInsert.getLoc(), narrowValue,
+ narrowDest,
+ origInsert.getMixedPosition());
}
};
diff --git a/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp b/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp
index ad2180d501148f1..f63825cdc8f6179 100644
--- a/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp
+++ b/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp
@@ -74,7 +74,7 @@ static FailureOr<TransferMask> getMaskOp(Operation *loadOp) {
if (auto maskOp =
extractOp.getVector().getDefiningOp<vector::CreateMaskOp>())
return TransferMask{maskOp,
- SmallVector<int64_t>(extractOp.getPosition())};
+ SmallVector<int64_t>(extractOp.getStaticPosition())};
// All other cases: not supported.
return failure();
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 7b9c5f9b879e8c4..85d21938d0ab711 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -223,6 +223,48 @@ static LogicalResult incSlicePosition(MutableArrayRef<int64_t> position,
return failure();
}
+/// Returns the integer numbers in `values`. `values` are expected to be
+/// constant operations.
+SmallVector<int64_t> vector::getAsIntegers(ArrayRef<Value> values) {
+ SmallVector<int64_t> ints;
+ llvm::transform(values, std::back_inserter(ints), [](Value value) {
+ auto constOp = value.getDefiningOp<arith::ConstantIndexOp>();
+ assert(constOp && "Unexpected non-constant index");
+ return constOp.value();
+ });
+ return ints;
+}
+
+/// Returns the integer numbers in `foldResults`. `foldResults` are expected to
+/// be constant operations.
+SmallVector<int64_t> vector::getAsIntegers(ArrayRef<OpFoldResult> foldResults) {
+ SmallVector<int64_t> ints;
+ llvm::transform(
+ foldResults, std::back_inserter(ints), [](OpFoldResult foldResult) {
+ assert(foldResult.is<Attribute>() && "Unexpected non-constant index");
+ return cast<IntegerAttr>(foldResult.get<Attribute>()).getInt();
+ });
+ return ints;
+}
+
+/// Convert `foldResults` into Values. Integer attributes are converted to
+/// constant op.
+SmallVector<Value> vector::getAsValues(OpBuilder &builder, Location loc,
+ ArrayRef<OpFoldResult> foldResults) {
+ SmallVector<Value> values;
+ llvm::transform(foldResults, std::back_inserter(values),
+ [&](OpFoldResult foldResult) {
+ if (auto attr = foldResult.dyn_cast<Attribute>())
+ return builder
+ .create<arith::ConstantIndexOp>(
+ loc, cast<IntegerAttr>(attr).getInt())
+ .getResult();
+
+ return foldResult.get<Value>();
+ });
+ return values;
+}
+
//===----------------------------------------------------------------------===//
// CombiningKindAttr
//===----------------------------------------------------------------------===//
@@ -389,12 +431,11 @@ struct ElideUnitDimsInMultiDimReduction
} else {
// This means we are reducing all the dimensions, and all reduction
// dimensions are of size 1. So a simple extraction would do.
- SmallVector<int64_t> zeroAttr(shape.size(), 0);
+ SmallVector<int64_t> zeroIdx(shape.size(), 0);
if (mask)
- mask = rewriter.create<vector::ExtractOp>(loc, rewriter.getI1Type(),
- mask, zeroAttr);
- cast = rewriter.create<vector::ExtractOp>(
- loc, reductionOp.getDestType(), reductionOp.getSource(), zeroAttr);
+ mask = rewriter.create<vector::ExtractOp>(loc, mask, zeroIdx);
+ cast = rewriter.create<vector::ExtractOp>(loc, reductionOp.getSource(),
+ zeroIdx);
}
Value result = vector::makeArithReduction(
@@ -574,11 +615,9 @@ struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
mask = rewriter.create<ExtractElementOp>(loc, mask);
result = rewriter.create<ExtractElementOp>(loc, reductionOp.getVector());
} else {
- if (mask) {
- mask = rewriter.create<ExtractOp>(loc, rewriter.getI1Type(), mask, 0);
- }
- result = rewriter.create<ExtractOp>(loc, reductionOp.getType(),
- reductionOp.getVector(), 0);
+ if (mask)
+ mask = rewriter.create<ExtractOp>(loc, mask, 0);
+ result = rewriter.create<ExtractOp>(loc, reductionOp.getVector(), 0);
}
if (Value acc = reductionOp.getAcc())
@@ -1148,12 +1187,29 @@ OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) {
// ExtractOp
//===----------------------------------------------------------------------===//
-// Convenience builder which assumes the values are constant indices.
void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
- Value source, ValueRange position) {
- SmallVector<int64_t> positionConstants = llvm::to_vector(llvm::map_range(
- position, [](Value pos) { return getConstantIntValue(pos).value(); }));
- build(builder, result, source, positionConstants);
+ Value source, int64_t position) {
+ build(builder, result, source, ArrayRef<int64_t>{position});
+}
+
+void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
+ Value source, OpFoldResult position) {
+ build(builder, result, source, ArrayRef<OpFoldResult>{position});
+}
+
+void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
+ Value source, ArrayRef<int64_t> position) {
+ build(builder, result, source, /*dynamic_position=*/ArrayRef<Value>(),
+ builder.getDenseI64ArrayAttr(position));
+}
+
+void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
+ Value source, ArrayRef<OpFoldResult> position) {
+ SmallVector<int64_t> staticPos;
+ SmallVector<Value> dynamicPos;
+ dispatchIndexOpFoldResults(position, dynamicPos, staticPos);
+ build(builder, result, source, dynamicPos,
+ builder.getDenseI64ArrayAttr(staticPos));
}
LogicalResult
@@ -1161,12 +1217,12 @@ ExtractOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
ExtractOp::Adaptor adaptor,
SmallVectorImpl<Type> &inferredReturnTypes) {
auto vectorType = llvm::cast<VectorType>(adaptor.getVector().getType());
- if (static_cast<int64_t>(adaptor.getPosition().size()) ==
+ if (static_cast<int64_t>(adaptor.getStaticPosition().size()) ==
vectorType.getRank()) {
inferredReturnTypes.push_back(vectorType.getElementType());
} else {
- auto n =
- std::min<size_t>(adaptor.getPosition().size(), vectorType.getRank());
+ auto n = std::min<size_t>(adaptor.getStaticPosition().size(),
+ vectorType.getRank());
inferredReturnTypes.push_back(VectorType::get(
vectorType.getShape().drop_front(n), vectorType.getElementType(),
vectorType.getScalableDims().drop_front(n)));
@@ -1188,17 +1244,20 @@ bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
}
LogicalResult vector::ExtractOp::verify() {
- ArrayRef<int64_t> position = getPosition();
+ auto position = getMixedPosition();
if (position.size() > static_cast<unsigned>(getSourceVectorType().getRank()))
return emitOpError(
"expected position attribute of rank no greater than vector rank");
- for (const auto &en : llvm::enumerate(position)) {
- if (en.value() < 0 ||
- en.value() >= getSourceVectorType().getDimSize(en.index()))
- return emitOpError("expected position attribute #")
- << (en.index() + 1)
- << " to be a non-negative integer smaller than the corresponding "
- "vector dimension";
+ for (auto [idx, pos] : llvm::enumerate(position)) {
+ if (pos.is<Attribute>()) {
+ int64_t constIdx = cast<IntegerAttr>(pos.get<Attribute>()).getInt();
+ if (constIdx < 0 || constIdx >= getSourceVectorType().getDimSize(idx)) {
+ return emitOpError("expected position attribute #")
+ << (idx + 1)
+ << " to be a non-negative integer smaller than the "
+ "corresponding vector dimension";
+ }
+ }
}
return success();
}
@@ -1216,20 +1275,24 @@ static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
if (!extractOp.getVector().getDefiningOp<ExtractOp>())
return failure();
- SmallVector<int64_t, 4> globalPosition;
+ // TODO: Canonicalization for dynamic position not implemented yet.
+ if (extractOp.hasDynamicPosition())
+ return failure();
+
+ SmallVector<int64_t> globalPosition;
ExtractOp currentOp = extractOp;
- ArrayRef<int64_t> extrPos = currentOp.getPosition();
+ ArrayRef<int64_t> extrPos = currentOp.getStaticPosition();
globalPosition.append(extrPos.rbegin(), extrPos.rend());
while (ExtractOp nextOp = currentOp.getVector().getDefiningOp<ExtractOp>()) {
currentOp = nextOp;
- ArrayRef<int64_t> extrPos = currentOp.getPosition();
+ ArrayRef<int64_t> extrPos = currentOp.getStaticPosition();
globalPosition.append(extrPos.rbegin(), extrPos.rend());
}
- extractOp.setOperand(currentOp.getVector());
+ extractOp.setOperand(0, currentOp.getVector());
// OpBuilder is only used as a helper to build an I64ArrayAttr.
OpBuilder b(extractOp.getContext());
std::reverse(globalPosition.begin(), globalPosition.end());
- extractOp.setPosition(globalPosition);
+ extractOp.setStaticPosition(globalPosition);
return success();
}
@@ -1335,19 +1398,23 @@ class ExtractFromInsertTransposeChainState {
ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState(
ExtractOp e)
: extractOp(e), vectorRank(extractOp.getSourceVectorType().getRank()),
- extractedRank(extractOp.getPosition().size()) {
- assert(vectorRank >= extractedRank && "extracted pos overflow");
+ extractedRank(extractOp.getNumIndices()) {
+ assert(vectorRank >= extractedRank && "Extracted position overflow");
sentinels.reserve(vectorRank - extractedRank);
for (int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i)
sentinels.push_back(-(i + 1));
- extractPosition.assign(extractOp.getPosition().begin(),
- extractOp.getPosition().end());
+ extractPosition.assign(extractOp.getStaticPosition().begin(),
+ extractOp.getStaticPosition().end());
llvm::append_range(extractPosition, sentinels);
}
// Case 1. If we hit a transpose, just compose the map and iterate.
// Invariant: insert + transpose do not change rank, we can always compose.
LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
+ // TODO: Canonicalization for dynamic position not implemented yet.
+ if (extractOp.hasDynamicPosition())
+ return failure();
+
if (!nextTransposeOp)
return failure();
auto permutation = extractVector<unsigned>(nextTransposeOp.getTransp());
@@ -1361,7 +1428,11 @@ LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
LogicalResult
ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
Value &res) {
- ArrayRef<int64_t> insertedPos = nextInsertOp.getPosition();
+ // TODO: Canonicalization for dynamic position not implemented yet.
+ if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
+ return failure();
+
+ ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
if (insertedPos != llvm::ArrayRef(extractPosition).take_front(extractedRank))
return failure();
// Case 2.a. early-exit fold.
@@ -1375,7 +1446,11 @@ ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
/// This method updates the internal state.
LogicalResult
ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) {
- ArrayRef<int64_t> insertedPos = nextInsertOp.getPosition();
+ // TODO: Canonicalization for dynamic position not implemented yet.
+ if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
+ return failure();
+
+ ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
if (!isContainedWithin(insertedPos, extractPosition))
return failure();
// Set leading dims to zero.
@@ -1395,19 +1470,29 @@ ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) {
/// internal tranposition in the result).
Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
Value source) {
+ // TODO: Canonicalization for dynamic position not implemented yet.
+ if (extractOp.hasDynamicPosition())
+ return Value();
+
// If we can't fold (either internal transposition, or nothing to fold), bail.
bool nothingToFold = (source == extractOp.getVector());
if (nothingToFold || !canFold())
return Value();
+
// Otherwise, fold by updating the op inplace and return its result.
OpBuilder b(extractOp.getContext());
- extractOp.setPosition(ArrayRef(extractPosition).take_front(extractedRank));
+ extractOp.setStaticPosition(
+ ArrayRef(extractPosition).take_front(extractedRank));
extractOp.getVectorMutable().assign(source);
return extractOp.getResult();
}
/// Iterate over producing insert and transpose ops until we find a fold.
Value ExtractFromInsertTransposeChainState::fold() {
+ // TODO: Canonicalization for dynamic position not implemented yet.
+ if (extractOp.hasDynamicPosition())
+ return Value();
+
Value valueToExtractFrom = extractOp.getVector();
updateStateForNextIteration(valueToExtractFrom);
while (nextInsertOp || nextTransposeOp) {
@@ -1431,7 +1516,7 @@ Value ExtractFromInsertTransposeChainState::fold() {
// Case 4: extractPositionRef intersects insertedPosRef on non-sentinel
// values. This is a more
diff icult case and we bail.
- ArrayRef<int64_t> insertedPos = nextInsertOp.getPosition();
+ ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
if (isContainedWithin(extractPosition, insertedPos) ||
intersectsWhereNonNegative(extractPosition, insertedPos))
return Value();
@@ -1457,6 +1542,10 @@ static bool hasZeroDimVectors(Operation *op) {
/// Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
static Value foldExtractFromBroadcast(ExtractOp extractOp) {
+ // TODO: Canonicalization for dynamic position not implemented yet.
+ if (extractOp.hasDynamicPosition())
+ return Value();
+
Operation *defOp = extractOp.getVector().getDefiningOp();
if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
return Value();
@@ -1497,7 +1586,7 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
// extract position to `0` when extracting from the source operand.
llvm::SetVector<int64_t> broadcastedUnitDims =
broadcastOp.computeBroadcastedUnitDims();
- SmallVector<int64_t> extractPos(extractOp.getPosition());
+ SmallVector<int64_t> extractPos(extractOp.getStaticPosition());
int64_t broadcastRankDiff = broadcastDstRank - broadcastSrcRank;
for (int64_t i = broadcastRankDiff, e = extractPos.size(); i < e; ++i)
if (broadcastedUnitDims.contains(i))
@@ -1509,13 +1598,17 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
std::next(extractPos.begin(), extractPos.size() - rankDiff));
// OpBuilder is only used as a helper to build an I64ArrayAttr.
OpBuilder b(extractOp.getContext());
- extractOp.setOperand(source);
- extractOp.setPosition(extractPos);
+ extractOp.setOperand(0, source);
+ extractOp.setStaticPosition(extractPos);
return extractOp.getResult();
}
// Fold extractOp with source coming from ShapeCast op.
static Value foldExtractFromShapeCast(ExtractOp extractOp) {
+ // TODO: Canonicalization for dynamic position not implemented yet.
+ if (extractOp.hasDynamicPosition())
+ return Value();
+
auto shapeCastOp = extractOp.getVector().getDefiningOp<vector::ShapeCastOp>();
if (!shapeCastOp)
return Value();
@@ -1549,7 +1642,7 @@ static Value foldExtractFromShapeCast(ExtractOp extractOp) {
}
// Extract the strides associated with the extract op vector source. Then use
// this to calculate a linearized position for the extract.
- SmallVector<int64_t> extractedPos(extractOp.getPosition());
+ SmallVector<int64_t> extractedPos(extractOp.getStaticPosition());
std::reverse(extractedPos.begin(), extractedPos.end());
SmallVector<int64_t, 4> strides;
int64_t stride = 1;
@@ -1575,13 +1668,17 @@ static Value foldExtractFromShapeCast(ExtractOp extractOp) {
SmallVector<int64_t, 4> newPosition = delinearize(position, newStrides);
// OpBuilder is only used as a helper to build an I64ArrayAttr.
OpBuilder b(extractOp.getContext());
- extractOp.setPosition(newPosition);
- extractOp.setOperand(shapeCastOp.getSource());
+ extractOp.setStaticPosition(newPosition);
+ extractOp.setOperand(0, shapeCastOp.getSource());
return extractOp.getResult();
}
/// Fold an ExtractOp from ExtractStridedSliceOp.
static Value foldExtractFromExtractStrided(ExtractOp extractOp) {
+ // TODO: Canonicalization for dynamic position not implemented yet.
+ if (extractOp.hasDynamicPosition())
+ return Value();
+
auto extractStridedSliceOp =
extractOp.getVector().getDefiningOp<vector::ExtractStridedSliceOp>();
if (!extractStridedSliceOp)
@@ -1615,19 +1712,25 @@ static Value foldExtractFromExtractStrided(ExtractOp extractOp) {
if (destinationRank > extractStridedSliceOp.getSourceVectorType().getRank() -
sliceOffsets.size())
return Value();
- SmallVector<int64_t> extractedPos(extractOp.getPosition());
+
+ SmallVector<int64_t> extractedPos(extractOp.getStaticPosition());
assert(extractedPos.size() >= sliceOffsets.size());
for (size_t i = 0, e = sliceOffsets.size(); i < e; i++)
extractedPos[i] = extractedPos[i] + sliceOffsets[i];
extractOp.getVectorMutable().assign(extractStridedSliceOp.getVector());
+
// OpBuilder is only used as a helper to build an I64ArrayAttr.
OpBuilder b(extractOp.getContext());
- extractOp.setPosition(extractedPos);
+ extractOp.setStaticPosition(extractedPos);
return extractOp.getResult();
}
/// Fold extract_op fed from a chain of insertStridedSlice ops.
static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) {
+ // TODO: Canonicalization for dynamic position not implemented yet.
+ if (extractOp.hasDynamicPosition())
+ return Value();
+
int64_t destinationRank =
llvm::isa<VectorType>(extractOp.getType())
? llvm::cast<VectorType>(extractOp.getType()).getRank()
@@ -1647,7 +1750,7 @@ static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) {
if (destinationRank > insertOp.getSourceVectorType().getRank())
return Value();
auto insertOffsets = extractVector<int64_t>(insertOp.getOffsets());
- ArrayRef<int64_t> extractOffsets = extractOp.getPosition();
+ ArrayRef<int64_t> extractOffsets = extractOp.getStaticPosition();
if (llvm::any_of(insertOp.getStrides(), [](Attribute attr) {
return llvm::cast<IntegerAttr>(attr).getInt() != 1;
@@ -1687,7 +1790,7 @@ static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) {
extractOp.getVectorMutable().assign(insertOp.getSource());
// OpBuilder is only used as a helper to build an I64ArrayAttr.
OpBuilder b(extractOp.getContext());
- extractOp.setPosition(offsetDiffs);
+ extractOp.setStaticPosition(offsetDiffs);
return extractOp.getResult();
}
// If the chunk extracted is disjoint from the chunk inserted, keep
@@ -1698,7 +1801,7 @@ static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) {
}
OpFoldResult ExtractOp::fold(FoldAdaptor) {
- if (getPosition().empty())
+ if (getNumIndices() == 0)
return getVector();
if (succeeded(foldExtractOpFromExtractChain(*this)))
return getResult();
@@ -1788,6 +1891,10 @@ class ExtractOpNonSplatConstantFolder final
LogicalResult matchAndRewrite(ExtractOp extractOp,
PatternRewriter &rewriter) const override {
+ // TODO: Canonicalization for dynamic position not implemented yet.
+ if (extractOp.hasDynamicPosition())
+ return failure();
+
// Return if 'ExtractOp' operand is not defined by a compatible vector
// ConstantOp.
Value sourceVector = extractOp.getVector();
@@ -1807,7 +1914,7 @@ class ExtractOpNonSplatConstantFolder final
// Calculate the linearized position of the continuous chunk of elements to
// extract.
llvm::SmallVector<int64_t> completePositions(vecTy.getRank(), 0);
- copy(extractOp.getPosition(), completePositions.begin());
+ copy(extractOp.getStaticPosition(), completePositions.begin());
int64_t elemBeginPosition =
linearize(completePositions, computeStrides(vecTy.getShape()));
auto denseValuesBegin = dense.value_begin<TypedAttr>() + elemBeginPosition;
@@ -2322,18 +2429,38 @@ OpFoldResult vector::InsertElementOp::fold(FoldAdaptor adaptor) {
// InsertOp
//===----------------------------------------------------------------------===//
-// Convenience builder which assumes the values are constant indices.
-void InsertOp::build(OpBuilder &builder, OperationState &result, Value source,
- Value dest, ValueRange position) {
- SmallVector<int64_t, 4> positionConstants =
- llvm::to_vector<4>(llvm::map_range(position, [](Value pos) {
- return getConstantIntValue(pos).value();
- }));
- build(builder, result, source, dest, positionConstants);
+void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
+ Value source, Value dest, int64_t position) {
+ build(builder, result, source, dest, ArrayRef<int64_t>{position});
+}
+
+void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
+ Value source, Value dest, OpFoldResult position) {
+ build(builder, result, source, dest, ArrayRef<OpFoldResult>{position});
+}
+
+void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
+ Value source, Value dest,
+ ArrayRef<int64_t> position) {
+ SmallVector<OpFoldResult> posVals;
+ posVals.reserve(position.size());
+ llvm::transform(position, std::back_inserter(posVals),
+ [&](int64_t pos) { return builder.getI64IntegerAttr(pos); });
+ build(builder, result, source, dest, posVals);
+}
+
+void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
+ Value source, Value dest,
+ ArrayRef<OpFoldResult> position) {
+ SmallVector<int64_t> staticPos;
+ SmallVector<Value> dynamicPos;
+ dispatchIndexOpFoldResults(position, dynamicPos, staticPos);
+ build(builder, result, source, dest, dynamicPos,
+ builder.getDenseI64ArrayAttr(staticPos));
}
LogicalResult InsertOp::verify() {
- ArrayRef<int64_t> position = getPosition();
+ SmallVector<OpFoldResult> position = getMixedPosition();
auto destVectorType = getDestVectorType();
if (position.size() > static_cast<unsigned>(destVectorType.getRank()))
return emitOpError(
@@ -2348,13 +2475,17 @@ LogicalResult InsertOp::verify() {
(position.size() != static_cast<unsigned>(destVectorType.getRank())))
return emitOpError(
"expected position attribute rank to match the dest vector rank");
- for (const auto &en : llvm::enumerate(position)) {
- int64_t attr = en.value();
- if (attr < 0 || attr >= destVectorType.getDimSize(en.index()))
- return emitOpError("expected position attribute #")
- << (en.index() + 1)
- << " to be a non-negative integer smaller than the corresponding "
- "dest vector dimension";
+ for (auto [idx, pos] : llvm::enumerate(position)) {
+ if (auto attr = pos.dyn_cast<Attribute>()) {
+ int64_t constIdx = cast<IntegerAttr>(attr).getInt();
+ if (constIdx < 0 || constIdx >= destVectorType.getDimSize(idx)) {
+ return emitOpError("expected position attribute #")
+ << (idx + 1)
+ << " to be a non-negative integer smaller than the "
+ "corresponding "
+ "dest vector dimension";
+ }
+ }
}
return success();
}
@@ -2411,6 +2542,10 @@ class InsertOpConstantFolder final : public OpRewritePattern<InsertOp> {
LogicalResult matchAndRewrite(InsertOp op,
PatternRewriter &rewriter) const override {
+ // TODO: Canonicalization for dynamic position not implemented yet.
+ if (op.hasDynamicPosition())
+ return failure();
+
// Return if 'InsertOp' operand is not defined by a compatible vector
// ConstantOp.
TypedValue<VectorType> destVector = op.getDest();
@@ -2437,7 +2572,7 @@ class InsertOpConstantFolder final : public OpRewritePattern<InsertOp> {
// Calculate the linearized position of the continuous chunk of elements to
// insert.
llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
- copy(op.getPosition(), completePositions.begin());
+ copy(op.getStaticPosition(), completePositions.begin());
int64_t insertBeginPosition =
linearize(completePositions, computeStrides(destTy.getShape()));
@@ -2468,7 +2603,7 @@ void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
// value. This happens when the source and destination vectors have identical
// sizes.
OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
- if (getPosition().empty())
+ if (getNumIndices() == 0)
return getSource();
return {};
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index 64ab0abda26e640..7560db2332cf8d9 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -89,20 +89,20 @@ static Value reshapeLoad(Location loc, Value val, VectorType type,
PatternRewriter &rewriter) {
if (index == -1)
return val;
- Type lowType = type.getRank() > 1 ? VectorType::Builder(type).dropDim(0)
- : type.getElementType();
+
// At extraction dimension?
if (index == 0)
- return rewriter.create<vector::ExtractOp>(loc, lowType, val, pos);
+ return rewriter.create<vector::ExtractOp>(loc, val, pos);
+
// Unroll leading dimensions.
- VectorType vType = cast<VectorType>(lowType);
+ VectorType vType = VectorType::Builder(type).dropDim(0);
VectorType resType = VectorType::Builder(type).dropDim(index);
Value result = rewriter.create<arith::ConstantOp>(
loc, resType, rewriter.getZeroAttr(resType));
for (int64_t d = 0, e = resType.getDimSize(0); d < e; d++) {
- Value ext = rewriter.create<vector::ExtractOp>(loc, vType, val, d);
+ Value ext = rewriter.create<vector::ExtractOp>(loc, val, d);
Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter);
- result = rewriter.create<vector::InsertOp>(loc, resType, load, result, d);
+ result = rewriter.create<vector::InsertOp>(loc, load, result, d);
}
return result;
}
@@ -117,16 +117,15 @@ static Value reshapeStore(Location loc, Value val, Value result,
return val;
// At insertion dimension?
if (index == 0)
- return rewriter.create<vector::InsertOp>(loc, type, val, result, pos);
+ return rewriter.create<vector::InsertOp>(loc, val, result, pos);
+
// Unroll leading dimensions.
- VectorType lowType = VectorType::Builder(type).dropDim(0);
- Type insType = lowType.getRank() > 1 ? VectorType::Builder(lowType).dropDim(0)
- : lowType.getElementType();
+ VectorType vType = VectorType::Builder(type).dropDim(0);
for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) {
- Value ext = rewriter.create<vector::ExtractOp>(loc, lowType, result, d);
- Value ins = rewriter.create<vector::ExtractOp>(loc, insType, val, d);
- Value sto = reshapeStore(loc, ins, ext, lowType, index - 1, pos, rewriter);
- result = rewriter.create<vector::InsertOp>(loc, type, sto, result, d);
+ Value ext = rewriter.create<vector::ExtractOp>(loc, result, d);
+ Value ins = rewriter.create<vector::ExtractOp>(loc, val, d);
+ Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter);
+ result = rewriter.create<vector::InsertOp>(loc, sto, result, d);
}
return result;
}
@@ -1175,7 +1174,7 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
loc, a, op.getRhs(), r, kind, rewriter, isInt, extrMask);
if (!m.has_value())
return failure();
- result = rewriter.create<vector::InsertOp>(loc, resType, *m, result, d);
+ result = rewriter.create<vector::InsertOp>(loc, *m, result, d);
}
rewriter.replaceOp(rootOp, result);
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
index 95b5ea011c82569..887d1af7645419f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
@@ -79,7 +79,7 @@ class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> {
Value val = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
bnd, idx);
Value sel = rewriter.create<arith::SelectOp>(loc, val, trueVal, falseVal);
- result = rewriter.create<vector::InsertOp>(loc, dstType, sel, result, d);
+ result = rewriter.create<vector::InsertOp>(loc, sel, result, d);
}
rewriter.replaceOp(op, result);
return success();
@@ -151,8 +151,8 @@ class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
Value result = rewriter.create<arith::ConstantOp>(
loc, dstType, rewriter.getZeroAttr(dstType));
for (int64_t d = 0; d < trueDimSize; d++)
- result =
- rewriter.create<vector::InsertOp>(loc, dstType, trueVal, result, d);
+ result = rewriter.create<vector::InsertOp>(loc, trueVal, result, d);
+
rewriter.replaceOp(op, result);
return success();
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 2a50947e976dffb..f4486ea117a2934 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1040,13 +1040,17 @@ struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
"vector.extract does not support rank 0 sources");
// "vector.extract %v[] : vector<...xf32>" can be canonicalized to %v.
- if (extractOp.getPosition().empty())
+ if (extractOp.getNumIndices() == 0)
return failure();
// Rewrite vector.extract with 1d source to vector.extractelement.
if (extractSrcType.getRank() == 1) {
- assert(extractOp.getPosition().size() == 1 && "expected 1 index");
- int64_t pos = extractOp.getPosition()[0];
+ if (extractOp.hasDynamicPosition())
+ // TODO: Dinamic position not supported yet.
+ return failure();
+
+ assert(extractOp.getNumIndices() == 1 && "expected 1 index");
+ int64_t pos = extractOp.getStaticPosition()[0];
rewriter.setInsertionPoint(extractOp);
rewriter.replaceOpWithNewOp<vector::ExtractElementOp>(
extractOp, extractOp.getVector(),
@@ -1070,7 +1074,7 @@ struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
// Extract from distributed vector.
Value newExtract = rewriter.create<vector::ExtractOp>(
- loc, distributedVec, extractOp.getPosition());
+ loc, distributedVec, extractOp.getMixedPosition());
rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
newExtract);
return success();
@@ -1096,7 +1100,7 @@ struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
SmallVector<int64_t> newDistributedShape(extractSrcType.getShape().begin(),
extractSrcType.getShape().end());
for (int i = 0; i < distributedType.getRank(); ++i)
- newDistributedShape[i + extractOp.getPosition().size()] =
+ newDistributedShape[i + extractOp.getNumIndices()] =
distributedType.getDimSize(i);
auto newDistributedType =
VectorType::get(newDistributedShape, distributedType.getElementType());
@@ -1108,7 +1112,7 @@ struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
// Extract from distributed vector.
Value newExtract = rewriter.create<vector::ExtractOp>(
- loc, distributedVec, extractOp.getPosition());
+ loc, distributedVec, extractOp.getMixedPosition());
rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
newExtract);
return success();
@@ -1297,13 +1301,17 @@ struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> {
Location loc = insertOp.getLoc();
// "vector.insert %v, %v[] : ..." can be canonicalized to %v.
- if (insertOp.getPosition().empty())
+ if (insertOp.getNumIndices() == 0)
return failure();
// Rewrite vector.insert with 1d dest to vector.insertelement.
if (insertOp.getDestVectorType().getRank() == 1) {
- assert(insertOp.getPosition().size() == 1 && "expected 1 index");
- int64_t pos = insertOp.getPosition()[0];
+ if (insertOp.hasDynamicPosition())
+ // TODO: Dinamic position not supported yet.
+ return failure();
+
+ assert(insertOp.getNumIndices() == 1 && "expected 1 index");
+ int64_t pos = insertOp.getStaticPosition()[0];
rewriter.setInsertionPoint(insertOp);
rewriter.replaceOpWithNewOp<vector::InsertElementOp>(
insertOp, insertOp.getSource(), insertOp.getDest(),
@@ -1323,7 +1331,7 @@ struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> {
Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
Value newResult = rewriter.create<vector::InsertOp>(
- loc, distributedSrc, distributedDest, insertOp.getPosition());
+ loc, distributedSrc, distributedDest, insertOp.getMixedPosition());
rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
newResult);
return success();
@@ -1354,7 +1362,7 @@ struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> {
// Case 2: distrDestDim = 0 (dim of size 128) => distrSrcDim = -1. In that
// case, one lane will insert the source vector<96xf32>. The other
// lanes will not do anything.
- int64_t distrSrcDim = distrDestDim - insertOp.getPosition().size();
+ int64_t distrSrcDim = distrDestDim - insertOp.getNumIndices();
if (distrSrcDim >= 0)
distrSrcShape[distrSrcDim] = distrDestType.getDimSize(distrDestDim);
auto distrSrcType =
@@ -1374,11 +1382,12 @@ struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> {
if (distrSrcDim >= 0) {
// Every lane inserts a small piece.
newResult = rewriter.create<vector::InsertOp>(
- loc, distributedSrc, distributedDest, insertOp.getPosition());
+ loc, distributedSrc, distributedDest, insertOp.getMixedPosition());
} else {
// One lane inserts the entire source vector.
int64_t elementsPerLane = distrDestType.getDimSize(distrDestDim);
- SmallVector<int64_t> newPos(insertOp.getPosition());
+ SmallVector<OpFoldResult> pos = insertOp.getMixedPosition();
+ SmallVector<int64_t> newPos = getAsIntegers(pos);
// tid of inserting lane: pos / elementsPerLane
Value insertingLane = rewriter.create<arith::ConstantIndexOp>(
loc, newPos[distrDestDim] / elementsPerLane);
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index 913c826dd912470..6bbb293fa2a6b5c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -6,6 +6,7 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
@@ -176,14 +177,16 @@ struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> {
// type has leading unit dims, we also trim the position array accordingly,
// then (2) if source type also has leading unit dims, we need to append
// zeroes to the position array accordingly.
- unsigned oldPosRank = insertOp.getPosition().size();
+ unsigned oldPosRank = insertOp.getNumIndices();
unsigned newPosRank = std::max<int64_t>(0, oldPosRank - dstDropCount);
- SmallVector<int64_t> newPositions =
- llvm::to_vector(insertOp.getPosition().take_back(newPosRank));
- newPositions.resize(newDstType.getRank() - newSrcRank, 0);
+ SmallVector<OpFoldResult> oldPosition = insertOp.getMixedPosition();
+ SmallVector<OpFoldResult> newPosition =
+ llvm::to_vector(ArrayRef(oldPosition).take_back(newPosRank));
+ newPosition.resize(newDstType.getRank() - newSrcRank,
+ rewriter.getI64IntegerAttr(0));
auto newInsertOp = rewriter.create<vector::InsertOp>(
- loc, newDstType, newSrcVector, newDstVector, newPositions);
+ loc, newSrcVector, newDstVector, newPosition);
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType,
newInsertOp);
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index f715c543eb17955..603b88f11c8e007 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -707,10 +707,10 @@ class RewriteScalarExtractOfTransferRead
auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
SmallVector<Value> newIndices(xferOp.getIndices().begin(),
xferOp.getIndices().end());
- for (const auto &it : llvm::enumerate(extractOp.getPosition())) {
- int64_t offset = it.value();
- int64_t idx =
- newIndices.size() - extractOp.getPosition().size() + it.index();
+ for (auto [i, pos] : llvm::enumerate(extractOp.getMixedPosition())) {
+ assert(pos.is<Attribute>() && "Unexpected non-constant index");
+ int64_t offset = cast<IntegerAttr>(pos.get<Attribute>()).getInt();
+ int64_t idx = newIndices.size() - extractOp.getNumIndices() + i;
OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
rewriter, extractOp.getLoc(),
rewriter.getAffineSymbolExpr(0) + offset, {newIndices[idx]});
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index b2a5aef5ee62d0f..b891d62ee508e30 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -598,27 +598,34 @@ struct BubbleDownVectorBitCastForExtract
unsigned expandRatio =
castDstType.getNumElements() / castSrcType.getNumElements();
- uint64_t index = extractOp.getPosition()[0];
+ auto getFirstIntValue = [](ArrayRef<OpFoldResult> values) -> uint64_t {
+ assert(values[0].is<Attribute>() && "Unexpected non-constant index");
+ return cast<IntegerAttr>(values[0].get<Attribute>()).getInt();
+ };
+
+ uint64_t index = getFirstIntValue(extractOp.getMixedPosition());
// Get the single scalar (as a vector) in the source value that packs the
// desired scalar. E.g. extract vector<1xf32> from vector<4xf32>
- VectorType oneScalarType =
- VectorType::get({1}, castSrcType.getElementType());
+ Location loc = extractOp.getLoc();
Value packedValue = rewriter.create<vector::ExtractOp>(
- extractOp.getLoc(), oneScalarType, castOp.getSource(),
- index / expandRatio);
+ loc, castOp.getSource(), index / expandRatio);
+ Type packedVecType = VectorType::get(/*shape=*/{1}, packedValue.getType());
+ Value zero = rewriter.create<arith::ConstantOp>(
+ loc, packedVecType, rewriter.getZeroAttr(packedVecType));
+ packedValue = rewriter.create<vector::InsertOp>(loc, packedValue, zero,
+ /*position=*/0);
// Cast it to a vector with the desired scalar's type.
// E.g. f32 -> vector<2xf16>
VectorType packedType =
VectorType::get({expandRatio}, castDstType.getElementType());
- Value castedValue = rewriter.create<vector::BitCastOp>(
- extractOp.getLoc(), packedType, packedValue);
+ Value castedValue =
+ rewriter.create<vector::BitCastOp>(loc, packedType, packedValue);
// Finally extract the desired scalar.
- rewriter.replaceOpWithNewOp<vector::ExtractOp>(
- extractOp, extractOp.getType(), castedValue, index % expandRatio);
-
+ rewriter.replaceOpWithNewOp<vector::ExtractOp>(extractOp, castedValue,
+ index % expandRatio);
return success();
}
};
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index b07c4bd67be2dc7..41ab06f2e23b501 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -728,6 +728,17 @@ func.func @extract_element_from_vec_3d(%arg0: vector<4x3x16xf32>) -> f32 {
// -----
+func.func @extract_element_with_value_1d(%arg0: vector<16xf32>, %arg1: index) -> f32 {
+ %0 = vector.extract %arg0[%arg1]: vector<16xf32>
+ return %0 : f32
+}
+// CHECK-LABEL: @extract_element_with_value_1d
+// CHECK-SAME: %[[VEC:.+]]: vector<16xf32>, %[[INDEX:.+]]: index
+// CHECK: %[[UC:.+]] = builtin.unrealized_conversion_cast %[[INDEX]] : index to i64
+// CHECK: llvm.extractelement %[[VEC]][%[[UC]] : i64] : vector<16xf32>
+
+// -----
+
// CHECK-LABEL: @insert_element_0d
// CHECK-SAME: %[[A:.*]]: f32,
func.func @insert_element_0d(%a: f32, %b: vector<f32>) -> vector<f32> {
@@ -830,6 +841,19 @@ func.func @insert_element_into_vec_3d(%arg0: f32, %arg1: vector<4x8x16xf32>) ->
// -----
+func.func @insert_element_with_value_1d(%arg0: vector<16xf32>, %arg1: f32, %arg2: index)
+ -> vector<16xf32> {
+ %0 = vector.insert %arg1, %arg0[%arg2]: f32 into vector<16xf32>
+ return %0 : vector<16xf32>
+}
+
+// CHECK-LABEL: @insert_element_with_value_1d
+// CHECK-SAME: %[[DST:.+]]: vector<16xf32>, %[[SRC:.+]]: f32, %[[INDEX:.+]]: index
+// CHECK: %[[UC:.+]] = builtin.unrealized_conversion_cast %[[INDEX]] : index to i64
+// CHECK: llvm.insertelement %[[SRC]], %[[DST]][%[[UC]] : i64] : vector<16xf32>
+
+// -----
+
func.func @vector_type_cast(%arg0: memref<8x8x8xf32>) -> memref<vector<8x8x8xf32>> {
%0 = vector.type_cast %arg0: memref<8x8x8xf32> to memref<vector<8x8x8xf32>>
return %0 : memref<vector<8x8x8xf32>>
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index f60a522cbfdba56..266161d5268e985 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -155,8 +155,8 @@ func.func @broadcast(%arg0 : f32) -> (vector<4xf32>, vector<2xf32>) {
// CHECK: spirv.CompositeExtract %[[ARG]][0 : i32] : vector<2xf32>
// CHECK: spirv.CompositeExtract %[[ARG]][1 : i32] : vector<2xf32>
func.func @extract(%arg0 : vector<2xf32>) -> (vector<1xf32>, f32) {
- %0 = "vector.extract"(%arg0) <{position = array<i64: 0>}> : (vector<2xf32>) -> vector<1xf32>
- %1 = "vector.extract"(%arg0) <{position = array<i64: 1>}> : (vector<2xf32>) -> f32
+ %0 = "vector.extract"(%arg0) <{static_position = array<i64: 0>}> : (vector<2xf32>) -> vector<1xf32>
+ %1 = "vector.extract"(%arg0) <{static_position = array<i64: 1>}> : (vector<2xf32>) -> f32
return %0, %1: vector<1xf32>, f32
}
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 26772b929493585..549fe7a6a61f6ac 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -133,7 +133,7 @@ func.func @extract_position_rank_overflow(%arg0: vector<4x8x16xf32>) {
func.func @extract_position_rank_overflow_generic(%arg0: vector<4x8x16xf32>) {
// expected-error at +1 {{expected position attribute of rank no greater than vector rank}}
- %1 = "vector.extract" (%arg0) <{position = array<i64: 0, 0, 0, 0>}> : (vector<4x8x16xf32>) -> (vector<16xf32>)
+ %1 = "vector.extract" (%arg0) <{static_position = array<i64: 0, 0, 0, 0>}> : (vector<4x8x16xf32>) -> (vector<16xf32>)
}
// -----
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 3a23ee14ca14fa0..f879cd122469a65 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -206,8 +206,9 @@ func.func @extract_element(%a: vector<16xf32>) -> f32 {
return %1 : f32
}
-// CHECK-LABEL: @extract
-func.func @extract(%arg0: vector<4x8x16xf32>) -> (vector<4x8x16xf32>, vector<8x16xf32>, vector<16xf32>, f32) {
+// CHECK-LABEL: @extract_const_idx
+func.func @extract_const_idx(%arg0: vector<4x8x16xf32>)
+ -> (vector<4x8x16xf32>, vector<8x16xf32>, vector<16xf32>, f32) {
// CHECK: vector.extract {{.*}}[] : vector<4x8x16xf32>
%0 = vector.extract %arg0[] : vector<4x8x16xf32>
// CHECK: vector.extract {{.*}}[3] : vector<4x8x16xf32>
@@ -219,6 +220,19 @@ func.func @extract(%arg0: vector<4x8x16xf32>) -> (vector<4x8x16xf32>, vector<8x1
return %0, %1, %2, %3 : vector<4x8x16xf32>, vector<8x16xf32>, vector<16xf32>, f32
}
+// CHECK-LABEL: @extract_val_idx
+// CHECK-SAME: %[[VEC:.+]]: vector<4x8x16xf32>, %[[IDX:.+]]: index
+func.func @extract_val_idx(%arg0: vector<4x8x16xf32>, %idx: index)
+ -> (vector<8x16xf32>, vector<16xf32>, f32) {
+ // CHECK: vector.extract %[[VEC]][%[[IDX]]] : vector<4x8x16xf32>
+ %0 = vector.extract %arg0[%idx] : vector<4x8x16xf32>
+ // CHECK-NEXT: vector.extract %[[VEC]][%[[IDX]], %[[IDX]]] : vector<4x8x16xf32>
+ %1 = vector.extract %arg0[%idx, %idx] : vector<4x8x16xf32>
+ // CHECK-NEXT: vector.extract %[[VEC]][%[[IDX]], 5, %[[IDX]]] : vector<4x8x16xf32>
+ %2 = vector.extract %arg0[%idx, 5, %idx] : vector<4x8x16xf32>
+ return %0, %1, %2 : vector<8x16xf32>, vector<16xf32>, f32
+}
+
// CHECK-LABEL: @extract_0d
func.func @extract_0d(%a: vector<f32>) -> f32 {
// CHECK-NEXT: vector.extract %{{.*}}[] : vector<f32>
@@ -242,8 +256,9 @@ func.func @insert_element(%a: f32, %b: vector<16xf32>) -> vector<16xf32> {
return %1 : vector<16xf32>
}
-// CHECK-LABEL: @insert
-func.func @insert(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>, %res: vector<4x8x16xf32>) -> vector<4x8x16xf32> {
+// CHECK-LABEL: @insert_const_idx
+func.func @insert_const_idx(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>,
+ %res: vector<4x8x16xf32>) -> vector<4x8x16xf32> {
// CHECK: vector.insert %{{.*}}, %{{.*}}[3] : vector<8x16xf32> into vector<4x8x16xf32>
%1 = vector.insert %c, %res[3] : vector<8x16xf32> into vector<4x8x16xf32>
// CHECK: vector.insert %{{.*}}, %{{.*}}[3, 3] : vector<16xf32> into vector<4x8x16xf32>
@@ -255,6 +270,19 @@ func.func @insert(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>, %res: vecto
return %4 : vector<4x8x16xf32>
}
+// CHECK-LABEL: @insert_val_idx
+// CHECK-SAME: %[[A:.+]]: f32, %[[B:.+]]: vector<16xf32>, %[[C:.+]]: vector<8x16xf32>, %[[IDX:.+]]: index
+func.func @insert_val_idx(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>,
+ %idx: index, %res: vector<4x8x16xf32>) -> vector<4x8x16xf32> {
+ // CHECK: vector.insert %[[C]], %{{.*}}[%[[IDX]]] : vector<8x16xf32> into vector<4x8x16xf32>
+ %0 = vector.insert %c, %res[%idx] : vector<8x16xf32> into vector<4x8x16xf32>
+ // CHECK: vector.insert %[[B]], %{{.*}}[%[[IDX]], %[[IDX]]] : vector<16xf32> into vector<4x8x16xf32>
+ %1 = vector.insert %b, %res[%idx, %idx] : vector<16xf32> into vector<4x8x16xf32>
+ // CHECK: vector.insert %[[A]], %{{.*}}[%[[IDX]], 5, %[[IDX]]] : f32 into vector<4x8x16xf32>
+ %2 = vector.insert %a, %res[%idx, 5, %idx] : f32 into vector<4x8x16xf32>
+ return %2 : vector<4x8x16xf32>
+}
+
// CHECK-LABEL: @insert_0d
func.func @insert_0d(%a: f32, %b: vector<f32>, %c: vector<2x3xf32>) -> (vector<f32>, vector<2x3xf32>) {
// CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[] : f32 into vector<f32>
@@ -1007,7 +1035,7 @@ func.func @contraction_masked_scalable(%A: vector<3x4xf32>,
%C: vector<3x[8]xf32>,
%M : vector<3x[8]x4xi1>) -> vector<3x[8]xf32> {
// CHECK: vector.mask %[[M]] { vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[A]], %[[B]], %[[C]] : vector<3x4xf32>, vector<4x[8]xf32> into vector<3x[8]xf32> } : vector<3x[8]x4xi1> -> vector<3x[8]xf32>
- %0 = vector.mask %M { vector.contract #matmat_trait %A, %B, %C : vector<3x4xf32>, vector<4x[8]xf32> into vector<3x[8]xf32> }
+ %0 = vector.mask %M { vector.contract #matmat_trait %A, %B, %C : vector<3x4xf32>, vector<4x[8]xf32> into vector<3x[8]xf32> }
: vector<3x[8]x4xi1> -> vector<3x[8]xf32>
return %0 : vector<3x[8]xf32>
}
diff --git a/mlir/test/Dialect/Vector/vector-transforms.mlir b/mlir/test/Dialect/Vector/vector-transforms.mlir
index dfc564ca6fe4836..27bbe1bb0d0349d 100644
--- a/mlir/test/Dialect/Vector/vector-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-transforms.mlir
@@ -286,11 +286,13 @@ func.func @contraction4x4_ikj_xfer_read_tensor(%arg0 : tensor<4x2xf32>,
func.func @bubble_down_bitcast_in_extract(%src: vector<4xf32>) -> (f16, f16) {
%0 = vector.bitcast %src : vector<4xf32> to vector<8xf16>
// CHECK: %[[EXTRACT1:.+]] = vector.extract %[[SRC]][1] : vector<4xf32>
- // CHECK: %[[CAST1:.+]] = vector.bitcast %[[EXTRACT1]] : vector<1xf32> to vector<2xf16>
+ // CHECK: %[[INSERT1:.+]] = vector.insert %[[EXTRACT1]], %{{.+}} [0] : f32 into vector<1xf32>
+ // CHECK: %[[CAST1:.+]] = vector.bitcast %[[INSERT1]] : vector<1xf32> to vector<2xf16>
// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[CAST1]][1] : vector<2xf16>
%1 = vector.extract %0[3] : vector<8xf16>
// CHECK: %[[EXTRACT3:.+]] = vector.extract %[[SRC]][2] : vector<4xf32>
- // CHECK: %[[CAST2:.+]] = vector.bitcast %[[EXTRACT3]] : vector<1xf32> to vector<2xf16>
+ // CHECK: %[[INSERT3:.+]] = vector.insert %[[EXTRACT3]], %{{.+}} [0] : f32 into vector<1xf32>
+ // CHECK: %[[CAST2:.+]] = vector.bitcast %[[INSERT3]] : vector<1xf32> to vector<2xf16>
// CHECK: %[[EXTRACT4:.+]] = vector.extract %[[CAST2]][0] : vector<2xf16>
%2 = vector.extract %0[4] : vector<8xf16>
// CHECK: return %[[EXTRACT2]], %[[EXTRACT4]]
More information about the Mlir-commits
mailing list