[Mlir-commits] [mlir] [MLIR][Shape] Support >2 args in `shape.broadcast` folder (PR #126808)

Mateusz Sokół llvmlistbot at llvm.org
Wed Feb 19 08:57:38 PST 2025


https://github.com/mtsokol updated https://github.com/llvm/llvm-project/pull/126808

>From 47398c24c666293ca93b1fb01f8576732afa9cf4 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Mateusz=20Sok=C3=B3=C5=82?= <mat646 at gmail.com>
Date: Tue, 11 Feb 2025 22:06:43 +0000
Subject: [PATCH 1/2] [MLIR][Shape] Support >2 args in `shape.broadcast` folder

---
 mlir/lib/Dialect/Shape/IR/Shape.cpp       | 36 +++++++++++++++--------
 mlir/lib/Dialect/Traits.cpp               |  2 +-
 mlir/test/Dialect/Shape/canonicalize.mlir | 13 ++++++++
 3 files changed, 37 insertions(+), 14 deletions(-)

diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 2200af0f67a86..13faa4921518a 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -649,24 +649,34 @@ OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
     return getShapes().front();
   }
 
-  // TODO: Support folding with more than 2 input shapes
-  if (getShapes().size() > 2)
+  if (!adaptor.getShapes().front())
     return nullptr;
 
-  if (!adaptor.getShapes()[0] || !adaptor.getShapes()[1])
-    return nullptr;
-  auto lhsShape = llvm::to_vector<6>(
-      llvm::cast<DenseIntElementsAttr>(adaptor.getShapes()[0])
-          .getValues<int64_t>());
-  auto rhsShape = llvm::to_vector<6>(
-      llvm::cast<DenseIntElementsAttr>(adaptor.getShapes()[1])
+  auto firstShape = llvm::to_vector<6>(
+      llvm::cast<DenseIntElementsAttr>(adaptor.getShapes().front())
           .getValues<int64_t>());
+
   SmallVector<int64_t, 6> resultShape;
+  resultShape.clear();
+  std::copy(firstShape.begin(), firstShape.end(),
+            std::back_inserter(resultShape));
 
-  // If the shapes are not compatible, we can't fold it.
-  // TODO: Fold to an "error".
-  if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape))
-    return nullptr;
+  for (auto next : adaptor.getShapes().drop_front()) {
+    if (!next)
+      return nullptr;
+    auto nextShape = llvm::to_vector<6>(
+        llvm::cast<DenseIntElementsAttr>(next).getValues<int64_t>());
+
+    SmallVector<int64_t, 6> tmpShape;
+    // If the shapes are not compatible, we can't fold it.
+    // TODO: Fold to an "error".
+    if (!OpTrait::util::getBroadcastedShape(resultShape, nextShape, tmpShape))
+      return nullptr;
+
+    resultShape.clear();
+    std::copy(tmpShape.begin(), tmpShape.end(),
+              std::back_inserter(resultShape));
+  }
 
   Builder builder(getContext());
   return builder.getIndexTensorAttr(resultShape);
diff --git a/mlir/lib/Dialect/Traits.cpp b/mlir/lib/Dialect/Traits.cpp
index a7aa25eae2644..6e62a33037eb8 100644
--- a/mlir/lib/Dialect/Traits.cpp
+++ b/mlir/lib/Dialect/Traits.cpp
@@ -84,7 +84,7 @@ bool OpTrait::util::getBroadcastedShape(ArrayRef<int64_t> shape1,
     if (ShapedType::isDynamic(*i1) || ShapedType::isDynamic(*i2)) {
       // One or both dimensions is unknown. Follow TensorFlow behavior:
       // - If either dimension is greater than 1, we assume that the program is
-      //   correct, and the other dimension will be broadcast to match it.
+      //   correct, and the other dimension will be broadcasted to match it.
       // - If either dimension is 1, the other dimension is the output.
       if (*i1 > 1) {
         *iR = *i1;
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index cf439c9c1b854..9ed4837a2fe7e 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -86,6 +86,19 @@ func.func @broadcast() -> !shape.shape {
 
 // -----
 
+// Variadic case including extent tensors.
+// CHECK-LABEL: @broadcast_variadic
+func.func @broadcast_variadic() -> !shape.shape {
+  // CHECK: shape.const_shape [7, 2, 10] : !shape.shape
+  %0 = shape.const_shape [2, 1] : tensor<2xindex>
+  %1 = shape.const_shape [7, 2, 1] : tensor<3xindex>
+  %2 = shape.const_shape [1, 10] : tensor<2xindex>
+  %3 = shape.broadcast %0, %1, %2 : tensor<2xindex>, tensor<3xindex>, tensor<2xindex> -> !shape.shape
+  return %3 : !shape.shape
+}
+
+// -----
+
 // Rhs is a scalar.
 // CHECK-LABEL: func @f
 func.func @f(%arg0 : !shape.shape) -> !shape.shape {

>From b2591990b983ca9e6ca380523412d90735a0adbb Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Mateusz=20Sok=C3=B3=C5=82?= <mat646 at gmail.com>
Date: Wed, 19 Feb 2025 16:57:17 +0000
Subject: [PATCH 2/2] Initialize `resultShape` directly

---
 mlir/lib/Dialect/Shape/IR/Shape.cpp | 7 +------
 1 file changed, 1 insertion(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 13faa4921518a..8d8e861c84157 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -652,15 +652,10 @@ OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
   if (!adaptor.getShapes().front())
     return nullptr;
 
-  auto firstShape = llvm::to_vector<6>(
+  SmallVector<int64_t, 6> resultShape(
       llvm::cast<DenseIntElementsAttr>(adaptor.getShapes().front())
           .getValues<int64_t>());
 
-  SmallVector<int64_t, 6> resultShape;
-  resultShape.clear();
-  std::copy(firstShape.begin(), firstShape.end(),
-            std::back_inserter(resultShape));
-
   for (auto next : adaptor.getShapes().drop_front()) {
     if (!next)
       return nullptr;



More information about the Mlir-commits mailing list