[Mlir-commits] [mlir] [mlir][TosaToLinalg] Only support ranked tensor for `reduce` and `gather` (PR #131805)

Longsheng Mou llvmlistbot at llvm.org
Tue Mar 18 07:02:51 PDT 2025


https://github.com/CoTinker created https://github.com/llvm/llvm-project/pull/131805

This PR adds checks for ranked tensors in converter of reduce and gather to prevent crash. Fixes #131087.

>From eb66a8e5dc7897e946f1cb53052dd0d0faa024b2 Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Tue, 18 Mar 2025 21:56:55 +0800
Subject: [PATCH] [mlir][TosaToLinalg] Only support ranked tensor for reduce
 and gather

This PR adds checks for ranked tensors in converter of reduce and gather
to prevent crash.
---
 .../lib/Conversion/TosaToLinalg/TosaToLinalg.cpp | 15 ++++++++-------
 .../TosaToLinalg/tosa-to-linalg-invalid.mlir     | 16 ++++++++++++++++
 2 files changed, 24 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index b5e0efa71b3cc..c0a25a56dbe2a 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1176,8 +1176,11 @@ template <typename OpTy>
 static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
                                                  PatternRewriter &rewriter) {
   auto loc = op->getLoc();
-  auto inputTy = cast<ShapedType>(op->getOperand(0).getType());
-  auto resultTy = cast<ShapedType>(op->getResult(0).getType());
+  auto inputTy = dyn_cast<RankedTensorType>(op->getOperand(0).getType());
+  auto resultTy = dyn_cast<RankedTensorType>(op->getResult(0).getType());
+  if (!inputTy || !resultTy)
+    return rewriter.notifyMatchFailure(op, "unranked tensors not supported");
+
   auto elementTy = resultTy.getElementType();
   Value input = op->getOperand(0);
 
@@ -2380,11 +2383,9 @@ class GatherConverter : public OpConversionPattern<tosa::GatherOp> {
     auto input = adaptor.getOperands()[0];
     auto indices = adaptor.getOperands()[1];
 
-    auto valuesTy =
-        dyn_cast_or_null<RankedTensorType>(op.getValues().getType());
-    auto resultTy = cast<ShapedType>(op.getType());
-
-    if (!valuesTy)
+    auto valuesTy = dyn_cast<RankedTensorType>(op.getValues().getType());
+    auto resultTy = dyn_cast<RankedTensorType>(op.getType());
+    if (!valuesTy || !resultTy)
       return rewriter.notifyMatchFailure(op, "unranked tensors not supported");
 
     auto dynamicDims = inferDynamicDimsForGather(
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
index d1a9671873de0..d00846a4c3e02 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
@@ -57,3 +57,19 @@ func.func @cast_unsupported_type(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3x!
   %0 = tosa.cast %arg0 : (tensor<13x21x3xi32>) -> tensor<13x21x3x!quant.uniform<i16:f32, 0.078431375324726104:128>>
   return %0 : tensor<13x21x3x!quant.uniform<i16:f32, 0.078431375324726104:128>>
 }
+
+// -----
+
+func.func @unranked_reduce(%arg0: tensor<*xf32>) -> tensor<*xf32> {
+  // expected-error at +1 {{failed to legalize operation 'tosa.reduce_sum'}}
+  %0 = tosa.reduce_sum %arg0 {axis = 0 : i32} : (tensor<*xf32>) -> tensor<*xf32>
+  return %0 : tensor<*xf32>
+}
+
+// -----
+
+func.func @unranked_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) -> tensor<*xf32> {
+  // expected-error at +1 {{failed to legalize operation 'tosa.gather'}}
+  %0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x26xi32>) -> tensor<*xf32>
+  return %0 : tensor<*xf32>
+}



More information about the Mlir-commits mailing list