[Mlir-commits] [mlir] e2d8b60 - Revert "[mlir][tosa] Add tosa.conv2d as fully_connected canonicalization"

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Dec 3 14:37:23 PST 2021


Author: natashaknk
Date: 2021-12-03T14:35:48-08:00
New Revision: e2d8b607427d7d01bd9bdc007af4f4a4ac9833c6

URL: https://github.com/llvm/llvm-project/commit/e2d8b607427d7d01bd9bdc007af4f4a4ac9833c6
DIFF: https://github.com/llvm/llvm-project/commit/e2d8b607427d7d01bd9bdc007af4f4a4ac9833c6.diff

LOG: Revert "[mlir][tosa] Add tosa.conv2d as fully_connected canonicalization"

This reverts commit 13bdb7ab4a7acaea7144a042fe583d45fbb9b5c4. The commit introduced/uncovered an unintended bug in models containing Conv2D.

Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D115079

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
    mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
    mlir/test/Dialect/Tosa/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 554023dc03814..4de90058ba786 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -118,8 +118,6 @@ def Tosa_Conv2DOp : Tosa_Op<"conv2d", [
   let builders = [Tosa_ConvOpQuantInfoBuilder];
 
   let verifier = [{ return verifyConvOp(*this); }];
-
-  let hasCanonicalizer = 1;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 15830331eb9a0..2a9d5d4c93321 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -423,100 +423,6 @@ void PadOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
   results.insert<MaterializePadValue>(context);
 }
 
-struct Conv2DFullyConnectedOptimization
-    : public OpRewritePattern<tosa::Conv2DOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(tosa::Conv2DOp op,
-                                PatternRewriter &rewriter) const override {
-    Value input = op.input();
-    Value weight = op.weight();
-    ShapedType inputType = input.getType().cast<ShapedType>();
-    ShapedType weightType = weight.getType().cast<ShapedType>();
-
-    if (!inputType.hasStaticShape() || !weightType.hasRank()) {
-      return failure();
-    }
-
-    // Stride must be 1 for this optimization.
-    for (Attribute stride : op.stride().getValue()) {
-      if (!stride.cast<IntegerAttr>().getValue().isOne()) {
-        return failure();
-      }
-    }
-
-    // Only works for a 1x1 kernel.
-    ArrayRef<int64_t> weightShape = weightType.getShape();
-    if (weightShape[1] != 1 || weightShape[2] != 1) {
-      return failure();
-    }
-
-    // Reshape input to [N,IH,IW,IC] -> [N * IH * IW, IC].
-    ArrayRef<int64_t> inputShape = inputType.getShape();
-    llvm::SmallVector<int64_t, 2> revisedInputShape{
-        inputShape[0] * inputShape[1] * inputShape[2], inputShape[3]};
-    auto revisedInputShapeType = RankedTensorType::get(
-        revisedInputShape,
-        input.getType().dyn_cast<RankedTensorType>().getElementType());
-    auto reshapedInput = rewriter
-                             .create<tosa::ReshapeOp>(
-                                 op.getLoc(), revisedInputShapeType, input,
-                                 rewriter.getI64ArrayAttr(revisedInputShape))
-                             .getResult();
-
-    // Reshape kernel to [OC,KH,KW,IC] -> [OC, IC].
-    llvm::SmallVector<int64_t, 2> revisedWeightShape{weightShape[0],
-                                                     weightShape[3]};
-    auto revisedWeightShapeType = RankedTensorType::get(
-        revisedWeightShape,
-        weight.getType().dyn_cast<RankedTensorType>().getElementType());
-    auto reshapedWeight = rewriter
-                              .create<tosa::ReshapeOp>(
-                                  op.getLoc(), revisedWeightShapeType, weight,
-                                  rewriter.getI64ArrayAttr(revisedWeightShape))
-                              .getResult();
-
-    // Perform a fully connected network over the reshaped input and weight.
-    llvm::SmallVector<int64_t, 2> fullyConnectedShape{
-        inputShape[0] * inputShape[1] * inputShape[2], inputShape[3]};
-    auto fullyConnectedShapeType = RankedTensorType::get(
-        fullyConnectedShape,
-        weight.getType().dyn_cast<RankedTensorType>().getElementType());
-
-    Value fullyConnectedValue;
-    if (op.quantization_info()) {
-      fullyConnectedValue =
-          rewriter
-              .create<tosa::FullyConnectedOp>(
-                  op.getLoc(), fullyConnectedShapeType, reshapedInput,
-                  reshapedWeight, op.bias(), op.quantization_info().getValue())
-              .getResult();
-    } else {
-      fullyConnectedValue = rewriter
-                                .create<tosa::FullyConnectedOp>(
-                                    op.getLoc(), fullyConnectedShapeType,
-                                    reshapedInput, reshapedWeight, op.bias())
-                                .getResult();
-    }
-
-    // Reshape output to [N, IH, IW, OC].
-    llvm::SmallVector<int64_t, 4> outputShape{inputShape[0], inputShape[1],
-                                              inputShape[2], weightShape[0]};
-    auto outputShapeType = RankedTensorType::get(
-        outputShape,
-        input.getType().dyn_cast<RankedTensorType>().getElementType());
-    rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
-        op, outputShapeType, fullyConnectedValue,
-        rewriter.getI64ArrayAttr(outputShape));
-    return success();
-  }
-};
-
-void Conv2DOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
-                                           MLIRContext *context) {
-  results.insert<Conv2DFullyConnectedOptimization>(context);
-}
-
 //===----------------------------------------------------------------------===//
 // Operator Folders.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 39554d1563769..70f26650fe610 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -66,48 +66,12 @@ func @concat_fold_cast(%arg0: tensor<?x1xf32>) -> tensor<?x?xf32> {
   return %0 : tensor<?x?xf32>
 }
 
-// -----
-
-// CHECK-LABEL: @conv2d_as_fully_connected
-func @conv2d_as_fully_connected(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<3x1x1x2xf32>, %arg2: tensor<3xf32>) -> tensor<4x10x10x3xf32> {
-  // CHECK-NOT: "tosa.conv2d"
-  // CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [400, 2]}
-  // CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = [3, 2]}
-  // CHECK: %[[VAR2:.*]] = "tosa.fully_connected"(%[[VAR0]], %[[VAR1]], %arg2)
-  // CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [4, 10, 10, 3]}
-  // CHECK: return %[[VAR3]]
-  %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1]} : (tensor<4x10x10x2xf32>, tensor<3x1x1x2xf32>, tensor<3xf32>) -> tensor<4x10x10x3xf32>
-  return %0 : tensor<4x10x10x3xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @conv2d_stride_2
-func @conv2d_stride_2(%arg0: tensor<4x10x10x2xf32>) -> tensor<4x10x10x3xf32> {
-  // CHECK: "tosa.conv2d"
-  %weight = "tosa.const"() {value = dense<[[[[1.0, 1.0]]], [[[1.0, 1.0]]], [[[1.0, 1.0]]]]> : tensor<3x1x1x2xf32>} : ()-> tensor<3x1x1x2xf32>
-  %bias = "tosa.const"() {value = dense<0.0> : tensor<3xf32>} : ()-> tensor<3xf32>
-  %0 = "tosa.conv2d"(%arg0, %weight, %bias) {pad = [0, 0, 0, 0], stride = [2, 2], dilation = [1, 1]} : (tensor<4x10x10x2xf32>, tensor<3x1x1x2xf32>, tensor<3xf32>) -> tensor<4x10x10x3xf32>
-  return %0 : tensor<4x10x10x3xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @conv2d_weight_2x2
-func @conv2d_weight_2x2(%arg0: tensor<4x10x10x1xf32>) -> tensor<4x10x10x1xf32> {
-  // CHECK: "tosa.conv2d"
-  %weight = "tosa.const"() {value = dense<[[[[1.0], [1.0]], [[1.0], [1.0]]]]> : tensor<1x2x2x1xf32>} : ()-> tensor<1x2x2x1xf32>
-  %bias = "tosa.const"() {value = dense<0.0> : tensor<1xf32>} : ()-> tensor<1xf32>
-  %0 = "tosa.conv2d"(%arg0, %weight, %bias) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1]} : (tensor<4x10x10x1xf32>, tensor<1x2x2x1xf32>, tensor<1xf32>) -> tensor<4x10x10x1xf32>
-  return %0 : tensor<4x10x10x1xf32>
-}
-
 // ----
 
 // CHECK-LABEL: @pad_noop
 func @pad_noop(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
   // CHECK: return %arg0
-  %0 = "tosa.const"() { value = dense<0> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
+  %0 = "tosa.const"() { value = dense<0> : tensor<2x2xi32>} : () -> tensor<2x2xi32> 
   %1 = "tosa.pad"(%arg0, %0) : (tensor<?x?xf32>, tensor<2x2xi32>) -> tensor<?x?xf32>
   return %1 : tensor<?x?xf32>
 }
@@ -118,7 +82,7 @@ func @pad_noop(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
 func @pad_determine_val_i32(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi32>) -> tensor<?x?xi32> {
   // CHECK: %[[ZERO:.+]] = "tosa.const"() {value = dense<0> : tensor<i32>}
   // CHECK: "tosa.pad"(%arg0, %arg1, %[[ZERO]])
-  %0 = "tosa.const"() { value = dense<[[1, 0], [0, 1]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
+  %0 = "tosa.const"() { value = dense<[[1, 0], [0, 1]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32> 
   %1 = "tosa.pad"(%arg0, %arg1) : (tensor<?x?xi32>, tensor<2x2xi32>) -> tensor<?x?xi32>
   return %1 : tensor<?x?xi32>
 }
@@ -129,7 +93,7 @@ func @pad_determine_val_i32(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi32>) ->
 func @pad_determine_val_f32(%arg0: tensor<?x?xf32>, %arg1 : tensor<2x2xi32>) -> tensor<?x?xf32> {
   // CHECK: %[[ZERO:.+]] = "tosa.const"() {value = dense<0.000000e+00> : tensor<f32>}
   // CHECK: "tosa.pad"(%arg0, %arg1, %[[ZERO]])
-  %0 = "tosa.const"() { value = dense<[[1, 0], [0, 1]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
+  %0 = "tosa.const"() { value = dense<[[1, 0], [0, 1]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32> 
   %1 = "tosa.pad"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<2x2xi32>) -> tensor<?x?xf32>
   return %1 : tensor<?x?xf32>
 }
@@ -140,7 +104,7 @@ func @pad_determine_val_f32(%arg0: tensor<?x?xf32>, %arg1 : tensor<2x2xi32>) ->
 func @pad_determine_val_quant(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi32>) -> tensor<?x?xi32> {
   // CHECK: %[[ZERO:.+]] = "tosa.const"() {value = dense<42> : tensor<i32>}
   // CHECK: "tosa.pad"(%arg0, %arg1, %[[ZERO]])
-  %0 = "tosa.const"() { value = dense<[[1, 0], [0, 1]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
+  %0 = "tosa.const"() { value = dense<[[1, 0], [0, 1]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32> 
   %1 = "tosa.pad"(%arg0, %arg1) { quantization_info = {input_zp = 42:i32} } : (tensor<?x?xi32>, tensor<2x2xi32>) -> tensor<?x?xi32>
   return %1 : tensor<?x?xi32>
 }


        


More information about the Mlir-commits mailing list