[Mlir-commits] [mlir] 1072196 - [tosa] Add duplicate indices check for Scatter (#143736)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jun 13 10:12:28 PDT 2025


Author: Tai Ly
Date: 2025-06-13T18:12:25+01:00
New Revision: 1072196c2737fcf921ad52e9a44c13423789111b

URL: https://github.com/llvm/llvm-project/commit/1072196c2737fcf921ad52e9a44c13423789111b
DIFF: https://github.com/llvm/llvm-project/commit/1072196c2737fcf921ad52e9a44c13423789111b.diff

LOG: [tosa] Add duplicate indices check for Scatter (#143736)

Tosa scatter operator disallow duplicate indices (per batch)
This patch adds, to the validation pass, checking for duplicate values
in scatter operator's constant indices values.

Signed-off-by: Tai Ly <tai.ly at arm.com>

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
    mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
    mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
    mlir/test/Dialect/Tosa/invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
index 096510a09e324..6f3b0916a7a60 100644
--- a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
+++ b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
@@ -243,6 +243,11 @@ bool getConstShapeValues(Operation *op,
 // returns a small vector of int64_t values that attr contains
 SmallVector<int64_t> convertFromIntAttr(const DenseElementsAttr &attr,
                                         const int rank);
+
+// returns true iff constant indices for scatter op contains unique indices
+// per batch
+bool hasUniqueConstantScatterIndices(ShapedType indicesType,
+                                     DenseIntElementsAttr indicesAttr);
 } // namespace tosa
 } // namespace mlir
 

diff  --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index d33fc902de3a1..229f42d3178b5 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -1244,10 +1244,36 @@ bool checkErrorIfCondIf(Operation *op) {
   return true;
 }
 
+bool checkErrorIfScatter(Operation *op) {
+  auto scatterOp = dyn_cast<tosa::ScatterOp>(op);
+  if (!scatterOp)
+    return true;
+
+  // for constant indices, check that there are no duplicate values
+  DenseIntElementsAttr indicesAttr;
+  if (!matchPattern(scatterOp.getIndices(), m_Constant(&indicesAttr)))
+    return true;
+
+  auto const indicesType =
+      dyn_cast<ShapedType>(scatterOp.getIndices().getType());
+  if (!indicesType || !indicesType.hasRank()) {
+    op->emitOpError("expect ranked indices tensor");
+    return false;
+  }
+
+  if (!hasUniqueConstantScatterIndices(indicesType, indicesAttr)) {
+    op->emitOpError("indices values contain duplicates");
+    return false;
+  }
+
+  return true;
+}
+
 LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
   if (!checkErrorIfResize(op) || !checkErrorIfMul(op) ||
       !checkErrorIfTable(op) || !checkErrorIfRescale(op) ||
-      !checkErrorIfPad(op) || !checkErrorIfCondIf(op))
+      !checkErrorIfPad(op) || !checkErrorIfCondIf(op) ||
+      !checkErrorIfScatter(op))
     return failure();
   return success();
 }

diff  --git a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
index e1b3be74b50fd..9844abcc34cb1 100644
--- a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
+++ b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
@@ -213,3 +213,30 @@ mlir::tosa::convertFromIntAttr(const DenseElementsAttr &attr, const int rank) {
   }
   return {};
 }
+
+bool mlir::tosa::hasUniqueConstantScatterIndices(
+    ShapedType indicesType, DenseIntElementsAttr indicesAttr) {
+  llvm::ArrayRef<int64_t> const indicesShape = indicesType.getShape();
+  const unsigned int indicesRank = indicesShape.size();
+  const unsigned int lastDimSize = indicesShape[indicesRank - 1];
+
+  // check each batch of indices from the flat indicesAttr values
+  // for duplicates
+  auto const indicesValues = indicesAttr.getValues<int32_t>();
+  assert(
+      (indicesValues.size() % lastDimSize == 0) &&
+      "Constant indices data length should be a multiple of indicesShape[-1]");
+
+  std::vector<uint64_t> indices(lastDimSize);
+  for (auto beg = indicesValues.begin(); beg < indicesValues.end();
+       beg += lastDimSize) {
+    std::copy(beg, beg + lastDimSize, indices.begin());
+    std::sort(indices.begin(), indices.end());
+    if (std::adjacent_find(indices.begin(), indices.end()) != indices.end()) {
+      // found duplicate values in indices in batch
+      return false;
+    }
+  }
+
+  return true;
+}

diff  --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index a4617fc6fba8b..805522799a6d8 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -2015,3 +2015,13 @@ func.func @test_rescale_output_unsigned(%arg0: tensor<1x1xi8>) -> (tensor<1x1xui
   %r = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = false, output_unsigned = true, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1xui8>
   return %r : tensor<1x1xui8>
 }
+
+// -----
+
+// CHECK-LABEL: test_scatter_duplicate_indices
+func.func @test_scatter_duplicate_indices(%arg0: tensor<2x52x3xf32>, %arg2: tensor<2x12x3xf32>) -> tensor<2x52x3xf32> {
+  %indices = "tosa.const"() { values = dense<[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [1, 2, 3, 4, 5, 6, 7, 8, 9, 3, 11, 12]]> : tensor<2x12xi32> } : () -> tensor<2x12xi32>
+  // expected-error at +1 {{'tosa.scatter' op indices values contain duplicates}}
+  %0 = tosa.scatter %arg0, %indices, %arg2 : (tensor<2x52x3xf32>, tensor<2x12xi32>, tensor<2x12x3xf32>) -> tensor<2x52x3xf32>
+  return %0 : tensor<2x52x3xf32>
+}


        


More information about the Mlir-commits mailing list