[Mlir-commits] [mlir] [mlir][vector] Adds ToElementsToTargetShape pattern. (PR #166476)

Jakub Kuderski llvmlistbot at llvm.org
Fri Nov 7 09:51:31 PST 2025


================
@@ -834,11 +843,101 @@ struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
   vector::UnrollVectorOptions options;
 };
 
+/// Takes a 1 dimensional `vector.to_element` op and attempts to change it to
+/// the target shape.
+///
+/// ```
+/// // In SPIR-V's default environment vector of size 8
+/// // are not allowed.
+/// %elements:8 = vector.to_elements %v : vector<8xf32>
+///
+/// ===>
+///
+/// %v_0_to_3 = vector.extract %v[0] : vector<4xf32> from vector<8xf32>
+/// %v_4_to_7 = vector.extract %v[4] : vector<4xf32> from vector<8xf32>
+/// %elements_0:4 = vector.to_elements %v_0_to_3 : vector<4xf32>
+/// %elements_1:4 = vector.to_elements %v_4_to_7 : vector<4xf32>
+/// ```
+///
+/// This pattern may fail if the rank is not divisible by to a native shape
+/// or if the rank is already in the target shape and therefore it may be
+/// skipped.
+struct ToElementsToTargetShape final
+    : public OpRewritePattern<vector::ToElementsOp> {
+  ToElementsToTargetShape(MLIRContext *context,
+                          const vector::UnrollVectorOptions &options,
+                          PatternBenefit benefit = 1)
+      : OpRewritePattern<vector::ToElementsOp>(context, benefit),
+        options(options) {}
+
+  LogicalResult matchAndRewrite(vector::ToElementsOp op,
+                                PatternRewriter &rewriter) const override {
+    std::optional<SmallVector<int64_t>> targetShape =
+        getTargetShape(options, op);
+    if (!targetShape)
+      return failure();
+
+    // We have
+    // source_rank = N * target_rank
+    int64_t source_rank = op.getSourceVectorType().getShape().front();
+    int64_t target_rank = targetShape->front();
+    int64_t N = source_rank / target_rank;
+
+    // Transformation where
+    // s = source_rank and
+    // t = target_rank
+    // ```
+    // %e:s = vector.to_elements %v : vector<sxf32>
+    //
+    // ===>
+    //
+    // // N vector.extract_strided_slice of size t
+    // %v0 = vector.extract_strided_slice %v
+    //   {offsets = [0*t], sizes = [t], strides = [1]}
+    //   : vector<txf32> from vector<sxf32>
+    // %v1 = vector.extract_strided_slice %v
+    //   {offsets = [1*t], sizes = [t], strides = [1]}
+    //   : vector<txf32> from vector<sxf32>
+    // ...
+    // %vNminus1 = vector.extract_strided_slice $v
+    //   {offsets = [(N-1)*t], sizes = [t], strides = [1]}
+    //   : vector<txf32> from vector<sxf32>
+    //
+    // // N vector.to_elements of size t vectors.
+    // %e0:t = vector.to_elements %v0 : vector<txf32>
+    // %e1:t = vector.to_elements %v1 : vector<txf32>
+    // ...
+    // %eNminus1:t = vector.to_elements %vNminus1 : vector<txf32>
+    // ```
+    SmallVector<Value> subVectors;
+    SmallVector<int64_t> strides(targetShape->size(), 1);
+    for (int64_t i = 0; i < N; i++) {
+      SmallVector<int64_t> elementOffsets = {i * target_rank};
+      Value subVector = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+          op.getLoc(), op.getSource(), elementOffsets, *targetShape, strides);
+      subVectors.push_back(subVector);
+    }
+
+    SmallVector<Value> elements;
+    for (const Value subVector : subVectors) {
+      auto elementsOp =
----------------
kuhar wrote:

```suggestion
    for (Value subVector : subVectors) {
      auto elementsOp =
```

https://github.com/llvm/llvm-project/pull/166476


More information about the Mlir-commits mailing list