[Mlir-commits] [mlir] 8d7a833 - [tosa][mlir] Add support for dynamic width/height for Conv2D inputs in tosa-to-linalg

Rob Suderman llvmlistbot at llvm.org
Wed Mar 2 12:23:23 PST 2022


Author: natashaknk
Date: 2022-03-02T12:16:35-08:00
New Revision: 8d7a833eed1a530c260882ffc346a6711cfe96af

URL: https://github.com/llvm/llvm-project/commit/8d7a833eed1a530c260882ffc346a6711cfe96af
DIFF: https://github.com/llvm/llvm-project/commit/8d7a833eed1a530c260882ffc346a6711cfe96af.diff

LOG: [tosa][mlir] Add support for dynamic width/height for Conv2D inputs in tosa-to-linalg

Infers output shape for dynamic width/height inputs.

Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D119977

Added: 
    

Modified: 
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 8916f1745b4a8..f2a5ffaf3e082 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -61,6 +61,39 @@ static mlir::Value applyPad(Location loc, Value input, ArrayRef<int64_t> pad,
       .result();
 }
 
+static mlir::Value reifyConstantDim(Attribute attr,
+                                    ImplicitLocOpBuilder &builder) {
+  return builder.createOrFold<arith::IndexCastOp>(
+      builder.getIndexType(), builder.create<arith::ConstantOp>(attr));
+}
+
+// Calculating the output width/height using the formula:
+// Out =((initDim+padBefore+padAttr-(dilation*(kernelDim-1)+1))/stride+1
+// H = ((IH+pad_top+pad_bottom-(dilation_y*(KH-1)+1))/stride_y)+1
+// W = ((IW+pad_left+pad_right-(dilation_x*(KW-1)+1))/stride_x)+1
+static mlir::Value
+getConvOutputDim(Location loc, Value initDim, Attribute padBeforeAttr,
+                 Attribute padAfterAttr, Value kernelDim, Attribute strideAttr,
+                 Attribute dilationAttr, Type inputETy, OpBuilder &rewriter) {
+  ImplicitLocOpBuilder builder(loc, rewriter);
+  auto one = rewriter.create<arith::ConstantOp>(
+      loc, IntegerAttr::get(initDim.getType(), 1));
+  Value padBefore = reifyConstantDim(padBeforeAttr, builder);
+  Value paddedBefore = builder.create<arith::AddIOp>(initDim, padBefore);
+  Value padAfter = reifyConstantDim(padAfterAttr, builder);
+  Value paddedAfter = builder.create<arith::AddIOp>(paddedBefore, padAfter);
+
+  Value subOne = builder.create<arith::SubIOp>(kernelDim, one);
+  Value dilation = reifyConstantDim(dilationAttr, builder);
+  Value dilated = builder.create<arith::MulIOp>(dilation, subOne);
+  Value addOne = builder.create<arith::AddIOp>(dilated, one);
+
+  Value subtract = builder.create<arith::SubIOp>(paddedAfter, addOne);
+  Value stride = reifyConstantDim(strideAttr, builder);
+  Value divide = builder.create<arith::DivUIOp>(subtract, stride);
+  return builder.create<arith::SubIOp>(divide, one);
+}
+
 namespace {
 
 class ConvConverter : public OpConversionPattern<tosa::Conv2DOp> {
@@ -78,6 +111,7 @@ class ConvConverter : public OpConversionPattern<tosa::Conv2DOp> {
     ShapedType weightTy = weight.getType().cast<ShapedType>();
     ShapedType biasTy = bias.getType().cast<ShapedType>();
     ShapedType resultTy = op->getResult(0).getType().cast<ShapedType>();
+    int64_t inputRank = inputTy.getRank();
 
     Type inputETy = inputTy.getElementType();
     Type resultETy = resultTy.getElementType();
@@ -91,16 +125,46 @@ class ConvConverter : public OpConversionPattern<tosa::Conv2DOp> {
       return rewriter.notifyMatchFailure(
           op, "tosa.conv ops require static shapes for weight and bias");
 
-    auto dynamicDimsOr =
-        checkHasDynamicBatchDims(rewriter, op, {input, op.output()});
-    if (!dynamicDimsOr.hasValue())
-      return failure();
-    SmallVector<Value> dynamicDims = dynamicDimsOr.getValue();
-
     if (inputETy.isUnsignedInteger())
       return rewriter.notifyMatchFailure(
           op, "tosa.conv ops does not support unsigned integer input");
 
+    SmallVector<Value> dynDims;
+    dynDims.resize(resultTy.getRank());
+    for (int i = 0; i < inputRank; i++) {
+      if (inputTy.isDynamicDim(i)) {
+        // Dynamic input height
+        // H = F(IH, pad_top, pad_bottom, dilation_y, KH, sride_y)
+        if (i == 1) {
+          Value initHDim =
+              rewriter.create<tensor::DimOp>(loc, input, 1).getResult();
+          Value kernelHDim =
+              rewriter.create<tensor::DimOp>(loc, weight, 1).getResult();
+          dynDims[i] = getConvOutputDim(
+              loc, initHDim, padAttr.getValue()[0], padAttr.getValue()[1],
+              kernelHDim, strideTosaAttr.getValue()[0],
+              dilationTosaAttr.getValue()[0], inputETy, rewriter);
+
+          // Dynamic input weight
+          // W = F(IH, pad_left, pad_right, dilation_x, KW, sride_x)
+        } else if (i == 2) {
+          Value initWDim =
+              rewriter.create<tensor::DimOp>(loc, input, 2).getResult();
+          Value kernelWDim =
+              rewriter.create<tensor::DimOp>(loc, weight, 2).getResult();
+          dynDims[i] = getConvOutputDim(
+              loc, initWDim, padAttr.getValue()[2], padAttr.getValue()[3],
+              kernelWDim, strideTosaAttr.getValue()[1],
+              dilationTosaAttr.getValue()[1], inputETy, rewriter);
+
+        } else {
+          dynDims[i] = rewriter.create<tensor::DimOp>(loc, input, i);
+        }
+      }
+    }
+
+    SmallVector<Value> filteredDims = condenseValues(dynDims);
+
     auto weightShape = weightTy.getShape();
 
     // Apply padding as necessary.
@@ -148,7 +212,7 @@ class ConvConverter : public OpConversionPattern<tosa::Conv2DOp> {
 
     Attribute resultZeroAttr = rewriter.getZeroAttr(resultETy);
     Value initTensor = rewriter.create<linalg::InitTensorOp>(
-        loc, dynamicDims, resultTy.getShape(), resultETy);
+        loc, filteredDims, resultTy.getShape(), resultETy);
     Value zero = rewriter.create<arith::ConstantOp>(loc, resultZeroAttr);
     Value zeroTensor =
         rewriter.create<linalg::FillOp>(loc, zero, initTensor).getResult(0);
@@ -173,7 +237,7 @@ class ConvConverter : public OpConversionPattern<tosa::Conv2DOp> {
     indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank()));
 
     Value biasInitTensor = rewriter.create<linalg::InitTensorOp>(
-        loc, dynamicDims, resultTy.getShape(), resultETy);
+        loc, filteredDims, resultTy.getShape(), resultETy);
 
     if (isQuantized) {
       auto quantizationInfo =

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index f04759c9a5a6e..776dd54ed64b6 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -383,6 +383,66 @@ func @conv2d_dyn(%input: tensor<?x49x42x27xf32>, %weights: tensor<28x3x3x27xf32>
 
 // -----
 
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+
+// CHECK-LABEL: @conv2d_dyn_w_h
+func @conv2d_dyn_w_h(%input: tensor<1x?x?x27xf32>, %weights: tensor<28x3x3x27xf32>, %bias: tensor<28xf32>) -> () {
+  // Computing output height
+  // CHECK: %[[C1:.+]] = arith.constant 1
+  // CHECK: %[[H:.+]] = tensor.dim %arg0, %[[C1]]
+  // CHECK: %[[C1_0:.+]] = arith.constant 1
+  // CHECK: %[[KH:.+]] = tensor.dim %arg1, %[[C1_0]]
+  // CHECK: %[[ONE:.+]] = arith.constant 1 : index
+  // CHECK: %[[PAD_0:.+]] = arith.constant 0 : index
+  // CHECK: %[[ADD_PAD_0:.+]] = arith.addi %[[H]], %[[PAD_0]] : index
+  // CHECK: %[[PAD_1:.+]] = arith.constant 0 : index
+  // CHECK: %[[ADD_PAD_1:.+]] = arith.addi %[[ADD_PAD_0]], %[[PAD_1]] : index
+  // CHECK: %[[SUB_ONE:.+]] = arith.subi %[[KH]], %[[ONE]] : index
+  // CHECK: %[[DIL_H:.+]] = arith.constant 2 : index
+  // CHECK: %[[DILATED:.+]] = arith.muli %[[DIL_H]], %[[SUB_ONE]] : index
+  // CHECK: %[[ADD_ONE:.+]] = arith.addi %[[DILATED]], %[[ONE]] : index
+  // CHECK: %[[SUBTRACTED:.+]] = arith.subi %[[ADD_PAD_1]], %[[ADD_ONE]] : index
+  // CHECK: %[[STRIDE_H:.+]] = arith.constant 1 : index
+  // CHECK: %[[DIVIDED:.+]] = arith.divui %[[SUBTRACTED]], %[[STRIDE_H]] : index
+  // CHECK: %[[H_OUT:.+]] = arith.subi %[[DIVIDED]], %[[ONE]] : index
+
+  // Computing output width
+  // CHECK: %[[C2:.+]] = arith.constant 2
+  // CHECK: %[[W:.+]] = tensor.dim %arg0, %[[C2]]
+  // CHECK: %[[C2_0:.+]] = arith.constant 2
+  // CHECK: %[[KW:.+]] = tensor.dim %arg1, %[[C2_0]]
+  // CHECK: %[[ONE_0:.+]] = arith.constant 1 : index
+  // CHECK: %[[PAD_2:.+]] = arith.constant 0 : index
+  // CHECK: %[[ADD_PAD_2:.+]] = arith.addi %[[W]], %[[PAD_2]] : index
+  // CHECK: %[[PAD_3:.+]] = arith.constant 0 : index
+  // CHECK: %[[ADD_PAD_3:.+]] = arith.addi %[[ADD_PAD_2]], %[[PAD_3]] : index
+  // CHECK: %[[SUB_ONE_0:.+]] = arith.subi %[[KW]], %[[ONE_0]] : index
+  // CHECK: %[[DIL_W:.+]] = arith.constant 1 : index
+  // CHECK: %[[DILATED_0:.+]] = arith.muli %[[DIL_W]], %[[SUB_ONE_0]] : index
+  // CHECK: %[[ADD_ONE_0:.+]] = arith.addi %[[DILATED_0]], %[[ONE_0]] : index
+  // CHECK: %[[SUBTRACTED_0:.+]] = arith.subi %[[ADD_PAD_3]], %[[ADD_ONE_0]] : index
+  // CHECK: %[[STRIDE_W:.+]] = arith.constant 1 : index
+  // CHECK: %[[DIVIDED_0:.+]] = arith.divui %[[SUBTRACTED_0]], %[[STRIDE_W]] : index
+  // CHECK: %[[W_OUT:.+]] = arith.subi %[[DIVIDED_0]], %[[ONE_0]] : index
+
+  // Running convolution
+  // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 2, 3, 0]>
+  // CHECK: %[[WEIGHT:.+]] = "tosa.transpose"(%arg1, %[[PERM]])
+  // CHECK: %[[M_IN:.+]] = linalg.init_tensor [1, %[[H_OUT]], %[[W_OUT]], 28]
+  // CHECK: %[[CST:.+]] = arith.constant 0
+  // CHECK: %[[FILL:.+]] = linalg.fill
+  // CHECK: %[[B_IN:.+]] = linalg.init_tensor [1, %[[H_OUT]], %[[W_OUT]], 28]
+  // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[WEIGHT]] : tensor<1x?x?x27xf32>, tensor<3x3x27x28xf32>) outs(%[[FILL]] : tensor<1x?x?x28xf32>)
+  // CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xf32>, tensor<1x?x?x28xf32>) outs(%[[B_IN]] : tensor<1x?x?x28xf32>)
+  // CHECK:   %[[ADD:.+]] = arith.addf
+  // CHECK:   linalg.yield %[[ADD]] : f32
+  %0 = "tosa.conv2d"(%input, %weights, %bias) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [2, 1]} : (tensor<1x?x?x27xf32>, tensor<28x3x3x27xf32>, tensor<28xf32>)  -> (tensor<1x?x?x28xf32>)
+  return
+}
+
+// -----
+
 // CHECK-LABEL: @conv2d_padded_f32
 func @conv2d_padded_f32(%input: tensor<1x47x40x28xf32>, %weights: tensor<28x3x3x28xf32>, %bias: tensor<28xf32>) -> () {
   // CHECK: %[[C0:.+]] = arith.constant 0


        


More information about the Mlir-commits mailing list