[Mlir-commits] [mlir] 1ce70c1 - [MLIR] Canonicalize broadcast operations on single shapes
Frederik Gossen
llvmlistbot at llvm.org
Thu Mar 18 01:00:19 PDT 2021
Author: Frederik Gossen
Date: 2021-03-18T08:59:50+01:00
New Revision: 1ce70c15ed3b9c84d6d73abd74f6605bccdf2e7b
URL: https://github.com/llvm/llvm-project/commit/1ce70c15ed3b9c84d6d73abd74f6605bccdf2e7b
DIFF: https://github.com/llvm/llvm-project/commit/1ce70c15ed3b9c84d6d73abd74f6605bccdf2e7b.diff
LOG: [MLIR] Canonicalize broadcast operations on single shapes
This covers cases that are not folded away because the extent tensor type
becomes more concrete in the process.
Differential Revision: https://reviews.llvm.org/D98782
Added:
Modified:
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/test/Dialect/Shape/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index ed8dcfc13549..33719951f3e9 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -414,11 +414,26 @@ struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> {
return failure();
}
};
+
+struct BroadcastForwardSingleOperandPattern
+ : public OpRewritePattern<BroadcastOp> {
+ using OpRewritePattern<BroadcastOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(BroadcastOp op,
+ PatternRewriter &rewriter) const override {
+ if (op.getNumOperands() == 1) {
+ rewriter.replaceOp(op, op.shapes().front());
+ return success();
+ }
+ return failure();
+ }
+};
} // namespace
void BroadcastOp::getCanonicalizationPatterns(
OwningRewritePatternList &patterns, MLIRContext *context) {
- patterns.insert<RemoveDuplicateOperandsPattern<BroadcastOp>>(context);
+ patterns.insert<BroadcastForwardSingleOperandPattern,
+ RemoveDuplicateOperandsPattern<BroadcastOp>>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 53f27e4839cf..3399fe0f4e23 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -1119,3 +1119,15 @@ func @broadcast_on_duplicate_shapes(%a : !shape.shape, %b : !shape.shape)
!shape.shape, !shape.shape, !shape.shape, !shape.shape -> !shape.shape
return %0 : !shape.shape
}
+
+// -----
+
+// CHECK-LABEL: @broadcast_on_single_operand
+// CHECK-SAME: (%[[A:.*]]: tensor<3xindex>)
+func @broadcast_on_single_operand(%a : tensor<3xindex>) {
+ // CHECK-NOT: broadcast
+ // CHECK: "use"(%[[A]])
+ %0 = shape.broadcast %a : tensor<3xindex> -> tensor<?xindex>
+ "use"(%0) : (tensor<?xindex>) -> ()
+ return
+}
More information about the Mlir-commits
mailing list