[Mlir-commits] [mlir] 9d19250 - [mlir][vector] Add vector.to_elements unrolling (#157142)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Sep 11 10:57:01 PDT 2025


Author: Erick Ochoa Lopez
Date: 2025-09-11T13:56:57-04:00
New Revision: 9d19250610fdaa80600d32fc7f6e06dcefd6bbff

URL: https://github.com/llvm/llvm-project/commit/9d19250610fdaa80600d32fc7f6e06dcefd6bbff
DIFF: https://github.com/llvm/llvm-project/commit/9d19250610fdaa80600d32fc7f6e06dcefd6bbff.diff

LOG: [mlir][vector] Add vector.to_elements unrolling (#157142)

This PR adds support for unrolling `vector.to_element`'s source operand.

It transforms

```mlir
%0:8 = vector.to_elements %v : vector<2x2x2xf32>
```

to

```mlir
%v0 = vector.extract %v[0] : vector<2x2xf32> from vector<2x2x2xf32>
%v1 = vector.extract %v[1] : vector<2x2xf32> from vector<2x2x2xf32>
%0:4 = vector.to_elements %v0 : vector<2x2xf32>
%1:4 = vector.to_elements %v1 : vector<2x2xf32>
// %0:8 = %0:4 - %1:4
```

This pattern will be applied until there are only 1-D vectors left.

---------

Signed-off-by: hanhanW <hanhan0912 at gmail.com>
Co-authored-by: hanhanW <hanhan0912 at gmail.com>
Co-authored-by: Jakub Kuderski <kubakuderski at gmail.com>

Added: 
    mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
    mlir/test/Dialect/Vector/lit.local.cfg
    mlir/test/Dialect/Vector/td/unroll-elements.mlir
    mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir

Modified: 
    mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
    mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
    mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
    mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
    mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
    mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
    mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
    mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
    mlir/test/python/dialects/transform_vector_ext.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index 07a4117a37b2c..72a69a056c46e 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -265,6 +265,17 @@ def ApplyUnrollFromElementsPatternsOp : Op<Transform_Dialect,
   let assemblyFormat = "attr-dict";
 }
 
+def ApplyUnrollToElementsPatternsOp : Op<Transform_Dialect,
+    "apply_patterns.vector.unroll_to_elements",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Indicates that vector to_elements operations should be unrolled
+    along the outermost dimension.
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
+
 def ApplyLowerScanPatternsOp : Op<Transform_Dialect,
     "apply_patterns.vector.lower_scan",
     [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {

diff  --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 47f96112a9433..f56124cb4fb95 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -311,6 +311,12 @@ void populateVectorToFromElementsToShuffleTreePatterns(
 void populateVectorFromElementsLoweringPatterns(RewritePatternSet &patterns,
                                                 PatternBenefit benefit = 1);
 
+/// Populate the pattern set with the following patterns:
+///
+/// [UnrollToElements]
+void populateVectorToElementsLoweringPatterns(RewritePatternSet &patterns,
+                                              PatternBenefit benefit = 1);
+
 /// Populate the pattern set with the following patterns:
 ///
 /// [ContractionOpToMatmulOpLowering]

diff  --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index ace26990601c8..97163c4532378 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -255,6 +255,12 @@ using UnrollVectorOpFn =
 LogicalResult unrollVectorOp(Operation *op, PatternRewriter &rewriter,
                              UnrollVectorOpFn unrollFn);
 
+/// Generic utility for unrolling values of type vector<NxAxBx...>
+/// to N values of type vector<AxBx...> using vector.extract. If the input
+/// is rank-1 or has leading scalable dimension, failure is returned.
+FailureOr<SmallVector<Value>> unrollVectorValue(TypedValue<VectorType>,
+                                                RewriterBase &);
+
 } // namespace vector
 
 /// Constructs a permutation map of invariant memref indices to vector

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 9852df6970fdc..0b44ca7ceee42 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -95,6 +95,7 @@ void ConvertVectorToLLVMPass::runOnOperation() {
     populateVectorRankReducingFMAPattern(patterns);
     populateVectorGatherLoweringPatterns(patterns);
     populateVectorFromElementsLoweringPatterns(patterns);
+    populateVectorToElementsLoweringPatterns(patterns);
     if (armI8MM) {
       if (armNeon)
         arm_neon::populateLowerContractionToNeonI8MMPatterns(patterns);

diff  --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index fe066dc04ad55..6bb390aa09d3e 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -144,6 +144,11 @@ void transform::ApplyUnrollFromElementsPatternsOp::populatePatterns(
   vector::populateVectorFromElementsLoweringPatterns(patterns);
 }
 
+void transform::ApplyUnrollToElementsPatternsOp::populatePatterns(
+    RewritePatternSet &patterns) {
+  vector::populateVectorToElementsLoweringPatterns(patterns);
+}
+
 void transform::ApplyLowerScanPatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
   vector::populateVectorScanLoweringPatterns(patterns);

diff  --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index acbf2b746037b..d74007f13a95b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -11,6 +11,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
   LowerVectorScan.cpp
   LowerVectorShapeCast.cpp
   LowerVectorStep.cpp
+  LowerVectorToElements.cpp
   LowerVectorToFromElementsToShuffleTree.cpp
   LowerVectorTransfer.cpp
   LowerVectorTranspose.cpp

diff  --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
new file mode 100644
index 0000000000000..a53a183ec31bc
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
@@ -0,0 +1,53 @@
+//===- LowerVectorToElements.cpp - Lower 'vector.to_elements' op ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements target-independent rewrites and utilities to lower the
+// 'vector.to_elements' operation.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+
+#define DEBUG_TYPE "lower-vector-to-elements"
+
+using namespace mlir;
+
+namespace {
+
+struct UnrollToElements final : public OpRewritePattern<vector::ToElementsOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ToElementsOp op,
+                                PatternRewriter &rewriter) const override {
+
+    TypedValue<VectorType> source = op.getSource();
+    FailureOr<SmallVector<Value>> result =
+        vector::unrollVectorValue(source, rewriter);
+    if (failed(result)) {
+      return failure();
+    }
+    SmallVector<Value> vectors = *result;
+
+    SmallVector<Value> results;
+    for (const Value &vector : vectors) {
+      auto subElements =
+          vector::ToElementsOp::create(rewriter, op.getLoc(), vector);
+      llvm::append_range(results, subElements.getResults());
+    }
+    rewriter.replaceOp(op, results);
+    return success();
+  }
+};
+
+} // namespace
+
+void mlir::vector::populateVectorToElementsLoweringPatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit) {
+  patterns.add<UnrollToElements>(patterns.getContext(), benefit);
+}

diff  --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 841e1384e03b3..39dc7a4f284a6 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -393,6 +393,41 @@ vector::isValidMaskedInputVector(ArrayRef<int64_t> shape,
   return success();
 }
 
+/// Takes a 2+ dimensional vector as an input
+/// returns n vector values produced by n vector.extract operations.
+/// I.e. calling unrollVectorValue([[%v]], rewriter) such that
+///
+///   %v : vector<nxaxb...>
+///
+/// will produce the following IR changes
+///
+///   %v0 = vector.extract %v[0] : vector<axbx...> from vector<nxaxb...>
+///   %v1 = vector.extract %v[1] : vector<axbx...> from vector<nxaxb...>
+///   ...
+///   %vnminusone = vector.extract %v[n-1] : vector<axbx...> from ...
+///
+/// and returns SmallVector<Value> r = {[[%v0]], [[%v1]], ..., [[%vnminusone]]}
+FailureOr<SmallVector<Value>>
+vector::unrollVectorValue(TypedValue<VectorType> vector,
+                          RewriterBase &rewriter) {
+  SmallVector<Value> subvectors;
+  VectorType ty = cast<VectorType>(vector.getType());
+  Location loc = vector.getLoc();
+  if (ty.getRank() < 2)
+    return rewriter.notifyMatchFailure(loc, "already 1-D");
+
+  // Unrolling doesn't take vscale into account. Pattern is disabled for
+  // vectors with leading scalable dim(s).
+  if (ty.getScalableDims().front())
+    return rewriter.notifyMatchFailure(loc, "cannot unroll scalable dim");
+
+  for (int64_t i = 0, e = ty.getShape().front(); i < e; ++i) {
+    subvectors.push_back(vector::ExtractOp::create(rewriter, loc, vector, i));
+  }
+
+  return subvectors;
+}
+
 LogicalResult vector::unrollVectorOp(Operation *op, PatternRewriter &rewriter,
                                      vector::UnrollVectorOpFn unrollFn) {
   assert(op->getNumResults() == 1 && "expected single result");

diff  --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 07d335117de01..2d33888854ea7 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1774,3 +1774,45 @@ func.func @from_elements_3d(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32) -> v
   %0 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x1x2xf32>
   return %0 : vector<2x1x2xf32>
 }
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// vector.to_elements
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func @to_elements_1d(
+// CHECK-SAME:    %[[ARG0:.+]]: vector<2xf32>
+// CHECK:         %[[C0:.+]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK:         %[[V0:.+]] = llvm.extractelement %[[ARG0]][%[[C0]] : i64] : vector<2xf32>
+// CHECK:         %[[C1:.+]] = llvm.mlir.constant(1 : i64) : i64
+// CHECK:         %[[V1:.+]] = llvm.extractelement %[[ARG0]][%[[C1]] : i64] : vector<2xf32>
+// CHECK:         return %[[V0]], %[[V1]]
+func.func @to_elements_1d(%arg0: vector<2xf32>) -> (f32, f32) {
+  %0:2 = vector.to_elements %arg0 : vector<2xf32>
+  return %0#0, %0#1 : f32, f32
+}
+
+// -----
+
+// NOTE: We unroll multi-dimensional to_elements ops with pattern
+// `UnrollToElements` and then convert the 1-D to_elements ops to llvm.
+
+// CHECK-LABEL: func @to_elements_2d(
+// CHECK-SAME:    %[[ARG0:.+]]: vector<2x2xf32>
+// CHECK:         %[[CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : vector<2x2xf32> to !llvm.array<2 x vector<2xf32>>
+// CHECK:         %[[V0:.+]] = llvm.extractvalue %[[CAST]][0] : !llvm.array<2 x vector<2xf32>>
+// CHECK:         %[[V1:.+]] = llvm.extractvalue %[[CAST]][1] : !llvm.array<2 x vector<2xf32>>
+// CHECK:         %[[C0:.+]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK:         %[[R0:.+]] = llvm.extractelement %[[V0]][%[[C0]] : i64] : vector<2xf32>
+// CHECK:         %[[C1:.+]] = llvm.mlir.constant(1 : i64) : i64
+// CHECK:         %[[R1:.+]] = llvm.extractelement %[[V0]][%[[C1]] : i64] : vector<2xf32>
+// CHECK:         %[[C0:.+]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK:         %[[R2:.+]] = llvm.extractelement %[[V1]][%[[C0]] : i64] : vector<2xf32>
+// CHECK:         %[[C1:.+]] = llvm.mlir.constant(1 : i64) : i64
+// CHECK:         %[[R3:.+]] = llvm.extractelement %[[V1]][%[[C1]] : i64] : vector<2xf32>
+// CHECK:         return %[[R0]], %[[R1]], %[[R2]], %[[R3]]
+func.func @to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) {
+  %0:4 = vector.to_elements %arg0 : vector<2x2xf32>
+  return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
+}

diff  --git a/mlir/test/Dialect/Vector/lit.local.cfg b/mlir/test/Dialect/Vector/lit.local.cfg
new file mode 100644
index 0000000000000..3e9e8f8497624
--- /dev/null
+++ b/mlir/test/Dialect/Vector/lit.local.cfg
@@ -0,0 +1,2 @@
+# Skip the directory with input TD sequences.
+config.excludes = ["td"]

diff  --git a/mlir/test/Dialect/Vector/td/unroll-elements.mlir b/mlir/test/Dialect/Vector/td/unroll-elements.mlir
new file mode 100644
index 0000000000000..40a90a33b0ac4
--- /dev/null
+++ b/mlir/test/Dialect/Vector/td/unroll-elements.mlir
@@ -0,0 +1,11 @@
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @unroll_to_elements(%module_op: !transform.any_op {transform.readonly}) {
+    %f = transform.structured.match ops{["func.func"]} in %module_op
+      : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %f {
+      transform.apply_patterns.vector.transfer_permutation_patterns
+      transform.apply_patterns.vector.unroll_to_elements
+    } : !transform.any_op
+    transform.yield
+  }
+}

diff  --git a/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir b/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir
new file mode 100644
index 0000000000000..9ec0d76599c41
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir
@@ -0,0 +1,26 @@
+// RUN: mlir-opt %s -test-unroll-vector-to-elements -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -transform-preload-library='transform-library-paths=%p/td/unroll-elements.mlir' \
+// RUN: -transform-interpreter=entry-point=unroll_to_elements | FileCheck %s
+
+// CHECK-LABEL: func.func @to_elements_1d(
+// CHECK-SAME:    %[[ARG0:.+]]: vector<2xf32>
+// CHECK:         %[[RES:.+]]:2 = vector.to_elements %[[ARG0]] : vector<2xf32>
+// CHECK:         return %[[RES]]#0, %[[RES]]#1
+func.func @to_elements_1d(%arg0: vector<2xf32>) -> (f32, f32) {
+  %0:2 = vector.to_elements %arg0 : vector<2xf32>
+  return %0#0, %0#1 : f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @to_elements_2d(
+// CHECK-SAME:    %[[ARG0:.+]]: vector<2x2xf32>
+// CHECK:         %[[VEC0:.+]] = vector.extract %[[ARG0]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK:         %[[VEC1:.+]] = vector.extract %[[ARG0]][1] : vector<2xf32> from vector<2x2xf32>
+// CHECK:         %[[RES0:.+]]:2 = vector.to_elements %[[VEC0]] : vector<2xf32>
+// CHECK:         %[[RES1:.+]]:2 = vector.to_elements %[[VEC1]] : vector<2xf32>
+// CHECK:         return %[[RES0]]#0, %[[RES0]]#1, %[[RES1]]#0, %[[RES1]]#1
+func.func @to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) {
+  %0:4 = vector.to_elements %arg0 : vector<2x2xf32>
+  return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
+}

diff  --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index bb1598ee3efe5..d6596cd341df7 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -808,6 +808,28 @@ struct TestUnrollVectorFromElements
   }
 };
 
+struct TestUnrollVectorToElements
+    : public PassWrapper<TestUnrollVectorToElements,
+                         OperationPass<func::FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestUnrollVectorToElements)
+
+  StringRef getArgument() const final {
+    return "test-unroll-vector-to-elements";
+  }
+  StringRef getDescription() const final {
+    return "Test unrolling patterns for to_elements ops";
+  }
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<func::FuncDialect, vector::VectorDialect>();
+  }
+
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    populateVectorToElementsLoweringPatterns(patterns);
+    (void)applyPatternsGreedily(getOperation(), std::move(patterns));
+  }
+};
+
 struct TestFoldArithExtensionIntoVectorContractPatterns
     : public PassWrapper<TestFoldArithExtensionIntoVectorContractPatterns,
                          OperationPass<func::FuncOp>> {
@@ -1083,6 +1105,8 @@ void registerTestVectorLowerings() {
 
   PassRegistration<TestUnrollVectorFromElements>();
 
+  PassRegistration<TestUnrollVectorToElements>();
+
   PassRegistration<TestFoldArithExtensionIntoVectorContractPatterns>();
 
   PassRegistration<TestVectorEmulateMaskedLoadStore>();

diff  --git a/mlir/test/python/dialects/transform_vector_ext.py b/mlir/test/python/dialects/transform_vector_ext.py
index 5a648fe073315..28902b012f7cb 100644
--- a/mlir/test/python/dialects/transform_vector_ext.py
+++ b/mlir/test/python/dialects/transform_vector_ext.py
@@ -48,6 +48,8 @@ def non_configurable_patterns():
     vector.ApplyLowerGatherPatternsOp()
     # CHECK: transform.apply_patterns.vector.unroll_from_elements
     vector.ApplyUnrollFromElementsPatternsOp()
+    # CHECK: transform.apply_patterns.vector.unroll_to_elements
+    vector.ApplyUnrollToElementsPatternsOp()
     # CHECK: transform.apply_patterns.vector.lower_scan
     vector.ApplyLowerScanPatternsOp()
     # CHECK: transform.apply_patterns.vector.lower_shape_cast


        


More information about the Mlir-commits mailing list