[Mlir-commits] [mlir] [mlir][vector] shape_cast(broadcast) -> broadcast canonicalization (PR #134939)

Kunwar Grover llvmlistbot at llvm.org
Fri Apr 11 07:14:57 PDT 2025


================
@@ -5792,26 +5791,17 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
     if (!broadcastOp)
       return failure();
 
-    ArrayRef<int64_t> broadcastSourceShape;
-    if (auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType()))
-      broadcastSourceShape = srcType.getShape();
-    ArrayRef<int64_t> shapeCastTargetShape =
-        shapeCastOp.getResultVectorType().getShape();
-
-    // If `broadcastSourceShape` is a suffix of the result, we can just replace
-    // with a broadcast to the final shape.
-    if (broadcastSourceShape ==
-        shapeCastTargetShape.take_back(broadcastSourceShape.size())) {
-      rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
-          shapeCastOp, shapeCastOp.getResultVectorType(),
-          broadcastOp.getSource());
-      return success();
-    }
-
-    // Otherwise, if the final result has the same element count, we can replace
-    // with a shape cast.
-    if (auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType())) {
-      if (srcType.getNumElements() ==
+    auto srcVectorType = dyn_cast<VectorType>(broadcastOp.getSourceType());
+    bool srcIsScalar = !srcVectorType;
+
+    // Replace Y = ShapeCast(Broadcast(X)) with Y = ShapeCast(X).
+    // Example:
+    // %0 = vector.broadcast %in : vector<3x4xf32> to vector<1x3x4xf32>
+    // %1 = vector.shape_cast %0 : vector<1x3x4xf32> to vector<12xf32>
+    // to
+    // %1 = vector.shape_cast %in : vector<3x4xf32> to vector<12xf32>
+    if (srcVectorType) {
+      if (srcVectorType.getNumElements() ==
           shapeCastOp.getResultVectorType().getNumElements()) {
         rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
             shapeCastOp, shapeCastOp.getResultVectorType(),
----------------
Groverkss wrote:

It's ok to land for now, but this should be a folder not a canonicalization pattern.

https://github.com/llvm/llvm-project/pull/134939


More information about the Mlir-commits mailing list