[Mlir-commits] [mlir] [mlir][vector] Add LinearizeVectorToElements (PR #157740)
Erick Ochoa Lopez
llvmlistbot at llvm.org
Tue Sep 9 13:04:56 PDT 2025
https://github.com/amd-eochoalo created https://github.com/llvm/llvm-project/pull/157740
None
>From d4251875f8c346634c564e5498f40f86b33bd20f Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 9 Sep 2025 12:46:01 -0700
Subject: [PATCH] [mlir][vector] Add LinearizeVectorToElements
---
.../Vector/Transforms/VectorLinearize.cpp | 47 ++++++++++++++++++-
mlir/test/Dialect/Vector/linearize.mlir | 23 +++++++++
2 files changed, 68 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 7dde6311fa809..54eb182a9680f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -798,6 +798,49 @@ struct LinearizeVectorFromElements final
}
};
+/// This pattern linearizes the operand in `vector.to_elements` operations
+/// by converting the result type to a 1-D vector while preserving all element
+/// values. The transformation creates a linearized `vector.shape_cast`
+/// followed by a `vector.to_elements`.
+///
+/// Example:
+///
+/// %0:4 = vector.to_elements %v : vector<2x2xf32>
+///
+/// is converted to:
+///
+/// %vector_cast = vector.shape_cast %v : vector<2x2xf32> to vector<4xf32>
+/// %0:4 = vector.to_elements %vector_cast : vector<4xf32>
+///
+struct LinearizeVectorToElements final
+ : public OpConversionPattern<vector::ToElementsOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LinearizeVectorToElements(const TypeConverter &typeConverter,
+ MLIRContext *context, PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit) {}
+
+ LogicalResult
+ matchAndRewrite(vector::ToElementsOp toElementsOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ VectorType vecType = toElementsOp.getSource().getType();
+ if (vecType.getRank() <= 1)
+ return rewriter.notifyMatchFailure(
+ toElementsOp, "the rank is already less than or equal to 1");
+
+ assert(vecType.getNumScalableDims() == 0 &&
+ "scalable vector is not yet supported");
+ auto vec1DType =
+ VectorType::get({vecType.getNumElements()}, vecType.getElementType());
+ Value shapeCast = vector::ShapeCastOp::create(
+ rewriter, toElementsOp.getLoc(), vec1DType, toElementsOp.getSource());
+ rewriter.replaceOpWithNewOp<vector::ToElementsOp>(
+ toElementsOp, toElementsOp.getResultTypes(), shapeCast);
+ return success();
+ }
+};
+
} // namespace
/// This method defines the set of operations that are linearizable, and hence
@@ -890,8 +933,8 @@ void mlir::vector::populateVectorLinearizeBasePatterns(
patterns
.add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
LinearizeVectorSplat, LinearizeVectorCreateMask, LinearizeVectorLoad,
- LinearizeVectorStore, LinearizeVectorFromElements>(
- typeConverter, patterns.getContext());
+ LinearizeVectorStore, LinearizeVectorFromElements,
+ LinearizeVectorToElements>(typeConverter, patterns.getContext());
}
void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 5e8bfd0698b33..fe697c8b9c057 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -538,3 +538,26 @@ func.func @test_vector_from_elements(%arg0: f32, %arg1: f32, %arg2: f32, %arg3:
%1 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x2xf32>
return %1 : vector<2x2xf32>
}
+
+// -----
+
+// CHECK-LABEL: func.func @to_elements_1d(
+// CHECK-SAME: %[[ARG0:.+]]: vector<2xf32>
+// CHECK: %[[RES:.+]]:2 = vector.to_elements %[[ARG0]] : vector<2xf32>
+// CHECK: return %[[RES]]#0, %[[RES]]#1
+func.func @to_elements_1d(%arg0: vector<2xf32>) -> (f32, f32) {
+ %0:2 = vector.to_elements %arg0 : vector<2xf32>
+ return %0#0, %0#1 : f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @to_elements_2d(
+// CHECK-SAME: %[[ARG0:.+]]: vector<2x2xf32>
+// CHECK: %[[CAST:.+]] = vector.shape_cast %[[ARG0]]
+// CHECK: %[[RES:.+]]:4 = vector.to_elements %[[CAST]] : vector<4xf32>
+// CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2, %[[RES]]#3
+func.func @to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) {
+ %0:4 = vector.to_elements %arg0 : vector<2x2xf32>
+ return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
+}
More information about the Mlir-commits
mailing list