[Mlir-commits] [mlir] [mlir][vector] Move hidden function to op definition (PR #140813)
James Newling
llvmlistbot at llvm.org
Thu May 22 12:06:05 PDT 2025
https://github.com/newling updated https://github.com/llvm/llvm-project/pull/140813
>From 207b0e11365b61db78896741d17d2505f8c3f891 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Tue, 20 May 2025 15:25:12 -0700
Subject: [PATCH 1/2] first commit in preparation
---
.../mlir/Dialect/Vector/IR/VectorOps.td | 4 +
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 114 ++++++++++++++++
.../Vector/Transforms/VectorLinearize.cpp | 126 ++----------------
3 files changed, 132 insertions(+), 112 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 3f5564541554e..f8412863b18c9 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1114,6 +1114,8 @@ def Vector_InsertStridedSliceOp :
return ::llvm::cast<IntegerAttr>(attr).getInt() != 1;
});
}
+ // \return The indices in dest that the values are inserted to.
+ FailureOr<SmallVector<int64_t>> getLinearIndices();
}];
let hasFolder = 1;
@@ -1254,6 +1256,8 @@ def Vector_ExtractStridedSliceOp :
return ::llvm::cast<IntegerAttr>(attr).getInt() != 1;
});
}
+ // \return The indices in source that the values are taken from.
+ FailureOr<SmallVector<int64_t>> getLinearIndices();
}];
let hasCanonicalizer = 1;
let hasFolder = 1;
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 41777347975da..e800b7b7c9ff6 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -3182,6 +3182,101 @@ void InsertStridedSliceOp::build(OpBuilder &builder, OperationState &result,
stridesAttr);
}
+/// Convert an array of attributes into a vector of integers, if possible.
+static FailureOr<SmallVector<int64_t>> intsFromArrayAttr(ArrayAttr attrs) {
+ if (!attrs)
+ return failure();
+ SmallVector<int64_t> ints;
+ ints.reserve(attrs.size());
+ for (auto attr : attrs) {
+ if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
+ ints.push_back(intAttr.getInt());
+ } else {
+ return failure();
+ }
+ }
+ return ints;
+}
+
+/// Consider inserting a vector of shape `small` into a vector of shape `large`,
+/// at position `offsets`: this function enumeratates all the indices in `large`
+/// that are written to. The enumeration is with row-major ordering.
+///
+/// Example: insert a 1x2 vector into a 4x5 vector at position (1,3). The 2
+/// positions written to are (1,3) and (1,4), which have linearized indices 8
+/// and 9. So [8,9] is returned.
+///
+/// The length of the returned vector is equal to the number of elements in
+/// the shape `small` (i.e. the product of dimensions of `small`).
+static SmallVector<int64_t>
+getStridedSliceInsertionIndices(ArrayRef<int64_t> small,
+ ArrayRef<int64_t> large,
+ ArrayRef<int64_t> offsets) {
+
+ // Example of alignment between, `large`, `small` and `offsets`:
+ // large = 4, 5, 6, 7, 8
+ // small = 1, 6, 7, 8
+ // offsets = 2, 3, 0
+ //
+ // `offsets` has implicit trailing 0s, `small` has implicit leading 1s.
+ assert((large.size() >= small.size()) &&
+ "rank of 'large' cannot be lower than rank of 'small'");
+ assert((large.size() >= offsets.size()) &&
+ "rank of 'large' cannot be lower than the number of offsets");
+ unsigned delta = large.size() - small.size();
+ unsigned nOffsets = offsets.size();
+ auto getSmall = [&](int64_t i) -> int64_t {
+ return i >= delta ? small[i - delta] : 1;
+ };
+ auto getOffset = [&](int64_t i) -> int64_t {
+ return i < nOffsets ? offsets[i] : 0;
+ };
+
+ // Using 2 vectors of indices, at each iteration populate the updated set of
+ // indices based on the old set of indices, and the size of the small vector
+ // in the current iteration.
+ SmallVector<int64_t> indices{0};
+ int64_t stride = 1;
+ for (int i = large.size() - 1; i >= 0; --i) {
+ int64_t currentSize = indices.size();
+ int64_t smallSize = getSmall(i);
+ int64_t nextSize = currentSize * smallSize;
+ SmallVector<int64_t> nextIndices(nextSize);
+ int64_t *base = nextIndices.begin();
+ int64_t offset = getOffset(i) * stride;
+ for (int j = 0; j < smallSize; ++j) {
+ for (int k = 0; k < currentSize; ++k) {
+ base[k] = indices[k] + offset;
+ }
+ offset += stride;
+ base += currentSize;
+ }
+ stride *= large[i];
+ indices = std::move(nextIndices);
+ }
+ return indices;
+}
+
+FailureOr<SmallVector<int64_t>> InsertStridedSliceOp::getLinearIndices() {
+
+ // Stride > 1 to be considered if/when the insert_strided_slice supports it.
+ if (hasNonUnitStrides())
+ return failure();
+
+ // Only when the destination has a static size can the indices be enumerated.
+ if (getType().isScalable())
+ return failure();
+
+ // Only when the offsets are all static can the indices be enumerated.
+ FailureOr<SmallVector<int64_t>> offsets = intsFromArrayAttr(getOffsets());
+ if (failed(offsets))
+ return failure();
+
+ return getStridedSliceInsertionIndices(getSourceVectorType().getShape(),
+ getDestVectorType().getShape(),
+ offsets.value());
+}
+
// TODO: Should be moved to Tablegen ConfinedAttr attributes.
template <typename OpType>
static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op,
@@ -3638,6 +3733,25 @@ void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &result,
stridesAttr);
}
+FailureOr<SmallVector<int64_t>> ExtractStridedSliceOp::getLinearIndices() {
+
+ // Stride > 1 to be considered if/when extract_strided_slice supports it.
+ if (hasNonUnitStrides())
+ return failure();
+
+ // Only when the source has a static size can the indices be enumerated.
+ if (getSourceVectorType().isScalable())
+ return failure();
+
+ // Only when the offsets are all static can the indices be enumerated.
+ FailureOr<SmallVector<int64_t>> offsets = intsFromArrayAttr(getOffsets());
+ if (failed(offsets))
+ return failure();
+
+ return getStridedSliceInsertionIndices(
+ getType().getShape(), getSourceVectorType().getShape(), offsets.value());
+}
+
LogicalResult ExtractStridedSliceOp::verify() {
auto type = getSourceVectorType();
auto offsets = getOffsetsAttr();
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 678a88627ca82..6cf818bbd0695 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -109,90 +109,6 @@ struct LinearizeVectorizable final
}
};
-template <typename TOp>
-static bool stridesAllOne(TOp op) {
- static_assert(
- std::is_same_v<TOp, vector::ExtractStridedSliceOp> ||
- std::is_same_v<TOp, vector::InsertStridedSliceOp>,
- "expected vector.extract_strided_slice or vector.insert_strided_slice");
- ArrayAttr strides = op.getStrides();
- return llvm::all_of(strides, isOneInteger);
-}
-
-/// Convert an array of attributes into a vector of integers, if possible.
-static FailureOr<SmallVector<int64_t>> intsFromArrayAttr(ArrayAttr attrs) {
- if (!attrs)
- return failure();
- SmallVector<int64_t> ints;
- ints.reserve(attrs.size());
- for (auto attr : attrs) {
- if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
- ints.push_back(intAttr.getInt());
- } else {
- return failure();
- }
- }
- return ints;
-}
-
-/// Consider inserting a vector of shape `small` into a vector of shape `large`,
-/// at position `offsets`: this function enumeratates all the indices in `large`
-/// that are written to. The enumeration is with row-major ordering.
-///
-/// Example: insert a 1x2 vector into a 4x5 vector at position (1,3). The 2
-/// positions written to are (1,3) and (1,4), which have linearized indices 8
-/// and 9. So [8,9] is returned.
-///
-/// The length of the returned vector is equal to the number of elements in
-/// the shape `small` (i.e. the product of dimensions of `small`).
-SmallVector<int64_t> static getStridedSliceInsertionIndices(
- ArrayRef<int64_t> small, ArrayRef<int64_t> large,
- ArrayRef<int64_t> offsets) {
-
- // Example of alignment between, `large`, `small` and `offsets`:
- // large = 4, 5, 6, 7, 8
- // small = 1, 6, 7, 8
- // offsets = 2, 3, 0
- //
- // `offsets` has implicit trailing 0s, `small` has implicit leading 1s.
- assert((large.size() >= small.size()) &&
- "rank of 'large' cannot be lower than rank of 'small'");
- assert((large.size() >= offsets.size()) &&
- "rank of 'large' cannot be lower than the number of offsets");
- unsigned delta = large.size() - small.size();
- unsigned nOffsets = offsets.size();
- auto getSmall = [&](int64_t i) -> int64_t {
- return i >= delta ? small[i - delta] : 1;
- };
- auto getOffset = [&](int64_t i) -> int64_t {
- return i < nOffsets ? offsets[i] : 0;
- };
-
- // Using 2 vectors of indices, at each iteration populate the updated set of
- // indices based on the old set of indices, and the size of the small vector
- // in the current iteration.
- SmallVector<int64_t> indices{0};
- int64_t stride = 1;
- for (int i = large.size() - 1; i >= 0; --i) {
- int64_t currentSize = indices.size();
- int64_t smallSize = getSmall(i);
- int64_t nextSize = currentSize * smallSize;
- SmallVector<int64_t> nextIndices(nextSize);
- int64_t *base = nextIndices.begin();
- int64_t offset = getOffset(i) * stride;
- for (int j = 0; j < smallSize; ++j) {
- for (int k = 0; k < currentSize; ++k) {
- base[k] = indices[k] + offset;
- }
- offset += stride;
- base += currentSize;
- }
- stride *= large[i];
- indices = std::move(nextIndices);
- }
- return indices;
-}
-
/// This pattern converts a vector.extract_strided_slice operation into a
/// vector.shuffle operation that has a rank-1 (linearized) operand and result.
///
@@ -231,30 +147,23 @@ struct LinearizeVectorExtractStridedSlice final
// Expect a legalization failure if the strides are not all 1 (if ever the
// verifier for extract_strided_slice allows non-1 strides).
- if (!stridesAllOne(extractStridedSliceOp)) {
+ if (extractStridedSliceOp.hasNonUnitStrides()) {
return rewriter.notifyMatchFailure(
extractStridedSliceOp,
"extract_strided_slice with strides != 1 not supported");
}
- FailureOr<SmallVector<int64_t>> offsets =
- intsFromArrayAttr(extractStridedSliceOp.getOffsets());
- if (failed(offsets)) {
+ FailureOr<SmallVector<int64_t>> indices =
+ extractStridedSliceOp.getLinearIndices();
+ if (failed(indices)) {
return rewriter.notifyMatchFailure(extractStridedSliceOp,
- "failed to get integer offsets");
+ "failed to get indices");
}
- ArrayRef<int64_t> inputShape =
- extractStridedSliceOp.getSourceVectorType().getShape();
-
- ArrayRef<int64_t> outputShape = extractStridedSliceOp.getType().getShape();
-
- SmallVector<int64_t> indices = getStridedSliceInsertionIndices(
- outputShape, inputShape, offsets.value());
-
Value srcVector = adaptor.getVector();
- rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
- extractStridedSliceOp, flatOutputType, srcVector, srcVector, indices);
+ rewriter.replaceOpWithNewOp<vector::ShuffleOp>(extractStridedSliceOp,
+ flatOutputType, srcVector,
+ srcVector, indices.value());
return success();
}
};
@@ -298,31 +207,24 @@ struct LinearizeVectorInsertStridedSlice final
// Expect a legalization failure if the strides are not all 1 (if ever the
// verifier for insert_strided_slice allows non-1 strides).
- if (!stridesAllOne(insertStridedSliceOp)) {
+ if (insertStridedSliceOp.hasNonUnitStrides()) {
return rewriter.notifyMatchFailure(
insertStridedSliceOp,
"insert_strided_slice with strides != 1 not supported");
}
- VectorType inputType = insertStridedSliceOp.getValueToStore().getType();
- ArrayRef<int64_t> inputShape = inputType.getShape();
-
VectorType outputType = insertStridedSliceOp.getType();
- ArrayRef<int64_t> outputShape = outputType.getShape();
int64_t nOutputElements = outputType.getNumElements();
- FailureOr<SmallVector<int64_t>> offsets =
- intsFromArrayAttr(insertStridedSliceOp.getOffsets());
- if (failed(offsets)) {
+ FailureOr<SmallVector<int64_t>> sliceIndices =
+ insertStridedSliceOp.getLinearIndices();
+ if (failed(sliceIndices))
return rewriter.notifyMatchFailure(insertStridedSliceOp,
- "failed to get integer offsets");
- }
- SmallVector<int64_t> sliceIndices = getStridedSliceInsertionIndices(
- inputShape, outputShape, offsets.value());
+ "failed to get indices");
SmallVector<int64_t> indices(nOutputElements);
std::iota(indices.begin(), indices.end(), 0);
- for (auto [index, sliceIndex] : llvm::enumerate(sliceIndices)) {
+ for (auto [index, sliceIndex] : llvm::enumerate(sliceIndices.value())) {
indices[sliceIndex] = index + nOutputElements;
}
>From 1cc6345d5cde6972b3bd1d1708f0c4f152af349b Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Thu, 22 May 2025 12:06:23 -0700
Subject: [PATCH 2/2] comment improvement
---
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index f8412863b18c9..481523ff10c3f 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1114,7 +1114,7 @@ def Vector_InsertStridedSliceOp :
return ::llvm::cast<IntegerAttr>(attr).getInt() != 1;
});
}
- // \return The indices in dest that the values are inserted to.
+ // \return The indices in `dest` where values are stored.
FailureOr<SmallVector<int64_t>> getLinearIndices();
}];
@@ -1256,7 +1256,7 @@ def Vector_ExtractStridedSliceOp :
return ::llvm::cast<IntegerAttr>(attr).getInt() != 1;
});
}
- // \return The indices in source that the values are taken from.
+ // \return The indices in `source` where values are extracted.
FailureOr<SmallVector<int64_t>> getLinearIndices();
}];
let hasCanonicalizer = 1;
More information about the Mlir-commits
mailing list