[Mlir-commits] [mlir] [MLIR][Shape] Support >2 args in `shape.broadcast` folder (PR #126808)
Mateusz Sokół
llvmlistbot at llvm.org
Sun Mar 9 12:39:24 PDT 2025
================
@@ -649,24 +649,29 @@ OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
return getShapes().front();
}
- // TODO: Support folding with more than 2 input shapes
- if (getShapes().size() > 2)
+ if (!adaptor.getShapes().front())
return nullptr;
- if (!adaptor.getShapes()[0] || !adaptor.getShapes()[1])
- return nullptr;
- auto lhsShape = llvm::to_vector<6>(
- llvm::cast<DenseIntElementsAttr>(adaptor.getShapes()[0])
- .getValues<int64_t>());
- auto rhsShape = llvm::to_vector<6>(
- llvm::cast<DenseIntElementsAttr>(adaptor.getShapes()[1])
+ SmallVector<int64_t, 6> resultShape(
+ llvm::cast<DenseIntElementsAttr>(adaptor.getShapes().front())
.getValues<int64_t>());
- SmallVector<int64_t, 6> resultShape;
- // If the shapes are not compatible, we can't fold it.
- // TODO: Fold to an "error".
- if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape))
- return nullptr;
+ for (auto next : adaptor.getShapes().drop_front()) {
+ if (!next)
+ return nullptr;
+ auto nextShape = llvm::to_vector<6>(
+ llvm::cast<DenseIntElementsAttr>(next).getValues<int64_t>());
+
+ SmallVector<int64_t, 6> tmpShape;
+ // If the shapes are not compatible, we can't fold it.
+ // TODO: Fold to an "error".
+ if (!OpTrait::util::getBroadcastedShape(resultShape, nextShape, tmpShape))
+ return nullptr;
+
+ resultShape.clear();
+ std::copy(tmpShape.begin(), tmpShape.end(),
----------------
mtsokol wrote:
Yes, that's correct - it was produced by a `clang-format`. Here's another place where `std::copy` is formatted the same way:
https://github.com/llvm/llvm-project/blob/74ca5799caea342ac1e8d34ab5be7f45875131b2/clang/include/clang/Lex/MacroInfo.h#L536-L537
https://github.com/llvm/llvm-project/pull/126808
More information about the Mlir-commits
mailing list