[Mlir-commits] [mlir] [mlir][vector] Add result type to `interleave` assembly format (PR #93392)
Jakub Kuderski
llvmlistbot at llvm.org
Mon May 27 07:59:40 PDT 2024
https://github.com/kuhar updated https://github.com/llvm/llvm-project/pull/93392
>From 9040d7d1ef07b4f446c8d2d3b4ea317b921398ae Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Sat, 25 May 2024 23:03:40 -0400
Subject: [PATCH 1/2] [mlir][vector] Add result type to `interleave` assembly
format
This is to make it more obvious for what the result type is, especially
with some less trivial cases like 0-d inputs resulting in 1-d inputs or
interaction with scalable vector types. Note that `vector.deinterleave`
uses the same format with explicit result type.
Also improve examples and clean up surrounding code.
---
.../mlir/Dialect/Vector/IR/VectorOps.td | 38 +++++++++----------
.../Transforms/LowerVectorInterleave.cpp | 15 ++++----
.../Transforms/VectorEmulateNarrowType.cpp | 6 +--
.../VectorToLLVM/vector-to-llvm.mlir | 22 +++++------
.../VectorToSPIRV/vector-to-spirv.mlir | 2 +-
mlir/test/Dialect/Vector/canonicalize.mlir | 7 ++--
mlir/test/Dialect/Vector/ops.mlir | 12 +++---
...vector-interleave-lowering-transforms.mlir | 20 +++++-----
.../Vector/vector-interleave-to-shuffle.mlir | 5 +--
.../CPU/ArmSVE/test-scalable-interleave.mlir | 2 +-
.../Dialect/Vector/CPU/test-interleave.mlir | 2 +-
11 files changed, 61 insertions(+), 70 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 2bb7540ef0b0f..e043320b56411 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -480,24 +480,25 @@ def Vector_ShuffleOp :
let hasCanonicalizer = 1;
}
-def Vector_InterleaveOp :
- Vector_Op<"interleave", [Pure,
- AllTypesMatch<["lhs", "rhs"]>,
- TypesMatchWith<
+def ResultIsDoubleSourceVectorType : TypesMatchWith<
"type of 'result' is double the width of the inputs",
"lhs", "result",
[{
[&]() -> ::mlir::VectorType {
- auto vectorType = ::llvm::cast<mlir::VectorType>($_self);
+ auto vectorType = ::llvm::cast<::mlir::VectorType>($_self);
::mlir::VectorType::Builder builder(vectorType);
if (vectorType.getRank() == 0) {
- static constexpr int64_t v2xty_shape[] = { 2 };
- return builder.setShape(v2xty_shape);
+ static constexpr int64_t v2xTyShape[] = {2};
+ return builder.setShape(v2xTyShape);
}
auto lastDim = vectorType.getRank() - 1;
return builder.setDim(lastDim, vectorType.getDimSize(lastDim) * 2);
}()
- }]>]> {
+ }]>;
+
+def Vector_InterleaveOp :
+ Vector_Op<"interleave", [Pure, AllTypesMatch<["lhs", "rhs"]>,
+ ResultIsDoubleSourceVectorType]> {
let summary = "constructs a vector by interleaving two input vectors";
let description = [{
The interleave operation constructs a new vector by interleaving the
@@ -513,16 +514,15 @@ def Vector_InterleaveOp :
Example:
```mlir
- %0 = vector.interleave %a, %b
- : vector<[4]xi32> ; yields vector<[8]xi32>
- %1 = vector.interleave %c, %d
- : vector<8xi8> ; yields vector<16xi8>
- %2 = vector.interleave %e, %f
- : vector<f16> ; yields vector<2xf16>
- %3 = vector.interleave %g, %h
- : vector<2x4x[2]xf64> ; yields vector<2x4x[4]xf64>
- %4 = vector.interleave %i, %j
- : vector<6x3xf32> ; yields vector<6x6xf32>
+ %a = arith.constant dense<[0, 1]> : vector<2xi32>
+ %b = arith.constant dense<[2, 3]> : vector<2xi32>
+ %0 = vector.interleave %a, %b : vector<2xi32> -> vector<4xi32>
+ // The value of `%0` is `[0, 2, 1, 3]`.
+
+ %1 = vector.interleave %c, %d : vector<f16> -> vector<2xf16>
+ %2 = vector.interleave %e, %f : vector<6x3xf32> -> vector<6x6xf32>
+ %3 = vector.interleave %g, %h : vector<[4]xi32> -> vector<[8]xi32>
+ %4 = vector.interleave %i, %j : vector<2x4x[2]xf64> -> vector<2x4x[4]xf64>
```
}];
@@ -530,7 +530,7 @@ def Vector_InterleaveOp :
let results = (outs AnyVector:$result);
let assemblyFormat = [{
- $lhs `,` $rhs attr-dict `:` type($lhs)
+ $lhs `,` $rhs attr-dict `:` type($lhs) `->` type($result)
}];
let extraClassDeclaration = [{
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
index 5326760c9b4eb..77c97b2f1497c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
@@ -30,7 +30,7 @@ namespace {
/// Example:
///
/// ```mlir
-/// vector.interleave %a, %b : vector<1x2x3x4xi64>
+/// vector.interleave %a, %b : vector<1x2x3x4xi64> -> vector<1x2x3x8xi64>
/// ```
/// Would be unrolled to:
/// ```mlir
@@ -39,14 +39,15 @@ namespace {
/// : vector<4xi64> from vector<1x2x3x4xi64> |
/// %1 = vector.extract %b[0, 0, 0] |
/// : vector<4xi64> from vector<1x2x3x4xi64> | - Repeated 6x for
-/// %2 = vector.interleave %0, %1 : vector<4xi64> | all leading positions
+/// %2 = vector.interleave %0, %1 : | all leading positions
+/// : vector<4xi64> -> vector<8xi64> |
/// %3 = vector.insert %2, %result [0, 0, 0] |
/// : vector<8xi64> into vector<1x2x3x8xi64> ┘
/// ```
///
/// Note: If any leading dimension before the `targetRank` is scalable the
/// unrolling will stop before the scalable dimension.
-class UnrollInterleaveOp : public OpRewritePattern<vector::InterleaveOp> {
+class UnrollInterleaveOp final : public OpRewritePattern<vector::InterleaveOp> {
public:
UnrollInterleaveOp(int64_t targetRank, MLIRContext *context,
PatternBenefit benefit = 1)
@@ -84,7 +85,7 @@ class UnrollInterleaveOp : public OpRewritePattern<vector::InterleaveOp> {
/// Example:
///
/// ```mlir
-/// vector.interleave %a, %b : vector<7xi16>
+/// vector.interleave %a, %b : vector<7xi16> -> vector<14xi16>
/// ```
///
/// Is rewritten into:
@@ -93,10 +94,8 @@ class UnrollInterleaveOp : public OpRewritePattern<vector::InterleaveOp> {
/// vector.shuffle %arg0, %arg1 [0, 7, 1, 8, 2, 9, 3, 10, 4, 11, 5, 12, 6, 13]
/// : vector<7xi16>, vector<7xi16>
/// ```
-class InterleaveToShuffle : public OpRewritePattern<vector::InterleaveOp> {
-public:
- InterleaveToShuffle(MLIRContext *context, PatternBenefit benefit = 1)
- : OpRewritePattern(context, benefit) {};
+struct InterleaveToShuffle final : OpRewritePattern<vector::InterleaveOp> {
+ using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::InterleaveOp op,
PatternRewriter &rewriter) const override {
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 6025c4ad7c145..59b6cb3ae667a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -1090,7 +1090,7 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
/// %1 = arith.shli %0, 4 : vector<4xi8>
/// %2 = arith.shrsi %1, 4 : vector<4xi8>
/// %3 = arith.shrsi %0, 4 : vector<4xi8>
-/// %4 = vector.interleave %2, %3 : vector<4xi8>
+/// %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8>
/// %5 = arith.extsi %4 : vector<8xi8> to vector<8xi32>
///
/// arith.sitofp %in : vector<8xi4> to vector<8xf32>
@@ -1099,7 +1099,7 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
/// %1 = arith.shli %0, 4 : vector<4xi8>
/// %2 = arith.shrsi %1, 4 : vector<4xi8>
/// %3 = arith.shrsi %0, 4 : vector<4xi8>
-/// %4 = vector.interleave %2, %3 : vector<4xi8>
+/// %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8>
/// %5 = arith.sitofp %4 : vector<8xi8> to vector<8xf32>
///
/// Example (unsigned):
@@ -1108,7 +1108,7 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
/// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
/// %1 = arith.andi %0, 15 : vector<4xi8>
/// %2 = arith.shrui %0, 4 : vector<4xi8>
-/// %3 = vector.interleave %1, %2 : vector<4xi8>
+/// %3 = vector.interleave %1, %2 : vector<4xi8> -> vector<8xi8>
/// %4 = arith.extui %3 : vector<8xi8> to vector<8xi32>
///
template <typename ConversionOpType, bool isSigned>
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 439f1e920e392..a7a0ca3d43b01 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -2495,7 +2495,7 @@ func.func @vector_interleave_0d(%a: vector<i8>, %b: vector<i8>) -> vector<2xi8>
// CHECK: %[[RHS_RANK1:.*]] = builtin.unrealized_conversion_cast %[[RHS]] : vector<i8> to vector<1xi8>
// CHECK: %[[ZIP:.*]] = llvm.shufflevector %[[LHS_RANK1]], %[[RHS_RANK1]] [0, 1] : vector<1xi8>
// CHECK: return %[[ZIP]]
- %0 = vector.interleave %a, %b : vector<i8>
+ %0 = vector.interleave %a, %b : vector<i8> -> vector<2xi8>
return %0 : vector<2xi8>
}
@@ -2503,11 +2503,10 @@ func.func @vector_interleave_0d(%a: vector<i8>, %b: vector<i8>) -> vector<2xi8>
// CHECK-LABEL: @vector_interleave_1d
// CHECK-SAME: %[[LHS:.*]]: vector<8xf32>, %[[RHS:.*]]: vector<8xf32>)
-func.func @vector_interleave_1d(%a: vector<8xf32>, %b: vector<8xf32>) -> vector<16xf32>
-{
+func.func @vector_interleave_1d(%a: vector<8xf32>, %b: vector<8xf32>) -> vector<16xf32> {
// CHECK: %[[ZIP:.*]] = llvm.shufflevector %[[LHS]], %[[RHS]] [0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] : vector<8xf32>
// CHECK: return %[[ZIP]]
- %0 = vector.interleave %a, %b : vector<8xf32>
+ %0 = vector.interleave %a, %b : vector<8xf32> -> vector<16xf32>
return %0 : vector<16xf32>
}
@@ -2515,11 +2514,10 @@ func.func @vector_interleave_1d(%a: vector<8xf32>, %b: vector<8xf32>) -> vector<
// CHECK-LABEL: @vector_interleave_1d_scalable
// CHECK-SAME: %[[LHS:.*]]: vector<[4]xi32>, %[[RHS:.*]]: vector<[4]xi32>)
-func.func @vector_interleave_1d_scalable(%a: vector<[4]xi32>, %b: vector<[4]xi32>) -> vector<[8]xi32>
-{
+func.func @vector_interleave_1d_scalable(%a: vector<[4]xi32>, %b: vector<[4]xi32>) -> vector<[8]xi32> {
// CHECK: %[[ZIP:.*]] = "llvm.intr.vector.interleave2"(%[[LHS]], %[[RHS]]) : (vector<[4]xi32>, vector<[4]xi32>) -> vector<[8]xi32>
// CHECK: return %[[ZIP]]
- %0 = vector.interleave %a, %b : vector<[4]xi32>
+ %0 = vector.interleave %a, %b : vector<[4]xi32> -> vector<[8]xi32>
return %0 : vector<[8]xi32>
}
@@ -2527,11 +2525,10 @@ func.func @vector_interleave_1d_scalable(%a: vector<[4]xi32>, %b: vector<[4]xi32
// CHECK-LABEL: @vector_interleave_2d
// CHECK-SAME: %[[LHS:.*]]: vector<2x3xi8>, %[[RHS:.*]]: vector<2x3xi8>)
-func.func @vector_interleave_2d(%a: vector<2x3xi8>, %b: vector<2x3xi8>) -> vector<2x6xi8>
-{
+func.func @vector_interleave_2d(%a: vector<2x3xi8>, %b: vector<2x3xi8>) -> vector<2x6xi8> {
// CHECK: llvm.shufflevector
// CHECK-NOT: vector.interleave {{.*}} : vector<2x3xi8>
- %0 = vector.interleave %a, %b : vector<2x3xi8>
+ %0 = vector.interleave %a, %b : vector<2x3xi8> -> vector<2x6xi8>
return %0 : vector<2x6xi8>
}
@@ -2539,10 +2536,9 @@ func.func @vector_interleave_2d(%a: vector<2x3xi8>, %b: vector<2x3xi8>) -> vecto
// CHECK-LABEL: @vector_interleave_2d_scalable
// CHECK-SAME: %[[LHS:.*]]: vector<2x[8]xi16>, %[[RHS:.*]]: vector<2x[8]xi16>)
-func.func @vector_interleave_2d_scalable(%a: vector<2x[8]xi16>, %b: vector<2x[8]xi16>) -> vector<2x[16]xi16>
-{
+func.func @vector_interleave_2d_scalable(%a: vector<2x[8]xi16>, %b: vector<2x[8]xi16>) -> vector<2x[16]xi16> {
// CHECK: llvm.intr.vector.interleave2
// CHECK-NOT: vector.interleave {{.*}} : vector<2x[8]xi16>
- %0 = vector.interleave %a, %b : vector<2x[8]xi16>
+ %0 = vector.interleave %a, %b : vector<2x[8]xi16> -> vector<2x[16]xi16>
return %0 : vector<2x[16]xi16>
}
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index a7542086aa766..b24088d951259 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -488,7 +488,7 @@ func.func @shuffle(%v0 : vector<1xi32>, %v1: vector<1xi32>) -> vector<2xi32> {
// CHECK: %[[SHUFFLE:.*]] = spirv.VectorShuffle [0 : i32, 2 : i32, 1 : i32, 3 : i32] %[[ARG0]], %[[ARG1]] : vector<2xf32>, vector<2xf32> -> vector<4xf32>
// CHECK: return %[[SHUFFLE]]
func.func @interleave(%a: vector<2xf32>, %b: vector<2xf32>) -> vector<4xf32> {
- %0 = vector.interleave %a, %b : vector<2xf32>
+ %0 = vector.interleave %a, %b : vector<2xf32> -> vector<4xf32>
return %0 : vector<4xf32>
}
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 61a5f2a96e1c1..22af91e0eb327 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2576,9 +2576,8 @@ func.func @load_store_forwarding_rank_mismatch(%v0: vector<4x1x1xf32>, %arg0: te
// CHECK-LABEL: func.func @rank_0_shuffle_to_interleave(
// CHECK-SAME: %[[LHS:.*]]: vector<f64>, %[[RHS:.*]]: vector<f64>)
-func.func @rank_0_shuffle_to_interleave(%arg0: vector<f64>, %arg1: vector<f64>) -> vector<2xf64>
-{
- // CHECK: %[[ZIP:.*]] = vector.interleave %[[LHS]], %[[RHS]] : vector<f64>
+func.func @rank_0_shuffle_to_interleave(%arg0: vector<f64>, %arg1: vector<f64>) -> vector<2xf64> {
+ // CHECK: %[[ZIP:.*]] = vector.interleave %[[LHS]], %[[RHS]] : vector<f64> -> vector<2xf64>
// CHECK: return %[[ZIP]]
%0 = vector.shuffle %arg0, %arg1 [0, 1] : vector<f64>, vector<f64>
return %0 : vector<2xf64>
@@ -2589,7 +2588,7 @@ func.func @rank_0_shuffle_to_interleave(%arg0: vector<f64>, %arg1: vector<f64>)
// CHECK-LABEL: func.func @rank_1_shuffle_to_interleave(
// CHECK-SAME: %[[LHS:.*]]: vector<6xi32>, %[[RHS:.*]]: vector<6xi32>)
func.func @rank_1_shuffle_to_interleave(%arg0: vector<6xi32>, %arg1: vector<6xi32>) -> vector<12xi32> {
- // CHECK: %[[ZIP:.*]] = vector.interleave %[[LHS]], %[[RHS]] : vector<6xi32>
+ // CHECK: %[[ZIP:.*]] = vector.interleave %[[LHS]], %[[RHS]] : vector<6xi32> -> vector<12xi32>
// CHECK: return %[[ZIP]]
%0 = vector.shuffle %arg0, %arg1 [0, 6, 1, 7, 2, 8, 3, 9, 4, 10, 5, 11] : vector<6xi32>, vector<6xi32>
return %0 : vector<12xi32>
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 9d8101d3eee97..c868c881d079a 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -1084,36 +1084,36 @@ func.func @fastmath(%x: vector<42xf32>) -> f32 {
// CHECK-LABEL: @interleave_0d
func.func @interleave_0d(%a: vector<f32>, %b: vector<f32>) -> vector<2xf32> {
- // CHECK: vector.interleave %{{.*}}, %{{.*}} : vector<f32>
- %0 = vector.interleave %a, %b : vector<f32>
+ // CHECK: vector.interleave %{{.*}}, %{{.*}} : vector<f32> -> vector<2xf32>
+ %0 = vector.interleave %a, %b : vector<f32> -> vector<2xf32>
return %0 : vector<2xf32>
}
// CHECK-LABEL: @interleave_1d
func.func @interleave_1d(%a: vector<4xf32>, %b: vector<4xf32>) -> vector<8xf32> {
// CHECK: vector.interleave %{{.*}}, %{{.*}} : vector<4xf32>
- %0 = vector.interleave %a, %b : vector<4xf32>
+ %0 = vector.interleave %a, %b : vector<4xf32> -> vector<8xf32>
return %0 : vector<8xf32>
}
// CHECK-LABEL: @interleave_1d_scalable
func.func @interleave_1d_scalable(%a: vector<[8]xi16>, %b: vector<[8]xi16>) -> vector<[16]xi16> {
// CHECK: vector.interleave %{{.*}}, %{{.*}} : vector<[8]xi16>
- %0 = vector.interleave %a, %b : vector<[8]xi16>
+ %0 = vector.interleave %a, %b : vector<[8]xi16> -> vector<[16]xi16>
return %0 : vector<[16]xi16>
}
// CHECK-LABEL: @interleave_2d
func.func @interleave_2d(%a: vector<2x8xf32>, %b: vector<2x8xf32>) -> vector<2x16xf32> {
// CHECK: vector.interleave %{{.*}}, %{{.*}} : vector<2x8xf32>
- %0 = vector.interleave %a, %b : vector<2x8xf32>
+ %0 = vector.interleave %a, %b : vector<2x8xf32> -> vector<2x16xf32>
return %0 : vector<2x16xf32>
}
// CHECK-LABEL: @interleave_2d_scalable
func.func @interleave_2d_scalable(%a: vector<2x[2]xf64>, %b: vector<2x[2]xf64>) -> vector<2x[4]xf64> {
// CHECK: vector.interleave %{{.*}}, %{{.*}} : vector<2x[2]xf64>
- %0 = vector.interleave %a, %b : vector<2x[2]xf64>
+ %0 = vector.interleave %a, %b : vector<2x[2]xf64> -> vector<2x[4]xf64>
return %0 : vector<2x[4]xf64>
}
diff --git a/mlir/test/Dialect/Vector/vector-interleave-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-interleave-lowering-transforms.mlir
index 3dd4857860eb1..598f7d70b4f1b 100644
--- a/mlir/test/Dialect/Vector/vector-interleave-lowering-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-interleave-lowering-transforms.mlir
@@ -2,8 +2,7 @@
// CHECK-LABEL: @vector_interleave_2d
// CHECK-SAME: %[[LHS:.*]]: vector<2x3xi8>, %[[RHS:.*]]: vector<2x3xi8>)
-func.func @vector_interleave_2d(%a: vector<2x3xi8>, %b: vector<2x3xi8>) -> vector<2x6xi8>
-{
+func.func @vector_interleave_2d(%a: vector<2x3xi8>, %b: vector<2x3xi8>) -> vector<2x6xi8> {
// CHECK-DAG: %[[CST:.*]] = arith.constant dense<0>
// CHECK-DAG: %[[LHS_0:.*]] = vector.extract %[[LHS]][0]
// CHECK-DAG: %[[RHS_0:.*]] = vector.extract %[[RHS]][0]
@@ -14,14 +13,13 @@ func.func @vector_interleave_2d(%a: vector<2x3xi8>, %b: vector<2x3xi8>) -> vecto
// CHECK-DAG: %[[RES_0:.*]] = vector.insert %[[ZIP_0]], %[[CST]] [0]
// CHECK-DAG: %[[RES_1:.*]] = vector.insert %[[ZIP_1]], %[[RES_0]] [1]
// CHECK-NEXT: return %[[RES_1]] : vector<2x6xi8>
- %0 = vector.interleave %a, %b : vector<2x3xi8>
+ %0 = vector.interleave %a, %b : vector<2x3xi8> -> vector<2x6xi8>
return %0 : vector<2x6xi8>
}
// CHECK-LABEL: @vector_interleave_2d_scalable
// CHECK-SAME: %[[LHS:.*]]: vector<2x[8]xi16>, %[[RHS:.*]]: vector<2x[8]xi16>)
-func.func @vector_interleave_2d_scalable(%a: vector<2x[8]xi16>, %b: vector<2x[8]xi16>) -> vector<2x[16]xi16>
-{
+func.func @vector_interleave_2d_scalable(%a: vector<2x[8]xi16>, %b: vector<2x[8]xi16>) -> vector<2x[16]xi16> {
// CHECK-DAG: %[[CST:.*]] = arith.constant dense<0>
// CHECK-DAG: %[[LHS_0:.*]] = vector.extract %[[LHS]][0]
// CHECK-DAG: %[[RHS_0:.*]] = vector.extract %[[RHS]][0]
@@ -32,7 +30,7 @@ func.func @vector_interleave_2d_scalable(%a: vector<2x[8]xi16>, %b: vector<2x[8]
// CHECK-DAG: %[[RES_0:.*]] = vector.insert %[[ZIP_0]], %[[CST]] [0]
// CHECK-DAG: %[[RES_1:.*]] = vector.insert %[[ZIP_1]], %[[RES_0]] [1]
// CHECK-NEXT: return %[[RES_1]] : vector<2x[16]xi16>
- %0 = vector.interleave %a, %b : vector<2x[8]xi16>
+ %0 = vector.interleave %a, %b : vector<2x[8]xi16> -> vector<2x[16]xi16>
return %0 : vector<2x[16]xi16>
}
@@ -44,17 +42,17 @@ func.func @vector_interleave_4d(%a: vector<1x2x3x4xi64>, %b: vector<1x2x3x4xi64>
// CHECK: %[[RHS_0:.*]] = vector.extract %[[RHS]][0, 0, 0] : vector<4xi64> from vector<1x2x3x4xi64>
// CHECK: %[[ZIP_0:.*]] = vector.interleave %[[LHS_0]], %[[RHS_0]] : vector<4xi64>
// CHECK: %[[RES_0:.*]] = vector.insert %[[ZIP_0]], %{{.*}} [0, 0, 0] : vector<8xi64> into vector<1x2x3x8xi64>
- // CHECK-COUNT-5: vector.interleave %{{.*}}, %{{.*}} : vector<4xi64>
- %0 = vector.interleave %a, %b : vector<1x2x3x4xi64>
+ // CHECK-COUNT-5: vector.interleave %{{.*}}, %{{.*}} : vector<4xi64> -> vector<8xi64>
+ %0 = vector.interleave %a, %b : vector<1x2x3x4xi64> -> vector<1x2x3x8xi64>
return %0 : vector<1x2x3x8xi64>
}
// CHECK-LABEL: @vector_interleave_nd_with_scalable_dim
-func.func @vector_interleave_nd_with_scalable_dim(%a: vector<1x3x[2]x2x3x4xf16>, %b: vector<1x3x[2]x2x3x4xf16>) -> vector<1x3x[2]x2x3x8xf16>
-{
+func.func @vector_interleave_nd_with_scalable_dim(
+ %a: vector<1x3x[2]x2x3x4xf16>, %b: vector<1x3x[2]x2x3x4xf16>) -> vector<1x3x[2]x2x3x8xf16> {
// The scalable dim blocks unrolling so only the first two dims are unrolled.
// CHECK-COUNT-3: vector.interleave %{{.*}}, %{{.*}} : vector<[2]x2x3x4xf16>
- %0 = vector.interleave %a, %b : vector<1x3x[2]x2x3x4xf16>
+ %0 = vector.interleave %a, %b : vector<1x3x[2]x2x3x4xf16> -> vector<1x3x[2]x2x3x8xf16>
return %0 : vector<1x3x[2]x2x3x8xf16>
}
diff --git a/mlir/test/Dialect/Vector/vector-interleave-to-shuffle.mlir b/mlir/test/Dialect/Vector/vector-interleave-to-shuffle.mlir
index ed3b3396bf3ea..d59cd4e6765ba 100644
--- a/mlir/test/Dialect/Vector/vector-interleave-to-shuffle.mlir
+++ b/mlir/test/Dialect/Vector/vector-interleave-to-shuffle.mlir
@@ -1,9 +1,8 @@
// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
// CHECK-LABEL: @vector_interleave_to_shuffle
-func.func @vector_interleave_to_shuffle(%a: vector<7xi16>, %b: vector<7xi16>) -> vector<14xi16>
-{
- %0 = vector.interleave %a, %b : vector<7xi16>
+func.func @vector_interleave_to_shuffle(%a: vector<7xi16>, %b: vector<7xi16>) -> vector<14xi16> {
+ %0 = vector.interleave %a, %b : vector<7xi16> -> vector<14xi16>
return %0 : vector<14xi16>
}
// CHECK: vector.shuffle %arg0, %arg1 [0, 7, 1, 8, 2, 9, 3, 10, 4, 11, 5, 12, 6, 13] : vector<7xi16>, vector<7xi16>
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/test-scalable-interleave.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/test-scalable-interleave.mlir
index 07989bd71f501..e9f1bbeafacdd 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/test-scalable-interleave.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/test-scalable-interleave.mlir
@@ -17,7 +17,7 @@ func.func @entry() {
// CHECK: ( 1, 1, 1, 1
// CHECK: ( 2, 2, 2, 2
- %v3 = vector.interleave %v1, %v2 : vector<[4]xf32>
+ %v3 = vector.interleave %v1, %v2 : vector<[4]xf32> -> vector<[8]xf32>
vector.print %v3 : vector<[8]xf32>
// CHECK: ( 1, 2, 1, 2, 1, 2, 1, 2
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-interleave.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-interleave.mlir
index 0bc78af6aba03..d6962cbe2776a 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-interleave.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-interleave.mlir
@@ -16,7 +16,7 @@ func.func @entry() {
// CHECK: ( ( 1, 1, 1, 1 ), ( 1, 1, 1, 1 ) )
// CHECK: ( ( 2, 2, 2, 2 ), ( 2, 2, 2, 2 ) )
- %v3 = vector.interleave %v1, %v2 : vector<2x4xf32>
+ %v3 = vector.interleave %v1, %v2 : vector<2x4xf32> -> vector<2x8xf32>
vector.print %v3 : vector<2x8xf32>
// CHECK: ( ( 1, 2, 1, 2, 1, 2, 1, 2 ), ( 1, 2, 1, 2, 1, 2, 1, 2 ) )
>From 8ee982880bf550fbe78e8cff7df22bd195325256 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Mon, 27 May 2024 10:59:29 -0400
Subject: [PATCH 2/2] Improve comments
---
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index e043320b56411..56d866ac5b40c 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -516,9 +516,10 @@ def Vector_InterleaveOp :
```mlir
%a = arith.constant dense<[0, 1]> : vector<2xi32>
%b = arith.constant dense<[2, 3]> : vector<2xi32>
- %0 = vector.interleave %a, %b : vector<2xi32> -> vector<4xi32>
// The value of `%0` is `[0, 2, 1, 3]`.
+ %0 = vector.interleave %a, %b : vector<2xi32> -> vector<4xi32>
+ // Examples showing allowed input and result types.
%1 = vector.interleave %c, %d : vector<f16> -> vector<2xf16>
%2 = vector.interleave %e, %f : vector<6x3xf32> -> vector<6x6xf32>
%3 = vector.interleave %g, %h : vector<[4]xi32> -> vector<[8]xi32>
More information about the Mlir-commits
mailing list