[Mlir-commits] [mlir] fc760c0 - [mlir][vector] Fold cancelling vector.shape_cast(vector.broadcast)
Lei Zhang
llvmlistbot at llvm.org
Fri Apr 22 06:02:23 PDT 2022
Author: Lei Zhang
Date: 2022-04-22T08:58:26-04:00
New Revision: fc760c026058aa347b5aa16b548995e89ffe8d31
URL: https://github.com/llvm/llvm-project/commit/fc760c026058aa347b5aa16b548995e89ffe8d31
DIFF: https://github.com/llvm/llvm-project/commit/fc760c026058aa347b5aa16b548995e89ffe8d31.diff
LOG: [mlir][vector] Fold cancelling vector.shape_cast(vector.broadcast)
vector.broadcast can inject all size one dimensions. If it's
followed by a vector.shape_cast to the original type, we can
cancel the op pair, like cancelling consecutive shape_cast ops.
Reviewed By: mravishankar
Differential Revision: https://reviews.llvm.org/D124094
Added:
Modified:
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Dialect/Vector/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index c935349310f78..fbf6675671152 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4088,7 +4088,7 @@ LogicalResult ShapeCastOp::verify() {
}
OpFoldResult ShapeCastOp::fold(ArrayRef<Attribute> operands) {
- // Nop shape cast.
+ // No-op shape cast.
if (getSource().getType() == getResult().getType())
return getSource();
@@ -4113,6 +4113,13 @@ OpFoldResult ShapeCastOp::fold(ArrayRef<Attribute> operands) {
setOperand(otherOp.getSource());
return getResult();
}
+
+ // Cancelling broadcast and shape cast ops.
+ if (auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
+ if (bcastOp.getSourceType() == getType())
+ return bcastOp.getSource();
+ }
+
return {};
}
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index d033b476d497c..824c455aec716 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -635,6 +635,39 @@ func.func @dont_fold_expand_collapse(%arg0: vector<1x1x64xf32>) -> vector<8x8xf3
// -----
+// CHECK-LABEL: func @fold_broadcast_shapecast
+// CHECK-SAME: (%[[V:.+]]: vector<4xf32>)
+// CHECK: return %[[V]]
+func @fold_broadcast_shapecast(%arg0: vector<4xf32>) -> vector<4xf32> {
+ %0 = vector.broadcast %arg0 : vector<4xf32> to vector<1x1x4xf32>
+ %1 = vector.shape_cast %0 : vector<1x1x4xf32> to vector<4xf32>
+ return %1 : vector<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @dont_fold_broadcast_shapecast_scalar
+// CHECK: vector.broadcast
+// CHECK: vector.shape_cast
+func @dont_fold_broadcast_shapecast_scalar(%arg0: f32) -> vector<1xf32> {
+ %0 = vector.broadcast %arg0 : f32 to vector<1x1x1xf32>
+ %1 = vector.shape_cast %0 : vector<1x1x1xf32> to vector<1xf32>
+ return %1 : vector<1xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @dont_fold_broadcast_shapecast_
diff _shape
+// CHECK: vector.broadcast
+// CHECK: vector.shape_cast
+func @dont_fold_broadcast_shapecast_
diff _shape(%arg0: vector<4xf32>) -> vector<8xf32> {
+ %0 = vector.broadcast %arg0 : vector<4xf32> to vector<1x2x4xf32>
+ %1 = vector.shape_cast %0 : vector<1x2x4xf32> to vector<8xf32>
+ return %1 : vector<8xf32>
+}
+
+// -----
+
// CHECK-LABEL: fold_vector_transfers
func.func @fold_vector_transfers(%A: memref<?x8xf32>) -> (vector<4x8xf32>, vector<4x9xf32>) {
%c0 = arith.constant 0 : index
More information about the Mlir-commits
mailing list