[Mlir-commits] [mlir] Fix tosa::TransposeOp::inferReturnTypeComponents() (PR #88656)
Maya Amrami
llvmlistbot at llvm.org
Sun Apr 14 04:08:30 PDT 2024
https://github.com/amrami created https://github.com/llvm/llvm-project/pull/88656
None
>From 536570d5aa8200f96a44d90ceabcdd1d4e8a5846 Mon Sep 17 00:00:00 2001
From: Maya Amrami <mayaam88 at gmail.com>
Date: Sun, 14 Apr 2024 11:48:56 +0300
Subject: [PATCH] Fix tosa::TransposeOp::inferReturnTypeComponents()
---
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 13 +++++++++----
1 file changed, 9 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index e06ac9a27ae4cc..e270363d3f3139 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1012,6 +1012,9 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
if (permsShape.hasRank() && permsShape.getRank() == 0)
return failure();
+ Type inputType =
+ adaptor.getInput1().getType().cast<TensorType>().getElementType();
+
// If input rank and permutation length is unknown, the output rank is
// unknown.
if (!inputShape.hasRank() || !permsShape.hasRank() ||
@@ -1029,7 +1032,8 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
SmallVector<int64_t> outputShape;
// Rank-0 means no permutations matter.
if (inputShape.getRank() == 0) {
- inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
+ inferredReturnShapes.push_back(
+ ShapedTypeComponents(outputShape, inputType));
return success();
}
@@ -1046,12 +1050,13 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
// permutation.
if (allTheSame) {
outputShape.resize(inputShape.getRank(), inputShape.getDimSize(0));
- inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
+ inferredReturnShapes.push_back(
+ ShapedTypeComponents(outputShape, inputType));
return success();
}
outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
- // If the permuations are a constant we can directly determine the output
+ // If the permutations are a constant we can directly determine the output
// shape.
DenseIntElementsAttr attr;
if (matchPattern(adaptor.getPerms(), m_Constant(&attr)) &&
@@ -1075,7 +1080,7 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
}
}
- inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
+ inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
return success();
}
More information about the Mlir-commits
mailing list