[Mlir-commits] [mlir] 7202344 - Add a `mlirModuleGetBody()` accessor to the C API and bind it in Python

Mehdi Amini llvmlistbot at llvm.org
Wed Oct 28 10:54:06 PDT 2020


Author: Mehdi Amini
Date: 2020-10-28T17:53:52Z
New Revision: 72023442c1eb3169389f469d4b804aff93497758

URL: https://github.com/llvm/llvm-project/commit/72023442c1eb3169389f469d4b804aff93497758
DIFF: https://github.com/llvm/llvm-project/commit/72023442c1eb3169389f469d4b804aff93497758.diff

LOG: Add a `mlirModuleGetBody()` accessor to the C API and bind it in Python

Getting the body of a Module is a common need which justifies a
dedicated accessor instead of forcing users to go through the
region->blocks->front unwrapping manually.

Differential Revision: https://reviews.llvm.org/D90287

Added: 
    

Modified: 
    mlir/include/mlir-c/IR.h
    mlir/lib/Bindings/Python/IRModules.cpp
    mlir/lib/CAPI/IR/IR.cpp
    mlir/test/Bindings/Python/dialects.py
    mlir/test/CAPI/ir.c

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index a08fe77da37c..af0ab1fdf341 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -175,6 +175,9 @@ MlirModule mlirModuleCreateParse(MlirContext context, const char *module);
 /** Gets the context that a module was created with. */
 MlirContext mlirModuleGetContext(MlirModule module);
 
+/** Gets the body of the module, i.e. the only block it contains. */
+MlirBlock mlirModuleGetBody(MlirModule module);
+
 /** Checks whether a module is null. */
 static inline int mlirModuleIsNull(MlirModule module) { return !module.ptr; }
 

diff  --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp
index 2fba7fa5e283..4a46d9161d76 100644
--- a/mlir/lib/Bindings/Python/IRModules.cpp
+++ b/mlir/lib/Bindings/Python/IRModules.cpp
@@ -2234,6 +2234,16 @@ void mlir::python::populateIRSubmodule(py::module &m) {
                 .releaseObject();
           },
           "Accesses the module as an operation")
+      .def_property_readonly(
+          "body",
+          [](PyModule &self) {
+            PyOperationRef module_op = PyOperation::forOperation(
+                self.getContext(), mlirModuleGetOperation(self.get()),
+                self.getRef().releaseObject());
+            PyBlock returnBlock(module_op, mlirModuleGetBody(self.get()));
+            return returnBlock;
+          },
+          "Return the block for this module")
       .def(
           "dump",
           [](PyModule &self) {

diff  --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index fdc40bc6c4f1..f3c91d1fae24 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -148,6 +148,10 @@ MlirContext mlirModuleGetContext(MlirModule module) {
   return wrap(unwrap(module).getContext());
 }
 
+MlirBlock mlirModuleGetBody(MlirModule module) {
+  return wrap(unwrap(module).getBody());
+}
+
 void mlirModuleDestroy(MlirModule module) {
   // Transfer ownership to an OwningModuleRef so that its destructor is called.
   OwningModuleRef(unwrap(module));

diff  --git a/mlir/test/Bindings/Python/dialects.py b/mlir/test/Bindings/Python/dialects.py
index bc88e8668f4d..ef95163c7743 100644
--- a/mlir/test/Bindings/Python/dialects.py
+++ b/mlir/test/Bindings/Python/dialects.py
@@ -73,7 +73,7 @@ def testCustomOpView():
   f32 = mlir.ir.F32Type.get(ctx)
   loc = ctx.get_unknown_location()
   m = ctx.create_module(loc)
-  m_block = m.operation.regions[0].blocks[0]
+  m_block = m.body
   # TODO: Remove integer insertion in favor of InsertionPoint and/or op-based.
   ip = [0]
   def createInput():

diff  --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index 1d382c32fb42..87c8b647e6a0 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -67,9 +67,7 @@ void populateLoopBody(MlirContext ctx, MlirBlock loopBody,
 
 MlirModule makeAdd(MlirContext ctx, MlirLocation location) {
   MlirModule moduleOp = mlirModuleCreateEmpty(location);
-  MlirOperation module = mlirModuleGetOperation(moduleOp);
-  MlirRegion moduleBodyRegion = mlirOperationGetRegion(module, 0);
-  MlirBlock moduleBody = mlirRegionGetFirstBlock(moduleBodyRegion);
+  MlirBlock moduleBody = mlirModuleGetBody(moduleOp);
 
   MlirType memrefType = mlirTypeParseGet(ctx, "memref<?xf32>");
   MlirType funcBodyArgTypes[] = {memrefType, memrefType};


        


More information about the Mlir-commits mailing list