[Mlir-commits] [mlir] 87a6ccf - [mlir][tosa] Add dynamic width/height for pooling in tosa to linalg

Rob Suderman llvmlistbot at llvm.org
Thu Sep 8 10:05:32 PDT 2022


Author: natashaknk
Date: 2022-09-08T09:50:09-07:00
New Revision: 87a6ccf0948c3ed22925ac0319bfa451fa97accb

URL: https://github.com/llvm/llvm-project/commit/87a6ccf0948c3ed22925ac0319bfa451fa97accb
DIFF: https://github.com/llvm/llvm-project/commit/87a6ccf0948c3ed22925ac0319bfa451fa97accb.diff

LOG: [mlir][tosa] Add dynamic width/height for pooling in tosa to linalg

Needed to support dynamic width/height for pooling inputs using
the similar convolution work.

Reviewed By: rsuderman

Differential Revision: https://reviews.llvm.org/D133389

Added: 
    

Modified: 
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 9ed3262198cf0..250d4525ee1a0 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -73,15 +73,16 @@ static mlir::Value reifyConstantDim(Attribute attr,
 // Calculating the output width/height using the formula:
 // H = ((IH+pad_top+pad_bottom-(dilation_y*(KH-1)+1))/stride_y)+1
 // W = ((IW+pad_left+pad_right-(dilation_x*(KW-1)+1))/stride_x)+1
-static mlir::Value
-getConvOutputDim(Location loc, Value initDim, Attribute padBeforeAttr,
-                 Attribute padAfterAttr, Value kernelDim, Attribute strideAttr,
-                 Attribute dilationAttr, Type inputETy, OpBuilder &rewriter) {
-  ImplicitLocOpBuilder builder(loc, rewriter);
-  auto one = rewriter.create<arith::ConstantOp>(
-      loc, IntegerAttr::get(initDim.getType(), 1));
+static mlir::Value getConvOrPoolOutputDim(Location loc, Value inputDim,
+                                          Attribute padBeforeAttr,
+                                          Attribute padAfterAttr,
+                                          Value kernelDim, Attribute strideAttr,
+                                          Attribute dilationAttr, Type inputETy,
+                                          ImplicitLocOpBuilder &builder) {
+  auto one = builder.create<arith::ConstantOp>(
+      loc, IntegerAttr::get(inputDim.getType(), 1));
   Value padBefore = reifyConstantDim(padBeforeAttr, builder);
-  Value paddedBefore = builder.create<arith::AddIOp>(initDim, padBefore);
+  Value paddedBefore = builder.create<arith::AddIOp>(inputDim, padBefore);
   Value padAfter = reifyConstantDim(padAfterAttr, builder);
   Value paddedAfter = builder.create<arith::AddIOp>(paddedBefore, padAfter);
 
@@ -96,11 +97,27 @@ getConvOutputDim(Location loc, Value initDim, Attribute padBeforeAttr,
   return builder.create<arith::AddIOp>(divide, one);
 }
 
-// Creates a vector of the dynamic output dims for Conv2D and Depthwise_Conv2D
-static SmallVector<Value> inferDynamicDimsForConv(
-    Location loc, Value input, Value weight, ShapedType resultTy,
-    ArrayAttr padAttr, ArrayAttr strideAttr, ArrayAttr dilationAttr,
-    int64_t weightHDim, int64_t weightWDim, OpBuilder &rewriter) {
+// For convolution, the kernel is a value.
+Value getKernelDim(Location loc, Value kernel, uint64_t dim,
+                   ImplicitLocOpBuilder &builder) {
+  return builder.create<tensor::DimOp>(loc, kernel, dim).getResult();
+}
+
+// For pooling, the kernel is an attribute.
+Value getKernelDim(Location loc, ArrayAttr kernel, uint64_t dim,
+                   ImplicitLocOpBuilder &builder) {
+  auto kernelArr = kernel.getValue();
+  if (dim >= kernelArr.size()) return nullptr;
+  Attribute kernelDimAttr = kernelArr[dim];
+  return reifyConstantDim(kernelDimAttr, builder);
+}
+
+// Creates a vector of the dynamic output dims convolution and pooling ops.
+template <typename T>
+static SmallVector<Value> inferDynamicDimsForConvOrPool(
+    Location loc, Value input, T weight, ShapedType resultTy, ArrayAttr padAttr,
+    ArrayAttr strideAttr, ArrayAttr dilationAttr, int64_t weightHDim,
+    int64_t weightWDim, OpBuilder &rewriter) {
   ShapedType inputTy = input.getType().cast<ShapedType>();
   Type inputETy = inputTy.getElementType();
   int64_t inputRank = inputTy.getRank();
@@ -114,30 +131,29 @@ static SmallVector<Value> inferDynamicDimsForConv(
       dynDims[i] = rewriter.create<tensor::DimOp>(loc, input, i);
   }
 
+  ImplicitLocOpBuilder builder(loc, rewriter);
   // Dynamic input height
   if (inputTy.isDynamicDim(heightDim)) {
-    Value initHDim =
-        rewriter.create<tensor::DimOp>(loc, input, heightDim).getResult();
-    Value kernelHDim =
-        rewriter.create<tensor::DimOp>(loc, weight, weightHDim).getResult();
+    Value inputHDim =
+        builder.create<tensor::DimOp>(loc, input, heightDim).getResult();
+    Value kernelHDim = getKernelDim(loc, weight, weightHDim, builder);
     // H = F(IH, pad_top, pad_bottom, dilation_y, KH, stride_y)
-    dynDims[heightDim] = getConvOutputDim(
-        loc, initHDim, padAttr.getValue()[0], padAttr.getValue()[1], kernelHDim,
-        strideAttr.getValue()[0], dilationAttr.getValue()[0], inputETy,
-        rewriter);
+    dynDims[heightDim] = getConvOrPoolOutputDim(
+        loc, inputHDim, padAttr.getValue()[0], padAttr.getValue()[1],
+        kernelHDim, strideAttr.getValue()[0], dilationAttr.getValue()[0],
+        inputETy, builder);
   }
 
   // Dynamic input weight
   if (inputTy.isDynamicDim(weightDim)) {
-    Value initWDim =
-        rewriter.create<tensor::DimOp>(loc, input, weightDim).getResult();
-    Value kernelWDim =
-        rewriter.create<tensor::DimOp>(loc, weight, weightWDim).getResult();
+    Value inputWDim =
+        builder.create<tensor::DimOp>(loc, input, weightDim).getResult();
+    Value kernelWDim = getKernelDim(loc, weight, weightWDim, builder);
     // W = F(IW, pad_left, pad_right, dilation_x, KW, stride_x)
-    dynDims[weightDim] = getConvOutputDim(
-        loc, initWDim, padAttr.getValue()[2], padAttr.getValue()[3], kernelWDim,
-        strideAttr.getValue()[1], dilationAttr.getValue()[1], inputETy,
-        rewriter);
+    dynDims[weightDim] = getConvOrPoolOutputDim(
+        loc, inputWDim, padAttr.getValue()[2], padAttr.getValue()[3],
+        kernelWDim, strideAttr.getValue()[1], dilationAttr.getValue()[1],
+        inputETy, builder);
   }
 
   SmallVector<Value> filteredDims = condenseValues(dynDims);
@@ -191,7 +207,7 @@ class ConvConverter : public OpConversionPattern<tosa::Conv2DOp> {
       return rewriter.notifyMatchFailure(
           op, "tosa.conv ops does not support unsigned integer input");
 
-    SmallVector<Value> filteredDims = inferDynamicDimsForConv(
+    SmallVector<Value> filteredDims = inferDynamicDimsForConvOrPool(
         loc, input, weight, resultTy, padAttr, strideTosaAttr, dilationTosaAttr,
         /*weightHDim=*/1, /*weightWDim=*/2, rewriter);
 
@@ -356,7 +372,7 @@ class DepthwiseConvConverter
           op, "tosa.depthwise_conv ops require static shapes");
 
     // Compute output dynamic dims
-    SmallVector<Value> filteredDims = inferDynamicDimsForConv(
+    SmallVector<Value> filteredDims = inferDynamicDimsForConvOrPool(
         loc, input, weight, resultTy, padAttr, strideTosaAttr, dilationTosaAttr,
         0, 1, rewriter);
 
@@ -692,11 +708,15 @@ class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
     ShapedType resultTy = op.getType().template cast<ShapedType>();
     Type resultETy = inputTy.getElementType();
 
-    auto dynamicDimsOr =
-        checkHasDynamicBatchDims(rewriter, op, {input, op.getOutput()});
-    if (!dynamicDimsOr.has_value())
-      return failure();
-    SmallVector<Value> dynamicDims = dynamicDimsOr.value();
+    auto kernelAttr = op.getKernel().cast<ArrayAttr>();
+    auto padAttr = op.getPad().cast<ArrayAttr>();
+    auto strideTosaAttr = op.getStride().cast<ArrayAttr>();
+    ArrayAttr dilationTosaAttr = rewriter.getI64ArrayAttr({1, 1});
+
+    SmallVector<Value> filteredDims = inferDynamicDimsForConvOrPool(
+        loc, input, kernelAttr, resultTy, padAttr, strideTosaAttr,
+        dilationTosaAttr,
+        /*weightHDim=*/0, /*weightWDim=*/1, rewriter);
 
     // Determine what the initial value needs to be for the max pool op.
     Attribute initialAttr;
@@ -733,7 +753,7 @@ class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
 
     // Create the linalg op that performs pooling.
     Value initTensor = rewriter.create<linalg::InitTensorOp>(
-        loc, dynamicDims, resultTy.getShape(), resultTy.getElementType());
+        loc, filteredDims, resultTy.getShape(), resultTy.getElementType());
 
     Value filledInitTensor =
         rewriter
@@ -769,11 +789,15 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
         inElementTy.isa<IntegerType>() ? rewriter.getI32Type() : inElementTy;
     ShapedType accTy = resultTy.clone(accETy);
 
-    auto dynamicDimsOr =
-        checkHasDynamicBatchDims(rewriter, op, {input, op.getOutput()});
-    if (!dynamicDimsOr.has_value())
-      return failure();
-    SmallVector<Value> dynamicDims = dynamicDimsOr.value();
+    auto kernelAttr = op.getKernel().cast<ArrayAttr>();
+    auto padArrayAttr = op.getPad().cast<ArrayAttr>();
+    auto strideTosaAttr = op.getStride().cast<ArrayAttr>();
+    ArrayAttr dilationTosaAttr = rewriter.getI64ArrayAttr({1, 1});
+
+    SmallVector<Value> filteredDims = inferDynamicDimsForConvOrPool(
+        loc, input, kernelAttr, resultTy, padArrayAttr, strideTosaAttr,
+        dilationTosaAttr,
+        /*weightHDim=*/0, /*weightWDim=*/1, rewriter);
 
     // Apply padding as necessary.
     llvm::SmallVector<int64_t> pad;
@@ -795,7 +819,7 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
 
     // Create the linalg op that performs pooling.
     Value poolInitTensor = rewriter.create<linalg::InitTensorOp>(
-        loc, dynamicDims, accTy.getShape(), accETy);
+        loc, filteredDims, accTy.getShape(), accETy);
 
     Value filledInitTensor =
         rewriter
@@ -820,7 +844,7 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
     auto affineMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
 
     Value genericInitTensor = rewriter.create<linalg::InitTensorOp>(
-        loc, dynamicDims, resultTy.getShape(), resultETy);
+        loc, filteredDims, resultTy.getShape(), resultETy);
 
     auto genericOp = rewriter.create<linalg::GenericOp>(
         loc, ArrayRef<Type>({resultTy}), ValueRange{poolingOp},

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index d956e822f9efa..1240070d78b4e 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -165,15 +165,21 @@ func.func @max_pool_padded(%arg0: tensor<1x6x34x62xf32>) -> () {
 }
 
 // CHECK-LABEL: @max_pool_dyn
-func.func @max_pool_dyn(%arg0: tensor<?x6x34x62xf32>) -> () {
-  // CHECK: %[[C0:.+]] = arith.constant 0
-  // CHECK: %[[BATCH:.+]] = tensor.dim %arg0, %[[C0]]
+func.func @max_pool_dyn(%arg0: tensor<?x?x?x64xf32>) -> () {
+  // CHECK: %[[C0:.+]]  = arith.constant 0 : index
+  // CHECK: %[[DIM0:.+]]  = tensor.dim %arg0, %[[C0]] : tensor<?x?x?x64xf32>
+  // CHECK: %[[C1:.+]]  = arith.constant 1 : index
+  // CHECK: %[[DIM1:.+]]  = tensor.dim %arg0, %[[C1]] : tensor<?x?x?x64xf32>
+  // CHECK: arith.constant 2 : index
+  // CHECK: %[[C2:.+]]  = arith.constant 2 : index
+  // CHECK: %[[DIM2:.+]]  = tensor.dim %arg0, %[[C2]] : tensor<?x?x?x64xf32>
+  // CHECK: %[[PAD:.+]] = tensor.pad %arg0
   // CHECK: %[[CONST:.+]] = arith.constant -3.40282347E+38
-  // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[BATCH]], 4, 32, 62]
-  // CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CONST]]{{.*}}outs(%[[INIT]]
+  // CHECK: %[[INIT:.+]] = linalg.init_tensor
+  // CHECK: %[[FILL:.+]] = linalg.fill ins(%cst_18 : f32) outs(%20 : tensor<?x?x?x64xf32>) -> tensor<?x?x?x64xf32>
   // CHECK: %[[KERNEL:.+]] = linalg.init_tensor [3, 3]
-  // CHECK: linalg.pooling_nhwc_max {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%arg0, %[[KERNEL]] : tensor<?x6x34x62xf32>, tensor<3x3xf32>) outs(%[[FILL]] : tensor<?x4x32x62xf32>)
-  %0 = "tosa.max_pool2d"(%arg0) {pad = [0, 0, 0, 0], kernel = [3, 3], stride = [1, 1]} : (tensor<?x6x34x62xf32>)  -> (tensor<?x4x32x62xf32>)
+  // CHECK: linalg.pooling_nhwc_max {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>} ins(%[[PAD]], %[[KERNEL]] : tensor<?x?x?x64xf32>, tensor<3x3xf32>) outs(%[[FILL]] : tensor<?x?x?x64xf32>) -> tensor<?x?x?x64xf32>
+  %0 = "tosa.max_pool2d"(%arg0) {kernel = [3, 3], pad = [1, 1, 1, 1], stride = [2, 2]} : (tensor<?x?x?x64xf32>) -> (tensor<?x?x?x64xf32>)
   return
 }
 
@@ -279,6 +285,25 @@ func.func @avg_pool_dyn(%arg0: tensor<?x6x34x62xf32>) -> (tensor<?x5x33x62xf32>)
   return %0 : tensor<?x5x33x62xf32>
 }
 
+// CHECK-LABEL: @avg_pool_dyn_h
+func.func @avg_pool_dyn_h(%arg0: tensor<2x?x34x62xf32>) -> (tensor<2x?x33x62xf32>) {
+  // CHECK: %[[C1:.+]] = arith.constant 1
+  // CHECK: %[[DIM1:.+]] = tensor.dim %arg0, %[[C1]]
+  // CHECK: arith.addi
+  // CHECK: arith.addi
+  // CHECK: arith.addi
+  // CHECK: %[[RESULT:.+]] = arith.addi
+  // CHECK: %[[PAD:.+]] = tensor.pad %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0]
+  // CHECK: %[[POOLINIT:.+]] = linalg.init_tensor [2, %[[RESULT]], 33, 62]
+  // CHECK: %[[FILL:.+]] = linalg.fill
+  // CHECK: %[[KERNEL:.+]] = linalg.init_tensor [4, 4]
+  // CHECK: %[[POOL:.+]] = linalg.pooling_nhwc_sum {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%[[PAD]], %[[KERNEL]] : tensor<2x?x36x62xf32>, tensor<4x4xf32>) outs(%[[FILL]] : tensor<2x?x33x62xf32>)
+  // CHECK: %[[INIT:.+]] = linalg.init_tensor [2, %[[RESULT]], 33, 62]
+  // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[POOL]] : tensor<2x?x33x62xf32>) outs(%[[INIT]] : tensor<2x?x33x62xf32>)
+  %0 = "tosa.avg_pool2d"(%arg0) {pad = [1, 1, 1, 1], kernel = [4, 4], stride = [1, 1]} : (tensor<2x?x34x62xf32>)  -> (tensor<2x?x33x62xf32>)
+  return %0 : tensor<2x?x33x62xf32>
+}
+
 // -----
 
 // CHECK-LABEL: @avg_pool_i8


        


More information about the Mlir-commits mailing list