[Mlir-commits] [mlir] [mlir][tosa] Fix crash in inferReturnTypes for ReduceOps (PR #69843)
Felix Schneider
llvmlistbot at llvm.org
Sat Oct 21 06:37:36 PDT 2023
https://github.com/ubfx created https://github.com/llvm/llvm-project/pull/69843
The `tosa.reduce_*` ops take an `axis` Attribute that determines along which dimension the reduction takes place. A crash can occur during shape inference when the input tensor rank is so low that the given axis doesn't exist.
Fix https://github.com/llvm/llvm-project/issues/68187
>From ec1495d5373f436f1f8c230699bdf315035f49e7 Mon Sep 17 00:00:00 2001
From: Felix Schneider <fx.schn at gmail.com>
Date: Sat, 21 Oct 2023 13:09:40 +0200
Subject: [PATCH 1/2] [mlir][tosa] Fix crash in inferReturnTypes for ReduceOps
The `tosa.reduce_*` ops take an `axis` Attribute that determines along
which dimension the reduction takes place. A crash can occur during
shape inference when the input tensor rank is so low that the given
axis doesn't exist.
Fix https://github.com/llvm/llvm-project/issues/68187
---
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index e03904a1611fc42..0f616db31c06a5f 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1117,7 +1117,8 @@ static LogicalResult ReduceInferReturnTypes(
SmallVector<int64_t> outputShape;
operandShape.getDims(outputShape);
int64_t axisVal = axis.getValue().getSExtValue();
- outputShape[axisVal] = 1;
+ if (axisVal < operandShape.getRank())
+ outputShape[axisVal] = 1;
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
return success();
}
>From 9c430215ac1e2fbd969b0e3b0e79de3ff46ff6ad Mon Sep 17 00:00:00 2001
From: Felix Schneider <fx.schn at gmail.com>
Date: Sat, 21 Oct 2023 15:33:19 +0200
Subject: [PATCH 2/2] rebase
---
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 7 +++----
1 file changed, 3 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 0f616db31c06a5f..5292465477b1094 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1109,16 +1109,15 @@ LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
static LogicalResult ReduceInferReturnTypes(
ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
- if (!operandShape.hasRank() || operandShape.getRank() == 0) {
+ int64_t axisVal = axis.getValue().getSExtValue();
+ if (!operandShape.hasRank() || operandShape.getRank() <= axisVal) {
inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
return success();
}
SmallVector<int64_t> outputShape;
operandShape.getDims(outputShape);
- int64_t axisVal = axis.getValue().getSExtValue();
- if (axisVal < operandShape.getRank())
- outputShape[axisVal] = 1;
+ outputShape[axisVal] = 1;
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
return success();
}
More information about the Mlir-commits
mailing list