[Mlir-commits] [mlir] [mlir][Transforms] Add missing check in applyPermutation (PR #102099)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Aug 6 20:23:38 PDT 2024


https://github.com/DarshanRamakant updated https://github.com/llvm/llvm-project/pull/102099

>From 721105ef6e72e4e6660acaa8c05715e485bf84cf Mon Sep 17 00:00:00 2001
From: Darshan Bhat <darshanbhatsirsi at gmail.com>
Date: Mon, 5 Aug 2024 20:12:32 +0530
Subject: [PATCH] [mlir][Transforms] Add missing check in applyPermutation

The applyPermutation() utility should make sure
that the permutation numbers are within the size
of the input array. Otherwise it will cause a
cryptic array out of bound assertion later.
---
 mlir/include/mlir/Dialect/Utils/IndexingUtils.h       |  3 +++
 mlir/lib/Dialect/Tosa/IR/TosaOps.cpp                  |  8 ++++++++
 .../TosaToLinalg/tosa-to-linalg-invalid.mlir          | 11 +++++++++++
 3 files changed, 22 insertions(+)

diff --git a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
index 7849782e5442b..99218f491ddef 100644
--- a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
@@ -202,6 +202,9 @@ SmallVector<T> applyPermutation(ArrayRef<T> input,
                                 ArrayRef<int64_t> permutation) {
   assert(input.size() == permutation.size() &&
          "expected input rank to equal permutation rank");
+  assert(
+      llvm::all_of(permutation, [&](size_t s) { return s < input.size(); }) &&
+      "permutation must be within input bounds");
   auto permutationRange = llvm::map_range(
       llvm::seq<unsigned>(0, input.size()),
       [&](int64_t idx) -> T { return input[permutation[idx]]; });
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 4337787e4aead..6890b5cab5ede 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -28,7 +28,9 @@
 #include "mlir/Transforms/InliningUtils.h"
 #include "llvm/ADT/APFloat.h"
 #include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/TypeSwitch.h"
+#include <sys/_types/_int64_t.h>
 
 using namespace mlir;
 using namespace mlir::tosa;
@@ -1116,6 +1118,12 @@ LogicalResult tosa::TransposeOp::verify() {
            "Unexpectedly found permutation tensor without rank");
     if (!isPermutationVector(constantPerms))
       return emitOpError() << "expected valid permutation tensor";
+
+    if (inputType.hasRank() && (!inputType.getNumDynamicDims()) &&
+        !llvm::all_of(constantPerms,
+                      [&](int64_t s) { return s < inputType.getRank(); })) {
+      return emitOpError() << "permutation must be within input bounds";
+    }
   }
   return success();
 }
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
index b78577275a52a..cd42a6d2b73d3 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
@@ -36,3 +36,14 @@ func.func @rfft2d_with_non_float_type(%arg0 : tensor<1x1x1xi32>) -> (tensor<1x1x
   %real, %imag = tosa.rfft2d %arg0 : (tensor<1x1x1xi32>) -> (tensor<1x1x1xi32>, tensor<1x1x1xi32>)
   return %real, %imag : tensor<1x1x1xi32>, tensor<1x1x1xi32>
 }
+
+// -----
+
+// CHECK-LABEL: @test_invalid_constant_permutation
+func.func @test_invalid_constant_permutation() {
+  // expected-error at +3 {{permutation must be within input bounds}}
+	%14 = tensor.empty() : tensor<3x4x5xi32>
+	%c1 = arith.constant dense<[3, 0, 1]> : tensor<3xi32>
+	%72 = tosa.transpose %14, %c1 : (tensor<3x4x5xi32>, tensor<3xi32>) -> tensor<3x4x5xi32>
+	return 
+}
\ No newline at end of file



More information about the Mlir-commits mailing list