[Mlir-commits] [mlir] e0537d1 - [TOSA] Refactor TosaMakeBroadcastable pass

Eric Kunze llvmlistbot at llvm.org
Wed May 24 14:56:15 PDT 2023


Author: Tai Ly
Date: 2023-05-24T14:43:33-07:00
New Revision: e0537d1ad4b9aa41928ecf7eff75d161f456059f

URL: https://github.com/llvm/llvm-project/commit/e0537d1ad4b9aa41928ecf7eff75d161f456059f
DIFF: https://github.com/llvm/llvm-project/commit/e0537d1ad4b9aa41928ecf7eff75d161f456059f.diff

LOG: [TOSA] Refactor TosaMakeBroadcastable pass

This refactors and exposes EqualizeRanks utility function
from within TosaMakeBroadcastable pass so it may be used to
reshape operator inputs to equal ranks.

Signed-off-by: Tai Ly <tai.ly at arm.com>

Differential Revision: https://reviews.llvm.org/D150283

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
    mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
    mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
    mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
    mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
    mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
    mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
index f425d376fbc7e..ca59b221d03eb 100644
--- a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
+++ b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
@@ -72,6 +72,13 @@ checkHasDynamicBatchDims(PatternRewriter &rewriter, Op op,
   return dynamicDims;
 }
 
+/// Common code to create the reshape op where necessary to make the rank of two
+/// values equal. input1 and input2 will be updated when the rank has
+/// changed. The caller is expected to use these to rewrite the original
+/// operator with the RESHAPE now in the graph.
+LogicalResult EqualizeRanks(PatternRewriter &rewriter, Location loc,
+                            Value &input1, Value &input2);
+
 } // namespace tosa
 } // namespace mlir
 

diff  --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
index 488e46d1339a1..e6fba211dc37a 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
@@ -13,6 +13,7 @@
 
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
 #include "mlir/Dialect/Tosa/Transforms/Passes.h"
+#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
 #include "mlir/Pass/Pass.h"
 
 using namespace mlir;
@@ -77,7 +78,9 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
         if (zp == 0)
           return val;
         auto ety = cast<ShapedType>(val.getType()).getElementType();
-        auto zpTy = RankedTensorType::get({}, ety);
+        std::vector<int64_t> shape(cast<ShapedType>(val.getType()).getRank(),
+                                   1);
+        auto zpTy = RankedTensorType::get(shape, ety);
         auto zpAttr =
             DenseElementsAttr::get(zpTy, rewriter.getIntegerAttr(ety, zp));
         auto zpVal = rewriter.create<tosa::ConstOp>(op.getLoc(), zpTy, zpAttr);
@@ -127,6 +130,11 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
     auto mulShapeType = RankedTensorType::get(
         mulShape,
         dyn_cast<RankedTensorType>(weight.getType()).getElementType());
+
+    if (EqualizeRanks(rewriter, op.getLoc(), input, weight).failed()) {
+      return failure();
+    }
+
     Value mulValue = rewriter
                          .create<tosa::MulOp>(op.getLoc(), mulShapeType, input,
                                               weight, /*shift=*/0)
@@ -137,14 +145,18 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
     auto outputShapeType = RankedTensorType::get(
         outputShape,
         dyn_cast<RankedTensorType>(input.getType()).getElementType());
-    auto outputValue = rewriter.create<tosa::ReshapeOp>(
+    Value outputValue = rewriter.create<tosa::ReshapeOp>(
         op.getLoc(), outputShapeType, mulValue,
         rewriter.getDenseI64ArrayAttr(outputShape));
 
+    Value bias = op.getBias();
+    if (EqualizeRanks(rewriter, op.getLoc(), outputValue, bias).failed()) {
+      return failure();
+    }
+
     // Add in the bias.
     rewriter
-        .replaceOpWithNewOp<tosa::AddOp>(op, outputShapeType, outputValue,
-                                         op.getBias())
+        .replaceOpWithNewOp<tosa::AddOp>(op, outputShapeType, outputValue, bias)
         .getResult();
     return success();
   }

diff  --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
index 50a556dfc6945..2339fb7fde3dc 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
@@ -17,6 +17,7 @@
 
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
 #include "mlir/Dialect/Tosa/Transforms/Passes.h"
+#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
 #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
 #include "mlir/Pass/Pass.h"
 
@@ -365,10 +366,14 @@ class TransposeConvStridedConverter
     Value resultPaddingVal = createOpAndInfer<tosa::ConstOp>(
         rewriter, loc, resultPaddingAttr.getType(), resultPaddingAttr);
 
-    auto resultPad = createOpAndInfer<tosa::PadOp>(
+    Value resultPad = createOpAndInfer<tosa::PadOp>(
         rewriter, loc, UnrankedTensorType::get(resultETy), slice,
         resultPaddingVal);
 
+    if (EqualizeRanks(rewriter, op.getLoc(), resultPad, bias).failed()) {
+      return failure();
+    }
+
     rewriter.replaceOpWithNewOp<tosa::AddOp>(op, op.getType(), resultPad, bias);
     return success();
   }

diff  --git a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
index bcfcbbbbcee69..829db2a86c44a 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
@@ -14,6 +14,7 @@
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
 #include "mlir/Dialect/Tosa/Transforms/Passes.h"
+#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
 #include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -28,60 +29,17 @@ namespace tosa {
 using namespace mlir;
 using namespace mlir::tosa;
 
-/// There are two potential ways implementing broadcast:
-/// a. https://www.tensorflow.org/xla/broadcasting#formal_definition
-/// b. https://numpy.org/doc/stable/user/basics.broadcasting.html
-/// This pass implements b (numpy style) now.
-
-/// In this pass, we insert RESHAPE operators to increase the rank of the
-/// lower rank operand as a first step in the broadcasting process. The TOSA
-/// operators that support broadcast require that the rank of the operands
-/// are equal.
-
-// Examples:
-// If lower=[c], higher=[a, b, c], [c] reshaped into [1, 1, c].
-// If lower=[b, c], higher=[a, b, c], [b, c] reshaped into [1, b, c].
-// If lower=[a], higher=[a, a], [a] reshaped into [1, a].
-// If lower=[a], target=[a, b, a], [a] reshaped into [1, 1, a].
-// If lower=[], target=[a, b, c], [] reshaped into [1, 1, 1].
-
-static LogicalResult
-computeReshapeOutput(ArrayRef<int64_t> higherRankShape,
-                     ArrayRef<int64_t> lowerRankShape,
-                     SmallVectorImpl<int64_t> &reshapeOutputShape) {
-  // Initialize new shapes with [1] * higherRank.
-  int64_t higherRank = higherRankShape.size();
-  int64_t lowerRank = lowerRankShape.size();
-
-  reshapeOutputShape.assign(higherRank, 1);
-
-  int64_t higherRankDim;
-  int64_t lowerRankDim;
-
-  for (int64_t i = higherRank - 1, j = lowerRank - 1; i >= 0 && j >= 0;
-       i--, j--) {
-    higherRankDim = higherRankShape[i];
-    lowerRankDim = lowerRankShape[j];
-
-    if (lowerRankDim == 1 && higherRankDim > 1)
-      reshapeOutputShape[i] = 1;
-    else if ((lowerRankDim > 1 && higherRankDim == 1) ||
-             (lowerRankDim == higherRankDim))
-      reshapeOutputShape[i] = lowerRankDim;
-    else if (higherRankDim != lowerRankDim)
-      return failure();
-  }
-  return success();
-}
+namespace {
 
 /// Common code to create the reshape op where necessary to make the rank of the
 /// operations equal. input1 and input2 will be updated when the rank has
 /// changed. The caller is expected to use these to rewrite the original
 /// operator with the RESHAPE now in the graph.
-static LogicalResult reshapeLowerToHigher(PatternRewriter &rewriter,
-                                          Location loc,
-                                          RankedTensorType outputType,
-                                          Value &input1, Value &input2) {
+/// return failure when (1) no reshape needed, or (2) output_type is specified
+/// and it has 
diff erent rank
+LogicalResult reshapeLowerToHigher(PatternRewriter &rewriter, Location loc,
+                                   RankedTensorType outputType, Value &input1,
+                                   Value &input2) {
   auto input1Ty = dyn_cast<RankedTensorType>(input1.getType());
   auto input2Ty = dyn_cast<RankedTensorType>(input2.getType());
 
@@ -96,54 +54,28 @@ static LogicalResult reshapeLowerToHigher(PatternRewriter &rewriter,
     return rewriter.notifyMatchFailure(loc,
                                        "cannot rewrite as its already correct");
 
-  Value higherTensorValue, lowerTensorValue;
-  if (input1Rank > input2Rank) {
-    higherTensorValue = input1;
-    lowerTensorValue = input2;
-  } else {
-    higherTensorValue = input2;
-    lowerTensorValue = input1;
+  Value input1_copy = input1;
+  Value input2_copy = input2;
+  if (EqualizeRanks(rewriter, loc, input1_copy, input2_copy).failed()) {
+    return rewriter.notifyMatchFailure(loc, "failed to reshape inputs");
   }
 
-  ArrayRef<int64_t> higherRankShape =
-      cast<RankedTensorType>(higherTensorValue.getType()).getShape();
-  ArrayRef<int64_t> lowerRankShape =
-      cast<RankedTensorType>(lowerTensorValue.getType()).getShape();
-
-  SmallVector<int64_t, 4> reshapeOutputShape;
-
-  if (computeReshapeOutput(higherRankShape, lowerRankShape, reshapeOutputShape)
-          .failed())
-    return rewriter.notifyMatchFailure(loc, "fail to compute a reshape type");
-
-  auto reshapeInputType = cast<RankedTensorType>(lowerTensorValue.getType());
-  auto reshapeOutputType = RankedTensorType::get(
-      ArrayRef<int64_t>(reshapeOutputShape), reshapeInputType.getElementType());
-
   // Verify the rank agrees with the output type if the output type is ranked.
   if (outputType) {
-    if (outputType.getShape().size() != reshapeOutputShape.size() ||
-        outputType.getShape().size() != higherRankShape.size())
+    if (outputType.getRank() !=
+            input1_copy.getType().cast<RankedTensorType>().getRank() ||
+        outputType.getRank() !=
+            input2_copy.getType().cast<RankedTensorType>().getRank())
       return rewriter.notifyMatchFailure(
           loc, "the reshaped type doesn't agrees with the ranked output type");
   }
 
-  auto reshapeLower = rewriter.create<tosa::ReshapeOp>(
-      loc, reshapeOutputType, lowerTensorValue,
-      rewriter.getDenseI64ArrayAttr(reshapeOutputShape));
-
-  if (input1Rank > input2Rank) {
-    input1 = higherTensorValue;
-    input2 = reshapeLower.getResult();
-  } else {
-    input1 = reshapeLower.getResult();
-    input2 = higherTensorValue;
-  }
+  input1 = input1_copy;
+  input2 = input2_copy;
 
   return success();
 }
 
-namespace {
 template <typename OpTy>
 struct ConvertTosaOp : public OpRewritePattern<OpTy> {
   using OpRewritePattern<OpTy>::OpRewritePattern;
@@ -268,8 +200,10 @@ struct ConvertTosaOp<tosa::SelectOp> : public OpRewritePattern<tosa::SelectOp> {
     int32_t result1Rank = cast<RankedTensorType>(input1.getType()).getRank();
     int32_t result2Rank = cast<RankedTensorType>(input2.getType()).getRank();
     int32_t result3Rank = cast<RankedTensorType>(input3.getType()).getRank();
+    int32_t outputRank = outputType.getRank();
 
-    if ((result1Rank != result2Rank) || (result2Rank != result3Rank))
+    if ((result1Rank != result2Rank) || (result2Rank != result3Rank) ||
+        (result1Rank != outputRank))
       return rewriter.notifyMatchFailure(
           tosaOp, "not all ranks are aligned with each other");
 

diff  --git a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
index 346ff860d2884..8f84a064382f4 100644
--- a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
+++ b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
@@ -11,6 +11,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
 
 using namespace mlir;
 using namespace mlir::tosa;
@@ -60,3 +61,96 @@ bool mlir::tosa::validIntegerRange(IntegerType ty, int64_t value) {
   APInt intMax = APInt::getSignedMaxValue(bitwidth);
   return value >= intMin.getSExtValue() && value <= intMax.getSExtValue();
 }
+
+namespace {
+// Given two tensors of high and low ranks, derive the output shape
+// to reshape the lower rank to.
+// Examples:
+// If lower=[c], higher=[a, b, c], [c] reshaped into [1, 1, c].
+// If lower=[b, c], higher=[a, b, c], [b, c] reshaped into [1, b, c].
+// If lower=[a], higher=[a, a], [a] reshaped into [1, a].
+// If lower=[a], target=[a, b, a], [a] reshaped into [1, 1, a].
+// If lower=[], target=[a, b, c], [] reshaped into [1, 1, 1].
+LogicalResult
+computeReshapeOutput(ArrayRef<int64_t> higherRankShape,
+                     ArrayRef<int64_t> lowerRankShape,
+                     SmallVectorImpl<int64_t> &reshapeOutputShape) {
+  // Initialize new shapes with [1] * higherRank.
+  int64_t higherRank = higherRankShape.size();
+  int64_t lowerRank = lowerRankShape.size();
+
+  reshapeOutputShape.assign(higherRank, 1);
+
+  int64_t higherRankDim;
+  int64_t lowerRankDim;
+
+  for (int64_t i = higherRank - 1, j = lowerRank - 1; i >= 0 && j >= 0;
+       i--, j--) {
+    higherRankDim = higherRankShape[i];
+    lowerRankDim = lowerRankShape[j];
+
+    if (lowerRankDim == 1 && higherRankDim > 1)
+      reshapeOutputShape[i] = 1;
+    else if ((lowerRankDim > 1 && higherRankDim == 1) ||
+             (lowerRankDim == higherRankDim))
+      reshapeOutputShape[i] = lowerRankDim;
+    else if (higherRankDim != lowerRankDim)
+      return failure();
+  }
+  return success();
+}
+} // namespace
+
+LogicalResult mlir::tosa::EqualizeRanks(PatternRewriter &rewriter, Location loc,
+                                        Value &input1, Value &input2) {
+  auto input1Ty = input1.getType().dyn_cast<RankedTensorType>();
+  auto input2Ty = input2.getType().dyn_cast<RankedTensorType>();
+
+  if (!input1Ty || !input2Ty) {
+    return failure();
+  }
+
+  int64_t input1Rank = input1Ty.getRank();
+  int64_t input2Rank = input2Ty.getRank();
+
+  if (input1Rank == input2Rank)
+    return success();
+
+  Value higherTensorValue, lowerTensorValue;
+  if (input1Rank > input2Rank) {
+    higherTensorValue = input1;
+    lowerTensorValue = input2;
+  } else {
+    higherTensorValue = input2;
+    lowerTensorValue = input1;
+  }
+
+  ArrayRef<int64_t> higherRankShape =
+      higherTensorValue.getType().cast<RankedTensorType>().getShape();
+  ArrayRef<int64_t> lowerRankShape =
+      lowerTensorValue.getType().cast<RankedTensorType>().getShape();
+
+  SmallVector<int64_t, 4> reshapeOutputShape;
+
+  if (computeReshapeOutput(higherRankShape, lowerRankShape, reshapeOutputShape)
+          .failed())
+    return failure();
+
+  auto reshapeInputType = lowerTensorValue.getType().cast<RankedTensorType>();
+  auto reshapeOutputType = RankedTensorType::get(
+      ArrayRef<int64_t>(reshapeOutputShape), reshapeInputType.getElementType());
+
+  auto reshapeLower = rewriter.create<tosa::ReshapeOp>(
+      loc, reshapeOutputType, lowerTensorValue,
+      rewriter.getDenseI64ArrayAttr(reshapeOutputShape));
+
+  if (input1Rank > input2Rank) {
+    input1 = higherTensorValue;
+    input2 = reshapeLower.getResult();
+  } else {
+    input1 = reshapeLower.getResult();
+    input2 = higherTensorValue;
+  }
+
+  return success();
+}

diff  --git a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
index e835991273ec5..59e7d35bf77b2 100644
--- a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
@@ -7,13 +7,17 @@ func.func @depthwise_conv2d_as_mul(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<1
   // CHECK-NOT: "tosa.depthwise_conv2d"
   // CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) <{new_shape = array<i64: 4, 10, 10, 2, 1>}
   // CHECK-SAME: -> tensor<4x10x10x2x1xf32>
-  // CHECK: %[[VAR2:.*]] = "tosa.mul"(%[[VAR0]], %arg1)
+  // CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg1) <{new_shape = array<i64: 1, 1, 1, 2, 3>}
+  // CHECK-SAME: -> tensor<1x1x1x2x3xf32>
+  // CHECK: %[[VAR2:.*]] = "tosa.mul"(%[[VAR0]], %[[VAR1]])
   // CHECK-SAME: -> tensor<4x10x10x2x3xf32>
   // CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) <{new_shape = array<i64: 4, 10, 10, 6>}
   // CHECK-SAME: -> tensor<4x10x10x6xf32>
-  // CHECK: %[[VAR4:.*]] = "tosa.add"(%[[VAR3]], %arg2)
+  // CHECK: %[[VAR4:.*]] = "tosa.reshape"(%arg2) <{new_shape = array<i64: 1, 1, 1, 6>}
+  // CHECK-SAME: -> tensor<1x1x1x6xf32>
+  // CHECK: %[[VAR5:.*]] = "tosa.add"(%[[VAR3]], %[[VAR4]])
   // CHECK-SAME: -> tensor<4x10x10x6xf32>
-  // CHECK: return %[[VAR4]]
+  // CHECK: return %[[VAR5]]
   %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<4x10x10x2xf32>, tensor<1x1x2x3xf32>, tensor<6xf32>) -> tensor<4x10x10x6xf32>
   return %0 : tensor<4x10x10x6xf32>
 }
@@ -22,16 +26,18 @@ func.func @depthwise_conv2d_as_mul(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<1
 
 // CHECK-LABEL: @depthwise_conv2d_as_mul_q
 func.func @depthwise_conv2d_as_mul_q(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<1x1x2x3xi8>, %arg2: tensor<6xi32>) -> tensor<4x10x10x6xi32> {
-  // CHECK: %[[iZp:.+]] = "tosa.const"() <{value = dense<7> : tensor<i32>}
-  // CHECK: %[[wZp:.+]] = "tosa.const"() <{value = dense<11> : tensor<i32>}
+  // CHECK: %[[iZp:.+]] = "tosa.const"() <{value = dense<7> : tensor<1x1x1x1x1xi32>}
+  // CHECK: %[[wZp:.+]] = "tosa.const"() <{value = dense<11> : tensor<1x1x1x1xi32>}
   // CHECK: %[[rIn:.+]] = "tosa.reshape"(%arg0) <{new_shape = array<i64: 4, 10, 10, 2, 1>}
   // CHECK: %[[cIn:.+]] = "tosa.cast"(%[[rIn]]) : (tensor<4x10x10x2x1xi8>) -> tensor<4x10x10x2x1xi32>
   // CHECK: %[[cWe:.+]] = "tosa.cast"(%arg1) : (tensor<1x1x2x3xi8>) -> tensor<1x1x2x3xi32>
   // CHECK: %[[sIn:.+]] = "tosa.sub"(%[[cIn]], %[[iZp]])
   // CHECK: %[[sWe:.+]] = "tosa.sub"(%[[cWe]], %[[wZp]])
-  // CHECK: %[[mul:.+]] = "tosa.mul"(%[[sIn]], %[[sWe]]) <{shift = 0 : i32}
+  // CHECK: %[[resWe:.+]] = "tosa.reshape"(%[[sWe]]) <{new_shape = array<i64: 1, 1, 1, 2, 3>}
+  // CHECK: %[[mul:.+]] = "tosa.mul"(%[[sIn]], %[[resWe]]) <{shift = 0 : i32}
   // CHECK: %[[reO:.+]] = "tosa.reshape"(%[[mul]]) <{new_shape = array<i64: 4, 10, 10, 6>}
-  // CHECK: %[[add:.+]] = "tosa.add"(%[[reO]], %arg2)
+  // CHECK: %[[reArg2:.+]] = "tosa.reshape"(%arg2) <{new_shape = array<i64: 1, 1, 1, 6>}
+  // CHECK: %[[add:.+]] = "tosa.add"(%[[reO]], %[[reArg2]])
   %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>, quantization_info = #tosa.conv_quant<input_zp = 7, weight_zp = 11>} : (tensor<4x10x10x2xi8>, tensor<1x1x2x3xi8>, tensor<6xi32>) -> tensor<4x10x10x6xi32>
   return %0 : tensor<4x10x10x6xi32>
 }
@@ -44,9 +50,11 @@ func.func @depthwise_conv2d_as_mul_padded(%arg0: tensor<4x10x10x2xf32>, %arg1: t
   // CHECK: %[[zero:.+]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<f32>}
   // CHECK: %[[reIn:.+]] = "tosa.reshape"(%arg0) <{new_shape = array<i64: 4, 10, 10, 2, 1>}
   // CHECK: %[[padded:.+]] = "tosa.pad"(%[[reIn]], %[[pad]], %[[zero]]) : (tensor<4x10x10x2x1xf32>, tensor<5x2xi64>, tensor<f32>) -> tensor<4x12x12x2x1xf32>
-  // CHECK: %[[mul:.+]] = "tosa.mul"(%3, %arg1) <{shift = 0 : i32}
+  // CHECK: %[[reArg1:.+]] = "tosa.reshape"(%arg1) <{new_shape = array<i64: 1, 1, 1, 2, 3>}
+  // CHECK: %[[mul:.+]] = "tosa.mul"(%3, %[[reArg1]]) <{shift = 0 : i32}
   // CHECK: %[[reOut:.+]] = "tosa.reshape"(%[[mul]]) <{new_shape = array<i64: 4, 12, 12, 6>}
-  // CHECK: %[[add:.+]] = "tosa.add"(%[[reOut]], %arg2)
+  // CHECK: %[[reArg2:.+]] = "tosa.reshape"(%arg2) <{new_shape = array<i64: 1, 1, 1, 6>}
+  // CHECK: %[[add:.+]] = "tosa.add"(%[[reOut]], %[[reArg2]])
   %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = array<i64: 1, 1, 1, 1>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<4x10x10x2xf32>, tensor<1x1x2x3xf32>, tensor<6xf32>) -> tensor<4x12x12x6xf32>
   return %0 : tensor<4x12x12x6xf32>
 }

diff  --git a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
index 6ccf510804d99..daac52eddf204 100644
--- a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
@@ -28,7 +28,7 @@ func.func @transpose_conv2d_quantized(%arg0: tensor<2x16x14x3xi8>, %arg1: tensor
 func.func @transpose_conv2d_quantized_padded(%arg0: tensor<2x16x14x3xi8>, %arg1: tensor<5x3x6x3xi8>, %arg2: tensor<5xi32>) -> (tensor<2x21x26x5xi32>) {
   // CHECK-DAG: %[[REV0:.+]] = "tosa.reverse"(%0) <{axis = 2 : i64}
   // CHECK-DAG: %[[REV1:.+]] = "tosa.reverse"(%arg1) <{axis = 1 : i64}
-  // CHECK: "tosa.conv2d"(%arg0, %1, %arg2) 
+  // CHECK: "tosa.conv2d"(%arg0, %1, %arg2)
   // CHECK-SAME: dilation = array<i64: 1, 1>, pad = array<i64: 3, 4, 8, 9>,
   // CHECK-SAME: quantization_info = #tosa.conv_quant<input_zp = -22, weight_zp = 42>, stride = array<i64: 1, 1>}
   %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {
@@ -65,7 +65,8 @@ func.func @transpose_conv2d_strided(%arg0: tensor<2x17x15x3xf32>, %arg1: tensor<
   // CHECK-DAG: %[[TRANS_OUT:.+]] = "tosa.transpose"(%[[RESHAPE_OUT_1]], %[[TRANS2]])
   // CHECK-DAG: %[[RESHAPE_OUT_2:.+]] = "tosa.reshape"(%[[TRANS_OUT]]) <{new_shape = array<i64: 2, 36, 48, 5>}
   // CHECK-DAG: %[[SLICE:.+]] = "tosa.slice"(%[[RESHAPE_OUT_2]]) <{size = array<i64: 2, 35, 47, 5>, start = array<i64: 0, 0, 0, 0>}
-  // CHECK: %[[ADD:.+]] = "tosa.add"(%[[SLICE]], %arg2)
+  // CHECK-DAG: %[[RESHAPE_ARG2:.+]] = "tosa.reshape"(%arg2) <{new_shape = array<i64: 1, 1, 1, 5>}
+  // CHECK: %[[ADD:.+]] = "tosa.add"(%[[SLICE]], %[[RESHAPE_ARG2]])
   %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 2, 3>} : (tensor<2x17x15x3xf32>, tensor<5x3x5x3xf32>, tensor<5xf32>) -> tensor<2x35x47x5xf32>
   %1 = tensor.cast %0 : tensor<2x35x47x5xf32> to tensor<2x?x?x5xf32>
   return %1 : tensor<2x?x?x5xf32>
@@ -97,8 +98,9 @@ func.func @transpose_conv2d_strided_quantized(%arg0: tensor<2x17x15x3xi8>, %arg1
   // CHECK-DAG: %[[TRANS_OUT:.+]] = "tosa.transpose"(%[[RESHAPE_OUT_1]], %[[TRANS2]])
   // CHECK-DAG: %[[RESHAPE_OUT_2:.+]] = "tosa.reshape"(%[[TRANS_OUT]]) <{new_shape = array<i64: 2, 36, 48, 5>}
   // CHECK-DAG: %[[SLICE:.+]] = "tosa.slice"(%[[RESHAPE_OUT_2]]) <{size = array<i64: 2, 35, 47, 5>, start = array<i64: 0, 0, 0, 0>}
-  // CHECK: %[[ADD:.+]] = "tosa.add"(%[[SLICE]], %arg2)
-  %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) <{out_pad = array<i64: 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = -22, weight_zp = 42>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 2, 3>}> : (tensor<2x17x15x3xi8>, tensor<5x3x5x3xi8>, tensor<5xi32>) -> tensor<2x35x47x5xi32>
+  // CHECK-DAG: %[[RESHAPE_ARG2:.+]] = "tosa.reshape"(%arg2) <{new_shape = array<i64: 1, 1, 1, 5>}
+  // CHECK: %[[ADD:.+]] = "tosa.add"(%[[SLICE]], %[[RESHAPE_ARG2]])
+  %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {out_pad = array<i64: 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = -22, weight_zp = 42>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 2, 3>} : (tensor<2x17x15x3xi8>, tensor<5x3x5x3xi8>, tensor<5xi32>) -> tensor<2x35x47x5xi32>
   return %0 : tensor<2x35x47x5xi32>
 }
 
@@ -106,14 +108,14 @@ func.func @transpose_conv2d_strided_quantized(%arg0: tensor<2x17x15x3xi8>, %arg1
 
 // CHECK-LABEL: @transpose_conv2d_strided_overpad
 func.func @transpose_conv2d_strided_overpad(%arg0 : tensor<1x16x1x1xi8>, %arg1 : tensor<1x2x1x1xi8>, %arg2 : tensor<1xi32>) -> (tensor<1x19x2x1xi32>) {
-  // CHECK: %[[WEIGHT_PAD:.+]] = "tosa.const"() 
+  // CHECK: %[[WEIGHT_PAD:.+]] = "tosa.const"()
   // CHECK-SAME{literal}: value = dense<[[0, 0], [0, 0], [0, 1], [0, 0]]> : tensor<4x2xi32>
   // CHECK: %[[WEIGHT_PERMS:.+]] = "tosa.const"() <{value = dense<[2, 4, 0, 1, 3, 5]> : tensor<6xi32>}
-  // CHECK: %[[INPUT_PAD:.+]] = "tosa.const"() 
+  // CHECK: %[[INPUT_PAD:.+]] = "tosa.const"()
   // CHECK-SAME{literal}: value = dense<[[0, 0], [1, 1], [0, 0], [0, 0]]> : tensor<4x2xi32>}
   // CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0> : tensor<2xi32>}
   // CHECK: %[[RESULT_PERMS:.+]] = "tosa.const"() <{value = dense<[0, 1, 3, 2, 4, 5]> : tensor<6xi32>}
-  // CHECK: %[[RESULT_PAD:.+]] = "tosa.const"() 
+  // CHECK: %[[RESULT_PAD:.+]] = "tosa.const"()
   // CHECK-SAME{literal}: value = dense<[[0, 0], [2, 0], [0, 0], [0, 0]]> : tensor<4x2xi32>}
   // CHECK: %[[PAD_WEIGHT:.+]] = "tosa.pad"(%arg1, %[[WEIGHT_PAD]]) <{quantization_info = #tosa.pad_quant<input_zp = 93>}
   // CHECK: %[[RESHAPE_WEIGHT_0:.+]] = "tosa.reshape"(%[[PAD_WEIGHT]]) <{new_shape = array<i64: 1, 2, 1, 1, 2, 1>}
@@ -121,13 +123,14 @@ func.func @transpose_conv2d_strided_overpad(%arg0 : tensor<1x16x1x1xi8>, %arg1 :
   // CHECK: %[[RESHAPE_WEIGHT_1:.+]] = "tosa.reshape"(%[[TRANSPOSE_WEIGHT]]) <{new_shape = array<i64: 2, 2, 1, 1>}
   // CHECK: %[[REVERSE:.+]] = "tosa.reverse"(%[[RESHAPE_WEIGHT_1]]) <{axis = 1 : i64}
   // CHECK: %[[PAD_INPUT:.+]] = "tosa.pad"(%arg0, %[[INPUT_PAD]]) <{quantization_info = #tosa.pad_quant<input_zp = -103>}
-  // CHECK: %[[CONV:.+]] = "tosa.conv2d"(%[[PAD_INPUT]], %[[REVERSE]], %[[ZERO]]) 
+  // CHECK: %[[CONV:.+]] = "tosa.conv2d"(%[[PAD_INPUT]], %[[REVERSE]], %[[ZERO]])
   // CHECK-SAME{literal}: dilation = [1, 1], pad = [0, 0, 0, 0], quantization_info = #tosa.conv_quant<input_zp = -103, weight_zp = 93>, stride = [1, 1]}
   // CHECK: %[[RESHAPE_RESULT_0:.+]] = "tosa.reshape"(%[[CONV]]) <{new_shape = array<i64: 1, 17, 1, 1, 2, 1>}
   // CHECK: %[[TRANSPOSE_RESULT:.+]] = "tosa.transpose"(%[[RESHAPE_RESULT_0]], %[[RESULT_PERMS]])
   // CHECK: %[[RESHAPE_RESULT_1:.+]] = "tosa.reshape"(%[[TRANSPOSE_RESULT]]) <{new_shape = array<i64: 1, 17, 2, 1>}
   // CHECK: %[[PAD_RESULT:.+]] = "tosa.pad"(%[[RESHAPE_RESULT_1]], %[[RESULT_PAD]])
-  // CHECK: %[[ADD:.+]] = "tosa.add"(%[[PAD_RESULT]], %arg2)
+  // CHECK: %[[RESHAPE_ARG2:.+]] = "tosa.reshape"(%arg2) <{new_shape = array<i64: 1, 1, 1, 1>}
+  // CHECK: %[[ADD:.+]] = "tosa.add"(%[[PAD_RESULT]], %[[RESHAPE_ARG2]])
   %2 =  "tosa.transpose_conv2d"(%arg0, %arg1, %arg2)  {
     out_pad = array<i64: 2, 0, 0, 1>,
     out_shape = array<i64: 1, -1, -1, 1>,


        


More information about the Mlir-commits mailing list