[Mlir-commits] [mlir] 2d0ba5e - [mlir][tosa] Fix tosa.reshape failures due to implicit broadcasting

Rob Suderman llvmlistbot at llvm.org
Thu Jul 29 15:23:10 PDT 2021


Author: Rob Suderman
Date: 2021-07-29T15:21:57-07:00
New Revision: 2d0ba5e1446f0025603bbe064090737c5510bcf4

URL: https://github.com/llvm/llvm-project/commit/2d0ba5e1446f0025603bbe064090737c5510bcf4
DIFF: https://github.com/llvm/llvm-project/commit/2d0ba5e1446f0025603bbe064090737c5510bcf4.diff

LOG: [mlir][tosa] Fix tosa.reshape failures due to implicit broadcasting

Make broadcastable needs the output shape to determine whether the operation
includes additional broadcasting. Include some canonicalizations for TOSA
to remove unneeded reshape.

Reviewed By: NatashaKnk

Differential Revision: https://reviews.llvm.org/D106846

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
    mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
    mlir/test/Dialect/Tosa/broadcast.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index caa9d055a535f..f17e13c66a449 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1414,6 +1414,8 @@ def Tosa_ReshapeOp: Tosa_Op<"reshape", [
     No data conversion happens during a reshape operation.
   }];
 
+  let hasCanonicalizer = 1;
+
   let arguments = (ins
     Tosa_Tensor:$input1,
     I64ArrayAttr:$new_shape

diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index e44b2457e9250..ea3fe0aec59fd 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -638,7 +638,8 @@ elementwiseMatchAndRewriteHelper(Operation *operation,
 
     if (newShape.size() != rank) {
       operand = rewriter.create<tosa::ReshapeOp>(
-          loc, RankedTensorType::get(newShape, type.getElementType()), operand);
+          loc, RankedTensorType::get(newShape, type.getElementType()), operand,
+          rewriter.getI64ArrayAttr(newShape));
     }
 
     operands.push_back(operand);

diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 5ae614f125f46..cfc9220c97501 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
 #include "mlir/Transforms/FoldUtils.h"
 #include "mlir/Transforms/InliningUtils.h"
 #include "mlir/Transforms/RegionUtils.h"
@@ -101,6 +102,48 @@ Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value,
   return nullptr;
 }
 
+//===----------------------------------------------------------------------===//
+// Operator Canonicalizers.
+//===----------------------------------------------------------------------===//
+
+struct RemoveReshapeNoop : public OpRewritePattern<tosa::ReshapeOp> {
+  using OpRewritePattern<tosa::ReshapeOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tosa::ReshapeOp op,
+                                PatternRewriter &rewriter) const override {
+    if (op.input1().getType() != op.getType())
+      return failure();
+
+    rewriter.replaceOp(op, op.input1());
+    return success();
+  }
+};
+
+struct ReshapeReshapeOptimization : public OpRewritePattern<tosa::ReshapeOp> {
+  using OpRewritePattern<tosa::ReshapeOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tosa::ReshapeOp op,
+                                PatternRewriter &rewriter) const override {
+    Value input = op.input1();
+    Operation *definingOp = input.getDefiningOp();
+    if (!definingOp)
+      return failure();
+
+    if (tosa::ReshapeOp reshapeOp = dyn_cast<tosa::ReshapeOp>(definingOp)) {
+      rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
+          op, op.getType(), reshapeOp.input1(), op.new_shape());
+      return success();
+    }
+
+    return failure();
+  }
+};
+
+void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+                                            MLIRContext *context) {
+  results.insert<ReshapeReshapeOptimization, RemoveReshapeNoop>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // Operator Folders.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
index 98df91198aca1..4e3706ec992b9 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
@@ -143,7 +143,8 @@ static LogicalResult reshapeLowerToHigher(PatternRewriter &rewriter,
 
   SmallVector<int64_t, 4> reshapeOutputShape;
 
-  computeReshapeOutput(higherRankShape, lowerRankShape, reshapeOutputShape);
+  computeReshapeOutput(outputType.getShape(), lowerRankShape,
+                       reshapeOutputShape);
 
   auto reshapeInputType = lowerTensorValue.getType().cast<RankedTensorType>();
   auto reshapeOutputType = RankedTensorType::get(

diff  --git a/mlir/test/Dialect/Tosa/broadcast.mlir b/mlir/test/Dialect/Tosa/broadcast.mlir
index 98d2352b739ad..16fb75450675b 100644
--- a/mlir/test/Dialect/Tosa/broadcast.mlir
+++ b/mlir/test/Dialect/Tosa/broadcast.mlir
@@ -136,6 +136,15 @@ func @test_broadcast18(%arg0: tensor<14x1xf32>, %arg1: tensor<1x15xf32>) -> tens
   return %0 : tensor<14x15xf32>
 }
 
+// -----
+// CHECK-LABEL: broadcast19
+func @broadcast19(%arg0: tensor<64x64x1xf32>, %arg1: tensor<1x17xf32>) -> (tensor<64x64x17xf32> ) {
+  //  CHECK: reshape
+  //  CHECK: sub
+  %0 = "tosa.sub"(%arg0, %arg1) : (tensor<64x64x1xf32>, tensor<1x17xf32>) -> tensor<64x64x17xf32>
+  return %0 : tensor<64x64x17xf32>
+}
+
 // -----
 // CHECK-LABEL: broadcast_mul
 func @test_broadcast_mul(%arg0: tensor<15x14xi32>, %arg1: tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32> {


        


More information about the Mlir-commits mailing list