[Mlir-commits] [mlir] [mlir][TosaToLinalg] Only support ranked tensor for `reduce` and `gather` (PR #131805)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Mar 18 07:03:36 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-tosa

Author: Longsheng Mou (CoTinker)

<details>
<summary>Changes</summary>

This PR adds checks for ranked tensors in converter of reduce and gather to prevent crash. Fixes #<!-- -->131087.

---
Full diff: https://github.com/llvm/llvm-project/pull/131805.diff


2 Files Affected:

- (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+8-7) 
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir (+16) 


``````````diff
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index b5e0efa71b3cc..c0a25a56dbe2a 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1176,8 +1176,11 @@ template <typename OpTy>
 static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
                                                  PatternRewriter &rewriter) {
   auto loc = op->getLoc();
-  auto inputTy = cast<ShapedType>(op->getOperand(0).getType());
-  auto resultTy = cast<ShapedType>(op->getResult(0).getType());
+  auto inputTy = dyn_cast<RankedTensorType>(op->getOperand(0).getType());
+  auto resultTy = dyn_cast<RankedTensorType>(op->getResult(0).getType());
+  if (!inputTy || !resultTy)
+    return rewriter.notifyMatchFailure(op, "unranked tensors not supported");
+
   auto elementTy = resultTy.getElementType();
   Value input = op->getOperand(0);
 
@@ -2380,11 +2383,9 @@ class GatherConverter : public OpConversionPattern<tosa::GatherOp> {
     auto input = adaptor.getOperands()[0];
     auto indices = adaptor.getOperands()[1];
 
-    auto valuesTy =
-        dyn_cast_or_null<RankedTensorType>(op.getValues().getType());
-    auto resultTy = cast<ShapedType>(op.getType());
-
-    if (!valuesTy)
+    auto valuesTy = dyn_cast<RankedTensorType>(op.getValues().getType());
+    auto resultTy = dyn_cast<RankedTensorType>(op.getType());
+    if (!valuesTy || !resultTy)
       return rewriter.notifyMatchFailure(op, "unranked tensors not supported");
 
     auto dynamicDims = inferDynamicDimsForGather(
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
index d1a9671873de0..d00846a4c3e02 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
@@ -57,3 +57,19 @@ func.func @cast_unsupported_type(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3x!
   %0 = tosa.cast %arg0 : (tensor<13x21x3xi32>) -> tensor<13x21x3x!quant.uniform<i16:f32, 0.078431375324726104:128>>
   return %0 : tensor<13x21x3x!quant.uniform<i16:f32, 0.078431375324726104:128>>
 }
+
+// -----
+
+func.func @unranked_reduce(%arg0: tensor<*xf32>) -> tensor<*xf32> {
+  // expected-error at +1 {{failed to legalize operation 'tosa.reduce_sum'}}
+  %0 = tosa.reduce_sum %arg0 {axis = 0 : i32} : (tensor<*xf32>) -> tensor<*xf32>
+  return %0 : tensor<*xf32>
+}
+
+// -----
+
+func.func @unranked_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) -> tensor<*xf32> {
+  // expected-error at +1 {{failed to legalize operation 'tosa.gather'}}
+  %0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x26xi32>) -> tensor<*xf32>
+  return %0 : tensor<*xf32>
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/131805


More information about the Mlir-commits mailing list