[Mlir-commits] [mlir] [mlir][tosa] Fix crash in inferReturnTypes for ReduceOps (PR #69843)

Felix Schneider llvmlistbot at llvm.org
Sat Oct 21 06:37:36 PDT 2023


https://github.com/ubfx created https://github.com/llvm/llvm-project/pull/69843

The `tosa.reduce_*` ops take an `axis` Attribute that determines along which dimension the reduction takes place. A crash can occur during shape inference when the input tensor rank is so low that the given axis doesn't exist.

Fix https://github.com/llvm/llvm-project/issues/68187


>From ec1495d5373f436f1f8c230699bdf315035f49e7 Mon Sep 17 00:00:00 2001
From: Felix Schneider <fx.schn at gmail.com>
Date: Sat, 21 Oct 2023 13:09:40 +0200
Subject: [PATCH 1/2] [mlir][tosa] Fix crash in inferReturnTypes for ReduceOps

The `tosa.reduce_*` ops take an `axis` Attribute that determines along
which dimension the reduction takes place. A crash can occur during
shape inference when the input tensor rank is so low that the given
axis doesn't exist.

Fix https://github.com/llvm/llvm-project/issues/68187
---
 mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index e03904a1611fc42..0f616db31c06a5f 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1117,7 +1117,8 @@ static LogicalResult ReduceInferReturnTypes(
   SmallVector<int64_t> outputShape;
   operandShape.getDims(outputShape);
   int64_t axisVal = axis.getValue().getSExtValue();
-  outputShape[axisVal] = 1;
+  if (axisVal < operandShape.getRank())
+    outputShape[axisVal] = 1;
   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
   return success();
 }

>From 9c430215ac1e2fbd969b0e3b0e79de3ff46ff6ad Mon Sep 17 00:00:00 2001
From: Felix Schneider <fx.schn at gmail.com>
Date: Sat, 21 Oct 2023 15:33:19 +0200
Subject: [PATCH 2/2] rebase

---
 mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 7 +++----
 1 file changed, 3 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 0f616db31c06a5f..5292465477b1094 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1109,16 +1109,15 @@ LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
 static LogicalResult ReduceInferReturnTypes(
     ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
-  if (!operandShape.hasRank() || operandShape.getRank() == 0) {
+  int64_t axisVal = axis.getValue().getSExtValue();
+  if (!operandShape.hasRank() || operandShape.getRank() <= axisVal) {
     inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
     return success();
   }
 
   SmallVector<int64_t> outputShape;
   operandShape.getDims(outputShape);
-  int64_t axisVal = axis.getValue().getSExtValue();
-  if (axisVal < operandShape.getRank())
-    outputShape[axisVal] = 1;
+  outputShape[axisVal] = 1;
   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
   return success();
 }



More information about the Mlir-commits mailing list