[Mlir-commits] [mlir] b812e3d - [mlir][vector] Add	LinearizeVectorToElements (#157740)
    llvmlistbot at llvm.org 
    llvmlistbot at llvm.org
       
    Thu Sep 11 10:58:46 PDT 2025
    
    
  
Author: Erick Ochoa Lopez
Date: 2025-09-11T13:58:42-04:00
New Revision: b812e3d61a9230424cec92e05f073f080f62eed5
URL: https://github.com/llvm/llvm-project/commit/b812e3d61a9230424cec92e05f073f080f62eed5
DIFF: https://github.com/llvm/llvm-project/commit/b812e3d61a9230424cec92e05f073f080f62eed5.diff
LOG: [mlir][vector] Add LinearizeVectorToElements (#157740)
Co-authored-by: James Newling <james.newling at gmail.com>
Added: 
    
Modified: 
    mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
    mlir/test/Dialect/Vector/linearize.mlir
Removed: 
    
################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 7dde6311fa809..12acf4b3f07f5 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -798,6 +798,51 @@ struct LinearizeVectorFromElements final
   }
 };
 
+/// This pattern linearizes the operand in `vector.to_elements` operations
+/// by converting the source type to a 1-D vector while preserving all element
+/// values. The transformation creates a linearized `vector.shape_cast`
+/// followed by a `vector.to_elements`.
+///
+/// Example:
+///
+///     %0:4 = vector.to_elements %v : vector<2x2xf32>
+///
+/// is converted to:
+///
+///     %vector_cast = vector.shape_cast %v : vector<2x2xf32> to vector<4xf32>
+///     %0:4 = vector.to_elements %vector_cast : vector<4xf32>
+///
+struct LinearizeVectorToElements final
+    : public OpConversionPattern<vector::ToElementsOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LinearizeVectorToElements(const TypeConverter &typeConverter,
+                            MLIRContext *context, PatternBenefit benefit = 1)
+      : OpConversionPattern(typeConverter, context, benefit) {}
+
+  LogicalResult
+  matchAndRewrite(vector::ToElementsOp toElementsOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    VectorType vecType = toElementsOp.getSource().getType();
+    if (vecType.getRank() <= 1)
+      return rewriter.notifyMatchFailure(
+          toElementsOp, "the rank is already less than or equal to 1");
+
+    assert(vecType.getNumScalableDims() == 0 &&
+           "to_elements does not support scalable vectors");
+    auto vec1DType =
+        VectorType::get({vecType.getNumElements()}, vecType.getElementType());
+    Value shapeCast = vector::ShapeCastOp::create(
+        rewriter, toElementsOp.getLoc(), vec1DType, toElementsOp.getSource());
+    auto newToElementsOp =
+        vector::ToElementsOp::create(rewriter, toElementsOp.getLoc(),
+                                     toElementsOp.getResultTypes(), shapeCast);
+    rewriter.replaceOp(toElementsOp, newToElementsOp);
+    return success();
+  }
+};
+
 } // namespace
 
 /// This method defines the set of operations that are linearizable, and hence
@@ -890,8 +935,8 @@ void mlir::vector::populateVectorLinearizeBasePatterns(
   patterns
       .add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
            LinearizeVectorSplat, LinearizeVectorCreateMask, LinearizeVectorLoad,
-           LinearizeVectorStore, LinearizeVectorFromElements>(
-          typeConverter, patterns.getContext());
+           LinearizeVectorStore, LinearizeVectorFromElements,
+           LinearizeVectorToElements>(typeConverter, patterns.getContext());
 }
 
 void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
diff  --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 5e8bfd0698b33..fe697c8b9c057 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -538,3 +538,26 @@ func.func @test_vector_from_elements(%arg0: f32, %arg1: f32, %arg2: f32, %arg3:
   %1 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x2xf32>
   return %1 : vector<2x2xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func.func @to_elements_1d(
+// CHECK-SAME:    %[[ARG0:.+]]: vector<2xf32>
+// CHECK:         %[[RES:.+]]:2 = vector.to_elements %[[ARG0]] : vector<2xf32>
+// CHECK:         return %[[RES]]#0, %[[RES]]#1
+func.func @to_elements_1d(%arg0: vector<2xf32>) -> (f32, f32) {
+  %0:2 = vector.to_elements %arg0 : vector<2xf32>
+  return %0#0, %0#1 : f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @to_elements_2d(
+// CHECK-SAME:    %[[ARG0:.+]]: vector<2x2xf32>
+// CHECK:         %[[CAST:.+]] = vector.shape_cast %[[ARG0]]
+// CHECK:         %[[RES:.+]]:4 = vector.to_elements %[[CAST]] : vector<4xf32>
+// CHECK:         return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2, %[[RES]]#3
+func.func @to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) {
+  %0:4 = vector.to_elements %arg0 : vector<2x2xf32>
+  return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
+}
        
    
    
More information about the Mlir-commits
mailing list