[Mlir-commits] [mlir] [mlir][tosa] Fix scatter duplicate indices check for int64 (PR #168085)

Luke Hutton llvmlistbot at llvm.org
Fri Nov 14 08:40:49 PST 2025


https://github.com/lhutton1 created https://github.com/llvm/llvm-project/pull/168085

This commit fixes the validation check for duplicate indices in the TOSA scatter operation when using int64 index tensors. Previously, use of int64 index tensors would cause a crash.

>From 92b2303105e13e7a6bcfe804ee0fcc44577c56a7 Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Fri, 14 Nov 2025 14:57:04 +0000
Subject: [PATCH] [mlir][tosa] Fix scatter duplicate indices check for int64

This commit fixes the validation check for duplicate indices
in the TOSA scatter operation when using int64 index tensors.
Previously, use of int64 index tensors would cause a crash.

Change-Id: Ib234ad655d382863cc1fcb31877190d0d20d455e
---
 mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp      |  9 +++++----
 mlir/test/Dialect/Tosa/invalid.mlir                  | 12 +++++++++++-
 .../Tosa/tosa-validation-version-1p1-valid.mlir      | 10 ++++++++++
 3 files changed, 26 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
index ac5d6207259eb..62c015a85ee36 100644
--- a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
+++ b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
@@ -216,22 +216,23 @@ mlir::tosa::convertFromIntAttr(const DenseElementsAttr &attr, const int rank) {
 
 bool mlir::tosa::hasUniqueConstantScatterIndices(
     ShapedType indicesType, DenseIntElementsAttr indicesAttr) {
-  llvm::ArrayRef<int64_t> const indicesShape = indicesType.getShape();
+  const llvm::ArrayRef<int64_t> indicesShape = indicesType.getShape();
   const unsigned int indicesRank = indicesShape.size();
   const unsigned int lastDimSize = indicesShape[indicesRank - 1];
 
   // check each batch of indices from the flat indicesAttr values
   // for duplicates
-  auto const indicesValues = indicesAttr.getValues<int32_t>();
+  auto const indicesValues = indicesAttr.getValues<APInt>();
   assert(
       (indicesValues.size() % lastDimSize == 0) &&
       "Constant indices data length should be a multiple of indicesShape[-1]");
 
-  std::vector<uint64_t> indices(lastDimSize);
+  std::vector<APInt> indices(lastDimSize);
   for (auto beg = indicesValues.begin(); beg < indicesValues.end();
        beg += lastDimSize) {
     std::copy(beg, beg + lastDimSize, indices.begin());
-    std::sort(indices.begin(), indices.end());
+    std::sort(indices.begin(), indices.end(),
+              [](const APInt &a, const APInt &b) { return a.slt(b); });
     if (std::adjacent_find(indices.begin(), indices.end()) != indices.end()) {
       // found duplicate values in indices in batch
       return false;
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index c9e03ca53a729..3d24928487ed2 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -4,7 +4,7 @@
 // validation flow.
 //--------------------------------------------------------------------------------------------------
 
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="specification_version=1.1.draft profiles=pro_int,pro_fp extensions=int16,int4,int64,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment"
 
 
 func.func @test_cast(%arg0: tensor<i1>) -> tensor<5xi32> {
@@ -2044,6 +2044,16 @@ func.func @test_scatter_duplicate_indices(%arg0: tensor<2x52x3xf32>, %arg2: tens
 
 // -----
 
+// CHECK-LABEL: test_scatter_duplicate_indices_int64
+func.func @test_scatter_duplicate_indices_int64(%arg0: tensor<2x52x3xf32>, %arg2: tensor<2x12x3xf32>) -> tensor<2x52x3xf32> {
+  %indices = "tosa.const"() { values = dense<[[1, 2, 3, 4, 5, 6, 7, 8, 9, 3, 11, 12], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]> : tensor<2x12xi64> } : () -> tensor<2x12xi64>
+  // expected-error at +1 {{'tosa.scatter' op indices values contain duplicates}}
+  %0 = tosa.scatter %arg0, %indices, %arg2 : (tensor<2x52x3xf32>, tensor<2x12xi64>, tensor<2x12x3xf32>) -> tensor<2x52x3xf32>
+  return %0 : tensor<2x52x3xf32>
+}
+
+// -----
+
 func.func @test_reduce_all_unsupported_data_types(%arg0: tensor<2x12x11xf32>) -> tensor<1x12x11xf32> {
   // expected-error at +1 {{'tosa.reduce_all' op illegal: operation operand/result data types did not align with any profile or extension, got (f32,f32), did you mean (i1,i1)?}}
   %0 = tosa.reduce_all %arg0 {axis = 0 : i32} : (tensor<2x12x11xf32>) -> tensor<1x12x11xf32>
diff --git a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
index acbff73b8b948..c285ae3cf44ee 100644
--- a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
@@ -2,6 +2,7 @@
 
 // -----
 
+// CHECK-LABEL: test_matmul_fp8_mixed_precision_operands
 func.func @test_matmul_fp8_mixed_precision_operands(%arg0: tensor<1x14x19xf8E4M3FN>, %arg1: tensor<1x19x28xf8E5M2>) -> tensor<1x14x28xf16> {
   %azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN>
   %bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
@@ -146,3 +147,12 @@ func.func @test_argmax_bf16_i64(%arg0: tensor<12x8x16xbf16>) -> tensor<12x16xi64
   %0 = tosa.argmax %arg0 { axis = 1 : i32 } : (tensor<12x8x16xbf16>) -> tensor<12x16xi64>
   return %0 : tensor<12x16xi64>
 }
+
+// -----
+
+// CHECK-LABEL: test_scatter_const_indices_int64
+func.func @test_scatter_const_indices_int64(%arg0: tensor<2x52x3xf32>, %arg2: tensor<2x12x3xf32>) -> tensor<2x52x3xf32> {
+  %indices = "tosa.const"() { values = dense<[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]> : tensor<2x12xi64> } : () -> tensor<2x12xi64>
+  %0 = tosa.scatter %arg0, %indices, %arg2 : (tensor<2x52x3xf32>, tensor<2x12xi64>, tensor<2x12x3xf32>) -> tensor<2x52x3xf32>
+  return %0 : tensor<2x52x3xf32>
+}



More information about the Mlir-commits mailing list