[Mlir-commits] [mlir] [mlir][VectorOps] Add unrolling for n-D vector.interleave ops (3/4) (PR #80967)

Benjamin Maxwell llvmlistbot at llvm.org
Thu Feb 15 05:26:18 PST 2024


https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/80967

>From c3a579080dfe4b6cfa044581de9fc19c952f52cf Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Wed, 7 Feb 2024 10:00:45 +0000
Subject: [PATCH 1/2] [mlir][VectorOps] Add unrolling for n-D vector.interleave
 ops

This unrolls n-D vector.interleave ops like:

```mlir
vector.interleave %i, %j : vector<6x3xf32>
```

To a sequence of 1-D operations, which can then be directly lowered to
LLVM.
---
 .../Vector/TransformOps/VectorTransformOps.td | 14 +++
 .../Vector/Transforms/LoweringPatterns.h      |  8 ++
 .../VectorToLLVM/ConvertVectorToLLVMPass.cpp  |  1 +
 .../TransformOps/VectorTransformOps.cpp       |  5 ++
 .../Dialect/Vector/Transforms/CMakeLists.txt  |  1 +
 .../Transforms/LowerVectorInterleave.cpp      | 88 +++++++++++++++++++
 .../VectorToLLVM/vector-to-llvm.mlir          | 24 +++++
 ...vector-interleave-lowering-transforms.mlir | 49 +++++++++++
 8 files changed, 190 insertions(+)
 create mode 100644 mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
 create mode 100644 mlir/test/Dialect/Vector/vector-interleave-lowering-transforms.mlir

diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index ce88360aa52e9d..83df5fe27d7a4a 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -292,6 +292,20 @@ def ApplyLowerTransposePatternsOp : Op<Transform_Dialect,
   }];
 }
 
+def ApplyLowerInterleavePatternsOp : Op<Transform_Dialect,
+    "apply_patterns.vector.lower_interleave",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Indicates that vector interleave operations should be lowered to
+    finer-grained vector primitives.
+
+    This is usally a late step that is run after bufferization as part of the
+    process of lowering to e.g. LLVM or NVVM.
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
+
 def ApplyRewriteNarrowTypePatternsOp : Op<Transform_Dialect,
     "apply_patterns.vector.rewrite_narrow_types",
     [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 57b39f5f52c6d3..1cd3bab46396e3 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -264,6 +264,14 @@ void populateVectorMaskLoweringPatternsForSideEffectingOps(
 void populateVectorMaskedLoadStoreEmulationPatterns(RewritePatternSet &patterns,
                                                     PatternBenefit benefit = 1);
 
+/// Populate the pattern set with the following patterns:
+///
+/// [InterleaveOpLowering]
+/// Progressive lowering of InterleaveOp to ExtractOp + InsertOp + lower-D
+/// InterleaveOp until dim 1.
+void populateVectorInterleaveLoweringPatterns(RewritePatternSet &patterns,
+                                              PatternBenefit benefit = 1);
+
 } // namespace vector
 } // namespace mlir
 #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index ff8e78a668e0f1..e3a436c4a94009 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -68,6 +68,7 @@ void LowerVectorToLLVMPass::runOnOperation() {
     populateVectorContractLoweringPatterns(patterns, VectorTransformsOptions());
     populateVectorMaskOpLoweringPatterns(patterns);
     populateVectorShapeCastLoweringPatterns(patterns);
+    populateVectorInterleaveLoweringPatterns(patterns);
     populateVectorTransposeLoweringPatterns(patterns,
                                             VectorTransformsOptions());
     // Vector transfer ops with rank > 1 should be lowered with VectorToSCF.
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 19922c4295fe03..6c2cfd8833dddc 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -159,6 +159,11 @@ void transform::ApplyLowerTransposePatternsOp::populatePatterns(
   }
 }
 
+void transform::ApplyLowerInterleavePatternsOp::populatePatterns(
+    RewritePatternSet &patterns) {
+  vector::populateVectorInterleaveLoweringPatterns(patterns);
+}
+
 void transform::ApplyRewriteNarrowTypePatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
   populateVectorNarrowTypeRewritePatterns(patterns);
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index adf961ff935ffb..c4b6abd3e23615 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
   LowerVectorBroadcast.cpp
   LowerVectorContract.cpp
   LowerVectorGather.cpp
+  LowerVectorInterleave.cpp
   LowerVectorMask.cpp
   LowerVectorMultiReduction.cpp
   LowerVectorScan.cpp
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
new file mode 100644
index 00000000000000..fffa63d13820a5
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
@@ -0,0 +1,88 @@
+//===- LowerVectorInterleave.cpp - Lower 'vector.interleave' operation ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements target-independent rewrites and utilities to lower the
+// 'vector.interleave' operation.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/PatternMatch.h"
+
+#define DEBUG_TYPE "vector-interleave-lowering"
+
+using namespace mlir;
+using namespace mlir::vector;
+
+namespace {
+
+/// Progressive lowering of InterleaveOp.
+///
+/// Each leading dimension is unrolled until the result of the interleave is
+/// rank 1 (or the dimension is scalable, so can't be unrolled).
+///
+/// Example:
+///
+/// ```
+/// %0 = vector.interleave %lhs, %rhs : vector<2x...8xty>
+/// ```
+/// Becomes:
+/// ```
+/// %lhs_0 = vector.extract %lhs[0]
+/// %rhs_0 = vector.extract %rhs[0]
+/// %lhs_1 = vector.extract %lhs[1]
+/// %rhs_1 = vector.extract %rhs[1]
+/// %zip_0 = vector.interleave %lhs_0, %rhs_0
+/// %zip_1 = vector.interleave %lhs_1, %rhs_1
+/// %res_0 = vector.insert %zip_0, %undef[0]
+///     %0 = vector.insert %zip_1, %res_0[1]
+/// ```
+///
+/// If %zip_0 and %zip_1 still have a rank > 1 they will be unrolled again
+/// following the same pattern.
+class InterleaveOpLowering : public OpRewritePattern<vector::InterleaveOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::InterleaveOp op,
+                                PatternRewriter &rewriter) const override {
+    VectorType resultType = op.getResultVectorType();
+    // 1-D vector.interleave ops can be directly lowered to LLVM (later).
+    if (resultType.getRank() == 1)
+      return failure();
+
+    // Below we unroll the leading (or front) dimension. If that dimension is
+    // scalable we can't unroll it.
+    if (resultType.getScalableDims().front())
+      return failure();
+
+    // n-D case: Unroll the leading dimension.
+    auto loc = op.getLoc();
+    Value result = rewriter.create<arith::ConstantOp>(
+        loc, resultType, rewriter.getZeroAttr(resultType));
+    for (int idx = 0, end = resultType.getDimSize(0); idx < end; ++idx) {
+      Value extractLhs = rewriter.create<ExtractOp>(loc, op.getLhs(), idx);
+      Value extractRhs = rewriter.create<ExtractOp>(loc, op.getRhs(), idx);
+      Value interleave =
+          rewriter.create<InterleaveOp>(loc, extractLhs, extractRhs);
+      result = rewriter.create<InsertOp>(loc, interleave, result, idx);
+    }
+
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
+} // namespace
+
+void mlir::vector::populateVectorInterleaveLoweringPatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit) {
+  patterns.add<InterleaveOpLowering>(patterns.getContext(), benefit);
+}
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index a46f2e101f3c35..e94e51d49a98b7 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -2497,3 +2497,27 @@ func.func @vector_interleave_1d_scalable(%a: vector<[4]xi32>, %b: vector<[4]xi32
   %0 = vector.interleave %a, %b : vector<[4]xi32>
   return %0 : vector<[8]xi32>
 }
+
+// -----
+
+// CHECK-LABEL: @vector_interleave_2d
+//  CHECK-SAME:     %[[LHS:.*]]: vector<2x3xi8>, %[[RHS:.*]]: vector<2x3xi8>)
+func.func @vector_interleave_2d(%a: vector<2x3xi8>, %b: vector<2x3xi8>) -> vector<2x6xi8>
+{
+  // CHECK: llvm.shufflevector
+  // CHECK-NOT: vector.interleave {{.*}} : vector<2x3xi8>
+  %0 = vector.interleave %a, %b : vector<2x3xi8>
+  return %0 : vector<2x6xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @vector_interleave_2d_scalable
+//  CHECK-SAME:     %[[LHS:.*]]: vector<2x[8]xi16>, %[[RHS:.*]]: vector<2x[8]xi16>)
+func.func @vector_interleave_2d_scalable(%a: vector<2x[8]xi16>, %b: vector<2x[8]xi16>) -> vector<2x[16]xi16>
+{
+  // CHECK: llvm.intr.experimental.vector.interleave2
+  // CHECK-NOT: vector.interleave {{.*}} : vector<2x[8]xi16>
+  %0 = vector.interleave %a, %b : vector<2x[8]xi16>
+  return %0 : vector<2x[16]xi16>
+}
diff --git a/mlir/test/Dialect/Vector/vector-interleave-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-interleave-lowering-transforms.mlir
new file mode 100644
index 00000000000000..05be19c3c81f9a
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-interleave-lowering-transforms.mlir
@@ -0,0 +1,49 @@
+// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
+
+// CHECK-LABEL: @vector_interleave_2d
+//  CHECK-SAME:     %[[LHS:.*]]: vector<2x3xi8>, %[[RHS:.*]]: vector<2x3xi8>)
+func.func @vector_interleave_2d(%a: vector<2x3xi8>, %b: vector<2x3xi8>) -> vector<2x6xi8>
+{
+  // CHECK-DAG: %[[CST:.*]] = arith.constant dense<0>
+  // CHECK-DAG: %[[LHS_0:.*]] = vector.extract %[[LHS]][0]
+  // CHECK-DAG: %[[RHS_0:.*]] = vector.extract %[[RHS]][0]
+  // CHECK-DAG: %[[LHS_1:.*]] = vector.extract %[[LHS]][1]
+  // CHECK-DAG: %[[RHS_1:.*]] = vector.extract %[[RHS]][1]
+  // CHECK-DAG: %[[ZIP_0:.*]] = vector.interleave %[[LHS_0]], %[[RHS_0]]
+  // CHECK-DAG: %[[ZIP_1:.*]] = vector.interleave %[[LHS_1]], %[[RHS_1]]
+  // CHECK-DAG: %[[RES_0:.*]] = vector.insert %[[ZIP_0]], %[[CST]] [0]
+  // CHECK-DAG: %[[RES_1:.*]] = vector.insert %[[ZIP_1]], %[[RES_0]] [1]
+  // CHECK-NEXT: return %[[RES_1]] : vector<2x6xi8>
+  %0 = vector.interleave %a, %b : vector<2x3xi8>
+  return %0 : vector<2x6xi8>
+}
+
+// CHECK-LABEL: @vector_interleave_2d_scalable
+//  CHECK-SAME:     %[[LHS:.*]]: vector<2x[8]xi16>, %[[RHS:.*]]: vector<2x[8]xi16>)
+func.func @vector_interleave_2d_scalable(%a: vector<2x[8]xi16>, %b: vector<2x[8]xi16>) -> vector<2x[16]xi16>
+{
+  // CHECK-DAG: %[[CST:.*]] = arith.constant dense<0>
+  // CHECK-DAG: %[[LHS_0:.*]] = vector.extract %[[LHS]][0]
+  // CHECK-DAG: %[[RHS_0:.*]] = vector.extract %[[RHS]][0]
+  // CHECK-DAG: %[[LHS_1:.*]] = vector.extract %[[LHS]][1]
+  // CHECK-DAG: %[[RHS_1:.*]] = vector.extract %[[RHS]][1]
+  // CHECK-DAG: %[[ZIP_0:.*]] = vector.interleave %[[LHS_0]], %[[RHS_0]]
+  // CHECK-DAG: %[[ZIP_1:.*]] = vector.interleave %[[LHS_1]], %[[RHS_1]]
+  // CHECK-DAG: %[[RES_0:.*]] = vector.insert %[[ZIP_0]], %[[CST]] [0]
+  // CHECK-DAG: %[[RES_1:.*]] = vector.insert %[[ZIP_1]], %[[RES_0]] [1]
+  // CHECK-NEXT: return %[[RES_1]] : vector<2x[16]xi16>
+  %0 = vector.interleave %a, %b : vector<2x[8]xi16>
+  return %0 : vector<2x[16]xi16>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+    %f = transform.structured.match ops{["func.func"]} in %module_op
+      : (!transform.any_op) -> !transform.any_op
+
+    transform.apply_patterns to %f {
+      transform.apply_patterns.vector.lower_interleave
+    } : !transform.any_op
+    transform.yield
+  }
+}

>From b25cd9385ef27d027cf500c78f6304cf5d510c15 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 15 Feb 2024 13:22:22 +0000
Subject: [PATCH 2/2] Fixup: Create and use `vector::createUnrollIterator()`
 until

Instead of progressively unrolling a leading dimension at a time, this
now uses `vector::createUnrollIterator()` which returns an iterator for
all leading dimensions of a vector type (until a target rank).
---
 .../Vector/Transforms/LoweringPatterns.h      |  7 +-
 .../mlir/Dialect/Vector/Utils/VectorUtils.h   | 23 +++++++
 .../Transforms/LowerVectorInterleave.cpp      | 64 +++++++++----------
 mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp | 22 +++++++
 ...vector-interleave-lowering-transforms.mlir | 22 +++++++
 5 files changed, 103 insertions(+), 35 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 1cd3bab46396e3..350d2777cadf50 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -266,10 +266,11 @@ void populateVectorMaskedLoadStoreEmulationPatterns(RewritePatternSet &patterns,
 
 /// Populate the pattern set with the following patterns:
 ///
-/// [InterleaveOpLowering]
-/// Progressive lowering of InterleaveOp to ExtractOp + InsertOp + lower-D
-/// InterleaveOp until dim 1.
+/// [UnrollInterleaveOp]
+/// A one-shot unrolling of InterleaveOp to (one or more) ExtractOp +
+/// InterleaveOp (of `targetRank`) + InsertOp.
 void populateVectorInterleaveLoweringPatterns(RewritePatternSet &patterns,
+                                              int64_t targetRank = 1,
                                               PatternBenefit benefit = 1);
 
 } // namespace vector
diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index 2ab456d4fdbf11..d72074de8525e7 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -9,6 +9,7 @@
 #ifndef MLIR_DIALECT_VECTOR_UTILS_VECTORUTILS_H_
 #define MLIR_DIALECT_VECTOR_UTILS_VECTORUTILS_H_
 
+#include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/Support/LLVM.h"
@@ -75,6 +76,28 @@ FailureOr<std::pair<int, int>> isTranspose2DSlice(vector::TransposeOp op);
 ///         vector<2x1x2x2xi32> from memref<5x4x3x2xi32>)
 bool isContiguousSlice(MemRefType memrefType, VectorType vectorType);
 
+/// Returns an iterator for all positions in the leading dimensions of `vType`
+/// up to the `targetRank`. If any leading dimension is scalable (so cannot be
+/// unrolled), it will return an iterator for positions up to the first scalable
+/// dimension.
+///
+/// If no leading dimensions can be unrolled an empty optional will be retunred.
+///
+/// Example:
+///
+///   For vType = vector<2x3x4> and targetRank = 1
+///
+///   The resulting iterator will yield:
+///     [0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2]
+///
+///   For vType = vector<3x[4]x5> and targetRank = 0
+///
+///   The scalable dimension blocks unrolling so the iterator yields only:
+///     [0], [1], [2]
+///
+std::optional<StaticTileOffsetRange>
+createUnrollIterator(VectorType vType, int64_t targetRank = 1);
+
 } // namespace vector
 
 /// Constructs a permutation map of invariant memref indices to vector
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
index fffa63d13820a5..5158a1a1b23eae 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
@@ -13,6 +13,7 @@
 
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/PatternMatch.h"
 
@@ -23,66 +24,65 @@ using namespace mlir::vector;
 
 namespace {
 
-/// Progressive lowering of InterleaveOp.
-///
-/// Each leading dimension is unrolled until the result of the interleave is
-/// rank 1 (or the dimension is scalable, so can't be unrolled).
+/// A one-shot unrolling of vector.interleave to the `targetRank`.
 ///
 /// Example:
 ///
+/// ```mlir
+/// vector.interleave %a, %b : vector<1x2x3x4xi64>
 /// ```
-/// %0 = vector.interleave %lhs, %rhs : vector<2x...8xty>
-/// ```
-/// Becomes:
-/// ```
-/// %lhs_0 = vector.extract %lhs[0]
-/// %rhs_0 = vector.extract %rhs[0]
-/// %lhs_1 = vector.extract %lhs[1]
-/// %rhs_1 = vector.extract %rhs[1]
-/// %zip_0 = vector.interleave %lhs_0, %rhs_0
-/// %zip_1 = vector.interleave %lhs_1, %rhs_1
-/// %res_0 = vector.insert %zip_0, %undef[0]
-///     %0 = vector.insert %zip_1, %res_0[1]
+/// Would be unrolled to:
+/// ```mlir
+/// %result = arith.constant dense<0> : vector<1x2x3x8xi64>
+/// %0 = vector.extract %a[0, 0, 0]                 ─┐
+///        : vector<4xi64> from vector<1x2x3x4xi64>  |
+/// %1 = vector.extract %b[0, 0, 0]                  |
+///        : vector<4xi64> from vector<1x2x3x4xi64>  | - Repeated 6x for
+/// %2 = vector.interleave %0, %1 : vector<4xi64>    |   all leading positions
+/// %3 = vector.insert %2, %result [0, 0, 0]         |
+///        : vector<8xi64> into vector<1x2x3x8xi64>  ┘
 /// ```
 ///
-/// If %zip_0 and %zip_1 still have a rank > 1 they will be unrolled again
-/// following the same pattern.
-class InterleaveOpLowering : public OpRewritePattern<vector::InterleaveOp> {
+/// Note: If any leading dimension before the `targetRank` is scalable the
+/// unrolling will stop before the scalable dimension.
+class UnrollInterleaveOp : public OpRewritePattern<vector::InterleaveOp> {
 public:
-  using OpRewritePattern::OpRewritePattern;
+  UnrollInterleaveOp(int64_t targetRank, MLIRContext *context,
+                     PatternBenefit benefit = 1)
+      : OpRewritePattern(context, benefit), targetRank(targetRank){};
 
   LogicalResult matchAndRewrite(vector::InterleaveOp op,
                                 PatternRewriter &rewriter) const override {
     VectorType resultType = op.getResultVectorType();
-    // 1-D vector.interleave ops can be directly lowered to LLVM (later).
-    if (resultType.getRank() == 1)
+    if (resultType.getRank() <= targetRank)
       return failure();
 
-    // Below we unroll the leading (or front) dimension. If that dimension is
-    // scalable we can't unroll it.
-    if (resultType.getScalableDims().front())
+    auto unrollIterator = vector::createUnrollIterator(resultType, targetRank);
+    if (!unrollIterator)
       return failure();
 
-    // n-D case: Unroll the leading dimension.
     auto loc = op.getLoc();
     Value result = rewriter.create<arith::ConstantOp>(
         loc, resultType, rewriter.getZeroAttr(resultType));
-    for (int idx = 0, end = resultType.getDimSize(0); idx < end; ++idx) {
-      Value extractLhs = rewriter.create<ExtractOp>(loc, op.getLhs(), idx);
-      Value extractRhs = rewriter.create<ExtractOp>(loc, op.getRhs(), idx);
+    for (auto position : *unrollIterator) {
+      Value extractLhs = rewriter.create<ExtractOp>(loc, op.getLhs(), position);
+      Value extractRhs = rewriter.create<ExtractOp>(loc, op.getRhs(), position);
       Value interleave =
           rewriter.create<InterleaveOp>(loc, extractLhs, extractRhs);
-      result = rewriter.create<InsertOp>(loc, interleave, result, idx);
+      result = rewriter.create<InsertOp>(loc, interleave, result, position);
     }
 
     rewriter.replaceOp(op, result);
     return success();
   }
+
+private:
+  int64_t targetRank = 1;
 };
 
 } // namespace
 
 void mlir::vector::populateVectorInterleaveLoweringPatterns(
-    RewritePatternSet &patterns, PatternBenefit benefit) {
-  patterns.add<InterleaveOpLowering>(patterns.getContext(), benefit);
+    RewritePatternSet &patterns, int64_t targetRank, PatternBenefit benefit) {
+  patterns.add<UnrollInterleaveOp>(targetRank, patterns.getContext(), benefit);
 }
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 377f3d8c557474..d888e2d8308d5b 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -303,3 +303,25 @@ bool vector::isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
 
   return llvm::all_of(leadingDims, [](auto x) { return x == 1; });
 }
+
+std::optional<StaticTileOffsetRange>
+vector::createUnrollIterator(VectorType vType, int64_t targetRank) {
+  if (vType.getRank() <= targetRank)
+    return {};
+  // Attempt to unroll until targetRank or the first scalable dimension (which
+  // cannot be unrolled).
+  auto shapeToUnroll = vType.getShape().drop_back(targetRank);
+  auto scalableDimsToUnroll = vType.getScalableDims().drop_back(targetRank);
+  auto it =
+      std::find(scalableDimsToUnroll.begin(), scalableDimsToUnroll.end(), true);
+  auto firstScalableDim = it - scalableDimsToUnroll.begin();
+  if (firstScalableDim == 0)
+    return {};
+  // All scalable dimensions should be removed now.
+  scalableDimsToUnroll = scalableDimsToUnroll.slice(0, firstScalableDim);
+  assert(!llvm::is_contained(scalableDimsToUnroll, true) &&
+         "unexpected leading scalable dimension");
+  // Create an unroll iterator for leading dimensions.
+  shapeToUnroll = shapeToUnroll.slice(0, firstScalableDim);
+  return StaticTileOffsetRange(shapeToUnroll, /*unrollStep=*/1);
+}
diff --git a/mlir/test/Dialect/Vector/vector-interleave-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-interleave-lowering-transforms.mlir
index 05be19c3c81f9a..3dd4857860eb13 100644
--- a/mlir/test/Dialect/Vector/vector-interleave-lowering-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-interleave-lowering-transforms.mlir
@@ -36,6 +36,28 @@ func.func @vector_interleave_2d_scalable(%a: vector<2x[8]xi16>, %b: vector<2x[8]
   return %0 : vector<2x[16]xi16>
 }
 
+// CHECK-LABEL: @vector_interleave_4d
+//  CHECK-SAME:     %[[LHS:.*]]: vector<1x2x3x4xi64>, %[[RHS:.*]]: vector<1x2x3x4xi64>)
+func.func @vector_interleave_4d(%a: vector<1x2x3x4xi64>, %b: vector<1x2x3x4xi64>) -> vector<1x2x3x8xi64>
+{
+  // CHECK: %[[LHS_0:.*]] = vector.extract %[[LHS]][0, 0, 0] : vector<4xi64> from vector<1x2x3x4xi64>
+  // CHECK: %[[RHS_0:.*]] = vector.extract %[[RHS]][0, 0, 0] : vector<4xi64> from vector<1x2x3x4xi64>
+  // CHECK: %[[ZIP_0:.*]] = vector.interleave %[[LHS_0]], %[[RHS_0]] : vector<4xi64>
+  // CHECK: %[[RES_0:.*]] = vector.insert %[[ZIP_0]], %{{.*}} [0, 0, 0] : vector<8xi64> into vector<1x2x3x8xi64>
+  // CHECK-COUNT-5: vector.interleave %{{.*}}, %{{.*}} : vector<4xi64>
+  %0 = vector.interleave %a, %b : vector<1x2x3x4xi64>
+  return %0 : vector<1x2x3x8xi64>
+}
+
+// CHECK-LABEL: @vector_interleave_nd_with_scalable_dim
+func.func @vector_interleave_nd_with_scalable_dim(%a: vector<1x3x[2]x2x3x4xf16>, %b: vector<1x3x[2]x2x3x4xf16>) -> vector<1x3x[2]x2x3x8xf16>
+{
+  // The scalable dim blocks unrolling so only the first two dims are unrolled.
+  // CHECK-COUNT-3: vector.interleave %{{.*}}, %{{.*}} : vector<[2]x2x3x4xf16>
+  %0 = vector.interleave %a, %b : vector<1x3x[2]x2x3x4xf16>
+  return %0 : vector<1x3x[2]x2x3x8xf16>
+}
+
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
     %f = transform.structured.match ops{["func.func"]} in %module_op



More information about the Mlir-commits mailing list