[llvm-branch-commits] [mlir] c0f3ea8 - [mlir][Python] Add checking process before create an AffineMap from a permutation.

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Tue Jan 12 17:36:50 PST 2021


Author: zhanghb97
Date: 2021-01-13T09:32:32+08:00
New Revision: c0f3ea8a08ca9a9ec473f6e9072ccf30dad5def8

URL: https://github.com/llvm/llvm-project/commit/c0f3ea8a08ca9a9ec473f6e9072ccf30dad5def8
DIFF: https://github.com/llvm/llvm-project/commit/c0f3ea8a08ca9a9ec473f6e9072ccf30dad5def8.diff

LOG: [mlir][Python] Add checking process before create an AffineMap from a permutation.

An invalid permutation will trigger a C++ assertion when attempting to create an AffineMap from the permutation.
This patch adds an `isPermutation` function to check the given permutation before creating the AffineMap.

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D94492

Added: 
    

Modified: 
    mlir/lib/Bindings/Python/IRModules.cpp
    mlir/test/Bindings/Python/ir_affine_map.py

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp
index 218099bedc6f..493ea5c1e47a 100644
--- a/mlir/lib/Bindings/Python/IRModules.cpp
+++ b/mlir/lib/Bindings/Python/IRModules.cpp
@@ -153,6 +153,21 @@ static MlirStringRef toMlirStringRef(const std::string &s) {
   return mlirStringRefCreate(s.data(), s.size());
 }
 
+template <typename PermutationTy>
+static bool isPermutation(std::vector<PermutationTy> permutation) {
+  llvm::SmallVector<bool, 8> seen(permutation.size(), false);
+  for (auto val : permutation) {
+    if (val < permutation.size()) {
+      if (seen[val])
+        return false;
+      seen[val] = true;
+      continue;
+    }
+    return false;
+  }
+  return true;
+}
+
 //------------------------------------------------------------------------------
 // Collections.
 //------------------------------------------------------------------------------
@@ -3914,6 +3929,9 @@ void mlir::python::populateIRSubmodule(py::module &m) {
           "get_permutation",
           [](std::vector<unsigned> permutation,
              DefaultingPyMlirContext context) {
+            if (!isPermutation(permutation))
+              throw py::cast_error("Invalid permutation when attempting to "
+                                   "create an AffineMap");
             MlirAffineMap affineMap = mlirAffineMapPermutationGet(
                 context->get(), permutation.size(), permutation.data());
             return PyAffineMap(context->getRef(), affineMap);

diff  --git a/mlir/test/Bindings/Python/ir_affine_map.py b/mlir/test/Bindings/Python/ir_affine_map.py
index fe37eb971555..0c99722dbf04 100644
--- a/mlir/test/Bindings/Python/ir_affine_map.py
+++ b/mlir/test/Bindings/Python/ir_affine_map.py
@@ -73,6 +73,12 @@ def testAffineMapGet():
       # CHECK: Invalid expression (None?) when attempting to create an AffineMap
       print(e)
 
+    try:
+      AffineMap.get_permutation([1, 0, 1])
+    except RuntimeError as e:
+      # CHECK: Invalid permutation when attempting to create an AffineMap
+      print(e)
+
     try:
       map3.get_submap([42])
     except ValueError as e:


        


More information about the llvm-branch-commits mailing list