[Mlir-commits] [mlir] [mlir] Move vector.{to_elements, from_elements} unrolling to `VectorUnroll.cpp` (PR #159118)

Erick Ochoa Lopez llvmlistbot at llvm.org
Thu Sep 18 07:23:08 PDT 2025


https://github.com/amd-eochoalo updated https://github.com/llvm/llvm-project/pull/159118

>From 08655a9b64e846a5dcd1d77ec0481292392032e3 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 16 Sep 2025 09:00:46 -0700
Subject: [PATCH 01/11] [mlir] Add vector.{to_elements,from_elements} unrolling
 to VectorToSPIRV

---
 .../SPIRV/Transforms/SPIRVConversion.cpp      |  2 +
 .../ConvertToSPIRV/vector-unroll.mlir         | 44 +++++++++++++++++++
 2 files changed, 46 insertions(+)

diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 49f4ce8de7c76..98e294b40456f 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -1495,6 +1495,8 @@ LogicalResult mlir::spirv::unrollVectorsInFuncBodies(Operation *op) {
     RewritePatternSet patterns(context);
     auto options = vector::UnrollVectorOptions().setNativeShapeFn(
         [](auto op) { return mlir::spirv::getNativeVectorShape(op); });
+    vector::populateVectorFromElementsLoweringPatterns(patterns);
+    vector::populateVectorToElementsLoweringPatterns(patterns);
     populateVectorUnrollPatterns(patterns, options);
     if (failed(applyPatternsGreedily(op, std::move(patterns))))
       return failure();
diff --git a/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir b/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir
index c85f4334ff2e5..0957f67690b97 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir
@@ -96,3 +96,47 @@ func.func @transpose(%arg0 : vector<2x3xi32>) -> (vector<3x2xi32>) {
   %0 = vector.transpose %arg0, [1, 0] : vector<2x3xi32> to vector<3x2xi32>
   return %0 : vector<3x2xi32>
 }
+
+// -----
+
+// In order to verify that the pattern is applied,
+// we need to make sure that the the 2d vector does not
+// come from the parameters. Otherwise, the pattern
+// in unrollVectorsInSignatures which splits the 2d vector
+// parameter will take precedent. Similarly, let's avoid
+// returning a vector as another pattern would take precendence.
+
+// CHECK-LABEL: @unroll_to_elements_2d
+func.func @unroll_to_elements_2d() -> (f32, f32, f32, f32) {
+  %1 = "test.op"() : () -> (vector<2x2xf32>)
+  // CHECK: %[[VEC2D:.+]] = "test.op"
+  // CHECK: %[[VEC0:.+]] = vector.extract %[[VEC2D]][0] : vector<2xf32> from vector<2x2xf32>
+  // CHECK: %[[VEC1:.+]] = vector.extract %[[VEC2D]][1] : vector<2xf32> from vector<2x2xf32>
+  // CHECK: %[[RES0:.+]]:2 = vector.to_elements %[[VEC0]]
+  // CHECK: %[[RES1:.+]]:2 = vector.to_elements %[[VEC1]]
+  %2:4 = vector.to_elements %1 : vector<2x2xf32>
+  return %2#0, %2#1, %2#2, %2#3 : f32, f32, f32, f32
+}
+
+// -----
+
+// In order to verify that the pattern is applied,
+// we need to make sure that the the 2d vector is used
+// by an operation and that extracts are not folded away.
+// In other words we can't use "test.op" nor return the
+// value `%0 = vector.from_elements`
+
+// CHECK-LABEL: @unroll_from_elements_2d
+// CHECK-SAME: (%[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32, %[[ARG3:.+]]: f32)
+func.func @unroll_from_elements_2d(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32) -> (vector<2x2xf32>) {
+  // CHECK: %[[VEC0:.+]] = vector.from_elements %[[ARG0]], %[[ARG1]] : vector<2xf32>
+  // CHECK: %[[VEC1:.+]] = vector.from_elements %[[ARG2]], %[[ARG3]] : vector<2xf32>
+  %0 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x2xf32>
+
+  // CHECK: %[[RES0:.+]] = arith.addf %[[VEC0]], %[[VEC0]]
+  // CHECK: %[[RES1:.+]] = arith.addf %[[VEC1]], %[[VEC1]]
+  %1 = arith.addf %0, %0 : vector<2x2xf32>
+
+  // return %[[RES0]], %%[[RES1]] : vector<2xf32>, vector<2xf32>
+  return %1 : vector<2x2xf32>
+}

>From c891a27212669a5d54f383a11ac3defb16a92f8f Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 16 Sep 2025 11:25:02 -0700
Subject: [PATCH 02/11] populate patterns inside populateVectorUnrollPatterns

---
 mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp | 2 --
 mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp   | 3 +++
 2 files changed, 3 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 98e294b40456f..49f4ce8de7c76 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -1495,8 +1495,6 @@ LogicalResult mlir::spirv::unrollVectorsInFuncBodies(Operation *op) {
     RewritePatternSet patterns(context);
     auto options = vector::UnrollVectorOptions().setNativeShapeFn(
         [](auto op) { return mlir::spirv::getNativeVectorShape(op); });
-    vector::populateVectorFromElementsLoweringPatterns(patterns);
-    vector::populateVectorToElementsLoweringPatterns(patterns);
     populateVectorUnrollPatterns(patterns, options);
     if (failed(applyPatternsGreedily(op, std::move(patterns))))
       return failure();
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index e8ecb0c0be846..a75e680afe1fb 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -12,6 +12,7 @@
 
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
 #include "mlir/Interfaces/VectorInterfaces.h"
 #include "llvm/ADT/MapVector.h"
@@ -814,6 +815,8 @@ struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
 void mlir::vector::populateVectorUnrollPatterns(
     RewritePatternSet &patterns, const UnrollVectorOptions &options,
     PatternBenefit benefit) {
+  populateVectorToElementsLoweringPatterns(patterns);
+  populateVectorFromElementsLoweringPatterns(patterns);
   patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
                UnrollContractionPattern, UnrollElementwisePattern,
                UnrollReductionPattern, UnrollMultiReductionPattern,

>From e1f9605b9fd07517e112e1ad7340de354c6f29f0 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 16 Sep 2025 11:43:01 -0700
Subject: [PATCH 03/11] Copy over UnrollToElements

---
 .../Vector/Transforms/VectorRewritePatterns.h |  6 ++++
 .../Vector/Transforms/VectorUnroll.cpp        | 32 ++++++++++++++++++-
 2 files changed, 37 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 0138f477cadea..32fcb948b9cf7 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -322,6 +322,12 @@ void populateVectorUnrollPatterns(RewritePatternSet &patterns,
                                   const UnrollVectorOptions &options,
                                   PatternBenefit benefit = 1);
 
+/// Populate the pattern set with the following patterns:
+///
+/// [UnrollToElements]
+void populateVectorToElementsUnrollPatterns(RewritePatternSet &patterns,
+                                            PatternBenefit benefit = 1);
+
 /// Collect a set of leading one dimension removal patterns.
 ///
 /// These patterns insert vector.shape_cast to remove leading one dimensions
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index a75e680afe1fb..bcfaa843a306f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -810,13 +810,38 @@ struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
   vector::UnrollVectorOptions options;
 };
 
+struct UnrollToElements final : public OpRewritePattern<vector::ToElementsOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ToElementsOp op,
+                                PatternRewriter &rewriter) const override {
+
+    TypedValue<VectorType> source = op.getSource();
+    FailureOr<SmallVector<Value>> result =
+        vector::unrollVectorValue(source, rewriter);
+    if (failed(result)) {
+      return failure();
+    }
+    SmallVector<Value> vectors = *result;
+
+    SmallVector<Value> results;
+    for (const Value &vector : vectors) {
+      auto subElements =
+          vector::ToElementsOp::create(rewriter, op.getLoc(), vector);
+      llvm::append_range(results, subElements.getResults());
+    }
+    rewriter.replaceOp(op, results);
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::vector::populateVectorUnrollPatterns(
     RewritePatternSet &patterns, const UnrollVectorOptions &options,
     PatternBenefit benefit) {
-  populateVectorToElementsLoweringPatterns(patterns);
   populateVectorFromElementsLoweringPatterns(patterns);
+  patterns.add<UnrollToElements>(patterns.getContext(), benefit);
   patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
                UnrollContractionPattern, UnrollElementwisePattern,
                UnrollReductionPattern, UnrollMultiReductionPattern,
@@ -824,3 +849,8 @@ void mlir::vector::populateVectorUnrollPatterns(
                UnrollStorePattern, UnrollBroadcastPattern>(
       patterns.getContext(), options, benefit);
 }
+
+void mlir::vector::populateVectorToElementsUnrollPatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit) {
+  patterns.add<UnrollToElements>(patterns.getContext(), benefit);
+}

>From 5c3e7d5d60c1fdd4e3e09399c68d04a78ee8748a Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 16 Sep 2025 11:55:05 -0700
Subject: [PATCH 04/11] Removes reference to previous UnrollToElements pattern

---
 .../Vector/Transforms/LoweringPatterns.h      |  6 ---
 .../VectorToLLVM/ConvertVectorToLLVMPass.cpp  |  2 +-
 .../TransformOps/VectorTransformOps.cpp       |  2 +-
 .../Dialect/Vector/Transforms/CMakeLists.txt  |  1 -
 .../Transforms/LowerVectorToElements.cpp      | 53 -------------------
 5 files changed, 2 insertions(+), 62 deletions(-)
 delete mode 100644 mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index f56124cb4fb95..47f96112a9433 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -311,12 +311,6 @@ void populateVectorToFromElementsToShuffleTreePatterns(
 void populateVectorFromElementsLoweringPatterns(RewritePatternSet &patterns,
                                                 PatternBenefit benefit = 1);
 
-/// Populate the pattern set with the following patterns:
-///
-/// [UnrollToElements]
-void populateVectorToElementsLoweringPatterns(RewritePatternSet &patterns,
-                                              PatternBenefit benefit = 1);
-
 /// Populate the pattern set with the following patterns:
 ///
 /// [ContractionOpToMatmulOpLowering]
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 0b44ca7ceee42..9cdfeea2b81bf 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -95,7 +95,7 @@ void ConvertVectorToLLVMPass::runOnOperation() {
     populateVectorRankReducingFMAPattern(patterns);
     populateVectorGatherLoweringPatterns(patterns);
     populateVectorFromElementsLoweringPatterns(patterns);
-    populateVectorToElementsLoweringPatterns(patterns);
+    populateVectorToElementsUnrollPatterns(patterns);
     if (armI8MM) {
       if (armNeon)
         arm_neon::populateLowerContractionToNeonI8MMPatterns(patterns);
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 18f105ef62e38..a3350a3332862 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -151,7 +151,7 @@ void transform::ApplyUnrollFromElementsPatternsOp::populatePatterns(
 
 void transform::ApplyUnrollToElementsPatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
-  vector::populateVectorToElementsLoweringPatterns(patterns);
+  vector::populateVectorToElementsUnrollPatterns(patterns);
 }
 
 void transform::ApplyLowerScanPatternsOp::populatePatterns(
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index d74007f13a95b..acbf2b746037b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -11,7 +11,6 @@ add_mlir_dialect_library(MLIRVectorTransforms
   LowerVectorScan.cpp
   LowerVectorShapeCast.cpp
   LowerVectorStep.cpp
-  LowerVectorToElements.cpp
   LowerVectorToFromElementsToShuffleTree.cpp
   LowerVectorTransfer.cpp
   LowerVectorTranspose.cpp
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
deleted file mode 100644
index a53a183ec31bc..0000000000000
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
+++ /dev/null
@@ -1,53 +0,0 @@
-//===- LowerVectorToElements.cpp - Lower 'vector.to_elements' op ----------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file implements target-independent rewrites and utilities to lower the
-// 'vector.to_elements' operation.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Vector/IR/VectorOps.h"
-#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
-
-#define DEBUG_TYPE "lower-vector-to-elements"
-
-using namespace mlir;
-
-namespace {
-
-struct UnrollToElements final : public OpRewritePattern<vector::ToElementsOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(vector::ToElementsOp op,
-                                PatternRewriter &rewriter) const override {
-
-    TypedValue<VectorType> source = op.getSource();
-    FailureOr<SmallVector<Value>> result =
-        vector::unrollVectorValue(source, rewriter);
-    if (failed(result)) {
-      return failure();
-    }
-    SmallVector<Value> vectors = *result;
-
-    SmallVector<Value> results;
-    for (const Value &vector : vectors) {
-      auto subElements =
-          vector::ToElementsOp::create(rewriter, op.getLoc(), vector);
-      llvm::append_range(results, subElements.getResults());
-    }
-    rewriter.replaceOp(op, results);
-    return success();
-  }
-};
-
-} // namespace
-
-void mlir::vector::populateVectorToElementsLoweringPatterns(
-    RewritePatternSet &patterns, PatternBenefit benefit) {
-  patterns.add<UnrollToElements>(patterns.getContext(), benefit);
-}

>From 04431c5f5755c7ed2b71f831014dc8d6f8563af7 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 16 Sep 2025 12:05:05 -0700
Subject: [PATCH 05/11] Copy over UnrollFromElements

---
 .../Vector/Transforms/VectorRewritePatterns.h |  8 ++++
 .../Vector/Transforms/VectorUnroll.cpp        | 46 ++++++++++++++++++-
 2 files changed, 52 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 32fcb948b9cf7..c42b8748f60de 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -328,6 +328,14 @@ void populateVectorUnrollPatterns(RewritePatternSet &patterns,
 void populateVectorToElementsUnrollPatterns(RewritePatternSet &patterns,
                                             PatternBenefit benefit = 1);
 
+/// Populate the pattern set with the following patterns:
+///
+/// [UnrollFromElements]
+/// Unrolls 2 or more dimensional `vector.from_elements` ops by unrolling the
+/// outermost dimension.
+void populateVectorFromElementsUnrollPatterns(RewritePatternSet &patterns,
+                                              PatternBenefit benefit = 1);
+
 /// Collect a set of leading one dimension removal patterns.
 ///
 /// These patterns insert vector.shape_cast to remove leading one dimensions
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index bcfaa843a306f..ba82aa766180f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -835,13 +835,50 @@ struct UnrollToElements final : public OpRewritePattern<vector::ToElementsOp> {
   }
 };
 
+/// Unrolls 2 or more dimensional `vector.from_elements` ops by unrolling the
+/// outermost dimension. For example:
+/// ```
+/// %v = vector.from_elements %e0, %e1, %e2, %e3, %e4, %e5 : vector<2x3xf32>
+///
+/// ==>
+///
+/// %0   = ub.poison : vector<2x3xf32>
+/// %v0  = vector.from_elements %e0, %e1, %e2 : vector<3xf32>
+/// %1   = vector.insert %v0, %0 [0] : vector<3xf32> into vector<2x3xf32>
+/// %v1  = vector.from_elements %e3, %e4, %e5 : vector<3xf32>
+/// %v   = vector.insert %v1, %1 [1] : vector<3xf32> into vector<2x3xf32>
+/// ```
+///
+/// When applied exhaustively, this will produce a sequence of 1-d from_elements
+/// ops.
+struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::FromElementsOp op,
+                                PatternRewriter &rewriter) const override {
+    ValueRange allElements = op.getElements();
+
+    auto unrollFromElementsFn = [&](PatternRewriter &rewriter, Location loc,
+                                    VectorType subTy, int64_t index) {
+      size_t subTyNumElements = subTy.getNumElements();
+      assert((index + 1) * subTyNumElements <= allElements.size() &&
+             "out of bounds");
+      ValueRange subElements =
+          allElements.slice(index * subTyNumElements, subTyNumElements);
+      return vector::FromElementsOp::create(rewriter, loc, subTy, subElements);
+    };
+
+    return unrollVectorOp(op, rewriter, unrollFromElementsFn);
+  }
+};
+
 } // namespace
 
 void mlir::vector::populateVectorUnrollPatterns(
     RewritePatternSet &patterns, const UnrollVectorOptions &options,
     PatternBenefit benefit) {
-  populateVectorFromElementsLoweringPatterns(patterns);
-  patterns.add<UnrollToElements>(patterns.getContext(), benefit);
+  patterns.add<UnrollFromElements, UnrollToElements>(patterns.getContext(),
+                                                     benefit);
   patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
                UnrollContractionPattern, UnrollElementwisePattern,
                UnrollReductionPattern, UnrollMultiReductionPattern,
@@ -854,3 +891,8 @@ void mlir::vector::populateVectorToElementsUnrollPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit) {
   patterns.add<UnrollToElements>(patterns.getContext(), benefit);
 }
+
+void mlir::vector::populateVectorFromElementsUnrollPatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit) {
+  patterns.add<UnrollFromElements>(patterns.getContext(), benefit);
+}

>From 70a667e035a6de981e7b21e759337f8bb47728bb Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 16 Sep 2025 12:10:48 -0700
Subject: [PATCH 06/11] Removes reference to previous UnrollFromElements
 pattern

---
 .../Vector/Transforms/LoweringPatterns.h      |  8 ---
 .../GPUCommon/GPUToLLVMConversion.cpp         |  2 +-
 .../GPUToNVVM/LowerGpuOpsToNVVMOps.cpp        |  2 +-
 .../VectorToLLVM/ConvertVectorToLLVMPass.cpp  |  2 +-
 .../TransformOps/VectorTransformOps.cpp       |  2 +-
 .../Dialect/Vector/Transforms/CMakeLists.txt  |  1 -
 .../Transforms/LowerVectorFromElements.cpp    | 65 -------------------
 7 files changed, 4 insertions(+), 78 deletions(-)
 delete mode 100644 mlir/lib/Dialect/Vector/Transforms/LowerVectorFromElements.cpp

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 47f96112a9433..e03f0dabece52 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -303,14 +303,6 @@ void populateVectorRankReducingFMAPattern(RewritePatternSet &patterns);
 void populateVectorToFromElementsToShuffleTreePatterns(
     RewritePatternSet &patterns, PatternBenefit benefit = 1);
 
-/// Populate the pattern set with the following patterns:
-///
-/// [UnrollFromElements]
-/// Unrolls 2 or more dimensional `vector.from_elements` ops by unrolling the
-/// outermost dimension.
-void populateVectorFromElementsLoweringPatterns(RewritePatternSet &patterns,
-                                                PatternBenefit benefit = 1);
-
 /// Populate the pattern set with the following patterns:
 ///
 /// [ContractionOpToMatmulOpLowering]
diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index e516118f75207..5994b64f3d9a5 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -534,7 +534,7 @@ void GpuToLLVMConversionPass::runOnOperation() {
                                                    /*maxTransferRank=*/1);
     // Transform N-D vector.from_elements to 1-D vector.from_elements before
     // conversion.
-    vector::populateVectorFromElementsLoweringPatterns(patterns);
+    vector::populateVectorFromElementsUnrollPatterns(patterns);
     if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
       return signalPassFailure();
   }
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 76a7e0f3831a2..a95263bb55f69 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -372,7 +372,7 @@ struct LowerGpuOpsToNVVMOpsPass final
       populateGpuRewritePatterns(patterns);
       // Transform N-D vector.from_elements to 1-D vector.from_elements before
       // conversion.
-      vector::populateVectorFromElementsLoweringPatterns(patterns);
+      vector::populateVectorFromElementsUnrollPatterns(patterns);
       if (failed(applyPatternsGreedily(m, std::move(patterns))))
         return signalPassFailure();
     }
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 9cdfeea2b81bf..cae490e5f03e7 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -94,7 +94,7 @@ void ConvertVectorToLLVMPass::runOnOperation() {
     populateVectorStepLoweringPatterns(patterns);
     populateVectorRankReducingFMAPattern(patterns);
     populateVectorGatherLoweringPatterns(patterns);
-    populateVectorFromElementsLoweringPatterns(patterns);
+    populateVectorFromElementsUnrollPatterns(patterns);
     populateVectorToElementsUnrollPatterns(patterns);
     if (armI8MM) {
       if (armNeon)
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index a3350a3332862..7faa222a9e574 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -146,7 +146,7 @@ void transform::ApplyLowerGatherPatternsOp::populatePatterns(
 
 void transform::ApplyUnrollFromElementsPatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
-  vector::populateVectorFromElementsLoweringPatterns(patterns);
+  vector::populateVectorFromElementsUnrollPatterns(patterns);
 }
 
 void transform::ApplyUnrollToElementsPatternsOp::populatePatterns(
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index acbf2b746037b..9e287fc109990 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -3,7 +3,6 @@ add_mlir_dialect_library(MLIRVectorTransforms
   LowerVectorBitCast.cpp
   LowerVectorBroadcast.cpp
   LowerVectorContract.cpp
-  LowerVectorFromElements.cpp
   LowerVectorGather.cpp
   LowerVectorInterleave.cpp
   LowerVectorMask.cpp
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorFromElements.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorFromElements.cpp
deleted file mode 100644
index c22fd54cef46b..0000000000000
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorFromElements.cpp
+++ /dev/null
@@ -1,65 +0,0 @@
-//===- LowerVectorFromElements.cpp - Lower 'vector.from_elements' op -----===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file implements target-independent rewrites and utilities to lower the
-// 'vector.from_elements' operation.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Vector/IR/VectorOps.h"
-#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
-
-#define DEBUG_TYPE "lower-vector-from-elements"
-
-using namespace mlir;
-
-namespace {
-
-/// Unrolls 2 or more dimensional `vector.from_elements` ops by unrolling the
-/// outermost dimension. For example:
-/// ```
-/// %v = vector.from_elements %e0, %e1, %e2, %e3, %e4, %e5 : vector<2x3xf32>
-///
-/// ==>
-///
-/// %0   = ub.poison : vector<2x3xf32>
-/// %v0  = vector.from_elements %e0, %e1, %e2 : vector<3xf32>
-/// %1   = vector.insert %v0, %0 [0] : vector<3xf32> into vector<2x3xf32>
-/// %v1  = vector.from_elements %e3, %e4, %e5 : vector<3xf32>
-/// %v   = vector.insert %v1, %1 [1] : vector<3xf32> into vector<2x3xf32>
-/// ```
-///
-/// When applied exhaustively, this will produce a sequence of 1-d from_elements
-/// ops.
-struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(vector::FromElementsOp op,
-                                PatternRewriter &rewriter) const override {
-    ValueRange allElements = op.getElements();
-
-    auto unrollFromElementsFn = [&](PatternRewriter &rewriter, Location loc,
-                                    VectorType subTy, int64_t index) {
-      size_t subTyNumElements = subTy.getNumElements();
-      assert((index + 1) * subTyNumElements <= allElements.size() &&
-             "out of bounds");
-      ValueRange subElements =
-          allElements.slice(index * subTyNumElements, subTyNumElements);
-      return vector::FromElementsOp::create(rewriter, loc, subTy, subElements);
-    };
-
-    return unrollVectorOp(op, rewriter, unrollFromElementsFn);
-  }
-};
-
-} // namespace
-
-void mlir::vector::populateVectorFromElementsLoweringPatterns(
-    RewritePatternSet &patterns, PatternBenefit benefit) {
-  patterns.add<UnrollFromElements>(patterns.getContext(), benefit);
-}

>From c054f16b16be2d8e3cee1234c6b1443b81ba70d3 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 16 Sep 2025 12:26:11 -0700
Subject: [PATCH 07/11] Adds UnrollVectorOptions to ToElements and ForElements
 patterns

---
 .../Vector/Transforms/VectorUnroll.cpp        | 30 ++++++++++++++-----
 1 file changed, 22 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index ba82aa766180f..6f8c667d58f48 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -811,7 +811,11 @@ struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
 };
 
 struct UnrollToElements final : public OpRewritePattern<vector::ToElementsOp> {
-  using OpRewritePattern::OpRewritePattern;
+  UnrollToElements(MLIRContext *context,
+                   const vector::UnrollVectorOptions &options,
+                   PatternBenefit benefit = 1)
+      : OpRewritePattern<vector::ToElementsOp>(context, benefit),
+        options(options) {}
 
   LogicalResult matchAndRewrite(vector::ToElementsOp op,
                                 PatternRewriter &rewriter) const override {
@@ -833,6 +837,9 @@ struct UnrollToElements final : public OpRewritePattern<vector::ToElementsOp> {
     rewriter.replaceOp(op, results);
     return success();
   }
+
+private:
+  vector::UnrollVectorOptions options;
 };
 
 /// Unrolls 2 or more dimensional `vector.from_elements` ops by unrolling the
@@ -852,7 +859,11 @@ struct UnrollToElements final : public OpRewritePattern<vector::ToElementsOp> {
 /// When applied exhaustively, this will produce a sequence of 1-d from_elements
 /// ops.
 struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
-  using OpRewritePattern::OpRewritePattern;
+  UnrollFromElements(MLIRContext *context,
+                     const vector::UnrollVectorOptions &options,
+                     PatternBenefit benefit = 1)
+      : OpRewritePattern<vector::FromElementsOp>(context, benefit),
+        options(options) {}
 
   LogicalResult matchAndRewrite(vector::FromElementsOp op,
                                 PatternRewriter &rewriter) const override {
@@ -870,6 +881,9 @@ struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
 
     return unrollVectorOp(op, rewriter, unrollFromElementsFn);
   }
+
+private:
+  vector::UnrollVectorOptions options;
 };
 
 } // namespace
@@ -877,22 +891,22 @@ struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
 void mlir::vector::populateVectorUnrollPatterns(
     RewritePatternSet &patterns, const UnrollVectorOptions &options,
     PatternBenefit benefit) {
-  patterns.add<UnrollFromElements, UnrollToElements>(patterns.getContext(),
-                                                     benefit);
   patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
                UnrollContractionPattern, UnrollElementwisePattern,
                UnrollReductionPattern, UnrollMultiReductionPattern,
                UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
-               UnrollStorePattern, UnrollBroadcastPattern>(
-      patterns.getContext(), options, benefit);
+               UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements,
+               UnrollToElements>(patterns.getContext(), options, benefit);
 }
 
 void mlir::vector::populateVectorToElementsUnrollPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit) {
-  patterns.add<UnrollToElements>(patterns.getContext(), benefit);
+  patterns.add<UnrollToElements>(patterns.getContext(), UnrollVectorOptions(),
+                                 benefit);
 }
 
 void mlir::vector::populateVectorFromElementsUnrollPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit) {
-  patterns.add<UnrollFromElements>(patterns.getContext(), benefit);
+  patterns.add<UnrollFromElements>(patterns.getContext(), UnrollVectorOptions(),
+                                   benefit);
 }

>From 5a266fbbfc6977b8f8bf21234133f2a4b30c45c5 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 17 Sep 2025 13:09:50 -0400
Subject: [PATCH 08/11] Address review comments

---
 mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 0af1882a66d30..1c80c6b6a39aa 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -829,7 +829,7 @@ struct UnrollToElements final : public OpRewritePattern<vector::ToElementsOp> {
     SmallVector<Value> vectors = *result;
 
     SmallVector<Value> results;
-    for (const Value &vector : vectors) {
+    for (Value vector : vectors) {
       auto subElements =
           vector::ToElementsOp::create(rewriter, op.getLoc(), vector);
       llvm::append_range(results, subElements.getResults());

>From f0283ca5f4ed806fe5cb88cf768b7b21d3049710 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Thu, 18 Sep 2025 08:43:13 -0400
Subject: [PATCH 09/11] Update documentation

---
 .../Dialect/Vector/Transforms/VectorRewritePatterns.h     | 8 ++------
 1 file changed, 2 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index b1effc8642383..69438011d2287 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -322,15 +322,11 @@ void populateVectorUnrollPatterns(RewritePatternSet &patterns,
                                   const UnrollVectorOptions &options,
                                   PatternBenefit benefit = 1);
 
-/// Populate the pattern set with the following patterns:
-///
-/// [UnrollToElements]
+/// Unrolls 2 or more dimensional `vector.to_elements` ops by unrolling the
+/// outermost dimension of the operand.
 void populateVectorToElementsUnrollPatterns(RewritePatternSet &patterns,
                                             PatternBenefit benefit = 1);
 
-/// Populate the pattern set with the following patterns:
-///
-/// [UnrollFromElements]
 /// Unrolls 2 or more dimensional `vector.from_elements` ops by unrolling the
 /// outermost dimension.
 void populateVectorFromElementsUnrollPatterns(RewritePatternSet &patterns,

>From 5d7afb1efd4a2b05918044c1bf0b46773b41619c Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Thu, 18 Sep 2025 08:44:46 -0400
Subject: [PATCH 10/11] Reword documentation 'exhaustive' -> 'fixed-point'

---
 mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 1c80c6b6a39aa..e7dd5958cf72a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -931,7 +931,8 @@ struct UnrollStepPattern : public OpRewritePattern<vector::StepOp> {
 /// %v   = vector.insert %v1, %1 [1] : vector<3xf32> into vector<2x3xf32>
 /// ```
 ///
-/// When applied exhaustively, this will produce a sequence of 1-d from_elements
+/// When this pattern is applied until a fixed-point is reached,
+/// this will produce a sequence of 1-d from_elements
 /// ops.
 struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
   UnrollFromElements(MLIRContext *context,

>From 44145b5c525a179ef117c17d89c4fa6bd37481a8 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Thu, 18 Sep 2025 08:48:06 -0400
Subject: [PATCH 11/11] Add documentation to UnrollToElements

---
 .../Dialect/Vector/Transforms/VectorUnroll.cpp  | 17 +++++++++++++++++
 1 file changed, 17 insertions(+)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index e7dd5958cf72a..14639c5f1cdd3 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -810,6 +810,23 @@ struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
   vector::UnrollVectorOptions options;
 };
 
+/// Unrolls 2 or more dimensional `vector.to_elements` ops by unrolling the
+/// outermost dimension of the operand. For example:
+///
+/// ```
+/// %0:4 = vector.to_elements %v : vector<2x2xf32>
+///
+/// ==>
+///
+/// %v0 = vector.extract %v[0] : vector<2x2xf32> from vector<2x2x2xf32>
+/// %v1 = vector.extract %v[1] : vector<2x2xf32> from vector<2x2x2xf32>
+/// %0:4 = vector.to_elements %v0 : vector<2x2xf32>
+/// %1:4 = vector.to_elements %v1 : vector<2x2xf32>
+/// ```
+///
+/// When this pattern is applied until a fixed-point is reached,
+/// this will produce a sequence of 1-d from_elements
+/// ops.
 struct UnrollToElements final : public OpRewritePattern<vector::ToElementsOp> {
   UnrollToElements(MLIRContext *context,
                    const vector::UnrollVectorOptions &options,



More information about the Mlir-commits mailing list