[Mlir-commits] [mlir] dfd0708 - [mlir][tosa] Allow optional TOSA decompositions to be populated separately

Rob Suderman llvmlistbot at llvm.org
Tue Jan 11 10:27:57 PST 2022


Author: Aaron DeBattista
Date: 2022-01-11T10:26:30-08:00
New Revision: dfd070820cbae9d6864a7de20ae81757d199fc61

URL: https://github.com/llvm/llvm-project/commit/dfd070820cbae9d6864a7de20ae81757d199fc61
DIFF: https://github.com/llvm/llvm-project/commit/dfd070820cbae9d6864a7de20ae81757d199fc61.diff

LOG: [mlir][tosa] Allow optional TOSA decompositions to be populated separately

Moved all TOSA decomposition patterns so that they can be optionally populated
and used by external rewrites. This avoids decomposing TOSa operations when
backends may benefit from the non-decomposed version.

Reviewed By: rsuderman, mehdi_amini

Differential Revision: https://reviews.llvm.org/D116526

Added: 
    mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
    mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
    mlir/lib/Dialect/Tosa/Transforms/TosaOptionalDecompositions.cpp
    mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir
    mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir

Modified: 
    mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
    mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
    mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
    mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
    mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir

Removed: 
    mlir/lib/Dialect/Tosa/Transforms/TosaOptimization.cpp
    mlir/test/Dialect/Tosa/operation_optimization.mlir


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
index e94daca0df72..1bdfc2f43bf3 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
@@ -19,11 +19,18 @@
 namespace mlir {
 namespace tosa {
 
-std::unique_ptr<Pass> createTosaDecomposeTransposeConvPass();
+// Expose Rewrite Functions that decompose TOSA Ops into further TOSA Ops.
+// The rewrites can be selectively added to a conversion pass.
+void populateTosaDecomposeConv2D(MLIRContext *ctx, RewritePatternSet &patterns);
+void populateTosaDecomposeTransposeConv(MLIRContext *ctx,
+                                        RewritePatternSet &patterns);
+void populateTosaDecomposeDepthwise(MLIRContext *ctx,
+                                    RewritePatternSet &patterns);
+
 std::unique_ptr<Pass> createTosaInferShapesPass();
 std::unique_ptr<Pass> createTosaMakeBroadcastablePass();
-std::unique_ptr<Pass> createTosaOptimizationPass();
 std::unique_ptr<Pass> createTosaTestQuantUtilAPIPass();
+std::unique_ptr<Pass> createTosaOptionalDecompositions();
 
 #define GEN_PASS_REGISTRATION
 #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"

diff  --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
index 4a75482ba832..fbb3134e0041 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
@@ -15,21 +15,6 @@
 
 include "mlir/Pass/PassBase.td"
 
-def TosaDecomposeTransposeConv : FunctionPass<"tosa-decompose-transpose-conv"> {
-  let summary = "Deompose transpose convolutiions into standard convolutions.";
-  let description = [{
-    Pass that uses shape manipulation and convolution operations to transform
-    a transpose convolution into a regular convolution.
-  }];
-
-  let constructor = "createTosaDecomposeTransposeConvPass()";
-  let dependentDialects = [
-    "StandardOpsDialect",
-    "tensor::TensorDialect",
-    "tosa::TosaDialect",
-  ];
-}
-
 def TosaInferShapes : FunctionPass<"tosa-infer-shapes"> {
   let summary = "Propagate shapes across TOSA operations";
   let description = [{
@@ -58,13 +43,14 @@ def TosaMakeBroadcastable : FunctionPass<"tosa-make-broadcastable"> {
   let constructor = "createTosaMakeBroadcastablePass()";
 }
 
-def TosaOptimization : FunctionPass<"tosa-optimization"> {
-  let summary = "TOSA operation optimizations";
+def TosaOptionalDecompositions : FunctionPass<"tosa-optional-decompositions"> {
+  let summary = "Applies Tosa operations optional decompositions";
   let description = [{
-    "Pass to perform optimizations on TOSA operations"
+    Pass to apply the Tosa operations decompositions 
+    exposed as populate functions in include/mlir/Dialect/Tosa/Transforms/Passes.h
   }];
 
-  let constructor = "createTosaOptimizationPass()";
+  let constructor = "tosa::createTosaOptionalDecompositions()";
 }
 
 #endif // MLIR_DIALECT_TOSA_TRANSFORMS_PASSES

diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
index 3813ba345137..e75e8d72bc2e 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
@@ -68,6 +68,7 @@ std::unique_ptr<Pass> mlir::tosa::createTosaToLinalg() {
 }
 
 void mlir::tosa::addTosaToLinalgPasses(OpPassManager &pm) {
+  pm.addNestedPass<FuncOp>(mlir::tosa::createTosaOptionalDecompositions());
   pm.addNestedPass<FuncOp>(createTosaMakeBroadcastablePass());
   pm.addNestedPass<FuncOp>(createTosaToLinalgNamed());
   pm.addNestedPass<FuncOp>(mlir::createCanonicalizerPass());

diff  --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
index 016575fc7735..a24d9cb65cbd 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
@@ -1,8 +1,10 @@
 add_mlir_dialect_library(MLIRTosaTransforms
   TosaDecomposeTransposeConv.cpp
+  TosaDecomposeConv2D.cpp
+  TosaDecomposeDepthwise.cpp
   TosaInferShapes.cpp
   TosaMakeBroadcastable.cpp
-  TosaOptimization.cpp
+  TosaOptionalDecompositions.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa/Transforms

diff  --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
new file mode 100644
index 000000000000..4c412f987899
--- /dev/null
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
@@ -0,0 +1,115 @@
+//===- TosaDecomposeConv2D.cpp ------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Decompose TOSA Conv2D operation to a series of TOSA Ops specifically
+// (1) Convert a 1x1 Convolution to a Reshape->FC->Reshape
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+using namespace mlir::tosa;
+
+namespace {
+
+struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
+  explicit Conv2DIsFullyConnected(MLIRContext *context)
+      : OpRewritePattern(context) {}
+
+  LogicalResult matchAndRewrite(tosa::Conv2DOp op,
+                                PatternRewriter &rewriter) const override {
+    Value input = op.input();
+    Value weight = op.weight();
+    ShapedType inputType = input.getType().cast<ShapedType>();
+    ShapedType weightType = weight.getType().cast<ShapedType>();
+    ShapedType resultType = op.getType().cast<ShapedType>();
+
+    if (!inputType.hasStaticShape() || !weightType.hasRank()) {
+      return failure();
+    }
+
+    // Stride must be 1 for this optimization.
+    for (Attribute stride : op.stride().getValue()) {
+      if (!stride.cast<IntegerAttr>().getValue().isOne()) {
+        return failure();
+      }
+    }
+
+    // Only works for a 1x1 kernel.
+    ArrayRef<int64_t> weightShape = weightType.getShape();
+    if (weightShape[1] != 1 || weightShape[2] != 1) {
+      return failure();
+    }
+
+    // Reshape input to [N,IH,IW,IC] -> [N * IH * IW, IC].
+    ArrayRef<int64_t> inputShape = inputType.getShape();
+    llvm::SmallVector<int64_t, 2> revisedInputShape{
+        inputShape[0] * inputShape[1] * inputShape[2], inputShape[3]};
+    auto revisedInputShapeType = RankedTensorType::get(
+        revisedInputShape,
+        input.getType().dyn_cast<RankedTensorType>().getElementType());
+    auto reshapedInput = rewriter
+                             .create<tosa::ReshapeOp>(
+                                 op.getLoc(), revisedInputShapeType, input,
+                                 rewriter.getI64ArrayAttr(revisedInputShape))
+                             .getResult();
+
+    // Reshape kernel to [OC,KH,KW,IC] -> [OC, IC].
+    llvm::SmallVector<int64_t, 2> revisedWeightShape{weightShape[0],
+                                                     weightShape[3]};
+    auto revisedWeightShapeType = RankedTensorType::get(
+        revisedWeightShape,
+        weight.getType().dyn_cast<RankedTensorType>().getElementType());
+    auto reshapedWeight = rewriter
+                              .create<tosa::ReshapeOp>(
+                                  op.getLoc(), revisedWeightShapeType, weight,
+                                  rewriter.getI64ArrayAttr(revisedWeightShape))
+                              .getResult();
+
+    // Perform a fully connected network over the reshaped input and weight.
+    llvm::SmallVector<int64_t, 2> fullyConnectedShape{
+        inputShape[0] * inputShape[1] * inputShape[2], weightShape[0]};
+    auto fullyConnectedShapeType = RankedTensorType::get(
+        fullyConnectedShape,
+        resultType.dyn_cast<ShapedType>().getElementType());
+
+    Value fullyConnectedValue;
+    if (op.quantization_info()) {
+      fullyConnectedValue =
+          rewriter
+              .create<tosa::FullyConnectedOp>(
+                  op.getLoc(), fullyConnectedShapeType, reshapedInput,
+                  reshapedWeight, op.bias(), op.quantization_info().getValue())
+              .getResult();
+    } else {
+      fullyConnectedValue = rewriter
+                                .create<tosa::FullyConnectedOp>(
+                                    op.getLoc(), fullyConnectedShapeType,
+                                    reshapedInput, reshapedWeight, op.bias())
+                                .getResult();
+    }
+
+    // Reshape output to [N, IH, IW, OC].
+    llvm::SmallVector<int64_t, 4> outputShape{inputShape[0], inputShape[1],
+                                              inputShape[2], weightShape[0]};
+    rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
+        op, resultType, fullyConnectedValue,
+        rewriter.getI64ArrayAttr(outputShape));
+    return success();
+  }
+};
+
+} // namespace
+
+void mlir::tosa::populateTosaDecomposeConv2D(MLIRContext *ctx,
+                                             RewritePatternSet &patterns) {
+  patterns.insert<Conv2DIsFullyConnected>(ctx);
+}

diff  --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
new file mode 100644
index 000000000000..685f97353d74
--- /dev/null
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
@@ -0,0 +1,121 @@
+//===- TosaDecomposeDepthwise.cpp
+//------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Decompose TOSA Depthwise operation to a series of TOSA Ops specifically
+// (1) Convert a 1x1 Depthwise to Reshape -> Mul -> Reshape -> Add
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+using namespace mlir::tosa;
+
+namespace {
+
+struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
+  explicit DepthwiseConv2DIsMul(MLIRContext *context)
+      : OpRewritePattern(context) {}
+
+  LogicalResult matchAndRewrite(tosa::DepthwiseConv2DOp op,
+                                PatternRewriter &rewriter) const override {
+    Value input = op.input();
+    Value weight = op.weight();
+    ShapedType inputType = input.getType().cast<ShapedType>();
+    ShapedType weightType = weight.getType().cast<ShapedType>();
+    ShapedType resultType = op.output().getType().cast<ShapedType>();
+    Type inputEType = inputType.getElementType();
+
+    if (!(inputType.hasStaticShape() && weightType.hasStaticShape() &&
+          resultType.hasStaticShape())) {
+      return failure();
+    }
+
+    // Quantization information needs to still be performed.
+    if (op.quantization_info() || !inputEType.isa<FloatType>()) {
+      return failure();
+    }
+
+    // Stride must be 1 for this optimization.
+    for (Attribute stride : op.stride().getValue()) {
+      if (!stride.cast<IntegerAttr>().getValue().isOne()) {
+        return failure();
+      }
+    }
+
+    // Only works for a 1x1 kernel.
+    ArrayRef<int64_t> weightShape = weightType.getShape();
+    if (weightShape[0] != 1 || weightShape[1] != 1) {
+      return failure();
+    }
+
+    // Reshape input to [N, H, W, C] -> [N, H, W, C, 1].
+    ArrayRef<int64_t> inputShape = inputType.getShape();
+    llvm::SmallVector<int64_t, 2> revisedInputShape{
+        inputShape[0], inputShape[1], inputShape[2], inputShape[3], 1};
+    auto revisedInputShapeType = RankedTensorType::get(
+        revisedInputShape,
+        input.getType().dyn_cast<RankedTensorType>().getElementType());
+    auto reshapedInput = rewriter
+                             .create<tosa::ReshapeOp>(
+                                 op.getLoc(), revisedInputShapeType, input,
+                                 rewriter.getI64ArrayAttr(revisedInputShape))
+                             .getResult();
+
+    // Reshape kernel to [KH, KW, C, M] -> [1, 1, 1, C, M].
+    llvm::SmallVector<int64_t, 2> revisedWeightShape{1, 1, 1, weightShape[2],
+                                                     weightShape[3]};
+    auto revisedWeightShapeType = RankedTensorType::get(
+        revisedWeightShape,
+        weight.getType().dyn_cast<RankedTensorType>().getElementType());
+    auto reshapedWeight = rewriter
+                              .create<tosa::ReshapeOp>(
+                                  op.getLoc(), revisedWeightShapeType, weight,
+                                  rewriter.getI64ArrayAttr(revisedWeightShape))
+                              .getResult();
+
+    // Perform an elementwise mul over the reshaped input and weight.
+    llvm::SmallVector<int64_t, 2> mulShape{inputShape[0], inputShape[1],
+                                           inputShape[2], inputShape[3],
+                                           weightShape[3]};
+    auto mulShapeType = RankedTensorType::get(
+        mulShape,
+        weight.getType().dyn_cast<RankedTensorType>().getElementType());
+    Value mulValue =
+        rewriter
+            .create<tosa::MulOp>(op.getLoc(), mulShapeType, reshapedInput,
+                                 reshapedWeight, /*shift=*/0)
+            .getResult();
+
+    // Reshape output to [N, H, W, C * M].
+    auto outputShape = op.output().getType().cast<ShapedType>().getShape();
+    auto outputShapeType = RankedTensorType::get(
+        outputShape,
+        input.getType().dyn_cast<RankedTensorType>().getElementType());
+    auto outputValue =
+        rewriter.create<tosa::ReshapeOp>(op.getLoc(), outputShapeType, mulValue,
+                                         rewriter.getI64ArrayAttr(outputShape));
+
+    // Add in the bias.
+    rewriter
+        .replaceOpWithNewOp<tosa::AddOp>(op, outputShapeType, outputValue,
+                                         op.bias())
+        .getResult();
+    return success();
+  }
+};
+
+} // namespace
+
+void mlir::tosa::populateTosaDecomposeDepthwise(MLIRContext *ctx,
+                                                RewritePatternSet &patterns) {
+  patterns.insert<DepthwiseConv2DIsMul>(ctx);
+}

diff  --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
index 341e78d52792..330add9e248e 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
@@ -7,17 +7,19 @@
 //
 //===----------------------------------------------------------------------===//
 //
-// Insert reshape to binary op's input if needed to match rank
+// Decompose TOSA TransposeConv operation to a series of TOSA Ops specifically
+// (1) Convert a Dilated TransposeConv2D to Conv2D including reversing/reshaping
+// etc.. of the weights (2) Convert a Strided TransposeConv2D to Conv2D
+// including transposing/reversing/reshaping etc..
+//     of the weights and input/output tenors and reversing/reshaping etc .. of
+//     the weights
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/StandardOps/IR/Ops.h"
-#include "mlir/Dialect/Tosa/IR//TosaOps.h"
-#include "mlir/Dialect/Tosa/Transforms/PassDetail.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
 #include "mlir/Dialect/Tosa/Transforms/Passes.h"
 #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
 #include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 using namespace mlir;
 using namespace mlir::tosa;
@@ -369,22 +371,10 @@ class TransposeConvStridedConverter
   }
 };
 
-/// Pass that enables broadcast by making all input arrays have the same
-/// number of dimensions. Insert RESHAPE operations to lower rank operand
-struct TosaDecomposeTransposeConv
-    : public TosaDecomposeTransposeConvBase<TosaDecomposeTransposeConv> {
-public:
-  void runOnFunction() override {
-    auto func = getFunction();
-    RewritePatternSet patterns(func.getContext());
-    patterns
-        .insert<TransposeConvDilatedConverter, TransposeConvStridedConverter>(
-            func.getContext());
-    (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
-  }
-};
 } // namespace
 
-std::unique_ptr<Pass> mlir::tosa::createTosaDecomposeTransposeConvPass() {
-  return std::make_unique<TosaDecomposeTransposeConv>();
+void mlir::tosa::populateTosaDecomposeTransposeConv(
+    MLIRContext *ctx, RewritePatternSet &patterns) {
+  patterns.insert<TransposeConvDilatedConverter>(ctx);
+  patterns.insert<TransposeConvStridedConverter>(ctx);
 }

diff  --git a/mlir/lib/Dialect/Tosa/Transforms/TosaOptimization.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaOptimization.cpp
deleted file mode 100644
index 9a19b63ed198..000000000000
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaOptimization.cpp
+++ /dev/null
@@ -1,243 +0,0 @@
-//===- TosaOptimization.cpp ------------------------------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// Pass to perform optimizations on TOSA operations
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Analysis/DataFlowAnalysis.h"
-#include "mlir/Dialect/StandardOps/IR/Ops.h"
-#include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/Dialect/Tosa/IR/TosaOps.h"
-#include "mlir/Dialect/Tosa/Transforms/PassDetail.h"
-#include "mlir/Dialect/Tosa/Transforms/Passes.h"
-#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
-#include "mlir/IR/BlockAndValueMapping.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/Matchers.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/DialectConversion.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "llvm/Support/FormatVariadic.h"
-
-using namespace mlir;
-using namespace mlir::tosa;
-
-#define PASS_NAME "tosa-optimization"
-#define DEBUG_TYPE PASS_NAME
-
-namespace {
-
-struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
-  explicit Conv2DIsFullyConnected(MLIRContext *context)
-      : OpRewritePattern(context) {}
-
-  LogicalResult matchAndRewrite(tosa::Conv2DOp op,
-                                PatternRewriter &rewriter) const override {
-    Value input = op.input();
-    Value weight = op.weight();
-    ShapedType inputType = input.getType().cast<ShapedType>();
-    ShapedType weightType = weight.getType().cast<ShapedType>();
-    ShapedType resultType = op.getType().cast<ShapedType>();
-
-    if (!inputType.hasStaticShape() || !weightType.hasRank()) {
-      return failure();
-    }
-
-    // Stride must be 1 for this optimization.
-    for (Attribute stride : op.stride().getValue()) {
-      if (!stride.cast<IntegerAttr>().getValue().isOne()) {
-        return failure();
-      }
-    }
-
-    // Only works for a 1x1 kernel.
-    ArrayRef<int64_t> weightShape = weightType.getShape();
-    if (weightShape[1] != 1 || weightShape[2] != 1) {
-      return failure();
-    }
-
-    // Reshape input to [N,IH,IW,IC] -> [N * IH * IW, IC].
-    ArrayRef<int64_t> inputShape = inputType.getShape();
-    llvm::SmallVector<int64_t, 2> revisedInputShape{
-        inputShape[0] * inputShape[1] * inputShape[2], inputShape[3]};
-    auto revisedInputShapeType = RankedTensorType::get(
-        revisedInputShape,
-        input.getType().dyn_cast<RankedTensorType>().getElementType());
-    auto reshapedInput = rewriter
-                             .create<tosa::ReshapeOp>(
-                                 op.getLoc(), revisedInputShapeType, input,
-                                 rewriter.getI64ArrayAttr(revisedInputShape))
-                             .getResult();
-
-    // Reshape kernel to [OC,KH,KW,IC] -> [OC, IC].
-    llvm::SmallVector<int64_t, 2> revisedWeightShape{weightShape[0],
-                                                     weightShape[3]};
-    auto revisedWeightShapeType = RankedTensorType::get(
-        revisedWeightShape,
-        weight.getType().dyn_cast<RankedTensorType>().getElementType());
-    auto reshapedWeight = rewriter
-                              .create<tosa::ReshapeOp>(
-                                  op.getLoc(), revisedWeightShapeType, weight,
-                                  rewriter.getI64ArrayAttr(revisedWeightShape))
-                              .getResult();
-
-    // Perform a fully connected network over the reshaped input and weight.
-    llvm::SmallVector<int64_t, 2> fullyConnectedShape{
-        inputShape[0] * inputShape[1] * inputShape[2], weightShape[0]};
-    auto fullyConnectedShapeType = RankedTensorType::get(
-        fullyConnectedShape,
-        resultType.dyn_cast<ShapedType>().getElementType());
-
-    Value fullyConnectedValue;
-    if (op.quantization_info()) {
-      fullyConnectedValue =
-          rewriter
-              .create<tosa::FullyConnectedOp>(
-                  op.getLoc(), fullyConnectedShapeType, reshapedInput,
-                  reshapedWeight, op.bias(), op.quantization_info().getValue())
-              .getResult();
-    } else {
-      fullyConnectedValue = rewriter
-                                .create<tosa::FullyConnectedOp>(
-                                    op.getLoc(), fullyConnectedShapeType,
-                                    reshapedInput, reshapedWeight, op.bias())
-                                .getResult();
-    }
-
-    // Reshape output to [N, IH, IW, OC].
-    llvm::SmallVector<int64_t, 4> outputShape{inputShape[0], inputShape[1],
-                                              inputShape[2], weightShape[0]};
-    rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
-        op, resultType, fullyConnectedValue,
-        rewriter.getI64ArrayAttr(outputShape));
-    return success();
-  }
-};
-
-struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
-  explicit DepthwiseConv2DIsMul(MLIRContext *context)
-      : OpRewritePattern(context) {}
-
-  LogicalResult matchAndRewrite(tosa::DepthwiseConv2DOp op,
-                                PatternRewriter &rewriter) const override {
-    Value input = op.input();
-    Value weight = op.weight();
-    ShapedType inputType = input.getType().cast<ShapedType>();
-    ShapedType weightType = weight.getType().cast<ShapedType>();
-    ShapedType resultType = op.output().getType().cast<ShapedType>();
-    Type inputEType = inputType.getElementType();
-
-    if (!(inputType.hasStaticShape() && weightType.hasStaticShape() &&
-          resultType.hasStaticShape())) {
-      return failure();
-    }
-
-    // Quantization information needs to still be performed.
-    if (op.quantization_info() || !inputEType.isa<FloatType>()) {
-      return failure();
-    }
-
-    // Stride must be 1 for this optimization.
-    for (Attribute stride : op.stride().getValue()) {
-      if (!stride.cast<IntegerAttr>().getValue().isOne()) {
-        return failure();
-      }
-    }
-
-    // Only works for a 1x1 kernel.
-    ArrayRef<int64_t> weightShape = weightType.getShape();
-    if (weightShape[0] != 1 || weightShape[1] != 1) {
-      return failure();
-    }
-
-    // Reshape input to [N, H, W, C] -> [N, H, W, C, 1].
-    ArrayRef<int64_t> inputShape = inputType.getShape();
-    llvm::SmallVector<int64_t, 2> revisedInputShape{
-        inputShape[0], inputShape[1], inputShape[2], inputShape[3], 1};
-    auto revisedInputShapeType = RankedTensorType::get(
-        revisedInputShape,
-        input.getType().dyn_cast<RankedTensorType>().getElementType());
-    auto reshapedInput = rewriter
-                             .create<tosa::ReshapeOp>(
-                                 op.getLoc(), revisedInputShapeType, input,
-                                 rewriter.getI64ArrayAttr(revisedInputShape))
-                             .getResult();
-
-    // Reshape kernel to [KH, KW, C, M] -> [1, 1, 1, C, M].
-    llvm::SmallVector<int64_t, 2> revisedWeightShape{1, 1, 1, weightShape[2],
-                                                     weightShape[3]};
-    auto revisedWeightShapeType = RankedTensorType::get(
-        revisedWeightShape,
-        weight.getType().dyn_cast<RankedTensorType>().getElementType());
-    auto reshapedWeight = rewriter
-                              .create<tosa::ReshapeOp>(
-                                  op.getLoc(), revisedWeightShapeType, weight,
-                                  rewriter.getI64ArrayAttr(revisedWeightShape))
-                              .getResult();
-
-    // Perform an elementwise mul over the reshaped input and weight.
-    llvm::SmallVector<int64_t, 2> mulShape{inputShape[0], inputShape[1],
-                                           inputShape[2], inputShape[3],
-                                           weightShape[3]};
-    auto mulShapeType = RankedTensorType::get(
-        mulShape,
-        weight.getType().dyn_cast<RankedTensorType>().getElementType());
-    Value mulValue =
-        rewriter
-            .create<tosa::MulOp>(op.getLoc(), mulShapeType, reshapedInput,
-                                 reshapedWeight, /*shift=*/0)
-            .getResult();
-
-    // Reshape output to [N, H, W, C * M].
-    auto outputShape = op.output().getType().cast<ShapedType>().getShape();
-    auto outputShapeType = RankedTensorType::get(
-        outputShape,
-        input.getType().dyn_cast<RankedTensorType>().getElementType());
-    auto outputValue =
-        rewriter.create<tosa::ReshapeOp>(op.getLoc(), outputShapeType, mulValue,
-                                         rewriter.getI64ArrayAttr(outputShape));
-
-    // Add in the bias.
-    rewriter
-        .replaceOpWithNewOp<tosa::AddOp>(op, outputShapeType, outputValue,
-                                         op.bias())
-        .getResult();
-    return success();
-  }
-};
-
-class TosaOptimization : public PassWrapper<TosaOptimization, FunctionPass> {
-public:
-  explicit TosaOptimization() = default;
-  void runOnFunction() override;
-
-  StringRef getArgument() const final { return PASS_NAME; }
-  StringRef getDescription() const final {
-    return "Applies TOSA Operation Optimizations";
-  }
-};
-
-void TosaOptimization::runOnFunction() {
-  OwningRewritePatternList patterns(&getContext());
-
-  patterns.insert<Conv2DIsFullyConnected>(&getContext());
-  patterns.insert<DepthwiseConv2DIsMul>(&getContext());
-
-  auto func = getFunction();
-  if (applyPatternsAndFoldGreedily(func, std::move(patterns)).failed()) {
-    signalPassFailure();
-  }
-}
-
-} // namespace
-
-std::unique_ptr<Pass> mlir::tosa::createTosaOptimizationPass() {
-  return std::make_unique<TosaOptimization>();
-}

diff  --git a/mlir/lib/Dialect/Tosa/Transforms/TosaOptionalDecompositions.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaOptionalDecompositions.cpp
new file mode 100644
index 000000000000..50fd635c8a46
--- /dev/null
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaOptionalDecompositions.cpp
@@ -0,0 +1,46 @@
+//===- TosaOptionalDecompositions.cpp
+//------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Pass to apply the Tosa operations decompositions
+// exposed as populate functions in
+// include/mlir/Dialect/Tosa/Transforms/Passes.h
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/Dialect/Tosa/Transforms/PassDetail.h"
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+
+namespace {
+
+struct TosaOptionalDecompositions
+    : public TosaOptionalDecompositionsBase<TosaOptionalDecompositions> {
+  void runOnFunction() {
+    auto *ctx = &getContext();
+    RewritePatternSet patterns(ctx);
+    auto func = getFunction();
+
+    mlir::tosa::populateTosaDecomposeConv2D(ctx, patterns);
+    mlir::tosa::populateTosaDecomposeTransposeConv(ctx, patterns);
+    mlir::tosa::populateTosaDecomposeDepthwise(ctx, patterns);
+
+    if (applyPatternsAndFoldGreedily(func, std::move(patterns)).failed())
+      signalPassFailure();
+  }
+};
+
+} // namespace
+
+std::unique_ptr<Pass> mlir::tosa::createTosaOptionalDecompositions() {
+  return std::make_unique<TosaOptionalDecompositions>();
+}

diff  --git a/mlir/test/Dialect/Tosa/operation_optimization.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir
similarity index 53%
rename from mlir/test/Dialect/Tosa/operation_optimization.mlir
rename to mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir
index aa65b96bad4e..cd9864f0c04f 100644
--- a/mlir/test/Dialect/Tosa/operation_optimization.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir
@@ -1,69 +1,40 @@
-// RUN: mlir-opt --split-input-file --tosa-optimization %s | FileCheck %s
-
-// -----
-
-// CHECK-LABEL: @conv2d_as_fully_connected
-func @conv2d_as_fully_connected(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<3x1x1x2xf32>, %arg2: tensor<3xf32>) -> tensor<4x10x10x3xf32> {
-  // CHECK-NOT: "tosa.conv2d"
-  // CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [400, 2]}
-  // CHECK-SAME: -> tensor<400x2xf32>
-  // CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = [3, 2]}
-  // CHECK-SAME: -> tensor<3x2xf32>
-  // CHECK: %[[VAR2:.*]] = "tosa.fully_connected"(%[[VAR0]], %[[VAR1]], %arg2)
-  // CHECK-SAME: -> tensor<400x3xf32>
-  // CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [4, 10, 10, 3]}
-  // CHECK-SAME: -> tensor<4x10x10x3xf32>
-  // CHECK: return %[[VAR3]]
-  %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1]} : (tensor<4x10x10x2xf32>, tensor<3x1x1x2xf32>, tensor<3xf32>) -> tensor<4x10x10x3xf32>
-  return %0 : tensor<4x10x10x3xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @conv2d_as_fully_connected_quant
-func @conv2d_as_fully_connected_quant(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<3x1x1x2xi8>, %arg2: tensor<3xi32>) -> tensor<4x10x10x3xi32> {
-  // CHECK-NOT: "tosa.conv2d"
-  // CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [400, 2]}
-  // CHECK-SAME: -> tensor<400x2xi8>
-  // CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = [3, 2]}
-  // CHECK-SAME: -> tensor<3x2xi8>
-  // CHECK: %[[VAR2:.*]] = "tosa.fully_connected"(%[[VAR0]], %[[VAR1]], %arg2)
-  // CHECK-SAME: quantization_info = {input_zp = 42 : i32, weight_zp = 24 : i32}
-  // CHECK-SAME: -> tensor<400x3xi32>
-  // CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [4, 10, 10, 3]}
-  // CHECK-SAME: -> tensor<4x10x10x3xi32>
-  // CHECK: return %[[VAR3]]
-  %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], quantization_info = {input_zp = 42 : i32, weight_zp = 24 : i32}} : (tensor<4x10x10x2xi8>, tensor<3x1x1x2xi8>, tensor<3xi32>) -> tensor<4x10x10x3xi32>
-  return %0 : tensor<4x10x10x3xi32>
-}
-
-// -----
-
-// CHECK-LABEL: @depthwise_conv2d_as_mul
-func @depthwise_conv2d_as_mul(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<1x1x2x3xf32>, %arg2: tensor<6xf32>) -> tensor<4x10x10x6xf32> {
-  // CHECK-NOT: "tosa.depthwise_conv2d"
-  // CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [4, 10, 10, 2, 1]}
-  // CHECK-SAME: -> tensor<4x10x10x2x1xf32>
-  // CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = [1, 1, 1, 2, 3]}
-  // CHECK-SAME: -> tensor<1x1x1x2x3xf32>
-  // CHECK: %[[VAR2:.*]] = "tosa.mul"(%[[VAR0]], %[[VAR1]])
-  // CHECK-SAME: -> tensor<4x10x10x2x3xf32>
-  // CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [4, 10, 10, 6]}
-  // CHECK-SAME: -> tensor<4x10x10x6xf32>
-  // CHECK: %[[VAR4:.*]] = "tosa.add"(%[[VAR3]], %arg2)
-  // CHECK-SAME: -> tensor<4x10x10x6xf32>
-  // CHECK: return %[[VAR4]]
-  %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1]} : (tensor<4x10x10x2xf32>, tensor<1x1x2x3xf32>, tensor<6xf32>) -> tensor<4x10x10x6xf32>
-  return %0 : tensor<4x10x10x6xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @depthwise_conv2d_as_mul_q
-func @depthwise_conv2d_as_mul_q(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<1x1x2x3xi8>, %arg2: tensor<6xi32>) -> tensor<4x10x10x6xi32> {
-  // CHECK: "tosa.depthwise_conv2d"
-  %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], quantization_info = {input_zp = 0 : i32, weight_zp = 0 : i32}} : (tensor<4x10x10x2xi8>, tensor<1x1x2x3xi8>, tensor<6xi32>) -> tensor<4x10x10x6xi32>
-  return %0 : tensor<4x10x10x6xi32>
-}
-
-// -----
+// RUN: mlir-opt --split-input-file --tosa-optional-decompositions %s | FileCheck %s
+
+// -----
+
+// CHECK-LABEL: @conv2d_as_fully_connected
+func @conv2d_as_fully_connected(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<3x1x1x2xf32>, %arg2: tensor<3xf32>) -> tensor<4x10x10x3xf32> {
+  // CHECK-NOT: "tosa.conv2d"
+  // CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [400, 2]}
+  // CHECK-SAME: -> tensor<400x2xf32>
+  // CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = [3, 2]}
+  // CHECK-SAME: -> tensor<3x2xf32>
+  // CHECK: %[[VAR2:.*]] = "tosa.fully_connected"(%[[VAR0]], %[[VAR1]], %arg2)
+  // CHECK-SAME: -> tensor<400x3xf32>
+  // CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [4, 10, 10, 3]}
+  // CHECK-SAME: -> tensor<4x10x10x3xf32>
+  // CHECK: return %[[VAR3]]
+  %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1]} : (tensor<4x10x10x2xf32>, tensor<3x1x1x2xf32>, tensor<3xf32>) -> tensor<4x10x10x3xf32>
+  return %0 : tensor<4x10x10x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @conv2d_as_fully_connected_quant
+func @conv2d_as_fully_connected_quant(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<3x1x1x2xi8>, %arg2: tensor<3xi32>) -> tensor<4x10x10x3xi32> {
+  // CHECK-NOT: "tosa.conv2d"
+  // CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [400, 2]}
+  // CHECK-SAME: -> tensor<400x2xi8>
+  // CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = [3, 2]}
+  // CHECK-SAME: -> tensor<3x2xi8>
+  // CHECK: %[[VAR2:.*]] = "tosa.fully_connected"(%[[VAR0]], %[[VAR1]], %arg2)
+  // CHECK-SAME: quantization_info = {input_zp = 42 : i32, weight_zp = 24 : i32}
+  // CHECK-SAME: -> tensor<400x3xi32>
+  // CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [4, 10, 10, 3]}
+  // CHECK-SAME: -> tensor<4x10x10x3xi32>
+  // CHECK: return %[[VAR3]]
+  %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], quantization_info = {input_zp = 42 : i32, weight_zp = 24 : i32}} : (tensor<4x10x10x2xi8>, tensor<3x1x1x2xi8>, tensor<3xi32>) -> tensor<4x10x10x3xi32>
+  return %0 : tensor<4x10x10x3xi32>
+}
+
+// -----

diff  --git a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
new file mode 100644
index 000000000000..e6370d7a8314
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
@@ -0,0 +1,32 @@
+// RUN: mlir-opt --split-input-file --tosa-optional-decompositions %s | FileCheck %s
+
+// -----
+
+// CHECK-LABEL: @depthwise_conv2d_as_mul
+func @depthwise_conv2d_as_mul(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<1x1x2x3xf32>, %arg2: tensor<6xf32>) -> tensor<4x10x10x6xf32> {
+  // CHECK-NOT: "tosa.depthwise_conv2d"
+  // CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [4, 10, 10, 2, 1]}
+  // CHECK-SAME: -> tensor<4x10x10x2x1xf32>
+  // CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = [1, 1, 1, 2, 3]}
+  // CHECK-SAME: -> tensor<1x1x1x2x3xf32>
+  // CHECK: %[[VAR2:.*]] = "tosa.mul"(%[[VAR0]], %[[VAR1]])
+  // CHECK-SAME: -> tensor<4x10x10x2x3xf32>
+  // CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [4, 10, 10, 6]}
+  // CHECK-SAME: -> tensor<4x10x10x6xf32>
+  // CHECK: %[[VAR4:.*]] = "tosa.add"(%[[VAR3]], %arg2)
+  // CHECK-SAME: -> tensor<4x10x10x6xf32>
+  // CHECK: return %[[VAR4]]
+  %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1]} : (tensor<4x10x10x2xf32>, tensor<1x1x2x3xf32>, tensor<6xf32>) -> tensor<4x10x10x6xf32>
+  return %0 : tensor<4x10x10x6xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @depthwise_conv2d_as_mul_q
+func @depthwise_conv2d_as_mul_q(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<1x1x2x3xi8>, %arg2: tensor<6xi32>) -> tensor<4x10x10x6xi32> {
+  // CHECK: "tosa.depthwise_conv2d"
+  %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], quantization_info = {input_zp = 0 : i32, weight_zp = 0 : i32}} : (tensor<4x10x10x2xi8>, tensor<1x1x2x3xi8>, tensor<6xi32>) -> tensor<4x10x10x6xi32>
+  return %0 : tensor<4x10x10x6xi32>
+}
+
+// -----

diff  --git a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
index 627622ba796e..d0e9e5e17e84 100644
--- a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt --split-input-file --tosa-decompose-transpose-conv %s | FileCheck %s
+// RUN: mlir-opt --split-input-file --tosa-optional-decompositions %s | FileCheck %s
 
 // CHECK-LABEL: @transpose_conv2d
 func @transpose_conv2d(%arg0: tensor<2x16x14x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>) -> tensor<2x?x?x5xf32> {


        


More information about the Mlir-commits mailing list