[Mlir-commits] [mlir] 69c984b - [mlir][tosa] Fix padding for tosa.conv2d and tosa.depthwise_conv2d decomposition
Rob Suderman
llvmlistbot at llvm.org
Tue Dec 13 17:38:30 PST 2022
Author: Rob Suderman
Date: 2022-12-13T17:37:36-08:00
New Revision: 69c984b6b803f00371dcf028bc9cf9b07911d1d6
URL: https://github.com/llvm/llvm-project/commit/69c984b6b803f00371dcf028bc9cf9b07911d1d6
DIFF: https://github.com/llvm/llvm-project/commit/69c984b6b803f00371dcf028bc9cf9b07911d1d6.diff
LOG: [mlir][tosa] Fix padding for tosa.conv2d and tosa.depthwise_conv2d decomposition
Decomposition did not take padding into account when decomposing into fully
connected operation.
Reviewed By: NatashaKnk
Differential Revision: https://reviews.llvm.org/D139500
Added:
Modified:
mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir
mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
index 5c30ddf921dc7..6e60e5aaba05f 100644
--- a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
+++ b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
@@ -38,6 +38,9 @@ Value clampFloatHelper(Location loc, Value arg, Value min, Value max,
Value clampIntHelper(Location loc, Value arg, Value min, Value max,
OpBuilder &rewriter);
+// Determines whether the integer value falls witin the range of integer type.
+bool validIntegerRange(IntegerType ty, int64_t value);
+
// Returns the values in an attribute as an array of values.
template <typename T>
void getValuesFromIntArrayAttribute(ArrayAttr attr,
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
index d70562400ac7a..9563740f5e2ad 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
@@ -13,6 +13,7 @@
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
using namespace mlir;
using namespace mlir::tosa;
@@ -56,6 +57,49 @@ struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
if (weightShape[1] != 1 || weightShape[2] != 1)
return failure();
+ auto padAttr = op.getPad();
+ llvm::SmallVector<int64_t> pad(8, 0);
+ for (auto it : llvm::enumerate(padAttr.getValue()))
+ pad[it.index() + 2] =
+ it.value().cast<IntegerAttr>().getValue().getSExtValue();
+
+ if (llvm::any_of(pad, [](int64_t p) { return p != 0; })) {
+ Type inputETy = inputType.getElementType();
+ Attribute zeroAttr = rewriter.getZeroAttr(inputETy);
+ if (op.getQuantizationInfo()) {
+ auto quantizationInfo = op.getQuantizationInfo();
+ int64_t iZp = quantizationInfo->getInputZp();
+
+ if (!validIntegerRange(inputETy.cast<IntegerType>(), iZp))
+ return rewriter.notifyMatchFailure(
+ op, "tosa.conv op quantization has zp outside of input range");
+
+ zeroAttr = rewriter.getIntegerAttr(inputETy, iZp);
+ }
+
+ llvm::SmallVector<int64_t> newShape(inputType.getShape());
+
+ for (int i = 0, s = newShape.size(); i < s; ++i) {
+ if (newShape[i] != ShapedType::kDynamic) {
+ newShape[i] += pad[i * 2] + pad[i * 2 + 1];
+ }
+ }
+
+ auto padSizeTy = RankedTensorType::get({4, 2}, rewriter.getI64Type());
+ auto padSize =
+ DenseIntElementsAttr::get(padSizeTy, ArrayRef<int64_t>(pad));
+ Value padSizeVal =
+ rewriter.create<tosa::ConstOp>(op->getLoc(), padSizeTy, padSize);
+
+ auto padTy = RankedTensorType::get({}, inputETy);
+ auto padAttr = DenseElementsAttr::get(padTy, zeroAttr);
+ Value padVal =
+ rewriter.create<tosa::ConstOp>(op->getLoc(), padTy, padAttr);
+ inputType = RankedTensorType::get(newShape, inputETy);
+ input = rewriter.create<tosa::PadOp>(op->getLoc(), inputType, input,
+ padSizeVal, padVal);
+ }
+
// Reshape input to [N,IH,IW,IC] -> [N * IH * IW, IC].
ArrayRef<int64_t> inputShape = inputType.getShape();
int64_t combined = ShapedType::kDynamic;
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
index d82a13ea6f515..f26d289ff2ac9 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
@@ -31,18 +31,12 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
ShapedType inputType = input.getType().cast<ShapedType>();
ShapedType weightType = weight.getType().cast<ShapedType>();
ShapedType resultType = op.getOutput().getType().cast<ShapedType>();
- Type inputEType = inputType.getElementType();
if (!(inputType.hasStaticShape() && weightType.hasStaticShape() &&
resultType.hasStaticShape())) {
return failure();
}
- // Quantization information needs to still be performed.
- if (op.getQuantizationInfo() || !inputEType.isa<FloatType>()) {
- return failure();
- }
-
// Stride must be 1 for this optimization.
for (Attribute stride : op.getStride().getValue()) {
if (!stride.cast<IntegerAttr>().getValue().isOne()) {
@@ -60,39 +54,88 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
ArrayRef<int64_t> inputShape = inputType.getShape();
llvm::SmallVector<int64_t, 2> revisedInputShape{
inputShape[0], inputShape[1], inputShape[2], inputShape[3], 1};
- auto revisedInputShapeType = RankedTensorType::get(
+ inputType = RankedTensorType::get(
revisedInputShape,
input.getType().dyn_cast<RankedTensorType>().getElementType());
- auto reshapedInput = rewriter
- .create<tosa::ReshapeOp>(
- op.getLoc(), revisedInputShapeType, input,
- rewriter.getI64ArrayAttr(revisedInputShape))
- .getResult();
-
- // Reshape kernel to [KH, KW, C, M] -> [1, 1, 1, C, M].
- llvm::SmallVector<int64_t, 2> revisedWeightShape{1, 1, 1, weightShape[2],
- weightShape[3]};
- auto revisedWeightShapeType = RankedTensorType::get(
- revisedWeightShape,
- weight.getType().dyn_cast<RankedTensorType>().getElementType());
- auto reshapedWeight = rewriter
- .create<tosa::ReshapeOp>(
- op.getLoc(), revisedWeightShapeType, weight,
- rewriter.getI64ArrayAttr(revisedWeightShape))
- .getResult();
+ input = rewriter
+ .create<tosa::ReshapeOp>(
+ op.getLoc(), inputType, input,
+ rewriter.getI64ArrayAttr(revisedInputShape))
+ .getResult();
+
+ if (inputType.getElementType() != resultType.getElementType()) {
+ inputType = inputType.clone(resultType.getElementType());
+ input = rewriter.create<tosa::CastOp>(op.getLoc(), inputType, input);
+ }
+
+ if (weightType.getElementType() != resultType.getElementType()) {
+ weightType = weightType.clone(resultType.getElementType());
+ weight = rewriter.create<tosa::CastOp>(op.getLoc(), weightType, weight);
+ }
+
+ if (auto quantizationInfo = op.getQuantizationInfo()) {
+ auto iZp = quantizationInfo->getInputZp();
+ auto wZp = quantizationInfo->getWeightZp();
+
+ auto applyZp = [&](Value val, int64_t zp) -> Value {
+ if (zp == 0)
+ return val;
+ auto ety = val.getType().cast<ShapedType>().getElementType();
+ auto zpTy = RankedTensorType::get({}, ety);
+ auto zpAttr =
+ DenseElementsAttr::get(zpTy, rewriter.getIntegerAttr(ety, zp));
+ auto zpVal = rewriter.create<tosa::ConstOp>(op.getLoc(), zpTy, zpAttr);
+ return rewriter.create<tosa::SubOp>(op.getLoc(), val.getType(), val,
+ zpVal);
+ };
+
+ input = applyZp(input, iZp);
+ weight = applyZp(weight, wZp);
+ }
+
+ auto padAttr = op.getPad();
+ llvm::SmallVector<int64_t> pad(10, 0);
+ for (auto it : llvm::enumerate(padAttr.getValue()))
+ pad[it.index() + 2] =
+ it.value().cast<IntegerAttr>().getValue().getSExtValue();
+
+ if (llvm::any_of(pad, [](int64_t p) { return p != 0; })) {
+ Type inputETy = inputType.getElementType();
+ Attribute zeroAttr = rewriter.getZeroAttr(inputETy);
+
+ llvm::SmallVector<int64_t> newShape(inputType.getShape());
+ for (int i = 0, s = pad.size(); i < s; ++i) {
+ if (newShape[i / 2] != ShapedType::kDynamic) {
+ newShape[i / 2] += pad[i];
+ }
+ }
+
+ auto padSizeTy = RankedTensorType::get({5, 2}, rewriter.getI64Type());
+ auto padSize =
+ DenseIntElementsAttr::get(padSizeTy, ArrayRef<int64_t>(pad));
+ Value padSizeVal =
+ rewriter.create<tosa::ConstOp>(op->getLoc(), padSizeTy, padSize);
+
+ auto padTy = RankedTensorType::get({}, inputETy);
+ auto padAttr = DenseElementsAttr::get(padTy, zeroAttr);
+ Value padVal =
+ rewriter.create<tosa::ConstOp>(op->getLoc(), padTy, padAttr);
+ inputType = RankedTensorType::get(newShape, inputETy);
+ input = rewriter.create<tosa::PadOp>(op->getLoc(), inputType, input,
+ padSizeVal, padVal);
+ }
// Perform an elementwise mul over the reshaped input and weight.
- llvm::SmallVector<int64_t, 2> mulShape{inputShape[0], inputShape[1],
- inputShape[2], inputShape[3],
- weightShape[3]};
+ llvm::SmallVector<int64_t, 2> mulShape{
+ inputType.getDimSize(0), inputType.getDimSize(1),
+ inputType.getDimSize(2), inputType.getDimSize(3), weightShape[3]};
auto mulShapeType = RankedTensorType::get(
mulShape,
weight.getType().dyn_cast<RankedTensorType>().getElementType());
- Value mulValue =
- rewriter
- .create<tosa::MulOp>(op.getLoc(), mulShapeType, reshapedInput,
- reshapedWeight, /*shift=*/0)
- .getResult();
+ Value mulValue = rewriter
+ .create<tosa::MulOp>(op.getLoc(), mulShapeType, input,
+ weight, /*shift=*/0)
+ .getResult();
// Reshape output to [N, H, W, C * M].
auto outputShape = op.getOutput().getType().cast<ShapedType>().getShape();
diff --git a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
index c511ca1a6c76e..346ff860d2884 100644
--- a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
+++ b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
@@ -46,3 +46,17 @@ Value mlir::tosa::clampIntHelper(Location loc, Value arg, Value min, Value max,
rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, max, arg);
return rewriter.create<arith::SelectOp>(loc, largerThanMax, max, minOrArg);
}
+
+bool mlir::tosa::validIntegerRange(IntegerType ty, int64_t value) {
+ uint64_t bitwidth = ty.getIntOrFloatBitWidth();
+ if (ty.getSignedness() == IntegerType::Unsigned) {
+ uint64_t uvalue = value;
+ APInt intMin = APInt::getMinValue(bitwidth);
+ APInt intMax = APInt::getMaxValue(bitwidth);
+ return uvalue >= intMin.getZExtValue() && uvalue <= intMax.getZExtValue();
+ }
+
+ APInt intMin = APInt::getSignedMinValue(bitwidth);
+ APInt intMax = APInt::getSignedMaxValue(bitwidth);
+ return value >= intMin.getSExtValue() && value <= intMax.getSExtValue();
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir
index e651140bfde18..09025c635fe3f 100644
--- a/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir
@@ -54,3 +54,17 @@ func.func @conv_with_dynamic_dim(%arg0: tensor<?x14x14x64xi8>, %arg1: tensor<384
return %0 : tensor<?x14x14x384xi32>
}
+// -----
+
+// CHECK-LABEL: @conv2d_as_fully_connected_padded
+func.func @conv2d_as_fully_connected_padded(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<3x1x1x2xi8>, %arg2: tensor<3xi32>) -> tensor<4x12x12x3xi32> {
+ // CHECK-DAG: %[[PAD_SHAPE:.+]] = "tosa.const"() {value = dense<{{\[\[}}0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi64>}
+ // CHECK-DAG: %[[PAD_VAL:.+]] = "tosa.const"() {value = dense<42> : tensor<i8>}
+ // CHECK-DAG: %[[PAD:.+]] = "tosa.pad"(%arg0, %[[PAD_SHAPE]], %[[PAD_VAL]]) : (tensor<4x10x10x2xi8>, tensor<4x2xi64>, tensor<i8>) -> tensor<4x12x12x2xi8>
+ // CHECK-DAG: %[[RESHAPE_INPUT:.+]] = "tosa.reshape"(%[[PAD]]) {new_shape = [576, 2]}
+ // CHECK-DAG: %[[RESHAPE_FILTER:.+]] = "tosa.reshape"(%arg1) {new_shape = [3, 2]}
+ // CHECK-DAG: %[[FULLY:.+]] = "tosa.fully_connected"(%[[RESHAPE_INPUT]], %[[RESHAPE_FILTER]], %arg2) {quantization_info = #tosa.conv_quant<input_zp = 42, weight_zp = 24>}
+ // CHECK: %[[RESHAPE:.+]] = "tosa.reshape"(%[[FULLY]]) {new_shape = [4, 12, 12, 3]}
+ %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {pad = [1, 1, 1, 1], stride = [1, 1], dilation = [1, 1], quantization_info = #tosa.conv_quant<input_zp = 42, weight_zp = 24>} : (tensor<4x10x10x2xi8>, tensor<3x1x1x2xi8>, tensor<3xi32>) -> tensor<4x12x12x3xi32>
+ return %0 : tensor<4x12x12x3xi32>
+}
\ No newline at end of file
diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
index 14adb96ec900c..2450d1eb84d1d 100644
--- a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
@@ -7,9 +7,7 @@ func.func @depthwise_conv2d_as_mul(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<1
// CHECK-NOT: "tosa.depthwise_conv2d"
// CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [4, 10, 10, 2, 1]}
// CHECK-SAME: -> tensor<4x10x10x2x1xf32>
- // CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = [1, 1, 1, 2, 3]}
- // CHECK-SAME: -> tensor<1x1x1x2x3xf32>
- // CHECK: %[[VAR2:.*]] = "tosa.mul"(%[[VAR0]], %[[VAR1]])
+ // CHECK: %[[VAR2:.*]] = "tosa.mul"(%[[VAR0]], %arg1)
// CHECK-SAME: -> tensor<4x10x10x2x3xf32>
// CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [4, 10, 10, 6]}
// CHECK-SAME: -> tensor<4x10x10x6xf32>
@@ -24,9 +22,31 @@ func.func @depthwise_conv2d_as_mul(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<1
// CHECK-LABEL: @depthwise_conv2d_as_mul_q
func.func @depthwise_conv2d_as_mul_q(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<1x1x2x3xi8>, %arg2: tensor<6xi32>) -> tensor<4x10x10x6xi32> {
- // CHECK: "tosa.depthwise_conv2d"
- %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], quantization_info = #tosa.conv_quant<input_zp = 0, weight_zp = 0>} : (tensor<4x10x10x2xi8>, tensor<1x1x2x3xi8>, tensor<6xi32>) -> tensor<4x10x10x6xi32>
+ // CHECK: %[[iZp:.+]] = "tosa.const"() {value = dense<7> : tensor<i32>}
+ // CHECK: %[[wZp:.+]] = "tosa.const"() {value = dense<11> : tensor<i32>}
+ // CHECK: %[[rIn:.+]] = "tosa.reshape"(%arg0) {new_shape = [4, 10, 10, 2, 1]}
+ // CHECK: %[[cIn:.+]] = "tosa.cast"(%[[rIn]]) : (tensor<4x10x10x2x1xi8>) -> tensor<4x10x10x2x1xi32>
+ // CHECK: %[[cWe:.+]] = "tosa.cast"(%arg1) : (tensor<1x1x2x3xi8>) -> tensor<1x1x2x3xi32>
+ // CHECK: %[[sIn:.+]] = "tosa.sub"(%[[cIn]], %[[iZp]])
+ // CHECK: %[[sWe:.+]] = "tosa.sub"(%[[cWe]], %[[wZp]])
+ // CHECK: %[[mul:.+]] = "tosa.mul"(%[[sIn]], %[[sWe]]) {shift = 0 : i32}
+ // CHECK: %[[reO:.+]] = "tosa.reshape"(%[[mul]]) {new_shape = [4, 10, 10, 6]}
+ // CHECK: %[[add:.+]] = "tosa.add"(%[[reO]], %arg2)
+ %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], quantization_info = #tosa.conv_quant<input_zp = 7, weight_zp = 11>} : (tensor<4x10x10x2xi8>, tensor<1x1x2x3xi8>, tensor<6xi32>) -> tensor<4x10x10x6xi32>
return %0 : tensor<4x10x10x6xi32>
}
// -----
+
+// CHECK-LABEL: @depthwise_conv2d_as_mul_padded
+func.func @depthwise_conv2d_as_mul_padded(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<1x1x2x3xf32>, %arg2: tensor<6xf32>) -> tensor<4x12x12x6xf32> {
+ // CHECK: %[[pad:.+]] = "tosa.const"() {value = dense<{{\[\[}}0, 0], [1, 1], [1, 1], [0, 0], [0, 0]]> : tensor<5x2xi64>}
+ // CHECK: %[[zero:.+]] = "tosa.const"() {value = dense<0.000000e+00> : tensor<f32>}
+ // CHECK: %[[reIn:.+]] = "tosa.reshape"(%arg0) {new_shape = [4, 10, 10, 2, 1]}
+ // CHECK: %[[padded:.+]] = "tosa.pad"(%[[reIn]], %[[pad]], %[[zero]]) : (tensor<4x10x10x2x1xf32>, tensor<5x2xi64>, tensor<f32>) -> tensor<4x12x12x2x1xf32>
+ // CHECK: %[[mul:.+]] = "tosa.mul"(%3, %arg1) {shift = 0 : i32}
+ // CHECK: %[[reOut:.+]] = "tosa.reshape"(%[[mul]]) {new_shape = [4, 12, 12, 6]}
+ // CHECK: %[[add:.+]] = "tosa.add"(%[[reOut]], %arg2)
+ %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [1, 1, 1, 1], stride = [1, 1], dilation = [1, 1]} : (tensor<4x10x10x2xf32>, tensor<1x1x2x3xf32>, tensor<6xf32>) -> tensor<4x12x12x6xf32>
+ return %0 : tensor<4x12x12x6xf32>
+}
More information about the Mlir-commits
mailing list