[Mlir-commits] [mlir] eb04f32 - [tosa] Add legalization for conv3d

Rob Suderman llvmlistbot at llvm.org
Thu Oct 6 13:04:30 PDT 2022

Author: TatWai Chong
Date: 2022-10-06T12:50:38-07:00
New Revision: eb04f321c344175e4510c3747d83a308bde96d68

URL: https://github.com/llvm/llvm-project/commit/eb04f321c344175e4510c3747d83a308bde96d68
DIFF: https://github.com/llvm/llvm-project/commit/eb04f321c344175e4510c3747d83a308bde96d68.diff

LOG: [tosa] Add legalization for conv3d

Update the existing implementation to match TOSA spec.

Reviewed By: rsuderman

Differential Revision: https://reviews.llvm.org/D133062




diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 8df30279cb7e0..841a27479bada 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1070,7 +1070,7 @@ LogicalResult Conv3DOp::inferReturnTypeComponents(
     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
   llvm::SmallVector<int64_t> outputShape(5, ShapedType::kDynamicSize);
-  Conv2DOp::Adaptor adaptor(operands.getValues(), attributes);
+  Conv3DOp::Adaptor adaptor(operands.getValues(), attributes);
   int32_t inputWidth = ShapedType::kDynamicSize;
   int32_t inputHeight = ShapedType::kDynamicSize;
@@ -1084,55 +1084,54 @@ LogicalResult Conv3DOp::inferReturnTypeComponents(
   ShapeAdaptor inputShape = operands.getShape(adaptor.getInput());
   if (inputShape.hasRank()) {
     outputShape[0] = inputShape.getDimSize(0);
-    inputHeight = inputShape.getDimSize(1);
-    inputWidth = inputShape.getDimSize(2);
-    inputDepth = inputShape.getDimSize(3);
+    inputDepth = inputShape.getDimSize(1);
+    inputHeight = inputShape.getDimSize(2);
+    inputWidth = inputShape.getDimSize(3);
   // Weight shapes describes the filter width/height and the output channels.
   ShapeAdaptor weightShape = operands.getShape(adaptor.getWeight());
   if (weightShape.hasRank()) {
     outputShape[4] = weightShape.getDimSize(0);
-    weightHeight = weightShape.getDimSize(1);
-    weightWidth = weightShape.getDimSize(2);
-    weightDepth = weightShape.getDimSize(3);
+    weightDepth = weightShape.getDimSize(1);
+    weightHeight = weightShape.getDimSize(2);
+    weightWidth = weightShape.getDimSize(3);
   // Bias shape can describe the output channels.
   ShapeAdaptor biasShape = operands.getShape(adaptor.getBias());
-  if (biasShape.hasRank()) {
-    outputShape[4] =
-        (outputShape[4] == -1) ? biasShape.getDimSize(0) : outputShape[4];
+  if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[4])) {
+    outputShape[4] = biasShape.getDimSize(0);
   llvm::SmallVector<int64_t> dilation;
-  llvm::SmallVector<int64_t> padding;
+  llvm::SmallVector<int64_t> pad;
   llvm::SmallVector<int64_t> stride;
   getI64Values(adaptor.getDilation(), dilation);
-  getI64Values(adaptor.getPad(), padding);
+  getI64Values(adaptor.getPad(), pad);
   getI64Values(adaptor.getStride(), stride);
-  if (!ShapedType::isDynamic(inputHeight) &&
-      !ShapedType::isDynamic(weightHeight)) {
-    int32_t inputSize = inputHeight + padding[0] + padding[1];
-    int32_t filterSize = (weightHeight - 1) * dilation[0] + 1;
+  if (!ShapedType::isDynamic(inputDepth) &&
+      !ShapedType::isDynamic(weightDepth)) {
+    int32_t inputSize = inputDepth + pad[0] + pad[1];
+    int32_t filterSize = (weightDepth - 1) * dilation[0] + 1;
     int32_t unstridedResult = inputSize - filterSize + 1;
     outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
-  if (!ShapedType::isDynamic(inputWidth) &&
-      !ShapedType::isDynamic(weightWidth)) {
-    int32_t inputSize = inputWidth + padding[2] + padding[3];
-    int32_t filterSize = (weightWidth - 1) * dilation[1] + 1;
+  if (!ShapedType::isDynamic(inputHeight) &&
+      !ShapedType::isDynamic(weightHeight)) {
+    int32_t inputSize = inputHeight + pad[2] + pad[3];
+    int32_t filterSize = (weightHeight - 1) * dilation[1] + 1;
     int32_t unstridedResult = inputSize - filterSize + 1;
     outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
-  if (!ShapedType::isDynamic(inputDepth) &&
-      !ShapedType::isDynamic(weightDepth)) {
-    int32_t inputSize = inputDepth + padding[4] + padding[5];
-    int32_t filterSize = (weightDepth - 1) * dilation[2] + 1;
+  if (!ShapedType::isDynamic(inputWidth) &&
+      !ShapedType::isDynamic(weightWidth)) {
+    int32_t inputSize = inputWidth + pad[4] + pad[5];
+    int32_t filterSize = (weightWidth - 1) * dilation[2] + 1;
     int32_t unstridedResult = inputSize - filterSize + 1;
     outputShape[3] = (unstridedResult - 1) / stride[2] + 1;


More information about the Mlir-commits mailing list