[Mlir-commits] [mlir] 1072196 - [tosa] Add duplicate indices check for Scatter (#143736)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jun 13 10:12:28 PDT 2025
Author: Tai Ly
Date: 2025-06-13T18:12:25+01:00
New Revision: 1072196c2737fcf921ad52e9a44c13423789111b
URL: https://github.com/llvm/llvm-project/commit/1072196c2737fcf921ad52e9a44c13423789111b
DIFF: https://github.com/llvm/llvm-project/commit/1072196c2737fcf921ad52e9a44c13423789111b.diff
LOG: [tosa] Add duplicate indices check for Scatter (#143736)
Tosa scatter operator disallow duplicate indices (per batch)
This patch adds, to the validation pass, checking for duplicate values
in scatter operator's constant indices values.
Signed-off-by: Tai Ly <tai.ly at arm.com>
Added:
Modified:
mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
mlir/test/Dialect/Tosa/invalid.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
index 096510a09e324..6f3b0916a7a60 100644
--- a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
+++ b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
@@ -243,6 +243,11 @@ bool getConstShapeValues(Operation *op,
// returns a small vector of int64_t values that attr contains
SmallVector<int64_t> convertFromIntAttr(const DenseElementsAttr &attr,
const int rank);
+
+// returns true iff constant indices for scatter op contains unique indices
+// per batch
+bool hasUniqueConstantScatterIndices(ShapedType indicesType,
+ DenseIntElementsAttr indicesAttr);
} // namespace tosa
} // namespace mlir
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index d33fc902de3a1..229f42d3178b5 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -1244,10 +1244,36 @@ bool checkErrorIfCondIf(Operation *op) {
return true;
}
+bool checkErrorIfScatter(Operation *op) {
+ auto scatterOp = dyn_cast<tosa::ScatterOp>(op);
+ if (!scatterOp)
+ return true;
+
+ // for constant indices, check that there are no duplicate values
+ DenseIntElementsAttr indicesAttr;
+ if (!matchPattern(scatterOp.getIndices(), m_Constant(&indicesAttr)))
+ return true;
+
+ auto const indicesType =
+ dyn_cast<ShapedType>(scatterOp.getIndices().getType());
+ if (!indicesType || !indicesType.hasRank()) {
+ op->emitOpError("expect ranked indices tensor");
+ return false;
+ }
+
+ if (!hasUniqueConstantScatterIndices(indicesType, indicesAttr)) {
+ op->emitOpError("indices values contain duplicates");
+ return false;
+ }
+
+ return true;
+}
+
LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
if (!checkErrorIfResize(op) || !checkErrorIfMul(op) ||
!checkErrorIfTable(op) || !checkErrorIfRescale(op) ||
- !checkErrorIfPad(op) || !checkErrorIfCondIf(op))
+ !checkErrorIfPad(op) || !checkErrorIfCondIf(op) ||
+ !checkErrorIfScatter(op))
return failure();
return success();
}
diff --git a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
index e1b3be74b50fd..9844abcc34cb1 100644
--- a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
+++ b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
@@ -213,3 +213,30 @@ mlir::tosa::convertFromIntAttr(const DenseElementsAttr &attr, const int rank) {
}
return {};
}
+
+bool mlir::tosa::hasUniqueConstantScatterIndices(
+ ShapedType indicesType, DenseIntElementsAttr indicesAttr) {
+ llvm::ArrayRef<int64_t> const indicesShape = indicesType.getShape();
+ const unsigned int indicesRank = indicesShape.size();
+ const unsigned int lastDimSize = indicesShape[indicesRank - 1];
+
+ // check each batch of indices from the flat indicesAttr values
+ // for duplicates
+ auto const indicesValues = indicesAttr.getValues<int32_t>();
+ assert(
+ (indicesValues.size() % lastDimSize == 0) &&
+ "Constant indices data length should be a multiple of indicesShape[-1]");
+
+ std::vector<uint64_t> indices(lastDimSize);
+ for (auto beg = indicesValues.begin(); beg < indicesValues.end();
+ beg += lastDimSize) {
+ std::copy(beg, beg + lastDimSize, indices.begin());
+ std::sort(indices.begin(), indices.end());
+ if (std::adjacent_find(indices.begin(), indices.end()) != indices.end()) {
+ // found duplicate values in indices in batch
+ return false;
+ }
+ }
+
+ return true;
+}
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index a4617fc6fba8b..805522799a6d8 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -2015,3 +2015,13 @@ func.func @test_rescale_output_unsigned(%arg0: tensor<1x1xi8>) -> (tensor<1x1xui
%r = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = false, output_unsigned = true, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1xui8>
return %r : tensor<1x1xui8>
}
+
+// -----
+
+// CHECK-LABEL: test_scatter_duplicate_indices
+func.func @test_scatter_duplicate_indices(%arg0: tensor<2x52x3xf32>, %arg2: tensor<2x12x3xf32>) -> tensor<2x52x3xf32> {
+ %indices = "tosa.const"() { values = dense<[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [1, 2, 3, 4, 5, 6, 7, 8, 9, 3, 11, 12]]> : tensor<2x12xi32> } : () -> tensor<2x12xi32>
+ // expected-error at +1 {{'tosa.scatter' op indices values contain duplicates}}
+ %0 = tosa.scatter %arg0, %indices, %arg2 : (tensor<2x52x3xf32>, tensor<2x12xi32>, tensor<2x12x3xf32>) -> tensor<2x52x3xf32>
+ return %0 : tensor<2x52x3xf32>
+}
More information about the Mlir-commits
mailing list