[Mlir-commits] [mlir] [mlir][VectorOps] Add `vector.interleave` operation (PR #80315)

Benjamin Maxwell llvmlistbot at llvm.org
Tue Feb 6 04:34:15 PST 2024


https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/80315

>From 171007a004eece6287d3ec141403052ab8efef53 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 1 Feb 2024 17:54:48 +0000
Subject: [PATCH 1/2] [mlir][VectorOps] Add `vector.interleave` operation

The interleave operation constructs a new vector by interleaving the
elements from the trailing (or final) dimension of two input vectors,
returning a new vector where the trailing dimension is twice the size.

Note that for the n-D case this differs from the interleaving possible
with `vector.shuffle`, which would only operate on the leading dimension.

Another key difference is this operation supports scalable vectors,
though currently a general LLVM lowering is limited to the case where
only the trailing dimension is scalable.

Example:
```mlir
%0 = vector.interleave %a, %b
            : vector<[4]xi32>     ; yields vector<[8]xi32>
%1 = vector.interleave %c, %d
            : vector<8xi8>        ; yields vector<16xi8>
%2 = vector.interleave %e, %f
            : vector<f16>         ; yields vector<2xf16>
%3 = vector.interleave %g, %h
            : vector<2x4x[2]xf64> ; yields vector<2x4x[4]xf64>
%4 = vector.interleave %i, %j
            : vector<6x3xf32>     ; yields vector<6x6xf32>
```
---
 .../mlir/Dialect/Vector/IR/VectorOps.td       | 65 ++++++++++++++
 .../VectorToLLVM/ConvertVectorToLLVM.cpp      | 67 ++++++++++++++-
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 42 +++++++++
 .../VectorToLLVM/vector-to-llvm.mlir          | 85 +++++++++++++++++++
 mlir/test/Dialect/Vector/canonicalize.mlir    | 23 +++++
 .../CPU/ArmSVE/test-scalable-interleave.mlir  | 25 ++++++
 .../Dialect/Vector/CPU/test-interleave.mlir   | 24 ++++++
 7 files changed, 330 insertions(+), 1 deletion(-)
 create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/test-scalable-interleave.mlir
 create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/test-interleave.mlir

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index bc08f8d07fb0d..38c49e7da5dee 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -478,6 +478,71 @@ def Vector_ShuffleOp :
   let hasCanonicalizer = 1;
 }
 
+def Vector_InterleaveOp :
+  Vector_Op<"interleave", [Pure,
+    AllTypesMatch<["lhs", "rhs"]>,
+    TypesMatchWith<
+    "type of 'result' is double the width of the inputs",
+    "lhs", "result",
+    [{
+      [&]() -> ::mlir::VectorType {
+        auto vectorType = ::llvm::cast<mlir::VectorType>($_self);
+        ::mlir::VectorType::Builder builder(vectorType);
+        if (vectorType.getRank() == 0) {
+          static constexpr int64_t v2xty_shape[] = { 2 };
+          return builder.setShape(v2xty_shape);
+        }
+        auto lastDim = vectorType.getRank() - 1;
+        return builder.setDim(lastDim, vectorType.getDimSize(lastDim) * 2);
+      }()
+    }]>]> {
+  let summary = "constructs a vector by interleaving two input vectors";
+  let description = [{
+    The interleave operation constructs a new vector by interleaving the
+    elements from the trailing (or final) dimension of two input vectors,
+    returning a new vector where the trailing dimension is twice the size.
+
+    Note that for the n-D case this differs from the interleaving possible with
+    `vector.shuffle`, which would only operate on the leading dimension.
+
+    Another key difference is this operation supports scalable vectors, though
+    currently a general LLVM lowering is limited to the case where only the
+    trailing dimension is scalable.
+
+    Example:
+    ```mlir
+    %0 = vector.interleave %a, %b
+               : vector<[4]xi32>     ; yields vector<[8]xi32>
+    %1 = vector.interleave %c, %d
+               : vector<8xi8>        ; yields vector<16xi8>
+    %2 = vector.interleave %e, %f
+               : vector<f16>         ; yields vector<2xf16>
+    %3 = vector.interleave %g, %h
+               : vector<2x4x[2]xf64> ; yields vector<2x4x[4]xf64>
+    %4 = vector.interleave %i, %j
+               : vector<6x3xf32>     ; yields vector<6x6xf32>
+    ```
+  }];
+
+  let arguments = (ins AnyVectorOfAnyRank:$lhs, AnyVectorOfAnyRank:$rhs);
+  let results = (outs AnyVector:$result);
+
+  let assemblyFormat = [{
+    $lhs `,` $rhs  attr-dict `:` type($lhs)
+  }];
+
+  let extraClassDeclaration = [{
+    VectorType getSourceVectorType() {
+      return ::llvm::cast<VectorType>(getLhs().getType());
+    }
+    VectorType getResultVectorType() {
+      return ::llvm::cast<VectorType>(getResult().getType());
+    }
+  }];
+
+  let hasCanonicalizer = 1;
+}
+
 def Vector_ExtractElementOp :
   Vector_Op<"extractelement", [Pure,
      TypesMatchWith<"result type matches element type of vector operand",
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index b66b55ae8d57f..4dc62608d1b92 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1734,6 +1734,70 @@ struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
   }
 };
 
+struct VectorInterleaveOpLowering
+    : public ConvertOpToLLVMPattern<vector::InterleaveOp> {
+  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+  void initialize() {
+    // This pattern recursively unpacks one dimension at a time. The recursion
+    // bounded as the rank is strictly decreasing.
+    setHasBoundedRewriteRecursion();
+  }
+
+  LogicalResult
+  matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    VectorType resultType = interleaveOp.getResultVectorType();
+
+    // If the result is rank 1, then this directly maps to LLVM.
+    if (resultType.getRank() == 1) {
+      if (resultType.isScalable()) {
+        rewriter.replaceOpWithNewOp<LLVM::experimental_vector_interleave2>(
+            interleaveOp, typeConverter->convertType(resultType),
+            adaptor.getLhs(), adaptor.getRhs());
+        return success();
+      }
+      // Lower fixed-size interleaves to a shufflevector. While the
+      // vector.interleave2 intrinsic supports fixed and scalable vectors, the
+      // langref still recommends fixed-vectors use shufflevector, see:
+      // https://llvm.org/docs/LangRef.html#id876.
+      int64_t resultVectorSize = resultType.getNumElements();
+      SmallVector<int32_t> interleaveShuffleMask;
+      interleaveShuffleMask.reserve(resultVectorSize);
+      for (int i = 0; i < resultVectorSize / 2; i++) {
+        interleaveShuffleMask.push_back(i);
+        interleaveShuffleMask.push_back((resultVectorSize / 2) + i);
+      }
+      rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(
+          interleaveOp, adaptor.getLhs(), adaptor.getRhs(),
+          interleaveShuffleMask);
+      return success();
+    }
+
+    // It's not possible to unroll a scalable dimension.
+    if (resultType.getScalableDims().front())
+      return failure();
+
+    // n-D case: Unroll the leading dimension.
+    // This eventually converges to an LLVM lowering.
+    auto loc = interleaveOp.getLoc();
+    Value result = rewriter.create<arith::ConstantOp>(
+        loc, resultType, rewriter.getZeroAttr(resultType));
+    for (int d = 0; d < resultType.getDimSize(0); d++) {
+      Value extractLhs =
+          rewriter.create<ExtractOp>(loc, interleaveOp.getLhs(), d);
+      Value extractRhs =
+          rewriter.create<ExtractOp>(loc, interleaveOp.getRhs(), d);
+      Value dimInterleave =
+          rewriter.create<InterleaveOp>(loc, extractLhs, extractRhs);
+      result = rewriter.create<InsertOp>(loc, dimInterleave, result, d);
+    }
+
+    rewriter.replaceOp(interleaveOp, result);
+    return success();
+  }
+};
+
 } // namespace
 
 /// Populate the given list with patterns that convert from Vector to LLVM.
@@ -1758,7 +1822,8 @@ void mlir::populateVectorToLLVMConversionPatterns(
                VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
                VectorSplatOpLowering, VectorSplatNdOpLowering,
                VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
-               MaskedReductionOpConversion>(converter);
+               MaskedReductionOpConversion, VectorInterleaveOpLowering>(
+      converter);
   // Transfer ops with rank > 1 are handled by VectorToSCF.
   populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
 }
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 452354413e883..8aabc35f4c265 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6308,6 +6308,48 @@ bool WarpExecuteOnLane0Op::areTypesCompatible(Type lhs, Type rhs) {
       verifyDistributedType(lhs, rhs, getWarpSize(), getOperation()));
 }
 
+//===----------------------------------------------------------------------===//
+// InterleaveOp
+//===----------------------------------------------------------------------===//
+
+// The rank 1 case of vector.interleave on fixed-size vectors is equivalent to a
+// vector.shuffle, which (as an older op) is more likely to be matched by
+// existing pipelines.
+struct FoldRank1FixedSizeInterleaveOp : public OpRewritePattern<InterleaveOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(InterleaveOp interleaveOp,
+                                PatternRewriter &rewriter) const override {
+    auto resultType = interleaveOp.getResultVectorType();
+    if (resultType.getRank() != 1)
+      return rewriter.notifyMatchFailure(
+          interleaveOp, "cannot fold interleave with result rank > 1");
+
+    if (resultType.isScalable())
+      return rewriter.notifyMatchFailure(
+          interleaveOp, "cannot fold interleave of scalable vectors");
+
+    int64_t resultVectorSize = resultType.getNumElements();
+    SmallVector<int64_t> interleaveShuffleMask;
+    interleaveShuffleMask.reserve(resultVectorSize);
+    for (int i = 0; i < resultVectorSize / 2; i++) {
+      interleaveShuffleMask.push_back(i);
+      interleaveShuffleMask.push_back((resultVectorSize / 2) + i);
+    }
+
+    rewriter.replaceOpWithNewOp<ShuffleOp>(interleaveOp, interleaveOp.getLhs(),
+                                           interleaveOp.getRhs(),
+                                           interleaveShuffleMask);
+
+    return success();
+  }
+};
+
+void InterleaveOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                               MLIRContext *context) {
+  results.add<FoldRank1FixedSizeInterleaveOp>(context);
+}
+
 Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
                                        CombiningKind kind, Value v1, Value acc,
                                        arith::FastMathFlagsAttr fastmath,
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 1c13b16dfd9af..3cbca65472fb6 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -2460,3 +2460,88 @@ func.func @make_fixed_vector_of_scalable_vector(%f : f64) -> vector<3x[2]xf64>
   %res = vector.broadcast %f : f64 to vector<3x[2]xf64>
   return %res : vector<3x[2]xf64>
 }
+
+// -----
+
+// CHECK-LABEL: @vector_interleave_0d
+//  CHECK-SAME:     %[[LHS:.*]]: vector<i8>, %[[RHS:.*]]: vector<i8>)
+func.func @vector_interleave_0d(%a: vector<i8>, %b: vector<i8>) -> vector<2xi8> {
+  // CHECK: %[[LHS_RANK1:.*]] = builtin.unrealized_conversion_cast %[[LHS]] : vector<i8> to vector<1xi8>
+  // CHECK: %[[RHS_RANK1:.*]] = builtin.unrealized_conversion_cast %[[RHS]] : vector<i8> to vector<1xi8>
+  // CHECK: %[[ZIP:.*]] = llvm.shufflevector %[[LHS_RANK1]], %[[RHS_RANK1]] [0, 1] : vector<1xi8>
+  // CHECK: return %[[ZIP]]
+  %0 = vector.interleave %a, %b : vector<i8>
+  return %0 : vector<2xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @vector_interleave_1d
+//  CHECK-SAME:     %[[LHS:.*]]: vector<8xf32>, %[[RHS:.*]]: vector<8xf32>)
+func.func @vector_interleave_1d(%a: vector<8xf32>, %b: vector<8xf32>) -> vector<16xf32>
+{
+  // CHECK: %[[ZIP:.*]] = llvm.shufflevector %[[LHS]], %[[RHS]] [0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] : vector<8xf32>
+  // CHECK: return %[[ZIP]]
+  %0 = vector.interleave %a, %b : vector<8xf32>
+  return %0 : vector<16xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @vector_interleave_1d_scalable
+//  CHECK-SAME:     %[[LHS:.*]]: vector<[4]xi32>, %[[RHS:.*]]: vector<[4]xi32>)
+func.func @vector_interleave_1d_scalable(%a: vector<[4]xi32>, %b: vector<[4]xi32>) -> vector<[8]xi32>
+{
+  // CHECK: %[[ZIP:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[LHS]], %[[RHS]]) : (vector<[4]xi32>, vector<[4]xi32>) -> vector<[8]xi32>
+  // CHECK: return %[[ZIP]]
+  %0 = vector.interleave %a, %b : vector<[4]xi32>
+  return %0 : vector<[8]xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @vector_interleave_2d
+//  CHECK-SAME:     %[[LHS:.*]]: vector<2x3xi8>, %[[RHS:.*]]: vector<2x3xi8>)
+func.func @vector_interleave_2d(%a: vector<2x3xi8>, %b: vector<2x3xi8>) -> vector<2x6xi8>
+{
+  // CHECK: %[[LHS_LLVM:.*]] = builtin.unrealized_conversion_cast %[[LHS]] : vector<2x3xi8> to !llvm.array<2 x vector<3xi8>>
+  // CHECK: %[[RHS_LLVM:.*]] = builtin.unrealized_conversion_cast %[[RHS]] : vector<2x3xi8> to !llvm.array<2 x vector<3xi8>>
+  // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<2x6xi8>
+  // CHECK: %[[CST_LLVM:.*]] = builtin.unrealized_conversion_cast %[[CST]] : vector<2x6xi8> to !llvm.array<2 x vector<6xi8>>
+  // CHECK: %[[LHS_DIM_0:.*]] = llvm.extractvalue %[[LHS_LLVM]][0] : !llvm.array<2 x vector<3xi8>>
+  // CHECK: %[[RHS_DIM_0:.*]] = llvm.extractvalue %[[RHS_LLVM]][0] : !llvm.array<2 x vector<3xi8>>
+  // CHECK: %[[ZIM_DIM_0:.*]] = llvm.shufflevector %[[LHS_DIM_0]], %[[RHS_DIM_0]] [0, 3, 1, 4, 2, 5] : vector<3xi8>
+  // CHECK: %[[RES_0:.*]] = llvm.insertvalue %[[ZIM_DIM_0]], %[[CST_LLVM]][0] : !llvm.array<2 x vector<6xi8>>
+  // CHECK: %[[LHS_DIM_1:.*]] = llvm.extractvalue %[[LHS_LLVM]][1] : !llvm.array<2 x vector<3xi8>>
+  // CHECK: %[[RHS_DIM_1:.*]] = llvm.extractvalue %[[RHS_LLVM]][1] : !llvm.array<2 x vector<3xi8>>
+  // CHECK: %[[ZIM_DIM_1:.*]] = llvm.shufflevector %[[LHS_DIM_1]], %[[RHS_DIM_1]] [0, 3, 1, 4, 2, 5] : vector<3xi8>
+  // CHECK: %[[RES_1:.*]] = llvm.insertvalue %[[ZIM_DIM_1]], %[[RES_0]][1] : !llvm.array<2 x vector<6xi8>>
+  // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[RES_1]] : !llvm.array<2 x vector<6xi8>> to vector<2x6xi8>
+  // CHECK: return %[[RES]]
+  %0 = vector.interleave %a, %b : vector<2x3xi8>
+  return %0 : vector<2x6xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @vector_interleave_2d_scalable
+//  CHECK-SAME:     %[[LHS:.*]]: vector<2x[8]xi16>, %[[RHS:.*]]: vector<2x[8]xi16>)
+func.func @vector_interleave_2d_scalable(%a: vector<2x[8]xi16>, %b: vector<2x[8]xi16>) -> vector<2x[16]xi16>
+{
+  // CHECK: %[[LHS_LLVM:.*]] = builtin.unrealized_conversion_cast %arg0 : vector<2x[8]xi16> to !llvm.array<2 x vector<[8]xi16>>
+  // CHECK: %[[RHS_LLVM:.*]] = builtin.unrealized_conversion_cast %arg1 : vector<2x[8]xi16> to !llvm.array<2 x vector<[8]xi16>>
+  // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<2x[16]xi16>
+  // CHECK: %[[CST_LLVM:.*]] = builtin.unrealized_conversion_cast %[[CST]] : vector<2x[16]xi16> to !llvm.array<2 x vector<[16]xi16>>
+  // CHECK: %[[LHS_DIM_0:.*]] = llvm.extractvalue %[[LHS_LLVM]][0] : !llvm.array<2 x vector<[8]xi16>>
+  // CHECK: %[[RHS_DIM_0:.*]] = llvm.extractvalue %[[RHS_LLVM]][0] : !llvm.array<2 x vector<[8]xi16>>
+  // CHECK: %[[ZIM_DIM_0:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[LHS_DIM_0]], %[[RHS_DIM_0]]) : (vector<[8]xi16>, vector<[8]xi16>) -> vector<[16]xi16>
+  // CHECK: %[[RES_0:.*]] = llvm.insertvalue %[[ZIM_DIM_0]], %[[CST_LLVM]][0] : !llvm.array<2 x vector<[16]xi16>>
+  // CHECK: %[[LHS_DIM_1:.*]] = llvm.extractvalue %0[1] : !llvm.array<2 x vector<[8]xi16>>
+  // CHECK: %[[RHS_DIM_1:.*]] = llvm.extractvalue %1[1] : !llvm.array<2 x vector<[8]xi16>>
+  // CHECK: %[[ZIP_DIM_1:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[LHS_DIM_1]], %[[RHS_DIM_1]]) : (vector<[8]xi16>, vector<[8]xi16>) -> vector<[16]xi16>
+  // CHECK: %[[RES_1:.*]] = llvm.insertvalue %[[ZIP_DIM_1]], %[[RES_0]][1] : !llvm.array<2 x vector<[16]xi16>>
+  // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[RES_1]] : !llvm.array<2 x vector<[16]xi16>> to vector<2x[16]xi16>
+  // CHECK: return %[[RES]]
+  %0 = vector.interleave %a, %b : vector<2x[8]xi16>
+  return %0 : vector<2x[16]xi16>
+}
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index e6f045e12e519..490ee6a462c6a 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2567,3 +2567,26 @@ func.func @load_store_forwarding_rank_mismatch(%v0: vector<4x1x1xf32>, %arg0: te
       tensor<4x4x4xf32>, vector<1x100x4x5xf32>
   return %r : vector<1x100x4x5xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func.func @fold_rank_1_vector_interleave(
+//  CHECK-SAME:     %[[LHS:.*]]: vector<6xi32>, %[[RHS:.*]]: vector<6xi32>)
+func.func @fold_rank_1_vector_interleave(%arg0: vector<6xi32>, %arg1: vector<6xi32>) -> vector<12xi32> {
+  // CHECK: %[[ZIP:.*]] = vector.shuffle %[[LHS]], %[[RHS]] [0, 6, 1, 7, 2, 8, 3, 9, 4, 10, 5, 11] : vector<6xi32>, vector<6xi32>
+  // CHECK: return %[[ZIP]] : vector<12xi32>
+  %0 = vector.interleave %arg0, %arg1 : vector<6xi32>
+  return %0 : vector<12xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @fold_rank_0_vector_interleave(
+//  CHECK-SAME:     %[[LHS:.*]]: vector<f64>, %[[RHS:.*]]: vector<f64>)
+func.func @fold_rank_0_vector_interleave(%arg0: vector<f64>, %arg1: vector<f64>) -> vector<2xf64>
+{
+  // CHECK: %[[ZIP:.*]] = vector.shuffle %[[LHS]], %[[RHS]] [0, 1] : vector<f64>, vector<f64>
+  // CHECK: return %[[ZIP]] : vector<2xf64>
+  %0 = vector.interleave %arg0, %arg1 : vector<f64>
+  return %0 : vector<2xf64>
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/test-scalable-interleave.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/test-scalable-interleave.mlir
new file mode 100644
index 0000000000000..58dd3d700beff
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/test-scalable-interleave.mlir
@@ -0,0 +1,25 @@
+// RUN: mlir-opt %s -test-lower-to-llvm | \
+// RUN: %mcr_aarch64_cmd -e entry -entry-point-result=void  \
+// RUN:   -shared-libs=%mlir_c_runner_utils,%mlir_arm_runner_utils | \
+// RUN: FileCheck %s
+
+func.func @entry() {
+  %f1 = arith.constant 1.0: f32
+  %f2 = arith.constant 2.0: f32
+  %v1 = vector.splat %f1 : vector<[4]xf32>
+  %v2 = vector.splat %f2 :  vector<[4]xf32>
+  vector.print %v1 : vector<[4]xf32>
+  vector.print %v2 : vector<[4]xf32>
+  //
+  // Test vectors:
+  //
+  // CHECK: ( 1, 1, 1, 1
+  // CHECK: ( 2, 2, 2, 2
+
+  %v3 = vector.interleave %v1, %v2 : vector<[4]xf32>
+  vector.print %v3 : vector<[8]xf32>
+  // CHECK: ( 1, 2, 1, 2, 1, 2, 1, 2
+
+  return
+}
+
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-interleave.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-interleave.mlir
new file mode 100644
index 0000000000000..c6dd6287208d4
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-interleave.mlir
@@ -0,0 +1,24 @@
+// RUN: mlir-opt %s -test-lower-to-llvm | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void  \
+// RUN:   -shared-libs=%mlir_c_runner_utils | \
+// RUN: FileCheck %s
+
+func.func @entry() {
+  %f1 = arith.constant 1.0: f32
+  %f2 = arith.constant 2.0: f32
+  %v1 = vector.splat %f1 : vector<2x4xf32>
+  %v2 = vector.splat %f2 :  vector<2x4xf32>
+  vector.print %v1 : vector<2x4xf32>
+  vector.print %v2 : vector<2x4xf32>
+  //
+  // Test vectors:
+  //
+  // CHECK: ( ( 1, 1, 1, 1 ), ( 1, 1, 1, 1 ) )
+  // CHECK: ( ( 2, 2, 2, 2 ), ( 2, 2, 2, 2 ) )
+
+  %v3 = vector.interleave %v1, %v2 : vector<2x4xf32>
+  vector.print %v3 : vector<2x8xf32>
+  // CHECK: ( ( 1, 2, 1, 2, 1, 2, 1, 2 ), ( 1, 2, 1, 2, 1, 2, 1, 2 ) )
+
+  return
+}

>From 795f4b873f2863e801e982aafd8f54ceb0b54351 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 6 Feb 2024 12:29:40 +0000
Subject: [PATCH 2/2] Fixups

- Remove vector.interleave -> vector.shuffle canonicalization
- Add vector.shuffle -> vector.interleave canonicalization
- Split vector.interleave unrolling and LLVM lowering
  - Unrolling now done in LowerVectorInterleave.cpp
- Add missing tests to vector ops.mlir
- Fixed a few nits
---
 .../mlir/Dialect/Vector/IR/VectorOps.td       |  2 -
 .../Vector/Transforms/LoweringPatterns.h      |  8 ++
 .../VectorToLLVM/ConvertVectorToLLVM.cpp      | 70 +++++----------
 .../VectorToLLVM/ConvertVectorToLLVMPass.cpp  |  1 +
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 85 +++++++++----------
 .../Dialect/Vector/Transforms/CMakeLists.txt  |  1 +
 .../Transforms/LowerVectorInterleave.cpp      | 64 ++++++++++++++
 mlir/test/Dialect/Vector/canonicalize.mlir    | 30 +++----
 mlir/test/Dialect/Vector/ops.mlir             | 35 ++++++++
 .../CPU/ArmSVE/test-scalable-interleave.mlir  |  4 +-
 .../Dialect/Vector/CPU/test-interleave.mlir   |  4 +-
 11 files changed, 192 insertions(+), 112 deletions(-)
 create mode 100644 mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 38c49e7da5dee..6d50b0654bc57 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -539,8 +539,6 @@ def Vector_InterleaveOp :
       return ::llvm::cast<VectorType>(getResult().getType());
     }
   }];
-
-  let hasCanonicalizer = 1;
 }
 
 def Vector_ExtractElementOp :
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 57b39f5f52c6d..1cd3bab46396e 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -264,6 +264,14 @@ void populateVectorMaskLoweringPatternsForSideEffectingOps(
 void populateVectorMaskedLoadStoreEmulationPatterns(RewritePatternSet &patterns,
                                                     PatternBenefit benefit = 1);
 
+/// Populate the pattern set with the following patterns:
+///
+/// [InterleaveOpLowering]
+/// Progressive lowering of InterleaveOp to ExtractOp + InsertOp + lower-D
+/// InterleaveOp until dim 1.
+void populateVectorInterleaveLoweringPatterns(RewritePatternSet &patterns,
+                                              PatternBenefit benefit = 1);
+
 } // namespace vector
 } // namespace mlir
 #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 4dc62608d1b92..0d9a451d11ca8 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1734,66 +1734,40 @@ struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
   }
 };
 
+/// Conversion pattern for a `vector.interleave`.
+/// This supports fixed-sized vectors and scalable vectors.
 struct VectorInterleaveOpLowering
     : public ConvertOpToLLVMPattern<vector::InterleaveOp> {
   using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
 
-  void initialize() {
-    // This pattern recursively unpacks one dimension at a time. The recursion
-    // bounded as the rank is strictly decreasing.
-    setHasBoundedRewriteRecursion();
-  }
-
   LogicalResult
   matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     VectorType resultType = interleaveOp.getResultVectorType();
-
+    // n-D interleaves should have been lowered already.
+    if (resultType.getRank() != 1)
+      return failure();
     // If the result is rank 1, then this directly maps to LLVM.
-    if (resultType.getRank() == 1) {
-      if (resultType.isScalable()) {
-        rewriter.replaceOpWithNewOp<LLVM::experimental_vector_interleave2>(
-            interleaveOp, typeConverter->convertType(resultType),
-            adaptor.getLhs(), adaptor.getRhs());
-        return success();
-      }
-      // Lower fixed-size interleaves to a shufflevector. While the
-      // vector.interleave2 intrinsic supports fixed and scalable vectors, the
-      // langref still recommends fixed-vectors use shufflevector, see:
-      // https://llvm.org/docs/LangRef.html#id876.
-      int64_t resultVectorSize = resultType.getNumElements();
-      SmallVector<int32_t> interleaveShuffleMask;
-      interleaveShuffleMask.reserve(resultVectorSize);
-      for (int i = 0; i < resultVectorSize / 2; i++) {
-        interleaveShuffleMask.push_back(i);
-        interleaveShuffleMask.push_back((resultVectorSize / 2) + i);
-      }
-      rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(
-          interleaveOp, adaptor.getLhs(), adaptor.getRhs(),
-          interleaveShuffleMask);
+    if (resultType.isScalable()) {
+      rewriter.replaceOpWithNewOp<LLVM::experimental_vector_interleave2>(
+          interleaveOp, typeConverter->convertType(resultType),
+          adaptor.getLhs(), adaptor.getRhs());
       return success();
     }
-
-    // It's not possible to unroll a scalable dimension.
-    if (resultType.getScalableDims().front())
-      return failure();
-
-    // n-D case: Unroll the leading dimension.
-    // This eventually converges to an LLVM lowering.
-    auto loc = interleaveOp.getLoc();
-    Value result = rewriter.create<arith::ConstantOp>(
-        loc, resultType, rewriter.getZeroAttr(resultType));
-    for (int d = 0; d < resultType.getDimSize(0); d++) {
-      Value extractLhs =
-          rewriter.create<ExtractOp>(loc, interleaveOp.getLhs(), d);
-      Value extractRhs =
-          rewriter.create<ExtractOp>(loc, interleaveOp.getRhs(), d);
-      Value dimInterleave =
-          rewriter.create<InterleaveOp>(loc, extractLhs, extractRhs);
-      result = rewriter.create<InsertOp>(loc, dimInterleave, result, d);
+    // Lower fixed-size interleaves to a shufflevector. While the
+    // vector.interleave2 intrinsic supports fixed and scalable vectors, the
+    // langref still recommends fixed-vectors use shufflevector, see:
+    // https://llvm.org/docs/LangRef.html#id876.
+    int64_t resultVectorSize = resultType.getNumElements();
+    SmallVector<int32_t> interleaveShuffleMask;
+    interleaveShuffleMask.reserve(resultVectorSize);
+    for (int i = 0, end = resultVectorSize / 2; i < end; ++i) {
+      interleaveShuffleMask.push_back(i);
+      interleaveShuffleMask.push_back((resultVectorSize / 2) + i);
     }
-
-    rewriter.replaceOp(interleaveOp, result);
+    rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(
+        interleaveOp, adaptor.getLhs(), adaptor.getRhs(),
+        interleaveShuffleMask);
     return success();
   }
 };
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index ff8e78a668e0f..e3a436c4a9400 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -68,6 +68,7 @@ void LowerVectorToLLVMPass::runOnOperation() {
     populateVectorContractLoweringPatterns(patterns, VectorTransformsOptions());
     populateVectorMaskOpLoweringPatterns(patterns);
     populateVectorShapeCastLoweringPatterns(patterns);
+    populateVectorInterleaveLoweringPatterns(patterns);
     populateVectorTransposeLoweringPatterns(patterns,
                                             VectorTransformsOptions());
     // Vector transfer ops with rank > 1 should be lowered with VectorToSCF.
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 8aabc35f4c265..084348e68270c 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2478,11 +2478,52 @@ class ShuffleSplat final : public OpRewritePattern<ShuffleOp> {
   }
 };
 
+/// Pattern to rewrite a fixed-size interleave via vector.shuffle to
+/// vector.interleave.
+class ShuffleInterleave : public OpRewritePattern<ShuffleOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ShuffleOp op,
+                                PatternRewriter &rewriter) const override {
+    VectorType resultType = op.getResultVectorType();
+    if (resultType.isScalable())
+      return rewriter.notifyMatchFailure(
+          op, "ShuffleOp can't represent a scalable interleave");
+
+    if (resultType.getRank() != 1)
+      return rewriter.notifyMatchFailure(
+          op, "ShuffleOp can't represent an n-D interleave");
+
+    VectorType sourceType = op.getV1VectorType();
+    if (sourceType != op.getV2VectorType() ||
+        ArrayRef<int64_t>{sourceType.getNumElements() * 2} !=
+            resultType.getShape()) {
+      return rewriter.notifyMatchFailure(
+          op, "ShuffleOp types don't match an interleave");
+    }
+
+    ArrayAttr shuffleMask = op.getMask();
+    int64_t resultVectorSize = resultType.getNumElements();
+    for (int i = 0, e = resultVectorSize / 2; i < e; ++i) {
+      int64_t maskValueA = cast<IntegerAttr>(shuffleMask[i * 2]).getInt();
+      int64_t maskValueB = cast<IntegerAttr>(shuffleMask[(i * 2) + 1]).getInt();
+      if (maskValueA != i || maskValueB != (resultVectorSize / 2) + i)
+        return rewriter.notifyMatchFailure(op,
+                                           "ShuffleOp mask not interleaving");
+    }
+
+    rewriter.replaceOpWithNewOp<InterleaveOp>(op, op.getV1(), op.getV2());
+    return success();
+  }
+};
+
 } // namespace
 
 void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                             MLIRContext *context) {
-  results.add<ShuffleSplat, Canonicalize0DShuffleOp>(context);
+  results.add<ShuffleSplat, ShuffleInterleave, Canonicalize0DShuffleOp>(
+      context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -6308,48 +6349,6 @@ bool WarpExecuteOnLane0Op::areTypesCompatible(Type lhs, Type rhs) {
       verifyDistributedType(lhs, rhs, getWarpSize(), getOperation()));
 }
 
-//===----------------------------------------------------------------------===//
-// InterleaveOp
-//===----------------------------------------------------------------------===//
-
-// The rank 1 case of vector.interleave on fixed-size vectors is equivalent to a
-// vector.shuffle, which (as an older op) is more likely to be matched by
-// existing pipelines.
-struct FoldRank1FixedSizeInterleaveOp : public OpRewritePattern<InterleaveOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(InterleaveOp interleaveOp,
-                                PatternRewriter &rewriter) const override {
-    auto resultType = interleaveOp.getResultVectorType();
-    if (resultType.getRank() != 1)
-      return rewriter.notifyMatchFailure(
-          interleaveOp, "cannot fold interleave with result rank > 1");
-
-    if (resultType.isScalable())
-      return rewriter.notifyMatchFailure(
-          interleaveOp, "cannot fold interleave of scalable vectors");
-
-    int64_t resultVectorSize = resultType.getNumElements();
-    SmallVector<int64_t> interleaveShuffleMask;
-    interleaveShuffleMask.reserve(resultVectorSize);
-    for (int i = 0; i < resultVectorSize / 2; i++) {
-      interleaveShuffleMask.push_back(i);
-      interleaveShuffleMask.push_back((resultVectorSize / 2) + i);
-    }
-
-    rewriter.replaceOpWithNewOp<ShuffleOp>(interleaveOp, interleaveOp.getLhs(),
-                                           interleaveOp.getRhs(),
-                                           interleaveShuffleMask);
-
-    return success();
-  }
-};
-
-void InterleaveOp::getCanonicalizationPatterns(RewritePatternSet &results,
-                                               MLIRContext *context) {
-  results.add<FoldRank1FixedSizeInterleaveOp>(context);
-}
-
 Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
                                        CombiningKind kind, Value v1, Value acc,
                                        arith::FastMathFlagsAttr fastmath,
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index daf28882976ef..f221b7462dfd7 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
   LowerVectorBroadcast.cpp
   LowerVectorContract.cpp
   LowerVectorGather.cpp
+  LowerVectorInterleave.cpp
   LowerVectorMask.cpp
   LowerVectorMultiReduction.cpp
   LowerVectorScan.cpp
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
new file mode 100644
index 0000000000000..0ca38eba942a5
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
@@ -0,0 +1,64 @@
+//===- LowerVectorInterleave.cpp - Lower 'vector.interleave' operation ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements target-independent rewrites and utilities to lower the
+// 'vector.interleave' operation.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/PatternMatch.h"
+
+#define DEBUG_TYPE "vector-interleave-lowering"
+
+using namespace mlir;
+using namespace mlir::vector;
+
+namespace {
+/// Progressive lowering of InterleaveOp.
+class InterleaveOpLowering : public OpRewritePattern<vector::InterleaveOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::InterleaveOp op,
+                                PatternRewriter &rewriter) const override {
+    VectorType resultType = op.getResultVectorType();
+    // 1-D vector.interleave ops can be directly lowered to LLVM (later).
+    if (resultType.getRank() == 1)
+      return failure();
+
+    // Below we unroll the leading (or front) dimension. If that dimension is
+    // scalable we can't unroll it.
+    if (resultType.getScalableDims().front())
+      return failure();
+
+    // n-D case: Unroll the leading dimension.
+    auto loc = op.getLoc();
+    Value result = rewriter.create<arith::ConstantOp>(
+        loc, resultType, rewriter.getZeroAttr(resultType));
+    for (int idx = 0, end = resultType.getDimSize(0); idx < end; ++idx) {
+      Value extractLhs = rewriter.create<ExtractOp>(loc, op.getLhs(), idx);
+      Value extractRhs = rewriter.create<ExtractOp>(loc, op.getRhs(), idx);
+      Value interleave =
+          rewriter.create<InterleaveOp>(loc, extractLhs, extractRhs);
+      result = rewriter.create<InsertOp>(loc, interleave, result, idx);
+    }
+
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
+} // namespace
+
+void mlir::vector::populateVectorInterleaveLoweringPatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit) {
+  patterns.add<InterleaveOpLowering>(patterns.getContext(), benefit);
+}
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 490ee6a462c6a..4c73a6271786e 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2570,23 +2570,23 @@ func.func @load_store_forwarding_rank_mismatch(%v0: vector<4x1x1xf32>, %arg0: te
 
 // -----
 
-// CHECK-LABEL: func.func @fold_rank_1_vector_interleave(
-//  CHECK-SAME:     %[[LHS:.*]]: vector<6xi32>, %[[RHS:.*]]: vector<6xi32>)
-func.func @fold_rank_1_vector_interleave(%arg0: vector<6xi32>, %arg1: vector<6xi32>) -> vector<12xi32> {
-  // CHECK: %[[ZIP:.*]] = vector.shuffle %[[LHS]], %[[RHS]] [0, 6, 1, 7, 2, 8, 3, 9, 4, 10, 5, 11] : vector<6xi32>, vector<6xi32>
-  // CHECK: return %[[ZIP]] : vector<12xi32>
-  %0 = vector.interleave %arg0, %arg1 : vector<6xi32>
-  return %0 : vector<12xi32>
+// CHECK-LABEL: func.func @rank_0_shuffle_to_interleave(
+//  CHECK-SAME:     %[[LHS:.*]]: vector<f64>, %[[RHS:.*]]: vector<f64>)
+func.func @rank_0_shuffle_to_interleave(%arg0: vector<f64>, %arg1: vector<f64>) -> vector<2xf64>
+{
+  // CHECK: %[[ZIP:.*]] = vector.interleave %[[LHS]], %[[RHS]] : vector<f64>
+  // CHECK: return %[[ZIP]]
+  %0 = vector.shuffle %arg0, %arg1 [0, 1] : vector<f64>, vector<f64>
+  return %0 : vector<2xf64>
 }
 
 // -----
 
-// CHECK-LABEL: func.func @fold_rank_0_vector_interleave(
-//  CHECK-SAME:     %[[LHS:.*]]: vector<f64>, %[[RHS:.*]]: vector<f64>)
-func.func @fold_rank_0_vector_interleave(%arg0: vector<f64>, %arg1: vector<f64>) -> vector<2xf64>
-{
-  // CHECK: %[[ZIP:.*]] = vector.shuffle %[[LHS]], %[[RHS]] [0, 1] : vector<f64>, vector<f64>
-  // CHECK: return %[[ZIP]] : vector<2xf64>
-  %0 = vector.interleave %arg0, %arg1 : vector<f64>
-  return %0 : vector<2xf64>
+// CHECK-LABEL: func.func @rank_1_shuffle_to_interleave(
+//  CHECK-SAME:     %[[LHS:.*]]: vector<6xi32>, %[[RHS:.*]]: vector<6xi32>)
+func.func @rank_1_shuffle_to_interleave(%arg0: vector<6xi32>, %arg1: vector<6xi32>) -> vector<12xi32> {
+  // CHECK: %[[ZIP:.*]] = vector.interleave %[[LHS]], %[[RHS]] : vector<6xi32>
+  // CHECK: return %[[ZIP]]
+  %0 = vector.shuffle %arg0, %arg1 [0, 6, 1, 7, 2, 8, 3, 9, 4, 10, 5, 11] : vector<6xi32>, vector<6xi32>
+  return %0 : vector<12xi32>
 }
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 2f8530e7c171a..79a80be4f8b20 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -1081,3 +1081,38 @@ func.func @fastmath(%x: vector<42xf32>) -> f32 {
   %min = vector.reduction <minnumf>, %x fastmath<reassoc,nnan,ninf> : vector<42xf32> into f32
   return %min: f32
 }
+
+// CHECK-LABEL: @interleave_0d
+func.func @interleave_0d(%a: vector<f32>, %b: vector<f32>) -> vector<2xf32> {
+  // CHECK: vector.interleave %{{.*}}, %{{.*}} : vector<f32>
+  %0 = vector.interleave %a, %b : vector<f32>
+  return %0 : vector<2xf32>
+}
+
+// CHECK-LABEL: @interleave_1d
+func.func @interleave_1d(%a: vector<4xf32>, %b: vector<4xf32>) -> vector<8xf32> {
+  // CHECK: vector.interleave %{{.*}}, %{{.*}} : vector<4xf32>
+  %0 = vector.interleave %a, %b : vector<4xf32>
+  return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: @interleave_1d_scalable
+func.func @interleave_1d_scalable(%a: vector<[8]xi16>, %b: vector<[8]xi16>) -> vector<[16]xi16> {
+  // CHECK: vector.interleave %{{.*}}, %{{.*}} : vector<[8]xi16>
+  %0 = vector.interleave %a, %b : vector<[8]xi16>
+  return %0 : vector<[16]xi16>
+}
+
+// CHECK-LABEL: @interleave_2d
+func.func @interleave_2d(%a: vector<2x8xf32>, %b: vector<2x8xf32>) -> vector<2x16xf32> {
+  // CHECK: vector.interleave %{{.*}}, %{{.*}} : vector<2x8xf32>
+  %0 = vector.interleave %a, %b : vector<2x8xf32>
+  return %0 : vector<2x16xf32>
+}
+
+// CHECK-LABEL: @interleave_2d_scalable
+func.func @interleave_2d_scalable(%a: vector<2x[2]xf64>, %b: vector<2x[2]xf64>) -> vector<2x[4]xf64> {
+  // CHECK: vector.interleave %{{.*}}, %{{.*}} : vector<2x[2]xf64>
+  %0 = vector.interleave %a, %b : vector<2x[2]xf64>
+  return %0 : vector<2x[4]xf64>
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/test-scalable-interleave.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/test-scalable-interleave.mlir
index 58dd3d700beff..479e50123bc2b 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/test-scalable-interleave.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/test-scalable-interleave.mlir
@@ -4,8 +4,8 @@
 // RUN: FileCheck %s
 
 func.func @entry() {
-  %f1 = arith.constant 1.0: f32
-  %f2 = arith.constant 2.0: f32
+  %f1 = arith.constant 1.0 : f32
+  %f2 = arith.constant 2.0 : f32
   %v1 = vector.splat %f1 : vector<[4]xf32>
   %v2 = vector.splat %f2 :  vector<[4]xf32>
   vector.print %v1 : vector<[4]xf32>
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-interleave.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-interleave.mlir
index c6dd6287208d4..69bf0320a3697 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-interleave.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-interleave.mlir
@@ -4,8 +4,8 @@
 // RUN: FileCheck %s
 
 func.func @entry() {
-  %f1 = arith.constant 1.0: f32
-  %f2 = arith.constant 2.0: f32
+  %f1 = arith.constant 1.0 : f32
+  %f2 = arith.constant 2.0 : f32
   %v1 = vector.splat %f1 : vector<2x4xf32>
   %v2 = vector.splat %f2 :  vector<2x4xf32>
   vector.print %v1 : vector<2x4xf32>



More information about the Mlir-commits mailing list