[Mlir-commits] [mlir] 9a2308e - [mlir][tosa] Minor cleanup of tosa.conv2d canonicalizer
Rob Suderman
llvmlistbot at llvm.org
Thu Dec 16 15:14:05 PST 2021
Author: Rob Suderman
Date: 2021-12-16T15:13:01-08:00
New Revision: 9a2308e170b468df62dd969a41b8c408992cf84e
URL: https://github.com/llvm/llvm-project/commit/9a2308e170b468df62dd969a41b8c408992cf84e
DIFF: https://github.com/llvm/llvm-project/commit/9a2308e170b468df62dd969a41b8c408992cf84e.diff
LOG: [mlir][tosa] Minor cleanup of tosa.conv2d canonicalizer
Slight rename and better variable type usage in tosa.conv2d to
tosa.fully_connected lowering. Included disabling pass for padded
convolutions.
Reviewed By: not-jenni
Differential Revision: https://reviews.llvm.org/D115776
Added:
Modified:
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
mlir/test/Dialect/Tosa/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 579f0b407ee8e..68c0d015c3a60 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -423,8 +423,7 @@ void PadOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
results.insert<MaterializePadValue>(context);
}
-struct Conv2DFullyConnectedOptimization
- : public OpRewritePattern<tosa::Conv2DOp> {
+struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(tosa::Conv2DOp op,
@@ -439,6 +438,12 @@ struct Conv2DFullyConnectedOptimization
return failure();
}
+ for (Attribute pad : op.pad().getValue()) {
+ if (!pad.cast<IntegerAttr>().getValue().isZero()) {
+ return failure();
+ }
+ }
+
// Stride must be 1 for this optimization.
for (Attribute stride : op.stride().getValue()) {
if (!stride.cast<IntegerAttr>().getValue().isOne()) {
@@ -456,9 +461,8 @@ struct Conv2DFullyConnectedOptimization
ArrayRef<int64_t> inputShape = inputType.getShape();
llvm::SmallVector<int64_t, 2> revisedInputShape{
inputShape[0] * inputShape[1] * inputShape[2], inputShape[3]};
- auto revisedInputShapeType = RankedTensorType::get(
- revisedInputShape,
- input.getType().dyn_cast<RankedTensorType>().getElementType());
+ auto revisedInputShapeType =
+ RankedTensorType::get(revisedInputShape, inputType.getElementType());
auto reshapedInput = rewriter
.create<tosa::ReshapeOp>(
op.getLoc(), revisedInputShapeType, input,
@@ -468,9 +472,8 @@ struct Conv2DFullyConnectedOptimization
// Reshape kernel to [OC,KH,KW,IC] -> [OC, IC].
llvm::SmallVector<int64_t, 2> revisedWeightShape{weightShape[0],
weightShape[3]};
- auto revisedWeightShapeType = RankedTensorType::get(
- revisedWeightShape,
- weight.getType().dyn_cast<RankedTensorType>().getElementType());
+ auto revisedWeightShapeType =
+ RankedTensorType::get(revisedWeightShape, weightType.getElementType());
auto reshapedWeight = rewriter
.create<tosa::ReshapeOp>(
op.getLoc(), revisedWeightShapeType, weight,
@@ -480,9 +483,8 @@ struct Conv2DFullyConnectedOptimization
// Perform a fully connected network over the reshaped input and weight.
llvm::SmallVector<int64_t, 2> fullyConnectedShape{
inputShape[0] * inputShape[1] * inputShape[2], weightShape[0]};
- auto fullyConnectedShapeType = RankedTensorType::get(
- fullyConnectedShape,
- resultType.dyn_cast<ShapedType>().getElementType());
+ auto fullyConnectedShapeType =
+ RankedTensorType::get(fullyConnectedShape, resultType.getElementType());
Value fullyConnectedValue;
if (op.quantization_info()) {
@@ -512,7 +514,7 @@ struct Conv2DFullyConnectedOptimization
void Conv2DOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
- results.insert<Conv2DFullyConnectedOptimization>(context);
+ results.insert<Conv2DIsFullyConnected>(context);
}
struct DepthwiseConv2DMulOptimization
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 7436bfa8ba9ff..91c2e3ce7feb6 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -105,6 +105,15 @@ func @conv2d_as_fully_connected_quant(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor
// -----
+// CHECK-LABEL: @conv2d_padded
+func @conv2d_padded(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<3x1x1x2xf32>, %arg2: tensor<3xf32>) -> tensor<4x12x12x3xf32> {
+ // CHECK: "tosa.conv2d"
+ %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {pad = [1, 1, 1, 1], stride = [1, 1], dilation = [1, 1]} : (tensor<4x10x10x2xf32>, tensor<3x1x1x2xf32>, tensor<3xf32>) -> tensor<4x12x12x3xf32>
+ return %0 : tensor<4x12x12x3xf32>
+}
+
+// -----
+
// CHECK-LABEL: @conv2d_stride_2
func @conv2d_stride_2(%arg0: tensor<4x10x10x2xf32>) -> tensor<4x10x10x3xf32> {
// CHECK: "tosa.conv2d"
More information about the Mlir-commits
mailing list