[Mlir-commits] [mlir] 63a35f3 - [mlir][Shape] Generalize cstr_broadcastable folding for n-ary broadcasts

Benjamin Kramer llvmlistbot at llvm.org
Wed Feb 17 02:53:13 PST 2021


Author: Benjamin Kramer
Date: 2021-02-17T11:44:52+01:00
New Revision: 63a35f35ecf8a46e63af750a88faecfe1c6354a6

URL: https://github.com/llvm/llvm-project/commit/63a35f35ecf8a46e63af750a88faecfe1c6354a6
DIFF: https://github.com/llvm/llvm-project/commit/63a35f35ecf8a46e63af750a88faecfe1c6354a6.diff

LOG: [mlir][Shape] Generalize cstr_broadcastable folding for n-ary broadcasts

This is still fairly tricky code, but I tried to untangle it a bit.

Differential Revision: https://reviews.llvm.org/D96800

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Traits.h
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/lib/Dialect/Traits.cpp
    mlir/test/Dialect/Shape/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Traits.h b/mlir/include/mlir/Dialect/Traits.h
index aecceaac0a42..c51cadf048ae 100644
--- a/mlir/include/mlir/Dialect/Traits.h
+++ b/mlir/include/mlir/Dialect/Traits.h
@@ -47,7 +47,7 @@ namespace util {
 bool getBroadcastedShape(ArrayRef<int64_t> shape1, ArrayRef<int64_t> shape2,
                          SmallVectorImpl<int64_t> &resultShape);
 
-/// Returns true if a broadcast between the 2 shapes is guaranteed to be
+/// Returns true if a broadcast between n shapes is guaranteed to be
 /// successful and not result in an error. False does not guarantee that the
 /// shapes are not broadcastable; it might guarantee that they are not
 /// broadcastable or it might mean that this function does not have enough
@@ -59,6 +59,7 @@ bool getBroadcastedShape(ArrayRef<int64_t> shape1, ArrayRef<int64_t> shape2,
 /// dimension, while this function will return false because it's possible for
 /// both shapes to have a dimension greater than 1 and 
diff erent which would
 /// fail to broadcast.
+bool staticallyKnownBroadcastable(ArrayRef<SmallVector<int64_t, 6>> shapes);
 bool staticallyKnownBroadcastable(ArrayRef<int64_t> shape1,
                                   ArrayRef<int64_t> shape2);
 

diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 058c0c58dda2..b1199fb17b20 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -490,38 +490,48 @@ void CstrBroadcastableOp::getCanonicalizationPatterns(
   patterns.insert<CstrBroadcastableEqOps>(context);
 }
 
-OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
-  // TODO: Add folding for the nary case
-  if (operands.size() != 2)
-    return nullptr;
+// Return true if there is exactly one attribute not representing a scalar
+// broadcast.
+static bool hasAtMostSingleNonScalar(ArrayRef<Attribute> attributes) {
+  bool nonScalarSeen = false;
+  for (Attribute a : attributes) {
+    if (!a || a.cast<DenseIntElementsAttr>().getNumElements() != 0) {
+      if (nonScalarSeen)
+        return false;
+      nonScalarSeen = true;
+    }
+  }
+  return true;
+}
 
-  // Both operands are not needed if one is a scalar.
-  if (operands[0] &&
-      operands[0].cast<DenseIntElementsAttr>().getNumElements() == 0)
-    return BoolAttr::get(getContext(), true);
-  if (operands[1] &&
-      operands[1].cast<DenseIntElementsAttr>().getNumElements() == 0)
+OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
+  // No broadcasting is needed if all operands but one are scalar.
+  if (hasAtMostSingleNonScalar(operands))
     return BoolAttr::get(getContext(), true);
 
-  if (operands[0] && operands[1]) {
-    auto lhsShape = llvm::to_vector<6>(
-        operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
-    auto rhsShape = llvm::to_vector<6>(
-        operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
-    SmallVector<int64_t, 6> resultShape;
-    if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape))
-      return BoolAttr::get(getContext(), true);
-  }
+  if ([&] {
+        SmallVector<SmallVector<int64_t, 6>, 6> extents;
+        for (const auto &operand : operands) {
+          if (!operand)
+            return false;
+          extents.push_back(llvm::to_vector<6>(
+              operand.cast<DenseIntElementsAttr>().getValues<int64_t>()));
+        }
+        return OpTrait::util::staticallyKnownBroadcastable(extents);
+      }())
+    return BoolAttr::get(getContext(), true);
 
   // Lastly, see if folding can be completed based on what constraints are known
   // on the input shapes.
-  SmallVector<int64_t, 6> lhsShape, rhsShape;
-  if (failed(getShapeVec(shapes()[0], lhsShape)))
-    return nullptr;
-  if (failed(getShapeVec(shapes()[1], rhsShape)))
-    return nullptr;
-
-  if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape))
+  if ([&] {
+        SmallVector<SmallVector<int64_t, 6>, 6> extents;
+        for (const auto &shape : shapes()) {
+          extents.emplace_back();
+          if (failed(getShapeVec(shape, extents.back())))
+            return false;
+        }
+        return OpTrait::util::staticallyKnownBroadcastable(extents);
+      }())
     return BoolAttr::get(getContext(), true);
 
   // Because a failing witness result here represents an eventual assertion

diff  --git a/mlir/lib/Dialect/Traits.cpp b/mlir/lib/Dialect/Traits.cpp
index b7d1bc876f2d..50f203644741 100644
--- a/mlir/lib/Dialect/Traits.cpp
+++ b/mlir/lib/Dialect/Traits.cpp
@@ -15,19 +15,45 @@ using namespace mlir;
 
 bool OpTrait::util::staticallyKnownBroadcastable(ArrayRef<int64_t> shape1,
                                                  ArrayRef<int64_t> shape2) {
-  // Two dimensions are compatible when
-  //   1. they are defined and equal, or
-  //   2. one of them is 1
-  return llvm::all_of(llvm::zip(llvm::reverse(shape1), llvm::reverse(shape2)),
-                      [](auto dimensions) {
-                        auto dim1 = std::get<0>(dimensions);
-                        auto dim2 = std::get<1>(dimensions);
-                        if (dim1 == 1 || dim2 == 1)
-                          return true;
-                        if (dim1 == dim2 && !ShapedType::isDynamic(dim1))
-                          return true;
-                        return false;
-                      });
+  SmallVector<SmallVector<int64_t, 6>, 2> extents;
+  extents.emplace_back(shape1.begin(), shape1.end());
+  extents.emplace_back(shape2.begin(), shape2.end());
+  return staticallyKnownBroadcastable(extents);
+}
+
+bool OpTrait::util::staticallyKnownBroadcastable(
+    ArrayRef<SmallVector<int64_t, 6>> shapes) {
+  assert(!shapes.empty() && "Expected at least one shape");
+  size_t maxRank = shapes[0].size();
+  for (size_t i = 1; i != shapes.size(); ++i)
+    maxRank = std::max(maxRank, shapes[i].size());
+
+  // We look backwards through every column of `shapes`.
+  for (size_t i = 0; i != maxRank; ++i) {
+    bool seenDynamic = false;
+    Optional<int64_t> nonOneDim;
+    for (ArrayRef<int64_t> extent : shapes) {
+      int64_t dim = i >= extent.size() ? 1 : extent[extent.size() - i - 1];
+
+      if (dim == 1)
+        continue;
+
+      // Dimensions are compatible when
+      //.  1. One is dynamic, the rest are 1
+      if (ShapedType::isDynamic(dim)) {
+        if (seenDynamic || nonOneDim)
+          return false;
+        seenDynamic = true;
+      }
+
+      //   2. All are 1 or a specific constant.
+      if (nonOneDim && dim != *nonOneDim)
+        return false;
+
+      nonOneDim = dim;
+    }
+  }
+  return true;
 }
 
 bool OpTrait::util::getBroadcastedShape(ArrayRef<int64_t> shape1,

diff  --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index ba7e479eb347..3a04c16cb9f0 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -600,6 +600,92 @@ func @broadcastable_on_extent_tensors(%arg : tensor<?xindex>) {
   return
 }
 
+// -----
+// Fold ternary broadcastable
+// CHECK-LABEL: func @f
+func @f() {
+  // CHECK-NEXT: shape.const_witness true
+  // CHECK-NEXT: consume.witness
+  // CHECK-NEXT: return
+  %cs0 = shape.const_shape [8, 1] : !shape.shape
+  %cs1 = shape.const_shape [1, 8] : !shape.shape
+  %cs2 = shape.const_shape [1, 1] : !shape.shape
+  %0 = shape.cstr_broadcastable %cs0, %cs1, %cs2 : !shape.shape, !shape.shape, !shape.shape
+  "consume.witness"(%0) : (!shape.witness) -> ()
+  return
+}
+
+// -----
+// Fold ternary broadcastable with dynamic ranks
+// CHECK-LABEL: func @f
+func @f() {
+  // CHECK-NEXT: shape.const_witness true
+  // CHECK-NEXT: consume.witness
+  // CHECK-NEXT: return
+  %cs0 = shape.const_shape [8, 1] : !shape.shape
+  %cs1 = shape.const_shape [1, -1] : !shape.shape
+  %0 = shape.cstr_broadcastable %cs0, %cs0, %cs1 : !shape.shape, !shape.shape, !shape.shape
+  "consume.witness"(%0) : (!shape.witness) -> ()
+  return
+}
+
+// -----
+// One scalar and one non-scalar and one unknown cannot be broadcasted at compile time
+// CHECK-LABEL: func @f
+func @f() {
+  // CHECK: shape.cstr_broadcastable
+  // CHECK-NEXT: consume.witness
+  // CHECK-NEXT: return
+  %cs0 = shape.const_shape [8, 1] : !shape.shape
+  %cs1 = shape.const_shape [1, 8] : !shape.shape
+  %cs2 = shape.const_shape [1, -1] : !shape.shape
+  %0 = shape.cstr_broadcastable %cs0, %cs1, %cs2 : !shape.shape, !shape.shape, !shape.shape
+  "consume.witness"(%0) : (!shape.witness) -> ()
+  return
+}
+
+// -----
+// One scalar and two unknowns cannot be broadcasted at compile time
+// CHECK-LABEL: func @f
+func @f() {
+  // CHECK: shape.cstr_broadcastable
+  // CHECK-NEXT: consume.witness
+  // CHECK-NEXT: return
+  %cs0 = shape.const_shape [8, 1] : !shape.shape
+  %cs1 = shape.const_shape [1, -1] : !shape.shape
+  %cs2 = shape.const_shape [1, -1] : !shape.shape
+  %0 = shape.cstr_broadcastable %cs0, %cs1, %cs2 : !shape.shape, !shape.shape, !shape.shape
+  "consume.witness"(%0) : (!shape.witness) -> ()
+  return
+}
+
+// -----
+// Broadcastable with scalars and a non-scalar can be constant folded
+// CHECK-LABEL: func @f
+func @f(%arg0 : !shape.shape) {
+  // CHECK-NEXT: shape.const_witness true
+  // CHECK-NEXT: consume.witness
+  // CHECK-NEXT: return
+  %cs0 = shape.const_shape [] : !shape.shape
+  %0 = shape.cstr_broadcastable %cs0, %cs0, %arg0 : !shape.shape, !shape.shape, !shape.shape
+  "consume.witness"(%0) : (!shape.witness) -> ()
+  return
+}
+
+// -----
+// One scalar and one non-scalar and one unknown cannot be folded.
+// CHECK-LABEL: func @f
+func @f(%arg0 : !shape.shape) {
+  // CHECK: shape.cstr_broadcastable
+  // CHECK-NEXT: consume.witness
+  // CHECK-NEXT: return
+  %cs0 = shape.const_shape [] : !shape.shape
+  %cs1 = shape.const_shape [2] : !shape.shape
+  %0 = shape.cstr_broadcastable %cs0, %cs1, %arg0 : !shape.shape, !shape.shape, !shape.shape
+  "consume.witness"(%0) : (!shape.witness) -> ()
+  return
+}
+
 // -----
 
 // Fold `rank` based on constant shape.


        


More information about the Mlir-commits mailing list