[Mlir-commits] [mlir] [mlir][tosa] Check for 0-ranked-tensors during fold (PR #68512)

Sarthak Gupta llvmlistbot at llvm.org
Sun Oct 8 01:52:53 PDT 2023


https://github.com/gptsarthak created https://github.com/llvm/llvm-project/pull/68512

Fixes https://github.com/llvm/llvm-project/issues/67761
Trying `getDimSize()` before checking for 0-ranked-tensors throws assert errors. This PR ensures that it is checked for.
Or should we throw an error if we have a 0-ranked-tensor in a tosa operation?

>From 4a9fc20c13a445403057a7741aa94ee5a0b871ae Mon Sep 17 00:00:00 2001
From: gptsarthak <sarthakgpt95 at gmail.com>
Date: Sun, 8 Oct 2023 13:47:09 +0530
Subject: [PATCH] [mlir][tosa] Check for 0-ranked-tensors during fold

Trying getDimSize() before checking for 0-ranked-tensors throws assert errors.
Fixes https://github.com/llvm/llvm-project/issues/67761
---
 mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index e69c40f2b052395..d1182b324d7af6a 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -771,7 +771,7 @@ OpFoldResult ConstOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }
     ShapedType inputTy = llvm::cast<ShapedType>(getInput().getType());         \
     if (!inputTy.hasRank())                                                    \
       return {};                                                               \
-    if (inputTy.getDimSize(getAxis()) == 1)                                    \
+    if (inputTy.getRank() != 0 && inputTy.getDimSize(getAxis()) == 1)          \
       return getInput();                                                       \
     return {};                                                                 \
   }
@@ -874,7 +874,7 @@ OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) {
     return operandAttr;
 
   // If the dim-length is 1, tosa.reverse is a no-op.
-  if (operandTy.hasRank() && operandTy.getDimSize(axis) == 1)
+  if (operandTy.hasRank() && operandTy.getRank() != 0 && operandTy.getDimSize(axis) == 1)
     return operand;
 
   return {};



More information about the Mlir-commits mailing list