[Mlir-commits] [mlir] 36e68c1 - [mlir][spirv] Add support for converting vector.shuffle
Lei Zhang
llvmlistbot at llvm.org
Fri Feb 4 13:31:34 PST 2022
Author: Lei Zhang
Date: 2022-02-04T16:31:11-05:00
New Revision: 36e68c11ad6f20d058dfad001470bede8d9731dd
URL: https://github.com/llvm/llvm-project/commit/36e68c11ad6f20d058dfad001470bede8d9731dd
DIFF: https://github.com/llvm/llvm-project/commit/36e68c11ad6f20d058dfad001470bede8d9731dd.diff
LOG: [mlir][spirv] Add support for converting vector.shuffle
Reviewed By: ThomasRaoux
Differential Revision: https://reviews.llvm.org/D119030
Added:
Modified:
mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
mlir/test/Conversion/VectorToSPIRV/simple.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index b40052df603b9..4061f57909cac 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -18,7 +18,9 @@
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/STLExtras.h"
#include <numeric>
using namespace mlir;
@@ -264,6 +266,43 @@ class VectorSplatPattern final : public OpConversionPattern<vector::SplatOp> {
}
};
+struct VectorShuffleOpConvert final
+ : public OpConversionPattern<vector::ShuffleOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto oldResultType = shuffleOp.getVectorType();
+ if (!spirv::CompositeType::isValid(oldResultType))
+ return failure();
+ auto newResultType = getTypeConverter()->convertType(oldResultType);
+
+ auto oldSourceType = shuffleOp.getV1VectorType();
+ if (oldSourceType.getNumElements() > 1) {
+ SmallVector<int32_t, 4> components = llvm::to_vector<4>(
+ llvm::map_range(shuffleOp.mask(), [](Attribute attr) -> int32_t {
+ return attr.cast<IntegerAttr>().getValue().getZExtValue();
+ }));
+ rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
+ shuffleOp, newResultType, adaptor.v1(), adaptor.v2(),
+ rewriter.getI32ArrayAttr(components));
+ return success();
+ }
+
+ SmallVector<Value, 2> oldOperands = {adaptor.v1(), adaptor.v2()};
+ SmallVector<Value, 4> newOperands;
+ newOperands.reserve(oldResultType.getNumElements());
+ for (const APInt &i : shuffleOp.mask().getAsValueRange<IntegerAttr>()) {
+ newOperands.push_back(oldOperands[i.getZExtValue()]);
+ }
+ rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
+ shuffleOp, newResultType, newOperands);
+
+ return success();
+ }
+};
+
} // namespace
void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
@@ -272,6 +311,6 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
VectorExtractElementOpConvert, VectorExtractOpConvert,
VectorExtractStridedSliceOpConvert, VectorFmaOpConvert,
VectorInsertElementOpConvert, VectorInsertOpConvert,
- VectorInsertStridedSliceOpConvert, VectorSplatPattern>(
- typeConverter, patterns.getContext());
+ VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
+ VectorSplatPattern>(typeConverter, patterns.getContext());
}
diff --git a/mlir/test/Conversion/VectorToSPIRV/simple.mlir b/mlir/test/Conversion/VectorToSPIRV/simple.mlir
index d93aa21f7dc36..949e59951a4f4 100644
--- a/mlir/test/Conversion/VectorToSPIRV/simple.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/simple.mlir
@@ -179,3 +179,34 @@ func @splat(%f : f32) -> vector<4xf32> {
%splat = vector.splat %f : vector<4xf32>
return %splat : vector<4xf32>
}
+
+// -----
+
+// CHECK-LABEL: func @shuffle
+// CHECK-SAME: %[[ARG0:.+]]: vector<1xf32>, %[[ARG1:.+]]: vector<1xf32>
+// CHECK: %[[V0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
+// CHECK: %[[V1:.+]] = builtin.unrealized_conversion_cast %[[ARG1]]
+// CHECK: spv.CompositeConstruct %[[V0]], %[[V1]], %[[V1]], %[[V0]] : vector<4xf32>
+func @shuffle(%v0 : vector<1xf32>, %v1: vector<1xf32>) -> vector<4xf32> {
+ %shuffle = vector.shuffle %v0, %v1 [0, 1, 1, 0] : vector<1xf32>, vector<1xf32>
+ return %shuffle : vector<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @shuffle
+// CHECK-SAME: %[[V0:.+]]: vector<3xf32>, %[[V1:.+]]: vector<3xf32>
+// CHECK: spv.VectorShuffle [3 : i32, 2 : i32, 5 : i32, 1 : i32] %[[V0]] : vector<3xf32>, %[[V1]] : vector<3xf32> -> vector<4xf32>
+func @shuffle(%v0 : vector<3xf32>, %v1: vector<3xf32>) -> vector<4xf32> {
+ %shuffle = vector.shuffle %v0, %v1 [3, 2, 5, 1] : vector<3xf32>, vector<3xf32>
+ return %shuffle : vector<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @shuffle
+func @shuffle(%v0 : vector<2x16xf32>, %v1: vector<1x16xf32>) -> vector<3x16xf32> {
+ // CHECK: vector.shuffle
+ %shuffle = vector.shuffle %v0, %v1 [0, 1, 2] : vector<2x16xf32>, vector<1x16xf32>
+ return %shuffle : vector<3x16xf32>
+}
More information about the Mlir-commits
mailing list