[Mlir-commits] [mlir] fc760c0 - [mlir][vector] Fold cancelling vector.shape_cast(vector.broadcast)

Lei Zhang llvmlistbot at llvm.org
Fri Apr 22 06:02:23 PDT 2022


Author: Lei Zhang
Date: 2022-04-22T08:58:26-04:00
New Revision: fc760c026058aa347b5aa16b548995e89ffe8d31

URL: https://github.com/llvm/llvm-project/commit/fc760c026058aa347b5aa16b548995e89ffe8d31
DIFF: https://github.com/llvm/llvm-project/commit/fc760c026058aa347b5aa16b548995e89ffe8d31.diff

LOG: [mlir][vector] Fold cancelling vector.shape_cast(vector.broadcast)

vector.broadcast can inject all size one dimensions. If it's
followed by a vector.shape_cast to the original type, we can
cancel the op pair, like cancelling consecutive shape_cast ops.

Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D124094

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Dialect/Vector/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index c935349310f78..fbf6675671152 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4088,7 +4088,7 @@ LogicalResult ShapeCastOp::verify() {
 }
 
 OpFoldResult ShapeCastOp::fold(ArrayRef<Attribute> operands) {
-  // Nop shape cast.
+  // No-op shape cast.
   if (getSource().getType() == getResult().getType())
     return getSource();
 
@@ -4113,6 +4113,13 @@ OpFoldResult ShapeCastOp::fold(ArrayRef<Attribute> operands) {
     setOperand(otherOp.getSource());
     return getResult();
   }
+
+  // Cancelling broadcast and shape cast ops.
+  if (auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
+    if (bcastOp.getSourceType() == getType())
+      return bcastOp.getSource();
+  }
+
   return {};
 }
 

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index d033b476d497c..824c455aec716 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -635,6 +635,39 @@ func.func @dont_fold_expand_collapse(%arg0: vector<1x1x64xf32>) -> vector<8x8xf3
 
 // -----
 
+// CHECK-LABEL: func @fold_broadcast_shapecast
+//  CHECK-SAME: (%[[V:.+]]: vector<4xf32>)
+//       CHECK:   return %[[V]]
+func @fold_broadcast_shapecast(%arg0: vector<4xf32>) -> vector<4xf32> {
+    %0 = vector.broadcast %arg0 : vector<4xf32> to vector<1x1x4xf32>
+    %1 = vector.shape_cast %0 : vector<1x1x4xf32> to vector<4xf32>
+    return %1 : vector<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @dont_fold_broadcast_shapecast_scalar
+//       CHECK:   vector.broadcast
+//       CHECK:   vector.shape_cast
+func @dont_fold_broadcast_shapecast_scalar(%arg0: f32) -> vector<1xf32> {
+    %0 = vector.broadcast %arg0 : f32 to vector<1x1x1xf32>
+    %1 = vector.shape_cast %0 : vector<1x1x1xf32> to vector<1xf32>
+    return %1 : vector<1xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @dont_fold_broadcast_shapecast_
diff _shape
+//       CHECK:   vector.broadcast
+//       CHECK:   vector.shape_cast
+func @dont_fold_broadcast_shapecast_
diff _shape(%arg0: vector<4xf32>) -> vector<8xf32> {
+    %0 = vector.broadcast %arg0 : vector<4xf32> to vector<1x2x4xf32>
+    %1 = vector.shape_cast %0 : vector<1x2x4xf32> to vector<8xf32>
+    return %1 : vector<8xf32>
+}
+
+// -----
+
 // CHECK-LABEL: fold_vector_transfers
 func.func @fold_vector_transfers(%A: memref<?x8xf32>) -> (vector<4x8xf32>, vector<4x9xf32>) {
   %c0 = arith.constant 0 : index


        


More information about the Mlir-commits mailing list