[Mlir-commits] [mlir] Fix tosa::TransposeOp::inferReturnTypeComponents() (PR #88656)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Apr 14 04:09:03 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Maya Amrami (amrami)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/88656.diff


1 Files Affected:

- (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+9-4) 


``````````diff
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index e06ac9a27ae4cc..e270363d3f3139 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1012,6 +1012,9 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
   if (permsShape.hasRank() && permsShape.getRank() == 0)
     return failure();
 
+  Type inputType =
+      adaptor.getInput1().getType().cast<TensorType>().getElementType();
+
   // If input rank and permutation length is unknown, the output rank is
   // unknown.
   if (!inputShape.hasRank() || !permsShape.hasRank() ||
@@ -1029,7 +1032,8 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
   SmallVector<int64_t> outputShape;
   // Rank-0 means no permutations matter.
   if (inputShape.getRank() == 0) {
-    inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
+    inferredReturnShapes.push_back(
+        ShapedTypeComponents(outputShape, inputType));
     return success();
   }
 
@@ -1046,12 +1050,13 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
   // permutation.
   if (allTheSame) {
     outputShape.resize(inputShape.getRank(), inputShape.getDimSize(0));
-    inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
+    inferredReturnShapes.push_back(
+        ShapedTypeComponents(outputShape, inputType));
     return success();
   }
 
   outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
-  // If the permuations are a constant we can directly determine the output
+  // If the permutations are a constant we can directly determine the output
   // shape.
   DenseIntElementsAttr attr;
   if (matchPattern(adaptor.getPerms(), m_Constant(&attr)) &&
@@ -1075,7 +1080,7 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
     }
   }
 
-  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
+  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
   return success();
 }
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/88656


More information about the Mlir-commits mailing list