[Mlir-commits] [mlir] [MLIR][TOSA] Guard scatter lowering against unranked operand (PR #178188)
Ayush Kumar Gaur
llvmlistbot at llvm.org
Tue Jan 27 05:50:42 PST 2026
https://github.com/Ayush3941 updated https://github.com/llvm/llvm-project/pull/178188
>From 738d8cf3df00590d0371e785e6223d3bfdf5632d Mon Sep 17 00:00:00 2001
From: Ayush3941 <ayushkgaur1 at gmail.com>
Date: Tue, 27 Jan 2026 07:09:53 -0500
Subject: [PATCH 1/4] [MLIR][TOSA] Guard scatter lowering against unranked
operand
---
mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp | 8 ++++++++
mlir/test/Conversion/TosaToSCF/tosa-to-scf-invalid.mlir | 9 +++++++++
2 files changed, 17 insertions(+)
create mode 100644 mlir/test/Conversion/TosaToSCF/tosa-to-scf-invalid.mlir
diff --git a/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp b/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
index aa6b4164e9876..9b1fb90d3ae17 100644
--- a/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
+++ b/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
@@ -101,6 +101,14 @@ class ScatterOpConverter : public OpRewritePattern<tosa::ScatterOp> {
auto input = scatter.getInput();
auto loc = scatter.getLoc();
+ auto valuesType = dyn_cast<RankedTensorType>(valuesIn.getType());
+ auto indicesType = dyn_cast<RankedTensorType>(indices.getType());
+ auto inputType = dyn_cast<RankedTensorType>(input.getType());
+ if (!valuesType || !indicesType || !inputType ||
+ valuesType.getRank() != 3 || inputType.getRank() != 3 ||
+ indicesType.getRank() != 2)
+ return failure();
+
// N, W, C are chosen to match the TOSA spec
auto dimN = createTensorDim(rewriter, loc, input, 0);
auto dimW = createTensorDim(rewriter, loc, input, 1);
diff --git a/mlir/test/Conversion/TosaToSCF/tosa-to-scf-invalid.mlir b/mlir/test/Conversion/TosaToSCF/tosa-to-scf-invalid.mlir
new file mode 100644
index 0000000000000..c1bb458f42113
--- /dev/null
+++ b/mlir/test/Conversion/TosaToSCF/tosa-to-scf-invalid.mlir
@@ -0,0 +1,9 @@
+// RUN: mlir-opt --split-input-file --tosa-to-scf %s -verify-diagnostics -o -
+
+// CHECK-LABEL: @scatter_unranked
+func.func @scatter_unranked(%v: tensor<*xi32>, %idx: tensor<*xi32>, %inp: tensor<*xi32>) -> tensor<*xi32> {
+ // expected-error @+1 {{failed to legalize operation 'tosa.scatter' that was explicitly marked illegal}}
+ %0 = tosa.scatter %v, %idx, %inp
+ : (tensor<*xi32>, tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32>
+ return %0 : tensor<*xi32>
+}
>From c93d72690dfaac7edde4a835d35b3bfefb9037a3 Mon Sep 17 00:00:00 2001
From: Ayush3941 <ayushkgaur1 at gmail.com>
Date: Tue, 27 Jan 2026 07:23:16 -0500
Subject: [PATCH 2/4] [MLIR][TOSA] Guard scatter lowering against unranked
operand with fixed format
---
mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp b/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
index 9b1fb90d3ae17..6b9186b7c26c4 100644
--- a/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
+++ b/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
@@ -101,9 +101,9 @@ class ScatterOpConverter : public OpRewritePattern<tosa::ScatterOp> {
auto input = scatter.getInput();
auto loc = scatter.getLoc();
- auto valuesType = dyn_cast<RankedTensorType>(valuesIn.getType());
+ auto valuesType = dyn_cast<RankedTensorType>(valuesIn.getType());
auto indicesType = dyn_cast<RankedTensorType>(indices.getType());
- auto inputType = dyn_cast<RankedTensorType>(input.getType());
+ auto inputType = dyn_cast<RankedTensorType>(input.getType());
if (!valuesType || !indicesType || !inputType ||
valuesType.getRank() != 3 || inputType.getRank() != 3 ||
indicesType.getRank() != 2)
>From 5d2dcf3af47f71f4e45138291ea8bd3481c485ed Mon Sep 17 00:00:00 2001
From: Ayush3941 <ayushkgaur1 at gmail.com>
Date: Tue, 27 Jan 2026 08:46:38 -0500
Subject: [PATCH 3/4] [MLIR][TOSA] Guard scatter lowering -v2 against unranked
operand with fixed format
---
mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp | 13 ++++++-------
1 file changed, 6 insertions(+), 7 deletions(-)
diff --git a/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp b/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
index 6b9186b7c26c4..a7b1ffde8d522 100644
--- a/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
+++ b/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
@@ -101,13 +101,12 @@ class ScatterOpConverter : public OpRewritePattern<tosa::ScatterOp> {
auto input = scatter.getInput();
auto loc = scatter.getLoc();
- auto valuesType = dyn_cast<RankedTensorType>(valuesIn.getType());
- auto indicesType = dyn_cast<RankedTensorType>(indices.getType());
- auto inputType = dyn_cast<RankedTensorType>(input.getType());
- if (!valuesType || !indicesType || !inputType ||
- valuesType.getRank() != 3 || inputType.getRank() != 3 ||
- indicesType.getRank() != 2)
- return failure();
+ if (!isa<RankedTensorType>(valuesIn.getType()) ||
+ !isa<RankedTensorType>(indices.getType()) ||
+ !isa<RankedTensorType>(input.getType())) {
+ return rewriter.notifyMatchFailure(scatter,
+ "expected ranked tensor operands for scatter lowering");
+ }
// N, W, C are chosen to match the TOSA spec
auto dimN = createTensorDim(rewriter, loc, input, 0);
>From 7dfae55145aebdc90ef0caa4cb7b1414ee348b08 Mon Sep 17 00:00:00 2001
From: Ayush3941 <ayushkgaur1 at gmail.com>
Date: Tue, 27 Jan 2026 08:47:08 -0500
Subject: [PATCH 4/4] [MLIR][TOSA] Guard scatter lowering -v2 against unranked
operand with fixed format -v2
---
mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp b/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
index a7b1ffde8d522..b46026b855b90 100644
--- a/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
+++ b/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
@@ -102,10 +102,10 @@ class ScatterOpConverter : public OpRewritePattern<tosa::ScatterOp> {
auto loc = scatter.getLoc();
if (!isa<RankedTensorType>(valuesIn.getType()) ||
- !isa<RankedTensorType>(indices.getType()) ||
- !isa<RankedTensorType>(input.getType())) {
- return rewriter.notifyMatchFailure(scatter,
- "expected ranked tensor operands for scatter lowering");
+ !isa<RankedTensorType>(indices.getType()) ||
+ !isa<RankedTensorType>(input.getType())) {
+ return rewriter.notifyMatchFailure(
+ scatter, "expected ranked tensor operands for scatter lowering");
}
// N, W, C are chosen to match the TOSA spec
More information about the Mlir-commits
mailing list