[Mlir-commits] [mlir] [mlir][spirv] Implement SPIR-V lowering for `vector.deinterleave` (PR #95313)
Angel Zhang
llvmlistbot at llvm.org
Thu Jun 13 05:55:24 PDT 2024
================
@@ -618,6 +618,74 @@ struct VectorInterleaveOpConvert final
}
};
+struct VectorDeinterleaveOpConvert final
+ : public OpConversionPattern<vector::DeinterleaveOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::DeinterleaveOp deinterleaveOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ // Check the result vector type.
+ VectorType oldResultType = deinterleaveOp.getResultVectorType();
+ Type newResultType = getTypeConverter()->convertType(oldResultType);
+ if (!newResultType)
+ return rewriter.notifyMatchFailure(deinterleaveOp,
+ "unsupported result vector type");
+
+ // Get location.
+ Location loc = deinterleaveOp->getLoc();
+
+ // Deinterleave the indices.
+ VectorType sourceType = deinterleaveOp.getSourceVectorType();
+ int n = sourceType.getNumElements();
+
+ // Output vectors of size 1 are converted to scalars by the type converter.
+ // We cannot use `spirv::VectorShuffleOp` directly in this case, and need to
+ // use `spirv::CompositeExtractOp`.
+ if (n == 2) {
+ spirv::CompositeExtractOp compositeExtractZero =
+ rewriter.create<spirv::CompositeExtractOp>(
+ loc, newResultType, adaptor.getSource(),
+ rewriter.getI32ArrayAttr({0}));
+
+ spirv::CompositeExtractOp compositeExtractOne =
+ rewriter.create<spirv::CompositeExtractOp>(
+ loc, newResultType, adaptor.getSource(),
+ rewriter.getI32ArrayAttr({1}));
+
+ rewriter.replaceOp(deinterleaveOp,
+ {compositeExtractZero, compositeExtractOne});
+ return success();
+ }
+
+ // Indices for `res1`.
+ auto seqEven = llvm::seq<int64_t>(n / 2);
+ auto indicesEven =
+ llvm::map_to_vector(seqEven, [](int i) { return i * 2; });
+
+ // Indices for `res2`.
+ auto seqOdd = llvm::seq<int64_t>(n / 2);
+ auto indicesOdd =
+ llvm::map_to_vector(seqOdd, [](int i) { return i * 2 + 1; });
+
+ // Create two SPIR-V shuffles.
+ spirv::VectorShuffleOp shuffleEven =
+ rewriter.create<spirv::VectorShuffleOp>(
+ loc, newResultType, adaptor.getSource(), adaptor.getSource(),
+ rewriter.getI32ArrayAttr(indicesEven));
+
+ spirv::VectorShuffleOp shuffleOdd = rewriter.create<spirv::VectorShuffleOp>(
+ loc, newResultType, adaptor.getSource(), adaptor.getSource(),
----------------
angelz913 wrote:
Fixed
https://github.com/llvm/llvm-project/pull/95313
More information about the Mlir-commits
mailing list