[Mlir-commits] [mlir] [mlir][Transforms] Add missing check in applyPermutation (PR #102099)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Aug 9 07:56:02 PDT 2024
https://github.com/DarshanRamakant updated https://github.com/llvm/llvm-project/pull/102099
>From 56f7539bf1427cccc70cfd6fb2ff5fff9a2bb6e5 Mon Sep 17 00:00:00 2001
From: Darshan Bhat <darshanbhatsirsi at gmail.com>
Date: Mon, 5 Aug 2024 20:12:32 +0530
Subject: [PATCH] [mlir][Transforms] Add missing check in applyPermutation
The applyPermutation() utility should make sure
that the permutation numbers are within the size
of the input array. Otherwise it will cause a
cryptic array out of bound assertion later.
---
mlir/include/mlir/Dialect/Utils/IndexingUtils.h | 3 +++
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 6 ++++++
mlir/test/Dialect/Tosa/invalid.mlir | 11 +++++++++++
3 files changed, 20 insertions(+)
diff --git a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
index 7849782e5442bd..99218f491ddef4 100644
--- a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
@@ -202,6 +202,9 @@ SmallVector<T> applyPermutation(ArrayRef<T> input,
ArrayRef<int64_t> permutation) {
assert(input.size() == permutation.size() &&
"expected input rank to equal permutation rank");
+ assert(
+ llvm::all_of(permutation, [&](size_t s) { return s < input.size(); }) &&
+ "permutation must be within input bounds");
auto permutationRange = llvm::map_range(
llvm::seq<unsigned>(0, input.size()),
[&](int64_t idx) -> T { return input[permutation[idx]]; });
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 4337787e4aeadd..b8b20cef965c94 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1116,6 +1116,12 @@ LogicalResult tosa::TransposeOp::verify() {
"Unexpectedly found permutation tensor without rank");
if (!isPermutationVector(constantPerms))
return emitOpError() << "expected valid permutation tensor";
+
+ if (inputType.hasRank() && (!inputType.getNumDynamicDims()) &&
+ !llvm::all_of(constantPerms,
+ [&](int64_t s) { return s < inputType.getRank(); })) {
+ return emitOpError() << "permutation must be within input bounds";
+ }
}
return success();
}
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index cb38d4d81ca2ee..79a86ddb32cc3c 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -413,3 +413,14 @@ func.func @test_tile_invalid_multiples() {
%1 = tosa.tile %0 {multiples = array<i64>} : (tensor<4x31x31xf32>) -> tensor<4x31x31xf32>
return
}
+
+// -----
+
+// CHECK-LABEL: @test_invalid_constant_permutation
+func.func @test_invalid_constant_permutation() {
+ // expected-error at +3 {{permutation must be within input bounds}}
+ %0 = tensor.empty() : tensor<3x4x5xi32>
+ %1 = arith.constant dense<[3, 0, 1]> : tensor<3xi32>
+ %2 = tosa.transpose %0, %1 : (tensor<3x4x5xi32>, tensor<3xi32>) -> tensor<3x4x5xi32>
+ return
+}
\ No newline at end of file
More information about the Mlir-commits
mailing list