[Mlir-commits] [mlir] 7e1fb9a - [mlir][tosa] Add conv2d lowering to linalg.conv2d operator for FP

Rob Suderman llvmlistbot at llvm.org
Tue Apr 13 13:28:47 PDT 2021


Author: Rob Suderman
Date: 2021-04-13T13:26:02-07:00
New Revision: 7e1fb9a0d2d731e21f845bf9c14db6df1047f991

URL: https://github.com/llvm/llvm-project/commit/7e1fb9a0d2d731e21f845bf9c14db6df1047f991
DIFF: https://github.com/llvm/llvm-project/commit/7e1fb9a0d2d731e21f845bf9c14db6df1047f991.diff

LOG: [mlir][tosa] Add conv2d lowering to linalg.conv2d operator for FP

Handles lowering conv2d to linalg's convolution operator. This implementation
only supports floating point values but handles all strides, dilations, and
padding values.

Differential Revision: https://reviews.llvm.org/D100061

Added: 
    

Modified: 
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index ef317aa3e9b46..de27feca3e103 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -740,6 +740,109 @@ class FullyConnectedConverter
   }
 };
 
+class Conv2DConverter : public OpConversionPattern<tosa::Conv2DOp> {
+public:
+  using OpConversionPattern<tosa::Conv2DOp>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(tosa::Conv2DOp op, ArrayRef<Value> args,
+                  ConversionPatternRewriter &rewriter) const final {
+    Location loc = op.getLoc();
+    Value input = op.input();
+    Value weight = op.weight();
+    Value bias = op.bias();
+
+    ShapedType inputTy = input.getType().cast<ShapedType>();
+    ShapedType weightTy = weight.getType().cast<ShapedType>();
+    ShapedType biasTy = bias.getType().cast<ShapedType>();
+    ShapedType resultTy = op.getType().cast<ShapedType>();
+
+    Type inputETy = inputTy.getElementType();
+    Type weightETy = weightTy.getElementType();
+    Type biasETy = biasTy.getElementType();
+    Type resultETy = resultTy.getElementType();
+
+    if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
+        !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
+      return rewriter.notifyMatchFailure(op,
+                                         "tosa.conv2d requires static shapes");
+
+    auto inputShape = inputTy.getShape();
+    auto weightShape = weightTy.getShape();
+
+    // TODO(suderman): Support other types.
+    if (!inputETy.isF32() || !weightETy.isF32() || !biasETy.isF32() ||
+        !resultETy.isF32())
+      return failure();
+
+    // Broadcast the initial value to the output tensor before convolving.
+    SmallVector<AffineMap, 4> indexingMaps;
+    indexingMaps.push_back(AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0,
+                                          {rewriter.getAffineDimExpr(3)},
+                                          rewriter.getContext()));
+    indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank()));
+
+    Value initTensor = rewriter.create<linalg::InitTensorOp>(
+        loc, resultTy.getShape(), resultTy.getElementType());
+    Value biasBroadcast =
+        rewriter
+            .create<linalg::GenericOp>(
+                loc, resultTy, bias, initTensor, indexingMaps,
+                getNParallelLoopsAttrs(resultTy.getRank()),
+                [&](OpBuilder &nestedBuilder, Location nestedLoc,
+                    ValueRange args) {
+                  nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
+                })
+            .getResult(0);
+
+    // Transpose weights tensor to be in dim order: spatial dims,
+    // input channels, and output channels.
+    SmallVector<int64_t> permutation{1, 2, 3, 0};
+    auto permutationAttr = DenseIntElementsAttr::get(
+        RankedTensorType::get({4}, rewriter.getI64Type()), permutation);
+    Value permutationValue = rewriter.create<ConstantOp>(loc, permutationAttr);
+
+    SmallVector<int64_t> newKernelShape{weightShape[1], weightShape[2],
+                                        weightShape[3], weightShape[0]};
+    Type newKernelTy = RankedTensorType::get(newKernelShape, biasETy);
+
+    Value transposedKernel = rewriter.create<tosa::TransposeOp>(
+        loc, newKernelTy, weight, permutationValue);
+
+    // Extract the attributes for convolution.
+    llvm::SmallVector<int64_t> stride, dilation, pad;
+    getValuesFromIntArrayAttribute(op.stride(), stride);
+    getValuesFromIntArrayAttribute(op.dilation(), dilation);
+    getValuesFromIntArrayAttribute(op.pad(), pad);
+
+    // Input should be padded if necessary.
+    if (llvm::any_of(pad, [](int64_t p) { return p; })) {
+      llvm::SmallVector<int64_t, 8> newPad{0,      0,      pad[0], pad[1],
+                                           pad[2], pad[3], 0,      0};
+      auto padAttr = DenseIntElementsAttr::get(
+          RankedTensorType::get({4, 2}, rewriter.getI64Type()), newPad);
+      Value padValue = rewriter.create<ConstantOp>(loc, padAttr);
+
+      SmallVector<int64_t, 4> paddedShape{
+          inputShape[0], inputShape[1] + pad[0] + pad[1],
+          inputShape[2] + pad[2] + pad[3], inputShape[3]};
+      Type paddedTy = RankedTensorType::get(paddedShape, inputETy);
+      input = rewriter.create<tosa::PadOp>(loc, paddedTy, input, padValue);
+    }
+
+    auto strideAttr = DenseIntElementsAttr::get(
+        RankedTensorType::get({2}, rewriter.getI64Type()), stride);
+    auto dilationAttr = DenseIntElementsAttr::get(
+        RankedTensorType::get({2}, rewriter.getI64Type()), dilation);
+
+    auto convOp = rewriter.create<linalg::ConvInputNHWCFilterHWCFOp>(
+        loc, resultTy, ValueRange{input, transposedKernel},
+        ValueRange{biasBroadcast}, dilationAttr, strideAttr);
+
+    rewriter.replaceOp(op, convOp.getResult(0));
+    return success();
+  }
+};
+
 class ReshapeConverter : public OpConversionPattern<tosa::ReshapeOp> {
 public:
   using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
@@ -1693,6 +1796,7 @@ void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
       ReduceConverter<tosa::ReduceProdOp>,
       ArgMaxConverter,
       ConcatConverter,
+      Conv2DConverter,
       PadConverter,
       ReshapeConverter,
       RescaleConverter,

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index b33c18f46ddf3..119ceea14afe3 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -923,3 +923,28 @@ func @max_pool_i32(%arg0: tensor<1x6x34x62xi32>) -> () {
   %0 = "tosa.max_pool2d"(%arg0) {pad = [0, 0, 0, 0], kernel = [3, 3], stride = [1, 1]} : (tensor<1x6x34x62xi32>)  -> (tensor<1x4x32x62xi32>)
   return
 }
+
+// -----
+
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d3)>
+// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d0, d1, d2)>
+
+func @conv2d_f32(%input: tensor<1x49x42x28xf32>, %weights: tensor<28x3x3x28xf32>, %bias: tensor<28xf32>) -> () {
+  // CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 45, 40, 28] : tensor<1x45x40x28xf32>
+  // CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<28xf32>) outs(%[[INIT]] : tensor<1x45x40x28xf32>)
+  // CHECK: ^bb0(%arg3: f32, %arg4: f32):
+  // CHECK:   linalg.yield %arg3 : f32
+  // CHECK: %[[INITKERNEL:.+]] = linalg.init_tensor [3, 3, 28, 28]
+  // CHECK: %[[TRANSPOSEKERNEL:.+]] = linalg.generic {indexing_maps = [#map2, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg1 : tensor<28x3x3x28xf32>) outs(%[[INITKERNEL]] : tensor<3x3x28x28xf32>)
+  // CHECK: linalg.conv_2d_input_nhwc_filter_hwcf {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[TRANSPOSEKERNEL]] : tensor<1x49x42x28xf32>, tensor<3x3x28x28xf32>) outs(%[[BROADCAST]] : tensor<1x45x40x28xf32>)
+  %0 = "tosa.conv2d"(%input, %weights, %bias) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [2, 1]} : (tensor<1x49x42x28xf32>, tensor<28x3x3x28xf32>, tensor<28xf32>)  -> (tensor<1x45x40x28xf32>)
+  return
+}
+
+func @conv2d_padded_f32(%input: tensor<1x47x40x28xf32>, %weights: tensor<28x3x3x28xf32>, %bias: tensor<28xf32>) -> () {
+  // CHECK: linalg.pad_tensor %arg0
+  // CHECK: linalg.conv_2d_input_nhwc_filter_hwcf
+  %0 = "tosa.conv2d"(%input, %weights, %bias) {pad = [1, 1, 1, 1], stride = [1, 1], dilation = [2, 1]} : (tensor<1x47x40x28xf32>, tensor<28x3x3x28xf32>, tensor<28xf32>)  -> (tensor<1x45x40x28xf32>)
+  return
+}


        


More information about the Mlir-commits mailing list