[Mlir-commits] [mlir] [MLIR][Shape] Support >2 args in `shape.broadcast` folder (PR #126808)

Jacques Pienaar llvmlistbot at llvm.org
Fri Mar 7 09:26:27 PST 2025


Mateusz =?utf-8?q?Sokół?= <mat646 at gmail.com>
Message-ID:
In-Reply-To: <llvm.org/llvm/llvm-project/pull/126808 at github.com>


================
@@ -649,24 +649,29 @@ OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
     return getShapes().front();
   }
 
-  // TODO: Support folding with more than 2 input shapes
-  if (getShapes().size() > 2)
+  if (!adaptor.getShapes().front())
     return nullptr;
 
-  if (!adaptor.getShapes()[0] || !adaptor.getShapes()[1])
-    return nullptr;
-  auto lhsShape = llvm::to_vector<6>(
-      llvm::cast<DenseIntElementsAttr>(adaptor.getShapes()[0])
-          .getValues<int64_t>());
-  auto rhsShape = llvm::to_vector<6>(
-      llvm::cast<DenseIntElementsAttr>(adaptor.getShapes()[1])
+  SmallVector<int64_t, 6> resultShape(
+      llvm::cast<DenseIntElementsAttr>(adaptor.getShapes().front())
           .getValues<int64_t>());
-  SmallVector<int64_t, 6> resultShape;
 
-  // If the shapes are not compatible, we can't fold it.
-  // TODO: Fold to an "error".
-  if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape))
-    return nullptr;
+  for (auto next : adaptor.getShapes().drop_front()) {
+    if (!next)
+      return nullptr;
+    auto nextShape = llvm::to_vector<6>(
+        llvm::cast<DenseIntElementsAttr>(next).getValues<int64_t>());
+
+    SmallVector<int64_t, 6> tmpShape;
+    // If the shapes are not compatible, we can't fold it.
+    // TODO: Fold to an "error".
+    if (!OpTrait::util::getBroadcastedShape(resultShape, nextShape, tmpShape))
+      return nullptr;
+
+    resultShape.clear();
+    std::copy(tmpShape.begin(), tmpShape.end(),
----------------
jpienaar wrote:

Was this what clang-format produced?

https://github.com/llvm/llvm-project/pull/126808


More information about the Mlir-commits mailing list