[Mlir-commits] [mlir] [mlir][TosaToLinalg] Only support ranked tensor for `reduce` and `gather` (PR #131805)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Mar 18 07:03:36 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-tosa
Author: Longsheng Mou (CoTinker)
<details>
<summary>Changes</summary>
This PR adds checks for ranked tensors in converter of reduce and gather to prevent crash. Fixes #<!-- -->131087.
---
Full diff: https://github.com/llvm/llvm-project/pull/131805.diff
2 Files Affected:
- (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+8-7)
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir (+16)
``````````diff
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index b5e0efa71b3cc..c0a25a56dbe2a 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1176,8 +1176,11 @@ template <typename OpTy>
static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
PatternRewriter &rewriter) {
auto loc = op->getLoc();
- auto inputTy = cast<ShapedType>(op->getOperand(0).getType());
- auto resultTy = cast<ShapedType>(op->getResult(0).getType());
+ auto inputTy = dyn_cast<RankedTensorType>(op->getOperand(0).getType());
+ auto resultTy = dyn_cast<RankedTensorType>(op->getResult(0).getType());
+ if (!inputTy || !resultTy)
+ return rewriter.notifyMatchFailure(op, "unranked tensors not supported");
+
auto elementTy = resultTy.getElementType();
Value input = op->getOperand(0);
@@ -2380,11 +2383,9 @@ class GatherConverter : public OpConversionPattern<tosa::GatherOp> {
auto input = adaptor.getOperands()[0];
auto indices = adaptor.getOperands()[1];
- auto valuesTy =
- dyn_cast_or_null<RankedTensorType>(op.getValues().getType());
- auto resultTy = cast<ShapedType>(op.getType());
-
- if (!valuesTy)
+ auto valuesTy = dyn_cast<RankedTensorType>(op.getValues().getType());
+ auto resultTy = dyn_cast<RankedTensorType>(op.getType());
+ if (!valuesTy || !resultTy)
return rewriter.notifyMatchFailure(op, "unranked tensors not supported");
auto dynamicDims = inferDynamicDimsForGather(
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
index d1a9671873de0..d00846a4c3e02 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
@@ -57,3 +57,19 @@ func.func @cast_unsupported_type(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3x!
%0 = tosa.cast %arg0 : (tensor<13x21x3xi32>) -> tensor<13x21x3x!quant.uniform<i16:f32, 0.078431375324726104:128>>
return %0 : tensor<13x21x3x!quant.uniform<i16:f32, 0.078431375324726104:128>>
}
+
+// -----
+
+func.func @unranked_reduce(%arg0: tensor<*xf32>) -> tensor<*xf32> {
+ // expected-error at +1 {{failed to legalize operation 'tosa.reduce_sum'}}
+ %0 = tosa.reduce_sum %arg0 {axis = 0 : i32} : (tensor<*xf32>) -> tensor<*xf32>
+ return %0 : tensor<*xf32>
+}
+
+// -----
+
+func.func @unranked_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) -> tensor<*xf32> {
+ // expected-error at +1 {{failed to legalize operation 'tosa.gather'}}
+ %0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x26xi32>) -> tensor<*xf32>
+ return %0 : tensor<*xf32>
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/131805
More information about the Mlir-commits
mailing list