[Mlir-commits] [mlir] 2d1362e - Add Location, Region and Block to MLIR Python bindings.

Stella Laurenzo llvmlistbot at llvm.org
Fri Aug 28 15:29:45 PDT 2020


Author: Stella Laurenzo
Date: 2020-08-28T15:26:05-07:00
New Revision: 2d1362e09af2e35c93aca59852211a735d865a54

URL: https://github.com/llvm/llvm-project/commit/2d1362e09af2e35c93aca59852211a735d865a54
DIFF: https://github.com/llvm/llvm-project/commit/2d1362e09af2e35c93aca59852211a735d865a54.diff

LOG: Add Location, Region and Block to MLIR Python bindings.

* This is just enough to create regions/blocks and iterate over them.
* Does not yet implement the preferred iteration strategy (python pseudo containers).
* Refinements need to come after doing basic mappings of operations and values so that the whole hierarchy can be used.

Differential Revision: https://reviews.llvm.org/D86683

Added: 
    mlir/test/Bindings/Python/ir_location.py
    mlir/test/Bindings/Python/ir_operation.py

Modified: 
    mlir/include/mlir-c/IR.h
    mlir/lib/Bindings/Python/IRModules.cpp
    mlir/lib/Bindings/Python/IRModules.h
    mlir/lib/Bindings/Python/PybindUtils.h
    mlir/lib/CAPI/IR/IR.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 225b53166306..340f8c5d78ff 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -85,6 +85,9 @@ typedef void (*MlirStringCallback)(const char *, intptr_t, void *);
 /** Creates an MLIR context and transfers its ownership to the caller. */
 MlirContext mlirContextCreate();
 
+/** Checks if two contexts are equal. */
+int mlirContextEqual(MlirContext ctx1, MlirContext ctx2);
+
 /** Takes an MLIR context owned by the caller and destroys it. */
 void mlirContextDestroy(MlirContext context);
 
@@ -315,6 +318,9 @@ void mlirValuePrint(MlirValue value, MlirStringCallback callback,
 /** Parses a type. The type is owned by the context. */
 MlirType mlirTypeParseGet(MlirContext context, const char *type);
 
+/** Gets the context that a type was created with. */
+MlirContext mlirTypeGetContext(MlirType type);
+
 /** Checks whether a type is null. */
 inline int mlirTypeIsNull(MlirType type) { return !type.ptr; }
 

diff  --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp
index 2f5735f83975..19da019149f9 100644
--- a/mlir/lib/Bindings/Python/IRModules.cpp
+++ b/mlir/lib/Bindings/Python/IRModules.cpp
@@ -28,13 +28,59 @@ Returns a new MlirModule or raises a ValueError if the parsing fails.
 See also: https://mlir.llvm.org/docs/LangRef/
 )";
 
-static const char kContextParseType[] = R"(Parses the assembly form of a type.
+static const char kContextParseTypeDocstring[] =
+    R"(Parses the assembly form of a type.
 
 Returns a Type object or raises a ValueError if the type cannot be parsed.
 
 See also: https://mlir.llvm.org/docs/LangRef/#type-system
 )";
 
+static const char kContextGetUnknownLocationDocstring[] =
+    R"(Gets a Location representing an unknown location)";
+
+static const char kContextGetFileLocationDocstring[] =
+    R"(Gets a Location representing a file, line and column)";
+
+static const char kContextCreateBlockDocstring[] =
+    R"(Creates a detached block)";
+
+static const char kContextCreateRegionDocstring[] =
+    R"(Creates a detached region)";
+
+static const char kRegionAppendBlockDocstring[] =
+    R"(Appends a block to a region.
+
+Raises:
+  ValueError: If the block is already attached to another region.
+)";
+
+static const char kRegionInsertBlockDocstring[] =
+    R"(Inserts a block at a postiion in a region.
+
+Raises:
+  ValueError: If the block is already attached to another region.
+)";
+
+static const char kRegionFirstBlockDocstring[] =
+    R"(Gets the first block in a region.
+
+Blocks can also be accessed via the `blocks` container.
+
+Raises:
+  IndexError: If the region has no blocks.
+)";
+
+static const char kBlockNextInRegionDocstring[] =
+    R"(Gets the next block in the enclosing region.
+
+Blocks can also be accessed via the `blocks` container of the owning region.
+This method exists to mirror the lower level API and should not be preferred.
+
+Raises:
+  IndexError: If there are no further blocks.
+)";
+
 static const char kOperationStrDunderDocstring[] =
     R"(Prints the assembly form of the operation with default options.
 
@@ -106,6 +152,24 @@ struct PySinglePartStringAccumulator {
 
 } // namespace
 
+//------------------------------------------------------------------------------
+// PyBlock, PyRegion, and PyOperation.
+//------------------------------------------------------------------------------
+
+void PyRegion::attachToParent() {
+  if (!detached) {
+    throw SetPyError(PyExc_ValueError, "Region is already attached to an op");
+  }
+  detached = false;
+}
+
+void PyBlock::attachToParent() {
+  if (!detached) {
+    throw SetPyError(PyExc_ValueError, "Block is already attached to an op");
+  }
+  detached = false;
+}
+
 //------------------------------------------------------------------------------
 // PyAttribute.
 //------------------------------------------------------------------------------
@@ -454,7 +518,59 @@ void mlir::python::populateIRSubmodule(py::module &m) {
             }
             return PyType(type);
           },
-          py::keep_alive<0, 1>(), kContextParseType);
+          py::keep_alive<0, 1>(), kContextParseTypeDocstring)
+      .def(
+          "get_unknown_location",
+          [](PyMlirContext &self) {
+            return PyLocation(mlirLocationUnknownGet(self.context));
+          },
+          py::keep_alive<0, 1>(), kContextGetUnknownLocationDocstring)
+      .def(
+          "get_file_location",
+          [](PyMlirContext &self, std::string filename, int line, int col) {
+            return PyLocation(mlirLocationFileLineColGet(
+                self.context, filename.c_str(), line, col));
+          },
+          py::keep_alive<0, 1>(), kContextGetFileLocationDocstring,
+          py::arg("filename"), py::arg("line"), py::arg("col"))
+      .def(
+          "create_region",
+          [](PyMlirContext &self) {
+            // The creating context is explicitly captured on regions to
+            // facilitate illegal assemblies of objects from multiple contexts
+            // that would invalidate the memory model.
+            return PyRegion(self.context, mlirRegionCreate(),
+                            /*detached=*/true);
+          },
+          py::keep_alive<0, 1>(), kContextCreateRegionDocstring)
+      .def(
+          "create_block",
+          [](PyMlirContext &self, std::vector<PyType> pyTypes) {
+            // In order for the keep_alive extend the proper lifetime, all
+            // types must be from the same context.
+            for (auto pyType : pyTypes) {
+              if (!mlirContextEqual(mlirTypeGetContext(pyType.type),
+                                    self.context)) {
+                throw SetPyError(
+                    PyExc_ValueError,
+                    "All types used to construct a block must be from "
+                    "the same context as the block");
+              }
+            }
+            llvm::SmallVector<MlirType, 4> types(pyTypes.begin(),
+                                                 pyTypes.end());
+            return PyBlock(self.context,
+                           mlirBlockCreate(types.size(), &types[0]),
+                           /*detached=*/true);
+          },
+          py::keep_alive<0, 1>(), kContextCreateBlockDocstring);
+
+  py::class_<PyLocation>(m, "Location").def("__repr__", [](PyLocation &self) {
+    PyPrintAccumulator printAccum;
+    mlirLocationPrint(self.loc, printAccum.getCallback(),
+                      printAccum.getUserData());
+    return printAccum.join();
+  });
 
   // Mapping of Module
   py::class_<PyModule>(m, "Module")
@@ -475,6 +591,70 @@ void mlir::python::populateIRSubmodule(py::module &m) {
           },
           kOperationStrDunderDocstring);
 
+  // Mapping of PyRegion.
+  py::class_<PyRegion>(m, "Region")
+      .def(
+          "append_block",
+          [](PyRegion &self, PyBlock &block) {
+            if (!mlirContextEqual(self.context, block.context)) {
+              throw SetPyError(
+                  PyExc_ValueError,
+                  "Block must have been created from the same context as "
+                  "this region");
+            }
+
+            block.attachToParent();
+            mlirRegionAppendOwnedBlock(self.region, block.block);
+          },
+          kRegionAppendBlockDocstring)
+      .def(
+          "insert_block",
+          [](PyRegion &self, int pos, PyBlock &block) {
+            if (!mlirContextEqual(self.context, block.context)) {
+              throw SetPyError(
+                  PyExc_ValueError,
+                  "Block must have been created from the same context as "
+                  "this region");
+            }
+            block.attachToParent();
+            // TODO: Make this return a failure and raise if out of bounds.
+            mlirRegionInsertOwnedBlock(self.region, pos, block.block);
+          },
+          kRegionInsertBlockDocstring)
+      .def_property_readonly(
+          "first_block",
+          [](PyRegion &self) {
+            MlirBlock block = mlirRegionGetFirstBlock(self.region);
+            if (mlirBlockIsNull(block)) {
+              throw SetPyError(PyExc_IndexError, "Region has no blocks");
+            }
+            return PyBlock(self.context, block, /*detached=*/false);
+          },
+          kRegionFirstBlockDocstring);
+
+  // Mapping of PyBlock.
+  py::class_<PyBlock>(m, "Block")
+      .def_property_readonly(
+          "next_in_region",
+          [](PyBlock &self) {
+            MlirBlock block = mlirBlockGetNextInRegion(self.block);
+            if (mlirBlockIsNull(block)) {
+              throw SetPyError(PyExc_IndexError,
+                               "Attempt to read past last block");
+            }
+            return PyBlock(self.context, block, /*detached=*/false);
+          },
+          py::keep_alive<0, 1>(), kBlockNextInRegionDocstring)
+      .def(
+          "__str__",
+          [](PyBlock &self) {
+            PyPrintAccumulator printAccum;
+            mlirBlockPrint(self.block, printAccum.getCallback(),
+                           printAccum.getUserData());
+            return printAccum.join();
+          },
+          kTypeStrDunderDocstring);
+
   // Mapping of Type.
   py::class_<PyAttribute>(m, "Attribute")
       .def(

diff  --git a/mlir/lib/Bindings/Python/IRModules.h b/mlir/lib/Bindings/Python/IRModules.h
index 1edfc1cead3e..20fe8014e138 100644
--- a/mlir/lib/Bindings/Python/IRModules.h
+++ b/mlir/lib/Bindings/Python/IRModules.h
@@ -28,6 +28,13 @@ class PyMlirContext {
   MlirContext context;
 };
 
+/// Wrapper around an MlirLocation.
+class PyLocation {
+public:
+  PyLocation(MlirLocation loc) : loc(loc) {}
+  MlirLocation loc;
+};
+
 /// Wrapper around MlirModule.
 class PyModule {
 public:
@@ -45,6 +52,72 @@ class PyModule {
   MlirModule module;
 };
 
+/// Wrapper around an MlirRegion.
+/// Note that region can exist in a detached state (where this instance is
+/// responsible for clearing) or an attached state (where its owner is
+/// responsible).
+///
+/// This python wrapper retains a redundant reference to its creating context
+/// in order to facilitate checking that parts of the operation hierarchy
+/// are only assembled from the same context.
+class PyRegion {
+public:
+  PyRegion(MlirContext context, MlirRegion region, bool detached)
+      : context(context), region(region), detached(detached) {}
+  PyRegion(PyRegion &&other)
+      : context(other.context), region(other.region), detached(other.detached) {
+    other.detached = false;
+  }
+  ~PyRegion() {
+    if (detached)
+      mlirRegionDestroy(region);
+  }
+
+  // Call prior to attaching the region to a parent.
+  // This will transition to the attached state and will throw an exception
+  // if already attached.
+  void attachToParent();
+
+  MlirContext context;
+  MlirRegion region;
+
+private:
+  bool detached;
+};
+
+/// Wrapper around an MlirBlock.
+/// Note that blocks can exist in a detached state (where this instance is
+/// responsible for clearing) or an attached state (where its owner is
+/// responsible).
+///
+/// This python wrapper retains a redundant reference to its creating context
+/// in order to facilitate checking that parts of the operation hierarchy
+/// are only assembled from the same context.
+class PyBlock {
+public:
+  PyBlock(MlirContext context, MlirBlock block, bool detached)
+      : context(context), block(block), detached(detached) {}
+  PyBlock(PyBlock &&other)
+      : context(other.context), block(other.block), detached(other.detached) {
+    other.detached = false;
+  }
+  ~PyBlock() {
+    if (detached)
+      mlirBlockDestroy(block);
+  }
+
+  // Call prior to attaching the block to a parent.
+  // This will transition to the attached state and will throw an exception
+  // if already attached.
+  void attachToParent();
+
+  MlirContext context;
+  MlirBlock block;
+
+private:
+  bool detached;
+};
+
 /// Wrapper around the generic MlirAttribute.
 /// The lifetime of a type is bound by the PyContext that created it.
 class PyAttribute {
@@ -84,6 +157,7 @@ class PyType {
 public:
   PyType(MlirType type) : type(type) {}
   bool operator==(const PyType &other);
+  operator MlirType() const { return type; }
 
   MlirType type;
 };

diff  --git a/mlir/lib/Bindings/Python/PybindUtils.h b/mlir/lib/Bindings/Python/PybindUtils.h
index 0c0e069fae03..29ab06a25055 100644
--- a/mlir/lib/Bindings/Python/PybindUtils.h
+++ b/mlir/lib/Bindings/Python/PybindUtils.h
@@ -10,6 +10,7 @@
 #define MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
 
 #include <pybind11/pybind11.h>
+#include <pybind11/stl.h>
 
 #include "llvm/ADT/Twine.h"
 

diff  --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 23337213e2c7..2a008a2114d6 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -55,6 +55,10 @@ MlirContext mlirContextCreate() {
   return wrap(context);
 }
 
+int mlirContextEqual(MlirContext ctx1, MlirContext ctx2) {
+  return unwrap(ctx1) == unwrap(ctx2);
+}
+
 void mlirContextDestroy(MlirContext context) { delete unwrap(context); }
 
 /* ========================================================================== */
@@ -350,6 +354,10 @@ MlirType mlirTypeParseGet(MlirContext context, const char *type) {
   return wrap(mlir::parseType(type, unwrap(context)));
 }
 
+MlirContext mlirTypeGetContext(MlirType type) {
+  return wrap(unwrap(type).getContext());
+}
+
 int mlirTypeEqual(MlirType t1, MlirType t2) { return unwrap(t1) == unwrap(t2); }
 
 void mlirTypePrint(MlirType type, MlirStringCallback callback, void *userData) {

diff  --git a/mlir/test/Bindings/Python/ir_location.py b/mlir/test/Bindings/Python/ir_location.py
new file mode 100644
index 000000000000..a24962ad476d
--- /dev/null
+++ b/mlir/test/Bindings/Python/ir_location.py
@@ -0,0 +1,31 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+import mlir
+
+def run(f):
+  print("\nTEST:", f.__name__)
+  f()
+
+# CHECK-LABEL: TEST: testUnknown
+def testUnknown():
+  ctx = mlir.ir.Context()
+  loc = ctx.get_unknown_location()
+  # CHECK: unknown str: loc(unknown)
+  print("unknown str:", str(loc))
+  # CHECK: unknown repr: loc(unknown)
+  print("unknown repr:", repr(loc))
+
+run(testUnknown)
+
+
+# CHECK-LABEL: TEST: testFileLineCol
+def testFileLineCol():
+  ctx = mlir.ir.Context()
+  loc = ctx.get_file_location("foo.txt", 123, 56)
+  # CHECK: file str: loc("foo.txt":123:56)
+  print("file str:", str(loc))
+  # CHECK: file repr: loc("foo.txt":123:56)
+  print("file repr:", repr(loc))
+
+run(testFileLineCol)
+

diff  --git a/mlir/test/Bindings/Python/ir_operation.py b/mlir/test/Bindings/Python/ir_operation.py
new file mode 100644
index 000000000000..c4246844f690
--- /dev/null
+++ b/mlir/test/Bindings/Python/ir_operation.py
@@ -0,0 +1,71 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+import mlir
+
+def run(f):
+  print("\nTEST:", f.__name__)
+  f()
+
+
+# CHECK-LABEL: TEST: testDetachedRegionBlock
+def testDetachedRegionBlock():
+  ctx = mlir.ir.Context()
+  t = mlir.ir.F32Type(ctx)
+  region = ctx.create_region()
+  block = ctx.create_block([t, t])
+  # CHECK: <<UNLINKED BLOCK>>
+  print(block)
+
+run(testDetachedRegionBlock)
+
+
+# CHECK-LABEL: TEST: testBlockTypeContextMismatch
+def testBlockTypeContextMismatch():
+  c1 = mlir.ir.Context()
+  c2 = mlir.ir.Context()
+  t1 = mlir.ir.F32Type(c1)
+  t2 = mlir.ir.F32Type(c2)
+  try:
+    block = c1.create_block([t1, t2])
+  except ValueError as e:
+    # CHECK: ERROR: All types used to construct a block must be from the same context as the block
+    print("ERROR:", e)
+
+run(testBlockTypeContextMismatch)
+
+
+# CHECK-LABEL: TEST: testBlockAppend
+def testBlockAppend():
+  ctx = mlir.ir.Context()
+  t = mlir.ir.F32Type(ctx)
+  region = ctx.create_region()
+  try:
+    region.first_block
+  except IndexError:
+    pass
+  else:
+    raise RuntimeError("Expected exception not raised")
+
+  block = ctx.create_block([t, t])
+  region.append_block(block)
+  try:
+    region.append_block(block)
+  except ValueError:
+    pass
+  else:
+    raise RuntimeError("Expected exception not raised")
+
+  block2 = ctx.create_block([t])
+  region.insert_block(1, block2)
+  # CHECK: <<UNLINKED BLOCK>>
+  block_first = region.first_block
+  print(block_first)
+  block_next = block_first.next_in_region
+  try:
+    block_next = block_next.next_in_region
+  except IndexError:
+    pass
+  else:
+    raise RuntimeError("Expected exception not raised")
+
+run(testBlockAppend)


        


More information about the Mlir-commits mailing list