[Mlir-commits] [mlir] 63a35f3 - [mlir][Shape] Generalize cstr_broadcastable folding for n-ary broadcasts
Benjamin Kramer
llvmlistbot at llvm.org
Wed Feb 17 02:53:13 PST 2021
Author: Benjamin Kramer
Date: 2021-02-17T11:44:52+01:00
New Revision: 63a35f35ecf8a46e63af750a88faecfe1c6354a6
URL: https://github.com/llvm/llvm-project/commit/63a35f35ecf8a46e63af750a88faecfe1c6354a6
DIFF: https://github.com/llvm/llvm-project/commit/63a35f35ecf8a46e63af750a88faecfe1c6354a6.diff
LOG: [mlir][Shape] Generalize cstr_broadcastable folding for n-ary broadcasts
This is still fairly tricky code, but I tried to untangle it a bit.
Differential Revision: https://reviews.llvm.org/D96800
Added:
Modified:
mlir/include/mlir/Dialect/Traits.h
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/lib/Dialect/Traits.cpp
mlir/test/Dialect/Shape/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Traits.h b/mlir/include/mlir/Dialect/Traits.h
index aecceaac0a42..c51cadf048ae 100644
--- a/mlir/include/mlir/Dialect/Traits.h
+++ b/mlir/include/mlir/Dialect/Traits.h
@@ -47,7 +47,7 @@ namespace util {
bool getBroadcastedShape(ArrayRef<int64_t> shape1, ArrayRef<int64_t> shape2,
SmallVectorImpl<int64_t> &resultShape);
-/// Returns true if a broadcast between the 2 shapes is guaranteed to be
+/// Returns true if a broadcast between n shapes is guaranteed to be
/// successful and not result in an error. False does not guarantee that the
/// shapes are not broadcastable; it might guarantee that they are not
/// broadcastable or it might mean that this function does not have enough
@@ -59,6 +59,7 @@ bool getBroadcastedShape(ArrayRef<int64_t> shape1, ArrayRef<int64_t> shape2,
/// dimension, while this function will return false because it's possible for
/// both shapes to have a dimension greater than 1 and
diff erent which would
/// fail to broadcast.
+bool staticallyKnownBroadcastable(ArrayRef<SmallVector<int64_t, 6>> shapes);
bool staticallyKnownBroadcastable(ArrayRef<int64_t> shape1,
ArrayRef<int64_t> shape2);
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 058c0c58dda2..b1199fb17b20 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -490,38 +490,48 @@ void CstrBroadcastableOp::getCanonicalizationPatterns(
patterns.insert<CstrBroadcastableEqOps>(context);
}
-OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
- // TODO: Add folding for the nary case
- if (operands.size() != 2)
- return nullptr;
+// Return true if there is exactly one attribute not representing a scalar
+// broadcast.
+static bool hasAtMostSingleNonScalar(ArrayRef<Attribute> attributes) {
+ bool nonScalarSeen = false;
+ for (Attribute a : attributes) {
+ if (!a || a.cast<DenseIntElementsAttr>().getNumElements() != 0) {
+ if (nonScalarSeen)
+ return false;
+ nonScalarSeen = true;
+ }
+ }
+ return true;
+}
- // Both operands are not needed if one is a scalar.
- if (operands[0] &&
- operands[0].cast<DenseIntElementsAttr>().getNumElements() == 0)
- return BoolAttr::get(getContext(), true);
- if (operands[1] &&
- operands[1].cast<DenseIntElementsAttr>().getNumElements() == 0)
+OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
+ // No broadcasting is needed if all operands but one are scalar.
+ if (hasAtMostSingleNonScalar(operands))
return BoolAttr::get(getContext(), true);
- if (operands[0] && operands[1]) {
- auto lhsShape = llvm::to_vector<6>(
- operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
- auto rhsShape = llvm::to_vector<6>(
- operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
- SmallVector<int64_t, 6> resultShape;
- if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape))
- return BoolAttr::get(getContext(), true);
- }
+ if ([&] {
+ SmallVector<SmallVector<int64_t, 6>, 6> extents;
+ for (const auto &operand : operands) {
+ if (!operand)
+ return false;
+ extents.push_back(llvm::to_vector<6>(
+ operand.cast<DenseIntElementsAttr>().getValues<int64_t>()));
+ }
+ return OpTrait::util::staticallyKnownBroadcastable(extents);
+ }())
+ return BoolAttr::get(getContext(), true);
// Lastly, see if folding can be completed based on what constraints are known
// on the input shapes.
- SmallVector<int64_t, 6> lhsShape, rhsShape;
- if (failed(getShapeVec(shapes()[0], lhsShape)))
- return nullptr;
- if (failed(getShapeVec(shapes()[1], rhsShape)))
- return nullptr;
-
- if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape))
+ if ([&] {
+ SmallVector<SmallVector<int64_t, 6>, 6> extents;
+ for (const auto &shape : shapes()) {
+ extents.emplace_back();
+ if (failed(getShapeVec(shape, extents.back())))
+ return false;
+ }
+ return OpTrait::util::staticallyKnownBroadcastable(extents);
+ }())
return BoolAttr::get(getContext(), true);
// Because a failing witness result here represents an eventual assertion
diff --git a/mlir/lib/Dialect/Traits.cpp b/mlir/lib/Dialect/Traits.cpp
index b7d1bc876f2d..50f203644741 100644
--- a/mlir/lib/Dialect/Traits.cpp
+++ b/mlir/lib/Dialect/Traits.cpp
@@ -15,19 +15,45 @@ using namespace mlir;
bool OpTrait::util::staticallyKnownBroadcastable(ArrayRef<int64_t> shape1,
ArrayRef<int64_t> shape2) {
- // Two dimensions are compatible when
- // 1. they are defined and equal, or
- // 2. one of them is 1
- return llvm::all_of(llvm::zip(llvm::reverse(shape1), llvm::reverse(shape2)),
- [](auto dimensions) {
- auto dim1 = std::get<0>(dimensions);
- auto dim2 = std::get<1>(dimensions);
- if (dim1 == 1 || dim2 == 1)
- return true;
- if (dim1 == dim2 && !ShapedType::isDynamic(dim1))
- return true;
- return false;
- });
+ SmallVector<SmallVector<int64_t, 6>, 2> extents;
+ extents.emplace_back(shape1.begin(), shape1.end());
+ extents.emplace_back(shape2.begin(), shape2.end());
+ return staticallyKnownBroadcastable(extents);
+}
+
+bool OpTrait::util::staticallyKnownBroadcastable(
+ ArrayRef<SmallVector<int64_t, 6>> shapes) {
+ assert(!shapes.empty() && "Expected at least one shape");
+ size_t maxRank = shapes[0].size();
+ for (size_t i = 1; i != shapes.size(); ++i)
+ maxRank = std::max(maxRank, shapes[i].size());
+
+ // We look backwards through every column of `shapes`.
+ for (size_t i = 0; i != maxRank; ++i) {
+ bool seenDynamic = false;
+ Optional<int64_t> nonOneDim;
+ for (ArrayRef<int64_t> extent : shapes) {
+ int64_t dim = i >= extent.size() ? 1 : extent[extent.size() - i - 1];
+
+ if (dim == 1)
+ continue;
+
+ // Dimensions are compatible when
+ //. 1. One is dynamic, the rest are 1
+ if (ShapedType::isDynamic(dim)) {
+ if (seenDynamic || nonOneDim)
+ return false;
+ seenDynamic = true;
+ }
+
+ // 2. All are 1 or a specific constant.
+ if (nonOneDim && dim != *nonOneDim)
+ return false;
+
+ nonOneDim = dim;
+ }
+ }
+ return true;
}
bool OpTrait::util::getBroadcastedShape(ArrayRef<int64_t> shape1,
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index ba7e479eb347..3a04c16cb9f0 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -600,6 +600,92 @@ func @broadcastable_on_extent_tensors(%arg : tensor<?xindex>) {
return
}
+// -----
+// Fold ternary broadcastable
+// CHECK-LABEL: func @f
+func @f() {
+ // CHECK-NEXT: shape.const_witness true
+ // CHECK-NEXT: consume.witness
+ // CHECK-NEXT: return
+ %cs0 = shape.const_shape [8, 1] : !shape.shape
+ %cs1 = shape.const_shape [1, 8] : !shape.shape
+ %cs2 = shape.const_shape [1, 1] : !shape.shape
+ %0 = shape.cstr_broadcastable %cs0, %cs1, %cs2 : !shape.shape, !shape.shape, !shape.shape
+ "consume.witness"(%0) : (!shape.witness) -> ()
+ return
+}
+
+// -----
+// Fold ternary broadcastable with dynamic ranks
+// CHECK-LABEL: func @f
+func @f() {
+ // CHECK-NEXT: shape.const_witness true
+ // CHECK-NEXT: consume.witness
+ // CHECK-NEXT: return
+ %cs0 = shape.const_shape [8, 1] : !shape.shape
+ %cs1 = shape.const_shape [1, -1] : !shape.shape
+ %0 = shape.cstr_broadcastable %cs0, %cs0, %cs1 : !shape.shape, !shape.shape, !shape.shape
+ "consume.witness"(%0) : (!shape.witness) -> ()
+ return
+}
+
+// -----
+// One scalar and one non-scalar and one unknown cannot be broadcasted at compile time
+// CHECK-LABEL: func @f
+func @f() {
+ // CHECK: shape.cstr_broadcastable
+ // CHECK-NEXT: consume.witness
+ // CHECK-NEXT: return
+ %cs0 = shape.const_shape [8, 1] : !shape.shape
+ %cs1 = shape.const_shape [1, 8] : !shape.shape
+ %cs2 = shape.const_shape [1, -1] : !shape.shape
+ %0 = shape.cstr_broadcastable %cs0, %cs1, %cs2 : !shape.shape, !shape.shape, !shape.shape
+ "consume.witness"(%0) : (!shape.witness) -> ()
+ return
+}
+
+// -----
+// One scalar and two unknowns cannot be broadcasted at compile time
+// CHECK-LABEL: func @f
+func @f() {
+ // CHECK: shape.cstr_broadcastable
+ // CHECK-NEXT: consume.witness
+ // CHECK-NEXT: return
+ %cs0 = shape.const_shape [8, 1] : !shape.shape
+ %cs1 = shape.const_shape [1, -1] : !shape.shape
+ %cs2 = shape.const_shape [1, -1] : !shape.shape
+ %0 = shape.cstr_broadcastable %cs0, %cs1, %cs2 : !shape.shape, !shape.shape, !shape.shape
+ "consume.witness"(%0) : (!shape.witness) -> ()
+ return
+}
+
+// -----
+// Broadcastable with scalars and a non-scalar can be constant folded
+// CHECK-LABEL: func @f
+func @f(%arg0 : !shape.shape) {
+ // CHECK-NEXT: shape.const_witness true
+ // CHECK-NEXT: consume.witness
+ // CHECK-NEXT: return
+ %cs0 = shape.const_shape [] : !shape.shape
+ %0 = shape.cstr_broadcastable %cs0, %cs0, %arg0 : !shape.shape, !shape.shape, !shape.shape
+ "consume.witness"(%0) : (!shape.witness) -> ()
+ return
+}
+
+// -----
+// One scalar and one non-scalar and one unknown cannot be folded.
+// CHECK-LABEL: func @f
+func @f(%arg0 : !shape.shape) {
+ // CHECK: shape.cstr_broadcastable
+ // CHECK-NEXT: consume.witness
+ // CHECK-NEXT: return
+ %cs0 = shape.const_shape [] : !shape.shape
+ %cs1 = shape.const_shape [2] : !shape.shape
+ %0 = shape.cstr_broadcastable %cs0, %cs1, %arg0 : !shape.shape, !shape.shape, !shape.shape
+ "consume.witness"(%0) : (!shape.witness) -> ()
+ return
+}
+
// -----
// Fold `rank` based on constant shape.
More information about the Mlir-commits
mailing list