[Mlir-commits] [mlir] [mlir][spirv] Add integration test for `vector.deinterleave` (PR #95465)

Angel Zhang llvmlistbot at llvm.org
Thu Jun 13 13:15:57 PDT 2024


https://github.com/angelz913 updated https://github.com/llvm/llvm-project/pull/95465

>From d19732f9707d590d9d40d73c382d697a2c834da7 Mon Sep 17 00:00:00 2001
From: Angel Zhang <angel.zhang at amd.com>
Date: Wed, 12 Jun 2024 20:50:31 +0000
Subject: [PATCH 01/13] [mlir][spirv] Implement SPIR-V lowering for
 vector.deinterleave

1. Added a conversion for vector.deinterleave to the VectorToSPIRV pass.
2. Added LIT tests for the new conversion.
---
 .../VectorToSPIRV/VectorToSPIRV.cpp           | 74 ++++++++++++++++++-
 .../VectorToSPIRV/vector-to-spirv.mlir        | 50 +++++++++++++
 2 files changed, 121 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 92168cfa36147..b9a086cfc91a4 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -618,6 +618,74 @@ struct VectorInterleaveOpConvert final
   }
 };
 
+struct VectorDeinterleaveOpConvert final
+    : public OpConversionPattern<vector::DeinterleaveOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::DeinterleaveOp deinterleaveOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    // Check the result vector type.
+    VectorType oldResultType = deinterleaveOp.getResultVectorType();
+    Type newResultType = getTypeConverter()->convertType(oldResultType);
+    if (!newResultType)
+      return rewriter.notifyMatchFailure(deinterleaveOp,
+                                         "unsupported result vector type");
+
+    // Get location.
+    Location loc = deinterleaveOp->getLoc();
+
+    // Deinterleave the indices.
+    VectorType sourceType = deinterleaveOp.getSourceVectorType();
+    int n = sourceType.getNumElements();
+
+    // Output vectors of size 1 are converted to scalars by the type converter.
+    // We cannot use `spirv::VectorShuffleOp` directly in this case, and need to
+    // use `spirv::CompositeExtractOp`.
+    if (n == 2) {
+      spirv::CompositeExtractOp compositeExtractZero =
+          rewriter.create<spirv::CompositeExtractOp>(
+              loc, newResultType, adaptor.getSource(),
+              rewriter.getI32ArrayAttr({0}));
+
+      spirv::CompositeExtractOp compositeExtractOne =
+          rewriter.create<spirv::CompositeExtractOp>(
+              loc, newResultType, adaptor.getSource(),
+              rewriter.getI32ArrayAttr({1}));
+
+      rewriter.replaceOp(deinterleaveOp,
+                         {compositeExtractZero, compositeExtractOne});
+      return success();
+    }
+
+    // Indices for `res1`.
+    auto seqEven = llvm::seq<int64_t>(n / 2);
+    auto indicesEven =
+        llvm::map_to_vector(seqEven, [](int i) { return i * 2; });
+
+    // Indices for `res2`.
+    auto seqOdd = llvm::seq<int64_t>(n / 2);
+    auto indicesOdd =
+        llvm::map_to_vector(seqOdd, [](int i) { return i * 2 + 1; });
+
+    // Create two SPIR-V shuffles.
+    spirv::VectorShuffleOp shuffleEven =
+        rewriter.create<spirv::VectorShuffleOp>(
+            loc, newResultType, adaptor.getSource(), adaptor.getSource(),
+            rewriter.getI32ArrayAttr(indicesEven));
+
+    spirv::VectorShuffleOp shuffleOdd = rewriter.create<spirv::VectorShuffleOp>(
+        loc, newResultType, adaptor.getSource(), adaptor.getSource(),
+        rewriter.getI32ArrayAttr(indicesOdd));
+
+    // Replace deinterleaveOp with SPIR-V shuffles.
+    rewriter.replaceOp(deinterleaveOp, {shuffleEven, shuffleOdd});
+
+    return success();
+  }
+};
+
 struct VectorLoadOpConverter final
     : public OpConversionPattern<vector::LoadOp> {
   using OpConversionPattern::OpConversionPattern;
@@ -862,9 +930,9 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
       VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
       VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
       VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
-      VectorInterleaveOpConvert, VectorSplatPattern, VectorLoadOpConverter,
-      VectorStoreOpConverter>(typeConverter, patterns.getContext(),
-                              PatternBenefit(1));
+      VectorInterleaveOpConvert, VectorDeinterleaveOpConvert,
+      VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter>(
+      typeConverter, patterns.getContext(), PatternBenefit(1));
 
   // Make sure that the more specialized dot product pattern has higher benefit
   // than the generic one that extracts all elements.
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 2592d0fc04111..87823ab9afc0f 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -507,6 +507,56 @@ func.func @interleave_size1(%a: vector<1xf32>, %b: vector<1xf32>) -> vector<2xf3
 
 // -----
 
+// CHECK-LABEL: func @deinterleave_return0
+// CHECK-SAME: (%[[ARG0:.+]]: vector<4xf32>)
+//       CHECK: %[[SHUFFLE0:.*]] = spirv.VectorShuffle [0 : i32, 2 : i32] %[[ARG0]], %[[ARG0]] : vector<4xf32>, vector<4xf32> -> vector<2xf32>
+//       CHECK: %[[SHUFFLE1:.*]] = spirv.VectorShuffle [1 : i32, 3 : i32] %[[ARG0]], %[[ARG0]] : vector<4xf32>, vector<4xf32> -> vector<2xf32>
+//       CHECK: return %[[SHUFFLE0]]
+func.func @deinterleave_return0(%a: vector<4xf32>) -> vector<2xf32> {
+  %0, %1 = vector.deinterleave %a : vector<4xf32> -> vector<2xf32>
+  return %0 : vector<2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @deinterleave_return1
+// CHECK-SAME: (%[[ARG0:.+]]: vector<4xf32>)
+//       CHECK: %[[SHUFFLE0:.*]] = spirv.VectorShuffle [0 : i32, 2 : i32] %[[ARG0]], %[[ARG0]] : vector<4xf32>, vector<4xf32> -> vector<2xf32>
+//       CHECK: %[[SHUFFLE1:.*]] = spirv.VectorShuffle [1 : i32, 3 : i32] %[[ARG0]], %[[ARG0]] : vector<4xf32>, vector<4xf32> -> vector<2xf32>
+//       CHECK: return %[[SHUFFLE1]]
+func.func @deinterleave_return1(%a: vector<4xf32>) -> vector<2xf32> {
+  %0, %1 = vector.deinterleave %a : vector<4xf32> -> vector<2xf32>
+  return %1 : vector<2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @deinterleave_scalar_return0
+// CHECK-SAME: (%[[ARG0:.+]]: vector<2xf32>)
+//       CHECK: %[[EXTRACT0:.*]] = spirv.CompositeExtract %[[ARG0]][0 : i32] : vector<2xf32>
+//       CHECK: %[[EXTRACT1:.*]] = spirv.CompositeExtract %[[ARG0]][1 : i32] : vector<2xf32>
+//       CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[EXTRACT0]] : f32 to vector<1xf32>
+//       CHECK: return %[[RES]]
+func.func @deinterleave_scalar_return0(%a: vector<2xf32>) -> vector<1xf32> {
+  %0, %1 = vector.deinterleave %a: vector<2xf32> -> vector<1xf32>
+  return %0 : vector<1xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @deinterleave_scalar_return1
+// CHECK-SAME: (%[[ARG0:.+]]: vector<2xf32>)
+//       CHECK: %[[EXTRACT0:.*]] = spirv.CompositeExtract %[[ARG0]][0 : i32] : vector<2xf32>
+//       CHECK: %[[EXTRACT1:.*]] = spirv.CompositeExtract %[[ARG0]][1 : i32] : vector<2xf32>
+//       CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[EXTRACT1]] : f32 to vector<1xf32>
+//       CHECK: return %[[RES]]
+func.func @deinterleave_scalar_return1(%a: vector<2xf32>) -> vector<1xf32> {
+  %0, %1 = vector.deinterleave %a: vector<2xf32> -> vector<1xf32>
+  return %1 : vector<1xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @reduction_add
 //  CHECK-SAME: (%[[V:.+]]: vector<4xi32>)
 //       CHECK:   %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<4xi32>

>From 2695c00f7f6fad263a1915d0ecff9cd02eeeb24a Mon Sep 17 00:00:00 2001
From: Angel Zhang <anzhouzhang913 at gmail.com>
Date: Wed, 12 Jun 2024 17:30:36 -0400
Subject: [PATCH 02/13] Remove comment

Co-authored-by: Jakub Kuderski <kubakuderski at gmail.com>
---
 mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp | 1 -
 1 file changed, 1 deletion(-)

diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index b9a086cfc91a4..0c0fc473b2190 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -633,7 +633,6 @@ struct VectorDeinterleaveOpConvert final
       return rewriter.notifyMatchFailure(deinterleaveOp,
                                          "unsupported result vector type");
 
-    // Get location.
     Location loc = deinterleaveOp->getLoc();
 
     // Deinterleave the indices.

>From c58482613ee3a0e6d0ae3c2170b7127d0aa081ca Mon Sep 17 00:00:00 2001
From: Angel Zhang <anzhouzhang913 at gmail.com>
Date: Wed, 12 Jun 2024 17:31:55 -0400
Subject: [PATCH 03/13] Fix naming style

Co-authored-by: Jakub Kuderski <kubakuderski at gmail.com>
---
 mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 0c0fc473b2190..3cdd4ee524946 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -648,7 +648,7 @@ struct VectorDeinterleaveOpConvert final
               loc, newResultType, adaptor.getSource(),
               rewriter.getI32ArrayAttr({0}));
 
-      spirv::CompositeExtractOp compositeExtractOne =
+      auto elem1 =
           rewriter.create<spirv::CompositeExtractOp>(
               loc, newResultType, adaptor.getSource(),
               rewriter.getI32ArrayAttr({1}));

>From 40a61455488a7815eae79872d766305b60bcc941 Mon Sep 17 00:00:00 2001
From: Angel Zhang <anzhouzhang913 at gmail.com>
Date: Wed, 12 Jun 2024 17:32:05 -0400
Subject: [PATCH 04/13] Fix naming style

Co-authored-by: Jakub Kuderski <kubakuderski at gmail.com>
---
 mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 3cdd4ee524946..06eba43df3556 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -653,8 +653,7 @@ struct VectorDeinterleaveOpConvert final
               loc, newResultType, adaptor.getSource(),
               rewriter.getI32ArrayAttr({1}));
 
-      rewriter.replaceOp(deinterleaveOp,
-                         {compositeExtractZero, compositeExtractOne});
+      rewriter.replaceOp(deinterleaveOp, {elem0, elem1});
       return success();
     }
 

>From c15cd47a926fbf4c6610933b3e3f8d588a43fd6f Mon Sep 17 00:00:00 2001
From: Angel Zhang <anzhouzhang913 at gmail.com>
Date: Wed, 12 Jun 2024 17:32:37 -0400
Subject: [PATCH 05/13] Fix naming style

Co-authored-by: Jakub Kuderski <kubakuderski at gmail.com>
---
 mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 06eba43df3556..e32af3abe1568 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -643,7 +643,7 @@ struct VectorDeinterleaveOpConvert final
     // We cannot use `spirv::VectorShuffleOp` directly in this case, and need to
     // use `spirv::CompositeExtractOp`.
     if (n == 2) {
-      spirv::CompositeExtractOp compositeExtractZero =
+      auto elem0 =
           rewriter.create<spirv::CompositeExtractOp>(
               loc, newResultType, adaptor.getSource(),
               rewriter.getI32ArrayAttr({0}));

>From 5607048e6d1ddfb94a2ab41e4362346efe40ccb6 Mon Sep 17 00:00:00 2001
From: Angel Zhang <anzhouzhang913 at gmail.com>
Date: Wed, 12 Jun 2024 17:32:55 -0400
Subject: [PATCH 06/13] Remove empty line

Co-authored-by: Jakub Kuderski <kubakuderski at gmail.com>
---
 mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp | 1 -
 1 file changed, 1 deletion(-)

diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index e32af3abe1568..0b3aad349c8a4 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -679,7 +679,6 @@ struct VectorDeinterleaveOpConvert final
 
     // Replace deinterleaveOp with SPIR-V shuffles.
     rewriter.replaceOp(deinterleaveOp, {shuffleEven, shuffleOdd});
-
     return success();
   }
 };

>From cb16983ab686d24600f65fa3f058a0c2b5cc90d6 Mon Sep 17 00:00:00 2001
From: Angel Zhang <anzhouzhang913 at gmail.com>
Date: Wed, 12 Jun 2024 17:43:07 -0400
Subject: [PATCH 07/13] Remove comment

Co-authored-by: Jakub Kuderski <kubakuderski at gmail.com>
---
 mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp | 1 -
 1 file changed, 1 deletion(-)

diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 0b3aad349c8a4..7f5274047df2f 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -677,7 +677,6 @@ struct VectorDeinterleaveOpConvert final
         loc, newResultType, adaptor.getSource(), adaptor.getSource(),
         rewriter.getI32ArrayAttr(indicesOdd));
 
-    // Replace deinterleaveOp with SPIR-V shuffles.
     rewriter.replaceOp(deinterleaveOp, {shuffleEven, shuffleOdd});
     return success();
   }

>From 9c701622a5cc93524104eb8024258601d0575443 Mon Sep 17 00:00:00 2001
From: Angel Zhang <angel.zhang at amd.com>
Date: Wed, 12 Jun 2024 21:46:55 +0000
Subject: [PATCH 08/13] Refactor code

---
 .../Conversion/VectorToSPIRV/VectorToSPIRV.cpp  | 17 +++++++----------
 1 file changed, 7 insertions(+), 10 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 7f5274047df2f..aaf2ce39e0052 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -636,6 +636,7 @@ struct VectorDeinterleaveOpConvert final
     Location loc = deinterleaveOp->getLoc();
 
     // Deinterleave the indices.
+    Value sourceVector = adaptor.getSource();
     VectorType sourceType = deinterleaveOp.getSourceVectorType();
     int n = sourceType.getNumElements();
 
@@ -643,15 +644,11 @@ struct VectorDeinterleaveOpConvert final
     // We cannot use `spirv::VectorShuffleOp` directly in this case, and need to
     // use `spirv::CompositeExtractOp`.
     if (n == 2) {
-      auto elem0 =
-          rewriter.create<spirv::CompositeExtractOp>(
-              loc, newResultType, adaptor.getSource(),
-              rewriter.getI32ArrayAttr({0}));
+      auto elem0 = rewriter.create<spirv::CompositeExtractOp>(
+          loc, newResultType, sourceVector, rewriter.getI32ArrayAttr({0}));
 
-      auto elem1 =
-          rewriter.create<spirv::CompositeExtractOp>(
-              loc, newResultType, adaptor.getSource(),
-              rewriter.getI32ArrayAttr({1}));
+      auto elem1 = rewriter.create<spirv::CompositeExtractOp>(
+          loc, newResultType, sourceVector, rewriter.getI32ArrayAttr({1}));
 
       rewriter.replaceOp(deinterleaveOp, {elem0, elem1});
       return success();
@@ -670,11 +667,11 @@ struct VectorDeinterleaveOpConvert final
     // Create two SPIR-V shuffles.
     spirv::VectorShuffleOp shuffleEven =
         rewriter.create<spirv::VectorShuffleOp>(
-            loc, newResultType, adaptor.getSource(), adaptor.getSource(),
+            loc, newResultType, sourceVector, sourceVector,
             rewriter.getI32ArrayAttr(indicesEven));
 
     spirv::VectorShuffleOp shuffleOdd = rewriter.create<spirv::VectorShuffleOp>(
-        loc, newResultType, adaptor.getSource(), adaptor.getSource(),
+        loc, newResultType, sourceVector, sourceVector,
         rewriter.getI32ArrayAttr(indicesOdd));
 
     rewriter.replaceOp(deinterleaveOp, {shuffleEven, shuffleOdd});

>From deeee41c1871e2706a682d0de0cc17b4fba4d272 Mon Sep 17 00:00:00 2001
From: Angel Zhang <angel.zhang at amd.com>
Date: Thu, 13 Jun 2024 12:54:34 +0000
Subject: [PATCH 09/13] Fix style

---
 mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp | 9 ++++-----
 1 file changed, 4 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index aaf2ce39e0052..b785df3f755e8 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -665,12 +665,11 @@ struct VectorDeinterleaveOpConvert final
         llvm::map_to_vector(seqOdd, [](int i) { return i * 2 + 1; });
 
     // Create two SPIR-V shuffles.
-    spirv::VectorShuffleOp shuffleEven =
-        rewriter.create<spirv::VectorShuffleOp>(
-            loc, newResultType, sourceVector, sourceVector,
-            rewriter.getI32ArrayAttr(indicesEven));
+    auto shuffleEven = rewriter.create<spirv::VectorShuffleOp>(
+        loc, newResultType, sourceVector, sourceVector,
+        rewriter.getI32ArrayAttr(indicesEven));
 
-    spirv::VectorShuffleOp shuffleOdd = rewriter.create<spirv::VectorShuffleOp>(
+    auto shuffleOdd = rewriter.create<spirv::VectorShuffleOp>(
         loc, newResultType, sourceVector, sourceVector,
         rewriter.getI32ArrayAttr(indicesOdd));
 

>From 9b22dd2a3cd50b944ea752f57789ce09c9f2f6e9 Mon Sep 17 00:00:00 2001
From: Angel Zhang <angel.zhang at amd.com>
Date: Thu, 13 Jun 2024 13:18:46 +0000
Subject: [PATCH 10/13] Remove duplicate tests

---
 .../VectorToSPIRV/vector-to-spirv.mlir        | 44 +++++--------------
 1 file changed, 10 insertions(+), 34 deletions(-)

diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 87823ab9afc0f..6c6a9a1d0c6c5 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -507,52 +507,28 @@ func.func @interleave_size1(%a: vector<1xf32>, %b: vector<1xf32>) -> vector<2xf3
 
 // -----
 
-// CHECK-LABEL: func @deinterleave_return0
+// CHECK-LABEL: func @deinterleave
 // CHECK-SAME: (%[[ARG0:.+]]: vector<4xf32>)
 //       CHECK: %[[SHUFFLE0:.*]] = spirv.VectorShuffle [0 : i32, 2 : i32] %[[ARG0]], %[[ARG0]] : vector<4xf32>, vector<4xf32> -> vector<2xf32>
 //       CHECK: %[[SHUFFLE1:.*]] = spirv.VectorShuffle [1 : i32, 3 : i32] %[[ARG0]], %[[ARG0]] : vector<4xf32>, vector<4xf32> -> vector<2xf32>
-//       CHECK: return %[[SHUFFLE0]]
-func.func @deinterleave_return0(%a: vector<4xf32>) -> vector<2xf32> {
+//       CHECK: return %[[SHUFFLE0]], %[[SHUFFLE1]]
+func.func @deinterleave(%a: vector<4xf32>) -> (vector<2xf32>, vector<2xf32>) {
   %0, %1 = vector.deinterleave %a : vector<4xf32> -> vector<2xf32>
-  return %0 : vector<2xf32>
+  return %0, %1 : vector<2xf32>, vector<2xf32>
 }
 
 // -----
 
-// CHECK-LABEL: func @deinterleave_return1
-// CHECK-SAME: (%[[ARG0:.+]]: vector<4xf32>)
-//       CHECK: %[[SHUFFLE0:.*]] = spirv.VectorShuffle [0 : i32, 2 : i32] %[[ARG0]], %[[ARG0]] : vector<4xf32>, vector<4xf32> -> vector<2xf32>
-//       CHECK: %[[SHUFFLE1:.*]] = spirv.VectorShuffle [1 : i32, 3 : i32] %[[ARG0]], %[[ARG0]] : vector<4xf32>, vector<4xf32> -> vector<2xf32>
-//       CHECK: return %[[SHUFFLE1]]
-func.func @deinterleave_return1(%a: vector<4xf32>) -> vector<2xf32> {
-  %0, %1 = vector.deinterleave %a : vector<4xf32> -> vector<2xf32>
-  return %1 : vector<2xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func @deinterleave_scalar_return0
+// CHECK-LABEL: func @deinterleave_scalar
 // CHECK-SAME: (%[[ARG0:.+]]: vector<2xf32>)
 //       CHECK: %[[EXTRACT0:.*]] = spirv.CompositeExtract %[[ARG0]][0 : i32] : vector<2xf32>
 //       CHECK: %[[EXTRACT1:.*]] = spirv.CompositeExtract %[[ARG0]][1 : i32] : vector<2xf32>
-//       CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[EXTRACT0]] : f32 to vector<1xf32>
-//       CHECK: return %[[RES]]
-func.func @deinterleave_scalar_return0(%a: vector<2xf32>) -> vector<1xf32> {
+//       CHECK: %[[CAST0:.*]] = builtin.unrealized_conversion_cast %[[EXTRACT0]] : f32 to vector<1xf32>
+//       CHECK: %[[CAST1:.*]] = builtin.unrealized_conversion_cast %[[EXTRACT1]] : f32 to vector<1xf32>
+//       CHECK: return %[[CAST0]], %[[CAST1]]
+func.func @deinterleave_scalar(%a: vector<2xf32>) -> (vector<1xf32>, vector<1xf32>) {
   %0, %1 = vector.deinterleave %a: vector<2xf32> -> vector<1xf32>
-  return %0 : vector<1xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func @deinterleave_scalar_return1
-// CHECK-SAME: (%[[ARG0:.+]]: vector<2xf32>)
-//       CHECK: %[[EXTRACT0:.*]] = spirv.CompositeExtract %[[ARG0]][0 : i32] : vector<2xf32>
-//       CHECK: %[[EXTRACT1:.*]] = spirv.CompositeExtract %[[ARG0]][1 : i32] : vector<2xf32>
-//       CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[EXTRACT1]] : f32 to vector<1xf32>
-//       CHECK: return %[[RES]]
-func.func @deinterleave_scalar_return1(%a: vector<2xf32>) -> vector<1xf32> {
-  %0, %1 = vector.deinterleave %a: vector<2xf32> -> vector<1xf32>
-  return %1 : vector<1xf32>
+  return %0, %1 : vector<1xf32>, vector<1xf32>
 }
 
 // -----

>From e54d9ebf175993fe0b7416afdd85e26acbcd27f3 Mon Sep 17 00:00:00 2001
From: Angel Zhang <anzhouzhang913 at gmail.com>
Date: Thu, 13 Jun 2024 10:34:34 -0400
Subject: [PATCH 11/13] Change comment

Co-authored-by: Jakub Kuderski <kubakuderski at gmail.com>
---
 mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index b785df3f755e8..dfa117fe86cfb 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -654,7 +654,7 @@ struct VectorDeinterleaveOpConvert final
       return success();
     }
 
-    // Indices for `res1`.
+    // Indices for `shuffleEven` (result 0).
     auto seqEven = llvm::seq<int64_t>(n / 2);
     auto indicesEven =
         llvm::map_to_vector(seqEven, [](int i) { return i * 2; });

>From d11a74909fa20c66e6b87fa5299fb76c2f10bd66 Mon Sep 17 00:00:00 2001
From: Angel Zhang <anzhouzhang913 at gmail.com>
Date: Thu, 13 Jun 2024 10:34:45 -0400
Subject: [PATCH 12/13] Change comment

Co-authored-by: Jakub Kuderski <kubakuderski at gmail.com>
---
 mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index dfa117fe86cfb..8baa31a235950 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -659,7 +659,7 @@ struct VectorDeinterleaveOpConvert final
     auto indicesEven =
         llvm::map_to_vector(seqEven, [](int i) { return i * 2; });
 
-    // Indices for `res2`.
+    // Indices for `shuffleOdd` (result 1).
     auto seqOdd = llvm::seq<int64_t>(n / 2);
     auto indicesOdd =
         llvm::map_to_vector(seqOdd, [](int i) { return i * 2 + 1; });

>From db7e425e7c95df071169a1256c4969a08c81b572 Mon Sep 17 00:00:00 2001
From: Angel Zhang <angel.zhang at amd.com>
Date: Thu, 13 Jun 2024 20:02:14 +0000
Subject: [PATCH 13/13] [mlir][spirv] Add integration test for
 vector.deinterleave

---
 .../vector-deinterleave.mlir                  | 81 +++++++++++++++++++
 1 file changed, 81 insertions(+)
 create mode 100644 mlir/test/mlir-vulkan-runner/vector-deinterleave.mlir

diff --git a/mlir/test/mlir-vulkan-runner/vector-deinterleave.mlir b/mlir/test/mlir-vulkan-runner/vector-deinterleave.mlir
new file mode 100644
index 0000000000000..36987ee952ec9
--- /dev/null
+++ b/mlir/test/mlir-vulkan-runner/vector-deinterleave.mlir
@@ -0,0 +1,81 @@
+// RUN: mlir-vulkan-runner %s \
+// RUN:  --shared-libs=%vulkan-runtime-wrappers,%mlir_runner_utils \
+// RUN:  --entry-point-result=void | FileCheck %s
+
+// CHECK: [0, 2]
+// CHECK: [1, 3]
+module attributes {
+  gpu.container_module,
+  spirv.target_env = #spirv.target_env<
+    #spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
+} {
+  gpu.module @kernels {
+    gpu.func @kernel_vector_deinterleave(%arg0 : memref<4xi32>, %arg1 : memref<2xi32>, %arg2 : memref<2xi32>)
+      kernel attributes { spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [1, 1, 1]>} {
+      %idx0 = arith.constant 0 : index
+      %idx1 = arith.constant 1 : index
+      %idx2 = arith.constant 2 : index
+      %idx3 = arith.constant 3 : index
+      %idx4 = arith.constant 4 : index
+
+      %src = arith.constant dense<[0, 0, 0, 0]> : vector<4xi32>
+
+      %val0 = memref.load %arg0[%idx0] : memref<4xi32>
+      %val1 = memref.load %arg0[%idx1] : memref<4xi32>
+      %val2 = memref.load %arg0[%idx2] : memref<4xi32>
+      %val3 = memref.load %arg0[%idx3] : memref<4xi32>
+
+      %src0 = vector.insertelement %val0, %src[%idx0 : index] : vector<4xi32>
+      %src1 = vector.insertelement %val1, %src0[%idx1 : index] : vector<4xi32>
+      %src2 = vector.insertelement %val2, %src1[%idx2 : index] : vector<4xi32>
+      %src3 = vector.insertelement %val3, %src2[%idx3 : index] : vector<4xi32>
+
+      %res0, %res1 = vector.deinterleave %src3 : vector<4xi32> -> vector<2xi32>
+
+      %res0_0 = vector.extractelement %res0[%idx0 : index] : vector<2xi32>
+      %res0_1 = vector.extractelement %res0[%idx1 : index] : vector<2xi32>
+      %res1_0 = vector.extractelement %res1[%idx0 : index] : vector<2xi32>
+      %res1_1 = vector.extractelement %res1[%idx1 : index] : vector<2xi32>
+
+      memref.store %res0_0, %arg1[%idx0]: memref<2xi32>
+      memref.store %res0_1, %arg1[%idx1]: memref<2xi32>
+      memref.store %res1_0, %arg2[%idx0]: memref<2xi32>
+      memref.store %res1_1, %arg2[%idx1]: memref<2xi32>
+
+      gpu.return
+    }
+  }
+
+  func.func @main() {
+    // Allocate 3 buffers.
+    %buf0 = memref.alloc() : memref<4xi32>
+    %buf1 = memref.alloc() : memref<2xi32>
+    %buf2 = memref.alloc() : memref<2xi32>
+
+    %idx0 = arith.constant 0 : index
+    %idx1 = arith.constant 1 : index
+    %idx4 = arith.constant 4 : index
+
+    // Initialize input buffer.
+    %buf0_vals = arith.constant dense<[0, 1, 2, 3]> : vector<4xi32>
+    vector.store %buf0_vals, %buf0[%idx0] : memref<4xi32>, vector<4xi32>
+
+    // Initialize output buffers.
+    %value0 = arith.constant 0 : i32
+    %buf3 = memref.cast %buf1 : memref<2xi32> to memref<?xi32>
+    %buf4 = memref.cast %buf2 : memref<2xi32> to memref<?xi32>
+    call @fillResource1DInt(%buf3, %value0) : (memref<?xi32>, i32) -> ()
+    call @fillResource1DInt(%buf4, %value0) : (memref<?xi32>, i32) -> ()
+
+    gpu.launch_func @kernels::@kernel_vector_deinterleave
+        blocks in (%idx4, %idx1, %idx1) threads in (%idx1, %idx1, %idx1)
+        args(%buf0 : memref<4xi32>, %buf1 : memref<2xi32>, %buf2 : memref<2xi32>)
+    %buf5 = memref.cast %buf3 : memref<?xi32> to memref<*xi32>
+    %buf6 = memref.cast %buf4 : memref<?xi32> to memref<*xi32>
+    call @printMemrefI32(%buf5) : (memref<*xi32>) -> ()
+    call @printMemrefI32(%buf6) : (memref<*xi32>) -> ()
+    return
+  }
+  func.func private @fillResource1DInt(%0 : memref<?xi32>, %1 : i32)
+  func.func private @printMemrefI32(%ptr : memref<*xi32>)
+}



More information about the Mlir-commits mailing list