[Mlir-commits] [mlir] 405d6ca - [mlir] Add pattern to handle trivial shape_cast in SPIR-V
Jerry Wu
llvmlistbot at llvm.org
Mon Jun 26 16:10:51 PDT 2023
Author: Jerry Wu
Date: 2023-06-26T23:10:43Z
New Revision: 405d6cada4286bf58c2a2ee6858108a34e623b97
URL: https://github.com/llvm/llvm-project/commit/405d6cada4286bf58c2a2ee6858108a34e623b97
DIFF: https://github.com/llvm/llvm-project/commit/405d6cada4286bf58c2a2ee6858108a34e623b97.diff
LOG: [mlir] Add pattern to handle trivial shape_cast in SPIR-V
Handle the trivial case of size-1 vector.shape_cast.
Differential Revision: https://reviews.llvm.org/D153719
Added:
Modified:
mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 2bc1e36855f1f..903441943f200 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -50,6 +50,29 @@ static int getNumBits(Type type) {
namespace {
+struct VectorShapeCast final : public OpConversionPattern<vector::ShapeCastOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::ShapeCastOp shapeCastOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Type dstType = getTypeConverter()->convertType(shapeCastOp.getType());
+ if (!dstType)
+ return failure();
+
+ // If dstType is same as the source type or the vector size is 1, it can be
+ // directly replaced by the source.
+ if (dstType == adaptor.getSource().getType() ||
+ shapeCastOp.getResultVectorType().getNumElements() == 1) {
+ rewriter.replaceOp(shapeCastOp, adaptor.getSource());
+ return success();
+ }
+
+ // Lowering for size-n vectors when n > 1 hasn't been implemented.
+ return failure();
+ }
+};
+
struct VectorBitcastConvert final
: public OpConversionPattern<vector::BitCastOp> {
using OpConversionPattern::OpConversionPattern;
@@ -551,15 +574,15 @@ struct VectorReductionToDotProd final : OpRewritePattern<vector::ReductionOp> {
void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns) {
- patterns.add<
- VectorBitcastConvert, VectorBroadcastConvert,
- VectorExtractElementOpConvert, VectorExtractOpConvert,
- VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
- VectorFmaOpConvert<spirv::CLFmaOp>, VectorInsertElementOpConvert,
- VectorInsertOpConvert, VectorReductionPattern<GL_MAX_MIN_OPS>,
- VectorReductionPattern<CL_MAX_MIN_OPS>, VectorInsertStridedSliceOpConvert,
- VectorShuffleOpConvert, VectorSplatPattern>(typeConverter,
- patterns.getContext());
+ patterns.add<VectorBitcastConvert, VectorBroadcastConvert,
+ VectorExtractElementOpConvert, VectorExtractOpConvert,
+ VectorExtractStridedSliceOpConvert,
+ VectorFmaOpConvert<spirv::GLFmaOp>,
+ VectorFmaOpConvert<spirv::CLFmaOp>, VectorInsertElementOpConvert,
+ VectorInsertOpConvert, VectorReductionPattern<GL_MAX_MIN_OPS>,
+ VectorReductionPattern<CL_MAX_MIN_OPS>, VectorShapeCast,
+ VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
+ VectorSplatPattern>(typeConverter, patterns.getContext());
}
void mlir::populateVectorReductionToSPIRVDotProductPatterns(
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index bedd3d11e6f93..535da1328d34c 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -568,3 +568,25 @@ func.func @reduction_minui(%v : vector<3xi32>, %s: i32) -> i32 {
%reduce = vector.reduction <minui>, %v, %s : vector<3xi32> into i32
return %reduce : i32
}
+
+// -----
+
+// CHECK-LABEL: @shape_cast_same_type
+// CHECK-SAME: (%[[ARG0:.*]]: vector<2xf32>)
+// CHECK: return %[[ARG0]]
+func.func @shape_cast_same_type(%arg0 : vector<2xf32>) -> vector<2xf32> {
+ %1 = vector.shape_cast %arg0 : vector<2xf32> to vector<2xf32>
+ return %arg0 : vector<2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @shape_cast_size1_vector
+// CHECK-SAME: (%[[ARG0:.*]]: vector<f32>)
+// CHECK: %[[R0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : vector<f32> to f32
+// CHECK: %[[R1:.+]] = builtin.unrealized_conversion_cast %[[R0]] : f32 to vector<1xf32>
+// CHECK: return %[[R1]]
+func.func @shape_cast_size1_vector(%arg0 : vector<f32>) -> vector<1xf32> {
+ %1 = vector.shape_cast %arg0 : vector<f32> to vector<1xf32>
+ return %1 : vector<1xf32>
+}
More information about the Mlir-commits
mailing list