[Mlir-commits] [mlir] [mlir][spirv] Add vector.interleave to spirv.VectorShuffle conversion (PR #93240)
Angel Zhang
llvmlistbot at llvm.org
Thu May 23 14:25:41 PDT 2024
https://github.com/angelz913 updated https://github.com/llvm/llvm-project/pull/93240
>From 98b4b90759a60b0d2b4a54f7f012ab4e5ef04cde Mon Sep 17 00:00:00 2001
From: Angel Zhang <angel.zhang at amd.com>
Date: Thu, 23 May 2024 21:09:49 +0000
Subject: [PATCH 1/2] [mlir][spirv] Add vector.interleave to
spirv.VectorShuffle conversion
---
.../VectorToSPIRV/VectorToSPIRV.cpp | 46 +++++++++++++++++--
1 file changed, 41 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index c2dd37f481466..b86ebe1a4bb54 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -578,6 +578,44 @@ struct VectorShuffleOpConvert final
}
};
+struct VectorInterleaveOpConvert final
+ : public OpConversionPattern<vector::InterleaveOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // Check the source vector type
+ auto sourceType = interleaveOp.getSourceVectorType();
+ if (sourceType.getRank() != 1 || sourceType.isScalable()) {
+ return rewriter.notifyMatchFailure(interleaveOp,
+ "unsupported source vector type");
+ }
+
+ // Check the result vector type
+ auto oldResultType = interleaveOp.getResultVectorType();
+ Type newResultType = getTypeConverter()->convertType(oldResultType);
+ if (!newResultType)
+ return rewriter.notifyMatchFailure(interleaveOp,
+ "unsupported result vector type");
+
+ // Interleave the indices
+ int n = sourceType.getNumElements();
+ auto seq = llvm::seq<int64_t>(2 * n);
+ auto indices = llvm::to_vector(llvm::map_range(
+ seq, [n](int i) { return (i % 2 ? n : 0) + i / 2; }));
+
+ // Emit a SPIR-V shuffle.
+ rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
+ interleaveOp, newResultType, adaptor.getLhs(), adaptor.getRhs(),
+ rewriter.getI32ArrayAttr(indices));
+
+ llvm::errs() << "vector.interleave to spirv.VectorShuffle succeeded\n";
+
+ return success();
+ }
+};
+
struct VectorLoadOpConverter final
: public OpConversionPattern<vector::LoadOp> {
using OpConversionPattern::OpConversionPattern;
@@ -821,17 +859,15 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
- VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
- VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter>(
+ VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
+ VectorInterleaveOpConvert, VectorSplatPattern,
+ VectorLoadOpConverter, VectorStoreOpConverter>(
typeConverter, patterns.getContext(), PatternBenefit(1));
// Make sure that the more specialized dot product pattern has higher benefit
// than the generic one that extracts all elements.
patterns.add<VectorReductionToFPDotProd>(typeConverter, patterns.getContext(),
PatternBenefit(2));
-
- // Need this until vector.interleave is handled.
- vector::populateVectorInterleaveToShufflePatterns(patterns);
}
void mlir::populateVectorReductionToSPIRVDotProductPatterns(
>From 0622fae6857f836b27d5b32e3f17d5f90c752401 Mon Sep 17 00:00:00 2001
From: Angel Zhang <angel.zhang at amd.com>
Date: Thu, 23 May 2024 21:25:28 +0000
Subject: [PATCH 2/2] Fix style
---
.../Conversion/VectorToSPIRV/VectorToSPIRV.cpp | 16 ++++++++--------
1 file changed, 8 insertions(+), 8 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index b86ebe1a4bb54..c01ff41d8b5e1 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -585,12 +585,12 @@ struct VectorInterleaveOpConvert final
LogicalResult
matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- // Check the source vector type
+ // Check the source vector type
auto sourceType = interleaveOp.getSourceVectorType();
if (sourceType.getRank() != 1 || sourceType.isScalable()) {
return rewriter.notifyMatchFailure(interleaveOp,
"unsupported source vector type");
- }
+ }
// Check the result vector type
auto oldResultType = interleaveOp.getResultVectorType();
@@ -602,8 +602,8 @@ struct VectorInterleaveOpConvert final
// Interleave the indices
int n = sourceType.getNumElements();
auto seq = llvm::seq<int64_t>(2 * n);
- auto indices = llvm::to_vector(llvm::map_range(
- seq, [n](int i) { return (i % 2 ? n : 0) + i / 2; }));
+ auto indices = llvm::to_vector(
+ llvm::map_range(seq, [n](int i) { return (i % 2 ? n : 0) + i / 2; }));
// Emit a SPIR-V shuffle.
rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
@@ -859,10 +859,10 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
- VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
- VectorInterleaveOpConvert, VectorSplatPattern,
- VectorLoadOpConverter, VectorStoreOpConverter>(
- typeConverter, patterns.getContext(), PatternBenefit(1));
+ VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
+ VectorInterleaveOpConvert, VectorSplatPattern, VectorLoadOpConverter,
+ VectorStoreOpConverter>(typeConverter, patterns.getContext(),
+ PatternBenefit(1));
// Make sure that the more specialized dot product pattern has higher benefit
// than the generic one that extracts all elements.
More information about the Mlir-commits
mailing list