[Mlir-commits] [mlir] b4444dc - [mlir][vector] Use `DenseI64ArrayAttr` for shuffle masks (#101163)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jul 30 07:00:17 PDT 2024
Author: Benjamin Maxwell
Date: 2024-07-30T15:00:14+01:00
New Revision: b4444dca47c41436aa781bfd38aac6eca856ef23
URL: https://github.com/llvm/llvm-project/commit/b4444dca47c41436aa781bfd38aac6eca856ef23
DIFF: https://github.com/llvm/llvm-project/commit/b4444dca47c41436aa781bfd38aac6eca856ef23.diff
LOG: [mlir][vector] Use `DenseI64ArrayAttr` for shuffle masks (#101163)
Follow on from #100997. This again removes from boilerplate conversions
to/from IntegerAttr and int64_t (otherwise, this is a NFC).
Added:
Modified:
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 3cdbd21874567..434ff3956c250 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -421,7 +421,7 @@ def Vector_ShuffleOp :
TCresVTEtIsSameAsOpBase<0, 1>>,
InferTypeOpAdaptor]>,
Arguments<(ins AnyFixedVector:$v1, AnyFixedVector:$v2,
- I64ArrayAttr:$mask)>,
+ DenseI64ArrayAttr:$mask)>,
Results<(outs AnyVector:$vector)> {
let summary = "shuffle operation";
let description = [{
@@ -459,11 +459,7 @@ def Vector_ShuffleOp :
: vector<f32>, vector<f32> ; yields vector<2xf32>
```
}];
- let builders = [
- OpBuilder<(ins "Value":$v1, "Value":$v2, "ArrayRef<int64_t>")>
- ];
- let hasFolder = 1;
- let hasCanonicalizer = 1;
+
let extraClassDeclaration = [{
VectorType getV1VectorType() {
return ::llvm::cast<VectorType>(getV1().getType());
@@ -475,7 +471,10 @@ def Vector_ShuffleOp :
return ::llvm::cast<VectorType>(getVector().getType());
}
}];
+
let assemblyFormat = "operands $mask attr-dict `:` type(operands)";
+
+ let hasFolder = 1;
let hasVerifier = 1;
let hasCanonicalizer = 1;
}
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index f6b1c42dcd24c..53e18a2e9d299 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -994,7 +994,7 @@ class VectorShuffleOpConversion
auto v2Type = shuffleOp.getV2VectorType();
auto vectorType = shuffleOp.getResultVectorType();
Type llvmType = typeConverter->convertType(vectorType);
- auto maskArrayAttr = shuffleOp.getMask();
+ ArrayRef<int64_t> mask = shuffleOp.getMask();
// Bail if result type cannot be lowered.
if (!llvmType)
@@ -1015,7 +1015,7 @@ class VectorShuffleOpConversion
if (rank <= 1 && v1Type == v2Type) {
Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>(
loc, adaptor.getV1(), adaptor.getV2(),
- LLVM::convertArrayToIndices<int32_t>(maskArrayAttr));
+ llvm::to_vector_of<int32_t>(mask));
rewriter.replaceOp(shuffleOp, llvmShuffleOp);
return success();
}
@@ -1029,8 +1029,7 @@ class VectorShuffleOpConversion
eltType = cast<VectorType>(llvmType).getElementType();
Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType);
int64_t insPos = 0;
- for (const auto &en : llvm::enumerate(maskArrayAttr)) {
- int64_t extPos = cast<IntegerAttr>(en.value()).getInt();
+ for (int64_t extPos : mask) {
Value value = adaptor.getV1();
if (extPos >= v1Dim) {
extPos -= v1Dim;
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 890706bf1bb2e..21b8858989839 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -527,10 +527,7 @@ struct VectorShuffleOpConvert final
return rewriter.notifyMatchFailure(shuffleOp,
"unsupported result vector type");
- SmallVector<int32_t, 4> mask = llvm::map_to_vector<4>(
- shuffleOp.getMask(), [](Attribute attr) -> int32_t {
- return cast<IntegerAttr>(attr).getValue().getZExtValue();
- });
+ auto mask = llvm::to_vector_of<int32_t>(shuffleOp.getMask());
VectorType oldV1Type = shuffleOp.getV1VectorType();
VectorType oldV2Type = shuffleOp.getV2VectorType();
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 669ae586e5786..5047bd925d4c5 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2464,11 +2464,6 @@ void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
// ShuffleOp
//===----------------------------------------------------------------------===//
-void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value v1,
- Value v2, ArrayRef<int64_t> mask) {
- build(builder, result, v1, v2, getVectorSubscriptAttr(builder, mask));
-}
-
LogicalResult ShuffleOp::verify() {
VectorType resultType = getResultVectorType();
VectorType v1Type = getV1VectorType();
@@ -2491,8 +2486,8 @@ LogicalResult ShuffleOp::verify() {
return emitOpError("dimension mismatch");
}
// Verify mask length.
- auto maskAttr = getMask().getValue();
- int64_t maskLength = maskAttr.size();
+ ArrayRef<int64_t> mask = getMask();
+ int64_t maskLength = mask.size();
if (maskLength <= 0)
return emitOpError("invalid mask length");
if (maskLength != resultType.getDimSize(0))
@@ -2500,10 +2495,9 @@ LogicalResult ShuffleOp::verify() {
// Verify all indices.
int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
(v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
- for (const auto &en : llvm::enumerate(maskAttr)) {
- auto attr = llvm::dyn_cast<IntegerAttr>(en.value());
- if (!attr || attr.getInt() < 0 || attr.getInt() >= indexSize)
- return emitOpError("mask index #") << (en.index() + 1) << " out of range";
+ for (auto [idx, maskPos] : llvm::enumerate(mask)) {
+ if (maskPos < 0 || maskPos >= indexSize)
+ return emitOpError("mask index #") << (idx + 1) << " out of range";
}
return success();
}
@@ -2527,13 +2521,12 @@ ShuffleOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
return success();
}
-static bool isStepIndexArray(ArrayAttr idxArr, uint64_t begin, size_t width) {
- uint64_t expected = begin;
- return idxArr.size() == width &&
- llvm::all_of(idxArr.getAsValueRange<IntegerAttr>(),
- [&expected](auto attr) {
- return attr.getZExtValue() == expected++;
- });
+template <typename T>
+static bool isStepIndexArray(ArrayRef<T> idxArr, uint64_t begin, size_t width) {
+ T expected = begin;
+ return idxArr.size() == width && llvm::all_of(idxArr, [&expected](T value) {
+ return value == expected++;
+ });
}
OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
@@ -2568,8 +2561,7 @@ OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
SmallVector<Attribute> results;
auto lhsElements = llvm::cast<DenseElementsAttr>(lhs).getValues<Attribute>();
auto rhsElements = llvm::cast<DenseElementsAttr>(rhs).getValues<Attribute>();
- for (const auto &index : this->getMask().getAsValueRange<IntegerAttr>()) {
- int64_t i = index.getZExtValue();
+ for (int64_t i : this->getMask()) {
if (i >= lhsSize) {
results.push_back(rhsElements[i - lhsSize]);
} else {
@@ -2590,13 +2582,13 @@ struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> {
LogicalResult matchAndRewrite(ShuffleOp shuffleOp,
PatternRewriter &rewriter) const override {
VectorType v1VectorType = shuffleOp.getV1VectorType();
- ArrayAttr mask = shuffleOp.getMask();
+ ArrayRef<int64_t> mask = shuffleOp.getMask();
if (v1VectorType.getRank() > 0)
return failure();
if (mask.size() != 1)
return failure();
VectorType resType = VectorType::Builder(v1VectorType).setShape({1});
- if (llvm::cast<IntegerAttr>(mask[0]).getInt() == 0)
+ if (mask[0] == 0)
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType,
shuffleOp.getV1());
else
@@ -2651,11 +2643,11 @@ class ShuffleInterleave : public OpRewritePattern<ShuffleOp> {
op, "ShuffleOp types don't match an interleave");
}
- ArrayAttr shuffleMask = op.getMask();
+ ArrayRef<int64_t> shuffleMask = op.getMask();
int64_t resultVectorSize = resultType.getNumElements();
for (int i = 0, e = resultVectorSize / 2; i < e; ++i) {
- int64_t maskValueA = cast<IntegerAttr>(shuffleMask[i * 2]).getInt();
- int64_t maskValueB = cast<IntegerAttr>(shuffleMask[(i * 2) + 1]).getInt();
+ int64_t maskValueA = shuffleMask[i * 2];
+ int64_t maskValueB = shuffleMask[(i * 2) + 1];
if (maskValueA != i || maskValueB != (resultVectorSize / 2) + i)
return rewriter.notifyMatchFailure(op,
"ShuffleOp mask not interleaving");
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
index 37216cea7b615..ec2ef3fc7501c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
@@ -225,8 +225,7 @@ class Convert1DExtractStridedSliceIntoShuffle
off += stride)
offsets.push_back(off);
rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.getVector(),
- op.getVector(),
- rewriter.getI64ArrayAttr(offsets));
+ op.getVector(), offsets);
return success();
}
};
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 4a3ae1b850517..868397f2daaae 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -232,8 +232,7 @@ struct LinearizeVectorExtractStridedSlice final
}
// Perform a shuffle to extract the kD vector.
rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
- extractOp, dstType, srcVector, srcVector,
- rewriter.getI64ArrayAttr(indices));
+ extractOp, dstType, srcVector, srcVector, indices);
return success();
}
@@ -298,20 +297,17 @@ struct LinearizeVectorShuffle final
// that needs to be shuffled to the destination vector. If shuffleSliceLen >
// 1 we need to shuffle the slices (consecutive shuffleSliceLen number of
// elements) instead of scalars.
- ArrayAttr mask = shuffleOp.getMask();
+ ArrayRef<int64_t> mask = shuffleOp.getMask();
int64_t totalSizeOfShuffledElmnts = mask.size() * shuffleSliceLen;
llvm::SmallVector<int64_t, 2> indices(totalSizeOfShuffledElmnts);
- for (auto [i, value] :
- llvm::enumerate(mask.getAsValueRange<IntegerAttr>())) {
-
- int64_t v = value.getZExtValue();
+ for (auto [i, value] : llvm::enumerate(mask)) {
std::iota(indices.begin() + shuffleSliceLen * i,
indices.begin() + shuffleSliceLen * (i + 1),
- shuffleSliceLen * v);
+ shuffleSliceLen * value);
}
- rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
- shuffleOp, dstType, vec1, vec2, rewriter.getI64ArrayAttr(indices));
+ rewriter.replaceOpWithNewOp<vector::ShuffleOp>(shuffleOp, dstType, vec1,
+ vec2, indices);
return success();
}
@@ -368,8 +364,7 @@ struct LinearizeVectorExtract final
llvm::SmallVector<int64_t, 2> indices(size);
std::iota(indices.begin(), indices.end(), linearizedOffset);
rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
- extractOp, dstTy, adaptor.getVector(), adaptor.getVector(),
- rewriter.getI64ArrayAttr(indices));
+ extractOp, dstTy, adaptor.getVector(), adaptor.getVector(), indices);
return success();
}
@@ -452,8 +447,7 @@ struct LinearizeVectorInsert final
// [offset+srcNumElements, end)
rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
- insertOp, dstTy, adaptor.getDest(), adaptor.getSource(),
- rewriter.getI64ArrayAttr(indices));
+ insertOp, dstTy, adaptor.getDest(), adaptor.getSource(), indices);
return success();
}
More information about the Mlir-commits
mailing list