[Mlir-commits] [mlir] [mlir][spirv] Implement SPIR-V lowering for `vector.deinterleave` (PR #95313)

Jakub Kuderski llvmlistbot at llvm.org
Wed Jun 12 14:30:04 PDT 2024


================
@@ -618,6 +618,74 @@ struct VectorInterleaveOpConvert final
   }
 };
 
+struct VectorDeinterleaveOpConvert final
+    : public OpConversionPattern<vector::DeinterleaveOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::DeinterleaveOp deinterleaveOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    // Check the result vector type.
+    VectorType oldResultType = deinterleaveOp.getResultVectorType();
+    Type newResultType = getTypeConverter()->convertType(oldResultType);
+    if (!newResultType)
+      return rewriter.notifyMatchFailure(deinterleaveOp,
+                                         "unsupported result vector type");
+
+    // Get location.
+    Location loc = deinterleaveOp->getLoc();
+
+    // Deinterleave the indices.
+    VectorType sourceType = deinterleaveOp.getSourceVectorType();
+    int n = sourceType.getNumElements();
+
+    // Output vectors of size 1 are converted to scalars by the type converter.
+    // We cannot use `spirv::VectorShuffleOp` directly in this case, and need to
+    // use `spirv::CompositeExtractOp`.
+    if (n == 2) {
+      spirv::CompositeExtractOp compositeExtractZero =
+          rewriter.create<spirv::CompositeExtractOp>(
+              loc, newResultType, adaptor.getSource(),
+              rewriter.getI32ArrayAttr({0}));
+
+      spirv::CompositeExtractOp compositeExtractOne =
+          rewriter.create<spirv::CompositeExtractOp>(
+              loc, newResultType, adaptor.getSource(),
+              rewriter.getI32ArrayAttr({1}));
+
+      rewriter.replaceOp(deinterleaveOp,
+                         {compositeExtractZero, compositeExtractOne});
+      return success();
+    }
+
+    // Indices for `res1`.
+    auto seqEven = llvm::seq<int64_t>(n / 2);
+    auto indicesEven =
+        llvm::map_to_vector(seqEven, [](int i) { return i * 2; });
+
+    // Indices for `res2`.
+    auto seqOdd = llvm::seq<int64_t>(n / 2);
+    auto indicesOdd =
+        llvm::map_to_vector(seqOdd, [](int i) { return i * 2 + 1; });
+
+    // Create two SPIR-V shuffles.
+    spirv::VectorShuffleOp shuffleEven =
+        rewriter.create<spirv::VectorShuffleOp>(
+            loc, newResultType, adaptor.getSource(), adaptor.getSource(),
+            rewriter.getI32ArrayAttr(indicesEven));
+
+    spirv::VectorShuffleOp shuffleOdd = rewriter.create<spirv::VectorShuffleOp>(
+        loc, newResultType, adaptor.getSource(), adaptor.getSource(),
+        rewriter.getI32ArrayAttr(indicesOdd));
+
+    // Replace deinterleaveOp with SPIR-V shuffles.
+    rewriter.replaceOp(deinterleaveOp, {shuffleEven, shuffleOdd});
+
----------------
kuhar wrote:

```suggestion
```

https://github.com/llvm/llvm-project/pull/95313


More information about the Mlir-commits mailing list