[Mlir-commits] [mlir] 934638c - [MLIR][TOSA] Guard scatter lowering against unranked operand (#178188)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jan 27 09:27:45 PST 2026


Author: Ayush Kumar Gaur
Date: 2026-01-27T17:27:40Z
New Revision: 934638cc6e0548a1eaae90a32761d9e11822020a

URL: https://github.com/llvm/llvm-project/commit/934638cc6e0548a1eaae90a32761d9e11822020a
DIFF: https://github.com/llvm/llvm-project/commit/934638cc6e0548a1eaae90a32761d9e11822020a.diff

LOG: [MLIR][TOSA] Guard scatter lowering against unranked operand (#178188)

### What the problem 

--tosa-to-scf crashes when lowering tosa.scatter if operands are
unranked, because the lowering assumes RankedTensorType and later builds
tensor.extract_slice, which asserts on unranked tensors.

### Why it Happened

The current pattern does not check operand ranks before lowering, even
though the implementation hardcodes 3D (N,W,C) semantics and cannot
handle unranked or differently-ranked tensors.

### Whats The Fix

Add rank/type guards in ScatterOpConverter and fail the rewrite unless
operands are ranked with expected ranks (values/input: rank-3, indices:
rank-2), allowing legalization to fail gracefully instead of crashing.

Fixes #177966

Added: 
    mlir/test/Conversion/TosaToSCF/tosa-to-scf-invalid.mlir

Modified: 
    mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp b/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
index aa6b4164e9876..b46026b855b90 100644
--- a/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
+++ b/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
@@ -101,6 +101,13 @@ class ScatterOpConverter : public OpRewritePattern<tosa::ScatterOp> {
     auto input = scatter.getInput();
     auto loc = scatter.getLoc();
 
+    if (!isa<RankedTensorType>(valuesIn.getType()) ||
+        !isa<RankedTensorType>(indices.getType()) ||
+        !isa<RankedTensorType>(input.getType())) {
+      return rewriter.notifyMatchFailure(
+          scatter, "expected ranked tensor operands for scatter lowering");
+    }
+
     // N, W, C are chosen to match the TOSA spec
     auto dimN = createTensorDim(rewriter, loc, input, 0);
     auto dimW = createTensorDim(rewriter, loc, input, 1);

diff  --git a/mlir/test/Conversion/TosaToSCF/tosa-to-scf-invalid.mlir b/mlir/test/Conversion/TosaToSCF/tosa-to-scf-invalid.mlir
new file mode 100644
index 0000000000000..c1bb458f42113
--- /dev/null
+++ b/mlir/test/Conversion/TosaToSCF/tosa-to-scf-invalid.mlir
@@ -0,0 +1,9 @@
+// RUN: mlir-opt --split-input-file --tosa-to-scf %s -verify-diagnostics -o -
+
+// CHECK-LABEL: @scatter_unranked
+func.func @scatter_unranked(%v: tensor<*xi32>, %idx: tensor<*xi32>, %inp: tensor<*xi32>) -> tensor<*xi32> {
+  // expected-error @+1 {{failed to legalize operation 'tosa.scatter' that was explicitly marked illegal}}
+  %0 = tosa.scatter %v, %idx, %inp
+      : (tensor<*xi32>, tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32>
+  return %0 : tensor<*xi32>
+}


        


More information about the Mlir-commits mailing list