[llvm] [mlir] [mlir][spirv] Add vector.interleave to spirv.VectorShuffle conversion (PR #93240)
Angel Zhang via llvm-commits
llvm-commits at lists.llvm.org
Mon May 27 10:48:40 PDT 2024
https://github.com/angelz913 updated https://github.com/llvm/llvm-project/pull/93240
>From cdc9def9d70feb5e3acaea7491c31d06039460d0 Mon Sep 17 00:00:00 2001
From: Angel Zhang <angel.zhang at amd.com>
Date: Thu, 23 May 2024 21:09:49 +0000
Subject: [PATCH 1/6] [mlir][spirv] Add vector.interleave to
spirv.VectorShuffle conversion
---
.../VectorToSPIRV/VectorToSPIRV.cpp | 44 ++++++++++++++++---
1 file changed, 39 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index c2dd37f481466..95464ef6d438e 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -578,6 +578,42 @@ struct VectorShuffleOpConvert final
}
};
+struct VectorInterleaveOpConvert final
+ : public OpConversionPattern<vector::InterleaveOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // Check the source vector type
+ auto sourceType = interleaveOp.getSourceVectorType();
+ if (sourceType.getRank() != 1 || sourceType.isScalable()) {
+ return rewriter.notifyMatchFailure(interleaveOp,
+ "unsupported source vector type");
+ }
+
+ // Check the result vector type
+ auto oldResultType = interleaveOp.getResultVectorType();
+ Type newResultType = getTypeConverter()->convertType(oldResultType);
+ if (!newResultType)
+ return rewriter.notifyMatchFailure(interleaveOp,
+ "unsupported result vector type");
+
+ // Interleave the indices
+ int n = sourceType.getNumElements();
+ auto seq = llvm::seq<int64_t>(2 * n);
+ auto indices = llvm::to_vector(
+ llvm::map_range(seq, [n](int i) { return (i % 2 ? n : 0) + i / 2; }));
+
+ // Emit a SPIR-V shuffle.
+ rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
+ interleaveOp, newResultType, adaptor.getLhs(), adaptor.getRhs(),
+ rewriter.getI32ArrayAttr(indices));
+
+ return success();
+ }
+};
+
struct VectorLoadOpConverter final
: public OpConversionPattern<vector::LoadOp> {
using OpConversionPattern::OpConversionPattern;
@@ -822,16 +858,14 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
- VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter>(
- typeConverter, patterns.getContext(), PatternBenefit(1));
+ VectorInterleaveOpConvert, VectorSplatPattern, VectorLoadOpConverter,
+ VectorStoreOpConverter>(typeConverter, patterns.getContext(),
+ PatternBenefit(1));
// Make sure that the more specialized dot product pattern has higher benefit
// than the generic one that extracts all elements.
patterns.add<VectorReductionToFPDotProd>(typeConverter, patterns.getContext(),
PatternBenefit(2));
-
- // Need this until vector.interleave is handled.
- vector::populateVectorInterleaveToShufflePatterns(patterns);
}
void mlir::populateVectorReductionToSPIRVDotProductPatterns(
>From 51b25d480e8661b124dceaa7dcb2840320bc01e7 Mon Sep 17 00:00:00 2001
From: Angel Zhang <anzhouzhang913 at gmail.com>
Date: Mon, 27 May 2024 08:33:14 -0400
Subject: [PATCH 2/6] Use VectorType for sourceType
Co-authored-by: Jakub Kuderski <kubakuderski at gmail.com>
---
mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 95464ef6d438e..aa3670f81fea3 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -586,7 +586,7 @@ struct VectorInterleaveOpConvert final
matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Check the source vector type
- auto sourceType = interleaveOp.getSourceVectorType();
+ VectorType sourceType = interleaveOp.getSourceVectorType();
if (sourceType.getRank() != 1 || sourceType.isScalable()) {
return rewriter.notifyMatchFailure(interleaveOp,
"unsupported source vector type");
>From fedd990a69c8ba1b79c4905c29467d9e352ba96c Mon Sep 17 00:00:00 2001
From: Angel Zhang <angel.zhang at amd.com>
Date: Mon, 27 May 2024 12:42:35 +0000
Subject: [PATCH 3/6] Use VectorType for oldResultType
---
mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index aa3670f81fea3..0af0595eebe0d 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -593,7 +593,7 @@ struct VectorInterleaveOpConvert final
}
// Check the result vector type
- auto oldResultType = interleaveOp.getResultVectorType();
+ VectorType oldResultType = interleaveOp.getResultVectorType();
Type newResultType = getTypeConverter()->convertType(oldResultType);
if (!newResultType)
return rewriter.notifyMatchFailure(interleaveOp,
>From cf476130280e8f5bda6e77bcee7c1aff12070443 Mon Sep 17 00:00:00 2001
From: Angel Zhang <angel.zhang at amd.com>
Date: Mon, 27 May 2024 16:06:10 +0000
Subject: [PATCH 4/6] Handle one-element input vector case and remove
cmake/bazel dependencies
---
mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt | 1 -
.../Conversion/VectorToSPIRV/VectorToSPIRV.cpp | 15 ++++++++++++++-
.../Conversion/VectorToSPIRV/vector-to-spirv.mlir | 13 +++++++++++++
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel | 1 -
4 files changed, 27 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt
index 113983146f5be..bb9f793d7fe0f 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt
+++ b/mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt
@@ -14,6 +14,5 @@ add_mlir_conversion_library(MLIRVectorToSPIRV
MLIRSPIRVDialect
MLIRSPIRVConversion
MLIRVectorDialect
- MLIRVectorTransforms
MLIRTransforms
)
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 0af0595eebe0d..a63ef5ab451eb 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -601,6 +601,19 @@ struct VectorInterleaveOpConvert final
// Interleave the indices
int n = sourceType.getNumElements();
+
+ // Input vectors of size 1 are converted to scalars by the type converter.
+ // We cannot use spirv::VectorShuffleOp directly in this case, and need to
+ // use spirv::CompositeConstructOp.
+ if (n == 1) {
+ SmallVector<Value> newOperands(2);
+ newOperands[0] = adaptor.getLhs();
+ newOperands[1] = adaptor.getRhs();
+ rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
+ interleaveOp, newResultType, newOperands);
+ return success();
+ }
+
auto seq = llvm::seq<int64_t>(2 * n);
auto indices = llvm::to_vector(
llvm::map_range(seq, [n](int i) { return (i % 2 ? n : 0) + i / 2; }));
@@ -609,7 +622,7 @@ struct VectorInterleaveOpConvert final
rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
interleaveOp, newResultType, adaptor.getLhs(), adaptor.getRhs(),
rewriter.getI32ArrayAttr(indices));
-
+
return success();
}
};
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index a7542086aa766..41e5823cb4ec7 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -494,6 +494,19 @@ func.func @interleave(%a: vector<2xf32>, %b: vector<2xf32>) -> vector<4xf32> {
// -----
+// CHECK-LABEL: func @interleave_size1
+// CHECK-SAME: (%[[ARG0:.+]]: vector<1xf32>, %[[ARG1:.+]]: vector<1xf32>)
+// CHECK: %[[V0:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : vector<1xf32> to f32
+// CHECK: %[[V1:.*]] = builtin.unrealized_conversion_cast %[[ARG1]] : vector<1xf32> to f32
+// CHECK: %[[RES:.*]] = spirv.CompositeConstruct %[[V0]], %[[V1]] : (f32, f32) -> vector<2xf32>
+// CHECK: return %[[RES]]
+func.func @interleave_size1(%a: vector<1xf32>, %b: vector<1xf32>) -> vector<2xf32> {
+ %0 = vector.interleave %a, %b : vector<1xf32>
+ return %0 : vector<2xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @reduction_add
// CHECK-SAME: (%[[V:.+]]: vector<4xi32>)
// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<4xi32>
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 614fe511d43a5..0c5fa11bc85ec 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -4979,7 +4979,6 @@ cc_library(
":VectorToLLVM",
":VectorToSCF",
":VectorTransformOpsIncGen",
- ":VectorTransforms",
":X86VectorTransforms",
],
)
>From 78d46085ff64a7f6dad53ad5cc061c6514a9826b Mon Sep 17 00:00:00 2001
From: Angel Zhang <angel.zhang at amd.com>
Date: Mon, 27 May 2024 16:39:02 +0000
Subject: [PATCH 5/6] Remove check for source type and reformat code
---
.../lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp | 14 ++++----------
1 file changed, 4 insertions(+), 10 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index a63ef5ab451eb..69f89d087dd3c 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -585,13 +585,6 @@ struct VectorInterleaveOpConvert final
LogicalResult
matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- // Check the source vector type
- VectorType sourceType = interleaveOp.getSourceVectorType();
- if (sourceType.getRank() != 1 || sourceType.isScalable()) {
- return rewriter.notifyMatchFailure(interleaveOp,
- "unsupported source vector type");
- }
-
// Check the result vector type
VectorType oldResultType = interleaveOp.getResultVectorType();
Type newResultType = getTypeConverter()->convertType(oldResultType);
@@ -600,10 +593,11 @@ struct VectorInterleaveOpConvert final
"unsupported result vector type");
// Interleave the indices
+ VectorType sourceType = interleaveOp.getSourceVectorType();
int n = sourceType.getNumElements();
- // Input vectors of size 1 are converted to scalars by the type converter.
- // We cannot use spirv::VectorShuffleOp directly in this case, and need to
+ // Input vectors of size 1 are converted to scalars by the type converter.
+ // We cannot use spirv::VectorShuffleOp directly in this case, and need to
// use spirv::CompositeConstructOp.
if (n == 1) {
SmallVector<Value> newOperands(2);
@@ -622,7 +616,7 @@ struct VectorInterleaveOpConvert final
rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
interleaveOp, newResultType, adaptor.getLhs(), adaptor.getRhs(),
rewriter.getI32ArrayAttr(indices));
-
+
return success();
}
};
>From 24c6d248e9cf1db12bf91fe1a33baf24de1ed9a6 Mon Sep 17 00:00:00 2001
From: Angel Zhang <anzhouzhang913 at gmail.com>
Date: Mon, 27 May 2024 13:39:33 -0400
Subject: [PATCH 6/6] Reformat code
Co-authored-by: Jakub Kuderski <kubakuderski at gmail.com>
---
mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp | 12 +++++-------
1 file changed, 5 insertions(+), 7 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 69f89d087dd3c..043b0741729d6 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -585,24 +585,22 @@ struct VectorInterleaveOpConvert final
LogicalResult
matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- // Check the result vector type
+ // Check the result vector type.
VectorType oldResultType = interleaveOp.getResultVectorType();
Type newResultType = getTypeConverter()->convertType(oldResultType);
if (!newResultType)
return rewriter.notifyMatchFailure(interleaveOp,
"unsupported result vector type");
- // Interleave the indices
+ // Interleave the indices.
VectorType sourceType = interleaveOp.getSourceVectorType();
int n = sourceType.getNumElements();
// Input vectors of size 1 are converted to scalars by the type converter.
- // We cannot use spirv::VectorShuffleOp directly in this case, and need to
- // use spirv::CompositeConstructOp.
+ // We cannot use `spirv::VectorShuffleOp` directly in this case, and need to
+ // use `spirv::CompositeConstructOp`.
if (n == 1) {
- SmallVector<Value> newOperands(2);
- newOperands[0] = adaptor.getLhs();
- newOperands[1] = adaptor.getRhs();
+ Value newOperands[] = {adaptor.getLhs(), adaptor.getRhs()};
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
interleaveOp, newResultType, newOperands);
return success();
More information about the llvm-commits
mailing list