[Mlir-commits] [mlir] 40a02fa - [mlir][tosa] Added tosa to linalg lowering to unstrided transposed conv

Rob Suderman llvmlistbot at llvm.org
Tue Jul 20 15:15:02 PDT 2021


Author: Rob Suderman
Date: 2021-07-20T15:07:08-07:00
New Revision: 40a02fae87ca7de676f6b9d96532c760130ccc68

URL: https://github.com/llvm/llvm-project/commit/40a02fae87ca7de676f6b9d96532c760130ccc68
DIFF: https://github.com/llvm/llvm-project/commit/40a02fae87ca7de676f6b9d96532c760130ccc68.diff

LOG: [mlir][tosa] Added tosa to linalg lowering to unstrided transposed conv

The unstrided transposed conv can be represented as a regular convolution.
Lower to this variant to handle the basic case. This includes transitioning from
the TC defined convolution operation and a yaml defined one.

Reviewed By: NatashaKnk

Differential Revision: https://reviews.llvm.org/D106389

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index f41a4cb5703a..7e07275173a0 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -627,6 +627,88 @@ structured_op: !LinalgStructuredOpConfig
                 - !ScalarExpression
                   scalar_arg: B
 --- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: conv_2d_input_nhwc_filter_ohwi_poly
+  cpp_class_name: Conv2DInputNhwcFilterOhwiPolyOp
+  doc: |-
+    Performs a 2-D convolution.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !LinalgOperandDefConfig
+    name: I
+    usage: InputOperand
+    type_var: T1
+    shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12]
+      -> (s0, s1, s2, s3)>
+  - !LinalgOperandDefConfig
+    name: K
+    usage: InputOperand
+    type_var: T2
+    shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12]
+      -> (s4, s5, s6, s3)>
+  - !LinalgOperandDefConfig
+    name: O
+    usage: OutputOperand
+    type_var: U
+    shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12]
+      -> (s0, s7, s8, s4)>
+  - !LinalgOperandDefConfig
+    name: strides
+    usage: IndexAttribute
+    type_var: I64
+    attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11,
+      s12] -> (s9, s10)>
+  - !LinalgOperandDefConfig
+    name: dilations
+    usage: IndexAttribute
+    type_var: I64
+    attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11,
+      s12] -> (s11, s12)>
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
+      s9, s10, s11, s12] -> (d0, d1 * s9 + d3 * s11, d2 * s10 + d4 * s12, d6)>
+    - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
+      s9, s10, s11, s12] -> (d5, d3, d4, d6)>
+    - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
+      s9, s10, s11, s12] -> (d0, d1, d2, d5)>
+  iterator_types:
+  - parallel
+  - parallel
+  - parallel
+  - reduction
+  - reduction
+  - parallel
+  - reduction
+  assignments:
+  - !ScalarAssign
+    arg: O
+    value: !ScalarExpression
+      scalar_apply:
+        fn_name: add
+        operands:
+        - !ScalarExpression
+          scalar_arg: O
+        - !ScalarExpression
+          scalar_apply:
+            fn_name: mul
+            operands:
+            - !ScalarExpression
+              symbolic_cast:
+                type_var: U
+                operands:
+                - !ScalarExpression
+                  scalar_arg: I
+            - !ScalarExpression
+              symbolic_cast:
+                type_var: U
+                operands:
+                - !ScalarExpression
+                  scalar_arg: K
+--- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: depthwise_conv_2d_input_nhwc_filter_hwc_poly
   cpp_class_name: DepthwiseConv2DInputNhwcFilterHwcPolyOp

diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index f99f76d4fae7..6509df285a89 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -892,30 +892,6 @@ convolutionMatchAndRewriterHelper(Operation *op,
 
   input = applyPad(loc, input, pad, zeroAttr, rewriter);
 
-  // We need to transpose the Conv2DOp kernel to line up the last input/output
-  // kernels.
-  // TODO(suderman): Eventually we will support specifying the filter channel
-  // ordering then we can avoid transposing the kernel.
-  if (isa<tosa::Conv2DOp>(op)) {
-    int32_t weightRank = weightTy.getRank();
-    SmallVector<int64_t> permutation, transposeWeightShape;
-    permutation.resize(weightRank, 0);
-    transposeWeightShape.resize(weightRank, 0);
-    for (int i = 0; i < weightRank; i++) {
-      permutation[i] = (i + 1) % weightRank;
-      transposeWeightShape[i] = weightShape[permutation[i]];
-    }
-
-    Value permutationValue = rewriter.create<ConstantOp>(
-        loc, DenseIntElementsAttr::get(
-                 RankedTensorType::get({weightRank}, rewriter.getI64Type()),
-                 permutation));
-    Type newWeightTy = RankedTensorType::get(transposeWeightShape, biasETy);
-
-    weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
-                                                permutationValue);
-  }
-
   // Broadcast the initial value to the output tensor before convolving.
   SmallVector<AffineMap, 4> indexingMaps;
   indexingMaps.push_back(AffineMap::get(
@@ -949,9 +925,9 @@ convolutionMatchAndRewriterHelper(Operation *op,
       RankedTensorType::get({2}, rewriter.getI64Type()), dilation);
 
   if (isa<tosa::Conv2DOp>(op)) {
-    rewriter.replaceOpWithNewOp<linalg::ConvInputNHWCFilterHWCFOp>(
+    rewriter.replaceOpWithNewOp<linalg::Conv2DInputNhwcFilterOhwiPolyOp>(
         op, resultTy, ValueRange{input, weight}, ValueRange{biasBroadcast},
-        dilationAttr, strideAttr);
+        strideAttr, dilationAttr);
     return success();
   }
 
@@ -1001,6 +977,84 @@ class ConvConverter : public OpConversionPattern<T> {
   }
 };
 
+class TransposeConvConverter
+    : public OpConversionPattern<tosa::TransposeConv2DOp> {
+public:
+  using OpConversionPattern<tosa::TransposeConv2DOp>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(tosa::TransposeConv2DOp op, ArrayRef<Value> args,
+                  ConversionPatternRewriter &rewriter) const final {
+    Location loc = op->getLoc();
+    Value input = op->getOperand(0);
+    Value weight = op->getOperand(1);
+    Value bias = op->getOperand(2);
+
+    ShapedType inputTy = input.getType().cast<ShapedType>();
+    ShapedType weightTy = weight.getType().cast<ShapedType>();
+    ShapedType biasTy = bias.getType().cast<ShapedType>();
+    ShapedType resultTy = op->getResult(0).getType().cast<ShapedType>();
+
+    llvm::SmallVector<int64_t> pad;
+    llvm::SmallVector<int64_t> stride;
+    llvm::SmallVector<int64_t> dilation;
+
+    getValuesFromIntArrayAttribute(op.out_pad().cast<ArrayAttr>(), pad);
+    getValuesFromIntArrayAttribute(op.stride().cast<ArrayAttr>(), stride);
+    getValuesFromIntArrayAttribute(op.dilation().cast<ArrayAttr>(), dilation);
+
+    // We have not solved for stride / dilation yet. Dilation should be
+    // straight forward but stride is more complicated. Linalg work is likely
+    // required for efficient implementation.
+    if (llvm::any_of(stride, [](int64_t v) { return v != 1; }))
+      return failure();
+    if (llvm::any_of(dilation, [](int64_t v) { return v != 1; }))
+      return failure();
+
+    if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
+        !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
+      return failure();
+
+    int64_t inputHeight = inputTy.getDimSize(1);
+    int64_t inputWidth = inputTy.getDimSize(2);
+    int64_t kernelHeight = weightTy.getDimSize(1);
+    int64_t kernelWidth = weightTy.getDimSize(2);
+    int64_t outputHeight = resultTy.getDimSize(1);
+    int64_t outputWidth = resultTy.getDimSize(2);
+
+    int64_t requiredInputHeight = outputHeight + kernelHeight - 1;
+    int64_t requiredInputWidth = outputWidth + kernelWidth - 1;
+
+    llvm::SmallVector<int64_t> newPad(4, 0);
+    newPad[0] = kernelHeight - 1 - pad[0];
+    newPad[2] = kernelWidth - 1 - pad[1];
+
+    newPad[1] = requiredInputHeight - newPad[0] - inputHeight;
+    newPad[3] = requiredInputWidth - newPad[2] - inputWidth;
+
+    auto reverse1 = rewriter.create<tosa::ReverseOp>(
+        loc, weightTy, weight, rewriter.getI64IntegerAttr(1));
+    auto reverse2 = rewriter.create<tosa::ReverseOp>(
+        loc, weightTy, reverse1, rewriter.getI64IntegerAttr(2));
+
+    Value conv2d;
+    if (op.quantization_info().hasValue()) {
+      conv2d = rewriter.create<tosa::Conv2DOp>(
+          loc, resultTy, input, reverse2, bias,
+          rewriter.getI64ArrayAttr(newPad), rewriter.getI64ArrayAttr(stride),
+          rewriter.getI64ArrayAttr(dilation),
+          op.quantization_info().getValue());
+    } else {
+      conv2d = rewriter.create<tosa::Conv2DOp>(
+          loc, resultTy, input, reverse2, bias,
+          rewriter.getI64ArrayAttr(newPad), rewriter.getI64ArrayAttr(stride),
+          rewriter.getI64ArrayAttr(dilation));
+    }
+
+    rewriter.replaceOp(op, conv2d);
+    return success();
+  }
+};
+
 class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> {
 public:
   using OpConversionPattern<tosa::MatMulOp>::OpConversionPattern;
@@ -2456,6 +2510,7 @@ void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
       ConcatConverter,
       ConvConverter<tosa::Conv2DOp>,
       ConvConverter<tosa::DepthwiseConv2DOp>,
+      TransposeConvConverter,
       GatherConverter,
       PadConverter,
       ReshapeConverter,

diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index 3ea171f78ef4..ebe9822b1663 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -145,6 +145,25 @@ def dot(
   C[None] += cast(U, A[D.m]) * cast(U, B[D.m])
 
 
+ at linalg_structured_op
+def conv_2d_input_nhwc_filter_ohwi_poly(
+    I=TensorDef(T1, S.N, S.IH, S.IW, S.IC),
+    K=TensorDef(T2, S.OC, S.KH, S.KW, S.IC),
+    O=TensorDef(U, S.N, S.OH, S.OW, S.OC, output=True),
+    strides=AttributeDef(S.SH, S.SW),
+    dilations=AttributeDef(S.DH, S.DW)):
+  """Performs a 2-D convolution.
+
+  Numeric casting is performed on the operands to the inner multiply, promoting
+  them to the same data type as the accumulator/output.
+  """
+  domain(D.n, D.oh, D.ow, D.kh, D.kw, D.oc, D.ic)
+  O[D.n, D.oh, D.ow, D.oc] += cast(
+      U, I[D.n,
+           D.oh * S.SH + D.kh * S.DH,
+           D.ow * S.SW + D.kw * S.DW,
+           D.ic]) * cast(U, K[D.oc, D.kh, D.kw, D.ic])
+
 @linalg_structured_op
 def depthwise_conv_2d_input_nhwc_filter_hwc_poly(
     I=TensorDef(T1, S.N, S.IH, S.IW, S.C),

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 1766e8ec144c..38a70fb5f5e4 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1184,25 +1184,20 @@ func @avg_pool(%arg0: tensor<1x6x34x62xf32>) -> (tensor<1x5x33x62xf32>) {
 
 // -----
 
-// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d0, d1, d2)>
-// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
 // CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d3)>
+// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
 
-func @conv2d_f32(%input: tensor<1x49x42x28xf32>, %weights: tensor<28x3x3x28xf32>, %bias: tensor<28xf32>) -> () {
-  // CHECK: %[[INIT:.+]] = linalg.init_tensor [3, 3, 28, 28]
-  // CHECK: %[[KERNEL:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg1 : tensor<28x3x3x28xf32>) outs(%[[INIT]] : tensor<3x3x28x28xf32>)
-  // CHECK: ^bb0(%arg3: f32, %arg4: f32):
-  // CHECK:   linalg.yield %arg3 : f32
+func @conv2d_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27xf32>, %bias: tensor<28xf32>) -> () {
   // CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 45, 40, 28]
   // CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<28xf32>) outs(%[[INIT]] : tensor<1x45x40x28xf32>)
-  // CHECK: linalg.conv_2d_input_nhwc_filter_hwcf {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[KERNEL]] : tensor<1x49x42x28xf32>, tensor<3x3x28x28xf32>) outs(%[[BROADCAST]] : tensor<1x45x40x28xf32>)
-  %0 = "tosa.conv2d"(%input, %weights, %bias) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [2, 1]} : (tensor<1x49x42x28xf32>, tensor<28x3x3x28xf32>, tensor<28xf32>)  -> (tensor<1x45x40x28xf32>)
+  // CHECK: linalg.conv_2d_input_nhwc_filter_ohwi_poly {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>) outs(%[[BROADCAST]] : tensor<1x45x40x28xf32>)
+  %0 = "tosa.conv2d"(%input, %weights, %bias) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [2, 1]} : (tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>, tensor<28xf32>)  -> (tensor<1x45x40x28xf32>)
   return
 }
 
 func @conv2d_padded_f32(%input: tensor<1x47x40x28xf32>, %weights: tensor<28x3x3x28xf32>, %bias: tensor<28xf32>) -> () {
   // CHECK: linalg.pad_tensor %arg0
-  // CHECK: linalg.conv_2d_input_nhwc_filter_hwcf
+  // CHECK: linalg.conv_2d_input_nhwc_filter_ohwi_poly
   %0 = "tosa.conv2d"(%input, %weights, %bias) {pad = [1, 1, 1, 1], stride = [1, 1], dilation = [2, 1]} : (tensor<1x47x40x28xf32>, tensor<28x3x3x28xf32>, tensor<28xf32>)  -> (tensor<1x45x40x28xf32>)
   return
 }
@@ -1226,6 +1221,16 @@ func @depthwise_conv(%arg0 : tensor<1x7x5x3xf32>, %arg1 : tensor<3x1x3x11xf32>,
   return
 }
 
+// -----
+
+// CHECK-LABEL: @transpose_conv
+func @transpose_conv(%arg0 : tensor<1x12x12x2xf32>, %arg1 : tensor<4x3x3x2xf32>, %arg2 : tensor<4xf32>) -> () {
+  // CHECK: [[PAD:%.+]] = linalg.pad_tensor %arg0 low[0, 2, 2, 0] high[0, 2, 2, 0]
+  // CHECK: linalg.conv_2d_input_nhwc_filter_ohwi_poly {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins([[PAD]], {{%.+}} : tensor<1x16x16x2xf32>, tensor<4x3x3x2xf32>)
+  %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], out_pad = [0, 0], out_shape = [1, 14, 14, 4], stride = [1, 1]} : (tensor<1x12x12x2xf32>, tensor<4x3x3x2xf32>, tensor<4xf32>) -> tensor<1x14x14x4xf32>
+  return
+}
+
 
 // -----
 


        


More information about the Mlir-commits mailing list