[Mlir-commits] [mlir] [mlir][tosa] Check for 0-ranked-tensors during fold (PR #68512)
Sarthak Gupta
llvmlistbot at llvm.org
Tue Oct 10 05:29:01 PDT 2023
https://github.com/gptsarthak updated https://github.com/llvm/llvm-project/pull/68512
>From 4531a8c0736721af3a0b9be2660da01b35786389 Mon Sep 17 00:00:00 2001
From: gptsarthak <sarthakgpt95 at gmail.com>
Date: Tue, 10 Oct 2023 17:58:45 +0530
Subject: [PATCH] [mlir][tosa] Check for 0-ranked-tensors during fold
Trying getDimSize() before checking for 0-ranked-tensors throws assert errors.
Fixes https://github.com/llvm/llvm-project/issues/67761
---
mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp | 5 +++--
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 2 +-
mlir/test/Dialect/Tosa/canonicalize.mlir | 10 ++++++++++
3 files changed, 14 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index e69c40f2b052395..7444f70a46e9355 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -771,7 +771,7 @@ OpFoldResult ConstOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }
ShapedType inputTy = llvm::cast<ShapedType>(getInput().getType()); \
if (!inputTy.hasRank()) \
return {}; \
- if (inputTy.getDimSize(getAxis()) == 1) \
+ if (inputTy.getRank() == 0 || inputTy.getDimSize(getAxis()) == 1) \
return getInput(); \
return {}; \
}
@@ -874,7 +874,8 @@ OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) {
return operandAttr;
// If the dim-length is 1, tosa.reverse is a no-op.
- if (operandTy.hasRank() && operandTy.getDimSize(axis) == 1)
+ if (operandTy.hasRank() &&
+ (operandTy.getRank() == 0 || operandTy.getDimSize(axis) == 1))
return operand;
return {};
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 0b92a3cb7a6203d..1298518e7b6e61a 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1015,7 +1015,7 @@ LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
static LogicalResult ReduceInferReturnTypes(
ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
- if (!operandShape.hasRank()) {
+ if (!operandShape.hasRank() || operandShape.getRank() == 0) {
inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
return success();
}
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 323864ea9013048..44890a65f40f26b 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -591,3 +591,13 @@ func.func @fold_abs_abs(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
%1 = tosa.abs %0 : (tensor<?x1xf32>) -> tensor<?x1xf32>
return %1 : tensor<?x1xf32>
}
+
+// -----
+
+// CHECK-LABEL: @rank_zero
+func.func nested @rank_zero() {
+ %0 = tensor.empty() : tensor<i32>
+ %1 = tosa.reduce_min %0 {axis = 0 : i32} : (tensor<i32>) -> tensor<1x10xi32>
+ %2 = tosa.reverse %0 {axis = 0 : i32} : (tensor<i32>) -> tensor<1x10xi32>
+ return
+}
More information about the Mlir-commits
mailing list