[Mlir-commits] [mlir] 858d488 - [MLIR][Shape] Ensure to preserve op type of `shape.broadcast`
Frederik Gossen
llvmlistbot at llvm.org
Mon Apr 26 08:56:20 PDT 2021
Author: Frederik Gossen
Date: 2021-04-26T17:55:39+02:00
New Revision: 858d4885dcc2242910971607bf46428d5b563d95
URL: https://github.com/llvm/llvm-project/commit/858d4885dcc2242910971607bf46428d5b563d95
DIFF: https://github.com/llvm/llvm-project/commit/858d4885dcc2242910971607bf46428d5b563d95.diff
LOG: [MLIR][Shape] Ensure to preserve op type of `shape.broadcast`
Ensure to preserve the correct type during when folding and canonicalization.
`shape.broadcast` of of a single operand can only be folded away if the argument
type is correct.
Differential Revision: https://reviews.llvm.org/D101158
Added:
Modified:
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/test/Dialect/Shape/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 1bec838dfbb2..392bcddb920a 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -470,8 +470,12 @@ void AssumingAllOp::build(OpBuilder &b, OperationState &state,
//===----------------------------------------------------------------------===//
OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
- if (shapes().size() == 1)
+ if (shapes().size() == 1) {
+ // Otherwise, we need a cast which would be a canonicalization, not folding.
+ if (shapes().front().getType() != getType())
+ return nullptr;
return shapes().front();
+ }
// TODO: Support folding with more than 2 input shapes
if (shapes().size() > 2)
@@ -556,8 +560,10 @@ struct BroadcastForwardSingleOperandPattern
PatternRewriter &rewriter) const override {
if (op.getNumOperands() == 1) {
Value uniqueShapeOperand = op.shapes().front();
- rewriter.replaceOp(op, uniqueShapeOperand);
- return success();
+ if (uniqueShapeOperand.getType() == op.getType()) {
+ rewriter.replaceOp(op, uniqueShapeOperand);
+ return success();
+ }
}
return failure();
}
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 9c698c493446..95d0de628b75 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -1217,10 +1217,21 @@ func @broadcast_on_duplicate_shapes(%a : !shape.shape, %b : !shape.shape)
// -----
// CHECK-LABEL: @broadcast_on_single_operand
-// CHECK-SAME: (%[[A:.*]]: tensor<3xindex>)
-func @broadcast_on_single_operand(%a : tensor<3xindex>) {
+// CHECK-SAME: (%[[A:.*]]: tensor<?xindex>)
+func @broadcast_on_single_operand(%a : tensor<?xindex>) {
// CHECK-NOT: broadcast
// CHECK: "use"(%[[A]])
+ %0 = shape.broadcast %a : tensor<?xindex> -> tensor<?xindex>
+ "use"(%0) : (tensor<?xindex>) -> ()
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @broadcast_on_single_operand
+// CHECK-SAME: (%[[A:.*]]: tensor<3xindex>)
+func @broadcast_on_single_operand(%a : tensor<3xindex>) {
+ // CHECK: broadcast %[[A]]
%0 = shape.broadcast %a : tensor<3xindex> -> tensor<?xindex>
"use"(%0) : (tensor<?xindex>) -> ()
return
More information about the Mlir-commits
mailing list