[Mlir-commits] [mlir] b55f424 - [MLIR] Add canonicalization for `shape.broadcast`
Frederik Gossen
llvmlistbot at llvm.org
Mon Mar 15 02:11:48 PDT 2021
Author: Frederik Gossen
Date: 2021-03-15T10:11:28+01:00
New Revision: b55f424ffcaca1d639ded583b9dc8ba151d92e2d
URL: https://github.com/llvm/llvm-project/commit/b55f424ffcaca1d639ded583b9dc8ba151d92e2d
DIFF: https://github.com/llvm/llvm-project/commit/b55f424ffcaca1d639ded583b9dc8ba151d92e2d.diff
LOG: [MLIR] Add canonicalization for `shape.broadcast`
Remove redundant operands and fold if only one left.
Differential Revision: https://reviews.llvm.org/D98402
Added:
Modified:
mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/test/Dialect/Shape/canonicalize.mlir
mlir/test/Dialect/Shape/invalid.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index ae14d81b9e45..a17be3834dc2 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -89,6 +89,7 @@ def Shape_BroadcastOp : Shape_Op<"broadcast", [Commutative, NoSideEffect]> {
];
let hasFolder = 1;
+ let hasCanonicalizer = 1;
let verifier = [{ return ::verify(*this); }];
}
@@ -277,10 +278,10 @@ def Shape_IsBroadcastableOp : Shape_Op<"is_broadcastable",
};
}];
+ let hasFolder = 1;
let hasCanonicalizer = 1;
let assemblyFormat = "$shapes attr-dict `:` type($shapes)";
- let verifier = [{ return ::verify(*this); }];
}
def Shape_RankOp : Shape_Op<"rank", [NoSideEffect]> {
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 741d065822db..472197c52f4e 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -354,13 +354,16 @@ static LogicalResult verify(AssumingAllOp op) {
//===----------------------------------------------------------------------===//
OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
- if (!operands[1])
- return nullptr;
+ if (operands.size() == 1)
+ return shapes().front();
// TODO: Support folding with more than 2 input shapes
if (shapes().size() > 2)
return nullptr;
+ if (!operands[1])
+ return nullptr;
+
auto rhsShape = llvm::to_vector<6>(
operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
if (rhsShape.empty())
@@ -384,13 +387,40 @@ OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
}
static LogicalResult verify(BroadcastOp op) {
- // Ensure that AssumingAllOp contains at least one operand
- if (op.getNumOperands() < 2)
- return op.emitOpError("required at least 2 input shapes");
-
return verifyShapeOrExtentTensorOp(op);
}
+namespace {
+template <typename OpTy>
+struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> {
+ using OpRewritePattern<OpTy>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(OpTy op,
+ PatternRewriter &rewriter) const override {
+ // Find unique operands.
+ SmallVector<Value, 2> unique;
+ for (Value v : op.getOperands()) {
+ if (!llvm::is_contained(unique, v))
+ unique.push_back(v);
+ }
+
+ // Reduce op to equivalent with unique operands.
+ if (unique.size() < op.getNumOperands()) {
+ rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), unique,
+ op.getAttrs());
+ return success();
+ }
+
+ return failure();
+ }
+};
+} // namespace
+
+void BroadcastOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &patterns, MLIRContext *context) {
+ patterns.insert<RemoveDuplicateOperandsPattern<BroadcastOp>>(context);
+}
+
//===----------------------------------------------------------------------===//
// ConcatOp
//===----------------------------------------------------------------------===//
@@ -772,49 +802,18 @@ void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
// IsBroadcastableOp
//===----------------------------------------------------------------------===//
-static LogicalResult verify(IsBroadcastableOp op) {
- // Ensure that AssumingAllOp contains at least one operand
- if (op.getNumOperands() < 2)
- return op.emitOpError("required at least 2 input shapes");
- return success();
+void IsBroadcastableOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &patterns, MLIRContext *context) {
+ patterns.insert<RemoveDuplicateOperandsPattern<IsBroadcastableOp>>(context);
}
-namespace {
-struct IsBroadcastableCanonicalizationPattern
- : public OpRewritePattern<IsBroadcastableOp> {
- using OpRewritePattern<IsBroadcastableOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(IsBroadcastableOp op,
- PatternRewriter &rewriter) const override {
- // Find unique operands.
- SmallVector<Value, 2> unique;
- for (Value v : op.getOperands()) {
- if (!llvm::is_contained(unique, v))
- unique.push_back(v);
- }
-
- // Can always broadcast fewer than two shapes.
- if (unique.size() < 2) {
- rewriter.replaceOpWithNewOp<mlir::ConstantOp>(op,
- rewriter.getBoolAttr(true));
- return success();
- }
-
- // Reduce op to equivalent with unique operands.
- if (unique.size() < op.getNumOperands()) {
- rewriter.replaceOpWithNewOp<IsBroadcastableOp>(op, rewriter.getI1Type(),
- unique);
- return success();
- }
-
- return failure();
+OpFoldResult IsBroadcastableOp::fold(ArrayRef<Attribute> operands) {
+ // Can always broadcast fewer than two shapes.
+ if (operands.size() < 2) {
+ return BoolAttr::get(getContext(), true);
}
-};
-} // namespace
-void IsBroadcastableOp::getCanonicalizationPatterns(
- OwningRewritePatternList &patterns, MLIRContext *context) {
- patterns.insert<IsBroadcastableCanonicalizationPattern>(context);
+ return nullptr;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 5589221eba82..53f27e4839cf 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -1088,9 +1088,34 @@ func @is_broadcastable_on_same_shape(%shape : !shape.shape) -> i1 {
// CHECK-SAME: (%[[A:.*]]: !shape.shape, %[[B:.*]]: !shape.shape)
func @is_broadcastable_on_duplicate_shapes(%a : !shape.shape, %b : !shape.shape)
-> i1 {
- // CHECK: %[[RES:.*]] = shape.is_broadcastable %[[A]], %[[B]]
+ // CHECK: %[[RES:.*]] = shape.is_broadcastable %[[A]], %[[B]] :
// CHECK: return %[[RES]]
%0 = shape.is_broadcastable %a, %b, %a, %a, %a, %b : !shape.shape,
!shape.shape, !shape.shape, !shape.shape, !shape.shape, !shape.shape
return %0 : i1
}
+
+// -----
+
+// CHECK-LABEL: @broadcast_on_same_shape
+// CHECK-SAME: (%[[SHAPE:.*]]: !shape.shape)
+func @broadcast_on_same_shape(%shape : !shape.shape) -> !shape.shape {
+ // CHECK-NOT: broadcast
+ // CHECK: return %[[SHAPE]]
+ %0 = shape.broadcast %shape, %shape, %shape : !shape.shape, !shape.shape,
+ !shape.shape -> !shape.shape
+ return %0 : !shape.shape
+}
+
+// -----
+
+// CHECK-LABEL: @broadcast_on_duplicate_shapes
+// CHECK-SAME: (%[[A:.*]]: !shape.shape, %[[B:.*]]: !shape.shape)
+func @broadcast_on_duplicate_shapes(%a : !shape.shape, %b : !shape.shape)
+ -> !shape.shape {
+ // CHECK: %[[RES:.*]] = shape.broadcast %[[A]], %[[B]] :
+ // CHECK: return %[[RES]]
+ %0 = shape.broadcast %a, %b, %a, %a, %a, %b : !shape.shape, !shape.shape,
+ !shape.shape, !shape.shape, !shape.shape, !shape.shape -> !shape.shape
+ return %0 : !shape.shape
+}
diff --git a/mlir/test/Dialect/Shape/invalid.mlir b/mlir/test/Dialect/Shape/invalid.mlir
index d685e6766072..d42f0fac4b5c 100644
--- a/mlir/test/Dialect/Shape/invalid.mlir
+++ b/mlir/test/Dialect/Shape/invalid.mlir
@@ -249,18 +249,8 @@ module attributes {shape.lib = @fn} { }
// -----
-func @fn(%arg: !shape.shape) -> i1 {
- // expected-error at +1 {{required at least 2 input shapes}}
- %0 = shape.is_broadcastable %arg : !shape.shape
- return %0 : i1
-}
-
-// -----
-
func @fn(%arg: !shape.shape) -> !shape.witness {
// expected-error at +1 {{required at least 2 input shapes}}
%0 = shape.cstr_broadcastable %arg : !shape.shape
return %0 : !shape.witness
}
-
-
More information about the Mlir-commits
mailing list