[Mlir-commits] [mlir] 3e41bec - [mlir][tosa] Revert added support for dynamic height/weight for pooling in tosa-to-linalg
Jacques Pienaar
llvmlistbot at llvm.org
Wed Sep 21 09:29:27 PDT 2022
Author: natashaknk
Date: 2022-09-21T09:29:17-07:00
New Revision: 3e41bec03ebdde08b6da7733b5c8340993258cb0
URL: https://github.com/llvm/llvm-project/commit/3e41bec03ebdde08b6da7733b5c8340993258cb0
DIFF: https://github.com/llvm/llvm-project/commit/3e41bec03ebdde08b6da7733b5c8340993258cb0.diff
LOG: [mlir][tosa] Revert added support for dynamic height/weight for pooling in tosa-to-linalg
Partial rollback to D133389
Reviewed By: jpienaar
Differential Revision: https://reviews.llvm.org/D134370
Added:
Modified:
mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 250d4525ee1a..4a31bd620776 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -73,13 +73,12 @@ static mlir::Value reifyConstantDim(Attribute attr,
// Calculating the output width/height using the formula:
// H = ((IH+pad_top+pad_bottom-(dilation_y*(KH-1)+1))/stride_y)+1
// W = ((IW+pad_left+pad_right-(dilation_x*(KW-1)+1))/stride_x)+1
-static mlir::Value getConvOrPoolOutputDim(Location loc, Value inputDim,
- Attribute padBeforeAttr,
- Attribute padAfterAttr,
- Value kernelDim, Attribute strideAttr,
- Attribute dilationAttr, Type inputETy,
- ImplicitLocOpBuilder &builder) {
- auto one = builder.create<arith::ConstantOp>(
+static mlir::Value
+getConvOutputDim(Location loc, Value inputDim, Attribute padBeforeAttr,
+ Attribute padAfterAttr, Value kernelDim, Attribute strideAttr,
+ Attribute dilationAttr, Type inputETy, OpBuilder &rewriter) {
+ ImplicitLocOpBuilder builder(loc, rewriter);
+ auto one = rewriter.create<arith::ConstantOp>(
loc, IntegerAttr::get(inputDim.getType(), 1));
Value padBefore = reifyConstantDim(padBeforeAttr, builder);
Value paddedBefore = builder.create<arith::AddIOp>(inputDim, padBefore);
@@ -97,27 +96,11 @@ static mlir::Value getConvOrPoolOutputDim(Location loc, Value inputDim,
return builder.create<arith::AddIOp>(divide, one);
}
-// For convolution, the kernel is a value.
-Value getKernelDim(Location loc, Value kernel, uint64_t dim,
- ImplicitLocOpBuilder &builder) {
- return builder.create<tensor::DimOp>(loc, kernel, dim).getResult();
-}
-
-// For pooling, the kernel is an attribute.
-Value getKernelDim(Location loc, ArrayAttr kernel, uint64_t dim,
- ImplicitLocOpBuilder &builder) {
- auto kernelArr = kernel.getValue();
- if (dim >= kernelArr.size()) return nullptr;
- Attribute kernelDimAttr = kernelArr[dim];
- return reifyConstantDim(kernelDimAttr, builder);
-}
-
-// Creates a vector of the dynamic output dims convolution and pooling ops.
-template <typename T>
-static SmallVector<Value> inferDynamicDimsForConvOrPool(
- Location loc, Value input, T weight, ShapedType resultTy, ArrayAttr padAttr,
- ArrayAttr strideAttr, ArrayAttr dilationAttr, int64_t weightHDim,
- int64_t weightWDim, OpBuilder &rewriter) {
+// Creates a vector of the dynamic output dims for Conv2D and Depthwise_Conv2D
+static SmallVector<Value> inferDynamicDimsForConv(
+ Location loc, Value input, Value weight, ShapedType resultTy,
+ ArrayAttr padAttr, ArrayAttr strideAttr, ArrayAttr dilationAttr,
+ int64_t weightHDim, int64_t weightWDim, OpBuilder &rewriter) {
ShapedType inputTy = input.getType().cast<ShapedType>();
Type inputETy = inputTy.getElementType();
int64_t inputRank = inputTy.getRank();
@@ -131,29 +114,30 @@ static SmallVector<Value> inferDynamicDimsForConvOrPool(
dynDims[i] = rewriter.create<tensor::DimOp>(loc, input, i);
}
- ImplicitLocOpBuilder builder(loc, rewriter);
// Dynamic input height
if (inputTy.isDynamicDim(heightDim)) {
- Value inputHDim =
- builder.create<tensor::DimOp>(loc, input, heightDim).getResult();
- Value kernelHDim = getKernelDim(loc, weight, weightHDim, builder);
+ Value initHDim =
+ rewriter.create<tensor::DimOp>(loc, input, heightDim).getResult();
+ Value kernelHDim =
+ rewriter.create<tensor::DimOp>(loc, weight, weightHDim).getResult();
// H = F(IH, pad_top, pad_bottom, dilation_y, KH, stride_y)
- dynDims[heightDim] = getConvOrPoolOutputDim(
- loc, inputHDim, padAttr.getValue()[0], padAttr.getValue()[1],
- kernelHDim, strideAttr.getValue()[0], dilationAttr.getValue()[0],
- inputETy, builder);
+ dynDims[heightDim] = getConvOutputDim(
+ loc, initHDim, padAttr.getValue()[0], padAttr.getValue()[1], kernelHDim,
+ strideAttr.getValue()[0], dilationAttr.getValue()[0], inputETy,
+ rewriter);
}
// Dynamic input weight
if (inputTy.isDynamicDim(weightDim)) {
- Value inputWDim =
- builder.create<tensor::DimOp>(loc, input, weightDim).getResult();
- Value kernelWDim = getKernelDim(loc, weight, weightWDim, builder);
+ Value initWDim =
+ rewriter.create<tensor::DimOp>(loc, input, weightDim).getResult();
+ Value kernelWDim =
+ rewriter.create<tensor::DimOp>(loc, weight, weightWDim).getResult();
// W = F(IW, pad_left, pad_right, dilation_x, KW, stride_x)
- dynDims[weightDim] = getConvOrPoolOutputDim(
- loc, inputWDim, padAttr.getValue()[2], padAttr.getValue()[3],
- kernelWDim, strideAttr.getValue()[1], dilationAttr.getValue()[1],
- inputETy, builder);
+ dynDims[weightDim] = getConvOutputDim(
+ loc, initWDim, padAttr.getValue()[2], padAttr.getValue()[3], kernelWDim,
+ strideAttr.getValue()[1], dilationAttr.getValue()[1], inputETy,
+ rewriter);
}
SmallVector<Value> filteredDims = condenseValues(dynDims);
@@ -207,7 +191,7 @@ class ConvConverter : public OpConversionPattern<tosa::Conv2DOp> {
return rewriter.notifyMatchFailure(
op, "tosa.conv ops does not support unsigned integer input");
- SmallVector<Value> filteredDims = inferDynamicDimsForConvOrPool(
+ SmallVector<Value> filteredDims = inferDynamicDimsForConv(
loc, input, weight, resultTy, padAttr, strideTosaAttr, dilationTosaAttr,
/*weightHDim=*/1, /*weightWDim=*/2, rewriter);
@@ -372,7 +356,7 @@ class DepthwiseConvConverter
op, "tosa.depthwise_conv ops require static shapes");
// Compute output dynamic dims
- SmallVector<Value> filteredDims = inferDynamicDimsForConvOrPool(
+ SmallVector<Value> filteredDims = inferDynamicDimsForConv(
loc, input, weight, resultTy, padAttr, strideTosaAttr, dilationTosaAttr,
0, 1, rewriter);
@@ -708,15 +692,11 @@ class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
ShapedType resultTy = op.getType().template cast<ShapedType>();
Type resultETy = inputTy.getElementType();
- auto kernelAttr = op.getKernel().cast<ArrayAttr>();
- auto padAttr = op.getPad().cast<ArrayAttr>();
- auto strideTosaAttr = op.getStride().cast<ArrayAttr>();
- ArrayAttr dilationTosaAttr = rewriter.getI64ArrayAttr({1, 1});
-
- SmallVector<Value> filteredDims = inferDynamicDimsForConvOrPool(
- loc, input, kernelAttr, resultTy, padAttr, strideTosaAttr,
- dilationTosaAttr,
- /*weightHDim=*/0, /*weightWDim=*/1, rewriter);
+ auto dynamicDimsOr =
+ checkHasDynamicBatchDims(rewriter, op, {input, op.getOutput()});
+ if (!dynamicDimsOr.has_value())
+ return failure();
+ SmallVector<Value> dynamicDims = dynamicDimsOr.value();
// Determine what the initial value needs to be for the max pool op.
Attribute initialAttr;
@@ -753,7 +733,7 @@ class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
// Create the linalg op that performs pooling.
Value initTensor = rewriter.create<linalg::InitTensorOp>(
- loc, filteredDims, resultTy.getShape(), resultTy.getElementType());
+ loc, dynamicDims, resultTy.getShape(), resultTy.getElementType());
Value filledInitTensor =
rewriter
@@ -789,15 +769,11 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
inElementTy.isa<IntegerType>() ? rewriter.getI32Type() : inElementTy;
ShapedType accTy = resultTy.clone(accETy);
- auto kernelAttr = op.getKernel().cast<ArrayAttr>();
- auto padArrayAttr = op.getPad().cast<ArrayAttr>();
- auto strideTosaAttr = op.getStride().cast<ArrayAttr>();
- ArrayAttr dilationTosaAttr = rewriter.getI64ArrayAttr({1, 1});
-
- SmallVector<Value> filteredDims = inferDynamicDimsForConvOrPool(
- loc, input, kernelAttr, resultTy, padArrayAttr, strideTosaAttr,
- dilationTosaAttr,
- /*weightHDim=*/0, /*weightWDim=*/1, rewriter);
+ auto dynamicDimsOr =
+ checkHasDynamicBatchDims(rewriter, op, {input, op.getOutput()});
+ if (!dynamicDimsOr.has_value())
+ return failure();
+ SmallVector<Value> dynamicDims = dynamicDimsOr.value();
// Apply padding as necessary.
llvm::SmallVector<int64_t> pad;
@@ -819,7 +795,7 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
// Create the linalg op that performs pooling.
Value poolInitTensor = rewriter.create<linalg::InitTensorOp>(
- loc, filteredDims, accTy.getShape(), accETy);
+ loc, dynamicDims, accTy.getShape(), accETy);
Value filledInitTensor =
rewriter
@@ -844,7 +820,7 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
auto affineMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
Value genericInitTensor = rewriter.create<linalg::InitTensorOp>(
- loc, filteredDims, resultTy.getShape(), resultETy);
+ loc, dynamicDims, resultTy.getShape(), resultETy);
auto genericOp = rewriter.create<linalg::GenericOp>(
loc, ArrayRef<Type>({resultTy}), ValueRange{poolingOp},
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index 1240070d78b4..d956e822f9ef 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -165,21 +165,15 @@ func.func @max_pool_padded(%arg0: tensor<1x6x34x62xf32>) -> () {
}
// CHECK-LABEL: @max_pool_dyn
-func.func @max_pool_dyn(%arg0: tensor<?x?x?x64xf32>) -> () {
- // CHECK: %[[C0:.+]] = arith.constant 0 : index
- // CHECK: %[[DIM0:.+]] = tensor.dim %arg0, %[[C0]] : tensor<?x?x?x64xf32>
- // CHECK: %[[C1:.+]] = arith.constant 1 : index
- // CHECK: %[[DIM1:.+]] = tensor.dim %arg0, %[[C1]] : tensor<?x?x?x64xf32>
- // CHECK: arith.constant 2 : index
- // CHECK: %[[C2:.+]] = arith.constant 2 : index
- // CHECK: %[[DIM2:.+]] = tensor.dim %arg0, %[[C2]] : tensor<?x?x?x64xf32>
- // CHECK: %[[PAD:.+]] = tensor.pad %arg0
+func.func @max_pool_dyn(%arg0: tensor<?x6x34x62xf32>) -> () {
+ // CHECK: %[[C0:.+]] = arith.constant 0
+ // CHECK: %[[BATCH:.+]] = tensor.dim %arg0, %[[C0]]
// CHECK: %[[CONST:.+]] = arith.constant -3.40282347E+38
- // CHECK: %[[INIT:.+]] = linalg.init_tensor
- // CHECK: %[[FILL:.+]] = linalg.fill ins(%cst_18 : f32) outs(%20 : tensor<?x?x?x64xf32>) -> tensor<?x?x?x64xf32>
+ // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[BATCH]], 4, 32, 62]
+ // CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CONST]]{{.*}}outs(%[[INIT]]
// CHECK: %[[KERNEL:.+]] = linalg.init_tensor [3, 3]
- // CHECK: linalg.pooling_nhwc_max {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>} ins(%[[PAD]], %[[KERNEL]] : tensor<?x?x?x64xf32>, tensor<3x3xf32>) outs(%[[FILL]] : tensor<?x?x?x64xf32>) -> tensor<?x?x?x64xf32>
- %0 = "tosa.max_pool2d"(%arg0) {kernel = [3, 3], pad = [1, 1, 1, 1], stride = [2, 2]} : (tensor<?x?x?x64xf32>) -> (tensor<?x?x?x64xf32>)
+ // CHECK: linalg.pooling_nhwc_max {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%arg0, %[[KERNEL]] : tensor<?x6x34x62xf32>, tensor<3x3xf32>) outs(%[[FILL]] : tensor<?x4x32x62xf32>)
+ %0 = "tosa.max_pool2d"(%arg0) {pad = [0, 0, 0, 0], kernel = [3, 3], stride = [1, 1]} : (tensor<?x6x34x62xf32>) -> (tensor<?x4x32x62xf32>)
return
}
@@ -285,25 +279,6 @@ func.func @avg_pool_dyn(%arg0: tensor<?x6x34x62xf32>) -> (tensor<?x5x33x62xf32>)
return %0 : tensor<?x5x33x62xf32>
}
-// CHECK-LABEL: @avg_pool_dyn_h
-func.func @avg_pool_dyn_h(%arg0: tensor<2x?x34x62xf32>) -> (tensor<2x?x33x62xf32>) {
- // CHECK: %[[C1:.+]] = arith.constant 1
- // CHECK: %[[DIM1:.+]] = tensor.dim %arg0, %[[C1]]
- // CHECK: arith.addi
- // CHECK: arith.addi
- // CHECK: arith.addi
- // CHECK: %[[RESULT:.+]] = arith.addi
- // CHECK: %[[PAD:.+]] = tensor.pad %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0]
- // CHECK: %[[POOLINIT:.+]] = linalg.init_tensor [2, %[[RESULT]], 33, 62]
- // CHECK: %[[FILL:.+]] = linalg.fill
- // CHECK: %[[KERNEL:.+]] = linalg.init_tensor [4, 4]
- // CHECK: %[[POOL:.+]] = linalg.pooling_nhwc_sum {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%[[PAD]], %[[KERNEL]] : tensor<2x?x36x62xf32>, tensor<4x4xf32>) outs(%[[FILL]] : tensor<2x?x33x62xf32>)
- // CHECK: %[[INIT:.+]] = linalg.init_tensor [2, %[[RESULT]], 33, 62]
- // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[POOL]] : tensor<2x?x33x62xf32>) outs(%[[INIT]] : tensor<2x?x33x62xf32>)
- %0 = "tosa.avg_pool2d"(%arg0) {pad = [1, 1, 1, 1], kernel = [4, 4], stride = [1, 1]} : (tensor<2x?x34x62xf32>) -> (tensor<2x?x33x62xf32>)
- return %0 : tensor<2x?x33x62xf32>
-}
-
// -----
// CHECK-LABEL: @avg_pool_i8
More information about the Mlir-commits
mailing list