[Mlir-commits] [mlir] fac9f45 - [tosa][mlir] Add dynamic width/height support for depthwise convolution in tosa-to-linalg
Rob Suderman
llvmlistbot at llvm.org
Thu Apr 7 10:51:21 PDT 2022
Author: natashaknk
Date: 2022-04-07T10:50:06-07:00
New Revision: fac9f45e050065ca419dc222bec413cb9fe92a57
URL: https://github.com/llvm/llvm-project/commit/fac9f45e050065ca419dc222bec413cb9fe92a57
DIFF: https://github.com/llvm/llvm-project/commit/fac9f45e050065ca419dc222bec413cb9fe92a57.diff
LOG: [tosa][mlir] Add dynamic width/height support for depthwise convolution in tosa-to-linalg
In addition, fixed a small bug with padding incorrectly inferring output shape for dynaic inputs in convolution
Reviewed By: jpienaar
Differential Revision: https://reviews.llvm.org/D121872
Added:
Modified:
mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index f5fe8f8389308..4c3547a47eb2e 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -48,7 +48,10 @@ static mlir::Value applyPad(Location loc, Value input, ArrayRef<int64_t> pad,
for (int i = 0, s = inputShape.size(); i < s; i++) {
auto lowPad = pad[i * 2];
auto highPad = pad[i * 2 + 1];
- paddedShape.push_back(inputShape[i] + highPad + lowPad);
+ if (ShapedType::isDynamic(inputShape[i]))
+ paddedShape.push_back(inputShape[i]);
+ else
+ paddedShape.push_back(inputShape[i] + highPad + lowPad);
lowIndices.push_back(rewriter.getIndexAttr(lowPad));
highIndices.push_back(rewriter.getIndexAttr(highPad));
}
@@ -68,7 +71,6 @@ static mlir::Value reifyConstantDim(Attribute attr,
}
// Calculating the output width/height using the formula:
-// Out =((initDim+padBefore+padAttr-(dilation*(kernelDim-1)+1))/stride+1
// H = ((IH+pad_top+pad_bottom-(dilation_y*(KH-1)+1))/stride_y)+1
// W = ((IW+pad_left+pad_right-(dilation_x*(KW-1)+1))/stride_x)+1
static mlir::Value
@@ -94,6 +96,54 @@ getConvOutputDim(Location loc, Value initDim, Attribute padBeforeAttr,
return builder.create<arith::SubIOp>(divide, one);
}
+// Creates a vector of the dynamic output dims for Conv2D and Depthwise_Conv2D
+static SmallVector<Value> inferDynamicDimsForConv(
+ Location loc, Value input, Value weight, ShapedType resultTy,
+ ArrayAttr padAttr, ArrayAttr strideAttr, ArrayAttr dilationAttr,
+ int64_t weightHDim, int64_t weightWDim, OpBuilder &rewriter) {
+ ShapedType inputTy = input.getType().cast<ShapedType>();
+ Type inputETy = inputTy.getElementType();
+ int64_t inputRank = inputTy.getRank();
+ int64_t heightDim = 1;
+ int64_t weightDim = 2;
+
+ SmallVector<Value> dynDims;
+ dynDims.resize(resultTy.getRank());
+ for (int i = 0; i < inputRank; i++) {
+ if (inputTy.isDynamicDim(i) && i != heightDim && i != weightDim)
+ dynDims[i] = rewriter.create<tensor::DimOp>(loc, input, i);
+ }
+
+ // Dynamic input height
+ if (inputTy.isDynamicDim(heightDim)) {
+ Value initHDim =
+ rewriter.create<tensor::DimOp>(loc, input, heightDim).getResult();
+ Value kernelHDim =
+ rewriter.create<tensor::DimOp>(loc, weight, weightHDim).getResult();
+ // H = F(IH, pad_top, pad_bottom, dilation_y, KH, stride_y)
+ dynDims[heightDim] = getConvOutputDim(
+ loc, initHDim, padAttr.getValue()[0], padAttr.getValue()[1], kernelHDim,
+ strideAttr.getValue()[0], dilationAttr.getValue()[0], inputETy,
+ rewriter);
+ }
+
+ // Dynamic input weight
+ if (inputTy.isDynamicDim(weightDim)) {
+ Value initWDim =
+ rewriter.create<tensor::DimOp>(loc, input, weightDim).getResult();
+ Value kernelWDim =
+ rewriter.create<tensor::DimOp>(loc, weight, weightWDim).getResult();
+ // W = F(IW, pad_left, pad_right, dilation_x, KW, stride_x)
+ dynDims[weightDim] = getConvOutputDim(
+ loc, initWDim, padAttr.getValue()[2], padAttr.getValue()[3], kernelWDim,
+ strideAttr.getValue()[1], dilationAttr.getValue()[1], inputETy,
+ rewriter);
+ }
+
+ SmallVector<Value> filteredDims = condenseValues(dynDims);
+ return filteredDims;
+}
+
namespace {
class ConvConverter : public OpConversionPattern<tosa::Conv2DOp> {
@@ -111,7 +161,6 @@ class ConvConverter : public OpConversionPattern<tosa::Conv2DOp> {
ShapedType weightTy = weight.getType().cast<ShapedType>();
ShapedType biasTy = bias.getType().cast<ShapedType>();
ShapedType resultTy = op->getResult(0).getType().cast<ShapedType>();
- int64_t inputRank = inputTy.getRank();
Type inputETy = inputTy.getElementType();
Type resultETy = resultTy.getElementType();
@@ -129,41 +178,9 @@ class ConvConverter : public OpConversionPattern<tosa::Conv2DOp> {
return rewriter.notifyMatchFailure(
op, "tosa.conv ops does not support unsigned integer input");
- SmallVector<Value> dynDims;
- dynDims.resize(resultTy.getRank());
- for (int i = 0; i < inputRank; i++) {
- if (inputTy.isDynamicDim(i)) {
- // Dynamic input height
- // H = F(IH, pad_top, pad_bottom, dilation_y, KH, sride_y)
- if (i == 1) {
- Value initHDim =
- rewriter.create<tensor::DimOp>(loc, input, 1).getResult();
- Value kernelHDim =
- rewriter.create<tensor::DimOp>(loc, weight, 1).getResult();
- dynDims[i] = getConvOutputDim(
- loc, initHDim, padAttr.getValue()[0], padAttr.getValue()[1],
- kernelHDim, strideTosaAttr.getValue()[0],
- dilationTosaAttr.getValue()[0], inputETy, rewriter);
-
- // Dynamic input weight
- // W = F(IH, pad_left, pad_right, dilation_x, KW, sride_x)
- } else if (i == 2) {
- Value initWDim =
- rewriter.create<tensor::DimOp>(loc, input, 2).getResult();
- Value kernelWDim =
- rewriter.create<tensor::DimOp>(loc, weight, 2).getResult();
- dynDims[i] = getConvOutputDim(
- loc, initWDim, padAttr.getValue()[2], padAttr.getValue()[3],
- kernelWDim, strideTosaAttr.getValue()[1],
- dilationTosaAttr.getValue()[1], inputETy, rewriter);
-
- } else {
- dynDims[i] = rewriter.create<tensor::DimOp>(loc, input, i);
- }
- }
- }
-
- SmallVector<Value> filteredDims = condenseValues(dynDims);
+ SmallVector<Value> filteredDims = inferDynamicDimsForConv(
+ loc, input, weight, resultTy, padAttr, strideTosaAttr, dilationTosaAttr,
+ /*weightHDim=*/1, /*weightWDim=*/2, rewriter);
auto weightShape = weightTy.getShape();
@@ -322,6 +339,15 @@ class DepthwiseConvConverter
auto strideTosaAttr = op->getAttr("stride").cast<ArrayAttr>();
auto dilationTosaAttr = op->getAttr("dilation").cast<ArrayAttr>();
+ if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
+ return rewriter.notifyMatchFailure(
+ op, "tosa.depthwise_conv ops require static shapes");
+
+ // Compute output dynamic dims
+ SmallVector<Value> filteredDims = inferDynamicDimsForConv(
+ loc, input, weight, resultTy, padAttr, strideTosaAttr, dilationTosaAttr,
+ 0, 1, rewriter);
+
bool isQuantized = op->hasAttr("quantization_info");
IntegerAttr iZp;
IntegerAttr kZp;
@@ -334,16 +360,6 @@ class DepthwiseConvConverter
quantizationInfo.weight_zp().getValue().getSExtValue());
}
- if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
- return rewriter.notifyMatchFailure(
- op, "tosa.depthwise_conv ops require static shapes");
-
- auto dynamicDimsOr =
- checkHasDynamicBatchDims(rewriter, op, {input, op.output()});
- if (!dynamicDimsOr.hasValue())
- return failure();
- SmallVector<Value> dynamicDims = dynamicDimsOr.getValue();
-
auto weightShape = weightTy.getShape();
auto resultShape = resultTy.getShape();
@@ -401,7 +417,7 @@ class DepthwiseConvConverter
Attribute resultZeroAttr = rewriter.getZeroAttr(resultETy);
Value initTensor = rewriter.create<linalg::InitTensorOp>(
- loc, dynamicDims, linalgConvTy.getShape(), resultETy);
+ loc, filteredDims, linalgConvTy.getShape(), resultETy);
Value zero = rewriter.create<arith::ConstantOp>(loc, resultZeroAttr);
Value zeroTensor = rewriter
.create<linalg::FillOp>(loc, ValueRange{zero},
@@ -409,7 +425,7 @@ class DepthwiseConvConverter
.result();
Value biasInitTensor = rewriter.create<linalg::InitTensorOp>(
- loc, dynamicDims, resultTy.getShape(), resultETy);
+ loc, filteredDims, resultTy.getShape(), resultETy);
if (!isQuantized) {
Value conv = rewriter
.create<linalg::DepthwiseConv2DNhwcHwcmOp>(
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index d769c5e96250e..91e74358e4d3d 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -479,7 +479,7 @@ func @depthwise_conv(%arg0 : tensor<1x7x5x3xf32>, %arg1 : tensor<3x1x3x11xf32>,
// CHECK: [[DEPTH:%.+]] = linalg.depthwise_conv_2d_nhwc_hwcm {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x7x5x3xf32>, tensor<3x1x3x11xf32>) outs([[FILL]] : tensor<1x5x5x3x11xf32>)
// CHECK: [[COLLAPSED:%.+]] = "tosa.reshape"([[DEPTH]]) {new_shape = [1, 5, 5, 33]}
// CHECK: [[BIAS:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, [[COLLAPSED]] : tensor<33xf32>, tensor<1x5x5x33xf32>) outs([[OUT]] : tensor<1x5x5x33xf32>) {
- // CHECK: ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
+ // CHECK: ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
// CHECK: [[ADD:%.+]] = arith.addf %arg3, %arg4 : f32
// CHECK: linalg.yield [[ADD]] : f32
// CHECK: } -> tensor<1x5x5x33xf32>
@@ -503,7 +503,7 @@ func @depthwise_conv_dyn(%arg0 : tensor<?x7x5x3xf32>, %arg1 : tensor<3x1x3x11xf3
// CHECK: %[[DEPTH:.+]] = linalg.depthwise_conv_2d_nhwc_hwcm {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<?x7x5x3xf32>, tensor<3x1x3x11xf32>) outs(%[[FILL]] : tensor<?x5x5x3x11xf32>)
// CHECK: %[[COLLAPSED:.+]] = "tosa.reshape"(%[[DEPTH]]) {new_shape = [-1, 5, 5, 33]}
// CHECK: %[[BIAS:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[COLLAPSED]] : tensor<33xf32>, tensor<?x5x5x33xf32>) outs(%[[OUT]] : tensor<?x5x5x33xf32>) {
- // CHECK: ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
+ // CHECK: ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
// CHECK: %[[ADD:.+]] = arith.addf %arg3, %arg4 : f32
// CHECK: linalg.yield %[[ADD]] : f32
// CHECK: } -> tensor<?x5x5x33xf32>
@@ -584,3 +584,19 @@ func @depthwise_conv_quant_dilations(%arg0 : tensor<1x14x14x4xi8>, %arg1 : tenso
%0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], quantization_info = {input_zp = -128 : i32, weight_zp = 42 : i32}, stride = [1, 1], dilation = [2, 2] } : (tensor<1x14x14x4xi8>, tensor<3x3x4x128xi8>, tensor<512xi32>) -> tensor<1x10x10x512xi32>
return
}
+
+// CHECK-LABEL: @depthwise_conv2d_dyn_w_h
+func @depthwise_conv2d_dyn_w_h(%arg0: tensor<2x?x?x3xf32>, %arg1: tensor<3x6x3x5xf32>, %arg2: tensor<15xf32>) {
+ // CHECK: arith.addi
+ // CHECK: arith.subi
+ // CHECK: arith.muli
+ // CHECK: arith.divui
+ // CHECK: %[[PADDED:.+]] = tensor.pad %arg0 low[0, 1, 3, 0] high[0, 2, 4, 0] {
+ // CHECK: ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index):
+ // CHECK: tensor.yield %cst : f32
+ // CHECK: } : tensor<2x?x?x3xf32> to tensor<2x?x?x3xf32>
+ // CHECK: %[[CONV:.+]] = linalg.depthwise_conv_2d_nhwc_hwcm {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<[1, 2]> : tensor<2xi64>} ins(%[[PADDED]], %arg1 : tensor<2x?x?x3xf32>, tensor<3x6x3x5xf32>) outs(%22 : tensor<2x?x?x3x5xf32>) -> tensor<2x?x?x3x5xf32>
+ // CHECK: %[[RESHAPED:.+]] = "tosa.reshape"(%[[CONV]]) {new_shape = [2, -1, -1, 15]} : (tensor<2x?x?x3x5xf32>) -> tensor<2x?x?x15xf32>
+ %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [1, 2, 3, 4], dilation = [2, 1], stride = [1, 2]} : (tensor<2x?x?x3xf32>, tensor<3x6x3x5xf32>, tensor<15xf32>) -> tensor<2x?x?x15xf32>
+ return
+}
More information about the Mlir-commits
mailing list