[Mlir-commits] [mlir] 5fdd3a1 - [mlir][vector] Follow-up improvements for multi-dimensional vector.from_elements support (#154664)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Aug 27 21:41:10 PDT 2025
Author: Yang Bai
Date: 2025-08-27T21:41:06-07:00
New Revision: 5fdd3a12e5aaab8411a64ac2e5e162d490c3161d
URL: https://github.com/llvm/llvm-project/commit/5fdd3a12e5aaab8411a64ac2e5e162d490c3161d
DIFF: https://github.com/llvm/llvm-project/commit/5fdd3a12e5aaab8411a64ac2e5e162d490c3161d.diff
LOG: [mlir][vector] Follow-up improvements for multi-dimensional vector.from_elements support (#154664)
This PR is a follow-up to #151175 that supported lowering
multi-dimensional `vector.from_elements` op to LLVM by introducing a
unrolling pattern.
## Changes
### Add `vector.shape_cast` based flattening pattern for
`vector.from_elements`
This change introduces a new linearization pattern that uses
`vector.shape_cast` to flatten multi-dimensional `vector.from_elements`
operations. This provides an alternative approach to the unrolling-based
method introduced in #151175.
**Example:**
```mlir
// Before
%v = vector.from_elements %e0, %e1, %e2, %e3 : vector<2x2xf32>
// After
%flat = vector.from_elements %e0, %e1, %e2, %e3 : vector<4xf32>
%result = vector.shape_cast %flat : vector<4xf32> to vector<2x2xf32>
```
---------
Co-authored-by: Yang Bai <yangb at nvidia.com>
Co-authored-by: James Newling <james.newling at gmail.com>
Added:
Modified:
mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
mlir/test/Dialect/Vector/linearize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 491b448e9e1e9..7dde6311fa809 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -762,6 +762,42 @@ struct LinearizeVectorStore final
}
};
+/// This pattern linearizes `vector.from_elements` operations by converting
+/// the result type to a 1-D vector while preserving all element values.
+/// The transformation creates a linearized `vector.from_elements` followed by
+/// a `vector.shape_cast` to restore the original multidimensional shape.
+///
+/// Example:
+///
+/// %0 = vector.from_elements %a, %b, %c, %d : vector<2x2xf32>
+///
+/// is converted to:
+///
+/// %0 = vector.from_elements %a, %b, %c, %d : vector<4xf32>
+/// %1 = vector.shape_cast %0 : vector<4xf32> to vector<2x2xf32>
+///
+struct LinearizeVectorFromElements final
+ : public OpConversionPattern<vector::FromElementsOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LinearizeVectorFromElements(const TypeConverter &typeConverter,
+ MLIRContext *context, PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit) {}
+ LogicalResult
+ matchAndRewrite(vector::FromElementsOp fromElementsOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ VectorType dstTy =
+ getTypeConverter()->convertType<VectorType>(fromElementsOp.getType());
+ assert(dstTy && "vector type destination expected.");
+
+ OperandRange elements = fromElementsOp.getElements();
+ assert(elements.size() == static_cast<size_t>(dstTy.getNumElements()) &&
+ "expected same number of elements");
+ rewriter.replaceOpWithNewOp<vector::FromElementsOp>(fromElementsOp, dstTy,
+ elements);
+ return success();
+ }
+};
+
} // namespace
/// This method defines the set of operations that are linearizable, and hence
@@ -854,7 +890,8 @@ void mlir::vector::populateVectorLinearizeBasePatterns(
patterns
.add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
LinearizeVectorSplat, LinearizeVectorCreateMask, LinearizeVectorLoad,
- LinearizeVectorStore>(typeConverter, patterns.getContext());
+ LinearizeVectorStore, LinearizeVectorFromElements>(
+ typeConverter, patterns.getContext());
}
void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 2e630bf93622e..5e8bfd0698b33 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -524,3 +524,17 @@ func.func @linearize_vector_store_scalable(%arg0: memref<2x8xf32>, %arg1: vector
vector.store %arg1, %arg0[%c0, %c0] : memref<2x8xf32>, vector<1x[4]xf32>
return
}
+
+// -----
+
+// Test pattern LinearizeVectorFromElements.
+
+// CHECK-LABEL: test_vector_from_elements
+// CHECK-SAME: %[[ARG_0:.*]]: f32, %[[ARG_1:.*]]: f32, %[[ARG_2:.*]]: f32, %[[ARG_3:.*]]: f32
+func.func @test_vector_from_elements(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32) -> vector<2x2xf32> {
+ // CHECK: %[[FROM_ELEMENTS:.*]] = vector.from_elements %[[ARG_0]], %[[ARG_1]], %[[ARG_2]], %[[ARG_3]] : vector<4xf32>
+ // CHECK: %[[CAST:.*]] = vector.shape_cast %[[FROM_ELEMENTS]] : vector<4xf32> to vector<2x2xf32>
+ // CHECK: return %[[CAST]] : vector<2x2xf32>
+ %1 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x2xf32>
+ return %1 : vector<2x2xf32>
+}
More information about the Mlir-commits
mailing list