[Mlir-commits] [mlir] 64f694a - [mlir][tosa] Move tosa canonicalizers to optional optimization pass

Rob Suderman llvmlistbot at llvm.org
Thu Dec 16 23:36:20 PST 2021


Author: Aaron DeBattista
Date: 2021-12-16T23:33:54-08:00
New Revision: 64f694acaf9279c0902f2f150b48191fb91057fb

URL: https://github.com/llvm/llvm-project/commit/64f694acaf9279c0902f2f150b48191fb91057fb
DIFF: https://github.com/llvm/llvm-project/commit/64f694acaf9279c0902f2f150b48191fb91057fb.diff

LOG: [mlir][tosa] Move tosa canonicalizers to optional optimization pass

TOSA's canonicalizers that change dense operations should be moved to a
seperate optimization pass to avoid canonicalizing to operations not supported
for relevant backends.

Reviewed By: rsuderman

Differential Revision: https://reviews.llvm.org/D115890

Added: 
    mlir/lib/Dialect/Tosa/Transforms/TosaOptimization.cpp
    mlir/test/Dialect/Tosa/operation_optimization.mlir

Modified: 
    mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
    mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
    mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
    mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
    mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
    mlir/test/Dialect/Tosa/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 982880e027e04..597813481d684 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -118,8 +118,6 @@ def Tosa_Conv2DOp : Tosa_Op<"conv2d", [
   let builders = [Tosa_ConvOpQuantInfoBuilder];
 
   let verifier = [{ return verifyConvOp(*this); }];
-
-  let hasCanonicalizer = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -187,8 +185,6 @@ def Tosa_DepthwiseConv2DOp : Tosa_Op<"depthwise_conv2d", [
   let builders = [Tosa_ConvOpQuantInfoBuilder];
 
   let verifier = [{ return verifyConvOp(*this); }];
-
-  let hasCanonicalizer = 1;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
index 278402eb93b01..e94daca0df728 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
@@ -22,6 +22,7 @@ namespace tosa {
 std::unique_ptr<Pass> createTosaDecomposeTransposeConvPass();
 std::unique_ptr<Pass> createTosaInferShapesPass();
 std::unique_ptr<Pass> createTosaMakeBroadcastablePass();
+std::unique_ptr<Pass> createTosaOptimizationPass();
 std::unique_ptr<Pass> createTosaTestQuantUtilAPIPass();
 
 #define GEN_PASS_REGISTRATION

diff  --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
index 7d6af621675b8..4a75482ba832a 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
@@ -1,4 +1,4 @@
-//===-- Passes.td - TOSA optimization pass declarations ----*- tablegen -*-===//
+//===-- Passes.td - TOSA pass declarations ----*- tablegen -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,7 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 //
-// This file declares the optimization passes for the TOSA Dialect in MLIR.
+// This file declares the passes for the TOSA Dialect in MLIR.
 //
 //===----------------------------------------------------------------------===//
 
@@ -58,4 +58,13 @@ def TosaMakeBroadcastable : FunctionPass<"tosa-make-broadcastable"> {
   let constructor = "createTosaMakeBroadcastablePass()";
 }
 
+def TosaOptimization : FunctionPass<"tosa-optimization"> {
+  let summary = "TOSA operation optimizations";
+  let description = [{
+    "Pass to perform optimizations on TOSA operations"
+  }];
+
+  let constructor = "createTosaOptimizationPass()";
+}
+
 #endif // MLIR_DIALECT_TOSA_TRANSFORMS_PASSES

diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 9809e57e3a9a0..f61ce68893d69 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -423,197 +423,6 @@ void PadOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
   results.insert<MaterializePadValue>(context);
 }
 
-struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(tosa::Conv2DOp op,
-                                PatternRewriter &rewriter) const override {
-    Value input = op.input();
-    Value weight = op.weight();
-    ShapedType inputType = input.getType().cast<ShapedType>();
-    ShapedType weightType = weight.getType().cast<ShapedType>();
-    ShapedType resultType = op.getType().cast<ShapedType>();
-
-    if (!inputType.hasStaticShape() || !weightType.hasRank()) {
-      return failure();
-    }
-
-    for (Attribute pad : op.pad().getValue()) {
-      if (!pad.cast<IntegerAttr>().getValue().isZero()) {
-        return failure();
-      }
-    }
-
-    // Stride must be 1 for this optimization.
-    for (Attribute stride : op.stride().getValue()) {
-      if (!stride.cast<IntegerAttr>().getValue().isOne()) {
-        return failure();
-      }
-    }
-
-    // Only works for a 1x1 kernel.
-    ArrayRef<int64_t> weightShape = weightType.getShape();
-    if (weightShape[1] != 1 || weightShape[2] != 1) {
-      return failure();
-    }
-
-    // Reshape input to [N,IH,IW,IC] -> [N * IH * IW, IC].
-    ArrayRef<int64_t> inputShape = inputType.getShape();
-    llvm::SmallVector<int64_t, 2> revisedInputShape{
-        inputShape[0] * inputShape[1] * inputShape[2], inputShape[3]};
-    auto revisedInputShapeType =
-        RankedTensorType::get(revisedInputShape, inputType.getElementType());
-    auto reshapedInput = rewriter
-                             .create<tosa::ReshapeOp>(
-                                 op.getLoc(), revisedInputShapeType, input,
-                                 rewriter.getI64ArrayAttr(revisedInputShape))
-                             .getResult();
-
-    // Reshape kernel to [OC,KH,KW,IC] -> [OC, IC].
-    llvm::SmallVector<int64_t, 2> revisedWeightShape{weightShape[0],
-                                                     weightShape[3]};
-    auto revisedWeightShapeType =
-        RankedTensorType::get(revisedWeightShape, weightType.getElementType());
-    auto reshapedWeight = rewriter
-                              .create<tosa::ReshapeOp>(
-                                  op.getLoc(), revisedWeightShapeType, weight,
-                                  rewriter.getI64ArrayAttr(revisedWeightShape))
-                              .getResult();
-
-    // Perform a fully connected network over the reshaped input and weight.
-    llvm::SmallVector<int64_t, 2> fullyConnectedShape{
-        inputShape[0] * inputShape[1] * inputShape[2], weightShape[0]};
-    auto fullyConnectedShapeType =
-        RankedTensorType::get(fullyConnectedShape, resultType.getElementType());
-
-    Value fullyConnectedValue;
-    if (op.quantization_info()) {
-      fullyConnectedValue =
-          rewriter
-              .create<tosa::FullyConnectedOp>(
-                  op.getLoc(), fullyConnectedShapeType, reshapedInput,
-                  reshapedWeight, op.bias(), op.quantization_info().getValue())
-              .getResult();
-    } else {
-      fullyConnectedValue = rewriter
-                                .create<tosa::FullyConnectedOp>(
-                                    op.getLoc(), fullyConnectedShapeType,
-                                    reshapedInput, reshapedWeight, op.bias())
-                                .getResult();
-    }
-
-    // Reshape output to [N, IH, IW, OC].
-    llvm::SmallVector<int64_t, 4> outputShape{inputShape[0], inputShape[1],
-                                              inputShape[2], weightShape[0]};
-    rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
-        op, resultType, fullyConnectedValue,
-        rewriter.getI64ArrayAttr(outputShape));
-    return success();
-  }
-};
-
-void Conv2DOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
-                                           MLIRContext *context) {
-  results.insert<Conv2DIsFullyConnected>(context);
-}
-
-struct DepthwiseConv2DMulOptimization
-    : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(tosa::DepthwiseConv2DOp op,
-                                PatternRewriter &rewriter) const override {
-    Value input = op.input();
-    Value weight = op.weight();
-    ShapedType inputType = input.getType().cast<ShapedType>();
-    ShapedType weightType = weight.getType().cast<ShapedType>();
-    ShapedType resultType = op.output().getType().cast<ShapedType>();
-    Type inputEType = inputType.getElementType();
-
-    if (!(inputType.hasStaticShape() && weightType.hasStaticShape() &&
-          resultType.hasStaticShape())) {
-      return failure();
-    }
-
-    // Quantization information needs to still be performed.
-    if (op.quantization_info() || !inputEType.isa<FloatType>()) {
-      return failure();
-    }
-
-    // Stride must be 1 for this optimization.
-    for (Attribute stride : op.stride().getValue()) {
-      if (!stride.cast<IntegerAttr>().getValue().isOne()) {
-        return failure();
-      }
-    }
-
-    // Only works for a 1x1 kernel.
-    ArrayRef<int64_t> weightShape = weightType.getShape();
-    if (weightShape[0] != 1 || weightShape[1] != 1) {
-      return failure();
-    }
-
-    // Reshape input to [N, H, W, C] -> [N, H, W, C, 1].
-    ArrayRef<int64_t> inputShape = inputType.getShape();
-    llvm::SmallVector<int64_t, 2> revisedInputShape{
-        inputShape[0], inputShape[1], inputShape[2], inputShape[3], 1};
-    auto revisedInputShapeType = RankedTensorType::get(
-        revisedInputShape,
-        input.getType().dyn_cast<RankedTensorType>().getElementType());
-    auto reshapedInput = rewriter
-                             .create<tosa::ReshapeOp>(
-                                 op.getLoc(), revisedInputShapeType, input,
-                                 rewriter.getI64ArrayAttr(revisedInputShape))
-                             .getResult();
-
-    // Reshape kernel to [KH, KW, C, M] -> [1, 1, 1, C, M].
-    llvm::SmallVector<int64_t, 2> revisedWeightShape{1, 1, 1, weightShape[2],
-                                                     weightShape[3]};
-    auto revisedWeightShapeType = RankedTensorType::get(
-        revisedWeightShape,
-        weight.getType().dyn_cast<RankedTensorType>().getElementType());
-    auto reshapedWeight = rewriter
-                              .create<tosa::ReshapeOp>(
-                                  op.getLoc(), revisedWeightShapeType, weight,
-                                  rewriter.getI64ArrayAttr(revisedWeightShape))
-                              .getResult();
-
-    // Perform an elementwise mul over the reshaped input and weight.
-    llvm::SmallVector<int64_t, 2> mulShape{inputShape[0], inputShape[1],
-                                           inputShape[2], inputShape[3],
-                                           weightShape[3]};
-    auto mulShapeType = RankedTensorType::get(
-        mulShape,
-        weight.getType().dyn_cast<RankedTensorType>().getElementType());
-    Value mulValue =
-        rewriter
-            .create<tosa::MulOp>(op.getLoc(), mulShapeType, reshapedInput,
-                                 reshapedWeight, /*shift=*/0)
-            .getResult();
-
-    // Reshape output to [N, H, W, C * M].
-    auto outputShape = op.output().getType().cast<ShapedType>().getShape();
-    auto outputShapeType = RankedTensorType::get(
-        outputShape,
-        input.getType().dyn_cast<RankedTensorType>().getElementType());
-    auto outputValue =
-        rewriter.create<tosa::ReshapeOp>(op.getLoc(), outputShapeType, mulValue,
-                                         rewriter.getI64ArrayAttr(outputShape));
-
-    // Add in the bias.
-    rewriter
-        .replaceOpWithNewOp<tosa::AddOp>(op, outputShapeType, outputValue,
-                                         op.bias())
-        .getResult();
-    return success();
-  }
-};
-
-void DepthwiseConv2DOp::getCanonicalizationPatterns(
-    OwningRewritePatternList &results, MLIRContext *context) {
-  results.insert<DepthwiseConv2DMulOptimization>(context);
-}
-
 struct MaxPool2dIsNoOp : public OpRewritePattern<tosa::MaxPool2dOp> {
   using OpRewritePattern::OpRewritePattern;
 
@@ -747,7 +556,8 @@ OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
 // TOSA Operator Verifiers.
 //===----------------------------------------------------------------------===//
 
-template <typename T> static LogicalResult verifyConvOp(T op) {
+template <typename T>
+static LogicalResult verifyConvOp(T op) {
   // All TOSA conv ops have an input() and weight().
   auto inputType = op.input().getType().template dyn_cast<RankedTensorType>();
   auto weightType = op.weight().getType().template dyn_cast<RankedTensorType>();

diff  --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
index b5e90bbeecc59..016575fc7735b 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRTosaTransforms
   TosaDecomposeTransposeConv.cpp
   TosaInferShapes.cpp
   TosaMakeBroadcastable.cpp
+  TosaOptimization.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa/Transforms

diff  --git a/mlir/lib/Dialect/Tosa/Transforms/TosaOptimization.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaOptimization.cpp
new file mode 100644
index 0000000000000..61b1618f68dad
--- /dev/null
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaOptimization.cpp
@@ -0,0 +1,243 @@
+//===- TosaOptimization.cpp ------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Pass to perform optimizations on TOSA operations
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/DataFlowAnalysis.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/Dialect/Tosa/Transforms/PassDetail.h"
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/FormatVariadic.h"
+
+using namespace mlir;
+using namespace mlir::tosa;
+
+#define PASS_NAME "tosa-optimization"
+#define DEBUG_TYPE PASS_NAME
+
+namespace {
+
+struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
+  explicit Conv2DIsFullyConnected(MLIRContext *context)
+      : OpRewritePattern(context) {}
+
+  LogicalResult matchAndRewrite(tosa::Conv2DOp op,
+                                PatternRewriter &rewriter) const {
+    Value input = op.input();
+    Value weight = op.weight();
+    ShapedType inputType = input.getType().cast<ShapedType>();
+    ShapedType weightType = weight.getType().cast<ShapedType>();
+    ShapedType resultType = op.getType().cast<ShapedType>();
+
+    if (!inputType.hasStaticShape() || !weightType.hasRank()) {
+      return failure();
+    }
+
+    // Stride must be 1 for this optimization.
+    for (Attribute stride : op.stride().getValue()) {
+      if (!stride.cast<IntegerAttr>().getValue().isOne()) {
+        return failure();
+      }
+    }
+
+    // Only works for a 1x1 kernel.
+    ArrayRef<int64_t> weightShape = weightType.getShape();
+    if (weightShape[1] != 1 || weightShape[2] != 1) {
+      return failure();
+    }
+
+    // Reshape input to [N,IH,IW,IC] -> [N * IH * IW, IC].
+    ArrayRef<int64_t> inputShape = inputType.getShape();
+    llvm::SmallVector<int64_t, 2> revisedInputShape{
+        inputShape[0] * inputShape[1] * inputShape[2], inputShape[3]};
+    auto revisedInputShapeType = RankedTensorType::get(
+        revisedInputShape,
+        input.getType().dyn_cast<RankedTensorType>().getElementType());
+    auto reshapedInput = rewriter
+                             .create<tosa::ReshapeOp>(
+                                 op.getLoc(), revisedInputShapeType, input,
+                                 rewriter.getI64ArrayAttr(revisedInputShape))
+                             .getResult();
+
+    // Reshape kernel to [OC,KH,KW,IC] -> [OC, IC].
+    llvm::SmallVector<int64_t, 2> revisedWeightShape{weightShape[0],
+                                                     weightShape[3]};
+    auto revisedWeightShapeType = RankedTensorType::get(
+        revisedWeightShape,
+        weight.getType().dyn_cast<RankedTensorType>().getElementType());
+    auto reshapedWeight = rewriter
+                              .create<tosa::ReshapeOp>(
+                                  op.getLoc(), revisedWeightShapeType, weight,
+                                  rewriter.getI64ArrayAttr(revisedWeightShape))
+                              .getResult();
+
+    // Perform a fully connected network over the reshaped input and weight.
+    llvm::SmallVector<int64_t, 2> fullyConnectedShape{
+        inputShape[0] * inputShape[1] * inputShape[2], weightShape[0]};
+    auto fullyConnectedShapeType = RankedTensorType::get(
+        fullyConnectedShape,
+        resultType.dyn_cast<ShapedType>().getElementType());
+
+    Value fullyConnectedValue;
+    if (op.quantization_info()) {
+      fullyConnectedValue =
+          rewriter
+              .create<tosa::FullyConnectedOp>(
+                  op.getLoc(), fullyConnectedShapeType, reshapedInput,
+                  reshapedWeight, op.bias(), op.quantization_info().getValue())
+              .getResult();
+    } else {
+      fullyConnectedValue = rewriter
+                                .create<tosa::FullyConnectedOp>(
+                                    op.getLoc(), fullyConnectedShapeType,
+                                    reshapedInput, reshapedWeight, op.bias())
+                                .getResult();
+    }
+
+    // Reshape output to [N, IH, IW, OC].
+    llvm::SmallVector<int64_t, 4> outputShape{inputShape[0], inputShape[1],
+                                              inputShape[2], weightShape[0]};
+    rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
+        op, resultType, fullyConnectedValue,
+        rewriter.getI64ArrayAttr(outputShape));
+    return success();
+  }
+};
+
+struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
+  explicit DepthwiseConv2DIsMul(MLIRContext *context)
+      : OpRewritePattern(context) {}
+
+  LogicalResult matchAndRewrite(tosa::DepthwiseConv2DOp op,
+                                PatternRewriter &rewriter) const {
+    Value input = op.input();
+    Value weight = op.weight();
+    ShapedType inputType = input.getType().cast<ShapedType>();
+    ShapedType weightType = weight.getType().cast<ShapedType>();
+    ShapedType resultType = op.output().getType().cast<ShapedType>();
+    Type inputEType = inputType.getElementType();
+
+    if (!(inputType.hasStaticShape() && weightType.hasStaticShape() &&
+          resultType.hasStaticShape())) {
+      return failure();
+    }
+
+    // Quantization information needs to still be performed.
+    if (op.quantization_info() || !inputEType.isa<FloatType>()) {
+      return failure();
+    }
+
+    // Stride must be 1 for this optimization.
+    for (Attribute stride : op.stride().getValue()) {
+      if (!stride.cast<IntegerAttr>().getValue().isOne()) {
+        return failure();
+      }
+    }
+
+    // Only works for a 1x1 kernel.
+    ArrayRef<int64_t> weightShape = weightType.getShape();
+    if (weightShape[0] != 1 || weightShape[1] != 1) {
+      return failure();
+    }
+
+    // Reshape input to [N, H, W, C] -> [N, H, W, C, 1].
+    ArrayRef<int64_t> inputShape = inputType.getShape();
+    llvm::SmallVector<int64_t, 2> revisedInputShape{
+        inputShape[0], inputShape[1], inputShape[2], inputShape[3], 1};
+    auto revisedInputShapeType = RankedTensorType::get(
+        revisedInputShape,
+        input.getType().dyn_cast<RankedTensorType>().getElementType());
+    auto reshapedInput = rewriter
+                             .create<tosa::ReshapeOp>(
+                                 op.getLoc(), revisedInputShapeType, input,
+                                 rewriter.getI64ArrayAttr(revisedInputShape))
+                             .getResult();
+
+    // Reshape kernel to [KH, KW, C, M] -> [1, 1, 1, C, M].
+    llvm::SmallVector<int64_t, 2> revisedWeightShape{1, 1, 1, weightShape[2],
+                                                     weightShape[3]};
+    auto revisedWeightShapeType = RankedTensorType::get(
+        revisedWeightShape,
+        weight.getType().dyn_cast<RankedTensorType>().getElementType());
+    auto reshapedWeight = rewriter
+                              .create<tosa::ReshapeOp>(
+                                  op.getLoc(), revisedWeightShapeType, weight,
+                                  rewriter.getI64ArrayAttr(revisedWeightShape))
+                              .getResult();
+
+    // Perform an elementwise mul over the reshaped input and weight.
+    llvm::SmallVector<int64_t, 2> mulShape{inputShape[0], inputShape[1],
+                                           inputShape[2], inputShape[3],
+                                           weightShape[3]};
+    auto mulShapeType = RankedTensorType::get(
+        mulShape,
+        weight.getType().dyn_cast<RankedTensorType>().getElementType());
+    Value mulValue =
+        rewriter
+            .create<tosa::MulOp>(op.getLoc(), mulShapeType, reshapedInput,
+                                 reshapedWeight, /*shift=*/0)
+            .getResult();
+
+    // Reshape output to [N, H, W, C * M].
+    auto outputShape = op.output().getType().cast<ShapedType>().getShape();
+    auto outputShapeType = RankedTensorType::get(
+        outputShape,
+        input.getType().dyn_cast<RankedTensorType>().getElementType());
+    auto outputValue =
+        rewriter.create<tosa::ReshapeOp>(op.getLoc(), outputShapeType, mulValue,
+                                         rewriter.getI64ArrayAttr(outputShape));
+
+    // Add in the bias.
+    rewriter
+        .replaceOpWithNewOp<tosa::AddOp>(op, outputShapeType, outputValue,
+                                         op.bias())
+        .getResult();
+    return success();
+  }
+};
+
+class TosaOptimization : public PassWrapper<TosaOptimization, FunctionPass> {
+public:
+  explicit TosaOptimization() {}
+  void runOnFunction() override;
+
+  StringRef getArgument() const final { return PASS_NAME; }
+  StringRef getDescription() const final {
+    return "Applies TOSA Operation Optimizations";
+  }
+};
+
+void TosaOptimization::runOnFunction() {
+  OwningRewritePatternList patterns(&getContext());
+
+  patterns.insert<Conv2DIsFullyConnected>(&getContext());
+  patterns.insert<DepthwiseConv2DIsMul>(&getContext());
+
+  auto func = getFunction();
+  if (applyPatternsAndFoldGreedily(func, std::move(patterns)).failed()) {
+    signalPassFailure();
+  }
+}
+
+} // namespace
+
+std::unique_ptr<Pass> mlir::tosa::createTosaOptimizationPass() {
+  return std::make_unique<TosaOptimization>();
+}

diff  --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index fa5304bcfb042..1f2f9a1c66a45 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -68,52 +68,6 @@ func @concat_fold_cast(%arg0: tensor<?x1xf32>) -> tensor<?x?xf32> {
 
 // -----
 
-// CHECK-LABEL: @conv2d_as_fully_connected
-func @conv2d_as_fully_connected(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<3x1x1x2xf32>, %arg2: tensor<3xf32>) -> tensor<4x10x10x3xf32> {
-  // CHECK-NOT: "tosa.conv2d"
-  // CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [400, 2]}
-  // CHECK-SAME: -> tensor<400x2xf32>
-  // CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = [3, 2]}
-  // CHECK-SAME: -> tensor<3x2xf32>
-  // CHECK: %[[VAR2:.*]] = "tosa.fully_connected"(%[[VAR0]], %[[VAR1]], %arg2)
-  // CHECK-SAME: -> tensor<400x3xf32>
-  // CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [4, 10, 10, 3]}
-  // CHECK-SAME: -> tensor<4x10x10x3xf32>
-  // CHECK: return %[[VAR3]]
-  %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1]} : (tensor<4x10x10x2xf32>, tensor<3x1x1x2xf32>, tensor<3xf32>) -> tensor<4x10x10x3xf32>
-  return %0 : tensor<4x10x10x3xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @conv2d_as_fully_connected_quant
-func @conv2d_as_fully_connected_quant(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<3x1x1x2xi8>, %arg2: tensor<3xi32>) -> tensor<4x10x10x3xi32> {
-  // CHECK-NOT: "tosa.conv2d"
-  // CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [400, 2]}
-  // CHECK-SAME: -> tensor<400x2xi8>
-  // CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = [3, 2]}
-  // CHECK-SAME: -> tensor<3x2xi8>
-  // CHECK: %[[VAR2:.*]] = "tosa.fully_connected"(%[[VAR0]], %[[VAR1]], %arg2)
-  // CHECK-SAME: quantization_info = {input_zp = 42 : i32, weight_zp = 24 : i32}
-  // CHECK-SAME: -> tensor<400x3xi32>
-  // CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [4, 10, 10, 3]}
-  // CHECK-SAME: -> tensor<4x10x10x3xi32>
-  // CHECK: return %[[VAR3]]
-  %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], quantization_info = {input_zp = 42 : i32, weight_zp = 24 : i32}} : (tensor<4x10x10x2xi8>, tensor<3x1x1x2xi8>, tensor<3xi32>) -> tensor<4x10x10x3xi32>
-  return %0 : tensor<4x10x10x3xi32>
-}
-
-// -----
-
-// CHECK-LABEL: @conv2d_padded
-func @conv2d_padded(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<3x1x1x2xf32>, %arg2: tensor<3xf32>) -> tensor<4x12x12x3xf32> {
-  // CHECK: "tosa.conv2d"
-  %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {pad = [1, 1, 1, 1], stride = [1, 1], dilation = [1, 1]} : (tensor<4x10x10x2xf32>, tensor<3x1x1x2xf32>, tensor<3xf32>) -> tensor<4x12x12x3xf32>
-  return %0 : tensor<4x12x12x3xf32>
-}
-
-// -----
-
 // CHECK-LABEL: @conv2d_stride_2
 func @conv2d_stride_2(%arg0: tensor<4x10x10x2xf32>) -> tensor<4x10x10x3xf32> {
   // CHECK: "tosa.conv2d"
@@ -136,35 +90,6 @@ func @conv2d_weight_2x2(%arg0: tensor<4x10x10x1xf32>) -> tensor<4x10x10x1xf32> {
 
 // -----
 
-// CHECK-LABEL: @depthwise_conv2d_as_mul
-func @depthwise_conv2d_as_mul(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<1x1x2x3xf32>, %arg2: tensor<6xf32>) -> tensor<4x10x10x6xf32> {
-  // CHECK-NOT: "tosa.depthwise_conv2d"
-  // CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [4, 10, 10, 2, 1]}
-  // CHECK-SAME: -> tensor<4x10x10x2x1xf32>
-  // CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = [1, 1, 1, 2, 3]}
-  // CHECK-SAME: -> tensor<1x1x1x2x3xf32>
-  // CHECK: %[[VAR2:.*]] = "tosa.mul"(%[[VAR0]], %[[VAR1]])
-  // CHECK-SAME: -> tensor<4x10x10x2x3xf32>
-  // CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [4, 10, 10, 6]}
-  // CHECK-SAME: -> tensor<4x10x10x6xf32>
-  // CHECK: %[[VAR4:.*]] = "tosa.add"(%[[VAR3]], %arg2)
-  // CHECK-SAME: -> tensor<4x10x10x6xf32>
-  // CHECK: return %[[VAR4]]
-  %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1]} : (tensor<4x10x10x2xf32>, tensor<1x1x2x3xf32>, tensor<6xf32>) -> tensor<4x10x10x6xf32>
-  return %0 : tensor<4x10x10x6xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @depthwise_conv2d_as_mul_q
-func @depthwise_conv2d_as_mul_q(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<1x1x2x3xi8>, %arg2: tensor<6xi32>) -> tensor<4x10x10x6xi32> {
-  // CHECK: "tosa.depthwise_conv2d"
-  %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], quantization_info = {input_zp = 0 : i32, weight_zp = 0 : i32}} : (tensor<4x10x10x2xi8>, tensor<1x1x2x3xi8>, tensor<6xi32>) -> tensor<4x10x10x6xi32>
-  return %0 : tensor<4x10x10x6xi32>
-}
-
-// -----
-
 // CHECK-LABEL: @depthwise_conv2d_stride_2
 func @depthwise_conv2d_stride_2(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<1x1x2x3xf32>, %arg2: tensor<6xf32>) -> tensor<4x10x10x6xf32> {
   // CHECK: "tosa.depthwise_conv2d"

diff  --git a/mlir/test/Dialect/Tosa/operation_optimization.mlir b/mlir/test/Dialect/Tosa/operation_optimization.mlir
new file mode 100644
index 0000000000000..aa65b96bad4ef
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/operation_optimization.mlir
@@ -0,0 +1,69 @@
+// RUN: mlir-opt --split-input-file --tosa-optimization %s | FileCheck %s
+
+// -----
+
+// CHECK-LABEL: @conv2d_as_fully_connected
+func @conv2d_as_fully_connected(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<3x1x1x2xf32>, %arg2: tensor<3xf32>) -> tensor<4x10x10x3xf32> {
+  // CHECK-NOT: "tosa.conv2d"
+  // CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [400, 2]}
+  // CHECK-SAME: -> tensor<400x2xf32>
+  // CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = [3, 2]}
+  // CHECK-SAME: -> tensor<3x2xf32>
+  // CHECK: %[[VAR2:.*]] = "tosa.fully_connected"(%[[VAR0]], %[[VAR1]], %arg2)
+  // CHECK-SAME: -> tensor<400x3xf32>
+  // CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [4, 10, 10, 3]}
+  // CHECK-SAME: -> tensor<4x10x10x3xf32>
+  // CHECK: return %[[VAR3]]
+  %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1]} : (tensor<4x10x10x2xf32>, tensor<3x1x1x2xf32>, tensor<3xf32>) -> tensor<4x10x10x3xf32>
+  return %0 : tensor<4x10x10x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @conv2d_as_fully_connected_quant
+func @conv2d_as_fully_connected_quant(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<3x1x1x2xi8>, %arg2: tensor<3xi32>) -> tensor<4x10x10x3xi32> {
+  // CHECK-NOT: "tosa.conv2d"
+  // CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [400, 2]}
+  // CHECK-SAME: -> tensor<400x2xi8>
+  // CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = [3, 2]}
+  // CHECK-SAME: -> tensor<3x2xi8>
+  // CHECK: %[[VAR2:.*]] = "tosa.fully_connected"(%[[VAR0]], %[[VAR1]], %arg2)
+  // CHECK-SAME: quantization_info = {input_zp = 42 : i32, weight_zp = 24 : i32}
+  // CHECK-SAME: -> tensor<400x3xi32>
+  // CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [4, 10, 10, 3]}
+  // CHECK-SAME: -> tensor<4x10x10x3xi32>
+  // CHECK: return %[[VAR3]]
+  %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], quantization_info = {input_zp = 42 : i32, weight_zp = 24 : i32}} : (tensor<4x10x10x2xi8>, tensor<3x1x1x2xi8>, tensor<3xi32>) -> tensor<4x10x10x3xi32>
+  return %0 : tensor<4x10x10x3xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @depthwise_conv2d_as_mul
+func @depthwise_conv2d_as_mul(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<1x1x2x3xf32>, %arg2: tensor<6xf32>) -> tensor<4x10x10x6xf32> {
+  // CHECK-NOT: "tosa.depthwise_conv2d"
+  // CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [4, 10, 10, 2, 1]}
+  // CHECK-SAME: -> tensor<4x10x10x2x1xf32>
+  // CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = [1, 1, 1, 2, 3]}
+  // CHECK-SAME: -> tensor<1x1x1x2x3xf32>
+  // CHECK: %[[VAR2:.*]] = "tosa.mul"(%[[VAR0]], %[[VAR1]])
+  // CHECK-SAME: -> tensor<4x10x10x2x3xf32>
+  // CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [4, 10, 10, 6]}
+  // CHECK-SAME: -> tensor<4x10x10x6xf32>
+  // CHECK: %[[VAR4:.*]] = "tosa.add"(%[[VAR3]], %arg2)
+  // CHECK-SAME: -> tensor<4x10x10x6xf32>
+  // CHECK: return %[[VAR4]]
+  %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1]} : (tensor<4x10x10x2xf32>, tensor<1x1x2x3xf32>, tensor<6xf32>) -> tensor<4x10x10x6xf32>
+  return %0 : tensor<4x10x10x6xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @depthwise_conv2d_as_mul_q
+func @depthwise_conv2d_as_mul_q(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<1x1x2x3xi8>, %arg2: tensor<6xi32>) -> tensor<4x10x10x6xi32> {
+  // CHECK: "tosa.depthwise_conv2d"
+  %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], quantization_info = {input_zp = 0 : i32, weight_zp = 0 : i32}} : (tensor<4x10x10x2xi8>, tensor<1x1x2x3xi8>, tensor<6xi32>) -> tensor<4x10x10x6xi32>
+  return %0 : tensor<4x10x10x6xi32>
+}
+
+// -----


        


More information about the Mlir-commits mailing list