[Mlir-commits] [mlir] e0537d1 - [TOSA] Refactor TosaMakeBroadcastable pass
Eric Kunze
llvmlistbot at llvm.org
Wed May 24 14:56:15 PDT 2023
Author: Tai Ly
Date: 2023-05-24T14:43:33-07:00
New Revision: e0537d1ad4b9aa41928ecf7eff75d161f456059f
URL: https://github.com/llvm/llvm-project/commit/e0537d1ad4b9aa41928ecf7eff75d161f456059f
DIFF: https://github.com/llvm/llvm-project/commit/e0537d1ad4b9aa41928ecf7eff75d161f456059f.diff
LOG: [TOSA] Refactor TosaMakeBroadcastable pass
This refactors and exposes EqualizeRanks utility function
from within TosaMakeBroadcastable pass so it may be used to
reshape operator inputs to equal ranks.
Signed-off-by: Tai Ly <tai.ly at arm.com>
Differential Revision: https://reviews.llvm.org/D150283
Added:
Modified:
mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
index f425d376fbc7e..ca59b221d03eb 100644
--- a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
+++ b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
@@ -72,6 +72,13 @@ checkHasDynamicBatchDims(PatternRewriter &rewriter, Op op,
return dynamicDims;
}
+/// Common code to create the reshape op where necessary to make the rank of two
+/// values equal. input1 and input2 will be updated when the rank has
+/// changed. The caller is expected to use these to rewrite the original
+/// operator with the RESHAPE now in the graph.
+LogicalResult EqualizeRanks(PatternRewriter &rewriter, Location loc,
+ Value &input1, Value &input2);
+
} // namespace tosa
} // namespace mlir
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
index 488e46d1339a1..e6fba211dc37a 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
@@ -13,6 +13,7 @@
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
#include "mlir/Pass/Pass.h"
using namespace mlir;
@@ -77,7 +78,9 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
if (zp == 0)
return val;
auto ety = cast<ShapedType>(val.getType()).getElementType();
- auto zpTy = RankedTensorType::get({}, ety);
+ std::vector<int64_t> shape(cast<ShapedType>(val.getType()).getRank(),
+ 1);
+ auto zpTy = RankedTensorType::get(shape, ety);
auto zpAttr =
DenseElementsAttr::get(zpTy, rewriter.getIntegerAttr(ety, zp));
auto zpVal = rewriter.create<tosa::ConstOp>(op.getLoc(), zpTy, zpAttr);
@@ -127,6 +130,11 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
auto mulShapeType = RankedTensorType::get(
mulShape,
dyn_cast<RankedTensorType>(weight.getType()).getElementType());
+
+ if (EqualizeRanks(rewriter, op.getLoc(), input, weight).failed()) {
+ return failure();
+ }
+
Value mulValue = rewriter
.create<tosa::MulOp>(op.getLoc(), mulShapeType, input,
weight, /*shift=*/0)
@@ -137,14 +145,18 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
auto outputShapeType = RankedTensorType::get(
outputShape,
dyn_cast<RankedTensorType>(input.getType()).getElementType());
- auto outputValue = rewriter.create<tosa::ReshapeOp>(
+ Value outputValue = rewriter.create<tosa::ReshapeOp>(
op.getLoc(), outputShapeType, mulValue,
rewriter.getDenseI64ArrayAttr(outputShape));
+ Value bias = op.getBias();
+ if (EqualizeRanks(rewriter, op.getLoc(), outputValue, bias).failed()) {
+ return failure();
+ }
+
// Add in the bias.
rewriter
- .replaceOpWithNewOp<tosa::AddOp>(op, outputShapeType, outputValue,
- op.getBias())
+ .replaceOpWithNewOp<tosa::AddOp>(op, outputShapeType, outputValue, bias)
.getResult();
return success();
}
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
index 50a556dfc6945..2339fb7fde3dc 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
#include "mlir/Pass/Pass.h"
@@ -365,10 +366,14 @@ class TransposeConvStridedConverter
Value resultPaddingVal = createOpAndInfer<tosa::ConstOp>(
rewriter, loc, resultPaddingAttr.getType(), resultPaddingAttr);
- auto resultPad = createOpAndInfer<tosa::PadOp>(
+ Value resultPad = createOpAndInfer<tosa::PadOp>(
rewriter, loc, UnrankedTensorType::get(resultETy), slice,
resultPaddingVal);
+ if (EqualizeRanks(rewriter, op.getLoc(), resultPad, bias).failed()) {
+ return failure();
+ }
+
rewriter.replaceOpWithNewOp<tosa::AddOp>(op, op.getType(), resultPad, bias);
return success();
}
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
index bcfcbbbbcee69..829db2a86c44a 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
@@ -14,6 +14,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -28,60 +29,17 @@ namespace tosa {
using namespace mlir;
using namespace mlir::tosa;
-/// There are two potential ways implementing broadcast:
-/// a. https://www.tensorflow.org/xla/broadcasting#formal_definition
-/// b. https://numpy.org/doc/stable/user/basics.broadcasting.html
-/// This pass implements b (numpy style) now.
-
-/// In this pass, we insert RESHAPE operators to increase the rank of the
-/// lower rank operand as a first step in the broadcasting process. The TOSA
-/// operators that support broadcast require that the rank of the operands
-/// are equal.
-
-// Examples:
-// If lower=[c], higher=[a, b, c], [c] reshaped into [1, 1, c].
-// If lower=[b, c], higher=[a, b, c], [b, c] reshaped into [1, b, c].
-// If lower=[a], higher=[a, a], [a] reshaped into [1, a].
-// If lower=[a], target=[a, b, a], [a] reshaped into [1, 1, a].
-// If lower=[], target=[a, b, c], [] reshaped into [1, 1, 1].
-
-static LogicalResult
-computeReshapeOutput(ArrayRef<int64_t> higherRankShape,
- ArrayRef<int64_t> lowerRankShape,
- SmallVectorImpl<int64_t> &reshapeOutputShape) {
- // Initialize new shapes with [1] * higherRank.
- int64_t higherRank = higherRankShape.size();
- int64_t lowerRank = lowerRankShape.size();
-
- reshapeOutputShape.assign(higherRank, 1);
-
- int64_t higherRankDim;
- int64_t lowerRankDim;
-
- for (int64_t i = higherRank - 1, j = lowerRank - 1; i >= 0 && j >= 0;
- i--, j--) {
- higherRankDim = higherRankShape[i];
- lowerRankDim = lowerRankShape[j];
-
- if (lowerRankDim == 1 && higherRankDim > 1)
- reshapeOutputShape[i] = 1;
- else if ((lowerRankDim > 1 && higherRankDim == 1) ||
- (lowerRankDim == higherRankDim))
- reshapeOutputShape[i] = lowerRankDim;
- else if (higherRankDim != lowerRankDim)
- return failure();
- }
- return success();
-}
+namespace {
/// Common code to create the reshape op where necessary to make the rank of the
/// operations equal. input1 and input2 will be updated when the rank has
/// changed. The caller is expected to use these to rewrite the original
/// operator with the RESHAPE now in the graph.
-static LogicalResult reshapeLowerToHigher(PatternRewriter &rewriter,
- Location loc,
- RankedTensorType outputType,
- Value &input1, Value &input2) {
+/// return failure when (1) no reshape needed, or (2) output_type is specified
+/// and it has
diff erent rank
+LogicalResult reshapeLowerToHigher(PatternRewriter &rewriter, Location loc,
+ RankedTensorType outputType, Value &input1,
+ Value &input2) {
auto input1Ty = dyn_cast<RankedTensorType>(input1.getType());
auto input2Ty = dyn_cast<RankedTensorType>(input2.getType());
@@ -96,54 +54,28 @@ static LogicalResult reshapeLowerToHigher(PatternRewriter &rewriter,
return rewriter.notifyMatchFailure(loc,
"cannot rewrite as its already correct");
- Value higherTensorValue, lowerTensorValue;
- if (input1Rank > input2Rank) {
- higherTensorValue = input1;
- lowerTensorValue = input2;
- } else {
- higherTensorValue = input2;
- lowerTensorValue = input1;
+ Value input1_copy = input1;
+ Value input2_copy = input2;
+ if (EqualizeRanks(rewriter, loc, input1_copy, input2_copy).failed()) {
+ return rewriter.notifyMatchFailure(loc, "failed to reshape inputs");
}
- ArrayRef<int64_t> higherRankShape =
- cast<RankedTensorType>(higherTensorValue.getType()).getShape();
- ArrayRef<int64_t> lowerRankShape =
- cast<RankedTensorType>(lowerTensorValue.getType()).getShape();
-
- SmallVector<int64_t, 4> reshapeOutputShape;
-
- if (computeReshapeOutput(higherRankShape, lowerRankShape, reshapeOutputShape)
- .failed())
- return rewriter.notifyMatchFailure(loc, "fail to compute a reshape type");
-
- auto reshapeInputType = cast<RankedTensorType>(lowerTensorValue.getType());
- auto reshapeOutputType = RankedTensorType::get(
- ArrayRef<int64_t>(reshapeOutputShape), reshapeInputType.getElementType());
-
// Verify the rank agrees with the output type if the output type is ranked.
if (outputType) {
- if (outputType.getShape().size() != reshapeOutputShape.size() ||
- outputType.getShape().size() != higherRankShape.size())
+ if (outputType.getRank() !=
+ input1_copy.getType().cast<RankedTensorType>().getRank() ||
+ outputType.getRank() !=
+ input2_copy.getType().cast<RankedTensorType>().getRank())
return rewriter.notifyMatchFailure(
loc, "the reshaped type doesn't agrees with the ranked output type");
}
- auto reshapeLower = rewriter.create<tosa::ReshapeOp>(
- loc, reshapeOutputType, lowerTensorValue,
- rewriter.getDenseI64ArrayAttr(reshapeOutputShape));
-
- if (input1Rank > input2Rank) {
- input1 = higherTensorValue;
- input2 = reshapeLower.getResult();
- } else {
- input1 = reshapeLower.getResult();
- input2 = higherTensorValue;
- }
+ input1 = input1_copy;
+ input2 = input2_copy;
return success();
}
-namespace {
template <typename OpTy>
struct ConvertTosaOp : public OpRewritePattern<OpTy> {
using OpRewritePattern<OpTy>::OpRewritePattern;
@@ -268,8 +200,10 @@ struct ConvertTosaOp<tosa::SelectOp> : public OpRewritePattern<tosa::SelectOp> {
int32_t result1Rank = cast<RankedTensorType>(input1.getType()).getRank();
int32_t result2Rank = cast<RankedTensorType>(input2.getType()).getRank();
int32_t result3Rank = cast<RankedTensorType>(input3.getType()).getRank();
+ int32_t outputRank = outputType.getRank();
- if ((result1Rank != result2Rank) || (result2Rank != result3Rank))
+ if ((result1Rank != result2Rank) || (result2Rank != result3Rank) ||
+ (result1Rank != outputRank))
return rewriter.notifyMatchFailure(
tosaOp, "not all ranks are aligned with each other");
diff --git a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
index 346ff860d2884..8f84a064382f4 100644
--- a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
+++ b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
@@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
using namespace mlir;
using namespace mlir::tosa;
@@ -60,3 +61,96 @@ bool mlir::tosa::validIntegerRange(IntegerType ty, int64_t value) {
APInt intMax = APInt::getSignedMaxValue(bitwidth);
return value >= intMin.getSExtValue() && value <= intMax.getSExtValue();
}
+
+namespace {
+// Given two tensors of high and low ranks, derive the output shape
+// to reshape the lower rank to.
+// Examples:
+// If lower=[c], higher=[a, b, c], [c] reshaped into [1, 1, c].
+// If lower=[b, c], higher=[a, b, c], [b, c] reshaped into [1, b, c].
+// If lower=[a], higher=[a, a], [a] reshaped into [1, a].
+// If lower=[a], target=[a, b, a], [a] reshaped into [1, 1, a].
+// If lower=[], target=[a, b, c], [] reshaped into [1, 1, 1].
+LogicalResult
+computeReshapeOutput(ArrayRef<int64_t> higherRankShape,
+ ArrayRef<int64_t> lowerRankShape,
+ SmallVectorImpl<int64_t> &reshapeOutputShape) {
+ // Initialize new shapes with [1] * higherRank.
+ int64_t higherRank = higherRankShape.size();
+ int64_t lowerRank = lowerRankShape.size();
+
+ reshapeOutputShape.assign(higherRank, 1);
+
+ int64_t higherRankDim;
+ int64_t lowerRankDim;
+
+ for (int64_t i = higherRank - 1, j = lowerRank - 1; i >= 0 && j >= 0;
+ i--, j--) {
+ higherRankDim = higherRankShape[i];
+ lowerRankDim = lowerRankShape[j];
+
+ if (lowerRankDim == 1 && higherRankDim > 1)
+ reshapeOutputShape[i] = 1;
+ else if ((lowerRankDim > 1 && higherRankDim == 1) ||
+ (lowerRankDim == higherRankDim))
+ reshapeOutputShape[i] = lowerRankDim;
+ else if (higherRankDim != lowerRankDim)
+ return failure();
+ }
+ return success();
+}
+} // namespace
+
+LogicalResult mlir::tosa::EqualizeRanks(PatternRewriter &rewriter, Location loc,
+ Value &input1, Value &input2) {
+ auto input1Ty = input1.getType().dyn_cast<RankedTensorType>();
+ auto input2Ty = input2.getType().dyn_cast<RankedTensorType>();
+
+ if (!input1Ty || !input2Ty) {
+ return failure();
+ }
+
+ int64_t input1Rank = input1Ty.getRank();
+ int64_t input2Rank = input2Ty.getRank();
+
+ if (input1Rank == input2Rank)
+ return success();
+
+ Value higherTensorValue, lowerTensorValue;
+ if (input1Rank > input2Rank) {
+ higherTensorValue = input1;
+ lowerTensorValue = input2;
+ } else {
+ higherTensorValue = input2;
+ lowerTensorValue = input1;
+ }
+
+ ArrayRef<int64_t> higherRankShape =
+ higherTensorValue.getType().cast<RankedTensorType>().getShape();
+ ArrayRef<int64_t> lowerRankShape =
+ lowerTensorValue.getType().cast<RankedTensorType>().getShape();
+
+ SmallVector<int64_t, 4> reshapeOutputShape;
+
+ if (computeReshapeOutput(higherRankShape, lowerRankShape, reshapeOutputShape)
+ .failed())
+ return failure();
+
+ auto reshapeInputType = lowerTensorValue.getType().cast<RankedTensorType>();
+ auto reshapeOutputType = RankedTensorType::get(
+ ArrayRef<int64_t>(reshapeOutputShape), reshapeInputType.getElementType());
+
+ auto reshapeLower = rewriter.create<tosa::ReshapeOp>(
+ loc, reshapeOutputType, lowerTensorValue,
+ rewriter.getDenseI64ArrayAttr(reshapeOutputShape));
+
+ if (input1Rank > input2Rank) {
+ input1 = higherTensorValue;
+ input2 = reshapeLower.getResult();
+ } else {
+ input1 = reshapeLower.getResult();
+ input2 = higherTensorValue;
+ }
+
+ return success();
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
index e835991273ec5..59e7d35bf77b2 100644
--- a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
@@ -7,13 +7,17 @@ func.func @depthwise_conv2d_as_mul(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<1
// CHECK-NOT: "tosa.depthwise_conv2d"
// CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) <{new_shape = array<i64: 4, 10, 10, 2, 1>}
// CHECK-SAME: -> tensor<4x10x10x2x1xf32>
- // CHECK: %[[VAR2:.*]] = "tosa.mul"(%[[VAR0]], %arg1)
+ // CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg1) <{new_shape = array<i64: 1, 1, 1, 2, 3>}
+ // CHECK-SAME: -> tensor<1x1x1x2x3xf32>
+ // CHECK: %[[VAR2:.*]] = "tosa.mul"(%[[VAR0]], %[[VAR1]])
// CHECK-SAME: -> tensor<4x10x10x2x3xf32>
// CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) <{new_shape = array<i64: 4, 10, 10, 6>}
// CHECK-SAME: -> tensor<4x10x10x6xf32>
- // CHECK: %[[VAR4:.*]] = "tosa.add"(%[[VAR3]], %arg2)
+ // CHECK: %[[VAR4:.*]] = "tosa.reshape"(%arg2) <{new_shape = array<i64: 1, 1, 1, 6>}
+ // CHECK-SAME: -> tensor<1x1x1x6xf32>
+ // CHECK: %[[VAR5:.*]] = "tosa.add"(%[[VAR3]], %[[VAR4]])
// CHECK-SAME: -> tensor<4x10x10x6xf32>
- // CHECK: return %[[VAR4]]
+ // CHECK: return %[[VAR5]]
%0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<4x10x10x2xf32>, tensor<1x1x2x3xf32>, tensor<6xf32>) -> tensor<4x10x10x6xf32>
return %0 : tensor<4x10x10x6xf32>
}
@@ -22,16 +26,18 @@ func.func @depthwise_conv2d_as_mul(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<1
// CHECK-LABEL: @depthwise_conv2d_as_mul_q
func.func @depthwise_conv2d_as_mul_q(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<1x1x2x3xi8>, %arg2: tensor<6xi32>) -> tensor<4x10x10x6xi32> {
- // CHECK: %[[iZp:.+]] = "tosa.const"() <{value = dense<7> : tensor<i32>}
- // CHECK: %[[wZp:.+]] = "tosa.const"() <{value = dense<11> : tensor<i32>}
+ // CHECK: %[[iZp:.+]] = "tosa.const"() <{value = dense<7> : tensor<1x1x1x1x1xi32>}
+ // CHECK: %[[wZp:.+]] = "tosa.const"() <{value = dense<11> : tensor<1x1x1x1xi32>}
// CHECK: %[[rIn:.+]] = "tosa.reshape"(%arg0) <{new_shape = array<i64: 4, 10, 10, 2, 1>}
// CHECK: %[[cIn:.+]] = "tosa.cast"(%[[rIn]]) : (tensor<4x10x10x2x1xi8>) -> tensor<4x10x10x2x1xi32>
// CHECK: %[[cWe:.+]] = "tosa.cast"(%arg1) : (tensor<1x1x2x3xi8>) -> tensor<1x1x2x3xi32>
// CHECK: %[[sIn:.+]] = "tosa.sub"(%[[cIn]], %[[iZp]])
// CHECK: %[[sWe:.+]] = "tosa.sub"(%[[cWe]], %[[wZp]])
- // CHECK: %[[mul:.+]] = "tosa.mul"(%[[sIn]], %[[sWe]]) <{shift = 0 : i32}
+ // CHECK: %[[resWe:.+]] = "tosa.reshape"(%[[sWe]]) <{new_shape = array<i64: 1, 1, 1, 2, 3>}
+ // CHECK: %[[mul:.+]] = "tosa.mul"(%[[sIn]], %[[resWe]]) <{shift = 0 : i32}
// CHECK: %[[reO:.+]] = "tosa.reshape"(%[[mul]]) <{new_shape = array<i64: 4, 10, 10, 6>}
- // CHECK: %[[add:.+]] = "tosa.add"(%[[reO]], %arg2)
+ // CHECK: %[[reArg2:.+]] = "tosa.reshape"(%arg2) <{new_shape = array<i64: 1, 1, 1, 6>}
+ // CHECK: %[[add:.+]] = "tosa.add"(%[[reO]], %[[reArg2]])
%0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>, quantization_info = #tosa.conv_quant<input_zp = 7, weight_zp = 11>} : (tensor<4x10x10x2xi8>, tensor<1x1x2x3xi8>, tensor<6xi32>) -> tensor<4x10x10x6xi32>
return %0 : tensor<4x10x10x6xi32>
}
@@ -44,9 +50,11 @@ func.func @depthwise_conv2d_as_mul_padded(%arg0: tensor<4x10x10x2xf32>, %arg1: t
// CHECK: %[[zero:.+]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<f32>}
// CHECK: %[[reIn:.+]] = "tosa.reshape"(%arg0) <{new_shape = array<i64: 4, 10, 10, 2, 1>}
// CHECK: %[[padded:.+]] = "tosa.pad"(%[[reIn]], %[[pad]], %[[zero]]) : (tensor<4x10x10x2x1xf32>, tensor<5x2xi64>, tensor<f32>) -> tensor<4x12x12x2x1xf32>
- // CHECK: %[[mul:.+]] = "tosa.mul"(%3, %arg1) <{shift = 0 : i32}
+ // CHECK: %[[reArg1:.+]] = "tosa.reshape"(%arg1) <{new_shape = array<i64: 1, 1, 1, 2, 3>}
+ // CHECK: %[[mul:.+]] = "tosa.mul"(%3, %[[reArg1]]) <{shift = 0 : i32}
// CHECK: %[[reOut:.+]] = "tosa.reshape"(%[[mul]]) <{new_shape = array<i64: 4, 12, 12, 6>}
- // CHECK: %[[add:.+]] = "tosa.add"(%[[reOut]], %arg2)
+ // CHECK: %[[reArg2:.+]] = "tosa.reshape"(%arg2) <{new_shape = array<i64: 1, 1, 1, 6>}
+ // CHECK: %[[add:.+]] = "tosa.add"(%[[reOut]], %[[reArg2]])
%0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = array<i64: 1, 1, 1, 1>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<4x10x10x2xf32>, tensor<1x1x2x3xf32>, tensor<6xf32>) -> tensor<4x12x12x6xf32>
return %0 : tensor<4x12x12x6xf32>
}
diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
index 6ccf510804d99..daac52eddf204 100644
--- a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
@@ -28,7 +28,7 @@ func.func @transpose_conv2d_quantized(%arg0: tensor<2x16x14x3xi8>, %arg1: tensor
func.func @transpose_conv2d_quantized_padded(%arg0: tensor<2x16x14x3xi8>, %arg1: tensor<5x3x6x3xi8>, %arg2: tensor<5xi32>) -> (tensor<2x21x26x5xi32>) {
// CHECK-DAG: %[[REV0:.+]] = "tosa.reverse"(%0) <{axis = 2 : i64}
// CHECK-DAG: %[[REV1:.+]] = "tosa.reverse"(%arg1) <{axis = 1 : i64}
- // CHECK: "tosa.conv2d"(%arg0, %1, %arg2)
+ // CHECK: "tosa.conv2d"(%arg0, %1, %arg2)
// CHECK-SAME: dilation = array<i64: 1, 1>, pad = array<i64: 3, 4, 8, 9>,
// CHECK-SAME: quantization_info = #tosa.conv_quant<input_zp = -22, weight_zp = 42>, stride = array<i64: 1, 1>}
%0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {
@@ -65,7 +65,8 @@ func.func @transpose_conv2d_strided(%arg0: tensor<2x17x15x3xf32>, %arg1: tensor<
// CHECK-DAG: %[[TRANS_OUT:.+]] = "tosa.transpose"(%[[RESHAPE_OUT_1]], %[[TRANS2]])
// CHECK-DAG: %[[RESHAPE_OUT_2:.+]] = "tosa.reshape"(%[[TRANS_OUT]]) <{new_shape = array<i64: 2, 36, 48, 5>}
// CHECK-DAG: %[[SLICE:.+]] = "tosa.slice"(%[[RESHAPE_OUT_2]]) <{size = array<i64: 2, 35, 47, 5>, start = array<i64: 0, 0, 0, 0>}
- // CHECK: %[[ADD:.+]] = "tosa.add"(%[[SLICE]], %arg2)
+ // CHECK-DAG: %[[RESHAPE_ARG2:.+]] = "tosa.reshape"(%arg2) <{new_shape = array<i64: 1, 1, 1, 5>}
+ // CHECK: %[[ADD:.+]] = "tosa.add"(%[[SLICE]], %[[RESHAPE_ARG2]])
%0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 2, 3>} : (tensor<2x17x15x3xf32>, tensor<5x3x5x3xf32>, tensor<5xf32>) -> tensor<2x35x47x5xf32>
%1 = tensor.cast %0 : tensor<2x35x47x5xf32> to tensor<2x?x?x5xf32>
return %1 : tensor<2x?x?x5xf32>
@@ -97,8 +98,9 @@ func.func @transpose_conv2d_strided_quantized(%arg0: tensor<2x17x15x3xi8>, %arg1
// CHECK-DAG: %[[TRANS_OUT:.+]] = "tosa.transpose"(%[[RESHAPE_OUT_1]], %[[TRANS2]])
// CHECK-DAG: %[[RESHAPE_OUT_2:.+]] = "tosa.reshape"(%[[TRANS_OUT]]) <{new_shape = array<i64: 2, 36, 48, 5>}
// CHECK-DAG: %[[SLICE:.+]] = "tosa.slice"(%[[RESHAPE_OUT_2]]) <{size = array<i64: 2, 35, 47, 5>, start = array<i64: 0, 0, 0, 0>}
- // CHECK: %[[ADD:.+]] = "tosa.add"(%[[SLICE]], %arg2)
- %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) <{out_pad = array<i64: 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = -22, weight_zp = 42>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 2, 3>}> : (tensor<2x17x15x3xi8>, tensor<5x3x5x3xi8>, tensor<5xi32>) -> tensor<2x35x47x5xi32>
+ // CHECK-DAG: %[[RESHAPE_ARG2:.+]] = "tosa.reshape"(%arg2) <{new_shape = array<i64: 1, 1, 1, 5>}
+ // CHECK: %[[ADD:.+]] = "tosa.add"(%[[SLICE]], %[[RESHAPE_ARG2]])
+ %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {out_pad = array<i64: 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = -22, weight_zp = 42>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 2, 3>} : (tensor<2x17x15x3xi8>, tensor<5x3x5x3xi8>, tensor<5xi32>) -> tensor<2x35x47x5xi32>
return %0 : tensor<2x35x47x5xi32>
}
@@ -106,14 +108,14 @@ func.func @transpose_conv2d_strided_quantized(%arg0: tensor<2x17x15x3xi8>, %arg1
// CHECK-LABEL: @transpose_conv2d_strided_overpad
func.func @transpose_conv2d_strided_overpad(%arg0 : tensor<1x16x1x1xi8>, %arg1 : tensor<1x2x1x1xi8>, %arg2 : tensor<1xi32>) -> (tensor<1x19x2x1xi32>) {
- // CHECK: %[[WEIGHT_PAD:.+]] = "tosa.const"()
+ // CHECK: %[[WEIGHT_PAD:.+]] = "tosa.const"()
// CHECK-SAME{literal}: value = dense<[[0, 0], [0, 0], [0, 1], [0, 0]]> : tensor<4x2xi32>
// CHECK: %[[WEIGHT_PERMS:.+]] = "tosa.const"() <{value = dense<[2, 4, 0, 1, 3, 5]> : tensor<6xi32>}
- // CHECK: %[[INPUT_PAD:.+]] = "tosa.const"()
+ // CHECK: %[[INPUT_PAD:.+]] = "tosa.const"()
// CHECK-SAME{literal}: value = dense<[[0, 0], [1, 1], [0, 0], [0, 0]]> : tensor<4x2xi32>}
// CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0> : tensor<2xi32>}
// CHECK: %[[RESULT_PERMS:.+]] = "tosa.const"() <{value = dense<[0, 1, 3, 2, 4, 5]> : tensor<6xi32>}
- // CHECK: %[[RESULT_PAD:.+]] = "tosa.const"()
+ // CHECK: %[[RESULT_PAD:.+]] = "tosa.const"()
// CHECK-SAME{literal}: value = dense<[[0, 0], [2, 0], [0, 0], [0, 0]]> : tensor<4x2xi32>}
// CHECK: %[[PAD_WEIGHT:.+]] = "tosa.pad"(%arg1, %[[WEIGHT_PAD]]) <{quantization_info = #tosa.pad_quant<input_zp = 93>}
// CHECK: %[[RESHAPE_WEIGHT_0:.+]] = "tosa.reshape"(%[[PAD_WEIGHT]]) <{new_shape = array<i64: 1, 2, 1, 1, 2, 1>}
@@ -121,13 +123,14 @@ func.func @transpose_conv2d_strided_overpad(%arg0 : tensor<1x16x1x1xi8>, %arg1 :
// CHECK: %[[RESHAPE_WEIGHT_1:.+]] = "tosa.reshape"(%[[TRANSPOSE_WEIGHT]]) <{new_shape = array<i64: 2, 2, 1, 1>}
// CHECK: %[[REVERSE:.+]] = "tosa.reverse"(%[[RESHAPE_WEIGHT_1]]) <{axis = 1 : i64}
// CHECK: %[[PAD_INPUT:.+]] = "tosa.pad"(%arg0, %[[INPUT_PAD]]) <{quantization_info = #tosa.pad_quant<input_zp = -103>}
- // CHECK: %[[CONV:.+]] = "tosa.conv2d"(%[[PAD_INPUT]], %[[REVERSE]], %[[ZERO]])
+ // CHECK: %[[CONV:.+]] = "tosa.conv2d"(%[[PAD_INPUT]], %[[REVERSE]], %[[ZERO]])
// CHECK-SAME{literal}: dilation = [1, 1], pad = [0, 0, 0, 0], quantization_info = #tosa.conv_quant<input_zp = -103, weight_zp = 93>, stride = [1, 1]}
// CHECK: %[[RESHAPE_RESULT_0:.+]] = "tosa.reshape"(%[[CONV]]) <{new_shape = array<i64: 1, 17, 1, 1, 2, 1>}
// CHECK: %[[TRANSPOSE_RESULT:.+]] = "tosa.transpose"(%[[RESHAPE_RESULT_0]], %[[RESULT_PERMS]])
// CHECK: %[[RESHAPE_RESULT_1:.+]] = "tosa.reshape"(%[[TRANSPOSE_RESULT]]) <{new_shape = array<i64: 1, 17, 2, 1>}
// CHECK: %[[PAD_RESULT:.+]] = "tosa.pad"(%[[RESHAPE_RESULT_1]], %[[RESULT_PAD]])
- // CHECK: %[[ADD:.+]] = "tosa.add"(%[[PAD_RESULT]], %arg2)
+ // CHECK: %[[RESHAPE_ARG2:.+]] = "tosa.reshape"(%arg2) <{new_shape = array<i64: 1, 1, 1, 1>}
+ // CHECK: %[[ADD:.+]] = "tosa.add"(%[[PAD_RESULT]], %[[RESHAPE_ARG2]])
%2 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {
out_pad = array<i64: 2, 0, 0, 1>,
out_shape = array<i64: 1, -1, -1, 1>,
More information about the Mlir-commits
mailing list