[Mlir-commits] [mlir] [mlir] Move vector.{to_elements, from_elements} unrolling to `VectorUnroll.cpp` (PR #159118)
Erick Ochoa Lopez
llvmlistbot at llvm.org
Thu Sep 18 07:23:08 PDT 2025
https://github.com/amd-eochoalo updated https://github.com/llvm/llvm-project/pull/159118
>From 08655a9b64e846a5dcd1d77ec0481292392032e3 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 16 Sep 2025 09:00:46 -0700
Subject: [PATCH 01/11] [mlir] Add vector.{to_elements,from_elements} unrolling
to VectorToSPIRV
---
.../SPIRV/Transforms/SPIRVConversion.cpp | 2 +
.../ConvertToSPIRV/vector-unroll.mlir | 44 +++++++++++++++++++
2 files changed, 46 insertions(+)
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 49f4ce8de7c76..98e294b40456f 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -1495,6 +1495,8 @@ LogicalResult mlir::spirv::unrollVectorsInFuncBodies(Operation *op) {
RewritePatternSet patterns(context);
auto options = vector::UnrollVectorOptions().setNativeShapeFn(
[](auto op) { return mlir::spirv::getNativeVectorShape(op); });
+ vector::populateVectorFromElementsLoweringPatterns(patterns);
+ vector::populateVectorToElementsLoweringPatterns(patterns);
populateVectorUnrollPatterns(patterns, options);
if (failed(applyPatternsGreedily(op, std::move(patterns))))
return failure();
diff --git a/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir b/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir
index c85f4334ff2e5..0957f67690b97 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir
@@ -96,3 +96,47 @@ func.func @transpose(%arg0 : vector<2x3xi32>) -> (vector<3x2xi32>) {
%0 = vector.transpose %arg0, [1, 0] : vector<2x3xi32> to vector<3x2xi32>
return %0 : vector<3x2xi32>
}
+
+// -----
+
+// In order to verify that the pattern is applied,
+// we need to make sure that the the 2d vector does not
+// come from the parameters. Otherwise, the pattern
+// in unrollVectorsInSignatures which splits the 2d vector
+// parameter will take precedent. Similarly, let's avoid
+// returning a vector as another pattern would take precendence.
+
+// CHECK-LABEL: @unroll_to_elements_2d
+func.func @unroll_to_elements_2d() -> (f32, f32, f32, f32) {
+ %1 = "test.op"() : () -> (vector<2x2xf32>)
+ // CHECK: %[[VEC2D:.+]] = "test.op"
+ // CHECK: %[[VEC0:.+]] = vector.extract %[[VEC2D]][0] : vector<2xf32> from vector<2x2xf32>
+ // CHECK: %[[VEC1:.+]] = vector.extract %[[VEC2D]][1] : vector<2xf32> from vector<2x2xf32>
+ // CHECK: %[[RES0:.+]]:2 = vector.to_elements %[[VEC0]]
+ // CHECK: %[[RES1:.+]]:2 = vector.to_elements %[[VEC1]]
+ %2:4 = vector.to_elements %1 : vector<2x2xf32>
+ return %2#0, %2#1, %2#2, %2#3 : f32, f32, f32, f32
+}
+
+// -----
+
+// In order to verify that the pattern is applied,
+// we need to make sure that the the 2d vector is used
+// by an operation and that extracts are not folded away.
+// In other words we can't use "test.op" nor return the
+// value `%0 = vector.from_elements`
+
+// CHECK-LABEL: @unroll_from_elements_2d
+// CHECK-SAME: (%[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32, %[[ARG3:.+]]: f32)
+func.func @unroll_from_elements_2d(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32) -> (vector<2x2xf32>) {
+ // CHECK: %[[VEC0:.+]] = vector.from_elements %[[ARG0]], %[[ARG1]] : vector<2xf32>
+ // CHECK: %[[VEC1:.+]] = vector.from_elements %[[ARG2]], %[[ARG3]] : vector<2xf32>
+ %0 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x2xf32>
+
+ // CHECK: %[[RES0:.+]] = arith.addf %[[VEC0]], %[[VEC0]]
+ // CHECK: %[[RES1:.+]] = arith.addf %[[VEC1]], %[[VEC1]]
+ %1 = arith.addf %0, %0 : vector<2x2xf32>
+
+ // return %[[RES0]], %%[[RES1]] : vector<2xf32>, vector<2xf32>
+ return %1 : vector<2x2xf32>
+}
>From c891a27212669a5d54f383a11ac3defb16a92f8f Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 16 Sep 2025 11:25:02 -0700
Subject: [PATCH 02/11] populate patterns inside populateVectorUnrollPatterns
---
mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp | 2 --
mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp | 3 +++
2 files changed, 3 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 98e294b40456f..49f4ce8de7c76 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -1495,8 +1495,6 @@ LogicalResult mlir::spirv::unrollVectorsInFuncBodies(Operation *op) {
RewritePatternSet patterns(context);
auto options = vector::UnrollVectorOptions().setNativeShapeFn(
[](auto op) { return mlir::spirv::getNativeVectorShape(op); });
- vector::populateVectorFromElementsLoweringPatterns(patterns);
- vector::populateVectorToElementsLoweringPatterns(patterns);
populateVectorUnrollPatterns(patterns, options);
if (failed(applyPatternsGreedily(op, std::move(patterns))))
return failure();
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index e8ecb0c0be846..a75e680afe1fb 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/Interfaces/VectorInterfaces.h"
#include "llvm/ADT/MapVector.h"
@@ -814,6 +815,8 @@ struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
void mlir::vector::populateVectorUnrollPatterns(
RewritePatternSet &patterns, const UnrollVectorOptions &options,
PatternBenefit benefit) {
+ populateVectorToElementsLoweringPatterns(patterns);
+ populateVectorFromElementsLoweringPatterns(patterns);
patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
UnrollContractionPattern, UnrollElementwisePattern,
UnrollReductionPattern, UnrollMultiReductionPattern,
>From e1f9605b9fd07517e112e1ad7340de354c6f29f0 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 16 Sep 2025 11:43:01 -0700
Subject: [PATCH 03/11] Copy over UnrollToElements
---
.../Vector/Transforms/VectorRewritePatterns.h | 6 ++++
.../Vector/Transforms/VectorUnroll.cpp | 32 ++++++++++++++++++-
2 files changed, 37 insertions(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 0138f477cadea..32fcb948b9cf7 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -322,6 +322,12 @@ void populateVectorUnrollPatterns(RewritePatternSet &patterns,
const UnrollVectorOptions &options,
PatternBenefit benefit = 1);
+/// Populate the pattern set with the following patterns:
+///
+/// [UnrollToElements]
+void populateVectorToElementsUnrollPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
+
/// Collect a set of leading one dimension removal patterns.
///
/// These patterns insert vector.shape_cast to remove leading one dimensions
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index a75e680afe1fb..bcfaa843a306f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -810,13 +810,38 @@ struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
vector::UnrollVectorOptions options;
};
+struct UnrollToElements final : public OpRewritePattern<vector::ToElementsOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ToElementsOp op,
+ PatternRewriter &rewriter) const override {
+
+ TypedValue<VectorType> source = op.getSource();
+ FailureOr<SmallVector<Value>> result =
+ vector::unrollVectorValue(source, rewriter);
+ if (failed(result)) {
+ return failure();
+ }
+ SmallVector<Value> vectors = *result;
+
+ SmallVector<Value> results;
+ for (const Value &vector : vectors) {
+ auto subElements =
+ vector::ToElementsOp::create(rewriter, op.getLoc(), vector);
+ llvm::append_range(results, subElements.getResults());
+ }
+ rewriter.replaceOp(op, results);
+ return success();
+ }
+};
+
} // namespace
void mlir::vector::populateVectorUnrollPatterns(
RewritePatternSet &patterns, const UnrollVectorOptions &options,
PatternBenefit benefit) {
- populateVectorToElementsLoweringPatterns(patterns);
populateVectorFromElementsLoweringPatterns(patterns);
+ patterns.add<UnrollToElements>(patterns.getContext(), benefit);
patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
UnrollContractionPattern, UnrollElementwisePattern,
UnrollReductionPattern, UnrollMultiReductionPattern,
@@ -824,3 +849,8 @@ void mlir::vector::populateVectorUnrollPatterns(
UnrollStorePattern, UnrollBroadcastPattern>(
patterns.getContext(), options, benefit);
}
+
+void mlir::vector::populateVectorToElementsUnrollPatterns(
+ RewritePatternSet &patterns, PatternBenefit benefit) {
+ patterns.add<UnrollToElements>(patterns.getContext(), benefit);
+}
>From 5c3e7d5d60c1fdd4e3e09399c68d04a78ee8748a Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 16 Sep 2025 11:55:05 -0700
Subject: [PATCH 04/11] Removes reference to previous UnrollToElements pattern
---
.../Vector/Transforms/LoweringPatterns.h | 6 ---
.../VectorToLLVM/ConvertVectorToLLVMPass.cpp | 2 +-
.../TransformOps/VectorTransformOps.cpp | 2 +-
.../Dialect/Vector/Transforms/CMakeLists.txt | 1 -
.../Transforms/LowerVectorToElements.cpp | 53 -------------------
5 files changed, 2 insertions(+), 62 deletions(-)
delete mode 100644 mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index f56124cb4fb95..47f96112a9433 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -311,12 +311,6 @@ void populateVectorToFromElementsToShuffleTreePatterns(
void populateVectorFromElementsLoweringPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
-/// Populate the pattern set with the following patterns:
-///
-/// [UnrollToElements]
-void populateVectorToElementsLoweringPatterns(RewritePatternSet &patterns,
- PatternBenefit benefit = 1);
-
/// Populate the pattern set with the following patterns:
///
/// [ContractionOpToMatmulOpLowering]
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 0b44ca7ceee42..9cdfeea2b81bf 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -95,7 +95,7 @@ void ConvertVectorToLLVMPass::runOnOperation() {
populateVectorRankReducingFMAPattern(patterns);
populateVectorGatherLoweringPatterns(patterns);
populateVectorFromElementsLoweringPatterns(patterns);
- populateVectorToElementsLoweringPatterns(patterns);
+ populateVectorToElementsUnrollPatterns(patterns);
if (armI8MM) {
if (armNeon)
arm_neon::populateLowerContractionToNeonI8MMPatterns(patterns);
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 18f105ef62e38..a3350a3332862 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -151,7 +151,7 @@ void transform::ApplyUnrollFromElementsPatternsOp::populatePatterns(
void transform::ApplyUnrollToElementsPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
- vector::populateVectorToElementsLoweringPatterns(patterns);
+ vector::populateVectorToElementsUnrollPatterns(patterns);
}
void transform::ApplyLowerScanPatternsOp::populatePatterns(
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index d74007f13a95b..acbf2b746037b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -11,7 +11,6 @@ add_mlir_dialect_library(MLIRVectorTransforms
LowerVectorScan.cpp
LowerVectorShapeCast.cpp
LowerVectorStep.cpp
- LowerVectorToElements.cpp
LowerVectorToFromElementsToShuffleTree.cpp
LowerVectorTransfer.cpp
LowerVectorTranspose.cpp
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
deleted file mode 100644
index a53a183ec31bc..0000000000000
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
+++ /dev/null
@@ -1,53 +0,0 @@
-//===- LowerVectorToElements.cpp - Lower 'vector.to_elements' op ----------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file implements target-independent rewrites and utilities to lower the
-// 'vector.to_elements' operation.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Vector/IR/VectorOps.h"
-#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
-
-#define DEBUG_TYPE "lower-vector-to-elements"
-
-using namespace mlir;
-
-namespace {
-
-struct UnrollToElements final : public OpRewritePattern<vector::ToElementsOp> {
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(vector::ToElementsOp op,
- PatternRewriter &rewriter) const override {
-
- TypedValue<VectorType> source = op.getSource();
- FailureOr<SmallVector<Value>> result =
- vector::unrollVectorValue(source, rewriter);
- if (failed(result)) {
- return failure();
- }
- SmallVector<Value> vectors = *result;
-
- SmallVector<Value> results;
- for (const Value &vector : vectors) {
- auto subElements =
- vector::ToElementsOp::create(rewriter, op.getLoc(), vector);
- llvm::append_range(results, subElements.getResults());
- }
- rewriter.replaceOp(op, results);
- return success();
- }
-};
-
-} // namespace
-
-void mlir::vector::populateVectorToElementsLoweringPatterns(
- RewritePatternSet &patterns, PatternBenefit benefit) {
- patterns.add<UnrollToElements>(patterns.getContext(), benefit);
-}
>From 04431c5f5755c7ed2b71f831014dc8d6f8563af7 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 16 Sep 2025 12:05:05 -0700
Subject: [PATCH 05/11] Copy over UnrollFromElements
---
.../Vector/Transforms/VectorRewritePatterns.h | 8 ++++
.../Vector/Transforms/VectorUnroll.cpp | 46 ++++++++++++++++++-
2 files changed, 52 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 32fcb948b9cf7..c42b8748f60de 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -328,6 +328,14 @@ void populateVectorUnrollPatterns(RewritePatternSet &patterns,
void populateVectorToElementsUnrollPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
+/// Populate the pattern set with the following patterns:
+///
+/// [UnrollFromElements]
+/// Unrolls 2 or more dimensional `vector.from_elements` ops by unrolling the
+/// outermost dimension.
+void populateVectorFromElementsUnrollPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
+
/// Collect a set of leading one dimension removal patterns.
///
/// These patterns insert vector.shape_cast to remove leading one dimensions
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index bcfaa843a306f..ba82aa766180f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -835,13 +835,50 @@ struct UnrollToElements final : public OpRewritePattern<vector::ToElementsOp> {
}
};
+/// Unrolls 2 or more dimensional `vector.from_elements` ops by unrolling the
+/// outermost dimension. For example:
+/// ```
+/// %v = vector.from_elements %e0, %e1, %e2, %e3, %e4, %e5 : vector<2x3xf32>
+///
+/// ==>
+///
+/// %0 = ub.poison : vector<2x3xf32>
+/// %v0 = vector.from_elements %e0, %e1, %e2 : vector<3xf32>
+/// %1 = vector.insert %v0, %0 [0] : vector<3xf32> into vector<2x3xf32>
+/// %v1 = vector.from_elements %e3, %e4, %e5 : vector<3xf32>
+/// %v = vector.insert %v1, %1 [1] : vector<3xf32> into vector<2x3xf32>
+/// ```
+///
+/// When applied exhaustively, this will produce a sequence of 1-d from_elements
+/// ops.
+struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::FromElementsOp op,
+ PatternRewriter &rewriter) const override {
+ ValueRange allElements = op.getElements();
+
+ auto unrollFromElementsFn = [&](PatternRewriter &rewriter, Location loc,
+ VectorType subTy, int64_t index) {
+ size_t subTyNumElements = subTy.getNumElements();
+ assert((index + 1) * subTyNumElements <= allElements.size() &&
+ "out of bounds");
+ ValueRange subElements =
+ allElements.slice(index * subTyNumElements, subTyNumElements);
+ return vector::FromElementsOp::create(rewriter, loc, subTy, subElements);
+ };
+
+ return unrollVectorOp(op, rewriter, unrollFromElementsFn);
+ }
+};
+
} // namespace
void mlir::vector::populateVectorUnrollPatterns(
RewritePatternSet &patterns, const UnrollVectorOptions &options,
PatternBenefit benefit) {
- populateVectorFromElementsLoweringPatterns(patterns);
- patterns.add<UnrollToElements>(patterns.getContext(), benefit);
+ patterns.add<UnrollFromElements, UnrollToElements>(patterns.getContext(),
+ benefit);
patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
UnrollContractionPattern, UnrollElementwisePattern,
UnrollReductionPattern, UnrollMultiReductionPattern,
@@ -854,3 +891,8 @@ void mlir::vector::populateVectorToElementsUnrollPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<UnrollToElements>(patterns.getContext(), benefit);
}
+
+void mlir::vector::populateVectorFromElementsUnrollPatterns(
+ RewritePatternSet &patterns, PatternBenefit benefit) {
+ patterns.add<UnrollFromElements>(patterns.getContext(), benefit);
+}
>From 70a667e035a6de981e7b21e759337f8bb47728bb Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 16 Sep 2025 12:10:48 -0700
Subject: [PATCH 06/11] Removes reference to previous UnrollFromElements
pattern
---
.../Vector/Transforms/LoweringPatterns.h | 8 ---
.../GPUCommon/GPUToLLVMConversion.cpp | 2 +-
.../GPUToNVVM/LowerGpuOpsToNVVMOps.cpp | 2 +-
.../VectorToLLVM/ConvertVectorToLLVMPass.cpp | 2 +-
.../TransformOps/VectorTransformOps.cpp | 2 +-
.../Dialect/Vector/Transforms/CMakeLists.txt | 1 -
.../Transforms/LowerVectorFromElements.cpp | 65 -------------------
7 files changed, 4 insertions(+), 78 deletions(-)
delete mode 100644 mlir/lib/Dialect/Vector/Transforms/LowerVectorFromElements.cpp
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 47f96112a9433..e03f0dabece52 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -303,14 +303,6 @@ void populateVectorRankReducingFMAPattern(RewritePatternSet &patterns);
void populateVectorToFromElementsToShuffleTreePatterns(
RewritePatternSet &patterns, PatternBenefit benefit = 1);
-/// Populate the pattern set with the following patterns:
-///
-/// [UnrollFromElements]
-/// Unrolls 2 or more dimensional `vector.from_elements` ops by unrolling the
-/// outermost dimension.
-void populateVectorFromElementsLoweringPatterns(RewritePatternSet &patterns,
- PatternBenefit benefit = 1);
-
/// Populate the pattern set with the following patterns:
///
/// [ContractionOpToMatmulOpLowering]
diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index e516118f75207..5994b64f3d9a5 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -534,7 +534,7 @@ void GpuToLLVMConversionPass::runOnOperation() {
/*maxTransferRank=*/1);
// Transform N-D vector.from_elements to 1-D vector.from_elements before
// conversion.
- vector::populateVectorFromElementsLoweringPatterns(patterns);
+ vector::populateVectorFromElementsUnrollPatterns(patterns);
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
}
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 76a7e0f3831a2..a95263bb55f69 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -372,7 +372,7 @@ struct LowerGpuOpsToNVVMOpsPass final
populateGpuRewritePatterns(patterns);
// Transform N-D vector.from_elements to 1-D vector.from_elements before
// conversion.
- vector::populateVectorFromElementsLoweringPatterns(patterns);
+ vector::populateVectorFromElementsUnrollPatterns(patterns);
if (failed(applyPatternsGreedily(m, std::move(patterns))))
return signalPassFailure();
}
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 9cdfeea2b81bf..cae490e5f03e7 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -94,7 +94,7 @@ void ConvertVectorToLLVMPass::runOnOperation() {
populateVectorStepLoweringPatterns(patterns);
populateVectorRankReducingFMAPattern(patterns);
populateVectorGatherLoweringPatterns(patterns);
- populateVectorFromElementsLoweringPatterns(patterns);
+ populateVectorFromElementsUnrollPatterns(patterns);
populateVectorToElementsUnrollPatterns(patterns);
if (armI8MM) {
if (armNeon)
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index a3350a3332862..7faa222a9e574 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -146,7 +146,7 @@ void transform::ApplyLowerGatherPatternsOp::populatePatterns(
void transform::ApplyUnrollFromElementsPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
- vector::populateVectorFromElementsLoweringPatterns(patterns);
+ vector::populateVectorFromElementsUnrollPatterns(patterns);
}
void transform::ApplyUnrollToElementsPatternsOp::populatePatterns(
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index acbf2b746037b..9e287fc109990 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -3,7 +3,6 @@ add_mlir_dialect_library(MLIRVectorTransforms
LowerVectorBitCast.cpp
LowerVectorBroadcast.cpp
LowerVectorContract.cpp
- LowerVectorFromElements.cpp
LowerVectorGather.cpp
LowerVectorInterleave.cpp
LowerVectorMask.cpp
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorFromElements.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorFromElements.cpp
deleted file mode 100644
index c22fd54cef46b..0000000000000
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorFromElements.cpp
+++ /dev/null
@@ -1,65 +0,0 @@
-//===- LowerVectorFromElements.cpp - Lower 'vector.from_elements' op -----===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file implements target-independent rewrites and utilities to lower the
-// 'vector.from_elements' operation.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Vector/IR/VectorOps.h"
-#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
-
-#define DEBUG_TYPE "lower-vector-from-elements"
-
-using namespace mlir;
-
-namespace {
-
-/// Unrolls 2 or more dimensional `vector.from_elements` ops by unrolling the
-/// outermost dimension. For example:
-/// ```
-/// %v = vector.from_elements %e0, %e1, %e2, %e3, %e4, %e5 : vector<2x3xf32>
-///
-/// ==>
-///
-/// %0 = ub.poison : vector<2x3xf32>
-/// %v0 = vector.from_elements %e0, %e1, %e2 : vector<3xf32>
-/// %1 = vector.insert %v0, %0 [0] : vector<3xf32> into vector<2x3xf32>
-/// %v1 = vector.from_elements %e3, %e4, %e5 : vector<3xf32>
-/// %v = vector.insert %v1, %1 [1] : vector<3xf32> into vector<2x3xf32>
-/// ```
-///
-/// When applied exhaustively, this will produce a sequence of 1-d from_elements
-/// ops.
-struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(vector::FromElementsOp op,
- PatternRewriter &rewriter) const override {
- ValueRange allElements = op.getElements();
-
- auto unrollFromElementsFn = [&](PatternRewriter &rewriter, Location loc,
- VectorType subTy, int64_t index) {
- size_t subTyNumElements = subTy.getNumElements();
- assert((index + 1) * subTyNumElements <= allElements.size() &&
- "out of bounds");
- ValueRange subElements =
- allElements.slice(index * subTyNumElements, subTyNumElements);
- return vector::FromElementsOp::create(rewriter, loc, subTy, subElements);
- };
-
- return unrollVectorOp(op, rewriter, unrollFromElementsFn);
- }
-};
-
-} // namespace
-
-void mlir::vector::populateVectorFromElementsLoweringPatterns(
- RewritePatternSet &patterns, PatternBenefit benefit) {
- patterns.add<UnrollFromElements>(patterns.getContext(), benefit);
-}
>From c054f16b16be2d8e3cee1234c6b1443b81ba70d3 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 16 Sep 2025 12:26:11 -0700
Subject: [PATCH 07/11] Adds UnrollVectorOptions to ToElements and ForElements
patterns
---
.../Vector/Transforms/VectorUnroll.cpp | 30 ++++++++++++++-----
1 file changed, 22 insertions(+), 8 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index ba82aa766180f..6f8c667d58f48 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -811,7 +811,11 @@ struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
};
struct UnrollToElements final : public OpRewritePattern<vector::ToElementsOp> {
- using OpRewritePattern::OpRewritePattern;
+ UnrollToElements(MLIRContext *context,
+ const vector::UnrollVectorOptions &options,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<vector::ToElementsOp>(context, benefit),
+ options(options) {}
LogicalResult matchAndRewrite(vector::ToElementsOp op,
PatternRewriter &rewriter) const override {
@@ -833,6 +837,9 @@ struct UnrollToElements final : public OpRewritePattern<vector::ToElementsOp> {
rewriter.replaceOp(op, results);
return success();
}
+
+private:
+ vector::UnrollVectorOptions options;
};
/// Unrolls 2 or more dimensional `vector.from_elements` ops by unrolling the
@@ -852,7 +859,11 @@ struct UnrollToElements final : public OpRewritePattern<vector::ToElementsOp> {
/// When applied exhaustively, this will produce a sequence of 1-d from_elements
/// ops.
struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
- using OpRewritePattern::OpRewritePattern;
+ UnrollFromElements(MLIRContext *context,
+ const vector::UnrollVectorOptions &options,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<vector::FromElementsOp>(context, benefit),
+ options(options) {}
LogicalResult matchAndRewrite(vector::FromElementsOp op,
PatternRewriter &rewriter) const override {
@@ -870,6 +881,9 @@ struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
return unrollVectorOp(op, rewriter, unrollFromElementsFn);
}
+
+private:
+ vector::UnrollVectorOptions options;
};
} // namespace
@@ -877,22 +891,22 @@ struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
void mlir::vector::populateVectorUnrollPatterns(
RewritePatternSet &patterns, const UnrollVectorOptions &options,
PatternBenefit benefit) {
- patterns.add<UnrollFromElements, UnrollToElements>(patterns.getContext(),
- benefit);
patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
UnrollContractionPattern, UnrollElementwisePattern,
UnrollReductionPattern, UnrollMultiReductionPattern,
UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
- UnrollStorePattern, UnrollBroadcastPattern>(
- patterns.getContext(), options, benefit);
+ UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements,
+ UnrollToElements>(patterns.getContext(), options, benefit);
}
void mlir::vector::populateVectorToElementsUnrollPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
- patterns.add<UnrollToElements>(patterns.getContext(), benefit);
+ patterns.add<UnrollToElements>(patterns.getContext(), UnrollVectorOptions(),
+ benefit);
}
void mlir::vector::populateVectorFromElementsUnrollPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
- patterns.add<UnrollFromElements>(patterns.getContext(), benefit);
+ patterns.add<UnrollFromElements>(patterns.getContext(), UnrollVectorOptions(),
+ benefit);
}
>From 5a266fbbfc6977b8f8bf21234133f2a4b30c45c5 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 17 Sep 2025 13:09:50 -0400
Subject: [PATCH 08/11] Address review comments
---
mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 0af1882a66d30..1c80c6b6a39aa 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -829,7 +829,7 @@ struct UnrollToElements final : public OpRewritePattern<vector::ToElementsOp> {
SmallVector<Value> vectors = *result;
SmallVector<Value> results;
- for (const Value &vector : vectors) {
+ for (Value vector : vectors) {
auto subElements =
vector::ToElementsOp::create(rewriter, op.getLoc(), vector);
llvm::append_range(results, subElements.getResults());
>From f0283ca5f4ed806fe5cb88cf768b7b21d3049710 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Thu, 18 Sep 2025 08:43:13 -0400
Subject: [PATCH 09/11] Update documentation
---
.../Dialect/Vector/Transforms/VectorRewritePatterns.h | 8 ++------
1 file changed, 2 insertions(+), 6 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index b1effc8642383..69438011d2287 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -322,15 +322,11 @@ void populateVectorUnrollPatterns(RewritePatternSet &patterns,
const UnrollVectorOptions &options,
PatternBenefit benefit = 1);
-/// Populate the pattern set with the following patterns:
-///
-/// [UnrollToElements]
+/// Unrolls 2 or more dimensional `vector.to_elements` ops by unrolling the
+/// outermost dimension of the operand.
void populateVectorToElementsUnrollPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
-/// Populate the pattern set with the following patterns:
-///
-/// [UnrollFromElements]
/// Unrolls 2 or more dimensional `vector.from_elements` ops by unrolling the
/// outermost dimension.
void populateVectorFromElementsUnrollPatterns(RewritePatternSet &patterns,
>From 5d7afb1efd4a2b05918044c1bf0b46773b41619c Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Thu, 18 Sep 2025 08:44:46 -0400
Subject: [PATCH 10/11] Reword documentation 'exhaustive' -> 'fixed-point'
---
mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 1c80c6b6a39aa..e7dd5958cf72a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -931,7 +931,8 @@ struct UnrollStepPattern : public OpRewritePattern<vector::StepOp> {
/// %v = vector.insert %v1, %1 [1] : vector<3xf32> into vector<2x3xf32>
/// ```
///
-/// When applied exhaustively, this will produce a sequence of 1-d from_elements
+/// When this pattern is applied until a fixed-point is reached,
+/// this will produce a sequence of 1-d from_elements
/// ops.
struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
UnrollFromElements(MLIRContext *context,
>From 44145b5c525a179ef117c17d89c4fa6bd37481a8 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Thu, 18 Sep 2025 08:48:06 -0400
Subject: [PATCH 11/11] Add documentation to UnrollToElements
---
.../Dialect/Vector/Transforms/VectorUnroll.cpp | 17 +++++++++++++++++
1 file changed, 17 insertions(+)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index e7dd5958cf72a..14639c5f1cdd3 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -810,6 +810,23 @@ struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
vector::UnrollVectorOptions options;
};
+/// Unrolls 2 or more dimensional `vector.to_elements` ops by unrolling the
+/// outermost dimension of the operand. For example:
+///
+/// ```
+/// %0:4 = vector.to_elements %v : vector<2x2xf32>
+///
+/// ==>
+///
+/// %v0 = vector.extract %v[0] : vector<2x2xf32> from vector<2x2x2xf32>
+/// %v1 = vector.extract %v[1] : vector<2x2xf32> from vector<2x2x2xf32>
+/// %0:4 = vector.to_elements %v0 : vector<2x2xf32>
+/// %1:4 = vector.to_elements %v1 : vector<2x2xf32>
+/// ```
+///
+/// When this pattern is applied until a fixed-point is reached,
+/// this will produce a sequence of 1-d from_elements
+/// ops.
struct UnrollToElements final : public OpRewritePattern<vector::ToElementsOp> {
UnrollToElements(MLIRContext *context,
const vector::UnrollVectorOptions &options,
More information about the Mlir-commits
mailing list