[Mlir-commits] [mlir] 597cde1 - [mlir][spirv] Implement SPIR-V lowering for `vector.deinterleave` (#95313)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jun 13 14:16:57 PDT 2024
Author: Angel Zhang
Date: 2024-06-13T17:16:53-04:00
New Revision: 597cde155a008364c83870b24306fbae93e80cf8
URL: https://github.com/llvm/llvm-project/commit/597cde155a008364c83870b24306fbae93e80cf8
DIFF: https://github.com/llvm/llvm-project/commit/597cde155a008364c83870b24306fbae93e80cf8.diff
LOG: [mlir][spirv] Implement SPIR-V lowering for `vector.deinterleave` (#95313)
1. Added a conversion for `vector.deinterleave` to the `VectorToSPIRV`
pass.
2. Added LIT tests for the new conversion.
---------
Co-authored-by: Jakub Kuderski <kubakuderski at gmail.com>
Added:
Modified:
mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 92168cfa36147..8baa31a235950 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -618,6 +618,66 @@ struct VectorInterleaveOpConvert final
}
};
+struct VectorDeinterleaveOpConvert final
+ : public OpConversionPattern<vector::DeinterleaveOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::DeinterleaveOp deinterleaveOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ // Check the result vector type.
+ VectorType oldResultType = deinterleaveOp.getResultVectorType();
+ Type newResultType = getTypeConverter()->convertType(oldResultType);
+ if (!newResultType)
+ return rewriter.notifyMatchFailure(deinterleaveOp,
+ "unsupported result vector type");
+
+ Location loc = deinterleaveOp->getLoc();
+
+ // Deinterleave the indices.
+ Value sourceVector = adaptor.getSource();
+ VectorType sourceType = deinterleaveOp.getSourceVectorType();
+ int n = sourceType.getNumElements();
+
+ // Output vectors of size 1 are converted to scalars by the type converter.
+ // We cannot use `spirv::VectorShuffleOp` directly in this case, and need to
+ // use `spirv::CompositeExtractOp`.
+ if (n == 2) {
+ auto elem0 = rewriter.create<spirv::CompositeExtractOp>(
+ loc, newResultType, sourceVector, rewriter.getI32ArrayAttr({0}));
+
+ auto elem1 = rewriter.create<spirv::CompositeExtractOp>(
+ loc, newResultType, sourceVector, rewriter.getI32ArrayAttr({1}));
+
+ rewriter.replaceOp(deinterleaveOp, {elem0, elem1});
+ return success();
+ }
+
+ // Indices for `shuffleEven` (result 0).
+ auto seqEven = llvm::seq<int64_t>(n / 2);
+ auto indicesEven =
+ llvm::map_to_vector(seqEven, [](int i) { return i * 2; });
+
+ // Indices for `shuffleOdd` (result 1).
+ auto seqOdd = llvm::seq<int64_t>(n / 2);
+ auto indicesOdd =
+ llvm::map_to_vector(seqOdd, [](int i) { return i * 2 + 1; });
+
+ // Create two SPIR-V shuffles.
+ auto shuffleEven = rewriter.create<spirv::VectorShuffleOp>(
+ loc, newResultType, sourceVector, sourceVector,
+ rewriter.getI32ArrayAttr(indicesEven));
+
+ auto shuffleOdd = rewriter.create<spirv::VectorShuffleOp>(
+ loc, newResultType, sourceVector, sourceVector,
+ rewriter.getI32ArrayAttr(indicesOdd));
+
+ rewriter.replaceOp(deinterleaveOp, {shuffleEven, shuffleOdd});
+ return success();
+ }
+};
+
struct VectorLoadOpConverter final
: public OpConversionPattern<vector::LoadOp> {
using OpConversionPattern::OpConversionPattern;
@@ -862,9 +922,9 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
- VectorInterleaveOpConvert, VectorSplatPattern, VectorLoadOpConverter,
- VectorStoreOpConverter>(typeConverter, patterns.getContext(),
- PatternBenefit(1));
+ VectorInterleaveOpConvert, VectorDeinterleaveOpConvert,
+ VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter>(
+ typeConverter, patterns.getContext(), PatternBenefit(1));
// Make sure that the more specialized dot product pattern has higher benefit
// than the generic one that extracts all elements.
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 2592d0fc04111..6c6a9a1d0c6c5 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -507,6 +507,32 @@ func.func @interleave_size1(%a: vector<1xf32>, %b: vector<1xf32>) -> vector<2xf3
// -----
+// CHECK-LABEL: func @deinterleave
+// CHECK-SAME: (%[[ARG0:.+]]: vector<4xf32>)
+// CHECK: %[[SHUFFLE0:.*]] = spirv.VectorShuffle [0 : i32, 2 : i32] %[[ARG0]], %[[ARG0]] : vector<4xf32>, vector<4xf32> -> vector<2xf32>
+// CHECK: %[[SHUFFLE1:.*]] = spirv.VectorShuffle [1 : i32, 3 : i32] %[[ARG0]], %[[ARG0]] : vector<4xf32>, vector<4xf32> -> vector<2xf32>
+// CHECK: return %[[SHUFFLE0]], %[[SHUFFLE1]]
+func.func @deinterleave(%a: vector<4xf32>) -> (vector<2xf32>, vector<2xf32>) {
+ %0, %1 = vector.deinterleave %a : vector<4xf32> -> vector<2xf32>
+ return %0, %1 : vector<2xf32>, vector<2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @deinterleave_scalar
+// CHECK-SAME: (%[[ARG0:.+]]: vector<2xf32>)
+// CHECK: %[[EXTRACT0:.*]] = spirv.CompositeExtract %[[ARG0]][0 : i32] : vector<2xf32>
+// CHECK: %[[EXTRACT1:.*]] = spirv.CompositeExtract %[[ARG0]][1 : i32] : vector<2xf32>
+// CHECK: %[[CAST0:.*]] = builtin.unrealized_conversion_cast %[[EXTRACT0]] : f32 to vector<1xf32>
+// CHECK: %[[CAST1:.*]] = builtin.unrealized_conversion_cast %[[EXTRACT1]] : f32 to vector<1xf32>
+// CHECK: return %[[CAST0]], %[[CAST1]]
+func.func @deinterleave_scalar(%a: vector<2xf32>) -> (vector<1xf32>, vector<1xf32>) {
+ %0, %1 = vector.deinterleave %a: vector<2xf32> -> vector<1xf32>
+ return %0, %1 : vector<1xf32>, vector<1xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @reduction_add
// CHECK-SAME: (%[[V:.+]]: vector<4xi32>)
// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<4xi32>
More information about the Mlir-commits
mailing list