[Mlir-commits] [mlir] Fix tosa::TransposeOp::inferReturnTypeComponents() (PR #88656)

Maya Amrami llvmlistbot at llvm.org
Sun Apr 14 04:08:30 PDT 2024


https://github.com/amrami created https://github.com/llvm/llvm-project/pull/88656

None

>From 536570d5aa8200f96a44d90ceabcdd1d4e8a5846 Mon Sep 17 00:00:00 2001
From: Maya Amrami <mayaam88 at gmail.com>
Date: Sun, 14 Apr 2024 11:48:56 +0300
Subject: [PATCH] Fix tosa::TransposeOp::inferReturnTypeComponents()

---
 mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 13 +++++++++----
 1 file changed, 9 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index e06ac9a27ae4cc..e270363d3f3139 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1012,6 +1012,9 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
   if (permsShape.hasRank() && permsShape.getRank() == 0)
     return failure();
 
+  Type inputType =
+      adaptor.getInput1().getType().cast<TensorType>().getElementType();
+
   // If input rank and permutation length is unknown, the output rank is
   // unknown.
   if (!inputShape.hasRank() || !permsShape.hasRank() ||
@@ -1029,7 +1032,8 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
   SmallVector<int64_t> outputShape;
   // Rank-0 means no permutations matter.
   if (inputShape.getRank() == 0) {
-    inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
+    inferredReturnShapes.push_back(
+        ShapedTypeComponents(outputShape, inputType));
     return success();
   }
 
@@ -1046,12 +1050,13 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
   // permutation.
   if (allTheSame) {
     outputShape.resize(inputShape.getRank(), inputShape.getDimSize(0));
-    inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
+    inferredReturnShapes.push_back(
+        ShapedTypeComponents(outputShape, inputType));
     return success();
   }
 
   outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
-  // If the permuations are a constant we can directly determine the output
+  // If the permutations are a constant we can directly determine the output
   // shape.
   DenseIntElementsAttr attr;
   if (matchPattern(adaptor.getPerms(), m_Constant(&attr)) &&
@@ -1075,7 +1080,7 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
     }
   }
 
-  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
+  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
   return success();
 }
 



More information about the Mlir-commits mailing list