[Mlir-commits] [mlir] [mlir][vector] Adds ToElementsToTargetShape pattern. (PR #166476)
Jakub Kuderski
llvmlistbot at llvm.org
Fri Nov 7 09:51:31 PST 2025
================
@@ -834,11 +843,101 @@ struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
vector::UnrollVectorOptions options;
};
+/// Takes a 1 dimensional `vector.to_element` op and attempts to change it to
+/// the target shape.
+///
+/// ```
+/// // In SPIR-V's default environment vector of size 8
+/// // are not allowed.
+/// %elements:8 = vector.to_elements %v : vector<8xf32>
+///
+/// ===>
+///
+/// %v_0_to_3 = vector.extract %v[0] : vector<4xf32> from vector<8xf32>
+/// %v_4_to_7 = vector.extract %v[4] : vector<4xf32> from vector<8xf32>
+/// %elements_0:4 = vector.to_elements %v_0_to_3 : vector<4xf32>
+/// %elements_1:4 = vector.to_elements %v_4_to_7 : vector<4xf32>
+/// ```
+///
+/// This pattern may fail if the rank is not divisible by to a native shape
+/// or if the rank is already in the target shape and therefore it may be
+/// skipped.
+struct ToElementsToTargetShape final
+ : public OpRewritePattern<vector::ToElementsOp> {
+ ToElementsToTargetShape(MLIRContext *context,
+ const vector::UnrollVectorOptions &options,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<vector::ToElementsOp>(context, benefit),
+ options(options) {}
+
+ LogicalResult matchAndRewrite(vector::ToElementsOp op,
+ PatternRewriter &rewriter) const override {
+ std::optional<SmallVector<int64_t>> targetShape =
+ getTargetShape(options, op);
+ if (!targetShape)
+ return failure();
+
+ // We have
+ // source_rank = N * target_rank
+ int64_t source_rank = op.getSourceVectorType().getShape().front();
+ int64_t target_rank = targetShape->front();
+ int64_t N = source_rank / target_rank;
+
+ // Transformation where
+ // s = source_rank and
+ // t = target_rank
+ // ```
+ // %e:s = vector.to_elements %v : vector<sxf32>
+ //
+ // ===>
+ //
+ // // N vector.extract_strided_slice of size t
+ // %v0 = vector.extract_strided_slice %v
+ // {offsets = [0*t], sizes = [t], strides = [1]}
+ // : vector<txf32> from vector<sxf32>
+ // %v1 = vector.extract_strided_slice %v
+ // {offsets = [1*t], sizes = [t], strides = [1]}
+ // : vector<txf32> from vector<sxf32>
+ // ...
+ // %vNminus1 = vector.extract_strided_slice $v
+ // {offsets = [(N-1)*t], sizes = [t], strides = [1]}
+ // : vector<txf32> from vector<sxf32>
+ //
+ // // N vector.to_elements of size t vectors.
+ // %e0:t = vector.to_elements %v0 : vector<txf32>
+ // %e1:t = vector.to_elements %v1 : vector<txf32>
+ // ...
+ // %eNminus1:t = vector.to_elements %vNminus1 : vector<txf32>
+ // ```
+ SmallVector<Value> subVectors;
+ SmallVector<int64_t> strides(targetShape->size(), 1);
+ for (int64_t i = 0; i < N; i++) {
+ SmallVector<int64_t> elementOffsets = {i * target_rank};
+ Value subVector = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+ op.getLoc(), op.getSource(), elementOffsets, *targetShape, strides);
+ subVectors.push_back(subVector);
+ }
+
+ SmallVector<Value> elements;
+ for (const Value subVector : subVectors) {
+ auto elementsOp =
----------------
kuhar wrote:
```suggestion
for (Value subVector : subVectors) {
auto elementsOp =
```
https://github.com/llvm/llvm-project/pull/166476
More information about the Mlir-commits
mailing list