[Mlir-commits] [mlir] [MLIR][Vector] Add linearization patterns for interleave/deinterleave (PR #197123)

Artem Kroviakov llvmlistbot at llvm.org
Tue May 12 01:33:10 PDT 2026


https://github.com/akroviakov created https://github.com/llvm/llvm-project/pull/197123

This PR extends `VectorLinearize` patterns for `vector.interleave` and `vector.deinterleave` ops. Linearization is trivial due to both ops working on the inner-most dimension of a vector.

>From 6789790f25d641e441f70d5e18f6c3196aa444fb Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Tue, 12 May 2026 08:29:32 +0000
Subject: [PATCH] [MLIR][Vector] Add linearization patterns for
 interleave/deinterleave

---
 .../Vector/Transforms/VectorLinearize.cpp     | 57 ++++++++++++++++++-
 mlir/test/Dialect/Vector/linearize.mlir       | 28 +++++++++
 2 files changed, 84 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index c17d3862c0eac..584fa28c0897a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -861,6 +861,60 @@ struct LinearizeVectorBroadcast final
   }
 };
 
+/// Linearize `vector.interleave` to operate on flattened 1D operands and
+/// result. The flattening is commutative here, the order of flatten and
+/// interleave ops does not matter, so this transform is valid for
+/// any ND shape.
+///
+/// Example:
+///   vector.interleave %a, %b : vector<4x1xT> -> vector<4x2xT>
+/// becomes:
+///   vector.interleave %a_flat, %b_flat : vector<4xT> -> vector<8xT>
+struct LinearizeVectorInterleave final
+    : public OpConversionPattern<vector::InterleaveOp> {
+  using Base::Base;
+  LinearizeVectorInterleave(const TypeConverter &typeConverter,
+                            MLIRContext *context, PatternBenefit benefit = 1)
+      : OpConversionPattern(typeConverter, context, benefit) {}
+
+  LogicalResult
+  matchAndRewrite(vector::InterleaveOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    VectorType flatResultTy =
+        getTypeConverter()->convertType<VectorType>(op.getResultVectorType());
+    if (!flatResultTy)
+      return rewriter.notifyMatchFailure(op, "failed to linearize result type");
+    rewriter.replaceOpWithNewOp<vector::InterleaveOp>(
+        op, flatResultTy, adaptor.getLhs(), adaptor.getRhs());
+    return success();
+  }
+};
+
+/// Linearize `vector.deinterleave` to operate on a flattened 1D source.
+/// The flattening is commutative here, the order of flatten and deinterleave
+/// ops does not matter, so this transform is valid for any ND shape.
+///
+/// Example:
+///   %even, %odd = vector.deinterleave %src : vector<4x2xT> -> vector<4x1xT>
+/// becomes:
+///   %even, %odd = vector.deinterleave %src_flat : vector<8xT> -> vector<4xT>
+struct LinearizeVectorDeinterleave final
+    : public OpConversionPattern<vector::DeinterleaveOp> {
+  using Base::Base;
+  LinearizeVectorDeinterleave(const TypeConverter &typeConverter,
+                              MLIRContext *context, PatternBenefit benefit = 1)
+      : OpConversionPattern(typeConverter, context, benefit) {}
+
+  LogicalResult
+  matchAndRewrite(vector::DeinterleaveOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto newOp = vector::DeinterleaveOp::create(rewriter, op.getLoc(),
+                                                adaptor.getSource());
+    rewriter.replaceOp(op, newOp.getResults());
+    return success();
+  }
+};
+
 } // namespace
 
 /// This method defines the set of operations that are linearizable, and hence
@@ -952,7 +1006,8 @@ void mlir::vector::populateVectorLinearizeBasePatterns(
       .add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
            LinearizeVectorCreateMask, LinearizeVectorLoad, LinearizeVectorStore,
            LinearizeVectorBroadcast, LinearizeVectorFromElements,
-           LinearizeVectorToElements>(typeConverter, patterns.getContext());
+           LinearizeVectorToElements, LinearizeVectorInterleave,
+           LinearizeVectorDeinterleave>(typeConverter, patterns.getContext());
 }
 
 void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index cbbc833d7a51d..c2e64a6ec16e9 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -575,3 +575,31 @@ func.func @to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) {
   %0:4 = vector.to_elements %arg0 : vector<2x2xf32>
   return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
 }
+
+// -----
+
+// CHECK-LABEL: linearize_vector_interleave
+// CHECK-SAME: (%[[ARG0:.*]]: vector<2x4xf32>, %[[ARG1:.*]]: vector<2x4xf32>) -> vector<2x8xf32>
+func.func @linearize_vector_interleave(%arg0: vector<2x4xf32>, %arg1: vector<2x4xf32>) -> vector<2x8xf32> {
+  // CHECK-DAG: %[[LHS:.*]] = vector.shape_cast %[[ARG0]] : vector<2x4xf32> to vector<8xf32>
+  // CHECK-DAG: %[[RHS:.*]] = vector.shape_cast %[[ARG1]] : vector<2x4xf32> to vector<8xf32>
+  // CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LHS]], %[[RHS]] : vector<8xf32> -> vector<16xf32>
+  // CHECK: %[[CAST:.*]] = vector.shape_cast %[[INTERLEAVE]] : vector<16xf32> to vector<2x8xf32>
+  // CHECK: return %[[CAST]] : vector<2x8xf32>
+  %0 = vector.interleave %arg0, %arg1 : vector<2x4xf32> -> vector<2x8xf32>
+  return %0 : vector<2x8xf32>
+}
+
+// -----
+
+// CHECK-LABEL: linearize_vector_deinterleave
+// CHECK-SAME: (%[[ARG0:.*]]: vector<2x8xf32>) -> (vector<2x4xf32>, vector<2x4xf32>)
+func.func @linearize_vector_deinterleave(%arg0: vector<2x8xf32>) -> (vector<2x4xf32>, vector<2x4xf32>) {
+  // CHECK: %[[SRC:.*]] = vector.shape_cast %[[ARG0]] : vector<2x8xf32> to vector<16xf32>
+  // CHECK: %[[EVEN:.*]], %[[ODD:.*]] = vector.deinterleave %[[SRC]] : vector<16xf32> -> vector<8xf32>
+  // CHECK: %[[ODD_CAST:.*]] = vector.shape_cast %[[ODD]] : vector<8xf32> to vector<2x4xf32>
+  // CHECK: %[[EVEN_CAST:.*]] = vector.shape_cast %[[EVEN]] : vector<8xf32> to vector<2x4xf32>
+  // CHECK: return %[[EVEN_CAST]], %[[ODD_CAST]] : vector<2x4xf32>, vector<2x4xf32>
+  %even, %odd = vector.deinterleave %arg0 : vector<2x8xf32> -> vector<2x4xf32>
+  return %even, %odd : vector<2x4xf32>, vector<2x4xf32>
+}



More information about the Mlir-commits mailing list