[Mlir-commits] [mlir] 0b303da - [mlir][vector] add pattern to cast away lead unit dimension for broadcast op
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu May 6 08:02:46 PDT 2021
Author: thomasraoux
Date: 2021-05-06T08:02:17-07:00
New Revision: 0b303da6f821dcbcb3f72135b2431aaf94045839
URL: https://github.com/llvm/llvm-project/commit/0b303da6f821dcbcb3f72135b2431aaf94045839
DIFF: https://github.com/llvm/llvm-project/commit/0b303da6f821dcbcb3f72135b2431aaf94045839.diff
LOG: [mlir][vector] add pattern to cast away lead unit dimension for broadcast op
Differential Revision: https://reviews.llvm.org/D101955
Added:
Modified:
mlir/lib/Dialect/Vector/VectorTransforms.cpp
mlir/test/Dialect/Vector/vector-transforms.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 38042958018a9..f7a2c3e5b51ce 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -3175,6 +3175,31 @@ struct CastAwayTransferWriteLeadingOneDim
}
};
+struct CastAwayBrodcastLeadingOneDim
+ : public OpRewritePattern<vector::BroadcastOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp,
+ PatternRewriter &rewriter) const override {
+ VectorType newDstType = trimLeadingOneDims(broadcastOp.getVectorType());
+ if (newDstType == broadcastOp.getVectorType())
+ return failure();
+ Location loc = broadcastOp.getLoc();
+ VectorType srcVecType = broadcastOp.getSourceType().dyn_cast<VectorType>();
+ if (srcVecType)
+ srcVecType = trimLeadingOneDims(srcVecType);
+ Value source = broadcastOp.source();
+ if (srcVecType && srcVecType != broadcastOp.getSourceType()) {
+ source = rewriter.create<vector::ShapeCastOp>(loc, srcVecType, source);
+ }
+ Value newBroadcastOp =
+ rewriter.create<vector::BroadcastOp>(loc, newDstType, source);
+ rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
+ broadcastOp, broadcastOp.getVectorType(), newBroadcastOp);
+ return success();
+ }
+};
+
// Returns the values in `arrayAttr` as an integer vector.
static SmallVector<int64_t, 4> getIntValueVector(ArrayAttr arrayAttr) {
return llvm::to_vector<4>(
@@ -3771,7 +3796,8 @@ void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
patterns.add<CastAwayExtractStridedSliceLeadingOneDim,
CastAwayInsertStridedSliceLeadingOneDim,
CastAwayTransferReadLeadingOneDim,
- CastAwayTransferWriteLeadingOneDim, ShapeCastOpFolder>(
+ CastAwayTransferWriteLeadingOneDim,
+ CastAwayBrodcastLeadingOneDim, ShapeCastOpFolder>(
patterns.getContext());
}
diff --git a/mlir/test/Dialect/Vector/vector-transforms.mlir b/mlir/test/Dialect/Vector/vector-transforms.mlir
index efa55e6f4cae2..47a12958aa144 100644
--- a/mlir/test/Dialect/Vector/vector-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-transforms.mlir
@@ -672,6 +672,23 @@ func @cast_away_transfer_write_leading_one_dims_one_element(%arg0: memref<1x1x1x
return
}
+// CHECK-LABEL: func @cast_away_broadcast_leading_one_dims
+func @cast_away_broadcast_leading_one_dims(
+ %arg0: vector<8xf32>, %arg1: f32, %arg2: vector<1x4xf32>) ->
+ (vector<1x1x8xf32>, vector<1x1x4xf32>, vector<1x3x4xf32>) {
+ // CHECK: vector.broadcast %{{.*}} : vector<8xf32> to vector<8xf32>
+ // CHECK: vector.shape_cast %{{.*}} : vector<8xf32> to vector<1x1x8xf32>
+ %0 = vector.broadcast %arg0 : vector<8xf32> to vector<1x1x8xf32>
+ // CHECK: vector.broadcast %{{.*}} : f32 to vector<4xf32>
+ // CHECK: vector.shape_cast %{{.*}} : vector<4xf32> to vector<1x1x4xf32>
+ %1 = vector.broadcast %arg1 : f32 to vector<1x1x4xf32>
+ // CHECK: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32>
+ // CHECK: vector.broadcast %{{.*}} : vector<4xf32> to vector<3x4xf32>
+ // CHECK: vector.shape_cast %{{.*}} : vector<3x4xf32> to vector<1x3x4xf32>
+ %2 = vector.broadcast %arg2 : vector<1x4xf32> to vector<1x3x4xf32>
+ return %0, %1, %2: vector<1x1x8xf32>, vector<1x1x4xf32>, vector<1x3x4xf32>
+}
+
// CHECK-LABEL: func @bubble_down_bitcast_in_extract
// CHECK-SAME: %[[SRC:.+]]: vector<4xf32>
func @bubble_down_bitcast_in_extract(%src: vector<4xf32>) -> (f16, f16) {
More information about the Mlir-commits
mailing list