[Mlir-commits] [mlir] [mlir][tosa] Check for 0-ranked-tensors during fold (PR #68512)
Sarthak Gupta
llvmlistbot at llvm.org
Sun Oct 8 02:00:17 PDT 2023
https://github.com/gptsarthak updated https://github.com/llvm/llvm-project/pull/68512
>From 958b38725e0824d3039647f1c5614cc6f312b128 Mon Sep 17 00:00:00 2001
From: gptsarthak <sarthakgpt95 at gmail.com>
Date: Sun, 8 Oct 2023 14:29:55 +0530
Subject: [PATCH] [mlir][tosa] Check for 0-ranked-tensors during fold
Trying getDimSize() before checking for 0-ranked-tensors throws assert errors.
Fixes https://github.com/llvm/llvm-project/issues/67761
---
mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp | 5 +++--
1 file changed, 3 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index e69c40f2b052395..f8b3b97321868a8 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -771,7 +771,7 @@ OpFoldResult ConstOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }
ShapedType inputTy = llvm::cast<ShapedType>(getInput().getType()); \
if (!inputTy.hasRank()) \
return {}; \
- if (inputTy.getDimSize(getAxis()) == 1) \
+ if (inputTy.getRank() != 0 && inputTy.getDimSize(getAxis()) == 1) \
return getInput(); \
return {}; \
}
@@ -874,7 +874,8 @@ OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) {
return operandAttr;
// If the dim-length is 1, tosa.reverse is a no-op.
- if (operandTy.hasRank() && operandTy.getDimSize(axis) == 1)
+ if (operandTy.hasRank() && operandTy.getRank() != 0 &&
+ operandTy.getDimSize(axis) == 1)
return operand;
return {};
More information about the Mlir-commits
mailing list