[Mlir-commits] [mlir] 64f694a - [mlir][tosa] Move tosa canonicalizers to optional optimization pass
Rob Suderman
llvmlistbot at llvm.org
Thu Dec 16 23:36:20 PST 2021
Author: Aaron DeBattista
Date: 2021-12-16T23:33:54-08:00
New Revision: 64f694acaf9279c0902f2f150b48191fb91057fb
URL: https://github.com/llvm/llvm-project/commit/64f694acaf9279c0902f2f150b48191fb91057fb
DIFF: https://github.com/llvm/llvm-project/commit/64f694acaf9279c0902f2f150b48191fb91057fb.diff
LOG: [mlir][tosa] Move tosa canonicalizers to optional optimization pass
TOSA's canonicalizers that change dense operations should be moved to a
seperate optimization pass to avoid canonicalizing to operations not supported
for relevant backends.
Reviewed By: rsuderman
Differential Revision: https://reviews.llvm.org/D115890
Added:
mlir/lib/Dialect/Tosa/Transforms/TosaOptimization.cpp
mlir/test/Dialect/Tosa/operation_optimization.mlir
Modified:
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
mlir/test/Dialect/Tosa/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 982880e027e04..597813481d684 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -118,8 +118,6 @@ def Tosa_Conv2DOp : Tosa_Op<"conv2d", [
let builders = [Tosa_ConvOpQuantInfoBuilder];
let verifier = [{ return verifyConvOp(*this); }];
-
- let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//
@@ -187,8 +185,6 @@ def Tosa_DepthwiseConv2DOp : Tosa_Op<"depthwise_conv2d", [
let builders = [Tosa_ConvOpQuantInfoBuilder];
let verifier = [{ return verifyConvOp(*this); }];
-
- let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
index 278402eb93b01..e94daca0df728 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
@@ -22,6 +22,7 @@ namespace tosa {
std::unique_ptr<Pass> createTosaDecomposeTransposeConvPass();
std::unique_ptr<Pass> createTosaInferShapesPass();
std::unique_ptr<Pass> createTosaMakeBroadcastablePass();
+std::unique_ptr<Pass> createTosaOptimizationPass();
std::unique_ptr<Pass> createTosaTestQuantUtilAPIPass();
#define GEN_PASS_REGISTRATION
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
index 7d6af621675b8..4a75482ba832a 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
@@ -1,4 +1,4 @@
-//===-- Passes.td - TOSA optimization pass declarations ----*- tablegen -*-===//
+//===-- Passes.td - TOSA pass declarations ----*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
//
-// This file declares the optimization passes for the TOSA Dialect in MLIR.
+// This file declares the passes for the TOSA Dialect in MLIR.
//
//===----------------------------------------------------------------------===//
@@ -58,4 +58,13 @@ def TosaMakeBroadcastable : FunctionPass<"tosa-make-broadcastable"> {
let constructor = "createTosaMakeBroadcastablePass()";
}
+def TosaOptimization : FunctionPass<"tosa-optimization"> {
+ let summary = "TOSA operation optimizations";
+ let description = [{
+ "Pass to perform optimizations on TOSA operations"
+ }];
+
+ let constructor = "createTosaOptimizationPass()";
+}
+
#endif // MLIR_DIALECT_TOSA_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 9809e57e3a9a0..f61ce68893d69 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -423,197 +423,6 @@ void PadOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
results.insert<MaterializePadValue>(context);
}
-struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(tosa::Conv2DOp op,
- PatternRewriter &rewriter) const override {
- Value input = op.input();
- Value weight = op.weight();
- ShapedType inputType = input.getType().cast<ShapedType>();
- ShapedType weightType = weight.getType().cast<ShapedType>();
- ShapedType resultType = op.getType().cast<ShapedType>();
-
- if (!inputType.hasStaticShape() || !weightType.hasRank()) {
- return failure();
- }
-
- for (Attribute pad : op.pad().getValue()) {
- if (!pad.cast<IntegerAttr>().getValue().isZero()) {
- return failure();
- }
- }
-
- // Stride must be 1 for this optimization.
- for (Attribute stride : op.stride().getValue()) {
- if (!stride.cast<IntegerAttr>().getValue().isOne()) {
- return failure();
- }
- }
-
- // Only works for a 1x1 kernel.
- ArrayRef<int64_t> weightShape = weightType.getShape();
- if (weightShape[1] != 1 || weightShape[2] != 1) {
- return failure();
- }
-
- // Reshape input to [N,IH,IW,IC] -> [N * IH * IW, IC].
- ArrayRef<int64_t> inputShape = inputType.getShape();
- llvm::SmallVector<int64_t, 2> revisedInputShape{
- inputShape[0] * inputShape[1] * inputShape[2], inputShape[3]};
- auto revisedInputShapeType =
- RankedTensorType::get(revisedInputShape, inputType.getElementType());
- auto reshapedInput = rewriter
- .create<tosa::ReshapeOp>(
- op.getLoc(), revisedInputShapeType, input,
- rewriter.getI64ArrayAttr(revisedInputShape))
- .getResult();
-
- // Reshape kernel to [OC,KH,KW,IC] -> [OC, IC].
- llvm::SmallVector<int64_t, 2> revisedWeightShape{weightShape[0],
- weightShape[3]};
- auto revisedWeightShapeType =
- RankedTensorType::get(revisedWeightShape, weightType.getElementType());
- auto reshapedWeight = rewriter
- .create<tosa::ReshapeOp>(
- op.getLoc(), revisedWeightShapeType, weight,
- rewriter.getI64ArrayAttr(revisedWeightShape))
- .getResult();
-
- // Perform a fully connected network over the reshaped input and weight.
- llvm::SmallVector<int64_t, 2> fullyConnectedShape{
- inputShape[0] * inputShape[1] * inputShape[2], weightShape[0]};
- auto fullyConnectedShapeType =
- RankedTensorType::get(fullyConnectedShape, resultType.getElementType());
-
- Value fullyConnectedValue;
- if (op.quantization_info()) {
- fullyConnectedValue =
- rewriter
- .create<tosa::FullyConnectedOp>(
- op.getLoc(), fullyConnectedShapeType, reshapedInput,
- reshapedWeight, op.bias(), op.quantization_info().getValue())
- .getResult();
- } else {
- fullyConnectedValue = rewriter
- .create<tosa::FullyConnectedOp>(
- op.getLoc(), fullyConnectedShapeType,
- reshapedInput, reshapedWeight, op.bias())
- .getResult();
- }
-
- // Reshape output to [N, IH, IW, OC].
- llvm::SmallVector<int64_t, 4> outputShape{inputShape[0], inputShape[1],
- inputShape[2], weightShape[0]};
- rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
- op, resultType, fullyConnectedValue,
- rewriter.getI64ArrayAttr(outputShape));
- return success();
- }
-};
-
-void Conv2DOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
- MLIRContext *context) {
- results.insert<Conv2DIsFullyConnected>(context);
-}
-
-struct DepthwiseConv2DMulOptimization
- : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(tosa::DepthwiseConv2DOp op,
- PatternRewriter &rewriter) const override {
- Value input = op.input();
- Value weight = op.weight();
- ShapedType inputType = input.getType().cast<ShapedType>();
- ShapedType weightType = weight.getType().cast<ShapedType>();
- ShapedType resultType = op.output().getType().cast<ShapedType>();
- Type inputEType = inputType.getElementType();
-
- if (!(inputType.hasStaticShape() && weightType.hasStaticShape() &&
- resultType.hasStaticShape())) {
- return failure();
- }
-
- // Quantization information needs to still be performed.
- if (op.quantization_info() || !inputEType.isa<FloatType>()) {
- return failure();
- }
-
- // Stride must be 1 for this optimization.
- for (Attribute stride : op.stride().getValue()) {
- if (!stride.cast<IntegerAttr>().getValue().isOne()) {
- return failure();
- }
- }
-
- // Only works for a 1x1 kernel.
- ArrayRef<int64_t> weightShape = weightType.getShape();
- if (weightShape[0] != 1 || weightShape[1] != 1) {
- return failure();
- }
-
- // Reshape input to [N, H, W, C] -> [N, H, W, C, 1].
- ArrayRef<int64_t> inputShape = inputType.getShape();
- llvm::SmallVector<int64_t, 2> revisedInputShape{
- inputShape[0], inputShape[1], inputShape[2], inputShape[3], 1};
- auto revisedInputShapeType = RankedTensorType::get(
- revisedInputShape,
- input.getType().dyn_cast<RankedTensorType>().getElementType());
- auto reshapedInput = rewriter
- .create<tosa::ReshapeOp>(
- op.getLoc(), revisedInputShapeType, input,
- rewriter.getI64ArrayAttr(revisedInputShape))
- .getResult();
-
- // Reshape kernel to [KH, KW, C, M] -> [1, 1, 1, C, M].
- llvm::SmallVector<int64_t, 2> revisedWeightShape{1, 1, 1, weightShape[2],
- weightShape[3]};
- auto revisedWeightShapeType = RankedTensorType::get(
- revisedWeightShape,
- weight.getType().dyn_cast<RankedTensorType>().getElementType());
- auto reshapedWeight = rewriter
- .create<tosa::ReshapeOp>(
- op.getLoc(), revisedWeightShapeType, weight,
- rewriter.getI64ArrayAttr(revisedWeightShape))
- .getResult();
-
- // Perform an elementwise mul over the reshaped input and weight.
- llvm::SmallVector<int64_t, 2> mulShape{inputShape[0], inputShape[1],
- inputShape[2], inputShape[3],
- weightShape[3]};
- auto mulShapeType = RankedTensorType::get(
- mulShape,
- weight.getType().dyn_cast<RankedTensorType>().getElementType());
- Value mulValue =
- rewriter
- .create<tosa::MulOp>(op.getLoc(), mulShapeType, reshapedInput,
- reshapedWeight, /*shift=*/0)
- .getResult();
-
- // Reshape output to [N, H, W, C * M].
- auto outputShape = op.output().getType().cast<ShapedType>().getShape();
- auto outputShapeType = RankedTensorType::get(
- outputShape,
- input.getType().dyn_cast<RankedTensorType>().getElementType());
- auto outputValue =
- rewriter.create<tosa::ReshapeOp>(op.getLoc(), outputShapeType, mulValue,
- rewriter.getI64ArrayAttr(outputShape));
-
- // Add in the bias.
- rewriter
- .replaceOpWithNewOp<tosa::AddOp>(op, outputShapeType, outputValue,
- op.bias())
- .getResult();
- return success();
- }
-};
-
-void DepthwiseConv2DOp::getCanonicalizationPatterns(
- OwningRewritePatternList &results, MLIRContext *context) {
- results.insert<DepthwiseConv2DMulOptimization>(context);
-}
-
struct MaxPool2dIsNoOp : public OpRewritePattern<tosa::MaxPool2dOp> {
using OpRewritePattern::OpRewritePattern;
@@ -747,7 +556,8 @@ OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
// TOSA Operator Verifiers.
//===----------------------------------------------------------------------===//
-template <typename T> static LogicalResult verifyConvOp(T op) {
+template <typename T>
+static LogicalResult verifyConvOp(T op) {
// All TOSA conv ops have an input() and weight().
auto inputType = op.input().getType().template dyn_cast<RankedTensorType>();
auto weightType = op.weight().getType().template dyn_cast<RankedTensorType>();
diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
index b5e90bbeecc59..016575fc7735b 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRTosaTransforms
TosaDecomposeTransposeConv.cpp
TosaInferShapes.cpp
TosaMakeBroadcastable.cpp
+ TosaOptimization.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa/Transforms
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaOptimization.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaOptimization.cpp
new file mode 100644
index 0000000000000..61b1618f68dad
--- /dev/null
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaOptimization.cpp
@@ -0,0 +1,243 @@
+//===- TosaOptimization.cpp ------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Pass to perform optimizations on TOSA operations
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/DataFlowAnalysis.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/Dialect/Tosa/Transforms/PassDetail.h"
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/FormatVariadic.h"
+
+using namespace mlir;
+using namespace mlir::tosa;
+
+#define PASS_NAME "tosa-optimization"
+#define DEBUG_TYPE PASS_NAME
+
+namespace {
+
+struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
+ explicit Conv2DIsFullyConnected(MLIRContext *context)
+ : OpRewritePattern(context) {}
+
+ LogicalResult matchAndRewrite(tosa::Conv2DOp op,
+ PatternRewriter &rewriter) const {
+ Value input = op.input();
+ Value weight = op.weight();
+ ShapedType inputType = input.getType().cast<ShapedType>();
+ ShapedType weightType = weight.getType().cast<ShapedType>();
+ ShapedType resultType = op.getType().cast<ShapedType>();
+
+ if (!inputType.hasStaticShape() || !weightType.hasRank()) {
+ return failure();
+ }
+
+ // Stride must be 1 for this optimization.
+ for (Attribute stride : op.stride().getValue()) {
+ if (!stride.cast<IntegerAttr>().getValue().isOne()) {
+ return failure();
+ }
+ }
+
+ // Only works for a 1x1 kernel.
+ ArrayRef<int64_t> weightShape = weightType.getShape();
+ if (weightShape[1] != 1 || weightShape[2] != 1) {
+ return failure();
+ }
+
+ // Reshape input to [N,IH,IW,IC] -> [N * IH * IW, IC].
+ ArrayRef<int64_t> inputShape = inputType.getShape();
+ llvm::SmallVector<int64_t, 2> revisedInputShape{
+ inputShape[0] * inputShape[1] * inputShape[2], inputShape[3]};
+ auto revisedInputShapeType = RankedTensorType::get(
+ revisedInputShape,
+ input.getType().dyn_cast<RankedTensorType>().getElementType());
+ auto reshapedInput = rewriter
+ .create<tosa::ReshapeOp>(
+ op.getLoc(), revisedInputShapeType, input,
+ rewriter.getI64ArrayAttr(revisedInputShape))
+ .getResult();
+
+ // Reshape kernel to [OC,KH,KW,IC] -> [OC, IC].
+ llvm::SmallVector<int64_t, 2> revisedWeightShape{weightShape[0],
+ weightShape[3]};
+ auto revisedWeightShapeType = RankedTensorType::get(
+ revisedWeightShape,
+ weight.getType().dyn_cast<RankedTensorType>().getElementType());
+ auto reshapedWeight = rewriter
+ .create<tosa::ReshapeOp>(
+ op.getLoc(), revisedWeightShapeType, weight,
+ rewriter.getI64ArrayAttr(revisedWeightShape))
+ .getResult();
+
+ // Perform a fully connected network over the reshaped input and weight.
+ llvm::SmallVector<int64_t, 2> fullyConnectedShape{
+ inputShape[0] * inputShape[1] * inputShape[2], weightShape[0]};
+ auto fullyConnectedShapeType = RankedTensorType::get(
+ fullyConnectedShape,
+ resultType.dyn_cast<ShapedType>().getElementType());
+
+ Value fullyConnectedValue;
+ if (op.quantization_info()) {
+ fullyConnectedValue =
+ rewriter
+ .create<tosa::FullyConnectedOp>(
+ op.getLoc(), fullyConnectedShapeType, reshapedInput,
+ reshapedWeight, op.bias(), op.quantization_info().getValue())
+ .getResult();
+ } else {
+ fullyConnectedValue = rewriter
+ .create<tosa::FullyConnectedOp>(
+ op.getLoc(), fullyConnectedShapeType,
+ reshapedInput, reshapedWeight, op.bias())
+ .getResult();
+ }
+
+ // Reshape output to [N, IH, IW, OC].
+ llvm::SmallVector<int64_t, 4> outputShape{inputShape[0], inputShape[1],
+ inputShape[2], weightShape[0]};
+ rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
+ op, resultType, fullyConnectedValue,
+ rewriter.getI64ArrayAttr(outputShape));
+ return success();
+ }
+};
+
+struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
+ explicit DepthwiseConv2DIsMul(MLIRContext *context)
+ : OpRewritePattern(context) {}
+
+ LogicalResult matchAndRewrite(tosa::DepthwiseConv2DOp op,
+ PatternRewriter &rewriter) const {
+ Value input = op.input();
+ Value weight = op.weight();
+ ShapedType inputType = input.getType().cast<ShapedType>();
+ ShapedType weightType = weight.getType().cast<ShapedType>();
+ ShapedType resultType = op.output().getType().cast<ShapedType>();
+ Type inputEType = inputType.getElementType();
+
+ if (!(inputType.hasStaticShape() && weightType.hasStaticShape() &&
+ resultType.hasStaticShape())) {
+ return failure();
+ }
+
+ // Quantization information needs to still be performed.
+ if (op.quantization_info() || !inputEType.isa<FloatType>()) {
+ return failure();
+ }
+
+ // Stride must be 1 for this optimization.
+ for (Attribute stride : op.stride().getValue()) {
+ if (!stride.cast<IntegerAttr>().getValue().isOne()) {
+ return failure();
+ }
+ }
+
+ // Only works for a 1x1 kernel.
+ ArrayRef<int64_t> weightShape = weightType.getShape();
+ if (weightShape[0] != 1 || weightShape[1] != 1) {
+ return failure();
+ }
+
+ // Reshape input to [N, H, W, C] -> [N, H, W, C, 1].
+ ArrayRef<int64_t> inputShape = inputType.getShape();
+ llvm::SmallVector<int64_t, 2> revisedInputShape{
+ inputShape[0], inputShape[1], inputShape[2], inputShape[3], 1};
+ auto revisedInputShapeType = RankedTensorType::get(
+ revisedInputShape,
+ input.getType().dyn_cast<RankedTensorType>().getElementType());
+ auto reshapedInput = rewriter
+ .create<tosa::ReshapeOp>(
+ op.getLoc(), revisedInputShapeType, input,
+ rewriter.getI64ArrayAttr(revisedInputShape))
+ .getResult();
+
+ // Reshape kernel to [KH, KW, C, M] -> [1, 1, 1, C, M].
+ llvm::SmallVector<int64_t, 2> revisedWeightShape{1, 1, 1, weightShape[2],
+ weightShape[3]};
+ auto revisedWeightShapeType = RankedTensorType::get(
+ revisedWeightShape,
+ weight.getType().dyn_cast<RankedTensorType>().getElementType());
+ auto reshapedWeight = rewriter
+ .create<tosa::ReshapeOp>(
+ op.getLoc(), revisedWeightShapeType, weight,
+ rewriter.getI64ArrayAttr(revisedWeightShape))
+ .getResult();
+
+ // Perform an elementwise mul over the reshaped input and weight.
+ llvm::SmallVector<int64_t, 2> mulShape{inputShape[0], inputShape[1],
+ inputShape[2], inputShape[3],
+ weightShape[3]};
+ auto mulShapeType = RankedTensorType::get(
+ mulShape,
+ weight.getType().dyn_cast<RankedTensorType>().getElementType());
+ Value mulValue =
+ rewriter
+ .create<tosa::MulOp>(op.getLoc(), mulShapeType, reshapedInput,
+ reshapedWeight, /*shift=*/0)
+ .getResult();
+
+ // Reshape output to [N, H, W, C * M].
+ auto outputShape = op.output().getType().cast<ShapedType>().getShape();
+ auto outputShapeType = RankedTensorType::get(
+ outputShape,
+ input.getType().dyn_cast<RankedTensorType>().getElementType());
+ auto outputValue =
+ rewriter.create<tosa::ReshapeOp>(op.getLoc(), outputShapeType, mulValue,
+ rewriter.getI64ArrayAttr(outputShape));
+
+ // Add in the bias.
+ rewriter
+ .replaceOpWithNewOp<tosa::AddOp>(op, outputShapeType, outputValue,
+ op.bias())
+ .getResult();
+ return success();
+ }
+};
+
+class TosaOptimization : public PassWrapper<TosaOptimization, FunctionPass> {
+public:
+ explicit TosaOptimization() {}
+ void runOnFunction() override;
+
+ StringRef getArgument() const final { return PASS_NAME; }
+ StringRef getDescription() const final {
+ return "Applies TOSA Operation Optimizations";
+ }
+};
+
+void TosaOptimization::runOnFunction() {
+ OwningRewritePatternList patterns(&getContext());
+
+ patterns.insert<Conv2DIsFullyConnected>(&getContext());
+ patterns.insert<DepthwiseConv2DIsMul>(&getContext());
+
+ auto func = getFunction();
+ if (applyPatternsAndFoldGreedily(func, std::move(patterns)).failed()) {
+ signalPassFailure();
+ }
+}
+
+} // namespace
+
+std::unique_ptr<Pass> mlir::tosa::createTosaOptimizationPass() {
+ return std::make_unique<TosaOptimization>();
+}
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index fa5304bcfb042..1f2f9a1c66a45 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -68,52 +68,6 @@ func @concat_fold_cast(%arg0: tensor<?x1xf32>) -> tensor<?x?xf32> {
// -----
-// CHECK-LABEL: @conv2d_as_fully_connected
-func @conv2d_as_fully_connected(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<3x1x1x2xf32>, %arg2: tensor<3xf32>) -> tensor<4x10x10x3xf32> {
- // CHECK-NOT: "tosa.conv2d"
- // CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [400, 2]}
- // CHECK-SAME: -> tensor<400x2xf32>
- // CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = [3, 2]}
- // CHECK-SAME: -> tensor<3x2xf32>
- // CHECK: %[[VAR2:.*]] = "tosa.fully_connected"(%[[VAR0]], %[[VAR1]], %arg2)
- // CHECK-SAME: -> tensor<400x3xf32>
- // CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [4, 10, 10, 3]}
- // CHECK-SAME: -> tensor<4x10x10x3xf32>
- // CHECK: return %[[VAR3]]
- %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1]} : (tensor<4x10x10x2xf32>, tensor<3x1x1x2xf32>, tensor<3xf32>) -> tensor<4x10x10x3xf32>
- return %0 : tensor<4x10x10x3xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @conv2d_as_fully_connected_quant
-func @conv2d_as_fully_connected_quant(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<3x1x1x2xi8>, %arg2: tensor<3xi32>) -> tensor<4x10x10x3xi32> {
- // CHECK-NOT: "tosa.conv2d"
- // CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [400, 2]}
- // CHECK-SAME: -> tensor<400x2xi8>
- // CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = [3, 2]}
- // CHECK-SAME: -> tensor<3x2xi8>
- // CHECK: %[[VAR2:.*]] = "tosa.fully_connected"(%[[VAR0]], %[[VAR1]], %arg2)
- // CHECK-SAME: quantization_info = {input_zp = 42 : i32, weight_zp = 24 : i32}
- // CHECK-SAME: -> tensor<400x3xi32>
- // CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [4, 10, 10, 3]}
- // CHECK-SAME: -> tensor<4x10x10x3xi32>
- // CHECK: return %[[VAR3]]
- %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], quantization_info = {input_zp = 42 : i32, weight_zp = 24 : i32}} : (tensor<4x10x10x2xi8>, tensor<3x1x1x2xi8>, tensor<3xi32>) -> tensor<4x10x10x3xi32>
- return %0 : tensor<4x10x10x3xi32>
-}
-
-// -----
-
-// CHECK-LABEL: @conv2d_padded
-func @conv2d_padded(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<3x1x1x2xf32>, %arg2: tensor<3xf32>) -> tensor<4x12x12x3xf32> {
- // CHECK: "tosa.conv2d"
- %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {pad = [1, 1, 1, 1], stride = [1, 1], dilation = [1, 1]} : (tensor<4x10x10x2xf32>, tensor<3x1x1x2xf32>, tensor<3xf32>) -> tensor<4x12x12x3xf32>
- return %0 : tensor<4x12x12x3xf32>
-}
-
-// -----
-
// CHECK-LABEL: @conv2d_stride_2
func @conv2d_stride_2(%arg0: tensor<4x10x10x2xf32>) -> tensor<4x10x10x3xf32> {
// CHECK: "tosa.conv2d"
@@ -136,35 +90,6 @@ func @conv2d_weight_2x2(%arg0: tensor<4x10x10x1xf32>) -> tensor<4x10x10x1xf32> {
// -----
-// CHECK-LABEL: @depthwise_conv2d_as_mul
-func @depthwise_conv2d_as_mul(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<1x1x2x3xf32>, %arg2: tensor<6xf32>) -> tensor<4x10x10x6xf32> {
- // CHECK-NOT: "tosa.depthwise_conv2d"
- // CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [4, 10, 10, 2, 1]}
- // CHECK-SAME: -> tensor<4x10x10x2x1xf32>
- // CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = [1, 1, 1, 2, 3]}
- // CHECK-SAME: -> tensor<1x1x1x2x3xf32>
- // CHECK: %[[VAR2:.*]] = "tosa.mul"(%[[VAR0]], %[[VAR1]])
- // CHECK-SAME: -> tensor<4x10x10x2x3xf32>
- // CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [4, 10, 10, 6]}
- // CHECK-SAME: -> tensor<4x10x10x6xf32>
- // CHECK: %[[VAR4:.*]] = "tosa.add"(%[[VAR3]], %arg2)
- // CHECK-SAME: -> tensor<4x10x10x6xf32>
- // CHECK: return %[[VAR4]]
- %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1]} : (tensor<4x10x10x2xf32>, tensor<1x1x2x3xf32>, tensor<6xf32>) -> tensor<4x10x10x6xf32>
- return %0 : tensor<4x10x10x6xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @depthwise_conv2d_as_mul_q
-func @depthwise_conv2d_as_mul_q(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<1x1x2x3xi8>, %arg2: tensor<6xi32>) -> tensor<4x10x10x6xi32> {
- // CHECK: "tosa.depthwise_conv2d"
- %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], quantization_info = {input_zp = 0 : i32, weight_zp = 0 : i32}} : (tensor<4x10x10x2xi8>, tensor<1x1x2x3xi8>, tensor<6xi32>) -> tensor<4x10x10x6xi32>
- return %0 : tensor<4x10x10x6xi32>
-}
-
-// -----
-
// CHECK-LABEL: @depthwise_conv2d_stride_2
func @depthwise_conv2d_stride_2(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<1x1x2x3xf32>, %arg2: tensor<6xf32>) -> tensor<4x10x10x6xf32> {
// CHECK: "tosa.depthwise_conv2d"
diff --git a/mlir/test/Dialect/Tosa/operation_optimization.mlir b/mlir/test/Dialect/Tosa/operation_optimization.mlir
new file mode 100644
index 0000000000000..aa65b96bad4ef
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/operation_optimization.mlir
@@ -0,0 +1,69 @@
+// RUN: mlir-opt --split-input-file --tosa-optimization %s | FileCheck %s
+
+// -----
+
+// CHECK-LABEL: @conv2d_as_fully_connected
+func @conv2d_as_fully_connected(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<3x1x1x2xf32>, %arg2: tensor<3xf32>) -> tensor<4x10x10x3xf32> {
+ // CHECK-NOT: "tosa.conv2d"
+ // CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [400, 2]}
+ // CHECK-SAME: -> tensor<400x2xf32>
+ // CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = [3, 2]}
+ // CHECK-SAME: -> tensor<3x2xf32>
+ // CHECK: %[[VAR2:.*]] = "tosa.fully_connected"(%[[VAR0]], %[[VAR1]], %arg2)
+ // CHECK-SAME: -> tensor<400x3xf32>
+ // CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [4, 10, 10, 3]}
+ // CHECK-SAME: -> tensor<4x10x10x3xf32>
+ // CHECK: return %[[VAR3]]
+ %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1]} : (tensor<4x10x10x2xf32>, tensor<3x1x1x2xf32>, tensor<3xf32>) -> tensor<4x10x10x3xf32>
+ return %0 : tensor<4x10x10x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @conv2d_as_fully_connected_quant
+func @conv2d_as_fully_connected_quant(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<3x1x1x2xi8>, %arg2: tensor<3xi32>) -> tensor<4x10x10x3xi32> {
+ // CHECK-NOT: "tosa.conv2d"
+ // CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [400, 2]}
+ // CHECK-SAME: -> tensor<400x2xi8>
+ // CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = [3, 2]}
+ // CHECK-SAME: -> tensor<3x2xi8>
+ // CHECK: %[[VAR2:.*]] = "tosa.fully_connected"(%[[VAR0]], %[[VAR1]], %arg2)
+ // CHECK-SAME: quantization_info = {input_zp = 42 : i32, weight_zp = 24 : i32}
+ // CHECK-SAME: -> tensor<400x3xi32>
+ // CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [4, 10, 10, 3]}
+ // CHECK-SAME: -> tensor<4x10x10x3xi32>
+ // CHECK: return %[[VAR3]]
+ %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], quantization_info = {input_zp = 42 : i32, weight_zp = 24 : i32}} : (tensor<4x10x10x2xi8>, tensor<3x1x1x2xi8>, tensor<3xi32>) -> tensor<4x10x10x3xi32>
+ return %0 : tensor<4x10x10x3xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @depthwise_conv2d_as_mul
+func @depthwise_conv2d_as_mul(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<1x1x2x3xf32>, %arg2: tensor<6xf32>) -> tensor<4x10x10x6xf32> {
+ // CHECK-NOT: "tosa.depthwise_conv2d"
+ // CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [4, 10, 10, 2, 1]}
+ // CHECK-SAME: -> tensor<4x10x10x2x1xf32>
+ // CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = [1, 1, 1, 2, 3]}
+ // CHECK-SAME: -> tensor<1x1x1x2x3xf32>
+ // CHECK: %[[VAR2:.*]] = "tosa.mul"(%[[VAR0]], %[[VAR1]])
+ // CHECK-SAME: -> tensor<4x10x10x2x3xf32>
+ // CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [4, 10, 10, 6]}
+ // CHECK-SAME: -> tensor<4x10x10x6xf32>
+ // CHECK: %[[VAR4:.*]] = "tosa.add"(%[[VAR3]], %arg2)
+ // CHECK-SAME: -> tensor<4x10x10x6xf32>
+ // CHECK: return %[[VAR4]]
+ %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1]} : (tensor<4x10x10x2xf32>, tensor<1x1x2x3xf32>, tensor<6xf32>) -> tensor<4x10x10x6xf32>
+ return %0 : tensor<4x10x10x6xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @depthwise_conv2d_as_mul_q
+func @depthwise_conv2d_as_mul_q(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<1x1x2x3xi8>, %arg2: tensor<6xi32>) -> tensor<4x10x10x6xi32> {
+ // CHECK: "tosa.depthwise_conv2d"
+ %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], quantization_info = {input_zp = 0 : i32, weight_zp = 0 : i32}} : (tensor<4x10x10x2xi8>, tensor<1x1x2x3xi8>, tensor<6xi32>) -> tensor<4x10x10x6xi32>
+ return %0 : tensor<4x10x10x6xi32>
+}
+
+// -----
More information about the Mlir-commits
mailing list