[Mlir-commits] [mlir] [mlir][spirv] Implement SPIR-V lowering for `vector.deinterleave` (PR #95313)

Angel Zhang llvmlistbot at llvm.org
Wed Jun 12 14:30:44 PDT 2024


https://github.com/angelz913 updated https://github.com/llvm/llvm-project/pull/95313

>From 0c1b1b27725efeb62e571fc014da09ae16304d9f Mon Sep 17 00:00:00 2001
From: Angel Zhang <angel.zhang at amd.com>
Date: Wed, 12 Jun 2024 20:50:31 +0000
Subject: [PATCH 1/2] [mlir][spirv] Implement SPIR-V lowering for
 vector.deinterleave

1. Added a conversion for vector.deinterleave to the VectorToSPIRV pass.
2. Added LIT tests for the new conversion.
---
 .../VectorToSPIRV/VectorToSPIRV.cpp           | 74 ++++++++++++++++++-
 .../VectorToSPIRV/vector-to-spirv.mlir        | 50 +++++++++++++
 2 files changed, 121 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 92168cfa36147..b9a086cfc91a4 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -618,6 +618,74 @@ struct VectorInterleaveOpConvert final
   }
 };
 
+struct VectorDeinterleaveOpConvert final
+    : public OpConversionPattern<vector::DeinterleaveOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::DeinterleaveOp deinterleaveOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    // Check the result vector type.
+    VectorType oldResultType = deinterleaveOp.getResultVectorType();
+    Type newResultType = getTypeConverter()->convertType(oldResultType);
+    if (!newResultType)
+      return rewriter.notifyMatchFailure(deinterleaveOp,
+                                         "unsupported result vector type");
+
+    // Get location.
+    Location loc = deinterleaveOp->getLoc();
+
+    // Deinterleave the indices.
+    VectorType sourceType = deinterleaveOp.getSourceVectorType();
+    int n = sourceType.getNumElements();
+
+    // Output vectors of size 1 are converted to scalars by the type converter.
+    // We cannot use `spirv::VectorShuffleOp` directly in this case, and need to
+    // use `spirv::CompositeExtractOp`.
+    if (n == 2) {
+      spirv::CompositeExtractOp compositeExtractZero =
+          rewriter.create<spirv::CompositeExtractOp>(
+              loc, newResultType, adaptor.getSource(),
+              rewriter.getI32ArrayAttr({0}));
+
+      spirv::CompositeExtractOp compositeExtractOne =
+          rewriter.create<spirv::CompositeExtractOp>(
+              loc, newResultType, adaptor.getSource(),
+              rewriter.getI32ArrayAttr({1}));
+
+      rewriter.replaceOp(deinterleaveOp,
+                         {compositeExtractZero, compositeExtractOne});
+      return success();
+    }
+
+    // Indices for `res1`.
+    auto seqEven = llvm::seq<int64_t>(n / 2);
+    auto indicesEven =
+        llvm::map_to_vector(seqEven, [](int i) { return i * 2; });
+
+    // Indices for `res2`.
+    auto seqOdd = llvm::seq<int64_t>(n / 2);
+    auto indicesOdd =
+        llvm::map_to_vector(seqOdd, [](int i) { return i * 2 + 1; });
+
+    // Create two SPIR-V shuffles.
+    spirv::VectorShuffleOp shuffleEven =
+        rewriter.create<spirv::VectorShuffleOp>(
+            loc, newResultType, adaptor.getSource(), adaptor.getSource(),
+            rewriter.getI32ArrayAttr(indicesEven));
+
+    spirv::VectorShuffleOp shuffleOdd = rewriter.create<spirv::VectorShuffleOp>(
+        loc, newResultType, adaptor.getSource(), adaptor.getSource(),
+        rewriter.getI32ArrayAttr(indicesOdd));
+
+    // Replace deinterleaveOp with SPIR-V shuffles.
+    rewriter.replaceOp(deinterleaveOp, {shuffleEven, shuffleOdd});
+
+    return success();
+  }
+};
+
 struct VectorLoadOpConverter final
     : public OpConversionPattern<vector::LoadOp> {
   using OpConversionPattern::OpConversionPattern;
@@ -862,9 +930,9 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
       VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
       VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
       VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
-      VectorInterleaveOpConvert, VectorSplatPattern, VectorLoadOpConverter,
-      VectorStoreOpConverter>(typeConverter, patterns.getContext(),
-                              PatternBenefit(1));
+      VectorInterleaveOpConvert, VectorDeinterleaveOpConvert,
+      VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter>(
+      typeConverter, patterns.getContext(), PatternBenefit(1));
 
   // Make sure that the more specialized dot product pattern has higher benefit
   // than the generic one that extracts all elements.
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 2592d0fc04111..87823ab9afc0f 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -507,6 +507,56 @@ func.func @interleave_size1(%a: vector<1xf32>, %b: vector<1xf32>) -> vector<2xf3
 
 // -----
 
+// CHECK-LABEL: func @deinterleave_return0
+// CHECK-SAME: (%[[ARG0:.+]]: vector<4xf32>)
+//       CHECK: %[[SHUFFLE0:.*]] = spirv.VectorShuffle [0 : i32, 2 : i32] %[[ARG0]], %[[ARG0]] : vector<4xf32>, vector<4xf32> -> vector<2xf32>
+//       CHECK: %[[SHUFFLE1:.*]] = spirv.VectorShuffle [1 : i32, 3 : i32] %[[ARG0]], %[[ARG0]] : vector<4xf32>, vector<4xf32> -> vector<2xf32>
+//       CHECK: return %[[SHUFFLE0]]
+func.func @deinterleave_return0(%a: vector<4xf32>) -> vector<2xf32> {
+  %0, %1 = vector.deinterleave %a : vector<4xf32> -> vector<2xf32>
+  return %0 : vector<2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @deinterleave_return1
+// CHECK-SAME: (%[[ARG0:.+]]: vector<4xf32>)
+//       CHECK: %[[SHUFFLE0:.*]] = spirv.VectorShuffle [0 : i32, 2 : i32] %[[ARG0]], %[[ARG0]] : vector<4xf32>, vector<4xf32> -> vector<2xf32>
+//       CHECK: %[[SHUFFLE1:.*]] = spirv.VectorShuffle [1 : i32, 3 : i32] %[[ARG0]], %[[ARG0]] : vector<4xf32>, vector<4xf32> -> vector<2xf32>
+//       CHECK: return %[[SHUFFLE1]]
+func.func @deinterleave_return1(%a: vector<4xf32>) -> vector<2xf32> {
+  %0, %1 = vector.deinterleave %a : vector<4xf32> -> vector<2xf32>
+  return %1 : vector<2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @deinterleave_scalar_return0
+// CHECK-SAME: (%[[ARG0:.+]]: vector<2xf32>)
+//       CHECK: %[[EXTRACT0:.*]] = spirv.CompositeExtract %[[ARG0]][0 : i32] : vector<2xf32>
+//       CHECK: %[[EXTRACT1:.*]] = spirv.CompositeExtract %[[ARG0]][1 : i32] : vector<2xf32>
+//       CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[EXTRACT0]] : f32 to vector<1xf32>
+//       CHECK: return %[[RES]]
+func.func @deinterleave_scalar_return0(%a: vector<2xf32>) -> vector<1xf32> {
+  %0, %1 = vector.deinterleave %a: vector<2xf32> -> vector<1xf32>
+  return %0 : vector<1xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @deinterleave_scalar_return1
+// CHECK-SAME: (%[[ARG0:.+]]: vector<2xf32>)
+//       CHECK: %[[EXTRACT0:.*]] = spirv.CompositeExtract %[[ARG0]][0 : i32] : vector<2xf32>
+//       CHECK: %[[EXTRACT1:.*]] = spirv.CompositeExtract %[[ARG0]][1 : i32] : vector<2xf32>
+//       CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[EXTRACT1]] : f32 to vector<1xf32>
+//       CHECK: return %[[RES]]
+func.func @deinterleave_scalar_return1(%a: vector<2xf32>) -> vector<1xf32> {
+  %0, %1 = vector.deinterleave %a: vector<2xf32> -> vector<1xf32>
+  return %1 : vector<1xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @reduction_add
 //  CHECK-SAME: (%[[V:.+]]: vector<4xi32>)
 //       CHECK:   %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<4xi32>

>From 7074fac5717573e149fb1b85a5e8250919d9a009 Mon Sep 17 00:00:00 2001
From: Angel Zhang <anzhouzhang913 at gmail.com>
Date: Wed, 12 Jun 2024 17:30:36 -0400
Subject: [PATCH 2/2] Update
 mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp

Co-authored-by: Jakub Kuderski <kubakuderski at gmail.com>
---
 mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp | 1 -
 1 file changed, 1 deletion(-)

diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index b9a086cfc91a4..0c0fc473b2190 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -633,7 +633,6 @@ struct VectorDeinterleaveOpConvert final
       return rewriter.notifyMatchFailure(deinterleaveOp,
                                          "unsupported result vector type");
 
-    // Get location.
     Location loc = deinterleaveOp->getLoc();
 
     // Deinterleave the indices.



More information about the Mlir-commits mailing list