[Mlir-commits] [mlir] [mlir][spirv] Implement SPIR-V lowering for `vector.deinterleave` (PR #95313)
Angel Zhang
llvmlistbot at llvm.org
Thu Jun 13 13:22:12 PDT 2024
https://github.com/angelz913 updated https://github.com/llvm/llvm-project/pull/95313
>From 6bd95761385ae10e7acaacc0e13ea5891ab8762d Mon Sep 17 00:00:00 2001
From: Angel Zhang <angel.zhang at amd.com>
Date: Wed, 12 Jun 2024 20:50:31 +0000
Subject: [PATCH 01/12] [mlir][spirv] Implement SPIR-V lowering for
vector.deinterleave
1. Added a conversion for vector.deinterleave to the VectorToSPIRV pass.
2. Added LIT tests for the new conversion.
---
.../VectorToSPIRV/VectorToSPIRV.cpp | 74 ++++++++++++++++++-
.../VectorToSPIRV/vector-to-spirv.mlir | 50 +++++++++++++
2 files changed, 121 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 92168cfa36147..b9a086cfc91a4 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -618,6 +618,74 @@ struct VectorInterleaveOpConvert final
}
};
+struct VectorDeinterleaveOpConvert final
+ : public OpConversionPattern<vector::DeinterleaveOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::DeinterleaveOp deinterleaveOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ // Check the result vector type.
+ VectorType oldResultType = deinterleaveOp.getResultVectorType();
+ Type newResultType = getTypeConverter()->convertType(oldResultType);
+ if (!newResultType)
+ return rewriter.notifyMatchFailure(deinterleaveOp,
+ "unsupported result vector type");
+
+ // Get location.
+ Location loc = deinterleaveOp->getLoc();
+
+ // Deinterleave the indices.
+ VectorType sourceType = deinterleaveOp.getSourceVectorType();
+ int n = sourceType.getNumElements();
+
+ // Output vectors of size 1 are converted to scalars by the type converter.
+ // We cannot use `spirv::VectorShuffleOp` directly in this case, and need to
+ // use `spirv::CompositeExtractOp`.
+ if (n == 2) {
+ spirv::CompositeExtractOp compositeExtractZero =
+ rewriter.create<spirv::CompositeExtractOp>(
+ loc, newResultType, adaptor.getSource(),
+ rewriter.getI32ArrayAttr({0}));
+
+ spirv::CompositeExtractOp compositeExtractOne =
+ rewriter.create<spirv::CompositeExtractOp>(
+ loc, newResultType, adaptor.getSource(),
+ rewriter.getI32ArrayAttr({1}));
+
+ rewriter.replaceOp(deinterleaveOp,
+ {compositeExtractZero, compositeExtractOne});
+ return success();
+ }
+
+ // Indices for `res1`.
+ auto seqEven = llvm::seq<int64_t>(n / 2);
+ auto indicesEven =
+ llvm::map_to_vector(seqEven, [](int i) { return i * 2; });
+
+ // Indices for `res2`.
+ auto seqOdd = llvm::seq<int64_t>(n / 2);
+ auto indicesOdd =
+ llvm::map_to_vector(seqOdd, [](int i) { return i * 2 + 1; });
+
+ // Create two SPIR-V shuffles.
+ spirv::VectorShuffleOp shuffleEven =
+ rewriter.create<spirv::VectorShuffleOp>(
+ loc, newResultType, adaptor.getSource(), adaptor.getSource(),
+ rewriter.getI32ArrayAttr(indicesEven));
+
+ spirv::VectorShuffleOp shuffleOdd = rewriter.create<spirv::VectorShuffleOp>(
+ loc, newResultType, adaptor.getSource(), adaptor.getSource(),
+ rewriter.getI32ArrayAttr(indicesOdd));
+
+ // Replace deinterleaveOp with SPIR-V shuffles.
+ rewriter.replaceOp(deinterleaveOp, {shuffleEven, shuffleOdd});
+
+ return success();
+ }
+};
+
struct VectorLoadOpConverter final
: public OpConversionPattern<vector::LoadOp> {
using OpConversionPattern::OpConversionPattern;
@@ -862,9 +930,9 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
- VectorInterleaveOpConvert, VectorSplatPattern, VectorLoadOpConverter,
- VectorStoreOpConverter>(typeConverter, patterns.getContext(),
- PatternBenefit(1));
+ VectorInterleaveOpConvert, VectorDeinterleaveOpConvert,
+ VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter>(
+ typeConverter, patterns.getContext(), PatternBenefit(1));
// Make sure that the more specialized dot product pattern has higher benefit
// than the generic one that extracts all elements.
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 2592d0fc04111..87823ab9afc0f 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -507,6 +507,56 @@ func.func @interleave_size1(%a: vector<1xf32>, %b: vector<1xf32>) -> vector<2xf3
// -----
+// CHECK-LABEL: func @deinterleave_return0
+// CHECK-SAME: (%[[ARG0:.+]]: vector<4xf32>)
+// CHECK: %[[SHUFFLE0:.*]] = spirv.VectorShuffle [0 : i32, 2 : i32] %[[ARG0]], %[[ARG0]] : vector<4xf32>, vector<4xf32> -> vector<2xf32>
+// CHECK: %[[SHUFFLE1:.*]] = spirv.VectorShuffle [1 : i32, 3 : i32] %[[ARG0]], %[[ARG0]] : vector<4xf32>, vector<4xf32> -> vector<2xf32>
+// CHECK: return %[[SHUFFLE0]]
+func.func @deinterleave_return0(%a: vector<4xf32>) -> vector<2xf32> {
+ %0, %1 = vector.deinterleave %a : vector<4xf32> -> vector<2xf32>
+ return %0 : vector<2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @deinterleave_return1
+// CHECK-SAME: (%[[ARG0:.+]]: vector<4xf32>)
+// CHECK: %[[SHUFFLE0:.*]] = spirv.VectorShuffle [0 : i32, 2 : i32] %[[ARG0]], %[[ARG0]] : vector<4xf32>, vector<4xf32> -> vector<2xf32>
+// CHECK: %[[SHUFFLE1:.*]] = spirv.VectorShuffle [1 : i32, 3 : i32] %[[ARG0]], %[[ARG0]] : vector<4xf32>, vector<4xf32> -> vector<2xf32>
+// CHECK: return %[[SHUFFLE1]]
+func.func @deinterleave_return1(%a: vector<4xf32>) -> vector<2xf32> {
+ %0, %1 = vector.deinterleave %a : vector<4xf32> -> vector<2xf32>
+ return %1 : vector<2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @deinterleave_scalar_return0
+// CHECK-SAME: (%[[ARG0:.+]]: vector<2xf32>)
+// CHECK: %[[EXTRACT0:.*]] = spirv.CompositeExtract %[[ARG0]][0 : i32] : vector<2xf32>
+// CHECK: %[[EXTRACT1:.*]] = spirv.CompositeExtract %[[ARG0]][1 : i32] : vector<2xf32>
+// CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[EXTRACT0]] : f32 to vector<1xf32>
+// CHECK: return %[[RES]]
+func.func @deinterleave_scalar_return0(%a: vector<2xf32>) -> vector<1xf32> {
+ %0, %1 = vector.deinterleave %a: vector<2xf32> -> vector<1xf32>
+ return %0 : vector<1xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @deinterleave_scalar_return1
+// CHECK-SAME: (%[[ARG0:.+]]: vector<2xf32>)
+// CHECK: %[[EXTRACT0:.*]] = spirv.CompositeExtract %[[ARG0]][0 : i32] : vector<2xf32>
+// CHECK: %[[EXTRACT1:.*]] = spirv.CompositeExtract %[[ARG0]][1 : i32] : vector<2xf32>
+// CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[EXTRACT1]] : f32 to vector<1xf32>
+// CHECK: return %[[RES]]
+func.func @deinterleave_scalar_return1(%a: vector<2xf32>) -> vector<1xf32> {
+ %0, %1 = vector.deinterleave %a: vector<2xf32> -> vector<1xf32>
+ return %1 : vector<1xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @reduction_add
// CHECK-SAME: (%[[V:.+]]: vector<4xi32>)
// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<4xi32>
>From 9de11f1be9a183e9559363e4225cb23e61768c29 Mon Sep 17 00:00:00 2001
From: Angel Zhang <anzhouzhang913 at gmail.com>
Date: Wed, 12 Jun 2024 17:30:36 -0400
Subject: [PATCH 02/12] Remove comment
Co-authored-by: Jakub Kuderski <kubakuderski at gmail.com>
---
mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp | 1 -
1 file changed, 1 deletion(-)
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index b9a086cfc91a4..0c0fc473b2190 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -633,7 +633,6 @@ struct VectorDeinterleaveOpConvert final
return rewriter.notifyMatchFailure(deinterleaveOp,
"unsupported result vector type");
- // Get location.
Location loc = deinterleaveOp->getLoc();
// Deinterleave the indices.
>From f88cdf670432c67158e71287de8fe0ab35799efb Mon Sep 17 00:00:00 2001
From: Angel Zhang <anzhouzhang913 at gmail.com>
Date: Wed, 12 Jun 2024 17:31:55 -0400
Subject: [PATCH 03/12] Fix naming style
Co-authored-by: Jakub Kuderski <kubakuderski at gmail.com>
---
mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 0c0fc473b2190..3cdd4ee524946 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -648,7 +648,7 @@ struct VectorDeinterleaveOpConvert final
loc, newResultType, adaptor.getSource(),
rewriter.getI32ArrayAttr({0}));
- spirv::CompositeExtractOp compositeExtractOne =
+ auto elem1 =
rewriter.create<spirv::CompositeExtractOp>(
loc, newResultType, adaptor.getSource(),
rewriter.getI32ArrayAttr({1}));
>From 444aea26d7aba7d0c02296e550c974000a0c3ac6 Mon Sep 17 00:00:00 2001
From: Angel Zhang <anzhouzhang913 at gmail.com>
Date: Wed, 12 Jun 2024 17:32:05 -0400
Subject: [PATCH 04/12] Fix naming style
Co-authored-by: Jakub Kuderski <kubakuderski at gmail.com>
---
mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 3cdd4ee524946..06eba43df3556 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -653,8 +653,7 @@ struct VectorDeinterleaveOpConvert final
loc, newResultType, adaptor.getSource(),
rewriter.getI32ArrayAttr({1}));
- rewriter.replaceOp(deinterleaveOp,
- {compositeExtractZero, compositeExtractOne});
+ rewriter.replaceOp(deinterleaveOp, {elem0, elem1});
return success();
}
>From c6142bdd44f947b7f52252daf6c82c14179d8a05 Mon Sep 17 00:00:00 2001
From: Angel Zhang <anzhouzhang913 at gmail.com>
Date: Wed, 12 Jun 2024 17:32:37 -0400
Subject: [PATCH 05/12] Fix naming style
Co-authored-by: Jakub Kuderski <kubakuderski at gmail.com>
---
mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 06eba43df3556..e32af3abe1568 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -643,7 +643,7 @@ struct VectorDeinterleaveOpConvert final
// We cannot use `spirv::VectorShuffleOp` directly in this case, and need to
// use `spirv::CompositeExtractOp`.
if (n == 2) {
- spirv::CompositeExtractOp compositeExtractZero =
+ auto elem0 =
rewriter.create<spirv::CompositeExtractOp>(
loc, newResultType, adaptor.getSource(),
rewriter.getI32ArrayAttr({0}));
>From b6c33100f4478312f0a28051d716b41a7b985b4a Mon Sep 17 00:00:00 2001
From: Angel Zhang <anzhouzhang913 at gmail.com>
Date: Wed, 12 Jun 2024 17:32:55 -0400
Subject: [PATCH 06/12] Remove empty line
Co-authored-by: Jakub Kuderski <kubakuderski at gmail.com>
---
mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp | 1 -
1 file changed, 1 deletion(-)
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index e32af3abe1568..0b3aad349c8a4 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -679,7 +679,6 @@ struct VectorDeinterleaveOpConvert final
// Replace deinterleaveOp with SPIR-V shuffles.
rewriter.replaceOp(deinterleaveOp, {shuffleEven, shuffleOdd});
-
return success();
}
};
>From 389eaccc33825000bb1267622834d2f1248b6a7b Mon Sep 17 00:00:00 2001
From: Angel Zhang <anzhouzhang913 at gmail.com>
Date: Wed, 12 Jun 2024 17:43:07 -0400
Subject: [PATCH 07/12] Remove comment
Co-authored-by: Jakub Kuderski <kubakuderski at gmail.com>
---
mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp | 1 -
1 file changed, 1 deletion(-)
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 0b3aad349c8a4..7f5274047df2f 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -677,7 +677,6 @@ struct VectorDeinterleaveOpConvert final
loc, newResultType, adaptor.getSource(), adaptor.getSource(),
rewriter.getI32ArrayAttr(indicesOdd));
- // Replace deinterleaveOp with SPIR-V shuffles.
rewriter.replaceOp(deinterleaveOp, {shuffleEven, shuffleOdd});
return success();
}
>From e72114a7cb666a424d07fc957f0a37b3862280ed Mon Sep 17 00:00:00 2001
From: Angel Zhang <angel.zhang at amd.com>
Date: Wed, 12 Jun 2024 21:46:55 +0000
Subject: [PATCH 08/12] Refactor code
---
.../Conversion/VectorToSPIRV/VectorToSPIRV.cpp | 17 +++++++----------
1 file changed, 7 insertions(+), 10 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 7f5274047df2f..aaf2ce39e0052 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -636,6 +636,7 @@ struct VectorDeinterleaveOpConvert final
Location loc = deinterleaveOp->getLoc();
// Deinterleave the indices.
+ Value sourceVector = adaptor.getSource();
VectorType sourceType = deinterleaveOp.getSourceVectorType();
int n = sourceType.getNumElements();
@@ -643,15 +644,11 @@ struct VectorDeinterleaveOpConvert final
// We cannot use `spirv::VectorShuffleOp` directly in this case, and need to
// use `spirv::CompositeExtractOp`.
if (n == 2) {
- auto elem0 =
- rewriter.create<spirv::CompositeExtractOp>(
- loc, newResultType, adaptor.getSource(),
- rewriter.getI32ArrayAttr({0}));
+ auto elem0 = rewriter.create<spirv::CompositeExtractOp>(
+ loc, newResultType, sourceVector, rewriter.getI32ArrayAttr({0}));
- auto elem1 =
- rewriter.create<spirv::CompositeExtractOp>(
- loc, newResultType, adaptor.getSource(),
- rewriter.getI32ArrayAttr({1}));
+ auto elem1 = rewriter.create<spirv::CompositeExtractOp>(
+ loc, newResultType, sourceVector, rewriter.getI32ArrayAttr({1}));
rewriter.replaceOp(deinterleaveOp, {elem0, elem1});
return success();
@@ -670,11 +667,11 @@ struct VectorDeinterleaveOpConvert final
// Create two SPIR-V shuffles.
spirv::VectorShuffleOp shuffleEven =
rewriter.create<spirv::VectorShuffleOp>(
- loc, newResultType, adaptor.getSource(), adaptor.getSource(),
+ loc, newResultType, sourceVector, sourceVector,
rewriter.getI32ArrayAttr(indicesEven));
spirv::VectorShuffleOp shuffleOdd = rewriter.create<spirv::VectorShuffleOp>(
- loc, newResultType, adaptor.getSource(), adaptor.getSource(),
+ loc, newResultType, sourceVector, sourceVector,
rewriter.getI32ArrayAttr(indicesOdd));
rewriter.replaceOp(deinterleaveOp, {shuffleEven, shuffleOdd});
>From b7a00c20d35a13c59c950e2b4888ae2bf0d39e67 Mon Sep 17 00:00:00 2001
From: Angel Zhang <angel.zhang at amd.com>
Date: Thu, 13 Jun 2024 12:54:34 +0000
Subject: [PATCH 09/12] Fix style
---
mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp | 9 ++++-----
1 file changed, 4 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index aaf2ce39e0052..b785df3f755e8 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -665,12 +665,11 @@ struct VectorDeinterleaveOpConvert final
llvm::map_to_vector(seqOdd, [](int i) { return i * 2 + 1; });
// Create two SPIR-V shuffles.
- spirv::VectorShuffleOp shuffleEven =
- rewriter.create<spirv::VectorShuffleOp>(
- loc, newResultType, sourceVector, sourceVector,
- rewriter.getI32ArrayAttr(indicesEven));
+ auto shuffleEven = rewriter.create<spirv::VectorShuffleOp>(
+ loc, newResultType, sourceVector, sourceVector,
+ rewriter.getI32ArrayAttr(indicesEven));
- spirv::VectorShuffleOp shuffleOdd = rewriter.create<spirv::VectorShuffleOp>(
+ auto shuffleOdd = rewriter.create<spirv::VectorShuffleOp>(
loc, newResultType, sourceVector, sourceVector,
rewriter.getI32ArrayAttr(indicesOdd));
>From 9a3396c3084e492c6325ebeb1ffa9ed64628ec90 Mon Sep 17 00:00:00 2001
From: Angel Zhang <angel.zhang at amd.com>
Date: Thu, 13 Jun 2024 13:18:46 +0000
Subject: [PATCH 10/12] Remove duplicate tests
---
.../VectorToSPIRV/vector-to-spirv.mlir | 44 +++++--------------
1 file changed, 10 insertions(+), 34 deletions(-)
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 87823ab9afc0f..6c6a9a1d0c6c5 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -507,52 +507,28 @@ func.func @interleave_size1(%a: vector<1xf32>, %b: vector<1xf32>) -> vector<2xf3
// -----
-// CHECK-LABEL: func @deinterleave_return0
+// CHECK-LABEL: func @deinterleave
// CHECK-SAME: (%[[ARG0:.+]]: vector<4xf32>)
// CHECK: %[[SHUFFLE0:.*]] = spirv.VectorShuffle [0 : i32, 2 : i32] %[[ARG0]], %[[ARG0]] : vector<4xf32>, vector<4xf32> -> vector<2xf32>
// CHECK: %[[SHUFFLE1:.*]] = spirv.VectorShuffle [1 : i32, 3 : i32] %[[ARG0]], %[[ARG0]] : vector<4xf32>, vector<4xf32> -> vector<2xf32>
-// CHECK: return %[[SHUFFLE0]]
-func.func @deinterleave_return0(%a: vector<4xf32>) -> vector<2xf32> {
+// CHECK: return %[[SHUFFLE0]], %[[SHUFFLE1]]
+func.func @deinterleave(%a: vector<4xf32>) -> (vector<2xf32>, vector<2xf32>) {
%0, %1 = vector.deinterleave %a : vector<4xf32> -> vector<2xf32>
- return %0 : vector<2xf32>
+ return %0, %1 : vector<2xf32>, vector<2xf32>
}
// -----
-// CHECK-LABEL: func @deinterleave_return1
-// CHECK-SAME: (%[[ARG0:.+]]: vector<4xf32>)
-// CHECK: %[[SHUFFLE0:.*]] = spirv.VectorShuffle [0 : i32, 2 : i32] %[[ARG0]], %[[ARG0]] : vector<4xf32>, vector<4xf32> -> vector<2xf32>
-// CHECK: %[[SHUFFLE1:.*]] = spirv.VectorShuffle [1 : i32, 3 : i32] %[[ARG0]], %[[ARG0]] : vector<4xf32>, vector<4xf32> -> vector<2xf32>
-// CHECK: return %[[SHUFFLE1]]
-func.func @deinterleave_return1(%a: vector<4xf32>) -> vector<2xf32> {
- %0, %1 = vector.deinterleave %a : vector<4xf32> -> vector<2xf32>
- return %1 : vector<2xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func @deinterleave_scalar_return0
+// CHECK-LABEL: func @deinterleave_scalar
// CHECK-SAME: (%[[ARG0:.+]]: vector<2xf32>)
// CHECK: %[[EXTRACT0:.*]] = spirv.CompositeExtract %[[ARG0]][0 : i32] : vector<2xf32>
// CHECK: %[[EXTRACT1:.*]] = spirv.CompositeExtract %[[ARG0]][1 : i32] : vector<2xf32>
-// CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[EXTRACT0]] : f32 to vector<1xf32>
-// CHECK: return %[[RES]]
-func.func @deinterleave_scalar_return0(%a: vector<2xf32>) -> vector<1xf32> {
+// CHECK: %[[CAST0:.*]] = builtin.unrealized_conversion_cast %[[EXTRACT0]] : f32 to vector<1xf32>
+// CHECK: %[[CAST1:.*]] = builtin.unrealized_conversion_cast %[[EXTRACT1]] : f32 to vector<1xf32>
+// CHECK: return %[[CAST0]], %[[CAST1]]
+func.func @deinterleave_scalar(%a: vector<2xf32>) -> (vector<1xf32>, vector<1xf32>) {
%0, %1 = vector.deinterleave %a: vector<2xf32> -> vector<1xf32>
- return %0 : vector<1xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func @deinterleave_scalar_return1
-// CHECK-SAME: (%[[ARG0:.+]]: vector<2xf32>)
-// CHECK: %[[EXTRACT0:.*]] = spirv.CompositeExtract %[[ARG0]][0 : i32] : vector<2xf32>
-// CHECK: %[[EXTRACT1:.*]] = spirv.CompositeExtract %[[ARG0]][1 : i32] : vector<2xf32>
-// CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[EXTRACT1]] : f32 to vector<1xf32>
-// CHECK: return %[[RES]]
-func.func @deinterleave_scalar_return1(%a: vector<2xf32>) -> vector<1xf32> {
- %0, %1 = vector.deinterleave %a: vector<2xf32> -> vector<1xf32>
- return %1 : vector<1xf32>
+ return %0, %1 : vector<1xf32>, vector<1xf32>
}
// -----
>From ff169d5ab725ec31b5a193b304c5e9ca4de99a87 Mon Sep 17 00:00:00 2001
From: Angel Zhang <anzhouzhang913 at gmail.com>
Date: Thu, 13 Jun 2024 10:34:34 -0400
Subject: [PATCH 11/12] Change comment
Co-authored-by: Jakub Kuderski <kubakuderski at gmail.com>
---
mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index b785df3f755e8..dfa117fe86cfb 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -654,7 +654,7 @@ struct VectorDeinterleaveOpConvert final
return success();
}
- // Indices for `res1`.
+ // Indices for `shuffleEven` (result 0).
auto seqEven = llvm::seq<int64_t>(n / 2);
auto indicesEven =
llvm::map_to_vector(seqEven, [](int i) { return i * 2; });
>From 0335c7a90ea9aa247c05a1a33e48fff12980411c Mon Sep 17 00:00:00 2001
From: Angel Zhang <anzhouzhang913 at gmail.com>
Date: Thu, 13 Jun 2024 10:34:45 -0400
Subject: [PATCH 12/12] Change comment
Co-authored-by: Jakub Kuderski <kubakuderski at gmail.com>
---
mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index dfa117fe86cfb..8baa31a235950 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -659,7 +659,7 @@ struct VectorDeinterleaveOpConvert final
auto indicesEven =
llvm::map_to_vector(seqEven, [](int i) { return i * 2; });
- // Indices for `res2`.
+ // Indices for `shuffleOdd` (result 1).
auto seqOdd = llvm::seq<int64_t>(n / 2);
auto indicesOdd =
llvm::map_to_vector(seqOdd, [](int i) { return i * 2 + 1; });
More information about the Mlir-commits
mailing list