[Mlir-commits] [mlir] [mlir][vector] Add result type to `interleave` assembly format (PR #93392)

Jakub Kuderski llvmlistbot at llvm.org
Mon May 27 07:59:40 PDT 2024


https://github.com/kuhar updated https://github.com/llvm/llvm-project/pull/93392

>From 9040d7d1ef07b4f446c8d2d3b4ea317b921398ae Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Sat, 25 May 2024 23:03:40 -0400
Subject: [PATCH 1/2] [mlir][vector] Add result type to `interleave` assembly
 format

This is to make it more obvious for what the result type is, especially
with some less trivial cases like 0-d inputs resulting in 1-d inputs or
interaction with scalable vector types. Note that `vector.deinterleave`
uses the same format with explicit result type.

Also improve examples and clean up surrounding code.
---
 .../mlir/Dialect/Vector/IR/VectorOps.td       | 38 +++++++++----------
 .../Transforms/LowerVectorInterleave.cpp      | 15 ++++----
 .../Transforms/VectorEmulateNarrowType.cpp    |  6 +--
 .../VectorToLLVM/vector-to-llvm.mlir          | 22 +++++------
 .../VectorToSPIRV/vector-to-spirv.mlir        |  2 +-
 mlir/test/Dialect/Vector/canonicalize.mlir    |  7 ++--
 mlir/test/Dialect/Vector/ops.mlir             | 12 +++---
 ...vector-interleave-lowering-transforms.mlir | 20 +++++-----
 .../Vector/vector-interleave-to-shuffle.mlir  |  5 +--
 .../CPU/ArmSVE/test-scalable-interleave.mlir  |  2 +-
 .../Dialect/Vector/CPU/test-interleave.mlir   |  2 +-
 11 files changed, 61 insertions(+), 70 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 2bb7540ef0b0f..e043320b56411 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -480,24 +480,25 @@ def Vector_ShuffleOp :
   let hasCanonicalizer = 1;
 }
 
-def Vector_InterleaveOp :
-  Vector_Op<"interleave", [Pure,
-    AllTypesMatch<["lhs", "rhs"]>,
-    TypesMatchWith<
+def ResultIsDoubleSourceVectorType : TypesMatchWith<
     "type of 'result' is double the width of the inputs",
     "lhs", "result",
     [{
       [&]() -> ::mlir::VectorType {
-        auto vectorType = ::llvm::cast<mlir::VectorType>($_self);
+        auto vectorType = ::llvm::cast<::mlir::VectorType>($_self);
         ::mlir::VectorType::Builder builder(vectorType);
         if (vectorType.getRank() == 0) {
-          static constexpr int64_t v2xty_shape[] = { 2 };
-          return builder.setShape(v2xty_shape);
+          static constexpr int64_t v2xTyShape[] = {2};
+          return builder.setShape(v2xTyShape);
         }
         auto lastDim = vectorType.getRank() - 1;
         return builder.setDim(lastDim, vectorType.getDimSize(lastDim) * 2);
       }()
-    }]>]> {
+    }]>;
+
+def Vector_InterleaveOp :
+  Vector_Op<"interleave", [Pure, AllTypesMatch<["lhs", "rhs"]>,
+    ResultIsDoubleSourceVectorType]> {
   let summary = "constructs a vector by interleaving two input vectors";
   let description = [{
     The interleave operation constructs a new vector by interleaving the
@@ -513,16 +514,15 @@ def Vector_InterleaveOp :
 
     Example:
     ```mlir
-    %0 = vector.interleave %a, %b
-               : vector<[4]xi32>     ; yields vector<[8]xi32>
-    %1 = vector.interleave %c, %d
-               : vector<8xi8>        ; yields vector<16xi8>
-    %2 = vector.interleave %e, %f
-               : vector<f16>         ; yields vector<2xf16>
-    %3 = vector.interleave %g, %h
-               : vector<2x4x[2]xf64> ; yields vector<2x4x[4]xf64>
-    %4 = vector.interleave %i, %j
-               : vector<6x3xf32>     ; yields vector<6x6xf32>
+    %a = arith.constant dense<[0, 1]> : vector<2xi32>
+    %b = arith.constant dense<[2, 3]> : vector<2xi32>
+    %0 = vector.interleave %a, %b : vector<2xi32> -> vector<4xi32>
+    // The value of `%0` is `[0, 2, 1, 3]`.
+
+    %1 = vector.interleave %c, %d : vector<f16> -> vector<2xf16>
+    %2 = vector.interleave %e, %f : vector<6x3xf32> -> vector<6x6xf32>
+    %3 = vector.interleave %g, %h : vector<[4]xi32> -> vector<[8]xi32>
+    %4 = vector.interleave %i, %j : vector<2x4x[2]xf64> -> vector<2x4x[4]xf64>
     ```
   }];
 
@@ -530,7 +530,7 @@ def Vector_InterleaveOp :
   let results = (outs AnyVector:$result);
 
   let assemblyFormat = [{
-    $lhs `,` $rhs  attr-dict `:` type($lhs)
+    $lhs `,` $rhs  attr-dict `:` type($lhs) `->` type($result)
   }];
 
   let extraClassDeclaration = [{
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
index 5326760c9b4eb..77c97b2f1497c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
@@ -30,7 +30,7 @@ namespace {
 /// Example:
 ///
 /// ```mlir
-/// vector.interleave %a, %b : vector<1x2x3x4xi64>
+/// vector.interleave %a, %b : vector<1x2x3x4xi64> -> vector<1x2x3x8xi64>
 /// ```
 /// Would be unrolled to:
 /// ```mlir
@@ -39,14 +39,15 @@ namespace {
 ///        : vector<4xi64> from vector<1x2x3x4xi64>  |
 /// %1 = vector.extract %b[0, 0, 0]                  |
 ///        : vector<4xi64> from vector<1x2x3x4xi64>  | - Repeated 6x for
-/// %2 = vector.interleave %0, %1 : vector<4xi64>    |   all leading positions
+/// %2 = vector.interleave %0, %1 :                  |   all leading positions
+///        : vector<4xi64> -> vector<8xi64>          |
 /// %3 = vector.insert %2, %result [0, 0, 0]         |
 ///        : vector<8xi64> into vector<1x2x3x8xi64>  ┘
 /// ```
 ///
 /// Note: If any leading dimension before the `targetRank` is scalable the
 /// unrolling will stop before the scalable dimension.
-class UnrollInterleaveOp : public OpRewritePattern<vector::InterleaveOp> {
+class UnrollInterleaveOp final : public OpRewritePattern<vector::InterleaveOp> {
 public:
   UnrollInterleaveOp(int64_t targetRank, MLIRContext *context,
                      PatternBenefit benefit = 1)
@@ -84,7 +85,7 @@ class UnrollInterleaveOp : public OpRewritePattern<vector::InterleaveOp> {
 /// Example:
 ///
 /// ```mlir
-/// vector.interleave %a, %b : vector<7xi16>
+/// vector.interleave %a, %b : vector<7xi16> -> vector<14xi16>
 /// ```
 ///
 /// Is rewritten into:
@@ -93,10 +94,8 @@ class UnrollInterleaveOp : public OpRewritePattern<vector::InterleaveOp> {
 /// vector.shuffle %arg0, %arg1 [0, 7, 1, 8, 2, 9, 3, 10, 4, 11, 5, 12, 6, 13]
 ///   : vector<7xi16>, vector<7xi16>
 /// ```
-class InterleaveToShuffle : public OpRewritePattern<vector::InterleaveOp> {
-public:
-  InterleaveToShuffle(MLIRContext *context, PatternBenefit benefit = 1)
-      : OpRewritePattern(context, benefit) {};
+struct InterleaveToShuffle final : OpRewritePattern<vector::InterleaveOp> {
+  using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(vector::InterleaveOp op,
                                 PatternRewriter &rewriter) const override {
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 6025c4ad7c145..59b6cb3ae667a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -1090,7 +1090,7 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
 ///        %1 = arith.shli %0, 4 : vector<4xi8>
 ///        %2 = arith.shrsi %1, 4 : vector<4xi8>
 ///        %3 = arith.shrsi %0, 4 : vector<4xi8>
-///        %4 = vector.interleave %2, %3 : vector<4xi8>
+///        %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8>
 ///        %5 = arith.extsi %4 : vector<8xi8> to vector<8xi32>
 ///
 ///    arith.sitofp %in : vector<8xi4> to vector<8xf32>
@@ -1099,7 +1099,7 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
 ///        %1 = arith.shli %0, 4 : vector<4xi8>
 ///        %2 = arith.shrsi %1, 4 : vector<4xi8>
 ///        %3 = arith.shrsi %0, 4 : vector<4xi8>
-///        %4 = vector.interleave %2, %3 : vector<4xi8>
+///        %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8>
 ///        %5 = arith.sitofp %4 : vector<8xi8> to vector<8xf32>
 ///
 /// Example (unsigned):
@@ -1108,7 +1108,7 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
 ///        %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
 ///        %1 = arith.andi %0, 15 : vector<4xi8>
 ///        %2 = arith.shrui %0, 4 : vector<4xi8>
-///        %3 = vector.interleave %1, %2 : vector<4xi8>
+///        %3 = vector.interleave %1, %2 : vector<4xi8> -> vector<8xi8>
 ///        %4 = arith.extui %3 : vector<8xi8> to vector<8xi32>
 ///
 template <typename ConversionOpType, bool isSigned>
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 439f1e920e392..a7a0ca3d43b01 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -2495,7 +2495,7 @@ func.func @vector_interleave_0d(%a: vector<i8>, %b: vector<i8>) -> vector<2xi8>
   // CHECK: %[[RHS_RANK1:.*]] = builtin.unrealized_conversion_cast %[[RHS]] : vector<i8> to vector<1xi8>
   // CHECK: %[[ZIP:.*]] = llvm.shufflevector %[[LHS_RANK1]], %[[RHS_RANK1]] [0, 1] : vector<1xi8>
   // CHECK: return %[[ZIP]]
-  %0 = vector.interleave %a, %b : vector<i8>
+  %0 = vector.interleave %a, %b : vector<i8> -> vector<2xi8>
   return %0 : vector<2xi8>
 }
 
@@ -2503,11 +2503,10 @@ func.func @vector_interleave_0d(%a: vector<i8>, %b: vector<i8>) -> vector<2xi8>
 
 // CHECK-LABEL: @vector_interleave_1d
 //  CHECK-SAME:     %[[LHS:.*]]: vector<8xf32>, %[[RHS:.*]]: vector<8xf32>)
-func.func @vector_interleave_1d(%a: vector<8xf32>, %b: vector<8xf32>) -> vector<16xf32>
-{
+func.func @vector_interleave_1d(%a: vector<8xf32>, %b: vector<8xf32>) -> vector<16xf32> {
   // CHECK: %[[ZIP:.*]] = llvm.shufflevector %[[LHS]], %[[RHS]] [0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] : vector<8xf32>
   // CHECK: return %[[ZIP]]
-  %0 = vector.interleave %a, %b : vector<8xf32>
+  %0 = vector.interleave %a, %b : vector<8xf32> -> vector<16xf32>
   return %0 : vector<16xf32>
 }
 
@@ -2515,11 +2514,10 @@ func.func @vector_interleave_1d(%a: vector<8xf32>, %b: vector<8xf32>) -> vector<
 
 // CHECK-LABEL: @vector_interleave_1d_scalable
 //  CHECK-SAME:     %[[LHS:.*]]: vector<[4]xi32>, %[[RHS:.*]]: vector<[4]xi32>)
-func.func @vector_interleave_1d_scalable(%a: vector<[4]xi32>, %b: vector<[4]xi32>) -> vector<[8]xi32>
-{
+func.func @vector_interleave_1d_scalable(%a: vector<[4]xi32>, %b: vector<[4]xi32>) -> vector<[8]xi32> {
   // CHECK: %[[ZIP:.*]] = "llvm.intr.vector.interleave2"(%[[LHS]], %[[RHS]]) : (vector<[4]xi32>, vector<[4]xi32>) -> vector<[8]xi32>
   // CHECK: return %[[ZIP]]
-  %0 = vector.interleave %a, %b : vector<[4]xi32>
+  %0 = vector.interleave %a, %b : vector<[4]xi32> -> vector<[8]xi32>
   return %0 : vector<[8]xi32>
 }
 
@@ -2527,11 +2525,10 @@ func.func @vector_interleave_1d_scalable(%a: vector<[4]xi32>, %b: vector<[4]xi32
 
 // CHECK-LABEL: @vector_interleave_2d
 //  CHECK-SAME:     %[[LHS:.*]]: vector<2x3xi8>, %[[RHS:.*]]: vector<2x3xi8>)
-func.func @vector_interleave_2d(%a: vector<2x3xi8>, %b: vector<2x3xi8>) -> vector<2x6xi8>
-{
+func.func @vector_interleave_2d(%a: vector<2x3xi8>, %b: vector<2x3xi8>) -> vector<2x6xi8> {
   // CHECK: llvm.shufflevector
   // CHECK-NOT: vector.interleave {{.*}} : vector<2x3xi8>
-  %0 = vector.interleave %a, %b : vector<2x3xi8>
+  %0 = vector.interleave %a, %b : vector<2x3xi8> -> vector<2x6xi8>
   return %0 : vector<2x6xi8>
 }
 
@@ -2539,10 +2536,9 @@ func.func @vector_interleave_2d(%a: vector<2x3xi8>, %b: vector<2x3xi8>) -> vecto
 
 // CHECK-LABEL: @vector_interleave_2d_scalable
 //  CHECK-SAME:     %[[LHS:.*]]: vector<2x[8]xi16>, %[[RHS:.*]]: vector<2x[8]xi16>)
-func.func @vector_interleave_2d_scalable(%a: vector<2x[8]xi16>, %b: vector<2x[8]xi16>) -> vector<2x[16]xi16>
-{
+func.func @vector_interleave_2d_scalable(%a: vector<2x[8]xi16>, %b: vector<2x[8]xi16>) -> vector<2x[16]xi16> {
   // CHECK: llvm.intr.vector.interleave2
   // CHECK-NOT: vector.interleave {{.*}} : vector<2x[8]xi16>
-  %0 = vector.interleave %a, %b : vector<2x[8]xi16>
+  %0 = vector.interleave %a, %b : vector<2x[8]xi16> -> vector<2x[16]xi16>
   return %0 : vector<2x[16]xi16>
 }
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index a7542086aa766..b24088d951259 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -488,7 +488,7 @@ func.func @shuffle(%v0 : vector<1xi32>, %v1: vector<1xi32>) -> vector<2xi32> {
 //       CHECK: %[[SHUFFLE:.*]] = spirv.VectorShuffle [0 : i32, 2 : i32, 1 : i32, 3 : i32] %[[ARG0]], %[[ARG1]] : vector<2xf32>, vector<2xf32> -> vector<4xf32>
 //       CHECK: return %[[SHUFFLE]]
 func.func @interleave(%a: vector<2xf32>, %b: vector<2xf32>) -> vector<4xf32> {
-  %0 = vector.interleave %a, %b : vector<2xf32>
+  %0 = vector.interleave %a, %b : vector<2xf32> -> vector<4xf32>
   return %0 : vector<4xf32>
 }
 
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 61a5f2a96e1c1..22af91e0eb327 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2576,9 +2576,8 @@ func.func @load_store_forwarding_rank_mismatch(%v0: vector<4x1x1xf32>, %arg0: te
 
 // CHECK-LABEL: func.func @rank_0_shuffle_to_interleave(
 //  CHECK-SAME:     %[[LHS:.*]]: vector<f64>, %[[RHS:.*]]: vector<f64>)
-func.func @rank_0_shuffle_to_interleave(%arg0: vector<f64>, %arg1: vector<f64>) -> vector<2xf64>
-{
-  // CHECK: %[[ZIP:.*]] = vector.interleave %[[LHS]], %[[RHS]] : vector<f64>
+func.func @rank_0_shuffle_to_interleave(%arg0: vector<f64>, %arg1: vector<f64>) -> vector<2xf64> {
+  // CHECK: %[[ZIP:.*]] = vector.interleave %[[LHS]], %[[RHS]] : vector<f64> -> vector<2xf64>
   // CHECK: return %[[ZIP]]
   %0 = vector.shuffle %arg0, %arg1 [0, 1] : vector<f64>, vector<f64>
   return %0 : vector<2xf64>
@@ -2589,7 +2588,7 @@ func.func @rank_0_shuffle_to_interleave(%arg0: vector<f64>, %arg1: vector<f64>)
 // CHECK-LABEL: func.func @rank_1_shuffle_to_interleave(
 //  CHECK-SAME:     %[[LHS:.*]]: vector<6xi32>, %[[RHS:.*]]: vector<6xi32>)
 func.func @rank_1_shuffle_to_interleave(%arg0: vector<6xi32>, %arg1: vector<6xi32>) -> vector<12xi32> {
-  // CHECK: %[[ZIP:.*]] = vector.interleave %[[LHS]], %[[RHS]] : vector<6xi32>
+  // CHECK: %[[ZIP:.*]] = vector.interleave %[[LHS]], %[[RHS]] : vector<6xi32> -> vector<12xi32>
   // CHECK: return %[[ZIP]]
   %0 = vector.shuffle %arg0, %arg1 [0, 6, 1, 7, 2, 8, 3, 9, 4, 10, 5, 11] : vector<6xi32>, vector<6xi32>
   return %0 : vector<12xi32>
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 9d8101d3eee97..c868c881d079a 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -1084,36 +1084,36 @@ func.func @fastmath(%x: vector<42xf32>) -> f32 {
 
 // CHECK-LABEL: @interleave_0d
 func.func @interleave_0d(%a: vector<f32>, %b: vector<f32>) -> vector<2xf32> {
-  // CHECK: vector.interleave %{{.*}}, %{{.*}} : vector<f32>
-  %0 = vector.interleave %a, %b : vector<f32>
+  // CHECK: vector.interleave %{{.*}}, %{{.*}} : vector<f32> -> vector<2xf32>
+  %0 = vector.interleave %a, %b : vector<f32> -> vector<2xf32>
   return %0 : vector<2xf32>
 }
 
 // CHECK-LABEL: @interleave_1d
 func.func @interleave_1d(%a: vector<4xf32>, %b: vector<4xf32>) -> vector<8xf32> {
   // CHECK: vector.interleave %{{.*}}, %{{.*}} : vector<4xf32>
-  %0 = vector.interleave %a, %b : vector<4xf32>
+  %0 = vector.interleave %a, %b : vector<4xf32> -> vector<8xf32>
   return %0 : vector<8xf32>
 }
 
 // CHECK-LABEL: @interleave_1d_scalable
 func.func @interleave_1d_scalable(%a: vector<[8]xi16>, %b: vector<[8]xi16>) -> vector<[16]xi16> {
   // CHECK: vector.interleave %{{.*}}, %{{.*}} : vector<[8]xi16>
-  %0 = vector.interleave %a, %b : vector<[8]xi16>
+  %0 = vector.interleave %a, %b : vector<[8]xi16> -> vector<[16]xi16>
   return %0 : vector<[16]xi16>
 }
 
 // CHECK-LABEL: @interleave_2d
 func.func @interleave_2d(%a: vector<2x8xf32>, %b: vector<2x8xf32>) -> vector<2x16xf32> {
   // CHECK: vector.interleave %{{.*}}, %{{.*}} : vector<2x8xf32>
-  %0 = vector.interleave %a, %b : vector<2x8xf32>
+  %0 = vector.interleave %a, %b : vector<2x8xf32> -> vector<2x16xf32>
   return %0 : vector<2x16xf32>
 }
 
 // CHECK-LABEL: @interleave_2d_scalable
 func.func @interleave_2d_scalable(%a: vector<2x[2]xf64>, %b: vector<2x[2]xf64>) -> vector<2x[4]xf64> {
   // CHECK: vector.interleave %{{.*}}, %{{.*}} : vector<2x[2]xf64>
-  %0 = vector.interleave %a, %b : vector<2x[2]xf64>
+  %0 = vector.interleave %a, %b : vector<2x[2]xf64> -> vector<2x[4]xf64>
   return %0 : vector<2x[4]xf64>
 }
 
diff --git a/mlir/test/Dialect/Vector/vector-interleave-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-interleave-lowering-transforms.mlir
index 3dd4857860eb1..598f7d70b4f1b 100644
--- a/mlir/test/Dialect/Vector/vector-interleave-lowering-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-interleave-lowering-transforms.mlir
@@ -2,8 +2,7 @@
 
 // CHECK-LABEL: @vector_interleave_2d
 //  CHECK-SAME:     %[[LHS:.*]]: vector<2x3xi8>, %[[RHS:.*]]: vector<2x3xi8>)
-func.func @vector_interleave_2d(%a: vector<2x3xi8>, %b: vector<2x3xi8>) -> vector<2x6xi8>
-{
+func.func @vector_interleave_2d(%a: vector<2x3xi8>, %b: vector<2x3xi8>) -> vector<2x6xi8> {
   // CHECK-DAG: %[[CST:.*]] = arith.constant dense<0>
   // CHECK-DAG: %[[LHS_0:.*]] = vector.extract %[[LHS]][0]
   // CHECK-DAG: %[[RHS_0:.*]] = vector.extract %[[RHS]][0]
@@ -14,14 +13,13 @@ func.func @vector_interleave_2d(%a: vector<2x3xi8>, %b: vector<2x3xi8>) -> vecto
   // CHECK-DAG: %[[RES_0:.*]] = vector.insert %[[ZIP_0]], %[[CST]] [0]
   // CHECK-DAG: %[[RES_1:.*]] = vector.insert %[[ZIP_1]], %[[RES_0]] [1]
   // CHECK-NEXT: return %[[RES_1]] : vector<2x6xi8>
-  %0 = vector.interleave %a, %b : vector<2x3xi8>
+  %0 = vector.interleave %a, %b : vector<2x3xi8> -> vector<2x6xi8>
   return %0 : vector<2x6xi8>
 }
 
 // CHECK-LABEL: @vector_interleave_2d_scalable
 //  CHECK-SAME:     %[[LHS:.*]]: vector<2x[8]xi16>, %[[RHS:.*]]: vector<2x[8]xi16>)
-func.func @vector_interleave_2d_scalable(%a: vector<2x[8]xi16>, %b: vector<2x[8]xi16>) -> vector<2x[16]xi16>
-{
+func.func @vector_interleave_2d_scalable(%a: vector<2x[8]xi16>, %b: vector<2x[8]xi16>) -> vector<2x[16]xi16> {
   // CHECK-DAG: %[[CST:.*]] = arith.constant dense<0>
   // CHECK-DAG: %[[LHS_0:.*]] = vector.extract %[[LHS]][0]
   // CHECK-DAG: %[[RHS_0:.*]] = vector.extract %[[RHS]][0]
@@ -32,7 +30,7 @@ func.func @vector_interleave_2d_scalable(%a: vector<2x[8]xi16>, %b: vector<2x[8]
   // CHECK-DAG: %[[RES_0:.*]] = vector.insert %[[ZIP_0]], %[[CST]] [0]
   // CHECK-DAG: %[[RES_1:.*]] = vector.insert %[[ZIP_1]], %[[RES_0]] [1]
   // CHECK-NEXT: return %[[RES_1]] : vector<2x[16]xi16>
-  %0 = vector.interleave %a, %b : vector<2x[8]xi16>
+  %0 = vector.interleave %a, %b : vector<2x[8]xi16> -> vector<2x[16]xi16>
   return %0 : vector<2x[16]xi16>
 }
 
@@ -44,17 +42,17 @@ func.func @vector_interleave_4d(%a: vector<1x2x3x4xi64>, %b: vector<1x2x3x4xi64>
   // CHECK: %[[RHS_0:.*]] = vector.extract %[[RHS]][0, 0, 0] : vector<4xi64> from vector<1x2x3x4xi64>
   // CHECK: %[[ZIP_0:.*]] = vector.interleave %[[LHS_0]], %[[RHS_0]] : vector<4xi64>
   // CHECK: %[[RES_0:.*]] = vector.insert %[[ZIP_0]], %{{.*}} [0, 0, 0] : vector<8xi64> into vector<1x2x3x8xi64>
-  // CHECK-COUNT-5: vector.interleave %{{.*}}, %{{.*}} : vector<4xi64>
-  %0 = vector.interleave %a, %b : vector<1x2x3x4xi64>
+  // CHECK-COUNT-5: vector.interleave %{{.*}}, %{{.*}} : vector<4xi64> -> vector<8xi64>
+  %0 = vector.interleave %a, %b : vector<1x2x3x4xi64> -> vector<1x2x3x8xi64>
   return %0 : vector<1x2x3x8xi64>
 }
 
 // CHECK-LABEL: @vector_interleave_nd_with_scalable_dim
-func.func @vector_interleave_nd_with_scalable_dim(%a: vector<1x3x[2]x2x3x4xf16>, %b: vector<1x3x[2]x2x3x4xf16>) -> vector<1x3x[2]x2x3x8xf16>
-{
+func.func @vector_interleave_nd_with_scalable_dim(
+  %a: vector<1x3x[2]x2x3x4xf16>, %b: vector<1x3x[2]x2x3x4xf16>) -> vector<1x3x[2]x2x3x8xf16> {
   // The scalable dim blocks unrolling so only the first two dims are unrolled.
   // CHECK-COUNT-3: vector.interleave %{{.*}}, %{{.*}} : vector<[2]x2x3x4xf16>
-  %0 = vector.interleave %a, %b : vector<1x3x[2]x2x3x4xf16>
+  %0 = vector.interleave %a, %b : vector<1x3x[2]x2x3x4xf16> -> vector<1x3x[2]x2x3x8xf16>
   return %0 : vector<1x3x[2]x2x3x8xf16>
 }
 
diff --git a/mlir/test/Dialect/Vector/vector-interleave-to-shuffle.mlir b/mlir/test/Dialect/Vector/vector-interleave-to-shuffle.mlir
index ed3b3396bf3ea..d59cd4e6765ba 100644
--- a/mlir/test/Dialect/Vector/vector-interleave-to-shuffle.mlir
+++ b/mlir/test/Dialect/Vector/vector-interleave-to-shuffle.mlir
@@ -1,9 +1,8 @@
 // RUN: mlir-opt %s --transform-interpreter | FileCheck %s
 
 // CHECK-LABEL: @vector_interleave_to_shuffle
-func.func @vector_interleave_to_shuffle(%a: vector<7xi16>, %b: vector<7xi16>) -> vector<14xi16>
-{
-  %0 = vector.interleave %a, %b : vector<7xi16>
+func.func @vector_interleave_to_shuffle(%a: vector<7xi16>, %b: vector<7xi16>) -> vector<14xi16> {
+  %0 = vector.interleave %a, %b : vector<7xi16> -> vector<14xi16>
   return %0 : vector<14xi16>
 }
 // CHECK: vector.shuffle %arg0, %arg1 [0, 7, 1, 8, 2, 9, 3, 10, 4, 11, 5, 12, 6, 13] : vector<7xi16>, vector<7xi16>
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/test-scalable-interleave.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/test-scalable-interleave.mlir
index 07989bd71f501..e9f1bbeafacdd 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/test-scalable-interleave.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/test-scalable-interleave.mlir
@@ -17,7 +17,7 @@ func.func @entry() {
   // CHECK: ( 1, 1, 1, 1
   // CHECK: ( 2, 2, 2, 2
 
-  %v3 = vector.interleave %v1, %v2 : vector<[4]xf32>
+  %v3 = vector.interleave %v1, %v2 : vector<[4]xf32> -> vector<[8]xf32>
   vector.print %v3 : vector<[8]xf32>
   // CHECK: ( 1, 2, 1, 2, 1, 2, 1, 2
 
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-interleave.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-interleave.mlir
index 0bc78af6aba03..d6962cbe2776a 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-interleave.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-interleave.mlir
@@ -16,7 +16,7 @@ func.func @entry() {
   // CHECK: ( ( 1, 1, 1, 1 ), ( 1, 1, 1, 1 ) )
   // CHECK: ( ( 2, 2, 2, 2 ), ( 2, 2, 2, 2 ) )
 
-  %v3 = vector.interleave %v1, %v2 : vector<2x4xf32>
+  %v3 = vector.interleave %v1, %v2 : vector<2x4xf32> -> vector<2x8xf32>
   vector.print %v3 : vector<2x8xf32>
   // CHECK: ( ( 1, 2, 1, 2, 1, 2, 1, 2 ), ( 1, 2, 1, 2, 1, 2, 1, 2 ) )
 

>From 8ee982880bf550fbe78e8cff7df22bd195325256 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Mon, 27 May 2024 10:59:29 -0400
Subject: [PATCH 2/2] Improve comments

---
 mlir/include/mlir/Dialect/Vector/IR/VectorOps.td | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index e043320b56411..56d866ac5b40c 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -516,9 +516,10 @@ def Vector_InterleaveOp :
     ```mlir
     %a = arith.constant dense<[0, 1]> : vector<2xi32>
     %b = arith.constant dense<[2, 3]> : vector<2xi32>
-    %0 = vector.interleave %a, %b : vector<2xi32> -> vector<4xi32>
     // The value of `%0` is `[0, 2, 1, 3]`.
+    %0 = vector.interleave %a, %b : vector<2xi32> -> vector<4xi32>
 
+    // Examples showing allowed input and result types.
     %1 = vector.interleave %c, %d : vector<f16> -> vector<2xf16>
     %2 = vector.interleave %e, %f : vector<6x3xf32> -> vector<6x6xf32>
     %3 = vector.interleave %g, %h : vector<[4]xi32> -> vector<[8]xi32>



More information about the Mlir-commits mailing list