[Mlir-commits] [mlir] 2ef71cb - [mlir] Add additional Canonicalization of shape.cstr_broadcastable.

Tres Popp llvmlistbot at llvm.org
Thu Jul 9 02:23:47 PDT 2020


Author: Tres Popp
Date: 2020-07-09T11:23:25+02:00
New Revision: 2ef71cb7fdb76acd2dc69584d05dacd041a7a522

URL: https://github.com/llvm/llvm-project/commit/2ef71cb7fdb76acd2dc69584d05dacd041a7a522
DIFF: https://github.com/llvm/llvm-project/commit/2ef71cb7fdb76acd2dc69584d05dacd041a7a522.diff

LOG: [mlir] Add additional Canonicalization of shape.cstr_broadcastable.

Summary:
Added canonicalization and folding was:
- Folding when either input is an attribute indicating a scalar input
which can always be broadcasted.
- Canonicalization where it can be determined that either input shape is
a scalar.
- Canonicalization where the partially specified input shapes can be
proven to be broadcastable always.

Differential Revision: https://reviews.llvm.org/D83194

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Traits.h
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/lib/Dialect/Traits.cpp
    mlir/test/Dialect/Shape/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Traits.h b/mlir/include/mlir/Dialect/Traits.h
index 2b4fb6d77855..aecceaac0a42 100644
--- a/mlir/include/mlir/Dialect/Traits.h
+++ b/mlir/include/mlir/Dialect/Traits.h
@@ -47,6 +47,21 @@ namespace util {
 bool getBroadcastedShape(ArrayRef<int64_t> shape1, ArrayRef<int64_t> shape2,
                          SmallVectorImpl<int64_t> &resultShape);
 
+/// Returns true if a broadcast between the 2 shapes is guaranteed to be
+/// successful and not result in an error. False does not guarantee that the
+/// shapes are not broadcastable; it might guarantee that they are not
+/// broadcastable or it might mean that this function does not have enough
+/// information to know.
+///
+/// Conceptually, this returns true if getBroadcastedShape would have returned
+/// true and vice versa, with one exception. If a dimension is unknown in both
+/// shapes, getBroadcastedShape would return true and have a result with unknown
+/// dimension, while this function will return false because it's possible for
+/// both shapes to have a dimension greater than 1 and 
diff erent which would
+/// fail to broadcast.
+bool staticallyKnownBroadcastable(ArrayRef<int64_t> shape1,
+                                  ArrayRef<int64_t> shape2);
+
 /// Returns the result broadcast composition type from the two given types by
 /// following NumPy broadcast semantics. Returned type may have dynamic shape if
 /// either of the input types has dynamic shape. Returns null type if the two

diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index e251a3887cd4..0a0608bbcda4 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -317,21 +317,101 @@ OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return shapeAttr(); }
 // CstrBroadcastableOp
 //===----------------------------------------------------------------------===//
 
+namespace {
+// Given an input shape Value, try to obtain the shape's values.
+LogicalResult getShapeVec(Value input, SmallVectorImpl<int64_t> &shapeValues) {
+  if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) {
+    auto type = inputOp.arg().getType().dyn_cast<ShapedType>();
+    if (!type.hasRank())
+      return failure();
+    shapeValues = llvm::to_vector<6>(type.getShape());
+    return success();
+  } else if (auto inputOp = input.getDefiningOp<ConstShapeOp>()) {
+    shapeValues = llvm::to_vector<6>(inputOp.shape().getValues<int64_t>());
+    return success();
+  } else {
+    return failure();
+  }
+}
+
+// For shapes that were created by some operations, we can obtain partial
+// information on the shapes and sometimes determine if they will be
+// broadcastable with that.
+struct CstrBroadcastablePartialInfo
+    : public OpRewritePattern<CstrBroadcastableOp> {
+  using OpRewritePattern<CstrBroadcastableOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(CstrBroadcastableOp op,
+                                PatternRewriter &rewriter) const override {
+    SmallVector<int64_t, 6> lhsShape, rhsShape;
+    if (failed(getShapeVec(op.lhs(), lhsShape)))
+      return failure();
+    if (failed(getShapeVec(op.rhs(), rhsShape)))
+      return failure();
+    if (!OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape))
+      return failure();
+
+    rewriter.replaceOpWithNewOp<ConstWitnessOp>(op.getOperation(), true);
+    return success();
+  }
+};
+
+// Scalars are always broadcastable.
+struct CstrBroadcastableScalar : public OpRewritePattern<CstrBroadcastableOp> {
+  using OpRewritePattern<CstrBroadcastableOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(CstrBroadcastableOp op,
+                                PatternRewriter &rewriter) const override {
+    SmallVector<int64_t, 6> shape;
+    if (failed(getShapeVec(op.lhs(), shape)) || shape.size() > 0)
+      return failure();
+    if (failed(getShapeVec(op.rhs(), shape)) || shape.size() > 0)
+      return failure();
+
+    rewriter.replaceOpWithNewOp<ConstWitnessOp>(op.getOperation(), true);
+    return success();
+  }
+};
+
+} // namespace
+
 void CstrBroadcastableOp::getCanonicalizationPatterns(
     OwningRewritePatternList &patterns, MLIRContext *context) {
-  // If inputs are equal, return passing witness
-  patterns.insert<CstrBroadcastableEqOps>(context);
+  // Canonicalization patterns have overlap with the considerations during
+  // folding in case additional shape information is inferred at some point that
+  // does not result in folding.
+  patterns.insert<CstrBroadcastableEqOps, CstrBroadcastablePartialInfo,
+                  CstrBroadcastableScalar>(context);
 }
 
 OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
-  if (!operands[0] || !operands[1])
+  // Both operands are not needed if one is a scalar.
+  if (operands[0] &&
+      operands[0].cast<DenseIntElementsAttr>().getNumElements() == 0)
+    return BoolAttr::get(true, getContext());
+  if (operands[1] &&
+      operands[1].cast<DenseIntElementsAttr>().getNumElements() == 0)
+    return BoolAttr::get(true, getContext());
+
+  if (operands[0] && operands[1]) {
+    auto lhsShape = llvm::to_vector<6>(
+        operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
+    auto rhsShape = llvm::to_vector<6>(
+        operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
+    SmallVector<int64_t, 6> resultShape;
+    if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape))
+      return BoolAttr::get(true, getContext());
+  }
+
+  // Lastly, see if folding can be completed based on what constraints are known
+  // on the input shapes.
+  SmallVector<int64_t, 6> lhsShape, rhsShape;
+  if (failed(getShapeVec(lhs(), lhsShape)))
     return nullptr;
-  auto lhsShape = llvm::to_vector<6>(
-      operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
-  auto rhsShape = llvm::to_vector<6>(
-      operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
-  SmallVector<int64_t, 6> resultShape;
-  if (OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape))
+  if (failed(getShapeVec(rhs(), rhsShape)))
+    return nullptr;
+
+  if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape))
     return BoolAttr::get(true, getContext());
 
   // Because a failing witness result here represents an eventual assertion

diff  --git a/mlir/lib/Dialect/Traits.cpp b/mlir/lib/Dialect/Traits.cpp
index c974e2fc097b..2a557c489e0b 100644
--- a/mlir/lib/Dialect/Traits.cpp
+++ b/mlir/lib/Dialect/Traits.cpp
@@ -13,6 +13,23 @@
 
 using namespace mlir;
 
+bool OpTrait::util::staticallyKnownBroadcastable(ArrayRef<int64_t> shape1,
+                                                 ArrayRef<int64_t> shape2) {
+  // Two dimensions are compatible when
+  //   1. they are defined and equal, or
+  //   2. one of them is 1
+  return llvm::all_of(llvm::zip(llvm::reverse(shape1), llvm::reverse(shape2)),
+                      [](auto dimensions) {
+                        auto dim1 = std::get<0>(dimensions);
+                        auto dim2 = std::get<1>(dimensions);
+                        if (dim1 == 1 || dim2 == 1)
+                          return true;
+                        if (dim1 == dim2 && !ShapedType::isDynamic(dim1))
+                          return true;
+                        return false;
+                      });
+}
+
 bool OpTrait::util::getBroadcastedShape(ArrayRef<int64_t> shape1,
                                         ArrayRef<int64_t> shape2,
                                         SmallVectorImpl<int64_t> &resultShape) {

diff  --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 1b9f3924b8b0..1665ef73f3e3 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -403,8 +403,8 @@ func @f() {
 
 // -----
 // Broadcastable with non-broadcastable constant shapes is always false
-// CHECK-LABEL: func @f
-func @f() {
+// CHECK-LABEL: func @static_non_broadcastable
+func @static_non_broadcastable() {
   // CHECK-NEXT: shape.const_shape
   // CHECK-NEXT: shape.const_shape
   // CHECK-NEXT: shape.cstr_broadcastable
@@ -515,3 +515,49 @@ func @size_to_index_to_size(%size : !shape.size) -> !shape.size {
   return %result : !shape.size
 }
 
+// -----
+
+// Canonicalize scalar cstr_broadcastable checks
+// CHECK-LABEL: @cstr_broadcastable_scalar
+func @cstr_broadcastable_scalar(%arg0 : tensor<?xf32>) {
+  // CHECK-NEXT: shape.const_witness true
+  // CHECK-NEXT: consume.witness
+  // CHECK-NEXT: return
+  %0 = shape.const_shape []
+  %1 = shape.shape_of %arg0 : tensor<?xf32>
+  %2 = shape.cstr_broadcastable %0, %1
+  "consume.witness"(%2) : (!shape.witness) -> ()
+  return
+}
+
+// -----
+
+// Do not canonicalize cstr_broadcastable checks with 2 unknowns
+// CHECK-LABEL: @cstr_broadcastable_unknown
+func @cstr_broadcastable_unknown(%arg0 : tensor<?xf32>, %arg1 : tensor<?xf32>) {
+  // CHECK-NEXT: shape.shape_of %arg0
+  // CHECK-NEXT: shape.shape_of %arg1
+  // CHECK-NEXT: shape.cstr_broadcastable
+  // CHECK-NEXT: consume.witness
+  // CHECK-NEXT: return
+  %0 = shape.shape_of %arg0 : tensor<?xf32>
+  %1 = shape.shape_of %arg1 : tensor<?xf32>
+  %2 = shape.cstr_broadcastable %0, %1
+  "consume.witness"(%2) : (!shape.witness) -> ()
+  return
+}
+
+// -----
+
+// Scalars are safe to broadcast to unranked sizes.
+// CHECK-LABEL: @cstr_broadcastable_scalar_unranked
+func @cstr_broadcastable_scalar_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<index>) {
+  // CHECK-NEXT: shape.const_witness true
+  // CHECK-NEXT: consume.witness
+  // CHECK-NEXT: return
+  %0 = shape.shape_of %arg1 : tensor<index>
+  %1 = shape.shape_of %arg0 : tensor<*xf32>
+  %2 = shape.cstr_broadcastable %0, %1
+  "consume.witness"(%2) : (!shape.witness) -> ()
+  return
+}


        


More information about the Mlir-commits mailing list