[Mlir-commits] [mlir] 9a2308e - [mlir][tosa] Minor cleanup of tosa.conv2d canonicalizer

Rob Suderman llvmlistbot at llvm.org
Thu Dec 16 15:14:05 PST 2021


Author: Rob Suderman
Date: 2021-12-16T15:13:01-08:00
New Revision: 9a2308e170b468df62dd969a41b8c408992cf84e

URL: https://github.com/llvm/llvm-project/commit/9a2308e170b468df62dd969a41b8c408992cf84e
DIFF: https://github.com/llvm/llvm-project/commit/9a2308e170b468df62dd969a41b8c408992cf84e.diff

LOG: [mlir][tosa] Minor cleanup of tosa.conv2d canonicalizer

Slight rename and better variable type usage in tosa.conv2d to
tosa.fully_connected lowering. Included disabling pass for padded
convolutions.

Reviewed By: not-jenni

Differential Revision: https://reviews.llvm.org/D115776

Added: 
    

Modified: 
    mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
    mlir/test/Dialect/Tosa/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 579f0b407ee8e..68c0d015c3a60 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -423,8 +423,7 @@ void PadOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
   results.insert<MaterializePadValue>(context);
 }
 
-struct Conv2DFullyConnectedOptimization
-    : public OpRewritePattern<tosa::Conv2DOp> {
+struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
   using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(tosa::Conv2DOp op,
@@ -439,6 +438,12 @@ struct Conv2DFullyConnectedOptimization
       return failure();
     }
 
+    for (Attribute pad : op.pad().getValue()) {
+      if (!pad.cast<IntegerAttr>().getValue().isZero()) {
+        return failure();
+      }
+    }
+
     // Stride must be 1 for this optimization.
     for (Attribute stride : op.stride().getValue()) {
       if (!stride.cast<IntegerAttr>().getValue().isOne()) {
@@ -456,9 +461,8 @@ struct Conv2DFullyConnectedOptimization
     ArrayRef<int64_t> inputShape = inputType.getShape();
     llvm::SmallVector<int64_t, 2> revisedInputShape{
         inputShape[0] * inputShape[1] * inputShape[2], inputShape[3]};
-    auto revisedInputShapeType = RankedTensorType::get(
-        revisedInputShape,
-        input.getType().dyn_cast<RankedTensorType>().getElementType());
+    auto revisedInputShapeType =
+        RankedTensorType::get(revisedInputShape, inputType.getElementType());
     auto reshapedInput = rewriter
                              .create<tosa::ReshapeOp>(
                                  op.getLoc(), revisedInputShapeType, input,
@@ -468,9 +472,8 @@ struct Conv2DFullyConnectedOptimization
     // Reshape kernel to [OC,KH,KW,IC] -> [OC, IC].
     llvm::SmallVector<int64_t, 2> revisedWeightShape{weightShape[0],
                                                      weightShape[3]};
-    auto revisedWeightShapeType = RankedTensorType::get(
-        revisedWeightShape,
-        weight.getType().dyn_cast<RankedTensorType>().getElementType());
+    auto revisedWeightShapeType =
+        RankedTensorType::get(revisedWeightShape, weightType.getElementType());
     auto reshapedWeight = rewriter
                               .create<tosa::ReshapeOp>(
                                   op.getLoc(), revisedWeightShapeType, weight,
@@ -480,9 +483,8 @@ struct Conv2DFullyConnectedOptimization
     // Perform a fully connected network over the reshaped input and weight.
     llvm::SmallVector<int64_t, 2> fullyConnectedShape{
         inputShape[0] * inputShape[1] * inputShape[2], weightShape[0]};
-    auto fullyConnectedShapeType = RankedTensorType::get(
-        fullyConnectedShape,
-        resultType.dyn_cast<ShapedType>().getElementType());
+    auto fullyConnectedShapeType =
+        RankedTensorType::get(fullyConnectedShape, resultType.getElementType());
 
     Value fullyConnectedValue;
     if (op.quantization_info()) {
@@ -512,7 +514,7 @@ struct Conv2DFullyConnectedOptimization
 
 void Conv2DOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
                                            MLIRContext *context) {
-  results.insert<Conv2DFullyConnectedOptimization>(context);
+  results.insert<Conv2DIsFullyConnected>(context);
 }
 
 struct DepthwiseConv2DMulOptimization

diff  --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 7436bfa8ba9ff..91c2e3ce7feb6 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -105,6 +105,15 @@ func @conv2d_as_fully_connected_quant(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor
 
 // -----
 
+// CHECK-LABEL: @conv2d_padded
+func @conv2d_padded(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<3x1x1x2xf32>, %arg2: tensor<3xf32>) -> tensor<4x12x12x3xf32> {
+  // CHECK: "tosa.conv2d"
+  %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {pad = [1, 1, 1, 1], stride = [1, 1], dilation = [1, 1]} : (tensor<4x10x10x2xf32>, tensor<3x1x1x2xf32>, tensor<3xf32>) -> tensor<4x12x12x3xf32>
+  return %0 : tensor<4x12x12x3xf32>
+}
+
+// -----
+
 // CHECK-LABEL: @conv2d_stride_2
 func @conv2d_stride_2(%arg0: tensor<4x10x10x2xf32>) -> tensor<4x10x10x3xf32> {
   // CHECK: "tosa.conv2d"


        


More information about the Mlir-commits mailing list