[Mlir-commits] [mlir] 8d7a833 - [tosa][mlir] Add support for dynamic width/height for Conv2D inputs in tosa-to-linalg
Rob Suderman
llvmlistbot at llvm.org
Wed Mar 2 12:23:23 PST 2022
Author: natashaknk
Date: 2022-03-02T12:16:35-08:00
New Revision: 8d7a833eed1a530c260882ffc346a6711cfe96af
URL: https://github.com/llvm/llvm-project/commit/8d7a833eed1a530c260882ffc346a6711cfe96af
DIFF: https://github.com/llvm/llvm-project/commit/8d7a833eed1a530c260882ffc346a6711cfe96af.diff
LOG: [tosa][mlir] Add support for dynamic width/height for Conv2D inputs in tosa-to-linalg
Infers output shape for dynamic width/height inputs.
Reviewed By: jpienaar
Differential Revision: https://reviews.llvm.org/D119977
Added:
Modified:
mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 8916f1745b4a8..f2a5ffaf3e082 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -61,6 +61,39 @@ static mlir::Value applyPad(Location loc, Value input, ArrayRef<int64_t> pad,
.result();
}
+static mlir::Value reifyConstantDim(Attribute attr,
+ ImplicitLocOpBuilder &builder) {
+ return builder.createOrFold<arith::IndexCastOp>(
+ builder.getIndexType(), builder.create<arith::ConstantOp>(attr));
+}
+
+// Calculating the output width/height using the formula:
+// Out =((initDim+padBefore+padAttr-(dilation*(kernelDim-1)+1))/stride+1
+// H = ((IH+pad_top+pad_bottom-(dilation_y*(KH-1)+1))/stride_y)+1
+// W = ((IW+pad_left+pad_right-(dilation_x*(KW-1)+1))/stride_x)+1
+static mlir::Value
+getConvOutputDim(Location loc, Value initDim, Attribute padBeforeAttr,
+ Attribute padAfterAttr, Value kernelDim, Attribute strideAttr,
+ Attribute dilationAttr, Type inputETy, OpBuilder &rewriter) {
+ ImplicitLocOpBuilder builder(loc, rewriter);
+ auto one = rewriter.create<arith::ConstantOp>(
+ loc, IntegerAttr::get(initDim.getType(), 1));
+ Value padBefore = reifyConstantDim(padBeforeAttr, builder);
+ Value paddedBefore = builder.create<arith::AddIOp>(initDim, padBefore);
+ Value padAfter = reifyConstantDim(padAfterAttr, builder);
+ Value paddedAfter = builder.create<arith::AddIOp>(paddedBefore, padAfter);
+
+ Value subOne = builder.create<arith::SubIOp>(kernelDim, one);
+ Value dilation = reifyConstantDim(dilationAttr, builder);
+ Value dilated = builder.create<arith::MulIOp>(dilation, subOne);
+ Value addOne = builder.create<arith::AddIOp>(dilated, one);
+
+ Value subtract = builder.create<arith::SubIOp>(paddedAfter, addOne);
+ Value stride = reifyConstantDim(strideAttr, builder);
+ Value divide = builder.create<arith::DivUIOp>(subtract, stride);
+ return builder.create<arith::SubIOp>(divide, one);
+}
+
namespace {
class ConvConverter : public OpConversionPattern<tosa::Conv2DOp> {
@@ -78,6 +111,7 @@ class ConvConverter : public OpConversionPattern<tosa::Conv2DOp> {
ShapedType weightTy = weight.getType().cast<ShapedType>();
ShapedType biasTy = bias.getType().cast<ShapedType>();
ShapedType resultTy = op->getResult(0).getType().cast<ShapedType>();
+ int64_t inputRank = inputTy.getRank();
Type inputETy = inputTy.getElementType();
Type resultETy = resultTy.getElementType();
@@ -91,16 +125,46 @@ class ConvConverter : public OpConversionPattern<tosa::Conv2DOp> {
return rewriter.notifyMatchFailure(
op, "tosa.conv ops require static shapes for weight and bias");
- auto dynamicDimsOr =
- checkHasDynamicBatchDims(rewriter, op, {input, op.output()});
- if (!dynamicDimsOr.hasValue())
- return failure();
- SmallVector<Value> dynamicDims = dynamicDimsOr.getValue();
-
if (inputETy.isUnsignedInteger())
return rewriter.notifyMatchFailure(
op, "tosa.conv ops does not support unsigned integer input");
+ SmallVector<Value> dynDims;
+ dynDims.resize(resultTy.getRank());
+ for (int i = 0; i < inputRank; i++) {
+ if (inputTy.isDynamicDim(i)) {
+ // Dynamic input height
+ // H = F(IH, pad_top, pad_bottom, dilation_y, KH, sride_y)
+ if (i == 1) {
+ Value initHDim =
+ rewriter.create<tensor::DimOp>(loc, input, 1).getResult();
+ Value kernelHDim =
+ rewriter.create<tensor::DimOp>(loc, weight, 1).getResult();
+ dynDims[i] = getConvOutputDim(
+ loc, initHDim, padAttr.getValue()[0], padAttr.getValue()[1],
+ kernelHDim, strideTosaAttr.getValue()[0],
+ dilationTosaAttr.getValue()[0], inputETy, rewriter);
+
+ // Dynamic input weight
+ // W = F(IH, pad_left, pad_right, dilation_x, KW, sride_x)
+ } else if (i == 2) {
+ Value initWDim =
+ rewriter.create<tensor::DimOp>(loc, input, 2).getResult();
+ Value kernelWDim =
+ rewriter.create<tensor::DimOp>(loc, weight, 2).getResult();
+ dynDims[i] = getConvOutputDim(
+ loc, initWDim, padAttr.getValue()[2], padAttr.getValue()[3],
+ kernelWDim, strideTosaAttr.getValue()[1],
+ dilationTosaAttr.getValue()[1], inputETy, rewriter);
+
+ } else {
+ dynDims[i] = rewriter.create<tensor::DimOp>(loc, input, i);
+ }
+ }
+ }
+
+ SmallVector<Value> filteredDims = condenseValues(dynDims);
+
auto weightShape = weightTy.getShape();
// Apply padding as necessary.
@@ -148,7 +212,7 @@ class ConvConverter : public OpConversionPattern<tosa::Conv2DOp> {
Attribute resultZeroAttr = rewriter.getZeroAttr(resultETy);
Value initTensor = rewriter.create<linalg::InitTensorOp>(
- loc, dynamicDims, resultTy.getShape(), resultETy);
+ loc, filteredDims, resultTy.getShape(), resultETy);
Value zero = rewriter.create<arith::ConstantOp>(loc, resultZeroAttr);
Value zeroTensor =
rewriter.create<linalg::FillOp>(loc, zero, initTensor).getResult(0);
@@ -173,7 +237,7 @@ class ConvConverter : public OpConversionPattern<tosa::Conv2DOp> {
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank()));
Value biasInitTensor = rewriter.create<linalg::InitTensorOp>(
- loc, dynamicDims, resultTy.getShape(), resultETy);
+ loc, filteredDims, resultTy.getShape(), resultETy);
if (isQuantized) {
auto quantizationInfo =
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index f04759c9a5a6e..776dd54ed64b6 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -383,6 +383,66 @@ func @conv2d_dyn(%input: tensor<?x49x42x27xf32>, %weights: tensor<28x3x3x27xf32>
// -----
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+
+// CHECK-LABEL: @conv2d_dyn_w_h
+func @conv2d_dyn_w_h(%input: tensor<1x?x?x27xf32>, %weights: tensor<28x3x3x27xf32>, %bias: tensor<28xf32>) -> () {
+ // Computing output height
+ // CHECK: %[[C1:.+]] = arith.constant 1
+ // CHECK: %[[H:.+]] = tensor.dim %arg0, %[[C1]]
+ // CHECK: %[[C1_0:.+]] = arith.constant 1
+ // CHECK: %[[KH:.+]] = tensor.dim %arg1, %[[C1_0]]
+ // CHECK: %[[ONE:.+]] = arith.constant 1 : index
+ // CHECK: %[[PAD_0:.+]] = arith.constant 0 : index
+ // CHECK: %[[ADD_PAD_0:.+]] = arith.addi %[[H]], %[[PAD_0]] : index
+ // CHECK: %[[PAD_1:.+]] = arith.constant 0 : index
+ // CHECK: %[[ADD_PAD_1:.+]] = arith.addi %[[ADD_PAD_0]], %[[PAD_1]] : index
+ // CHECK: %[[SUB_ONE:.+]] = arith.subi %[[KH]], %[[ONE]] : index
+ // CHECK: %[[DIL_H:.+]] = arith.constant 2 : index
+ // CHECK: %[[DILATED:.+]] = arith.muli %[[DIL_H]], %[[SUB_ONE]] : index
+ // CHECK: %[[ADD_ONE:.+]] = arith.addi %[[DILATED]], %[[ONE]] : index
+ // CHECK: %[[SUBTRACTED:.+]] = arith.subi %[[ADD_PAD_1]], %[[ADD_ONE]] : index
+ // CHECK: %[[STRIDE_H:.+]] = arith.constant 1 : index
+ // CHECK: %[[DIVIDED:.+]] = arith.divui %[[SUBTRACTED]], %[[STRIDE_H]] : index
+ // CHECK: %[[H_OUT:.+]] = arith.subi %[[DIVIDED]], %[[ONE]] : index
+
+ // Computing output width
+ // CHECK: %[[C2:.+]] = arith.constant 2
+ // CHECK: %[[W:.+]] = tensor.dim %arg0, %[[C2]]
+ // CHECK: %[[C2_0:.+]] = arith.constant 2
+ // CHECK: %[[KW:.+]] = tensor.dim %arg1, %[[C2_0]]
+ // CHECK: %[[ONE_0:.+]] = arith.constant 1 : index
+ // CHECK: %[[PAD_2:.+]] = arith.constant 0 : index
+ // CHECK: %[[ADD_PAD_2:.+]] = arith.addi %[[W]], %[[PAD_2]] : index
+ // CHECK: %[[PAD_3:.+]] = arith.constant 0 : index
+ // CHECK: %[[ADD_PAD_3:.+]] = arith.addi %[[ADD_PAD_2]], %[[PAD_3]] : index
+ // CHECK: %[[SUB_ONE_0:.+]] = arith.subi %[[KW]], %[[ONE_0]] : index
+ // CHECK: %[[DIL_W:.+]] = arith.constant 1 : index
+ // CHECK: %[[DILATED_0:.+]] = arith.muli %[[DIL_W]], %[[SUB_ONE_0]] : index
+ // CHECK: %[[ADD_ONE_0:.+]] = arith.addi %[[DILATED_0]], %[[ONE_0]] : index
+ // CHECK: %[[SUBTRACTED_0:.+]] = arith.subi %[[ADD_PAD_3]], %[[ADD_ONE_0]] : index
+ // CHECK: %[[STRIDE_W:.+]] = arith.constant 1 : index
+ // CHECK: %[[DIVIDED_0:.+]] = arith.divui %[[SUBTRACTED_0]], %[[STRIDE_W]] : index
+ // CHECK: %[[W_OUT:.+]] = arith.subi %[[DIVIDED_0]], %[[ONE_0]] : index
+
+ // Running convolution
+ // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 2, 3, 0]>
+ // CHECK: %[[WEIGHT:.+]] = "tosa.transpose"(%arg1, %[[PERM]])
+ // CHECK: %[[M_IN:.+]] = linalg.init_tensor [1, %[[H_OUT]], %[[W_OUT]], 28]
+ // CHECK: %[[CST:.+]] = arith.constant 0
+ // CHECK: %[[FILL:.+]] = linalg.fill
+ // CHECK: %[[B_IN:.+]] = linalg.init_tensor [1, %[[H_OUT]], %[[W_OUT]], 28]
+ // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[WEIGHT]] : tensor<1x?x?x27xf32>, tensor<3x3x27x28xf32>) outs(%[[FILL]] : tensor<1x?x?x28xf32>)
+ // CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xf32>, tensor<1x?x?x28xf32>) outs(%[[B_IN]] : tensor<1x?x?x28xf32>)
+ // CHECK: %[[ADD:.+]] = arith.addf
+ // CHECK: linalg.yield %[[ADD]] : f32
+ %0 = "tosa.conv2d"(%input, %weights, %bias) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [2, 1]} : (tensor<1x?x?x27xf32>, tensor<28x3x3x27xf32>, tensor<28xf32>) -> (tensor<1x?x?x28xf32>)
+ return
+}
+
+// -----
+
// CHECK-LABEL: @conv2d_padded_f32
func @conv2d_padded_f32(%input: tensor<1x47x40x28xf32>, %weights: tensor<28x3x3x28xf32>, %bias: tensor<28xf32>) -> () {
// CHECK: %[[C0:.+]] = arith.constant 0
More information about the Mlir-commits
mailing list