[Mlir-commits] [mlir] 5911a29 - [mlir][tosa] Add tosa.depthwise_conv2d as tosa.mul canonicalization

Rob Suderman llvmlistbot at llvm.org
Mon Dec 6 17:33:53 PST 2021


Author: not-jenni
Date: 2021-12-06T17:28:52-08:00
New Revision: 5911a29aa92065549031c998ce03ef75d8d61118

URL: https://github.com/llvm/llvm-project/commit/5911a29aa92065549031c998ce03ef75d8d61118
DIFF: https://github.com/llvm/llvm-project/commit/5911a29aa92065549031c998ce03ef75d8d61118.diff

LOG: [mlir][tosa] Add tosa.depthwise_conv2d as tosa.mul canonicalization

For a 1x1 weight and stride of 1, the input/weight can be reshaped and
multiplied elementwise then reshaped back

Reviewed By: rsuderman, KoolJBlack

Differential Revision: https://reviews.llvm.org/D115207

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
    mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
    mlir/test/Dialect/Tosa/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 554023dc03814..173f26db6c934 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -187,6 +187,8 @@ def Tosa_DepthwiseConv2DOp : Tosa_Op<"depthwise_conv2d", [
   let builders = [Tosa_ConvOpQuantInfoBuilder];
 
   let verifier = [{ return verifyConvOp(*this); }];
+
+  let hasCanonicalizer = 1;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 51c41e3334adf..cefe13f57dbba 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -515,6 +515,97 @@ void Conv2DOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
   results.insert<Conv2DFullyConnectedOptimization>(context);
 }
 
+struct DepthwiseConv2DMulOptimization
+    : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tosa::DepthwiseConv2DOp op,
+                                PatternRewriter &rewriter) const override {
+    Value input = op.input();
+    Value weight = op.weight();
+    ShapedType inputType = input.getType().cast<ShapedType>();
+    ShapedType weightType = weight.getType().cast<ShapedType>();
+    ShapedType resultType = op.output().getType().cast<ShapedType>();
+
+    if (!(inputType.hasStaticShape() && weightType.hasStaticShape() &&
+          resultType.hasStaticShape())) {
+      return failure();
+    }
+
+    // Stride must be 1 for this optimization.
+    for (Attribute stride : op.stride().getValue()) {
+      if (!stride.cast<IntegerAttr>().getValue().isOne()) {
+        return failure();
+      }
+    }
+
+    // Only works for a 1x1 kernel.
+    ArrayRef<int64_t> weightShape = weightType.getShape();
+    if (weightShape[0] != 1 || weightShape[1] != 1) {
+      return failure();
+    }
+
+    // Reshape input to [N, H, W, C] -> [N, H, W, C, 1].
+    ArrayRef<int64_t> inputShape = inputType.getShape();
+    llvm::SmallVector<int64_t, 2> revisedInputShape{
+        inputShape[0], inputShape[1], inputShape[2], inputShape[3], 1};
+    auto revisedInputShapeType = RankedTensorType::get(
+        revisedInputShape,
+        input.getType().dyn_cast<RankedTensorType>().getElementType());
+    auto reshapedInput = rewriter
+                             .create<tosa::ReshapeOp>(
+                                 op.getLoc(), revisedInputShapeType, input,
+                                 rewriter.getI64ArrayAttr(revisedInputShape))
+                             .getResult();
+
+    // Reshape kernel to [KH, KW, C, M] -> [1, 1, 1, C, M].
+    llvm::SmallVector<int64_t, 2> revisedWeightShape{1, 1, 1, weightShape[2],
+                                                     weightShape[3]};
+    auto revisedWeightShapeType = RankedTensorType::get(
+        revisedWeightShape,
+        weight.getType().dyn_cast<RankedTensorType>().getElementType());
+    auto reshapedWeight = rewriter
+                              .create<tosa::ReshapeOp>(
+                                  op.getLoc(), revisedWeightShapeType, weight,
+                                  rewriter.getI64ArrayAttr(revisedWeightShape))
+                              .getResult();
+
+    // Perform an elementwise mul over the reshaped input and weight.
+    llvm::SmallVector<int64_t, 2> mulShape{inputShape[0], inputShape[1],
+                                           inputShape[2], inputShape[3],
+                                           weightShape[3]};
+    auto mulShapeType = RankedTensorType::get(
+        mulShape,
+        weight.getType().dyn_cast<RankedTensorType>().getElementType());
+    Value mulValue =
+        rewriter
+            .create<tosa::MulOp>(op.getLoc(), mulShapeType, reshapedInput,
+                                 reshapedWeight, /*shift=*/0)
+            .getResult();
+
+    // Reshape output to [N, H, W, C * M].
+    auto outputShape = op.output().getType().cast<ShapedType>().getShape();
+    auto outputShapeType = RankedTensorType::get(
+        outputShape,
+        input.getType().dyn_cast<RankedTensorType>().getElementType());
+    auto outputValue =
+        rewriter.create<tosa::ReshapeOp>(op.getLoc(), outputShapeType, mulValue,
+                                         rewriter.getI64ArrayAttr(outputShape));
+
+    // Add in the bias.
+    rewriter
+        .replaceOpWithNewOp<tosa::AddOp>(op, outputShapeType, outputValue,
+                                         op.bias())
+        .getResult();
+    return success();
+  }
+};
+
+void DepthwiseConv2DOp::getCanonicalizationPatterns(
+    OwningRewritePatternList &results, MLIRContext *context) {
+  results.insert<DepthwiseConv2DMulOptimization>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // Operator Folders.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index c4d105ca438e6..ed659ee91964d 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -106,6 +106,44 @@ func @conv2d_weight_2x2(%arg0: tensor<4x10x10x1xf32>) -> tensor<4x10x10x1xf32> {
   return %0 : tensor<4x10x10x1xf32>
 }
 
+// -----
+
+// CHECK-LABEL: @depthwise_conv2d_as_mul
+func @depthwise_conv2d_as_mul(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<1x1x2x3xf32>, %arg2: tensor<6xf32>) -> tensor<4x10x10x6xf32> {
+  // CHECK-NOT: "tosa.depthwise_conv2d"
+  // CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [4, 10, 10, 2, 1]}
+  // CHECK-SAME: -> tensor<4x10x10x2x1xf32>
+  // CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = [1, 1, 1, 2, 3]}
+  // CHECK-SAME: -> tensor<1x1x1x2x3xf32>
+  // CHECK: %[[VAR2:.*]] = "tosa.mul"(%[[VAR0]], %[[VAR1]])
+  // CHECK-SAME: -> tensor<4x10x10x2x3xf32>
+  // CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [4, 10, 10, 6]}
+  // CHECK-SAME: -> tensor<4x10x10x6xf32>
+  // CHECK: %[[VAR4:.*]] = "tosa.add"(%[[VAR3]], %arg2)
+  // CHECK-SAME: -> tensor<4x10x10x6xf32>
+  // CHECK: return %[[VAR4]]
+  %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1]} : (tensor<4x10x10x2xf32>, tensor<1x1x2x3xf32>, tensor<6xf32>) -> tensor<4x10x10x6xf32>
+  return %0 : tensor<4x10x10x6xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @depthwise_conv2d_stride_2
+func @depthwise_conv2d_stride_2(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<1x1x2x3xf32>, %arg2: tensor<6xf32>) -> tensor<4x10x10x6xf32> {
+  // CHECK: "tosa.depthwise_conv2d"
+  %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [2, 2], dilation = [1, 1]} : (tensor<4x10x10x2xf32>, tensor<1x1x2x3xf32>, tensor<6xf32>) -> tensor<4x10x10x6xf32>
+  return %0 : tensor<4x10x10x6xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @depthwise_conv2d_weight_2x2
+func @depthwise_conv2d_weight_2x2(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<2x2x2x3xf32>, %arg2: tensor<6xf32>) -> tensor<4x10x10x6xf32> {
+  // CHECK: "tosa.depthwise_conv2d"
+  %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1]} : (tensor<4x10x10x2xf32>, tensor<2x2x2x3xf32>, tensor<6xf32>) -> tensor<4x10x10x6xf32>
+  return %0 : tensor<4x10x10x6xf32>
+}
+
 // ----
 
 // CHECK-LABEL: @pad_noop


        


More information about the Mlir-commits mailing list