[Mlir-commits] [mlir] f0c51cb - [MLIR][Shape] Add canonicalizations for `shape.broadcast`

Frederik Gossen llvmlistbot at llvm.org
Thu Apr 22 05:11:46 PDT 2021


Author: Frederik Gossen
Date: 2021-04-22T14:11:23+02:00
New Revision: f0c51cb2d4562fb97280326202eab23ff7b29e3f

URL: https://github.com/llvm/llvm-project/commit/f0c51cb2d4562fb97280326202eab23ff7b29e3f
DIFF: https://github.com/llvm/llvm-project/commit/f0c51cb2d4562fb97280326202eab23ff7b29e3f.diff

LOG: [MLIR][Shape] Add canonicalizations for `shape.broadcast`

Eliminate empty shapes from the operands, partially fold all constant shape
operands, and fix normal folding.

Differential Revision: https://reviews.llvm.org/D100634

Added: 
    

Modified: 
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/test/Dialect/Shape/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 1455271071b05..1bec838dfbb2c 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -470,34 +470,26 @@ void AssumingAllOp::build(OpBuilder &b, OperationState &state,
 //===----------------------------------------------------------------------===//
 
 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
-  if (operands.size() == 1)
+  if (shapes().size() == 1)
     return shapes().front();
 
   // TODO: Support folding with more than 2 input shapes
   if (shapes().size() > 2)
     return nullptr;
 
-  if (!operands[1])
-    return nullptr;
-
-  auto rhsShape = llvm::to_vector<6>(
-      operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
-  if (rhsShape.empty())
-    return shapes()[0];
-
-  if (!operands[0])
+  if (!operands[0] || !operands[1])
     return nullptr;
-
   auto lhsShape = llvm::to_vector<6>(
       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
-  if (lhsShape.empty())
-    return shapes()[1];
-
+  auto rhsShape = llvm::to_vector<6>(
+      operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
   SmallVector<int64_t, 6> resultShape;
+
   // If the shapes are not compatible, we can't fold it.
   // TODO: Fold to an "error".
   if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape))
     return nullptr;
+
   Builder builder(getContext());
   return builder.getIndexTensorAttr(resultShape);
 }
@@ -531,6 +523,31 @@ struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> {
   }
 };
 
+template <typename OpTy>
+struct RemoveEmptyShapeOperandsPattern : public OpRewritePattern<OpTy> {
+  using OpRewritePattern<OpTy>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(OpTy op,
+                                PatternRewriter &rewriter) const override {
+    auto isPotentiallyNonEmptyShape = [](Value shape) {
+      if (auto constShape = shape.getDefiningOp<ConstShapeOp>())
+        return constShape.shape().size() != 0;
+      return true;
+    };
+    auto newOperands = llvm::to_vector<8>(
+        llvm::make_filter_range(op->getOperands(), isPotentiallyNonEmptyShape));
+
+    // Reduce op to equivalent without empty shape operands.
+    if (newOperands.size() < op.getNumOperands()) {
+      rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands,
+                                        op->getAttrs());
+      return success();
+    }
+
+    return failure();
+  }
+};
+
 struct BroadcastForwardSingleOperandPattern
     : public OpRewritePattern<BroadcastOp> {
   using OpRewritePattern<BroadcastOp>::OpRewritePattern;
@@ -538,18 +555,59 @@ struct BroadcastForwardSingleOperandPattern
   LogicalResult matchAndRewrite(BroadcastOp op,
                                 PatternRewriter &rewriter) const override {
     if (op.getNumOperands() == 1) {
-      rewriter.replaceOp(op, op.shapes().front());
+      Value uniqueShapeOperand = op.shapes().front();
+      rewriter.replaceOp(op, uniqueShapeOperand);
       return success();
     }
     return failure();
   }
 };
+
+struct BroadcastFoldConstantOperandsPattern
+    : public OpRewritePattern<BroadcastOp> {
+  using OpRewritePattern<BroadcastOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(BroadcastOp op,
+                                PatternRewriter &rewriter) const override {
+    SmallVector<int64_t, 8> foldedConstantShape;
+    SmallVector<Value, 8> newShapeOperands;
+    for (Value shape : op.shapes()) {
+      if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
+        SmallVector<int64_t, 8> newFoldedConstantShape;
+        if (OpTrait::util::getBroadcastedShape(
+                foldedConstantShape,
+                llvm::to_vector<8>(constShape.shape().getValues<int64_t>()),
+                newFoldedConstantShape)) {
+          foldedConstantShape = newFoldedConstantShape;
+          continue;
+        }
+      }
+      newShapeOperands.push_back(shape);
+    }
+
+    // Need at least two constant operands to fold anything.
+    if (op.getNumOperands() - newShapeOperands.size() < 2)
+      return failure();
+
+    auto foldedConstantOperandsTy = RankedTensorType::get(
+        {static_cast<int64_t>(foldedConstantShape.size())},
+        rewriter.getIndexType());
+    newShapeOperands.push_back(rewriter.create<ConstShapeOp>(
+        op.getLoc(), foldedConstantOperandsTy,
+        rewriter.getIndexTensorAttr(foldedConstantShape)));
+    rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(),
+                                             newShapeOperands);
+    return success();
+  }
+};
 } // namespace
 
 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                               MLIRContext *context) {
-  patterns.add<BroadcastForwardSingleOperandPattern,
-               RemoveDuplicateOperandsPattern<BroadcastOp>>(context);
+  patterns.add<BroadcastFoldConstantOperandsPattern,
+               BroadcastForwardSingleOperandPattern,
+               RemoveDuplicateOperandsPattern<BroadcastOp>,
+               RemoveEmptyShapeOperandsPattern<BroadcastOp>>(context);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 7a5170ec3504a..9c698c4934463 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -120,6 +120,36 @@ func @f() -> !shape.shape {
 
 // -----
 
+// All but one operands are known empty shapes.
+// CHECK-LABEL: @all_but_one_empty
+// CHECK-SAME:  (%[[ARG:.*]]: !shape.shape)
+func @all_but_one_empty(%arg0 : !shape.shape) -> !shape.shape {
+  // CHECK: return %[[ARG]]
+  %0 = shape.const_shape [] : !shape.shape
+  %1 = shape.const_shape [] : tensor<0xindex>
+  %2 = shape.broadcast %0, %arg0, %1, %0 : !shape.shape, !shape.shape,
+      tensor<0xindex>, !shape.shape -> !shape.shape
+  return %2 : !shape.shape
+}
+
+// -----
+
+// Partial folding.
+// CHECK-LABEL: @partial_folding
+// CHECK-SAME:  (%[[ARG:.*]]: !shape.shape)
+func @partial_folding(%arg0 : !shape.shape) -> !shape.shape {
+  // CHECK: %[[CST_SHAPE:.*]] = constant dense<[1, 2, 3]> : tensor<3xindex>
+  // CHECK: %[[RESULT:.*]] = shape.broadcast %[[ARG]], %[[CST_SHAPE]] : !shape.shape, tensor<3xindex> -> !shape.shape
+  // CHECK: return %[[RESULT]]
+  %0 = shape.const_shape [2, 1] : !shape.shape
+  %1 = shape.const_shape [1, 2, 3] : tensor<3xindex>
+  %2 = shape.broadcast %0, %arg0, %1, %0 : !shape.shape, !shape.shape,
+      tensor<3xindex>, !shape.shape -> !shape.shape
+  return %2 : !shape.shape
+}
+
+// -----
+
 // Incompatible shapes. No folding.
 // CHECK-LABEL: func @f
 func @f() -> !shape.shape {


        


More information about the Mlir-commits mailing list