[Mlir-commits] [mlir] 783b4d9 - [mlir][tosa] Check for 0-ranked-tensors during fold (#68512)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Oct 20 04:03:28 PDT 2023


Author: Sarthak Gupta
Date: 2023-10-20T12:03:24+01:00
New Revision: 783b4d91c73c992fad32e045ce3265b01028fc99

URL: https://github.com/llvm/llvm-project/commit/783b4d91c73c992fad32e045ce3265b01028fc99
DIFF: https://github.com/llvm/llvm-project/commit/783b4d91c73c992fad32e045ce3265b01028fc99.diff

LOG: [mlir][tosa] Check for 0-ranked-tensors during fold (#68512)

Fixes https://github.com/llvm/llvm-project/issues/67761
Trying `getDimSize()` before checking for 0-ranked-tensors throws assert
errors. This PR ensures that it is checked for.
Or should we throw an error if we have a 0-ranked-tensor in a tosa
operation?

Added: 
    

Modified: 
    mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
    mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
    mlir/test/Dialect/Tosa/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index e69c40f2b052395..7444f70a46e9355 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -771,7 +771,7 @@ OpFoldResult ConstOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }
     ShapedType inputTy = llvm::cast<ShapedType>(getInput().getType());         \
     if (!inputTy.hasRank())                                                    \
       return {};                                                               \
-    if (inputTy.getDimSize(getAxis()) == 1)                                    \
+    if (inputTy.getRank() == 0 || inputTy.getDimSize(getAxis()) == 1)          \
       return getInput();                                                       \
     return {};                                                                 \
   }
@@ -874,7 +874,8 @@ OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) {
     return operandAttr;
 
   // If the dim-length is 1, tosa.reverse is a no-op.
-  if (operandTy.hasRank() && operandTy.getDimSize(axis) == 1)
+  if (operandTy.hasRank() &&
+      (operandTy.getRank() == 0 || operandTy.getDimSize(axis) == 1))
     return operand;
 
   return {};

diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 2b7f5bee6b7dcad..e03904a1611fc42 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1109,7 +1109,7 @@ LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
 static LogicalResult ReduceInferReturnTypes(
     ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
-  if (!operandShape.hasRank()) {
+  if (!operandShape.hasRank() || operandShape.getRank() == 0) {
     inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
     return success();
   }

diff  --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index d36cf6a1d94a9f3..dddf15fffbb7aec 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -591,3 +591,15 @@ func.func @fold_abs_abs(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   %1 = tosa.abs %0 : (tensor<?x1xf32>) -> tensor<?x1xf32>
   return %1 : tensor<?x1xf32>
 }
+
+// -----
+
+// CHECK-LABEL: @fold_reduce_rank_zero
+func.func nested @fold_reduce_rank_zero() {
+  // CHECK-NOT: tosa.reduce_min
+  // CHECK-NOT: tosa.reverse
+  %0 = tensor.empty() : tensor<i32>
+  %1 = tosa.reduce_min %0 {axis = 0 : i32} : (tensor<i32>) -> tensor<1x10xi32>
+  %2 = tosa.reverse %0 {axis = 0 : i32} : (tensor<i32>) -> tensor<1x10xi32>
+  return
+}


        


More information about the Mlir-commits mailing list