[Mlir-commits] [mlir] [MLIR][TOSA] Guard scatter lowering against unranked operand (PR #178188)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jan 27 04:19:11 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Ayush Kumar Gaur (Ayush3941)
<details>
<summary>Changes</summary>
### What the problem
--tosa-to-scf crashes when lowering tosa.scatter if operands are unranked, because the lowering assumes RankedTensorType and later builds tensor.extract_slice, which asserts on unranked tensors.
### Why it Happened
The current pattern does not check operand ranks before lowering, even though the implementation hardcodes 3D (N,W,C) semantics and cannot handle unranked or differently-ranked tensors.
### Whats The Fix
Add rank/type guards in ScatterOpConverter and fail the rewrite unless operands are ranked with expected ranks (values/input: rank-3, indices: rank-2), allowing legalization to fail gracefully instead of crashing.
---
Full diff: https://github.com/llvm/llvm-project/pull/178188.diff
2 Files Affected:
- (modified) mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp (+8)
- (added) mlir/test/Conversion/TosaToSCF/tosa-to-scf-invalid.mlir (+9)
``````````diff
diff --git a/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp b/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
index aa6b4164e9876..9b1fb90d3ae17 100644
--- a/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
+++ b/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
@@ -101,6 +101,14 @@ class ScatterOpConverter : public OpRewritePattern<tosa::ScatterOp> {
auto input = scatter.getInput();
auto loc = scatter.getLoc();
+ auto valuesType = dyn_cast<RankedTensorType>(valuesIn.getType());
+ auto indicesType = dyn_cast<RankedTensorType>(indices.getType());
+ auto inputType = dyn_cast<RankedTensorType>(input.getType());
+ if (!valuesType || !indicesType || !inputType ||
+ valuesType.getRank() != 3 || inputType.getRank() != 3 ||
+ indicesType.getRank() != 2)
+ return failure();
+
// N, W, C are chosen to match the TOSA spec
auto dimN = createTensorDim(rewriter, loc, input, 0);
auto dimW = createTensorDim(rewriter, loc, input, 1);
diff --git a/mlir/test/Conversion/TosaToSCF/tosa-to-scf-invalid.mlir b/mlir/test/Conversion/TosaToSCF/tosa-to-scf-invalid.mlir
new file mode 100644
index 0000000000000..c1bb458f42113
--- /dev/null
+++ b/mlir/test/Conversion/TosaToSCF/tosa-to-scf-invalid.mlir
@@ -0,0 +1,9 @@
+// RUN: mlir-opt --split-input-file --tosa-to-scf %s -verify-diagnostics -o -
+
+// CHECK-LABEL: @scatter_unranked
+func.func @scatter_unranked(%v: tensor<*xi32>, %idx: tensor<*xi32>, %inp: tensor<*xi32>) -> tensor<*xi32> {
+ // expected-error @+1 {{failed to legalize operation 'tosa.scatter' that was explicitly marked illegal}}
+ %0 = tosa.scatter %v, %idx, %inp
+ : (tensor<*xi32>, tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32>
+ return %0 : tensor<*xi32>
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/178188
More information about the Mlir-commits
mailing list