[Mlir-commits] [mlir] 2d0ba5e - [mlir][tosa] Fix tosa.reshape failures due to implicit broadcasting
Rob Suderman
llvmlistbot at llvm.org
Thu Jul 29 15:23:10 PDT 2021
Author: Rob Suderman
Date: 2021-07-29T15:21:57-07:00
New Revision: 2d0ba5e1446f0025603bbe064090737c5510bcf4
URL: https://github.com/llvm/llvm-project/commit/2d0ba5e1446f0025603bbe064090737c5510bcf4
DIFF: https://github.com/llvm/llvm-project/commit/2d0ba5e1446f0025603bbe064090737c5510bcf4.diff
LOG: [mlir][tosa] Fix tosa.reshape failures due to implicit broadcasting
Make broadcastable needs the output shape to determine whether the operation
includes additional broadcasting. Include some canonicalizations for TOSA
to remove unneeded reshape.
Reviewed By: NatashaKnk
Differential Revision: https://reviews.llvm.org/D106846
Added:
Modified:
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
mlir/test/Dialect/Tosa/broadcast.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index caa9d055a535f..f17e13c66a449 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1414,6 +1414,8 @@ def Tosa_ReshapeOp: Tosa_Op<"reshape", [
No data conversion happens during a reshape operation.
}];
+ let hasCanonicalizer = 1;
+
let arguments = (ins
Tosa_Tensor:$input1,
I64ArrayAttr:$new_shape
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index e44b2457e9250..ea3fe0aec59fd 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -638,7 +638,8 @@ elementwiseMatchAndRewriteHelper(Operation *operation,
if (newShape.size() != rank) {
operand = rewriter.create<tosa::ReshapeOp>(
- loc, RankedTensorType::get(newShape, type.getElementType()), operand);
+ loc, RankedTensorType::get(newShape, type.getElementType()), operand,
+ rewriter.getI64ArrayAttr(newShape));
}
operands.push_back(operand);
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 5ae614f125f46..cfc9220c97501 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/InliningUtils.h"
#include "mlir/Transforms/RegionUtils.h"
@@ -101,6 +102,48 @@ Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value,
return nullptr;
}
+//===----------------------------------------------------------------------===//
+// Operator Canonicalizers.
+//===----------------------------------------------------------------------===//
+
+struct RemoveReshapeNoop : public OpRewritePattern<tosa::ReshapeOp> {
+ using OpRewritePattern<tosa::ReshapeOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tosa::ReshapeOp op,
+ PatternRewriter &rewriter) const override {
+ if (op.input1().getType() != op.getType())
+ return failure();
+
+ rewriter.replaceOp(op, op.input1());
+ return success();
+ }
+};
+
+struct ReshapeReshapeOptimization : public OpRewritePattern<tosa::ReshapeOp> {
+ using OpRewritePattern<tosa::ReshapeOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tosa::ReshapeOp op,
+ PatternRewriter &rewriter) const override {
+ Value input = op.input1();
+ Operation *definingOp = input.getDefiningOp();
+ if (!definingOp)
+ return failure();
+
+ if (tosa::ReshapeOp reshapeOp = dyn_cast<tosa::ReshapeOp>(definingOp)) {
+ rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
+ op, op.getType(), reshapeOp.input1(), op.new_shape());
+ return success();
+ }
+
+ return failure();
+ }
+};
+
+void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {
+ results.insert<ReshapeReshapeOptimization, RemoveReshapeNoop>(context);
+}
+
//===----------------------------------------------------------------------===//
// Operator Folders.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
index 98df91198aca1..4e3706ec992b9 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
@@ -143,7 +143,8 @@ static LogicalResult reshapeLowerToHigher(PatternRewriter &rewriter,
SmallVector<int64_t, 4> reshapeOutputShape;
- computeReshapeOutput(higherRankShape, lowerRankShape, reshapeOutputShape);
+ computeReshapeOutput(outputType.getShape(), lowerRankShape,
+ reshapeOutputShape);
auto reshapeInputType = lowerTensorValue.getType().cast<RankedTensorType>();
auto reshapeOutputType = RankedTensorType::get(
diff --git a/mlir/test/Dialect/Tosa/broadcast.mlir b/mlir/test/Dialect/Tosa/broadcast.mlir
index 98d2352b739ad..16fb75450675b 100644
--- a/mlir/test/Dialect/Tosa/broadcast.mlir
+++ b/mlir/test/Dialect/Tosa/broadcast.mlir
@@ -136,6 +136,15 @@ func @test_broadcast18(%arg0: tensor<14x1xf32>, %arg1: tensor<1x15xf32>) -> tens
return %0 : tensor<14x15xf32>
}
+// -----
+// CHECK-LABEL: broadcast19
+func @broadcast19(%arg0: tensor<64x64x1xf32>, %arg1: tensor<1x17xf32>) -> (tensor<64x64x17xf32> ) {
+ // CHECK: reshape
+ // CHECK: sub
+ %0 = "tosa.sub"(%arg0, %arg1) : (tensor<64x64x1xf32>, tensor<1x17xf32>) -> tensor<64x64x17xf32>
+ return %0 : tensor<64x64x17xf32>
+}
+
// -----
// CHECK-LABEL: broadcast_mul
func @test_broadcast_mul(%arg0: tensor<15x14xi32>, %arg1: tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32> {
More information about the Mlir-commits
mailing list