[Mlir-commits] [mlir] [MLIR][Shape] Support >2 args in `shape.broadcast` folder (PR #126808)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Feb 11 14:09:10 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Mateusz Sokół (mtsokol)
<details>
<summary>Changes</summary>
Hi!
As the title says, this PR adds support for >2 arguments in `shape.broadcast` folder by sequentially calling `getBroadcastedShape`.
---
Full diff: https://github.com/llvm/llvm-project/pull/126808.diff
3 Files Affected:
- (modified) mlir/lib/Dialect/Shape/IR/Shape.cpp (+21-13)
- (modified) mlir/lib/Dialect/Traits.cpp (+1-1)
- (modified) mlir/test/Dialect/Shape/canonicalize.mlir (+13)
``````````diff
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 65efc88e9c403..daa33ea865a5c 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -649,24 +649,32 @@ OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
return getShapes().front();
}
- // TODO: Support folding with more than 2 input shapes
- if (getShapes().size() > 2)
+ if (!adaptor.getShapes().front())
return nullptr;
- if (!adaptor.getShapes()[0] || !adaptor.getShapes()[1])
- return nullptr;
- auto lhsShape = llvm::to_vector<6>(
- llvm::cast<DenseIntElementsAttr>(adaptor.getShapes()[0])
- .getValues<int64_t>());
- auto rhsShape = llvm::to_vector<6>(
- llvm::cast<DenseIntElementsAttr>(adaptor.getShapes()[1])
+ auto firstShape = llvm::to_vector<6>(
+ llvm::cast<DenseIntElementsAttr>(adaptor.getShapes().front())
.getValues<int64_t>());
+
SmallVector<int64_t, 6> resultShape;
+ resultShape.clear();
+ std::copy(firstShape.begin(), firstShape.end(), std::back_inserter(resultShape));
- // If the shapes are not compatible, we can't fold it.
- // TODO: Fold to an "error".
- if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape))
- return nullptr;
+ for (auto next : adaptor.getShapes().drop_front()) {
+ if (!next)
+ return nullptr;
+ auto nextShape = llvm::to_vector<6>(
+ llvm::cast<DenseIntElementsAttr>(next).getValues<int64_t>());
+
+ SmallVector<int64_t, 6> tmpShape;
+ // If the shapes are not compatible, we can't fold it.
+ // TODO: Fold to an "error".
+ if (!OpTrait::util::getBroadcastedShape(resultShape, nextShape, tmpShape))
+ return nullptr;
+
+ resultShape.clear();
+ std::copy(tmpShape.begin(), tmpShape.end(), std::back_inserter(resultShape));
+ }
Builder builder(getContext());
return builder.getIndexTensorAttr(resultShape);
diff --git a/mlir/lib/Dialect/Traits.cpp b/mlir/lib/Dialect/Traits.cpp
index a7aa25eae2644..6e62a33037eb8 100644
--- a/mlir/lib/Dialect/Traits.cpp
+++ b/mlir/lib/Dialect/Traits.cpp
@@ -84,7 +84,7 @@ bool OpTrait::util::getBroadcastedShape(ArrayRef<int64_t> shape1,
if (ShapedType::isDynamic(*i1) || ShapedType::isDynamic(*i2)) {
// One or both dimensions is unknown. Follow TensorFlow behavior:
// - If either dimension is greater than 1, we assume that the program is
- // correct, and the other dimension will be broadcast to match it.
+ // correct, and the other dimension will be broadcasted to match it.
// - If either dimension is 1, the other dimension is the output.
if (*i1 > 1) {
*iR = *i1;
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index cf439c9c1b854..9ed4837a2fe7e 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -86,6 +86,19 @@ func.func @broadcast() -> !shape.shape {
// -----
+// Variadic case including extent tensors.
+// CHECK-LABEL: @broadcast_variadic
+func.func @broadcast_variadic() -> !shape.shape {
+ // CHECK: shape.const_shape [7, 2, 10] : !shape.shape
+ %0 = shape.const_shape [2, 1] : tensor<2xindex>
+ %1 = shape.const_shape [7, 2, 1] : tensor<3xindex>
+ %2 = shape.const_shape [1, 10] : tensor<2xindex>
+ %3 = shape.broadcast %0, %1, %2 : tensor<2xindex>, tensor<3xindex>, tensor<2xindex> -> !shape.shape
+ return %3 : !shape.shape
+}
+
+// -----
+
// Rhs is a scalar.
// CHECK-LABEL: func @f
func.func @f(%arg0 : !shape.shape) -> !shape.shape {
``````````
</details>
https://github.com/llvm/llvm-project/pull/126808
More information about the Mlir-commits
mailing list