[Mlir-commits] [mlir] f0c51cb - [MLIR][Shape] Add canonicalizations for `shape.broadcast`
Frederik Gossen
llvmlistbot at llvm.org
Thu Apr 22 05:11:46 PDT 2021
Author: Frederik Gossen
Date: 2021-04-22T14:11:23+02:00
New Revision: f0c51cb2d4562fb97280326202eab23ff7b29e3f
URL: https://github.com/llvm/llvm-project/commit/f0c51cb2d4562fb97280326202eab23ff7b29e3f
DIFF: https://github.com/llvm/llvm-project/commit/f0c51cb2d4562fb97280326202eab23ff7b29e3f.diff
LOG: [MLIR][Shape] Add canonicalizations for `shape.broadcast`
Eliminate empty shapes from the operands, partially fold all constant shape
operands, and fix normal folding.
Differential Revision: https://reviews.llvm.org/D100634
Added:
Modified:
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/test/Dialect/Shape/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 1455271071b05..1bec838dfbb2c 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -470,34 +470,26 @@ void AssumingAllOp::build(OpBuilder &b, OperationState &state,
//===----------------------------------------------------------------------===//
OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
- if (operands.size() == 1)
+ if (shapes().size() == 1)
return shapes().front();
// TODO: Support folding with more than 2 input shapes
if (shapes().size() > 2)
return nullptr;
- if (!operands[1])
- return nullptr;
-
- auto rhsShape = llvm::to_vector<6>(
- operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
- if (rhsShape.empty())
- return shapes()[0];
-
- if (!operands[0])
+ if (!operands[0] || !operands[1])
return nullptr;
-
auto lhsShape = llvm::to_vector<6>(
operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
- if (lhsShape.empty())
- return shapes()[1];
-
+ auto rhsShape = llvm::to_vector<6>(
+ operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
SmallVector<int64_t, 6> resultShape;
+
// If the shapes are not compatible, we can't fold it.
// TODO: Fold to an "error".
if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape))
return nullptr;
+
Builder builder(getContext());
return builder.getIndexTensorAttr(resultShape);
}
@@ -531,6 +523,31 @@ struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> {
}
};
+template <typename OpTy>
+struct RemoveEmptyShapeOperandsPattern : public OpRewritePattern<OpTy> {
+ using OpRewritePattern<OpTy>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(OpTy op,
+ PatternRewriter &rewriter) const override {
+ auto isPotentiallyNonEmptyShape = [](Value shape) {
+ if (auto constShape = shape.getDefiningOp<ConstShapeOp>())
+ return constShape.shape().size() != 0;
+ return true;
+ };
+ auto newOperands = llvm::to_vector<8>(
+ llvm::make_filter_range(op->getOperands(), isPotentiallyNonEmptyShape));
+
+ // Reduce op to equivalent without empty shape operands.
+ if (newOperands.size() < op.getNumOperands()) {
+ rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands,
+ op->getAttrs());
+ return success();
+ }
+
+ return failure();
+ }
+};
+
struct BroadcastForwardSingleOperandPattern
: public OpRewritePattern<BroadcastOp> {
using OpRewritePattern<BroadcastOp>::OpRewritePattern;
@@ -538,18 +555,59 @@ struct BroadcastForwardSingleOperandPattern
LogicalResult matchAndRewrite(BroadcastOp op,
PatternRewriter &rewriter) const override {
if (op.getNumOperands() == 1) {
- rewriter.replaceOp(op, op.shapes().front());
+ Value uniqueShapeOperand = op.shapes().front();
+ rewriter.replaceOp(op, uniqueShapeOperand);
return success();
}
return failure();
}
};
+
+struct BroadcastFoldConstantOperandsPattern
+ : public OpRewritePattern<BroadcastOp> {
+ using OpRewritePattern<BroadcastOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(BroadcastOp op,
+ PatternRewriter &rewriter) const override {
+ SmallVector<int64_t, 8> foldedConstantShape;
+ SmallVector<Value, 8> newShapeOperands;
+ for (Value shape : op.shapes()) {
+ if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
+ SmallVector<int64_t, 8> newFoldedConstantShape;
+ if (OpTrait::util::getBroadcastedShape(
+ foldedConstantShape,
+ llvm::to_vector<8>(constShape.shape().getValues<int64_t>()),
+ newFoldedConstantShape)) {
+ foldedConstantShape = newFoldedConstantShape;
+ continue;
+ }
+ }
+ newShapeOperands.push_back(shape);
+ }
+
+ // Need at least two constant operands to fold anything.
+ if (op.getNumOperands() - newShapeOperands.size() < 2)
+ return failure();
+
+ auto foldedConstantOperandsTy = RankedTensorType::get(
+ {static_cast<int64_t>(foldedConstantShape.size())},
+ rewriter.getIndexType());
+ newShapeOperands.push_back(rewriter.create<ConstShapeOp>(
+ op.getLoc(), foldedConstantOperandsTy,
+ rewriter.getIndexTensorAttr(foldedConstantShape)));
+ rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(),
+ newShapeOperands);
+ return success();
+ }
+};
} // namespace
void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
- patterns.add<BroadcastForwardSingleOperandPattern,
- RemoveDuplicateOperandsPattern<BroadcastOp>>(context);
+ patterns.add<BroadcastFoldConstantOperandsPattern,
+ BroadcastForwardSingleOperandPattern,
+ RemoveDuplicateOperandsPattern<BroadcastOp>,
+ RemoveEmptyShapeOperandsPattern<BroadcastOp>>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 7a5170ec3504a..9c698c4934463 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -120,6 +120,36 @@ func @f() -> !shape.shape {
// -----
+// All but one operands are known empty shapes.
+// CHECK-LABEL: @all_but_one_empty
+// CHECK-SAME: (%[[ARG:.*]]: !shape.shape)
+func @all_but_one_empty(%arg0 : !shape.shape) -> !shape.shape {
+ // CHECK: return %[[ARG]]
+ %0 = shape.const_shape [] : !shape.shape
+ %1 = shape.const_shape [] : tensor<0xindex>
+ %2 = shape.broadcast %0, %arg0, %1, %0 : !shape.shape, !shape.shape,
+ tensor<0xindex>, !shape.shape -> !shape.shape
+ return %2 : !shape.shape
+}
+
+// -----
+
+// Partial folding.
+// CHECK-LABEL: @partial_folding
+// CHECK-SAME: (%[[ARG:.*]]: !shape.shape)
+func @partial_folding(%arg0 : !shape.shape) -> !shape.shape {
+ // CHECK: %[[CST_SHAPE:.*]] = constant dense<[1, 2, 3]> : tensor<3xindex>
+ // CHECK: %[[RESULT:.*]] = shape.broadcast %[[ARG]], %[[CST_SHAPE]] : !shape.shape, tensor<3xindex> -> !shape.shape
+ // CHECK: return %[[RESULT]]
+ %0 = shape.const_shape [2, 1] : !shape.shape
+ %1 = shape.const_shape [1, 2, 3] : tensor<3xindex>
+ %2 = shape.broadcast %0, %arg0, %1, %0 : !shape.shape, !shape.shape,
+ tensor<3xindex>, !shape.shape -> !shape.shape
+ return %2 : !shape.shape
+}
+
+// -----
+
// Incompatible shapes. No folding.
// CHECK-LABEL: func @f
func @f() -> !shape.shape {
More information about the Mlir-commits
mailing list