[Mlir-commits] [mlir] f596acc - [mlir][tosa] Small refactor to the functionality of Depthwise_Conv2D to add the bias at the end of the convolution
    Rob Suderman 
    llvmlistbot at llvm.org
       
    Wed Sep  1 10:07:26 PDT 2021
    
    
  
Author: natashaknk
Date: 2021-09-01T10:01:00-07:00
New Revision: f596acc74d4bccd034955042e385a2d5e2ba4f05
URL: https://github.com/llvm/llvm-project/commit/f596acc74d4bccd034955042e385a2d5e2ba4f05
DIFF: https://github.com/llvm/llvm-project/commit/f596acc74d4bccd034955042e385a2d5e2ba4f05.diff
LOG: [mlir][tosa] Small refactor to the functionality of Depthwise_Conv2D to add the bias at the end of the convolution
Follow-up to the Conv2d and fully_connected lowering adjustments
Reviewed By: rsuderman
Differential Revision: https://reviews.llvm.org/D108949
Added: 
    
Modified: 
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
Removed: 
    
################################################################################
diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 02db33cd01ec5..e6be286f43b42 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1129,27 +1129,6 @@ class DepthwiseConvConverter
 
     input = applyPad(loc, input, pad, zeroAttr, rewriter);
 
-    // Broadcast the initial value to the output tensor before convolving.
-    SmallVector<AffineMap, 4> indexingMaps;
-    indexingMaps.push_back(AffineMap::get(
-        /*dimCount=*/resultTy.getRank(), /*symbolCount=*/0,
-        {rewriter.getAffineDimExpr(3)}, rewriter.getContext()));
-    indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank()));
-
-    Value initTensor =
-        rewriter.create<linalg::InitTensorOp>(loc, resultShape, resultETy);
-
-    Value biasBroadcast =
-        rewriter
-            .create<linalg::GenericOp>(
-                loc, resultTy, bias, initTensor, indexingMaps,
-                getNParallelLoopsAttrs(resultTy.getRank()),
-                [&](OpBuilder &nestedBuilder, Location nestedLoc,
-                    ValueRange args) {
-                  nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
-                })
-            .getResult(0);
-
     // Extract the attributes for convolution.
     llvm::SmallVector<int64_t> stride, dilation;
     getValuesFromIntArrayAttribute(strideTosaAttr, stride);
@@ -1165,28 +1144,69 @@ class DepthwiseConvConverter
                                weightShape[2], weightShape[3]},
                               resultETy);
 
-    Value biasReshape =
-        rewriter.create<tosa::ReshapeOp>(loc, linalgConvTy, biasBroadcast);
-    Value conv;
+    // Broadcast the initial value to the output tensor before convolving.
+    SmallVector<AffineMap, 4> indexingMaps;
+    indexingMaps.push_back(AffineMap::get(
+        /*dimCount=*/resultTy.getRank(), /*symbolCount=*/0,
+        {rewriter.getAffineDimExpr(3)}, rewriter.getContext()));
+    indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank()));
+    indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank()));
+
+    Attribute resultZeroAttr = rewriter.getZeroAttr(resultETy);
+    Value initTensor = rewriter.create<linalg::InitTensorOp>(
+        loc, linalgConvTy.getShape(), resultETy);
+    Value zero = rewriter.create<ConstantOp>(loc, resultZeroAttr);
+    Value zeroTensor =
+        rewriter.create<linalg::FillOp>(loc, zero, initTensor).getResult(0);
+
+    Value biasInitTensor = rewriter.create<linalg::InitTensorOp>(
+        loc, resultTy.getShape(), resultETy);
     if (!isQuantized) {
-      conv = rewriter
-                 .create<linalg::DepthwiseConv2DNhwcOp>(
-                     loc, linalgConvTy, ValueRange{input, weight},
-                     ValueRange{biasReshape}, strideAttr, dilationAttr)
-                 .getResult(0);
+      Value conv = rewriter
+                       .create<linalg::DepthwiseConv2DNhwcOp>(
+                           loc, linalgConvTy, ValueRange{input, weight},
+                           ValueRange{zeroTensor}, strideAttr, dilationAttr)
+                       .getResult(0);
+      Value convReshape = rewriter.create<tosa::ReshapeOp>(loc, resultTy, conv);
+      Value result =
+          rewriter
+              .create<linalg::GenericOp>(
+                  loc, resultTy, ValueRange({bias, convReshape}),
+                  biasInitTensor, indexingMaps,
+                  getNParallelLoopsAttrs(resultTy.getRank()),
+                  [&](OpBuilder &nestedBuilder, Location nestedLoc,
+                      ValueRange args) {
+                    Value added =
+                        nestedBuilder.create<AddFOp>(loc, args[0], args[1]);
+                    nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
+                  })
+              .getResult(0);
+      rewriter.replaceOp(op, result);
     } else {
       auto iZpVal = rewriter.create<ConstantOp>(loc, iZp);
       auto kZpVal = rewriter.create<ConstantOp>(loc, kZp);
-      conv =
+      Value conv =
           rewriter
               .create<linalg::DepthwiseConv2DNhwcQOp>(
                   loc, linalgConvTy, ValueRange{input, weight, iZpVal, kZpVal},
-                  ValueRange{biasReshape}, strideAttr, dilationAttr)
+                  ValueRange{zeroTensor}, strideAttr, dilationAttr)
+              .getResult(0);
+      Value convReshape = rewriter.create<tosa::ReshapeOp>(loc, resultTy, conv);
+      Value result =
+          rewriter
+              .create<linalg::GenericOp>(
+                  loc, resultTy, ValueRange({bias, convReshape}),
+                  biasInitTensor, indexingMaps,
+                  getNParallelLoopsAttrs(resultTy.getRank()),
+                  [&](OpBuilder &nestedBuilder, Location nestedLoc,
+                      ValueRange args) {
+                    Value added =
+                        nestedBuilder.create<AddIOp>(loc, args[0], args[1]);
+                    nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
+                  })
               .getResult(0);
+      rewriter.replaceOp(op, result);
     }
-
-    Value reshape = rewriter.create<tosa::ReshapeOp>(loc, resultTy, conv);
-    rewriter.replaceOp(op, reshape);
     return success();
   }
 };
diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 0cadf6e7caadf..9cf3eba69d1ad 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1404,14 +1404,17 @@ func @conv2d_quant(%arg0 : tensor<1x12x12x1xi8>, %arg1 : tensor<1024x3x3x1xi8>,
 
 // CHECK-LABEL: @depthwise_conv
 func @depthwise_conv(%arg0 : tensor<1x7x5x3xf32>, %arg1 : tensor<3x1x3x11xf32>, %arg2 : tensor<33xf32>) -> () {
-  // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 5, 5, 33]
-  // CHECK: [[BIAS:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<33xf32>) outs([[INIT]] : tensor<1x5x5x33xf32>) {
-  // CHECK: ^bb0(%arg3: f32, %arg4: f32):  // no predecessors
-  // CHECK:   linalg.yield %arg3 : f32
+  // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 5, 5, 3, 11]
+  // CHECK: [[CST0:%.+]] = constant 0
+  // CHECK: [[FILL:%.+]] = linalg.fill([[CST0]], [[INIT]])
+  // CHECK: [[OUT:%.+]] = linalg.init_tensor [1, 5, 5, 33]
+  // CHECK: [[DEPTH:%.+]] = linalg.depthwise_conv2D_nhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x7x5x3xf32>, tensor<3x1x3x11xf32>) outs([[FILL]] : tensor<1x5x5x3x11xf32>)
+  // CHECK: [[COLLAPSED:%.+]] = linalg.tensor_collapse_shape [[DEPTH]] {{\[}}[0], [1], [2], [3, 4]]
+  // CHECK: [[BIAS:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, [[COLLAPSED]] : tensor<33xf32>, tensor<1x5x5x33xf32>) outs([[OUT]] : tensor<1x5x5x33xf32>) {
+  // CHECK: ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):  // no predecessors
+  // CHECK:   [[ADD:%.+]] = addf %arg3, %arg4 : f32
+  // CHECK:   linalg.yield [[ADD]] : f32
   // CHECK: } -> tensor<1x5x5x33xf32>
-  // CHECK: [[DBIAS:%.+]] = linalg.tensor_expand_shape [[BIAS]] {{\[}}[0], [1], [2], [3, 4]]
-  // CHECK: [[DEPTH:%.+]] = linalg.depthwise_conv2D_nhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x7x5x3xf32>, tensor<3x1x3x11xf32>) outs([[DBIAS]] : tensor<1x5x5x3x11xf32>)
-  // CHECK: linalg.tensor_collapse_shape %3 {{\[}}[0], [1], [2], [3, 4]]
   %2 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) { pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1] } : (tensor<1x7x5x3xf32>, tensor<3x1x3x11xf32>, tensor<33xf32>)  -> (tensor<1x5x5x33xf32>)
   return
 }
@@ -1423,14 +1426,17 @@ func @depthwise_conv(%arg0 : tensor<1x7x5x3xf32>, %arg1 : tensor<3x1x3x11xf32>,
 
 // CHECK-LABEL: @depthwise_conv_strides
 func @depthwise_conv_strides(%arg0 : tensor<1x11x9x3xf32>, %arg1 : tensor<3x1x3x11xf32>, %arg2 : tensor<33xf32>) -> () {
-  // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 5, 5, 33]
-  // CHECK: [[BIAS:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<33xf32>) outs([[INIT]] : tensor<1x5x5x33xf32>) {
-  // CHECK: ^bb0(%arg3: f32, %arg4: f32):  // no predecessors
-  // CHECK:   linalg.yield %arg3 : f32
+  // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 5, 5, 3, 11]
+  // CHECK: [[CST0:%.+]] = constant 0
+  // CHECK: [[FILL:%.+]] = linalg.fill([[CST0]], [[INIT]])
+  // CHECK: [[OUT:%.+]] = linalg.init_tensor [1, 5, 5, 33]
+  // CHECK: [[DEPTH:%.+]] = linalg.depthwise_conv2D_nhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x11x9x3xf32>, tensor<3x1x3x11xf32>) outs([[FILL]] : tensor<1x5x5x3x11xf32>)
+  // CHECK: [[COLLAPSED:%.+]] = linalg.tensor_collapse_shape [[DEPTH]] {{\[}}[0], [1], [2], [3, 4]]
+  // CHECK: [[BIAS:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, [[COLLAPSED]] : tensor<33xf32>, tensor<1x5x5x33xf32>) outs([[OUT]] : tensor<1x5x5x33xf32>) {
+  // CHECK: ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):  // no predecessors
+  // CHECK:   [[ADD:%.+]] = addf %arg3, %arg4 : f32
+  // CHECK:   linalg.yield [[ADD]] : f32
   // CHECK: } -> tensor<1x5x5x33xf32>
-  // CHECK: [[DBIAS:%.+]] = linalg.tensor_expand_shape [[BIAS]] {{\[}}[0], [1], [2], [3, 4]]
-  // CHECK: [[DEPTH:%.+]] = linalg.depthwise_conv2D_nhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x11x9x3xf32>, tensor<3x1x3x11xf32>) outs([[DBIAS]] : tensor<1x5x5x3x11xf32>)
-  // CHECK: linalg.tensor_collapse_shape [[DEPTH]] {{\[}}[0], [1], [2], [3, 4]]
   %2 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) { pad = [0, 0, 0, 0], stride = [2, 2], dilation = [1, 1] } : (tensor<1x11x9x3xf32>, tensor<3x1x3x11xf32>, tensor<33xf32>)  -> (tensor<1x5x5x33xf32>)
   return
 }
@@ -1442,20 +1448,23 @@ func @depthwise_conv_strides(%arg0 : tensor<1x11x9x3xf32>, %arg1 : tensor<3x1x3x
 
 // CHECK-LABEL: @depthwise_conv_quant
 func @depthwise_conv_quant(%arg0 : tensor<1x12x12x4xi8>, %arg1 : tensor<3x3x4x128xi8>, %arg2 : tensor<512xi32>) -> () {
-  // CHECK: %[[PADV:.+]] = constant -128
-  // CHECK: %[[PAD:.+]] = linalg.pad_tensor %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0]
-  // CHECK:   linalg.yield %[[PADV]]
-
-  // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 12, 12, 512]
-  // CHECK: [[BIAS:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<512xi32>) outs([[INIT]] : tensor<1x12x12x512xi32>) {
-  // CHECK: ^bb0(%arg3: i32, %arg4: i32):  // no predecessors
-  // CHECK:   linalg.yield %arg3 : i32
+  // CHECK: [[PADV:%.+]] = constant -128
+  // CHECK: [[PAD:%.+]] = linalg.pad_tensor %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0]
+  // CHECK:   linalg.yield [[PADV]]
+
+  // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 12, 12, 4, 128]
+  // CHECK: [[CST0:%.+]] = constant 0
+  // CHECK: [[FILL:%.+]] = linalg.fill([[CST0]], [[INIT]])
+  // CHECK: [[OUT:%.+]] = linalg.init_tensor [1, 12, 12, 512]
+  // CHECK: [[C128:%.+]] = constant -128
+  // CHECK: [[C42:%.+]] = constant 42
+  // CHECK: [[DEPTH:%.+]] = linalg.depthwise_conv2D_nhwc_q {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins([[PAD]], %arg1, [[C128]], [[C42]] : tensor<1x14x14x4xi8>, tensor<3x3x4x128xi8>, i32, i32) outs([[FILL]] : tensor<1x12x12x4x128xi32>)
+  // CHECK: [[COLLAPSED:%.+]] = linalg.tensor_collapse_shape [[DEPTH]] {{\[}}[0], [1], [2], [3, 4]]
+  // CHECK: [[BIAS:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, [[COLLAPSED]] : tensor<512xi32>, tensor<1x12x12x512xi32>) outs([[OUT]] : tensor<1x12x12x512xi32>) {
+  // CHECK: ^bb0(%arg3: i32, %arg4: i32, %arg5: i32):  // no predecessors
+  // CHECK:   [[ADD:%.+]] = addi %arg3, %arg4 : i32
+  // CHECK:   linalg.yield [[ADD]] : i32
   // CHECK: } -> tensor<1x12x12x512xi32>
-  // CHECK: %[[DBIAS:.+]] = linalg.tensor_expand_shape [[BIAS]] {{\[}}[0], [1], [2], [3, 4]]
-  // CHECK: %[[C128:.+]] = constant -128
-  // CHECK: %[[C42:.+]] = constant 42
-  // CHECK: %[[DEPTH:.+]] = linalg.depthwise_conv2D_nhwc_q {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%[[PAD]], %arg1, %[[C128]], %[[C42]] : tensor<1x14x14x4xi8>, tensor<3x3x4x128xi8>, i32, i32) outs(%[[DBIAS]] : tensor<1x12x12x4x128xi32>)
-  // CHECK: linalg.tensor_collapse_shape %[[DEPTH]] {{\[}}[0], [1], [2], [3, 4]]
   %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [1, 1, 1, 1], quantization_info = {input_zp = -128 : i32, weight_zp = 42 : i32}, stride = [1, 1], dilation = [1, 1] } : (tensor<1x12x12x4xi8>, tensor<3x3x4x128xi8>, tensor<512xi32>)  -> tensor<1x12x12x512xi32>
   return
 }
@@ -1467,16 +1476,19 @@ func @depthwise_conv_quant(%arg0 : tensor<1x12x12x4xi8>, %arg1 : tensor<3x3x4x12
 
 // CHECK-LABEL: @depthwise_conv_quant_dilations
 func @depthwise_conv_quant_dilations(%arg0 : tensor<1x14x14x4xi8>, %arg1 : tensor<3x3x4x128xi8>, %arg2 : tensor<512xi32>) -> () {
-  // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 10, 10, 512]
-  // CHECK: [[BIAS:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<512xi32>) outs([[INIT]] : tensor<1x10x10x512xi32>) {
-  // CHECK: ^bb0(%arg3: i32, %arg4: i32):  // no predecessors
-  // CHECK:   linalg.yield %arg3 : i32
+  // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 10, 10, 4, 128]
+  // CHECK: [[CST0:%.+]] = constant 0
+  // CHECK: [[FILL:%.+]] = linalg.fill([[CST0]], [[INIT]])
+  // CHECK: [[OUT:%.+]] = linalg.init_tensor [1, 10, 10, 512]
+  // CHECK: [[C128:%.+]] = constant -128
+  // CHECK: [[C42:%.+]] = constant 42
+  // CHECK: [[DEPTH:%.+]] = linalg.depthwise_conv2D_nhwc_q {dilations = dense<2> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1, [[C128]], [[C42]] : tensor<1x14x14x4xi8>, tensor<3x3x4x128xi8>, i32, i32) outs([[FILL]] : tensor<1x10x10x4x128xi32>)
+  // CHECK: [[COLLAPSED:%.+]] = linalg.tensor_collapse_shape [[DEPTH]] {{\[}}[0], [1], [2], [3, 4]]
+  // CHECK: [[BIAS:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, [[COLLAPSED]] : tensor<512xi32>, tensor<1x10x10x512xi32>) outs([[OUT]] : tensor<1x10x10x512xi32>) {
+  // CHECK: ^bb0(%arg3: i32, %arg4: i32, %arg5: i32):  // no predecessors
+  // CHECK:   [[ADD:%.+]] = addi %arg3, %arg4 : i32
+  // CHECK:   linalg.yield [[ADD]] : i32
   // CHECK: } -> tensor<1x10x10x512xi32>
-  // CHECK: [[DBIAS:%.+]] = linalg.tensor_expand_shape [[BIAS]] {{\[}}[0], [1], [2], [3, 4]]
-  // CHECK: %[[C128:.+]] = constant -128
-  // CHECK: %[[C42:.+]] = constant 42
-  // CHECK: [[DEPTH:%.+]] = linalg.depthwise_conv2D_nhwc_q {dilations = dense<2> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1, %[[C128]], %[[C42]] : tensor<1x14x14x4xi8>, tensor<3x3x4x128xi8>, i32, i32) outs([[DBIAS]] : tensor<1x10x10x4x128xi32>)
-  // CHECK: linalg.tensor_collapse_shape %3 {{\[}}[0], [1], [2], [3, 4]]
   %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], quantization_info = {input_zp = -128 : i32, weight_zp = 42 : i32}, stride = [1, 1], dilation = [2, 2] } : (tensor<1x14x14x4xi8>, tensor<3x3x4x128xi8>, tensor<512xi32>)  -> tensor<1x10x10x512xi32>
   return
 }
        
    
    
More information about the Mlir-commits
mailing list