[Mlir-commits] [mlir] [mlir][tosa] Fix crash in inferReturnTypes for ReduceOps (PR #69843)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Oct 21 06:38:39 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Felix Schneider (ubfx)

<details>
<summary>Changes</summary>

The `tosa.reduce_*` ops take an `axis` Attribute that determines along which dimension the reduction takes place. A crash can occur during shape inference when the input tensor rank is so low that the given axis doesn't exist.

Fix https://github.com/llvm/llvm-project/issues/68187


---
Full diff: https://github.com/llvm/llvm-project/pull/69843.diff


1 Files Affected:

- (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+2-2) 


``````````diff
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index e03904a1611fc42..5292465477b1094 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1109,14 +1109,14 @@ LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
 static LogicalResult ReduceInferReturnTypes(
     ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
-  if (!operandShape.hasRank() || operandShape.getRank() == 0) {
+  int64_t axisVal = axis.getValue().getSExtValue();
+  if (!operandShape.hasRank() || operandShape.getRank() <= axisVal) {
     inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
     return success();
   }
 
   SmallVector<int64_t> outputShape;
   operandShape.getDims(outputShape);
-  int64_t axisVal = axis.getValue().getSExtValue();
   outputShape[axisVal] = 1;
   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
   return success();

``````````

</details>


https://github.com/llvm/llvm-project/pull/69843


More information about the Mlir-commits mailing list