[Mlir-commits] [mlir] [mlir][vector] Add vector.to_elements unrolling (PR #157142)

Erick Ochoa Lopez llvmlistbot at llvm.org
Thu Sep 11 10:40:55 PDT 2025


https://github.com/amd-eochoalo updated https://github.com/llvm/llvm-project/pull/157142

>From d661413b787a15a91f772e2570333aad95166a68 Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Thu, 4 Sep 2025 17:19:24 -0700
Subject: [PATCH 01/21] [mlir][vector] Add support for lowering n-D
 vector.to_elements op.

The revision adds a pattern that flattens 2 or more dimensional
`vector.to_elements` ops by `vector.shape_cast` + `vector.to_elements`.

It also adds the lowering pattern to ConvertVectorToLLVMPass and
complete the tests.

It recovers the e2e lowering breakage from https://github.com/llvm/llvm-project/commit/b4c31dc98dfc929728904cd96f0f4cf812c4d5b5 on LLVM path.

Signed-off-by: hanhanW <hanhan0912 at gmail.com>
---
 .../Vector/Transforms/LoweringPatterns.h      |  6 +++
 .../VectorToLLVM/ConvertVectorToLLVMPass.cpp  |  1 +
 .../Dialect/Vector/Transforms/CMakeLists.txt  |  1 +
 .../Transforms/LowerVectorToElements.cpp      | 52 +++++++++++++++++++
 .../VectorToLLVM/vector-to-llvm.mlir          | 40 ++++++++++++++
 .../Vector/vector-to-elements-lowering.mlir   | 22 ++++++++
 .../Dialect/Vector/TestVectorTransforms.cpp   | 24 +++++++++
 7 files changed, 146 insertions(+)
 create mode 100644 mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
 create mode 100644 mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 47f96112a9433..e0f744841db2b 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -311,6 +311,12 @@ void populateVectorToFromElementsToShuffleTreePatterns(
 void populateVectorFromElementsLoweringPatterns(RewritePatternSet &patterns,
                                                 PatternBenefit benefit = 1);
 
+/// Populate the pattern set with the following patterns:
+///
+/// [FlattenToElements]
+void populateVectorToElementsLoweringPatterns(RewritePatternSet &patterns,
+                                              PatternBenefit benefit = 1);
+
 /// Populate the pattern set with the following patterns:
 ///
 /// [ContractionOpToMatmulOpLowering]
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 9852df6970fdc..0b44ca7ceee42 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -95,6 +95,7 @@ void ConvertVectorToLLVMPass::runOnOperation() {
     populateVectorRankReducingFMAPattern(patterns);
     populateVectorGatherLoweringPatterns(patterns);
     populateVectorFromElementsLoweringPatterns(patterns);
+    populateVectorToElementsLoweringPatterns(patterns);
     if (armI8MM) {
       if (armNeon)
         arm_neon::populateLowerContractionToNeonI8MMPatterns(patterns);
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index acbf2b746037b..d74007f13a95b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -11,6 +11,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
   LowerVectorScan.cpp
   LowerVectorShapeCast.cpp
   LowerVectorStep.cpp
+  LowerVectorToElements.cpp
   LowerVectorToFromElementsToShuffleTree.cpp
   LowerVectorTransfer.cpp
   LowerVectorTranspose.cpp
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
new file mode 100644
index 0000000000000..014034b8f9737
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
@@ -0,0 +1,52 @@
+//===- LowerVectorToElements.cpp - Lower 'vector.to_elements' op ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements target-independent rewrites and utilities to lower the
+// 'vector.to_elements' operation.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+
+#define DEBUG_TYPE "lower-vector-to-elements"
+
+using namespace mlir;
+
+namespace {
+
+/// Flattens 2 or more dimensional `vector.to_elements` ops by
+/// `vector.shape_cast` + `vector.to_elements`.
+struct FlattenToElements : OpRewritePattern<vector::ToElementsOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ToElementsOp op,
+                                PatternRewriter &rewriter) const override {
+    VectorType vecType = op.getSource().getType();
+    if (vecType.getRank() <= 1)
+      return rewriter.notifyMatchFailure(
+          op, "the rank is already less than or equal to 1");
+    if (vecType.getNumScalableDims() > 0)
+      return rewriter.notifyMatchFailure(
+          op, "scalable vector is not yet supported");
+    auto vec1DType =
+        VectorType::get({vecType.getNumElements()}, vecType.getElementType());
+    Value shapeCast = vector::ShapeCastOp::create(rewriter, op.getLoc(),
+                                                  vec1DType, op.getSource());
+    rewriter.replaceOpWithNewOp<vector::ToElementsOp>(op, op.getResultTypes(),
+                                                      shapeCast);
+    return success();
+  }
+};
+
+} // namespace
+
+void mlir::vector::populateVectorToElementsLoweringPatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit) {
+  patterns.add<FlattenToElements>(patterns.getContext(), benefit);
+}
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 07d335117de01..bf4b05f7874de 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1774,3 +1774,43 @@ func.func @from_elements_3d(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32) -> v
   %0 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x1x2xf32>
   return %0 : vector<2x1x2xf32>
 }
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// vector.to_elements
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func @to_elements_1d(
+// CHECK-SAME:    %[[ARG0:.+]]: vector<2xf32>
+// CHECK:         %[[C0:.+]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK:         %[[V0:.+]] = llvm.extractelement %[[ARG0]][%[[C0]] : i64] : vector<2xf32>
+// CHECK:         %[[C1:.+]] = llvm.mlir.constant(1 : i64) : i64
+// CHECK:         %[[V1:.+]] = llvm.extractelement %[[ARG0]][%[[C1]] : i64] : vector<2xf32>
+// CHECK:         return %[[V0]], %[[V1]]
+func.func @to_elements_1d(%arg0: vector<2xf32>) -> (f32, f32) {
+  %0:2 = vector.to_elements %arg0 : vector<2xf32>
+  return %0#0, %0#1 : f32, f32
+}
+
+// -----
+
+// NOTE: We flatten multi-dimensional to_elements ops with pattern
+// `FlattenToElements` and then convert the 1-D to_elements ops to llvm.
+
+// CHECK-LABEL: func @to_elements_2d(
+// CHECK-SAME:    %[[ARG0:.+]]: vector<2x2xf32>
+// CHECK:         %[[CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : vector<2x2xf32> to !llvm.array<2 x vector<2xf32>>
+// CHECK:         %[[C0:.+]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK:         %[[V0:.+]] = llvm.extractelement %{{.+}}[%[[C0]] : i64] : vector<4xf32>
+// CHECK:         %[[C1:.+]] = llvm.mlir.constant(1 : i64) : i64
+// CHECK:         %[[V1:.+]] = llvm.extractelement %{{.+}}[%[[C1]] : i64] : vector<4xf32>
+// CHECK:         %[[C2:.+]] = llvm.mlir.constant(2 : i64) : i64
+// CHECK:         %[[V2:.+]] = llvm.extractelement %{{.+}}[%[[C2]] : i64] : vector<4xf32>
+// CHECK:         %[[C3:.+]] = llvm.mlir.constant(3 : i64) : i64
+// CHECK:         %[[V3:.+]] = llvm.extractelement %{{.+}}[%[[C3]] : i64] : vector<4xf32>
+// CHECK:         return %[[V0]], %[[V1]], %[[V2]], %[[V3]]
+func.func @to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) {
+  %0:4 = vector.to_elements %arg0 : vector<2x2xf32>
+  return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
+}
diff --git a/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir b/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir
new file mode 100644
index 0000000000000..a57521c4db467
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir
@@ -0,0 +1,22 @@
+// RUN: mlir-opt %s -test-flatten-vector-to-elements -split-input-file | FileCheck %s
+
+// CHECK-LABEL: func.func @to_elements_1d(
+// CHECK-SAME:    %[[ARG0:.+]]: vector<2xf32>
+// CHECK:         %[[RES:.+]]:2 = vector.to_elements %[[ARG0]] : vector<2xf32>
+// CHECK:         return %[[RES]]#0, %[[RES]]#1
+func.func @to_elements_1d(%arg0: vector<2xf32>) -> (f32, f32) {
+  %0:2 = vector.to_elements %arg0 : vector<2xf32>
+  return %0#0, %0#1 : f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @to_elements_2d(
+// CHECK-SAME:    %[[ARG0:.+]]: vector<2x2xf32>
+// CHECK:         %[[CAST:.+]] = vector.shape_cast %[[ARG0]]
+// CHECK:         %[[RES:.+]]:4 = vector.to_elements %[[CAST]] : vector<4xf32>
+// CHECK:         return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2, %[[RES]]#3
+func.func @to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) {
+  %0:4 = vector.to_elements %arg0 : vector<2x2xf32>
+  return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
+}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index bb1598ee3efe5..560a1331bdaf0 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -808,6 +808,28 @@ struct TestUnrollVectorFromElements
   }
 };
 
+struct TestFlattenVectorToElements
+    : public PassWrapper<TestFlattenVectorToElements,
+                         OperationPass<func::FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFlattenVectorToElements)
+
+  StringRef getArgument() const final {
+    return "test-flatten-vector-to-elements";
+  }
+  StringRef getDescription() const final {
+    return "Test flattening patterns for to_elements ops";
+  }
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<func::FuncDialect, vector::VectorDialect>();
+  }
+
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    populateVectorToElementsLoweringPatterns(patterns);
+    (void)applyPatternsGreedily(getOperation(), std::move(patterns));
+  }
+};
+
 struct TestFoldArithExtensionIntoVectorContractPatterns
     : public PassWrapper<TestFoldArithExtensionIntoVectorContractPatterns,
                          OperationPass<func::FuncOp>> {
@@ -1083,6 +1105,8 @@ void registerTestVectorLowerings() {
 
   PassRegistration<TestUnrollVectorFromElements>();
 
+  PassRegistration<TestFlattenVectorToElements>();
+
   PassRegistration<TestFoldArithExtensionIntoVectorContractPatterns>();
 
   PassRegistration<TestVectorEmulateMaskedLoadStore>();

>From edf4019862d3448a31bd2d4052fc9d11259a7e37 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 5 Sep 2025 09:48:27 -0700
Subject: [PATCH 02/21] Add new populate patterns for flattening.

---
 .../mlir/Dialect/Vector/Transforms/LoweringPatterns.h       | 6 ++++++
 .../lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp | 5 +++++
 mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp       | 2 +-
 3 files changed, 12 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index e0f744841db2b..c39c9d4ae00c9 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -317,6 +317,12 @@ void populateVectorFromElementsLoweringPatterns(RewritePatternSet &patterns,
 void populateVectorToElementsLoweringPatterns(RewritePatternSet &patterns,
                                               PatternBenefit benefit = 1);
 
+/// Populate the pattern set with the following patterns:
+///
+/// [FlattenToElements]
+void populateVectorToElementsFlatteningPatterns(RewritePatternSet &patterns,
+                                                PatternBenefit benefit = 1);
+
 /// Populate the pattern set with the following patterns:
 ///
 /// [ContractionOpToMatmulOpLowering]
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
index 014034b8f9737..33c5d2cb33369 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
@@ -50,3 +50,8 @@ void mlir::vector::populateVectorToElementsLoweringPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit) {
   patterns.add<FlattenToElements>(patterns.getContext(), benefit);
 }
+
+void mlir::vector::populateVectorToElementsFlatteningPatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit) {
+  patterns.add<FlattenToElements>(patterns.getContext(), benefit);
+}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 560a1331bdaf0..01a00509c7331 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -825,7 +825,7 @@ struct TestFlattenVectorToElements
 
   void runOnOperation() override {
     RewritePatternSet patterns(&getContext());
-    populateVectorToElementsLoweringPatterns(patterns);
+    populateVectorToElementsFlatteningPatterns(patterns);
     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
   }
 };

>From 12fd6fc77b44f25020aee2b0193a02117d2fc1e1 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 5 Sep 2025 08:49:00 -0700
Subject: [PATCH 03/21] [mlir][vector] Add function to unroll vectors.

Extract n vector<axbx...> from vector<nxaxbx...>

This patch adds a utility function that will unroll vector values.
This is different from the current utility function that focuses
on unrolling vector operations.
---
 .../mlir/Dialect/Vector/Utils/VectorUtils.h   |  3 +++
 mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp | 21 +++++++++++++++++++
 2 files changed, 24 insertions(+)

diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index ace26990601c8..95f2ee5a7ac1d 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -255,6 +255,9 @@ using UnrollVectorOpFn =
 LogicalResult unrollVectorOp(Operation *op, PatternRewriter &rewriter,
                              UnrollVectorOpFn unrollFn);
 
+LogicalResult unrollVectorValue(Value vector, PatternRewriter &rewriter,
+                                SmallVector<Value> &values);
+
 } // namespace vector
 
 /// Constructs a permutation map of invariant memref indices to vector
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 841e1384e03b3..cbedd9563fc29 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -393,6 +393,27 @@ vector::isValidMaskedInputVector(ArrayRef<int64_t> shape,
   return success();
 }
 
+LogicalResult vector::unrollVectorValue(Value vector, PatternRewriter &rewriter,
+                                        SmallVector<Value> &subvectors) {
+  assert(isa<VectorType>(vector.getType()) && "expected vector type");
+  VectorType ty = cast<VectorType>(vector.getType());
+  Location loc = vector.getLoc();
+  if (ty.getRank() < 2)
+    return rewriter.notifyMatchFailure(loc, "already 1-D");
+
+  // Unrolling doesn't take vscale into account. Pattern is disabled for
+  // vectors with leading scalable dim(s).
+  if (ty.getScalableDims().front())
+    return rewriter.notifyMatchFailure(loc, "cannot unroll scalable dim");
+
+  // We just need zero indices for the all dimensions except the leading one.
+  for (int64_t i = 0, e = ty.getShape().front(); i < e; ++i) {
+    subvectors.push_back(vector::ExtractOp::create(rewriter, loc, vector, i));
+  }
+
+  return success();
+}
+
 LogicalResult vector::unrollVectorOp(Operation *op, PatternRewriter &rewriter,
                                      vector::UnrollVectorOpFn unrollFn) {
   assert(op->getNumResults() == 1 && "expected single result");

>From 0415e3b066c258e2b12aba4e074f62c07b1ee3d1 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 5 Sep 2025 09:39:24 -0700
Subject: [PATCH 04/21] [mlir][vector] Add vector.to_elements unrolling.

---
 .../Vector/Transforms/LoweringPatterns.h      |  2 +-
 .../Transforms/LowerVectorToElements.cpp      | 28 ++++++++++++++++++-
 2 files changed, 28 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index c39c9d4ae00c9..31150a2afc19f 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -313,7 +313,7 @@ void populateVectorFromElementsLoweringPatterns(RewritePatternSet &patterns,
 
 /// Populate the pattern set with the following patterns:
 ///
-/// [FlattenToElements]
+/// [UnrollToElements]
 void populateVectorToElementsLoweringPatterns(RewritePatternSet &patterns,
                                               PatternBenefit benefit = 1);
 
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
index 33c5d2cb33369..b86e8b274770f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
@@ -20,6 +20,32 @@ using namespace mlir;
 
 namespace {
 
+struct UnrollToElements : OpRewritePattern<vector::ToElementsOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ToElementsOp op,
+                                PatternRewriter &rewriter) const override {
+    SmallVector<Value> vectors;
+    LogicalResult match =
+        mlir::vector::unrollVectorValue(op.getSource(), rewriter, vectors);
+    if (failed(match)) {
+      return match;
+    }
+
+    // May be large vector.
+    std::vector<Value> results;
+    for (const auto &vector : vectors) {
+      // we need to replace the current result
+      auto subElements =
+          rewriter.create<vector::ToElementsOp>(op.getLoc(), vector);
+      results.insert(results.end(), subElements.getResults().begin(),
+                     subElements.getResults().end());
+    }
+    rewriter.replaceOp(op, results);
+    return success();
+  }
+};
+
 /// Flattens 2 or more dimensional `vector.to_elements` ops by
 /// `vector.shape_cast` + `vector.to_elements`.
 struct FlattenToElements : OpRewritePattern<vector::ToElementsOp> {
@@ -48,7 +74,7 @@ struct FlattenToElements : OpRewritePattern<vector::ToElementsOp> {
 
 void mlir::vector::populateVectorToElementsLoweringPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit) {
-  patterns.add<FlattenToElements>(patterns.getContext(), benefit);
+  patterns.add<UnrollToElements>(patterns.getContext(), benefit);
 }
 
 void mlir::vector::populateVectorToElementsFlatteningPatterns(

>From 1b21117c35c588486aa737bdc7f5a37f5c52c1fb Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 5 Sep 2025 10:02:39 -0700
Subject: [PATCH 05/21] Split tests for unrolling and flattening to elements

---
 .../Vector/vector-to-elements-flattening.mlir | 22 +++++++++++++++++
 .../Vector/vector-to-elements-lowering.mlir   | 10 ++++----
 .../Dialect/Vector/TestVectorTransforms.cpp   | 24 +++++++++++++++++++
 3 files changed, 52 insertions(+), 4 deletions(-)
 create mode 100644 mlir/test/Dialect/Vector/vector-to-elements-flattening.mlir

diff --git a/mlir/test/Dialect/Vector/vector-to-elements-flattening.mlir b/mlir/test/Dialect/Vector/vector-to-elements-flattening.mlir
new file mode 100644
index 0000000000000..a57521c4db467
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-to-elements-flattening.mlir
@@ -0,0 +1,22 @@
+// RUN: mlir-opt %s -test-flatten-vector-to-elements -split-input-file | FileCheck %s
+
+// CHECK-LABEL: func.func @to_elements_1d(
+// CHECK-SAME:    %[[ARG0:.+]]: vector<2xf32>
+// CHECK:         %[[RES:.+]]:2 = vector.to_elements %[[ARG0]] : vector<2xf32>
+// CHECK:         return %[[RES]]#0, %[[RES]]#1
+func.func @to_elements_1d(%arg0: vector<2xf32>) -> (f32, f32) {
+  %0:2 = vector.to_elements %arg0 : vector<2xf32>
+  return %0#0, %0#1 : f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @to_elements_2d(
+// CHECK-SAME:    %[[ARG0:.+]]: vector<2x2xf32>
+// CHECK:         %[[CAST:.+]] = vector.shape_cast %[[ARG0]]
+// CHECK:         %[[RES:.+]]:4 = vector.to_elements %[[CAST]] : vector<4xf32>
+// CHECK:         return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2, %[[RES]]#3
+func.func @to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) {
+  %0:4 = vector.to_elements %arg0 : vector<2x2xf32>
+  return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
+}
diff --git a/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir b/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir
index a57521c4db467..e302dbd174322 100644
--- a/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-flatten-vector-to-elements -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-unroll-vector-to-elements -split-input-file | FileCheck %s
 
 // CHECK-LABEL: func.func @to_elements_1d(
 // CHECK-SAME:    %[[ARG0:.+]]: vector<2xf32>
@@ -13,9 +13,11 @@ func.func @to_elements_1d(%arg0: vector<2xf32>) -> (f32, f32) {
 
 // CHECK-LABEL: func.func @to_elements_2d(
 // CHECK-SAME:    %[[ARG0:.+]]: vector<2x2xf32>
-// CHECK:         %[[CAST:.+]] = vector.shape_cast %[[ARG0]]
-// CHECK:         %[[RES:.+]]:4 = vector.to_elements %[[CAST]] : vector<4xf32>
-// CHECK:         return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2, %[[RES]]#3
+// CHECK:         %[[VEC0:.+]] = vector.extract %[[ARG0]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK:         %[[VEC1:.+]] = vector.extract %[[ARG0]][1] : vector<2xf32> from vector<2x2xf32>
+// CHECK:         %[[RES0:.+]]:2 = vector.to_elements %[[VEC0]] : vector<2xf32>
+// CHECK:         %[[RES1:.+]]:2 = vector.to_elements %[[VEC1]] : vector<2xf32>
+// CHECK:         return %[[RES0]]#0, %[[RES0]]#1, %[[RES1]]#0, %[[RES1]]#1
 func.func @to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) {
   %0:4 = vector.to_elements %arg0 : vector<2x2xf32>
   return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 01a00509c7331..093134c119cea 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -830,6 +830,28 @@ struct TestFlattenVectorToElements
   }
 };
 
+struct TestUnrollVectorToElements
+    : public PassWrapper<TestUnrollVectorToElements,
+                         OperationPass<func::FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestUnrollVectorToElements)
+
+  StringRef getArgument() const final {
+    return "test-unroll-vector-to-elements";
+  }
+  StringRef getDescription() const final {
+    return "Test unrolling patterns for to_elements ops";
+  }
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<func::FuncDialect, vector::VectorDialect>();
+  }
+
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    populateVectorToElementsLoweringPatterns(patterns);
+    (void)applyPatternsGreedily(getOperation(), std::move(patterns));
+  }
+};
+
 struct TestFoldArithExtensionIntoVectorContractPatterns
     : public PassWrapper<TestFoldArithExtensionIntoVectorContractPatterns,
                          OperationPass<func::FuncOp>> {
@@ -1105,6 +1127,8 @@ void registerTestVectorLowerings() {
 
   PassRegistration<TestUnrollVectorFromElements>();
 
+  PassRegistration<TestUnrollVectorToElements>();
+
   PassRegistration<TestFlattenVectorToElements>();
 
   PassRegistration<TestFoldArithExtensionIntoVectorContractPatterns>();

>From 567af4bde157e7c266dcf0df5be645d370a90086 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 5 Sep 2025 10:09:47 -0700
Subject: [PATCH 06/21] Fix test

---
 .../VectorToLLVM/vector-to-llvm.mlir          | 20 ++++++++++---------
 1 file changed, 11 insertions(+), 9 deletions(-)

diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index bf4b05f7874de..2d33888854ea7 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1795,21 +1795,23 @@ func.func @to_elements_1d(%arg0: vector<2xf32>) -> (f32, f32) {
 
 // -----
 
-// NOTE: We flatten multi-dimensional to_elements ops with pattern
-// `FlattenToElements` and then convert the 1-D to_elements ops to llvm.
+// NOTE: We unroll multi-dimensional to_elements ops with pattern
+// `UnrollToElements` and then convert the 1-D to_elements ops to llvm.
 
 // CHECK-LABEL: func @to_elements_2d(
 // CHECK-SAME:    %[[ARG0:.+]]: vector<2x2xf32>
 // CHECK:         %[[CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : vector<2x2xf32> to !llvm.array<2 x vector<2xf32>>
+// CHECK:         %[[V0:.+]] = llvm.extractvalue %[[CAST]][0] : !llvm.array<2 x vector<2xf32>>
+// CHECK:         %[[V1:.+]] = llvm.extractvalue %[[CAST]][1] : !llvm.array<2 x vector<2xf32>>
 // CHECK:         %[[C0:.+]] = llvm.mlir.constant(0 : i64) : i64
-// CHECK:         %[[V0:.+]] = llvm.extractelement %{{.+}}[%[[C0]] : i64] : vector<4xf32>
+// CHECK:         %[[R0:.+]] = llvm.extractelement %[[V0]][%[[C0]] : i64] : vector<2xf32>
 // CHECK:         %[[C1:.+]] = llvm.mlir.constant(1 : i64) : i64
-// CHECK:         %[[V1:.+]] = llvm.extractelement %{{.+}}[%[[C1]] : i64] : vector<4xf32>
-// CHECK:         %[[C2:.+]] = llvm.mlir.constant(2 : i64) : i64
-// CHECK:         %[[V2:.+]] = llvm.extractelement %{{.+}}[%[[C2]] : i64] : vector<4xf32>
-// CHECK:         %[[C3:.+]] = llvm.mlir.constant(3 : i64) : i64
-// CHECK:         %[[V3:.+]] = llvm.extractelement %{{.+}}[%[[C3]] : i64] : vector<4xf32>
-// CHECK:         return %[[V0]], %[[V1]], %[[V2]], %[[V3]]
+// CHECK:         %[[R1:.+]] = llvm.extractelement %[[V0]][%[[C1]] : i64] : vector<2xf32>
+// CHECK:         %[[C0:.+]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK:         %[[R2:.+]] = llvm.extractelement %[[V1]][%[[C0]] : i64] : vector<2xf32>
+// CHECK:         %[[C1:.+]] = llvm.mlir.constant(1 : i64) : i64
+// CHECK:         %[[R3:.+]] = llvm.extractelement %[[V1]][%[[C1]] : i64] : vector<2xf32>
+// CHECK:         return %[[R0]], %[[R1]], %[[R2]], %[[R3]]
 func.func @to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) {
   %0:4 = vector.to_elements %arg0 : vector<2x2xf32>
   return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32

>From c58b5c41bcb66ea43547f87ad811bb57fa0b152e Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 5 Sep 2025 12:25:12 -0700
Subject: [PATCH 07/21] Address review comments

---
 .../Transforms/LowerVectorToElements.cpp      | 25 +++++++++----------
 mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp |  2 --
 2 files changed, 12 insertions(+), 15 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
index b86e8b274770f..b897b15d7d690 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
@@ -20,26 +20,25 @@ using namespace mlir;
 
 namespace {
 
-struct UnrollToElements : OpRewritePattern<vector::ToElementsOp> {
+struct UnrollToElements final : OpRewritePattern<vector::ToElementsOp> {
   using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(vector::ToElementsOp op,
                                 PatternRewriter &rewriter) const override {
     SmallVector<Value> vectors;
-    LogicalResult match =
-        mlir::vector::unrollVectorValue(op.getSource(), rewriter, vectors);
-    if (failed(match)) {
+    if (LogicalResult match =
+            vector::unrollVectorValue(op.getSource(), rewriter, vectors);
+        failed(match)) {
       return match;
     }
 
-    // May be large vector.
-    std::vector<Value> results;
-    for (const auto &vector : vectors) {
+    // May be a large vector.
+    SmallVector<Value, 0> results;
+    for (const Value &vector : vectors) {
       // we need to replace the current result
       auto subElements =
           rewriter.create<vector::ToElementsOp>(op.getLoc(), vector);
-      results.insert(results.end(), subElements.getResults().begin(),
-                     subElements.getResults().end());
+      llvm::append_range(results, subElements.getResults());
     }
     rewriter.replaceOp(op, results);
     return success();
@@ -48,7 +47,7 @@ struct UnrollToElements : OpRewritePattern<vector::ToElementsOp> {
 
 /// Flattens 2 or more dimensional `vector.to_elements` ops by
 /// `vector.shape_cast` + `vector.to_elements`.
-struct FlattenToElements : OpRewritePattern<vector::ToElementsOp> {
+struct FlattenToElements final : OpRewritePattern<vector::ToElementsOp> {
   using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(vector::ToElementsOp op,
@@ -57,9 +56,9 @@ struct FlattenToElements : OpRewritePattern<vector::ToElementsOp> {
     if (vecType.getRank() <= 1)
       return rewriter.notifyMatchFailure(
           op, "the rank is already less than or equal to 1");
-    if (vecType.getNumScalableDims() > 0)
-      return rewriter.notifyMatchFailure(
-          op, "scalable vector is not yet supported");
+
+    assert(vecType.getNumScalableDims() == 0 &&
+           "scalable vector is not yet supported");
     auto vec1DType =
         VectorType::get({vecType.getNumElements()}, vecType.getElementType());
     Value shapeCast = vector::ShapeCastOp::create(rewriter, op.getLoc(),
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index cbedd9563fc29..8d67cd7e80382 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -395,7 +395,6 @@ vector::isValidMaskedInputVector(ArrayRef<int64_t> shape,
 
 LogicalResult vector::unrollVectorValue(Value vector, PatternRewriter &rewriter,
                                         SmallVector<Value> &subvectors) {
-  assert(isa<VectorType>(vector.getType()) && "expected vector type");
   VectorType ty = cast<VectorType>(vector.getType());
   Location loc = vector.getLoc();
   if (ty.getRank() < 2)
@@ -406,7 +405,6 @@ LogicalResult vector::unrollVectorValue(Value vector, PatternRewriter &rewriter,
   if (ty.getScalableDims().front())
     return rewriter.notifyMatchFailure(loc, "cannot unroll scalable dim");
 
-  // We just need zero indices for the all dimensions except the leading one.
   for (int64_t i = 0, e = ty.getShape().front(); i < e; ++i) {
     subvectors.push_back(vector::ExtractOp::create(rewriter, loc, vector, i));
   }

>From 871f5a5873a765a89983a8d967d458faed6b3f2e Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 5 Sep 2025 12:26:52 -0700
Subject: [PATCH 08/21] Remove unnecessary comment

---
 mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp | 1 -
 1 file changed, 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
index b897b15d7d690..43ba1fe885e85 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
@@ -35,7 +35,6 @@ struct UnrollToElements final : OpRewritePattern<vector::ToElementsOp> {
     // May be a large vector.
     SmallVector<Value, 0> results;
     for (const Value &vector : vectors) {
-      // we need to replace the current result
       auto subElements =
           rewriter.create<vector::ToElementsOp>(op.getLoc(), vector);
       llvm::append_range(results, subElements.getResults());

>From 244eed587f232ae8c67af133887b6e6f3e8e351d Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 5 Sep 2025 13:00:39 -0700
Subject: [PATCH 09/21] Address review comments

---
 .../lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp | 6 ++----
 1 file changed, 2 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
index 43ba1fe885e85..e1bb07287ae9b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
@@ -26,10 +26,8 @@ struct UnrollToElements final : OpRewritePattern<vector::ToElementsOp> {
   LogicalResult matchAndRewrite(vector::ToElementsOp op,
                                 PatternRewriter &rewriter) const override {
     SmallVector<Value> vectors;
-    if (LogicalResult match =
-            vector::unrollVectorValue(op.getSource(), rewriter, vectors);
-        failed(match)) {
-      return match;
+    if (failed(vector::unrollVectorValue(op.getSource(), rewriter, vectors))) {
+      return failure();
     }
 
     // May be a large vector.

>From 9933387796fba4c5155b718df9f1f0d5e20d47e4 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 9 Sep 2025 12:40:49 -0700
Subject: [PATCH 10/21] Add comment

---
 mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h | 7 +++++++
 1 file changed, 7 insertions(+)

diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index 95f2ee5a7ac1d..985f90f6ed955 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -255,6 +255,13 @@ using UnrollVectorOpFn =
 LogicalResult unrollVectorOp(Operation *op, PatternRewriter &rewriter,
                              UnrollVectorOpFn unrollFn);
 
+/// Generic utility for mapping values of type vector<nxaxbx...>
+/// to n values of type vector<axbx...>
+/// Follows the following pattern:
+/// 1. Check if already 1-D. If so, return failure.
+/// 2. Check for scalable dimensions. If so, return failure.
+/// 3. Returns the values of n vector.extract operations corresponding
+///    to the outermost dimension.
 LogicalResult unrollVectorValue(Value vector, PatternRewriter &rewriter,
                                 SmallVector<Value> &values);
 

>From 351086b8983a77343a4716e5b1f237d07ec3e488 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 9 Sep 2025 12:18:09 -0700
Subject: [PATCH 11/21] Disable and move vector flattening

---
 .../Vector/Transforms/LoweringPatterns.h      |  6 ----
 .../Transforms/LowerVectorToElements.cpp      | 29 -------------------
 .../Vector/Transforms/VectorLinearize.cpp     | 24 +++++++++++++++
 .../Vector/vector-to-elements-flattening.mlir | 22 --------------
 .../Dialect/Vector/TestVectorTransforms.cpp   | 24 ---------------
 5 files changed, 24 insertions(+), 81 deletions(-)
 delete mode 100644 mlir/test/Dialect/Vector/vector-to-elements-flattening.mlir

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 31150a2afc19f..f56124cb4fb95 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -317,12 +317,6 @@ void populateVectorFromElementsLoweringPatterns(RewritePatternSet &patterns,
 void populateVectorToElementsLoweringPatterns(RewritePatternSet &patterns,
                                               PatternBenefit benefit = 1);
 
-/// Populate the pattern set with the following patterns:
-///
-/// [FlattenToElements]
-void populateVectorToElementsFlatteningPatterns(RewritePatternSet &patterns,
-                                                PatternBenefit benefit = 1);
-
 /// Populate the pattern set with the following patterns:
 ///
 /// [ContractionOpToMatmulOpLowering]
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
index e1bb07287ae9b..718b947f13715 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
@@ -42,38 +42,9 @@ struct UnrollToElements final : OpRewritePattern<vector::ToElementsOp> {
   }
 };
 
-/// Flattens 2 or more dimensional `vector.to_elements` ops by
-/// `vector.shape_cast` + `vector.to_elements`.
-struct FlattenToElements final : OpRewritePattern<vector::ToElementsOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(vector::ToElementsOp op,
-                                PatternRewriter &rewriter) const override {
-    VectorType vecType = op.getSource().getType();
-    if (vecType.getRank() <= 1)
-      return rewriter.notifyMatchFailure(
-          op, "the rank is already less than or equal to 1");
-
-    assert(vecType.getNumScalableDims() == 0 &&
-           "scalable vector is not yet supported");
-    auto vec1DType =
-        VectorType::get({vecType.getNumElements()}, vecType.getElementType());
-    Value shapeCast = vector::ShapeCastOp::create(rewriter, op.getLoc(),
-                                                  vec1DType, op.getSource());
-    rewriter.replaceOpWithNewOp<vector::ToElementsOp>(op, op.getResultTypes(),
-                                                      shapeCast);
-    return success();
-  }
-};
-
 } // namespace
 
 void mlir::vector::populateVectorToElementsLoweringPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit) {
   patterns.add<UnrollToElements>(patterns.getContext(), benefit);
 }
-
-void mlir::vector::populateVectorToElementsFlatteningPatterns(
-    RewritePatternSet &patterns, PatternBenefit benefit) {
-  patterns.add<FlattenToElements>(patterns.getContext(), benefit);
-}
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 7dde6311fa809..0f5e9259d4c19 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -798,6 +798,30 @@ struct LinearizeVectorFromElements final
   }
 };
 
+/// Flattens 2 or more dimensional `vector.to_elements` ops by
+/// `vector.shape_cast` + `vector.to_elements`.
+struct FlattenToElements final : OpRewritePattern<vector::ToElementsOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ToElementsOp op,
+                                PatternRewriter &rewriter) const override {
+    VectorType vecType = op.getSource().getType();
+    if (vecType.getRank() <= 1)
+      return rewriter.notifyMatchFailure(
+          op, "the rank is already less than or equal to 1");
+
+    assert(vecType.getNumScalableDims() == 0 &&
+           "scalable vector is not yet supported");
+    auto vec1DType =
+        VectorType::get({vecType.getNumElements()}, vecType.getElementType());
+    Value shapeCast = vector::ShapeCastOp::create(rewriter, op.getLoc(),
+                                                  vec1DType, op.getSource());
+    rewriter.replaceOpWithNewOp<vector::ToElementsOp>(op, op.getResultTypes(),
+                                                      shapeCast);
+    return success();
+  }
+};
+
 } // namespace
 
 /// This method defines the set of operations that are linearizable, and hence
diff --git a/mlir/test/Dialect/Vector/vector-to-elements-flattening.mlir b/mlir/test/Dialect/Vector/vector-to-elements-flattening.mlir
deleted file mode 100644
index a57521c4db467..0000000000000
--- a/mlir/test/Dialect/Vector/vector-to-elements-flattening.mlir
+++ /dev/null
@@ -1,22 +0,0 @@
-// RUN: mlir-opt %s -test-flatten-vector-to-elements -split-input-file | FileCheck %s
-
-// CHECK-LABEL: func.func @to_elements_1d(
-// CHECK-SAME:    %[[ARG0:.+]]: vector<2xf32>
-// CHECK:         %[[RES:.+]]:2 = vector.to_elements %[[ARG0]] : vector<2xf32>
-// CHECK:         return %[[RES]]#0, %[[RES]]#1
-func.func @to_elements_1d(%arg0: vector<2xf32>) -> (f32, f32) {
-  %0:2 = vector.to_elements %arg0 : vector<2xf32>
-  return %0#0, %0#1 : f32, f32
-}
-
-// -----
-
-// CHECK-LABEL: func.func @to_elements_2d(
-// CHECK-SAME:    %[[ARG0:.+]]: vector<2x2xf32>
-// CHECK:         %[[CAST:.+]] = vector.shape_cast %[[ARG0]]
-// CHECK:         %[[RES:.+]]:4 = vector.to_elements %[[CAST]] : vector<4xf32>
-// CHECK:         return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2, %[[RES]]#3
-func.func @to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) {
-  %0:4 = vector.to_elements %arg0 : vector<2x2xf32>
-  return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
-}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 093134c119cea..d6596cd341df7 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -808,28 +808,6 @@ struct TestUnrollVectorFromElements
   }
 };
 
-struct TestFlattenVectorToElements
-    : public PassWrapper<TestFlattenVectorToElements,
-                         OperationPass<func::FuncOp>> {
-  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFlattenVectorToElements)
-
-  StringRef getArgument() const final {
-    return "test-flatten-vector-to-elements";
-  }
-  StringRef getDescription() const final {
-    return "Test flattening patterns for to_elements ops";
-  }
-  void getDependentDialects(DialectRegistry &registry) const override {
-    registry.insert<func::FuncDialect, vector::VectorDialect>();
-  }
-
-  void runOnOperation() override {
-    RewritePatternSet patterns(&getContext());
-    populateVectorToElementsFlatteningPatterns(patterns);
-    (void)applyPatternsGreedily(getOperation(), std::move(patterns));
-  }
-};
-
 struct TestUnrollVectorToElements
     : public PassWrapper<TestUnrollVectorToElements,
                          OperationPass<func::FuncOp>> {
@@ -1129,8 +1107,6 @@ void registerTestVectorLowerings() {
 
   PassRegistration<TestUnrollVectorToElements>();
 
-  PassRegistration<TestFlattenVectorToElements>();
-
   PassRegistration<TestFoldArithExtensionIntoVectorContractPatterns>();
 
   PassRegistration<TestVectorEmulateMaskedLoadStore>();

>From 9f7d15d1c0ba453db4caa9b86281c3dc5f622bfc Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 9 Sep 2025 12:30:50 -0700
Subject: [PATCH 12/21] Re-enable and rename flattening to linearize

---
 .../Vector/Transforms/VectorLinearize.cpp     | 49 +++++++++++++------
 mlir/test/Dialect/Vector/linearize.mlir       | 23 +++++++++
 2 files changed, 57 insertions(+), 15 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 0f5e9259d4c19..54eb182a9680f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -798,26 +798,45 @@ struct LinearizeVectorFromElements final
   }
 };
 
-/// Flattens 2 or more dimensional `vector.to_elements` ops by
-/// `vector.shape_cast` + `vector.to_elements`.
-struct FlattenToElements final : OpRewritePattern<vector::ToElementsOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(vector::ToElementsOp op,
-                                PatternRewriter &rewriter) const override {
-    VectorType vecType = op.getSource().getType();
+/// This pattern linearizes the operand in `vector.to_elements` operations
+/// by converting the result type to a 1-D vector while preserving all element
+/// values. The transformation creates a linearized `vector.shape_cast`
+/// followed by a `vector.to_elements`.
+///
+/// Example:
+///
+///     %0:4 = vector.to_elements %v : vector<2x2xf32>
+///
+/// is converted to:
+///
+///     %vector_cast = vector.shape_cast %v : vector<2x2xf32> to vector<4xf32>
+///     %0:4 = vector.to_elements %vector_cast : vector<4xf32>
+///
+struct LinearizeVectorToElements final
+    : public OpConversionPattern<vector::ToElementsOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LinearizeVectorToElements(const TypeConverter &typeConverter,
+                            MLIRContext *context, PatternBenefit benefit = 1)
+      : OpConversionPattern(typeConverter, context, benefit) {}
+
+  LogicalResult
+  matchAndRewrite(vector::ToElementsOp toElementsOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    VectorType vecType = toElementsOp.getSource().getType();
     if (vecType.getRank() <= 1)
       return rewriter.notifyMatchFailure(
-          op, "the rank is already less than or equal to 1");
+          toElementsOp, "the rank is already less than or equal to 1");
 
     assert(vecType.getNumScalableDims() == 0 &&
            "scalable vector is not yet supported");
     auto vec1DType =
         VectorType::get({vecType.getNumElements()}, vecType.getElementType());
-    Value shapeCast = vector::ShapeCastOp::create(rewriter, op.getLoc(),
-                                                  vec1DType, op.getSource());
-    rewriter.replaceOpWithNewOp<vector::ToElementsOp>(op, op.getResultTypes(),
-                                                      shapeCast);
+    Value shapeCast = vector::ShapeCastOp::create(
+        rewriter, toElementsOp.getLoc(), vec1DType, toElementsOp.getSource());
+    rewriter.replaceOpWithNewOp<vector::ToElementsOp>(
+        toElementsOp, toElementsOp.getResultTypes(), shapeCast);
     return success();
   }
 };
@@ -914,8 +933,8 @@ void mlir::vector::populateVectorLinearizeBasePatterns(
   patterns
       .add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
            LinearizeVectorSplat, LinearizeVectorCreateMask, LinearizeVectorLoad,
-           LinearizeVectorStore, LinearizeVectorFromElements>(
-          typeConverter, patterns.getContext());
+           LinearizeVectorStore, LinearizeVectorFromElements,
+           LinearizeVectorToElements>(typeConverter, patterns.getContext());
 }
 
 void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 5e8bfd0698b33..fe697c8b9c057 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -538,3 +538,26 @@ func.func @test_vector_from_elements(%arg0: f32, %arg1: f32, %arg2: f32, %arg3:
   %1 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x2xf32>
   return %1 : vector<2x2xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func.func @to_elements_1d(
+// CHECK-SAME:    %[[ARG0:.+]]: vector<2xf32>
+// CHECK:         %[[RES:.+]]:2 = vector.to_elements %[[ARG0]] : vector<2xf32>
+// CHECK:         return %[[RES]]#0, %[[RES]]#1
+func.func @to_elements_1d(%arg0: vector<2xf32>) -> (f32, f32) {
+  %0:2 = vector.to_elements %arg0 : vector<2xf32>
+  return %0#0, %0#1 : f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @to_elements_2d(
+// CHECK-SAME:    %[[ARG0:.+]]: vector<2x2xf32>
+// CHECK:         %[[CAST:.+]] = vector.shape_cast %[[ARG0]]
+// CHECK:         %[[RES:.+]]:4 = vector.to_elements %[[CAST]] : vector<4xf32>
+// CHECK:         return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2, %[[RES]]#3
+func.func @to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) {
+  %0:4 = vector.to_elements %arg0 : vector<2x2xf32>
+  return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
+}

>From a8f9b9cb13b1677bd5be6d662f8865f739671208 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 9 Sep 2025 12:33:21 -0700
Subject: [PATCH 13/21] Remove LinearizeToElements

---
 .../Vector/Transforms/VectorLinearize.cpp     | 47 +------------------
 mlir/test/Dialect/Vector/linearize.mlir       | 23 ---------
 2 files changed, 2 insertions(+), 68 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 54eb182a9680f..7dde6311fa809 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -798,49 +798,6 @@ struct LinearizeVectorFromElements final
   }
 };
 
-/// This pattern linearizes the operand in `vector.to_elements` operations
-/// by converting the result type to a 1-D vector while preserving all element
-/// values. The transformation creates a linearized `vector.shape_cast`
-/// followed by a `vector.to_elements`.
-///
-/// Example:
-///
-///     %0:4 = vector.to_elements %v : vector<2x2xf32>
-///
-/// is converted to:
-///
-///     %vector_cast = vector.shape_cast %v : vector<2x2xf32> to vector<4xf32>
-///     %0:4 = vector.to_elements %vector_cast : vector<4xf32>
-///
-struct LinearizeVectorToElements final
-    : public OpConversionPattern<vector::ToElementsOp> {
-  using OpConversionPattern::OpConversionPattern;
-
-  LinearizeVectorToElements(const TypeConverter &typeConverter,
-                            MLIRContext *context, PatternBenefit benefit = 1)
-      : OpConversionPattern(typeConverter, context, benefit) {}
-
-  LogicalResult
-  matchAndRewrite(vector::ToElementsOp toElementsOp, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-
-    VectorType vecType = toElementsOp.getSource().getType();
-    if (vecType.getRank() <= 1)
-      return rewriter.notifyMatchFailure(
-          toElementsOp, "the rank is already less than or equal to 1");
-
-    assert(vecType.getNumScalableDims() == 0 &&
-           "scalable vector is not yet supported");
-    auto vec1DType =
-        VectorType::get({vecType.getNumElements()}, vecType.getElementType());
-    Value shapeCast = vector::ShapeCastOp::create(
-        rewriter, toElementsOp.getLoc(), vec1DType, toElementsOp.getSource());
-    rewriter.replaceOpWithNewOp<vector::ToElementsOp>(
-        toElementsOp, toElementsOp.getResultTypes(), shapeCast);
-    return success();
-  }
-};
-
 } // namespace
 
 /// This method defines the set of operations that are linearizable, and hence
@@ -933,8 +890,8 @@ void mlir::vector::populateVectorLinearizeBasePatterns(
   patterns
       .add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
            LinearizeVectorSplat, LinearizeVectorCreateMask, LinearizeVectorLoad,
-           LinearizeVectorStore, LinearizeVectorFromElements,
-           LinearizeVectorToElements>(typeConverter, patterns.getContext());
+           LinearizeVectorStore, LinearizeVectorFromElements>(
+          typeConverter, patterns.getContext());
 }
 
 void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index fe697c8b9c057..5e8bfd0698b33 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -538,26 +538,3 @@ func.func @test_vector_from_elements(%arg0: f32, %arg1: f32, %arg2: f32, %arg3:
   %1 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x2xf32>
   return %1 : vector<2x2xf32>
 }
-
-// -----
-
-// CHECK-LABEL: func.func @to_elements_1d(
-// CHECK-SAME:    %[[ARG0:.+]]: vector<2xf32>
-// CHECK:         %[[RES:.+]]:2 = vector.to_elements %[[ARG0]] : vector<2xf32>
-// CHECK:         return %[[RES]]#0, %[[RES]]#1
-func.func @to_elements_1d(%arg0: vector<2xf32>) -> (f32, f32) {
-  %0:2 = vector.to_elements %arg0 : vector<2xf32>
-  return %0#0, %0#1 : f32, f32
-}
-
-// -----
-
-// CHECK-LABEL: func.func @to_elements_2d(
-// CHECK-SAME:    %[[ARG0:.+]]: vector<2x2xf32>
-// CHECK:         %[[CAST:.+]] = vector.shape_cast %[[ARG0]]
-// CHECK:         %[[RES:.+]]:4 = vector.to_elements %[[CAST]] : vector<4xf32>
-// CHECK:         return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2, %[[RES]]#3
-func.func @to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) {
-  %0:4 = vector.to_elements %arg0 : vector<2x2xf32>
-  return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
-}

>From 0568cbabd9c81e47cea30b0d01864ea0466f4633 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 10 Sep 2025 07:04:00 -0700
Subject: [PATCH 14/21] Improve API of unrollVectorValue.

Parameters are now:
* using TypedValue<VectorType> instead of just Value
* using RewriterBase class.

Return types are:
* changed to FailureOr<SmallValue<Value>> instead of passing
  a Value as a parameter and returning Logical.
---
 .../mlir/Dialect/Vector/Utils/VectorUtils.h        | 14 +++++---------
 .../Vector/Transforms/LowerVectorToElements.cpp    |  8 ++++++--
 mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp      |  8 +++++---
 3 files changed, 16 insertions(+), 14 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index 985f90f6ed955..97163c4532378 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -255,15 +255,11 @@ using UnrollVectorOpFn =
 LogicalResult unrollVectorOp(Operation *op, PatternRewriter &rewriter,
                              UnrollVectorOpFn unrollFn);
 
-/// Generic utility for mapping values of type vector<nxaxbx...>
-/// to n values of type vector<axbx...>
-/// Follows the following pattern:
-/// 1. Check if already 1-D. If so, return failure.
-/// 2. Check for scalable dimensions. If so, return failure.
-/// 3. Returns the values of n vector.extract operations corresponding
-///    to the outermost dimension.
-LogicalResult unrollVectorValue(Value vector, PatternRewriter &rewriter,
-                                SmallVector<Value> &values);
+/// Generic utility for unrolling values of type vector<NxAxBx...>
+/// to N values of type vector<AxBx...> using vector.extract. If the input
+/// is rank-1 or has leading scalable dimension, failure is returned.
+FailureOr<SmallVector<Value>> unrollVectorValue(TypedValue<VectorType>,
+                                                RewriterBase &);
 
 } // namespace vector
 
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
index 718b947f13715..56d0f61d3c9cc 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
@@ -25,10 +25,14 @@ struct UnrollToElements final : OpRewritePattern<vector::ToElementsOp> {
 
   LogicalResult matchAndRewrite(vector::ToElementsOp op,
                                 PatternRewriter &rewriter) const override {
-    SmallVector<Value> vectors;
-    if (failed(vector::unrollVectorValue(op.getSource(), rewriter, vectors))) {
+
+    TypedValue<VectorType> source = op.getSource();
+    FailureOr<SmallVector<Value>> result =
+        vector::unrollVectorValue(source, rewriter);
+    if (failed(result)) {
       return failure();
     }
+    SmallVector<Value> vectors = *result;
 
     // May be a large vector.
     SmallVector<Value, 0> results;
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 8d67cd7e80382..d8e96a294005b 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -393,8 +393,10 @@ vector::isValidMaskedInputVector(ArrayRef<int64_t> shape,
   return success();
 }
 
-LogicalResult vector::unrollVectorValue(Value vector, PatternRewriter &rewriter,
-                                        SmallVector<Value> &subvectors) {
+FailureOr<SmallVector<Value>>
+vector::unrollVectorValue(TypedValue<VectorType> vector,
+                          RewriterBase &rewriter) {
+  SmallVector<Value> subvectors;
   VectorType ty = cast<VectorType>(vector.getType());
   Location loc = vector.getLoc();
   if (ty.getRank() < 2)
@@ -409,7 +411,7 @@ LogicalResult vector::unrollVectorValue(Value vector, PatternRewriter &rewriter,
     subvectors.push_back(vector::ExtractOp::create(rewriter, loc, vector, i));
   }
 
-  return success();
+  return subvectors;
 }
 
 LogicalResult vector::unrollVectorOp(Operation *op, PatternRewriter &rewriter,

>From 00cc94bf6e01f703cb2a3ea941406039d39d8f56 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 10 Sep 2025 11:43:00 -0700
Subject: [PATCH 15/21] Documentation

---
 mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp | 14 ++++++++++++++
 1 file changed, 14 insertions(+)

diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index d8e96a294005b..56f20c334c50d 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -393,6 +393,20 @@ vector::isValidMaskedInputVector(ArrayRef<int64_t> shape,
   return success();
 }
 
+/// Takes a 2+ dimensional vector as an input
+/// returns n vector values produced by n vector.extract operations.
+/// I.e. calling unrollVectorValue([[%v]], rewriter) such that
+///
+///   %v : vector<nxaxb...>
+///
+/// will produce the following IR changes
+///
+///   %v0 = vector.extract %v[0] : vector<axbx...>
+///   %v1 = vector.extract %v[1] : vector<axbx...>
+///   ...
+///   %vnminusone = vector.extract %v[n-1] : vector<axbx...>
+///
+/// and returns SmallVector<Value> r = {[[%v0]], [[%v1]], ..., [[%vnminusone]]}
 FailureOr<SmallVector<Value>>
 vector::unrollVectorValue(TypedValue<VectorType> vector,
                           RewriterBase &rewriter) {

>From 6894af06a787c03f7a72b9d62025d373170e3a88 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 10 Sep 2025 11:49:47 -0700
Subject: [PATCH 16/21] Add transform.apply_patterns.vector.unroll_to_elements

---
 .../Dialect/Vector/TransformOps/VectorTransformOps.td | 11 +++++++++++
 .../Vector/TransformOps/VectorTransformOps.cpp        |  5 +++++
 mlir/test/python/dialects/transform_vector_ext.py     |  2 ++
 3 files changed, 18 insertions(+)

diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index 07a4117a37b2c..72a69a056c46e 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -265,6 +265,17 @@ def ApplyUnrollFromElementsPatternsOp : Op<Transform_Dialect,
   let assemblyFormat = "attr-dict";
 }
 
+def ApplyUnrollToElementsPatternsOp : Op<Transform_Dialect,
+    "apply_patterns.vector.unroll_to_elements",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Indicates that vector to_elements operations should be unrolled
+    along the outermost dimension.
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
+
 def ApplyLowerScanPatternsOp : Op<Transform_Dialect,
     "apply_patterns.vector.lower_scan",
     [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index fe066dc04ad55..6bb390aa09d3e 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -144,6 +144,11 @@ void transform::ApplyUnrollFromElementsPatternsOp::populatePatterns(
   vector::populateVectorFromElementsLoweringPatterns(patterns);
 }
 
+void transform::ApplyUnrollToElementsPatternsOp::populatePatterns(
+    RewritePatternSet &patterns) {
+  vector::populateVectorToElementsLoweringPatterns(patterns);
+}
+
 void transform::ApplyLowerScanPatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
   vector::populateVectorScanLoweringPatterns(patterns);
diff --git a/mlir/test/python/dialects/transform_vector_ext.py b/mlir/test/python/dialects/transform_vector_ext.py
index 5a648fe073315..28902b012f7cb 100644
--- a/mlir/test/python/dialects/transform_vector_ext.py
+++ b/mlir/test/python/dialects/transform_vector_ext.py
@@ -48,6 +48,8 @@ def non_configurable_patterns():
     vector.ApplyLowerGatherPatternsOp()
     # CHECK: transform.apply_patterns.vector.unroll_from_elements
     vector.ApplyUnrollFromElementsPatternsOp()
+    # CHECK: transform.apply_patterns.vector.unroll_to_elements
+    vector.ApplyUnrollToElementsPatternsOp()
     # CHECK: transform.apply_patterns.vector.lower_scan
     vector.ApplyLowerScanPatternsOp()
     # CHECK: transform.apply_patterns.vector.lower_shape_cast

>From 82004a21be08227d5b2b7656813646d74af154c0 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 10 Sep 2025 11:58:16 -0700
Subject: [PATCH 17/21] Minor changes

---
 .../lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp | 4 ++--
 mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp               | 6 +++---
 2 files changed, 5 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
index 56d0f61d3c9cc..82a4fab138191 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
@@ -20,7 +20,7 @@ using namespace mlir;
 
 namespace {
 
-struct UnrollToElements final : OpRewritePattern<vector::ToElementsOp> {
+struct UnrollToElements final : public OpRewritePattern<vector::ToElementsOp> {
   using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(vector::ToElementsOp op,
@@ -38,7 +38,7 @@ struct UnrollToElements final : OpRewritePattern<vector::ToElementsOp> {
     SmallVector<Value, 0> results;
     for (const Value &vector : vectors) {
       auto subElements =
-          rewriter.create<vector::ToElementsOp>(op.getLoc(), vector);
+          vector::ToElementsOp::create(rewriter, op.getLoc(), vector);
       llvm::append_range(results, subElements.getResults());
     }
     rewriter.replaceOp(op, results);
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 56f20c334c50d..39dc7a4f284a6 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -401,10 +401,10 @@ vector::isValidMaskedInputVector(ArrayRef<int64_t> shape,
 ///
 /// will produce the following IR changes
 ///
-///   %v0 = vector.extract %v[0] : vector<axbx...>
-///   %v1 = vector.extract %v[1] : vector<axbx...>
+///   %v0 = vector.extract %v[0] : vector<axbx...> from vector<nxaxb...>
+///   %v1 = vector.extract %v[1] : vector<axbx...> from vector<nxaxb...>
 ///   ...
-///   %vnminusone = vector.extract %v[n-1] : vector<axbx...>
+///   %vnminusone = vector.extract %v[n-1] : vector<axbx...> from ...
 ///
 /// and returns SmallVector<Value> r = {[[%v0]], [[%v1]], ..., [[%vnminusone]]}
 FailureOr<SmallVector<Value>>

>From cb6cf99c8b687b6493d833bf55ada9764f9137a4 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Thu, 11 Sep 2025 10:05:38 -0700
Subject: [PATCH 18/21] Remove comment and use inline storage for SmallVector

---
 mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
index 82a4fab138191..a53a183ec31bc 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
@@ -34,8 +34,7 @@ struct UnrollToElements final : public OpRewritePattern<vector::ToElementsOp> {
     }
     SmallVector<Value> vectors = *result;
 
-    // May be a large vector.
-    SmallVector<Value, 0> results;
+    SmallVector<Value> results;
     for (const Value &vector : vectors) {
       auto subElements =
           vector::ToElementsOp::create(rewriter, op.getLoc(), vector);

>From db83d2f4af38a3fbfdb29a3dfff284edd712ccd3 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Thu, 11 Sep 2025 10:21:41 -0700
Subject: [PATCH 19/21] Add test with transform interpreter

---
 .../Vector/vector-to-elements-lowering.mlir   | 25 +++++++++++++++++++
 1 file changed, 25 insertions(+)

diff --git a/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir b/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir
index e302dbd174322..18bcf7da7959a 100644
--- a/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir
@@ -1,4 +1,5 @@
 // RUN: mlir-opt %s -test-unroll-vector-to-elements -split-input-file | FileCheck %s
+// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
 
 // CHECK-LABEL: func.func @to_elements_1d(
 // CHECK-SAME:    %[[ARG0:.+]]: vector<2xf32>
@@ -9,6 +10,18 @@ func.func @to_elements_1d(%arg0: vector<2xf32>) -> (f32, f32) {
   return %0#0, %0#1 : f32, f32
 }
 
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+    %f = transform.structured.match ops{["func.func"]} in %module_op
+      : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %f {
+      transform.apply_patterns.vector.transfer_permutation_patterns
+      transform.apply_patterns.vector.unroll_to_elements
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
 // -----
 
 // CHECK-LABEL: func.func @to_elements_2d(
@@ -22,3 +35,15 @@ func.func @to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) {
   %0:4 = vector.to_elements %arg0 : vector<2x2xf32>
   return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
 }
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+    %f = transform.structured.match ops{["func.func"]} in %module_op
+      : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %f {
+      transform.apply_patterns.vector.transfer_permutation_patterns
+      transform.apply_patterns.vector.unroll_to_elements
+    } : !transform.any_op
+    transform.yield
+  }
+}

>From ccec33c15fc18ab41cd96a93a23ef09cac85c5e0 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Thu, 11 Sep 2025 10:28:14 -0700
Subject: [PATCH 20/21] Use transform dialect library file

---
 mlir/test/Dialect/Vector/lit.local.cfg        |  2 ++
 .../Dialect/Vector/td/unroll-elements.mlir    | 11 ++++++++
 .../Vector/vector-to-elements-lowering.mlir   | 27 ++-----------------
 3 files changed, 15 insertions(+), 25 deletions(-)
 create mode 100644 mlir/test/Dialect/Vector/lit.local.cfg
 create mode 100644 mlir/test/Dialect/Vector/td/unroll-elements.mlir

diff --git a/mlir/test/Dialect/Vector/lit.local.cfg b/mlir/test/Dialect/Vector/lit.local.cfg
new file mode 100644
index 0000000000000..62743008a3e3a
--- /dev/null
+++ b/mlir/test/Dialect/Vector/lit.local.cfg
@@ -0,0 +1,2 @@
+# Skip the directory with input TD sequences
+config.excludes = ["td"]
diff --git a/mlir/test/Dialect/Vector/td/unroll-elements.mlir b/mlir/test/Dialect/Vector/td/unroll-elements.mlir
new file mode 100644
index 0000000000000..40a90a33b0ac4
--- /dev/null
+++ b/mlir/test/Dialect/Vector/td/unroll-elements.mlir
@@ -0,0 +1,11 @@
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @unroll_to_elements(%module_op: !transform.any_op {transform.readonly}) {
+    %f = transform.structured.match ops{["func.func"]} in %module_op
+      : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %f {
+      transform.apply_patterns.vector.transfer_permutation_patterns
+      transform.apply_patterns.vector.unroll_to_elements
+    } : !transform.any_op
+    transform.yield
+  }
+}
diff --git a/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir b/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir
index 18bcf7da7959a..9ec0d76599c41 100644
--- a/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir
@@ -1,5 +1,6 @@
 // RUN: mlir-opt %s -test-unroll-vector-to-elements -split-input-file | FileCheck %s
-// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
+// RUN: mlir-opt %s -transform-preload-library='transform-library-paths=%p/td/unroll-elements.mlir' \
+// RUN: -transform-interpreter=entry-point=unroll_to_elements | FileCheck %s
 
 // CHECK-LABEL: func.func @to_elements_1d(
 // CHECK-SAME:    %[[ARG0:.+]]: vector<2xf32>
@@ -10,18 +11,6 @@ func.func @to_elements_1d(%arg0: vector<2xf32>) -> (f32, f32) {
   return %0#0, %0#1 : f32, f32
 }
 
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
-    %f = transform.structured.match ops{["func.func"]} in %module_op
-      : (!transform.any_op) -> !transform.any_op
-    transform.apply_patterns to %f {
-      transform.apply_patterns.vector.transfer_permutation_patterns
-      transform.apply_patterns.vector.unroll_to_elements
-    } : !transform.any_op
-    transform.yield
-  }
-}
-
 // -----
 
 // CHECK-LABEL: func.func @to_elements_2d(
@@ -35,15 +24,3 @@ func.func @to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) {
   %0:4 = vector.to_elements %arg0 : vector<2x2xf32>
   return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
 }
-
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
-    %f = transform.structured.match ops{["func.func"]} in %module_op
-      : (!transform.any_op) -> !transform.any_op
-    transform.apply_patterns to %f {
-      transform.apply_patterns.vector.transfer_permutation_patterns
-      transform.apply_patterns.vector.unroll_to_elements
-    } : !transform.any_op
-    transform.yield
-  }
-}

>From 7e52d00fe71355b1d986e5ada996960fbd573e26 Mon Sep 17 00:00:00 2001
From: Erick Ochoa Lopez <eochoalo at amd.com>
Date: Thu, 11 Sep 2025 13:40:35 -0400
Subject: [PATCH 21/21] Update mlir/test/Dialect/Vector/lit.local.cfg

Co-authored-by: Jakub Kuderski <kubakuderski at gmail.com>
---
 mlir/test/Dialect/Vector/lit.local.cfg | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/test/Dialect/Vector/lit.local.cfg b/mlir/test/Dialect/Vector/lit.local.cfg
index 62743008a3e3a..3e9e8f8497624 100644
--- a/mlir/test/Dialect/Vector/lit.local.cfg
+++ b/mlir/test/Dialect/Vector/lit.local.cfg
@@ -1,2 +1,2 @@
-# Skip the directory with input TD sequences
+# Skip the directory with input TD sequences.
 config.excludes = ["td"]



More information about the Mlir-commits mailing list