[Mlir-commits] [mlir] [mlir][tosa] Fix crash in inferReturnTypes for ReduceOps (PR #69843)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Oct 21 06:38:39 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Felix Schneider (ubfx)
<details>
<summary>Changes</summary>
The `tosa.reduce_*` ops take an `axis` Attribute that determines along which dimension the reduction takes place. A crash can occur during shape inference when the input tensor rank is so low that the given axis doesn't exist.
Fix https://github.com/llvm/llvm-project/issues/68187
---
Full diff: https://github.com/llvm/llvm-project/pull/69843.diff
1 Files Affected:
- (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+2-2)
``````````diff
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index e03904a1611fc42..5292465477b1094 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1109,14 +1109,14 @@ LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
static LogicalResult ReduceInferReturnTypes(
ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
- if (!operandShape.hasRank() || operandShape.getRank() == 0) {
+ int64_t axisVal = axis.getValue().getSExtValue();
+ if (!operandShape.hasRank() || operandShape.getRank() <= axisVal) {
inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
return success();
}
SmallVector<int64_t> outputShape;
operandShape.getDims(outputShape);
- int64_t axisVal = axis.getValue().getSExtValue();
outputShape[axisVal] = 1;
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
return success();
``````````
</details>
https://github.com/llvm/llvm-project/pull/69843
More information about the Mlir-commits
mailing list