[Mlir-commits] [mlir] [mlir][vector] Add vector.to_elements unrolling (PR #157142)
Erick Ochoa Lopez
llvmlistbot at llvm.org
Tue Sep 9 12:41:51 PDT 2025
https://github.com/amd-eochoalo updated https://github.com/llvm/llvm-project/pull/157142
>From d661413b787a15a91f772e2570333aad95166a68 Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Thu, 4 Sep 2025 17:19:24 -0700
Subject: [PATCH 01/13] [mlir][vector] Add support for lowering n-D
vector.to_elements op.
The revision adds a pattern that flattens 2 or more dimensional
`vector.to_elements` ops by `vector.shape_cast` + `vector.to_elements`.
It also adds the lowering pattern to ConvertVectorToLLVMPass and
complete the tests.
It recovers the e2e lowering breakage from https://github.com/llvm/llvm-project/commit/b4c31dc98dfc929728904cd96f0f4cf812c4d5b5 on LLVM path.
Signed-off-by: hanhanW <hanhan0912 at gmail.com>
---
.../Vector/Transforms/LoweringPatterns.h | 6 +++
.../VectorToLLVM/ConvertVectorToLLVMPass.cpp | 1 +
.../Dialect/Vector/Transforms/CMakeLists.txt | 1 +
.../Transforms/LowerVectorToElements.cpp | 52 +++++++++++++++++++
.../VectorToLLVM/vector-to-llvm.mlir | 40 ++++++++++++++
.../Vector/vector-to-elements-lowering.mlir | 22 ++++++++
.../Dialect/Vector/TestVectorTransforms.cpp | 24 +++++++++
7 files changed, 146 insertions(+)
create mode 100644 mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
create mode 100644 mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 47f96112a9433..e0f744841db2b 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -311,6 +311,12 @@ void populateVectorToFromElementsToShuffleTreePatterns(
void populateVectorFromElementsLoweringPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
+/// Populate the pattern set with the following patterns:
+///
+/// [FlattenToElements]
+void populateVectorToElementsLoweringPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
+
/// Populate the pattern set with the following patterns:
///
/// [ContractionOpToMatmulOpLowering]
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 9852df6970fdc..0b44ca7ceee42 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -95,6 +95,7 @@ void ConvertVectorToLLVMPass::runOnOperation() {
populateVectorRankReducingFMAPattern(patterns);
populateVectorGatherLoweringPatterns(patterns);
populateVectorFromElementsLoweringPatterns(patterns);
+ populateVectorToElementsLoweringPatterns(patterns);
if (armI8MM) {
if (armNeon)
arm_neon::populateLowerContractionToNeonI8MMPatterns(patterns);
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index acbf2b746037b..d74007f13a95b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -11,6 +11,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
LowerVectorScan.cpp
LowerVectorShapeCast.cpp
LowerVectorStep.cpp
+ LowerVectorToElements.cpp
LowerVectorToFromElementsToShuffleTree.cpp
LowerVectorTransfer.cpp
LowerVectorTranspose.cpp
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
new file mode 100644
index 0000000000000..014034b8f9737
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
@@ -0,0 +1,52 @@
+//===- LowerVectorToElements.cpp - Lower 'vector.to_elements' op ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements target-independent rewrites and utilities to lower the
+// 'vector.to_elements' operation.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+
+#define DEBUG_TYPE "lower-vector-to-elements"
+
+using namespace mlir;
+
+namespace {
+
+/// Flattens 2 or more dimensional `vector.to_elements` ops by
+/// `vector.shape_cast` + `vector.to_elements`.
+struct FlattenToElements : OpRewritePattern<vector::ToElementsOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ToElementsOp op,
+ PatternRewriter &rewriter) const override {
+ VectorType vecType = op.getSource().getType();
+ if (vecType.getRank() <= 1)
+ return rewriter.notifyMatchFailure(
+ op, "the rank is already less than or equal to 1");
+ if (vecType.getNumScalableDims() > 0)
+ return rewriter.notifyMatchFailure(
+ op, "scalable vector is not yet supported");
+ auto vec1DType =
+ VectorType::get({vecType.getNumElements()}, vecType.getElementType());
+ Value shapeCast = vector::ShapeCastOp::create(rewriter, op.getLoc(),
+ vec1DType, op.getSource());
+ rewriter.replaceOpWithNewOp<vector::ToElementsOp>(op, op.getResultTypes(),
+ shapeCast);
+ return success();
+ }
+};
+
+} // namespace
+
+void mlir::vector::populateVectorToElementsLoweringPatterns(
+ RewritePatternSet &patterns, PatternBenefit benefit) {
+ patterns.add<FlattenToElements>(patterns.getContext(), benefit);
+}
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 07d335117de01..bf4b05f7874de 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1774,3 +1774,43 @@ func.func @from_elements_3d(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32) -> v
%0 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x1x2xf32>
return %0 : vector<2x1x2xf32>
}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// vector.to_elements
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func @to_elements_1d(
+// CHECK-SAME: %[[ARG0:.+]]: vector<2xf32>
+// CHECK: %[[C0:.+]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK: %[[V0:.+]] = llvm.extractelement %[[ARG0]][%[[C0]] : i64] : vector<2xf32>
+// CHECK: %[[C1:.+]] = llvm.mlir.constant(1 : i64) : i64
+// CHECK: %[[V1:.+]] = llvm.extractelement %[[ARG0]][%[[C1]] : i64] : vector<2xf32>
+// CHECK: return %[[V0]], %[[V1]]
+func.func @to_elements_1d(%arg0: vector<2xf32>) -> (f32, f32) {
+ %0:2 = vector.to_elements %arg0 : vector<2xf32>
+ return %0#0, %0#1 : f32, f32
+}
+
+// -----
+
+// NOTE: We flatten multi-dimensional to_elements ops with pattern
+// `FlattenToElements` and then convert the 1-D to_elements ops to llvm.
+
+// CHECK-LABEL: func @to_elements_2d(
+// CHECK-SAME: %[[ARG0:.+]]: vector<2x2xf32>
+// CHECK: %[[CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : vector<2x2xf32> to !llvm.array<2 x vector<2xf32>>
+// CHECK: %[[C0:.+]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK: %[[V0:.+]] = llvm.extractelement %{{.+}}[%[[C0]] : i64] : vector<4xf32>
+// CHECK: %[[C1:.+]] = llvm.mlir.constant(1 : i64) : i64
+// CHECK: %[[V1:.+]] = llvm.extractelement %{{.+}}[%[[C1]] : i64] : vector<4xf32>
+// CHECK: %[[C2:.+]] = llvm.mlir.constant(2 : i64) : i64
+// CHECK: %[[V2:.+]] = llvm.extractelement %{{.+}}[%[[C2]] : i64] : vector<4xf32>
+// CHECK: %[[C3:.+]] = llvm.mlir.constant(3 : i64) : i64
+// CHECK: %[[V3:.+]] = llvm.extractelement %{{.+}}[%[[C3]] : i64] : vector<4xf32>
+// CHECK: return %[[V0]], %[[V1]], %[[V2]], %[[V3]]
+func.func @to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) {
+ %0:4 = vector.to_elements %arg0 : vector<2x2xf32>
+ return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
+}
diff --git a/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir b/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir
new file mode 100644
index 0000000000000..a57521c4db467
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir
@@ -0,0 +1,22 @@
+// RUN: mlir-opt %s -test-flatten-vector-to-elements -split-input-file | FileCheck %s
+
+// CHECK-LABEL: func.func @to_elements_1d(
+// CHECK-SAME: %[[ARG0:.+]]: vector<2xf32>
+// CHECK: %[[RES:.+]]:2 = vector.to_elements %[[ARG0]] : vector<2xf32>
+// CHECK: return %[[RES]]#0, %[[RES]]#1
+func.func @to_elements_1d(%arg0: vector<2xf32>) -> (f32, f32) {
+ %0:2 = vector.to_elements %arg0 : vector<2xf32>
+ return %0#0, %0#1 : f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @to_elements_2d(
+// CHECK-SAME: %[[ARG0:.+]]: vector<2x2xf32>
+// CHECK: %[[CAST:.+]] = vector.shape_cast %[[ARG0]]
+// CHECK: %[[RES:.+]]:4 = vector.to_elements %[[CAST]] : vector<4xf32>
+// CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2, %[[RES]]#3
+func.func @to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) {
+ %0:4 = vector.to_elements %arg0 : vector<2x2xf32>
+ return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
+}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index bb1598ee3efe5..560a1331bdaf0 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -808,6 +808,28 @@ struct TestUnrollVectorFromElements
}
};
+struct TestFlattenVectorToElements
+ : public PassWrapper<TestFlattenVectorToElements,
+ OperationPass<func::FuncOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFlattenVectorToElements)
+
+ StringRef getArgument() const final {
+ return "test-flatten-vector-to-elements";
+ }
+ StringRef getDescription() const final {
+ return "Test flattening patterns for to_elements ops";
+ }
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<func::FuncDialect, vector::VectorDialect>();
+ }
+
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ populateVectorToElementsLoweringPatterns(patterns);
+ (void)applyPatternsGreedily(getOperation(), std::move(patterns));
+ }
+};
+
struct TestFoldArithExtensionIntoVectorContractPatterns
: public PassWrapper<TestFoldArithExtensionIntoVectorContractPatterns,
OperationPass<func::FuncOp>> {
@@ -1083,6 +1105,8 @@ void registerTestVectorLowerings() {
PassRegistration<TestUnrollVectorFromElements>();
+ PassRegistration<TestFlattenVectorToElements>();
+
PassRegistration<TestFoldArithExtensionIntoVectorContractPatterns>();
PassRegistration<TestVectorEmulateMaskedLoadStore>();
>From edf4019862d3448a31bd2d4052fc9d11259a7e37 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 5 Sep 2025 09:48:27 -0700
Subject: [PATCH 02/13] Add new populate patterns for flattening.
---
.../mlir/Dialect/Vector/Transforms/LoweringPatterns.h | 6 ++++++
.../lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp | 5 +++++
mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp | 2 +-
3 files changed, 12 insertions(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index e0f744841db2b..c39c9d4ae00c9 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -317,6 +317,12 @@ void populateVectorFromElementsLoweringPatterns(RewritePatternSet &patterns,
void populateVectorToElementsLoweringPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
+/// Populate the pattern set with the following patterns:
+///
+/// [FlattenToElements]
+void populateVectorToElementsFlatteningPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
+
/// Populate the pattern set with the following patterns:
///
/// [ContractionOpToMatmulOpLowering]
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
index 014034b8f9737..33c5d2cb33369 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
@@ -50,3 +50,8 @@ void mlir::vector::populateVectorToElementsLoweringPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<FlattenToElements>(patterns.getContext(), benefit);
}
+
+void mlir::vector::populateVectorToElementsFlatteningPatterns(
+ RewritePatternSet &patterns, PatternBenefit benefit) {
+ patterns.add<FlattenToElements>(patterns.getContext(), benefit);
+}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 560a1331bdaf0..01a00509c7331 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -825,7 +825,7 @@ struct TestFlattenVectorToElements
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
- populateVectorToElementsLoweringPatterns(patterns);
+ populateVectorToElementsFlatteningPatterns(patterns);
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};
>From 12fd6fc77b44f25020aee2b0193a02117d2fc1e1 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 5 Sep 2025 08:49:00 -0700
Subject: [PATCH 03/13] [mlir][vector] Add function to unroll vectors.
Extract n vector<axbx...> from vector<nxaxbx...>
This patch adds a utility function that will unroll vector values.
This is different from the current utility function that focuses
on unrolling vector operations.
---
.../mlir/Dialect/Vector/Utils/VectorUtils.h | 3 +++
mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp | 21 +++++++++++++++++++
2 files changed, 24 insertions(+)
diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index ace26990601c8..95f2ee5a7ac1d 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -255,6 +255,9 @@ using UnrollVectorOpFn =
LogicalResult unrollVectorOp(Operation *op, PatternRewriter &rewriter,
UnrollVectorOpFn unrollFn);
+LogicalResult unrollVectorValue(Value vector, PatternRewriter &rewriter,
+ SmallVector<Value> &values);
+
} // namespace vector
/// Constructs a permutation map of invariant memref indices to vector
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 841e1384e03b3..cbedd9563fc29 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -393,6 +393,27 @@ vector::isValidMaskedInputVector(ArrayRef<int64_t> shape,
return success();
}
+LogicalResult vector::unrollVectorValue(Value vector, PatternRewriter &rewriter,
+ SmallVector<Value> &subvectors) {
+ assert(isa<VectorType>(vector.getType()) && "expected vector type");
+ VectorType ty = cast<VectorType>(vector.getType());
+ Location loc = vector.getLoc();
+ if (ty.getRank() < 2)
+ return rewriter.notifyMatchFailure(loc, "already 1-D");
+
+ // Unrolling doesn't take vscale into account. Pattern is disabled for
+ // vectors with leading scalable dim(s).
+ if (ty.getScalableDims().front())
+ return rewriter.notifyMatchFailure(loc, "cannot unroll scalable dim");
+
+ // We just need zero indices for the all dimensions except the leading one.
+ for (int64_t i = 0, e = ty.getShape().front(); i < e; ++i) {
+ subvectors.push_back(vector::ExtractOp::create(rewriter, loc, vector, i));
+ }
+
+ return success();
+}
+
LogicalResult vector::unrollVectorOp(Operation *op, PatternRewriter &rewriter,
vector::UnrollVectorOpFn unrollFn) {
assert(op->getNumResults() == 1 && "expected single result");
>From 0415e3b066c258e2b12aba4e074f62c07b1ee3d1 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 5 Sep 2025 09:39:24 -0700
Subject: [PATCH 04/13] [mlir][vector] Add vector.to_elements unrolling.
---
.../Vector/Transforms/LoweringPatterns.h | 2 +-
.../Transforms/LowerVectorToElements.cpp | 28 ++++++++++++++++++-
2 files changed, 28 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index c39c9d4ae00c9..31150a2afc19f 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -313,7 +313,7 @@ void populateVectorFromElementsLoweringPatterns(RewritePatternSet &patterns,
/// Populate the pattern set with the following patterns:
///
-/// [FlattenToElements]
+/// [UnrollToElements]
void populateVectorToElementsLoweringPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
index 33c5d2cb33369..b86e8b274770f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
@@ -20,6 +20,32 @@ using namespace mlir;
namespace {
+struct UnrollToElements : OpRewritePattern<vector::ToElementsOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ToElementsOp op,
+ PatternRewriter &rewriter) const override {
+ SmallVector<Value> vectors;
+ LogicalResult match =
+ mlir::vector::unrollVectorValue(op.getSource(), rewriter, vectors);
+ if (failed(match)) {
+ return match;
+ }
+
+ // May be large vector.
+ std::vector<Value> results;
+ for (const auto &vector : vectors) {
+ // we need to replace the current result
+ auto subElements =
+ rewriter.create<vector::ToElementsOp>(op.getLoc(), vector);
+ results.insert(results.end(), subElements.getResults().begin(),
+ subElements.getResults().end());
+ }
+ rewriter.replaceOp(op, results);
+ return success();
+ }
+};
+
/// Flattens 2 or more dimensional `vector.to_elements` ops by
/// `vector.shape_cast` + `vector.to_elements`.
struct FlattenToElements : OpRewritePattern<vector::ToElementsOp> {
@@ -48,7 +74,7 @@ struct FlattenToElements : OpRewritePattern<vector::ToElementsOp> {
void mlir::vector::populateVectorToElementsLoweringPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
- patterns.add<FlattenToElements>(patterns.getContext(), benefit);
+ patterns.add<UnrollToElements>(patterns.getContext(), benefit);
}
void mlir::vector::populateVectorToElementsFlatteningPatterns(
>From 1b21117c35c588486aa737bdc7f5a37f5c52c1fb Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 5 Sep 2025 10:02:39 -0700
Subject: [PATCH 05/13] Split tests for unrolling and flattening to elements
---
.../Vector/vector-to-elements-flattening.mlir | 22 +++++++++++++++++
.../Vector/vector-to-elements-lowering.mlir | 10 ++++----
.../Dialect/Vector/TestVectorTransforms.cpp | 24 +++++++++++++++++++
3 files changed, 52 insertions(+), 4 deletions(-)
create mode 100644 mlir/test/Dialect/Vector/vector-to-elements-flattening.mlir
diff --git a/mlir/test/Dialect/Vector/vector-to-elements-flattening.mlir b/mlir/test/Dialect/Vector/vector-to-elements-flattening.mlir
new file mode 100644
index 0000000000000..a57521c4db467
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-to-elements-flattening.mlir
@@ -0,0 +1,22 @@
+// RUN: mlir-opt %s -test-flatten-vector-to-elements -split-input-file | FileCheck %s
+
+// CHECK-LABEL: func.func @to_elements_1d(
+// CHECK-SAME: %[[ARG0:.+]]: vector<2xf32>
+// CHECK: %[[RES:.+]]:2 = vector.to_elements %[[ARG0]] : vector<2xf32>
+// CHECK: return %[[RES]]#0, %[[RES]]#1
+func.func @to_elements_1d(%arg0: vector<2xf32>) -> (f32, f32) {
+ %0:2 = vector.to_elements %arg0 : vector<2xf32>
+ return %0#0, %0#1 : f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @to_elements_2d(
+// CHECK-SAME: %[[ARG0:.+]]: vector<2x2xf32>
+// CHECK: %[[CAST:.+]] = vector.shape_cast %[[ARG0]]
+// CHECK: %[[RES:.+]]:4 = vector.to_elements %[[CAST]] : vector<4xf32>
+// CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2, %[[RES]]#3
+func.func @to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) {
+ %0:4 = vector.to_elements %arg0 : vector<2x2xf32>
+ return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
+}
diff --git a/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir b/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir
index a57521c4db467..e302dbd174322 100644
--- a/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-flatten-vector-to-elements -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-unroll-vector-to-elements -split-input-file | FileCheck %s
// CHECK-LABEL: func.func @to_elements_1d(
// CHECK-SAME: %[[ARG0:.+]]: vector<2xf32>
@@ -13,9 +13,11 @@ func.func @to_elements_1d(%arg0: vector<2xf32>) -> (f32, f32) {
// CHECK-LABEL: func.func @to_elements_2d(
// CHECK-SAME: %[[ARG0:.+]]: vector<2x2xf32>
-// CHECK: %[[CAST:.+]] = vector.shape_cast %[[ARG0]]
-// CHECK: %[[RES:.+]]:4 = vector.to_elements %[[CAST]] : vector<4xf32>
-// CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2, %[[RES]]#3
+// CHECK: %[[VEC0:.+]] = vector.extract %[[ARG0]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[VEC1:.+]] = vector.extract %[[ARG0]][1] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[RES0:.+]]:2 = vector.to_elements %[[VEC0]] : vector<2xf32>
+// CHECK: %[[RES1:.+]]:2 = vector.to_elements %[[VEC1]] : vector<2xf32>
+// CHECK: return %[[RES0]]#0, %[[RES0]]#1, %[[RES1]]#0, %[[RES1]]#1
func.func @to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) {
%0:4 = vector.to_elements %arg0 : vector<2x2xf32>
return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 01a00509c7331..093134c119cea 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -830,6 +830,28 @@ struct TestFlattenVectorToElements
}
};
+struct TestUnrollVectorToElements
+ : public PassWrapper<TestUnrollVectorToElements,
+ OperationPass<func::FuncOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestUnrollVectorToElements)
+
+ StringRef getArgument() const final {
+ return "test-unroll-vector-to-elements";
+ }
+ StringRef getDescription() const final {
+ return "Test unrolling patterns for to_elements ops";
+ }
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<func::FuncDialect, vector::VectorDialect>();
+ }
+
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ populateVectorToElementsLoweringPatterns(patterns);
+ (void)applyPatternsGreedily(getOperation(), std::move(patterns));
+ }
+};
+
struct TestFoldArithExtensionIntoVectorContractPatterns
: public PassWrapper<TestFoldArithExtensionIntoVectorContractPatterns,
OperationPass<func::FuncOp>> {
@@ -1105,6 +1127,8 @@ void registerTestVectorLowerings() {
PassRegistration<TestUnrollVectorFromElements>();
+ PassRegistration<TestUnrollVectorToElements>();
+
PassRegistration<TestFlattenVectorToElements>();
PassRegistration<TestFoldArithExtensionIntoVectorContractPatterns>();
>From 567af4bde157e7c266dcf0df5be645d370a90086 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 5 Sep 2025 10:09:47 -0700
Subject: [PATCH 06/13] Fix test
---
.../VectorToLLVM/vector-to-llvm.mlir | 20 ++++++++++---------
1 file changed, 11 insertions(+), 9 deletions(-)
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index bf4b05f7874de..2d33888854ea7 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1795,21 +1795,23 @@ func.func @to_elements_1d(%arg0: vector<2xf32>) -> (f32, f32) {
// -----
-// NOTE: We flatten multi-dimensional to_elements ops with pattern
-// `FlattenToElements` and then convert the 1-D to_elements ops to llvm.
+// NOTE: We unroll multi-dimensional to_elements ops with pattern
+// `UnrollToElements` and then convert the 1-D to_elements ops to llvm.
// CHECK-LABEL: func @to_elements_2d(
// CHECK-SAME: %[[ARG0:.+]]: vector<2x2xf32>
// CHECK: %[[CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : vector<2x2xf32> to !llvm.array<2 x vector<2xf32>>
+// CHECK: %[[V0:.+]] = llvm.extractvalue %[[CAST]][0] : !llvm.array<2 x vector<2xf32>>
+// CHECK: %[[V1:.+]] = llvm.extractvalue %[[CAST]][1] : !llvm.array<2 x vector<2xf32>>
// CHECK: %[[C0:.+]] = llvm.mlir.constant(0 : i64) : i64
-// CHECK: %[[V0:.+]] = llvm.extractelement %{{.+}}[%[[C0]] : i64] : vector<4xf32>
+// CHECK: %[[R0:.+]] = llvm.extractelement %[[V0]][%[[C0]] : i64] : vector<2xf32>
// CHECK: %[[C1:.+]] = llvm.mlir.constant(1 : i64) : i64
-// CHECK: %[[V1:.+]] = llvm.extractelement %{{.+}}[%[[C1]] : i64] : vector<4xf32>
-// CHECK: %[[C2:.+]] = llvm.mlir.constant(2 : i64) : i64
-// CHECK: %[[V2:.+]] = llvm.extractelement %{{.+}}[%[[C2]] : i64] : vector<4xf32>
-// CHECK: %[[C3:.+]] = llvm.mlir.constant(3 : i64) : i64
-// CHECK: %[[V3:.+]] = llvm.extractelement %{{.+}}[%[[C3]] : i64] : vector<4xf32>
-// CHECK: return %[[V0]], %[[V1]], %[[V2]], %[[V3]]
+// CHECK: %[[R1:.+]] = llvm.extractelement %[[V0]][%[[C1]] : i64] : vector<2xf32>
+// CHECK: %[[C0:.+]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK: %[[R2:.+]] = llvm.extractelement %[[V1]][%[[C0]] : i64] : vector<2xf32>
+// CHECK: %[[C1:.+]] = llvm.mlir.constant(1 : i64) : i64
+// CHECK: %[[R3:.+]] = llvm.extractelement %[[V1]][%[[C1]] : i64] : vector<2xf32>
+// CHECK: return %[[R0]], %[[R1]], %[[R2]], %[[R3]]
func.func @to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) {
%0:4 = vector.to_elements %arg0 : vector<2x2xf32>
return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
>From c58b5c41bcb66ea43547f87ad811bb57fa0b152e Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 5 Sep 2025 12:25:12 -0700
Subject: [PATCH 07/13] Address review comments
---
.../Transforms/LowerVectorToElements.cpp | 25 +++++++++----------
mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp | 2 --
2 files changed, 12 insertions(+), 15 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
index b86e8b274770f..b897b15d7d690 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
@@ -20,26 +20,25 @@ using namespace mlir;
namespace {
-struct UnrollToElements : OpRewritePattern<vector::ToElementsOp> {
+struct UnrollToElements final : OpRewritePattern<vector::ToElementsOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ToElementsOp op,
PatternRewriter &rewriter) const override {
SmallVector<Value> vectors;
- LogicalResult match =
- mlir::vector::unrollVectorValue(op.getSource(), rewriter, vectors);
- if (failed(match)) {
+ if (LogicalResult match =
+ vector::unrollVectorValue(op.getSource(), rewriter, vectors);
+ failed(match)) {
return match;
}
- // May be large vector.
- std::vector<Value> results;
- for (const auto &vector : vectors) {
+ // May be a large vector.
+ SmallVector<Value, 0> results;
+ for (const Value &vector : vectors) {
// we need to replace the current result
auto subElements =
rewriter.create<vector::ToElementsOp>(op.getLoc(), vector);
- results.insert(results.end(), subElements.getResults().begin(),
- subElements.getResults().end());
+ llvm::append_range(results, subElements.getResults());
}
rewriter.replaceOp(op, results);
return success();
@@ -48,7 +47,7 @@ struct UnrollToElements : OpRewritePattern<vector::ToElementsOp> {
/// Flattens 2 or more dimensional `vector.to_elements` ops by
/// `vector.shape_cast` + `vector.to_elements`.
-struct FlattenToElements : OpRewritePattern<vector::ToElementsOp> {
+struct FlattenToElements final : OpRewritePattern<vector::ToElementsOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ToElementsOp op,
@@ -57,9 +56,9 @@ struct FlattenToElements : OpRewritePattern<vector::ToElementsOp> {
if (vecType.getRank() <= 1)
return rewriter.notifyMatchFailure(
op, "the rank is already less than or equal to 1");
- if (vecType.getNumScalableDims() > 0)
- return rewriter.notifyMatchFailure(
- op, "scalable vector is not yet supported");
+
+ assert(vecType.getNumScalableDims() == 0 &&
+ "scalable vector is not yet supported");
auto vec1DType =
VectorType::get({vecType.getNumElements()}, vecType.getElementType());
Value shapeCast = vector::ShapeCastOp::create(rewriter, op.getLoc(),
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index cbedd9563fc29..8d67cd7e80382 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -395,7 +395,6 @@ vector::isValidMaskedInputVector(ArrayRef<int64_t> shape,
LogicalResult vector::unrollVectorValue(Value vector, PatternRewriter &rewriter,
SmallVector<Value> &subvectors) {
- assert(isa<VectorType>(vector.getType()) && "expected vector type");
VectorType ty = cast<VectorType>(vector.getType());
Location loc = vector.getLoc();
if (ty.getRank() < 2)
@@ -406,7 +405,6 @@ LogicalResult vector::unrollVectorValue(Value vector, PatternRewriter &rewriter,
if (ty.getScalableDims().front())
return rewriter.notifyMatchFailure(loc, "cannot unroll scalable dim");
- // We just need zero indices for the all dimensions except the leading one.
for (int64_t i = 0, e = ty.getShape().front(); i < e; ++i) {
subvectors.push_back(vector::ExtractOp::create(rewriter, loc, vector, i));
}
>From 871f5a5873a765a89983a8d967d458faed6b3f2e Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 5 Sep 2025 12:26:52 -0700
Subject: [PATCH 08/13] Remove unnecessary comment
---
mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp | 1 -
1 file changed, 1 deletion(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
index b897b15d7d690..43ba1fe885e85 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
@@ -35,7 +35,6 @@ struct UnrollToElements final : OpRewritePattern<vector::ToElementsOp> {
// May be a large vector.
SmallVector<Value, 0> results;
for (const Value &vector : vectors) {
- // we need to replace the current result
auto subElements =
rewriter.create<vector::ToElementsOp>(op.getLoc(), vector);
llvm::append_range(results, subElements.getResults());
>From 244eed587f232ae8c67af133887b6e6f3e8e351d Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 5 Sep 2025 13:00:39 -0700
Subject: [PATCH 09/13] Address review comments
---
.../lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp | 6 ++----
1 file changed, 2 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
index 43ba1fe885e85..e1bb07287ae9b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
@@ -26,10 +26,8 @@ struct UnrollToElements final : OpRewritePattern<vector::ToElementsOp> {
LogicalResult matchAndRewrite(vector::ToElementsOp op,
PatternRewriter &rewriter) const override {
SmallVector<Value> vectors;
- if (LogicalResult match =
- vector::unrollVectorValue(op.getSource(), rewriter, vectors);
- failed(match)) {
- return match;
+ if (failed(vector::unrollVectorValue(op.getSource(), rewriter, vectors))) {
+ return failure();
}
// May be a large vector.
>From 9933387796fba4c5155b718df9f1f0d5e20d47e4 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 9 Sep 2025 12:40:49 -0700
Subject: [PATCH 10/13] Add comment
---
mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h | 7 +++++++
1 file changed, 7 insertions(+)
diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index 95f2ee5a7ac1d..985f90f6ed955 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -255,6 +255,13 @@ using UnrollVectorOpFn =
LogicalResult unrollVectorOp(Operation *op, PatternRewriter &rewriter,
UnrollVectorOpFn unrollFn);
+/// Generic utility for mapping values of type vector<nxaxbx...>
+/// to n values of type vector<axbx...>
+/// Follows the following pattern:
+/// 1. Check if already 1-D. If so, return failure.
+/// 2. Check for scalable dimensions. If so, return failure.
+/// 3. Returns the values of n vector.extract operations corresponding
+/// to the outermost dimension.
LogicalResult unrollVectorValue(Value vector, PatternRewriter &rewriter,
SmallVector<Value> &values);
>From 351086b8983a77343a4716e5b1f237d07ec3e488 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 9 Sep 2025 12:18:09 -0700
Subject: [PATCH 11/13] Disable and move vector flattening
---
.../Vector/Transforms/LoweringPatterns.h | 6 ----
.../Transforms/LowerVectorToElements.cpp | 29 -------------------
.../Vector/Transforms/VectorLinearize.cpp | 24 +++++++++++++++
.../Vector/vector-to-elements-flattening.mlir | 22 --------------
.../Dialect/Vector/TestVectorTransforms.cpp | 24 ---------------
5 files changed, 24 insertions(+), 81 deletions(-)
delete mode 100644 mlir/test/Dialect/Vector/vector-to-elements-flattening.mlir
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 31150a2afc19f..f56124cb4fb95 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -317,12 +317,6 @@ void populateVectorFromElementsLoweringPatterns(RewritePatternSet &patterns,
void populateVectorToElementsLoweringPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
-/// Populate the pattern set with the following patterns:
-///
-/// [FlattenToElements]
-void populateVectorToElementsFlatteningPatterns(RewritePatternSet &patterns,
- PatternBenefit benefit = 1);
-
/// Populate the pattern set with the following patterns:
///
/// [ContractionOpToMatmulOpLowering]
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
index e1bb07287ae9b..718b947f13715 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
@@ -42,38 +42,9 @@ struct UnrollToElements final : OpRewritePattern<vector::ToElementsOp> {
}
};
-/// Flattens 2 or more dimensional `vector.to_elements` ops by
-/// `vector.shape_cast` + `vector.to_elements`.
-struct FlattenToElements final : OpRewritePattern<vector::ToElementsOp> {
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(vector::ToElementsOp op,
- PatternRewriter &rewriter) const override {
- VectorType vecType = op.getSource().getType();
- if (vecType.getRank() <= 1)
- return rewriter.notifyMatchFailure(
- op, "the rank is already less than or equal to 1");
-
- assert(vecType.getNumScalableDims() == 0 &&
- "scalable vector is not yet supported");
- auto vec1DType =
- VectorType::get({vecType.getNumElements()}, vecType.getElementType());
- Value shapeCast = vector::ShapeCastOp::create(rewriter, op.getLoc(),
- vec1DType, op.getSource());
- rewriter.replaceOpWithNewOp<vector::ToElementsOp>(op, op.getResultTypes(),
- shapeCast);
- return success();
- }
-};
-
} // namespace
void mlir::vector::populateVectorToElementsLoweringPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<UnrollToElements>(patterns.getContext(), benefit);
}
-
-void mlir::vector::populateVectorToElementsFlatteningPatterns(
- RewritePatternSet &patterns, PatternBenefit benefit) {
- patterns.add<FlattenToElements>(patterns.getContext(), benefit);
-}
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 7dde6311fa809..0f5e9259d4c19 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -798,6 +798,30 @@ struct LinearizeVectorFromElements final
}
};
+/// Flattens 2 or more dimensional `vector.to_elements` ops by
+/// `vector.shape_cast` + `vector.to_elements`.
+struct FlattenToElements final : OpRewritePattern<vector::ToElementsOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ToElementsOp op,
+ PatternRewriter &rewriter) const override {
+ VectorType vecType = op.getSource().getType();
+ if (vecType.getRank() <= 1)
+ return rewriter.notifyMatchFailure(
+ op, "the rank is already less than or equal to 1");
+
+ assert(vecType.getNumScalableDims() == 0 &&
+ "scalable vector is not yet supported");
+ auto vec1DType =
+ VectorType::get({vecType.getNumElements()}, vecType.getElementType());
+ Value shapeCast = vector::ShapeCastOp::create(rewriter, op.getLoc(),
+ vec1DType, op.getSource());
+ rewriter.replaceOpWithNewOp<vector::ToElementsOp>(op, op.getResultTypes(),
+ shapeCast);
+ return success();
+ }
+};
+
} // namespace
/// This method defines the set of operations that are linearizable, and hence
diff --git a/mlir/test/Dialect/Vector/vector-to-elements-flattening.mlir b/mlir/test/Dialect/Vector/vector-to-elements-flattening.mlir
deleted file mode 100644
index a57521c4db467..0000000000000
--- a/mlir/test/Dialect/Vector/vector-to-elements-flattening.mlir
+++ /dev/null
@@ -1,22 +0,0 @@
-// RUN: mlir-opt %s -test-flatten-vector-to-elements -split-input-file | FileCheck %s
-
-// CHECK-LABEL: func.func @to_elements_1d(
-// CHECK-SAME: %[[ARG0:.+]]: vector<2xf32>
-// CHECK: %[[RES:.+]]:2 = vector.to_elements %[[ARG0]] : vector<2xf32>
-// CHECK: return %[[RES]]#0, %[[RES]]#1
-func.func @to_elements_1d(%arg0: vector<2xf32>) -> (f32, f32) {
- %0:2 = vector.to_elements %arg0 : vector<2xf32>
- return %0#0, %0#1 : f32, f32
-}
-
-// -----
-
-// CHECK-LABEL: func.func @to_elements_2d(
-// CHECK-SAME: %[[ARG0:.+]]: vector<2x2xf32>
-// CHECK: %[[CAST:.+]] = vector.shape_cast %[[ARG0]]
-// CHECK: %[[RES:.+]]:4 = vector.to_elements %[[CAST]] : vector<4xf32>
-// CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2, %[[RES]]#3
-func.func @to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) {
- %0:4 = vector.to_elements %arg0 : vector<2x2xf32>
- return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
-}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 093134c119cea..d6596cd341df7 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -808,28 +808,6 @@ struct TestUnrollVectorFromElements
}
};
-struct TestFlattenVectorToElements
- : public PassWrapper<TestFlattenVectorToElements,
- OperationPass<func::FuncOp>> {
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFlattenVectorToElements)
-
- StringRef getArgument() const final {
- return "test-flatten-vector-to-elements";
- }
- StringRef getDescription() const final {
- return "Test flattening patterns for to_elements ops";
- }
- void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<func::FuncDialect, vector::VectorDialect>();
- }
-
- void runOnOperation() override {
- RewritePatternSet patterns(&getContext());
- populateVectorToElementsFlatteningPatterns(patterns);
- (void)applyPatternsGreedily(getOperation(), std::move(patterns));
- }
-};
-
struct TestUnrollVectorToElements
: public PassWrapper<TestUnrollVectorToElements,
OperationPass<func::FuncOp>> {
@@ -1129,8 +1107,6 @@ void registerTestVectorLowerings() {
PassRegistration<TestUnrollVectorToElements>();
- PassRegistration<TestFlattenVectorToElements>();
-
PassRegistration<TestFoldArithExtensionIntoVectorContractPatterns>();
PassRegistration<TestVectorEmulateMaskedLoadStore>();
>From 9f7d15d1c0ba453db4caa9b86281c3dc5f622bfc Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 9 Sep 2025 12:30:50 -0700
Subject: [PATCH 12/13] Re-enable and rename flattening to linearize
---
.../Vector/Transforms/VectorLinearize.cpp | 49 +++++++++++++------
mlir/test/Dialect/Vector/linearize.mlir | 23 +++++++++
2 files changed, 57 insertions(+), 15 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 0f5e9259d4c19..54eb182a9680f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -798,26 +798,45 @@ struct LinearizeVectorFromElements final
}
};
-/// Flattens 2 or more dimensional `vector.to_elements` ops by
-/// `vector.shape_cast` + `vector.to_elements`.
-struct FlattenToElements final : OpRewritePattern<vector::ToElementsOp> {
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(vector::ToElementsOp op,
- PatternRewriter &rewriter) const override {
- VectorType vecType = op.getSource().getType();
+/// This pattern linearizes the operand in `vector.to_elements` operations
+/// by converting the result type to a 1-D vector while preserving all element
+/// values. The transformation creates a linearized `vector.shape_cast`
+/// followed by a `vector.to_elements`.
+///
+/// Example:
+///
+/// %0:4 = vector.to_elements %v : vector<2x2xf32>
+///
+/// is converted to:
+///
+/// %vector_cast = vector.shape_cast %v : vector<2x2xf32> to vector<4xf32>
+/// %0:4 = vector.to_elements %vector_cast : vector<4xf32>
+///
+struct LinearizeVectorToElements final
+ : public OpConversionPattern<vector::ToElementsOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LinearizeVectorToElements(const TypeConverter &typeConverter,
+ MLIRContext *context, PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit) {}
+
+ LogicalResult
+ matchAndRewrite(vector::ToElementsOp toElementsOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ VectorType vecType = toElementsOp.getSource().getType();
if (vecType.getRank() <= 1)
return rewriter.notifyMatchFailure(
- op, "the rank is already less than or equal to 1");
+ toElementsOp, "the rank is already less than or equal to 1");
assert(vecType.getNumScalableDims() == 0 &&
"scalable vector is not yet supported");
auto vec1DType =
VectorType::get({vecType.getNumElements()}, vecType.getElementType());
- Value shapeCast = vector::ShapeCastOp::create(rewriter, op.getLoc(),
- vec1DType, op.getSource());
- rewriter.replaceOpWithNewOp<vector::ToElementsOp>(op, op.getResultTypes(),
- shapeCast);
+ Value shapeCast = vector::ShapeCastOp::create(
+ rewriter, toElementsOp.getLoc(), vec1DType, toElementsOp.getSource());
+ rewriter.replaceOpWithNewOp<vector::ToElementsOp>(
+ toElementsOp, toElementsOp.getResultTypes(), shapeCast);
return success();
}
};
@@ -914,8 +933,8 @@ void mlir::vector::populateVectorLinearizeBasePatterns(
patterns
.add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
LinearizeVectorSplat, LinearizeVectorCreateMask, LinearizeVectorLoad,
- LinearizeVectorStore, LinearizeVectorFromElements>(
- typeConverter, patterns.getContext());
+ LinearizeVectorStore, LinearizeVectorFromElements,
+ LinearizeVectorToElements>(typeConverter, patterns.getContext());
}
void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 5e8bfd0698b33..fe697c8b9c057 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -538,3 +538,26 @@ func.func @test_vector_from_elements(%arg0: f32, %arg1: f32, %arg2: f32, %arg3:
%1 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x2xf32>
return %1 : vector<2x2xf32>
}
+
+// -----
+
+// CHECK-LABEL: func.func @to_elements_1d(
+// CHECK-SAME: %[[ARG0:.+]]: vector<2xf32>
+// CHECK: %[[RES:.+]]:2 = vector.to_elements %[[ARG0]] : vector<2xf32>
+// CHECK: return %[[RES]]#0, %[[RES]]#1
+func.func @to_elements_1d(%arg0: vector<2xf32>) -> (f32, f32) {
+ %0:2 = vector.to_elements %arg0 : vector<2xf32>
+ return %0#0, %0#1 : f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @to_elements_2d(
+// CHECK-SAME: %[[ARG0:.+]]: vector<2x2xf32>
+// CHECK: %[[CAST:.+]] = vector.shape_cast %[[ARG0]]
+// CHECK: %[[RES:.+]]:4 = vector.to_elements %[[CAST]] : vector<4xf32>
+// CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2, %[[RES]]#3
+func.func @to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) {
+ %0:4 = vector.to_elements %arg0 : vector<2x2xf32>
+ return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
+}
>From a8f9b9cb13b1677bd5be6d662f8865f739671208 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 9 Sep 2025 12:33:21 -0700
Subject: [PATCH 13/13] Remove LinearizeToElements
---
.../Vector/Transforms/VectorLinearize.cpp | 47 +------------------
mlir/test/Dialect/Vector/linearize.mlir | 23 ---------
2 files changed, 2 insertions(+), 68 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 54eb182a9680f..7dde6311fa809 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -798,49 +798,6 @@ struct LinearizeVectorFromElements final
}
};
-/// This pattern linearizes the operand in `vector.to_elements` operations
-/// by converting the result type to a 1-D vector while preserving all element
-/// values. The transformation creates a linearized `vector.shape_cast`
-/// followed by a `vector.to_elements`.
-///
-/// Example:
-///
-/// %0:4 = vector.to_elements %v : vector<2x2xf32>
-///
-/// is converted to:
-///
-/// %vector_cast = vector.shape_cast %v : vector<2x2xf32> to vector<4xf32>
-/// %0:4 = vector.to_elements %vector_cast : vector<4xf32>
-///
-struct LinearizeVectorToElements final
- : public OpConversionPattern<vector::ToElementsOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LinearizeVectorToElements(const TypeConverter &typeConverter,
- MLIRContext *context, PatternBenefit benefit = 1)
- : OpConversionPattern(typeConverter, context, benefit) {}
-
- LogicalResult
- matchAndRewrite(vector::ToElementsOp toElementsOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
-
- VectorType vecType = toElementsOp.getSource().getType();
- if (vecType.getRank() <= 1)
- return rewriter.notifyMatchFailure(
- toElementsOp, "the rank is already less than or equal to 1");
-
- assert(vecType.getNumScalableDims() == 0 &&
- "scalable vector is not yet supported");
- auto vec1DType =
- VectorType::get({vecType.getNumElements()}, vecType.getElementType());
- Value shapeCast = vector::ShapeCastOp::create(
- rewriter, toElementsOp.getLoc(), vec1DType, toElementsOp.getSource());
- rewriter.replaceOpWithNewOp<vector::ToElementsOp>(
- toElementsOp, toElementsOp.getResultTypes(), shapeCast);
- return success();
- }
-};
-
} // namespace
/// This method defines the set of operations that are linearizable, and hence
@@ -933,8 +890,8 @@ void mlir::vector::populateVectorLinearizeBasePatterns(
patterns
.add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
LinearizeVectorSplat, LinearizeVectorCreateMask, LinearizeVectorLoad,
- LinearizeVectorStore, LinearizeVectorFromElements,
- LinearizeVectorToElements>(typeConverter, patterns.getContext());
+ LinearizeVectorStore, LinearizeVectorFromElements>(
+ typeConverter, patterns.getContext());
}
void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index fe697c8b9c057..5e8bfd0698b33 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -538,26 +538,3 @@ func.func @test_vector_from_elements(%arg0: f32, %arg1: f32, %arg2: f32, %arg3:
%1 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x2xf32>
return %1 : vector<2x2xf32>
}
-
-// -----
-
-// CHECK-LABEL: func.func @to_elements_1d(
-// CHECK-SAME: %[[ARG0:.+]]: vector<2xf32>
-// CHECK: %[[RES:.+]]:2 = vector.to_elements %[[ARG0]] : vector<2xf32>
-// CHECK: return %[[RES]]#0, %[[RES]]#1
-func.func @to_elements_1d(%arg0: vector<2xf32>) -> (f32, f32) {
- %0:2 = vector.to_elements %arg0 : vector<2xf32>
- return %0#0, %0#1 : f32, f32
-}
-
-// -----
-
-// CHECK-LABEL: func.func @to_elements_2d(
-// CHECK-SAME: %[[ARG0:.+]]: vector<2x2xf32>
-// CHECK: %[[CAST:.+]] = vector.shape_cast %[[ARG0]]
-// CHECK: %[[RES:.+]]:4 = vector.to_elements %[[CAST]] : vector<4xf32>
-// CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2, %[[RES]]#3
-func.func @to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) {
- %0:4 = vector.to_elements %arg0 : vector<2x2xf32>
- return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
-}
More information about the Mlir-commits
mailing list