[Mlir-commits] [mlir] 203d38b - [mlir][tosa] Small refactor to the functionality of Conv2D and Fully_connected to add the bias at the end of the convolution

Rob Suderman llvmlistbot at llvm.org
Mon Aug 30 13:25:03 PDT 2021


Author: natashaknk
Date: 2021-08-30T13:18:43-07:00
New Revision: 203d38b234b8a744f364807ff0f15644453b14c7

URL: https://github.com/llvm/llvm-project/commit/203d38b234b8a744f364807ff0f15644453b14c7
DIFF: https://github.com/llvm/llvm-project/commit/203d38b234b8a744f364807ff0f15644453b14c7.diff

LOG: [mlir][tosa] Small refactor to the functionality of Conv2D and Fully_connected to add the bias at the end of the convolution

Made to adjust for a modification to the tiling algorithm

Reviewed By: rsuderman

Differential Revision: https://reviews.llvm.org/D108746

Added: 
    

Modified: 
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 74239fed3f010..02db33cd01ec5 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -970,26 +970,12 @@ class ConvConverter : public OpConversionPattern<tosa::Conv2DOp> {
     weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
                                                 weightPermValue);
 
-    // Broadcast the initial value to the output tensor before convolving.
-    SmallVector<AffineMap, 4> indexingMaps;
-    indexingMaps.push_back(AffineMap::get(
-        /*dimCount=*/resultTy.getRank(), /*symbolCount=*/0,
-        {rewriter.getAffineDimExpr(3)}, rewriter.getContext()));
-    indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank()));
-
+    Attribute resultZeroAttr = rewriter.getZeroAttr(resultETy);
     Value initTensor = rewriter.create<linalg::InitTensorOp>(
         loc, resultTy.getShape(), resultETy);
-
-    Value biasBroadcast =
-        rewriter
-            .create<linalg::GenericOp>(
-                loc, resultTy, bias, initTensor, indexingMaps,
-                getNParallelLoopsAttrs(resultTy.getRank()),
-                [&](OpBuilder &nestedBuilder, Location nestedLoc,
-                    ValueRange args) {
-                  nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
-                })
-            .getResult(0);
+    Value zero = rewriter.create<ConstantOp>(loc, resultZeroAttr);
+    Value zeroTensor =
+        rewriter.create<linalg::FillOp>(loc, zero, initTensor).getResult(0);
 
     // Extract the attributes for convolution.
     llvm::SmallVector<int64_t> stride, dilation;
@@ -1002,7 +988,17 @@ class ConvConverter : public OpConversionPattern<tosa::Conv2DOp> {
     auto dilationAttr = DenseIntElementsAttr::get(
         RankedTensorType::get({2}, rewriter.getI64Type()), dilation);
 
-    Value conv;
+    // Create maps for the bias broadcasting
+    SmallVector<AffineMap, 4> indexingMaps;
+    indexingMaps.push_back(AffineMap::get(
+        /*dimCount=*/resultTy.getRank(), /*symbolCount=*/0,
+        {rewriter.getAffineDimExpr(3)}, rewriter.getContext()));
+    indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank()));
+    indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank()));
+
+    Value biasInitTensor = rewriter.create<linalg::InitTensorOp>(
+        loc, resultTy.getShape(), resultETy);
+
     if (isQuantized) {
       auto quantizationInfo =
           op->getAttr("quantization_info").cast<tosa::ConvOpQuantizationAttr>();
@@ -1013,15 +1009,49 @@ class ConvConverter : public OpConversionPattern<tosa::Conv2DOp> {
 
       auto iZpVal = rewriter.create<ConstantOp>(loc, iZp);
       auto kZpVal = rewriter.create<ConstantOp>(loc, kZp);
-      rewriter.replaceOpWithNewOp<linalg::Conv2DNhwcHwcfQOp>(
-          op, resultTy, ValueRange{input, weight, iZpVal, kZpVal},
-          ValueRange{biasBroadcast}, strideAttr, dilationAttr);
+      Value conv =
+          rewriter
+              .create<linalg::Conv2DNhwcHwcfQOp>(
+                  loc, resultTy, ValueRange{input, weight, iZpVal, kZpVal},
+                  ValueRange{zeroTensor}, strideAttr, dilationAttr)
+              ->getResult(0);
+
+      Value result =
+          rewriter
+              .create<linalg::GenericOp>(
+                  loc, resultTy, ValueRange({bias, conv}), biasInitTensor,
+                  indexingMaps, getNParallelLoopsAttrs(resultTy.getRank()),
+                  [&](OpBuilder &nestedBuilder, Location nestedLoc,
+                      ValueRange args) {
+                    Value added =
+                        nestedBuilder.create<AddIOp>(loc, args[0], args[1]);
+                    nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
+                  })
+              .getResult(0);
+      rewriter.replaceOp(op, result);
       return success();
     }
 
-    rewriter.replaceOpWithNewOp<linalg::Conv2DNhwcHwcfOp>(
-        op, resultTy, ValueRange{input, weight}, ValueRange{biasBroadcast},
-        strideAttr, dilationAttr);
+    Value conv = rewriter
+                     .create<linalg::Conv2DNhwcHwcfOp>(
+                         loc, resultTy, ValueRange{input, weight},
+                         ValueRange{zeroTensor}, strideAttr, dilationAttr)
+                     ->getResult(0);
+
+    Value result =
+        rewriter
+            .create<linalg::GenericOp>(
+                loc, resultTy, ValueRange({bias, conv}), biasInitTensor,
+                indexingMaps, getNParallelLoopsAttrs(resultTy.getRank()),
+                [&](OpBuilder &nestedBuilder, Location nestedLoc,
+                    ValueRange args) {
+                  Value added =
+                      nestedBuilder.create<AddFOp>(loc, args[0], args[1]);
+                  nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
+                })
+            .getResult(0);
+
+    rewriter.replaceOp(op, result);
     return success();
   }
 };
@@ -1288,6 +1318,8 @@ class FullyConnectedConverter
     auto weightTy = weight.getType().cast<ShapedType>();
     auto weightShape = weightTy.getShape();
 
+    auto outputETy = outputTy.getElementType();
+
     // Creating maps for the output of MatMul and the bias
     SmallVector<AffineMap, 4> indexingMaps;
 
@@ -1297,23 +1329,16 @@ class FullyConnectedConverter
                                           rewriter.getContext()));
 
     indexingMaps.push_back(rewriter.getMultiDimIdentityMap(outputTy.getRank()));
+    indexingMaps.push_back(rewriter.getMultiDimIdentityMap(outputTy.getRank()));
 
-    auto initTensor =
-        rewriter
-            .create<linalg::InitTensorOp>(loc, outputTy.getShape(),
-                                          outputTy.getElementType())
-            ->getResults();
+    auto initTensor = rewriter.create<linalg::InitTensorOp>(
+        loc, outputTy.getShape(), outputTy.getElementType());
 
-    auto linalgOp =
-        rewriter
-            .create<linalg::GenericOp>(
-                loc, outputTy, bias, initTensor, indexingMaps,
-                getNParallelLoopsAttrs(outputTy.getRank()),
-                [&](OpBuilder &nested_builder, Location nested_loc,
-                    ValueRange args) {
-                  nested_builder.create<linalg::YieldOp>(loc, *args.begin());
-                })
-            ->getResults();
+    // When quantized, the input elemeny type is not the same as the output
+    Attribute resultZeroAttr = rewriter.getZeroAttr(outputETy);
+    Value zero = rewriter.create<ConstantOp>(loc, resultZeroAttr);
+    Value zeroTensor =
+        rewriter.create<linalg::FillOp>(loc, zero, initTensor).getResult(0);
 
     SmallVector<int64_t> permutation{1, 0};
     auto permutationAttr = DenseIntElementsAttr::get(
@@ -1327,10 +1352,31 @@ class FullyConnectedConverter
     Value transposedWeight = rewriter.create<tosa::TransposeOp>(
         loc, newWeightTy, weight, permutationValue);
 
+    auto biasInitTensor =
+        rewriter
+            .create<linalg::InitTensorOp>(loc, outputTy.getShape(), outputETy)
+            ->getResults();
+
     if (!op.quantization_info()) {
-      rewriter.replaceOpWithNewOp<linalg::MatmulOp>(
-          op, TypeRange{op.getType()}, ValueRange{input, transposedWeight},
-          linalgOp);
+      Value matmul = rewriter
+                         .create<linalg::MatmulOp>(
+                             loc, TypeRange{op.getType()},
+                             ValueRange{input, transposedWeight}, zeroTensor)
+                         ->getResult(0);
+
+      Value result =
+          rewriter
+              .create<linalg::GenericOp>(
+                  loc, outputTy, ValueRange({bias, matmul}), biasInitTensor,
+                  indexingMaps, getNParallelLoopsAttrs(outputTy.getRank()),
+                  [&](OpBuilder &nestedBuilder, Location nestedLoc,
+                      ValueRange args) {
+                    Value added =
+                        nestedBuilder.create<AddFOp>(loc, args[0], args[1]);
+                    nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
+                  })
+              .getResult(0);
+      rewriter.replaceOp(op, result);
       return success();
     }
 
@@ -1341,10 +1387,26 @@ class FullyConnectedConverter
     auto outputZp = rewriter.create<ConstantOp>(
         loc, rewriter.getI32IntegerAttr(
                  quantizationInfo.weight_zp().getValue().getSExtValue()));
-    rewriter.replaceOpWithNewOp<linalg::QuantizedMatmulOp>(
-        op, TypeRange{op.getType()},
-        ValueRange{input, transposedWeight, inputZp, outputZp}, linalgOp);
-
+    Value matmul =
+        rewriter
+            .create<linalg::QuantizedMatmulOp>(
+                loc, TypeRange{op.getType()},
+                ValueRange{input, transposedWeight, inputZp, outputZp},
+                zeroTensor)
+            ->getResult(0);
+    Value result =
+        rewriter
+            .create<linalg::GenericOp>(
+                loc, outputTy, ValueRange({bias, matmul}), biasInitTensor,
+                indexingMaps, getNParallelLoopsAttrs(outputTy.getRank()),
+                [&](OpBuilder &nestedBuilder, Location nestedLoc,
+                    ValueRange args) {
+                  Value added =
+                      nestedBuilder.create<AddIOp>(loc, args[0], args[1]);
+                  nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
+                })
+            .getResult(0);
+    rewriter.replaceOp(op, result);
     return success();
   }
 };

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 50e4c78bb2d45..0cadf6e7caadf 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -993,44 +993,55 @@ func @matmul_quantized(%arg0: tensor<1x5x3xi8>, %arg1: tensor<1x3x6xi8>) -> (ten
 
 // -----
 
-// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d1)>
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d1, d0)>
 // CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)>
-// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d1, d0)>
+// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d1)>
 
 // CHECK-LABEL: @fully_connected
 func @fully_connected(%arg0: tensor<5x3xf32>, %arg1: tensor<6x3xf32>, %arg2: tensor<6xf32>) -> (tensor<5x6xf32>) {
-  // CHECK: [[INITB:%.+]] = linalg.init_tensor [5, 6]
-  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg2 : tensor<6xf32>) outs([[INITB]] : tensor<5x6xf32>) {
-  // CHECK: ^bb0([[IN:%.+]]: f32, [[UNUSED:%.+]]: f32):
-  // CHECK:   linalg.yield [[IN]] : f32
+  // CHECK: [[INITT:%.+]] = linalg.init_tensor [5, 6]
+  // CHECK: [[ZERO:%.+]] = constant 0
+  // CHECK: [[FILL:%.+]] = linalg.fill([[ZERO]], [[INITT]])
+  // CHECK: [[PERM:%.+]] = constant dense<[1, 0]>
   // CHECK: [[INITT:%.+]] = linalg.init_tensor [3, 6]
-  // CHECK: [[TRANSPOSE:%.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg1 : tensor<6x3xf32>) outs([[INITT]]
+  // CHECK: [[TRANSPOSE:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg1 : tensor<6x3xf32>) outs([[INITT]] : tensor<3x6xf32>) {
   // CHECK: ^bb0([[IN:%.+]]: f32, [[UNUSED:%.+]]: f32):
   // CHECK:   linalg.yield [[IN]] : f32
-  // CHECK: linalg.matmul ins(%arg0, [[TRANSPOSE]] : tensor<5x3xf32>, tensor<3x6xf32>) outs([[GENERIC]] : tensor<5x6xf32>) -> tensor<5x6xf32>
+  // CHECK: [[INITB:%.+]] = linalg.init_tensor [5, 6]
+  // CHECK: [[MATMUL:%.+]] = linalg.matmul ins(%arg0, [[TRANSPOSE]] : tensor<5x3xf32>, tensor<3x6xf32>) outs([[FILL]] : tensor<5x6xf32>) -> tensor<5x6xf32>
+  // CHECK: [[ADDED:%.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg2, [[MATMUL]] : tensor<6xf32>, tensor<5x6xf32>) outs([[INITB]] : tensor<5x6xf32>) {
+  // CHECK: ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
+  // CHECK:   [[ADD:%.+]] = addf %arg3, %arg4 : f32
+  // CHECK:   linalg.yield [[ADD]] : f32
+
   %0 = "tosa.fully_connected"(%arg0, %arg1, %arg2) : (tensor<5x3xf32>, tensor<6x3xf32>, tensor<6xf32>)  -> (tensor<5x6xf32>)
   return %0 : tensor<5x6xf32>
 }
 
 // -----
 
-// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d1)>
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d1, d0)>
 // CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)>
-// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d1, d0)>
+// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d1)>
 
 // CHECK-LABEL: @quantized_fully_connected
 func @quantized_fully_connected(%arg0: tensor<5x3xi8>, %arg1: tensor<6x3xi8>, %arg2: tensor<6xi32>) -> (tensor<5x6xi32>) {
-  // CHECK: [[INITB:%.+]] = linalg.init_tensor [5, 6]
-  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg2 : tensor<6xi32>) outs([[INITB]] : tensor<5x6xi32>) {
-  // CHECK: ^bb0([[IN:%.+]]: i32, [[UNUSED:%.+]]: i32):
-  // CHECK:   linalg.yield [[IN]] : i32
+  // CHECK: [[INITT:%.+]] = linalg.init_tensor [5, 6]
+  // CHECK: [[ZERO:%.+]] = constant 0
+  // CHECK: [[FILL:%.+]] = linalg.fill([[ZERO]], [[INITT]])
+  // CHECK: [[PERM:%.+]] = constant dense<[1, 0]>
   // CHECK: [[INITT:%.+]] = linalg.init_tensor [3, 6]
-  // CHECK: [[TRANSPOSE:%.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg1 : tensor<6x3xi8>) outs([[INITT]]
+  // CHECK: [[TRANSPOSE:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg1 : tensor<6x3xi8>) outs([[INITT]] : tensor<3x6xi8>) {
   // CHECK: ^bb0([[IN:%.+]]: i8, [[UNUSED:%.+]]: i8):
   // CHECK:   linalg.yield [[IN]] : i8
+  // CHECK: [[INITB:%.+]] = linalg.init_tensor [5, 6]
   // CHECK: [[ONE:%.+]] = constant 1 
-  // CHECK: [[TWO:%.+]] = constant 2 
-  // CHECK: linalg.quantized_matmul ins(%arg0, [[TRANSPOSE]], [[ONE]], [[TWO]] : tensor<5x3xi8>, tensor<3x6xi8>, i32, i32) outs([[GENERIC]] : tensor<5x6xi32>) -> tensor<5x6xi32>
+  // CHECK: [[TWO:%.+]] = constant 2
+  // CHECK: [[MATMUL:%.+]] = linalg.quantized_matmul ins(%arg0, [[TRANSPOSE]], [[ONE]], [[TWO]] : tensor<5x3xi8>, tensor<3x6xi8>, i32, i32) outs([[FILL]] : tensor<5x6xi32>) -> tensor<5x6xi32>
+  // CHECK: [[ADDED:%.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg2, [[MATMUL]] : tensor<6xi32>, tensor<5x6xi32>) outs([[INITB]]
+  // CHECK: ^bb0([[IN1:%.+]]: i32, [[IN2:%.+]]: i32, [[UNUSED:%.+]]: i32):
+  // CHECK:   [[ADD:%.+]] = addi
+  // CHECK:   linalg.yield [[ADD]] : i32
   %0 = "tosa.fully_connected"(%arg0, %arg1, %arg2) {quantization_info = {input_zp = 1:i32, weight_zp = 2:i32}} : (tensor<5x3xi8>, tensor<6x3xi8>, tensor<6xi32>)  -> (tensor<5x6xi32>)
   return %0 : tensor<5x6xi32>
 }
@@ -1350,10 +1361,14 @@ func @conv2d_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27xf32>
   // CHECK: %[[W_IN:.+]] = linalg.init_tensor [3, 3, 27, 28]
   // CHECK: %[[W:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg1 : tensor<28x3x3x27xf32>) outs(%[[W_IN]] : tensor<3x3x27x28xf32>)
   // CHECK:   linalg.yield %arg3 : f32
+  // CHECK: %[[M_IN:.+]] = linalg.init_tensor [1, 45, 40, 28]
+  // CHECK: %[[CST:.+]] = constant 0
+  // CHECK: %[[FILL:.+]] = linalg.fill
   // CHECK: %[[B_IN:.+]] = linalg.init_tensor [1, 45, 40, 28]
-  // CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<28xf32>) outs(%[[B_IN]] : tensor<1x45x40x28xf32>)
-  // CHECK:   linalg.yield %arg3 : f32
-  // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %1 : tensor<1x49x42x27xf32>, tensor<3x3x27x28xf32>) outs(%[[B]] : tensor<1x45x40x28xf32>)
+  // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[W]] : tensor<1x49x42x27xf32>, tensor<3x3x27x28xf32>) outs(%[[FILL]] : tensor<1x45x40x28xf32>)
+  // CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xf32>, tensor<1x45x40x28xf32>) outs(%[[B_IN]] : tensor<1x45x40x28xf32>)
+  // CHECK:   addf
+  // CHECK:   linalg.yield %7 : f32
   %0 = "tosa.conv2d"(%input, %weights, %bias) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [2, 1]} : (tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>, tensor<28xf32>)  -> (tensor<1x45x40x28xf32>)
   return
 }


        


More information about the Mlir-commits mailing list