[Mlir-commits] [mlir] 4d0d295 - [mlir][python] Allow specifying block arg locations

Rahul Kayaith llvmlistbot at llvm.org
Tue May 9 09:40:22 PDT 2023


Author: Rahul Kayaith
Date: 2023-05-09T12:40:17-04:00
New Revision: 4d0d295b618edfc937d5bf247f0853df5c70cb96

URL: https://github.com/llvm/llvm-project/commit/4d0d295b618edfc937d5bf247f0853df5c70cb96
DIFF: https://github.com/llvm/llvm-project/commit/4d0d295b618edfc937d5bf247f0853df5c70cb96.diff

LOG: [mlir][python] Allow specifying block arg locations

Currently blocks are always created with UnknownLoc's for their arguments. This
adds an `arg_locs` argument to all block creation APIs, which takes an optional
sequence of locations to use, one per block argument. If no locations are
supplied, the current Location context is used.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D150084

Added: 
    

Modified: 
    mlir/lib/Bindings/Python/IRCore.cpp
    mlir/python/mlir/dialects/_func_ops_ext.py
    mlir/test/python/ir/blocks.py

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 7ffa464009fc8..2158a4cb56206 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -193,6 +193,31 @@ static MlirStringRef toMlirStringRef(const std::string &s) {
   return mlirStringRefCreate(s.data(), s.size());
 }
 
+/// Create a block, using the current location context if no locations are
+/// specified.
+static MlirBlock createBlock(const py::sequence &pyArgTypes,
+                             const std::optional<py::sequence> &pyArgLocs) {
+  SmallVector<MlirType> argTypes;
+  argTypes.reserve(pyArgTypes.size());
+  for (const auto &pyType : pyArgTypes)
+    argTypes.push_back(pyType.cast<PyType &>());
+
+  SmallVector<MlirLocation> argLocs;
+  if (pyArgLocs) {
+    argLocs.reserve(pyArgLocs->size());
+    for (const auto &pyLoc : *pyArgLocs)
+      argLocs.push_back(pyLoc.cast<PyLocation &>());
+  } else if (!argTypes.empty()) {
+    argLocs.assign(argTypes.size(), DefaultingPyLocation::resolve());
+  }
+
+  if (argTypes.size() != argLocs.size())
+    throw py::value_error(("Expected " + Twine(argTypes.size()) +
+                           " locations, got: " + Twine(argLocs.size()))
+                              .str());
+  return mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data());
+}
+
 /// Wrapper for the global LLVM debugging flag.
 struct PyGlobalDebugFlag {
   static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); }
@@ -364,21 +389,10 @@ class PyBlockList {
     throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block");
   }
 
-  PyBlock appendBlock(const py::args &pyArgTypes) {
+  PyBlock appendBlock(const py::args &pyArgTypes,
+                      const std::optional<py::sequence> &pyArgLocs) {
     operation->checkValid();
-    llvm::SmallVector<MlirType, 4> argTypes;
-    llvm::SmallVector<MlirLocation, 4> argLocs;
-    argTypes.reserve(pyArgTypes.size());
-    argLocs.reserve(pyArgTypes.size());
-    for (auto &pyArg : pyArgTypes) {
-      argTypes.push_back(pyArg.cast<PyType &>());
-      // TODO: Pass in a proper location here.
-      argLocs.push_back(
-          mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back())));
-    }
-
-    MlirBlock block =
-        mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data());
+    MlirBlock block = createBlock(pyArgTypes, pyArgLocs);
     mlirRegionAppendOwnedBlock(region, block);
     return PyBlock(operation, block);
   }
@@ -388,7 +402,8 @@ class PyBlockList {
         .def("__getitem__", &PyBlockList::dunderGetItem)
         .def("__iter__", &PyBlockList::dunderIter)
         .def("__len__", &PyBlockList::dunderLen)
-        .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring);
+        .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring,
+             py::arg("arg_locs") = std::nullopt);
   }
 
 private:
@@ -2966,27 +2981,17 @@ void mlir::python::populateIRCore(py::module &m) {
           "Returns a forward-optimized sequence of operations.")
       .def_static(
           "create_at_start",
-          [](PyRegion &parent, py::list pyArgTypes) {
+          [](PyRegion &parent, const py::list &pyArgTypes,
+             const std::optional<py::sequence> &pyArgLocs) {
             parent.checkValid();
-            llvm::SmallVector<MlirType, 4> argTypes;
-            llvm::SmallVector<MlirLocation, 4> argLocs;
-            argTypes.reserve(pyArgTypes.size());
-            argLocs.reserve(pyArgTypes.size());
-            for (auto &pyArg : pyArgTypes) {
-              argTypes.push_back(pyArg.cast<PyType &>());
-              // TODO: Pass in a proper location here.
-              argLocs.push_back(
-                  mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back())));
-            }
-
-            MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data(),
-                                              argLocs.data());
+            MlirBlock block = createBlock(pyArgTypes, pyArgLocs);
             mlirRegionInsertOwnedBlock(parent, 0, block);
             return PyBlock(parent.getParentOperation(), block);
           },
           py::arg("parent"), py::arg("arg_types") = py::list(),
+          py::arg("arg_locs") = std::nullopt,
           "Creates and returns a new Block at the beginning of the given "
-          "region (with given argument types).")
+          "region (with given argument types and locations).")
       .def(
           "append_to",
           [](PyBlock &self, PyRegion &region) {
@@ -2998,50 +3003,30 @@ void mlir::python::populateIRCore(py::module &m) {
           "Append this block to a region, transferring ownership if necessary")
       .def(
           "create_before",
-          [](PyBlock &self, py::args pyArgTypes) {
+          [](PyBlock &self, const py::args &pyArgTypes,
+             const std::optional<py::sequence> &pyArgLocs) {
             self.checkValid();
-            llvm::SmallVector<MlirType, 4> argTypes;
-            llvm::SmallVector<MlirLocation, 4> argLocs;
-            argTypes.reserve(pyArgTypes.size());
-            argLocs.reserve(pyArgTypes.size());
-            for (auto &pyArg : pyArgTypes) {
-              argTypes.push_back(pyArg.cast<PyType &>());
-              // TODO: Pass in a proper location here.
-              argLocs.push_back(
-                  mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back())));
-            }
-
-            MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data(),
-                                              argLocs.data());
+            MlirBlock block = createBlock(pyArgTypes, pyArgLocs);
             MlirRegion region = mlirBlockGetParentRegion(self.get());
             mlirRegionInsertOwnedBlockBefore(region, self.get(), block);
             return PyBlock(self.getParentOperation(), block);
           },
+          py::arg("arg_locs") = std::nullopt,
           "Creates and returns a new Block before this block "
-          "(with given argument types).")
+          "(with given argument types and locations).")
       .def(
           "create_after",
-          [](PyBlock &self, py::args pyArgTypes) {
+          [](PyBlock &self, const py::args &pyArgTypes,
+             const std::optional<py::sequence> &pyArgLocs) {
             self.checkValid();
-            llvm::SmallVector<MlirType, 4> argTypes;
-            llvm::SmallVector<MlirLocation, 4> argLocs;
-            argTypes.reserve(pyArgTypes.size());
-            argLocs.reserve(pyArgTypes.size());
-            for (auto &pyArg : pyArgTypes) {
-              argTypes.push_back(pyArg.cast<PyType &>());
-
-              // TODO: Pass in a proper location here.
-              argLocs.push_back(
-                  mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back())));
-            }
-            MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data(),
-                                              argLocs.data());
+            MlirBlock block = createBlock(pyArgTypes, pyArgLocs);
             MlirRegion region = mlirBlockGetParentRegion(self.get());
             mlirRegionInsertOwnedBlockAfter(region, self.get(), block);
             return PyBlock(self.getParentOperation(), block);
           },
+          py::arg("arg_locs") = std::nullopt,
           "Creates and returns a new Block after this block "
-          "(with given argument types).")
+          "(with given argument types and locations).")
       .def(
           "__iter__",
           [](PyBlock &self) {

diff  --git a/mlir/python/mlir/dialects/_func_ops_ext.py b/mlir/python/mlir/dialects/_func_ops_ext.py
index 79577463d9199..56df423d30a0f 100644
--- a/mlir/python/mlir/dialects/_func_ops_ext.py
+++ b/mlir/python/mlir/dialects/_func_ops_ext.py
@@ -90,7 +90,7 @@ def entry_block(self):
       raise IndexError('External function does not have a body')
     return self.regions[0].blocks[0]
 
-  def add_entry_block(self):
+  def add_entry_block(self, arg_locs: Optional[Sequence[Location]] = None):
     """
     Add an entry block to the function body using the function signature to
     infer block arguments.
@@ -98,7 +98,7 @@ def add_entry_block(self):
     """
     if not self.is_external:
       raise IndexError('The function already has an entry block!')
-    self.body.blocks.append(*self.type.inputs)
+    self.body.blocks.append(*self.type.inputs, arg_locs=arg_locs)
     return self.body.blocks[0]
 
   @property

diff  --git a/mlir/test/python/ir/blocks.py b/mlir/test/python/ir/blocks.py
index 47aafca7e2d56..e929d79e6c5cc 100644
--- a/mlir/test/python/ir/blocks.py
+++ b/mlir/test/python/ir/blocks.py
@@ -18,28 +18,28 @@ def run(f):
 
 
 # CHECK-LABEL: TEST: testBlockCreation
-# CHECK: func @test(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i16)
+# CHECK: func @test(%[[ARG0:.*]]: i32 loc("arg0"), %[[ARG1:.*]]: i16 loc("arg1"))
 # CHECK:   cf.br ^bb1(%[[ARG1]] : i16)
-# CHECK: ^bb1(%[[PHI0:.*]]: i16):
+# CHECK: ^bb1(%[[PHI0:.*]]: i16 loc("middle")):
 # CHECK:   cf.br ^bb2(%[[ARG0]] : i32)
-# CHECK: ^bb2(%[[PHI1:.*]]: i32):
+# CHECK: ^bb2(%[[PHI1:.*]]: i32 loc("successor")):
 # CHECK:   return
 @run
 def testBlockCreation():
   with Context() as ctx, Location.unknown():
-    module = Module.create()
+    module = builtin.ModuleOp()
     with InsertionPoint(module.body):
       f_type = FunctionType.get(
           [IntegerType.get_signless(32),
            IntegerType.get_signless(16)], [])
       f_op = func.FuncOp("test", f_type)
-      entry_block = f_op.add_entry_block()
+      entry_block = f_op.add_entry_block([Location.name("arg0"), Location.name("arg1")])
       i32_arg, i16_arg = entry_block.arguments
-      successor_block = entry_block.create_after(i32_arg.type)
+      successor_block = entry_block.create_after(i32_arg.type, arg_locs=[Location.name("successor")])
       with InsertionPoint(successor_block) as successor_ip:
         assert successor_ip.block == successor_block
         func.ReturnOp([])
-      middle_block = successor_block.create_before(i16_arg.type)
+      middle_block = successor_block.create_before(i16_arg.type, arg_locs=[Location.name("middle")])
 
       with InsertionPoint(entry_block) as entry_ip:
         assert entry_ip.block == entry_block
@@ -48,27 +48,57 @@ def testBlockCreation():
       with InsertionPoint(middle_block) as middle_ip:
         assert middle_ip.block == middle_block
         cf.BranchOp([i32_arg], dest=successor_block)
-    print(module.operation)
+    module.print(enable_debug_info=True)
     # Ensure region back references are coherent.
     assert entry_block.region == middle_block.region == successor_block.region
 
 
+# CHECK-LABEL: TEST: testBlockCreationArgLocs
+ at run
+def testBlockCreationArgLocs():
+  with Context() as ctx:
+    ctx.allow_unregistered_dialects = True
+    f32 = F32Type.get()
+    op = Operation.create("test", regions=1, loc=Location.unknown())
+    blocks = op.regions[0].blocks
+
+    with Location.name("default_loc"):
+      blocks.append(f32)
+    blocks.append()
+    # CHECK:      ^bb0(%{{.+}}: f32 loc("default_loc")):
+    # CHECK-NEXT: ^bb1:
+    op.print(enable_debug_info=True)
+
+    try:
+      blocks.append(f32)
+    except RuntimeError as err:
+      # CHECK: Missing loc: An MLIR function requires a Location but none was provided
+      print("Missing loc:", err)
+
+    try:
+      blocks.append(f32, f32, arg_locs=[Location.unknown()])
+    except ValueError as err:
+      # CHECK: Wrong loc count: Expected 2 locations, got: 1
+      print("Wrong loc count:", err)
+
+
 # CHECK-LABEL: TEST: testFirstBlockCreation
-# CHECK: func @test(%{{.*}}: f32)
+# CHECK: func @test(%{{.*}}: f32 loc("arg_loc"))
 # CHECK:   return
 @run
 def testFirstBlockCreation():
   with Context() as ctx, Location.unknown():
-    module = Module.create()
+    module = builtin.ModuleOp()
     f32 = F32Type.get()
     with InsertionPoint(module.body):
       f = func.FuncOp("test", ([f32], []))
-      entry_block = Block.create_at_start(f.operation.regions[0], [f32])
+      entry_block = Block.create_at_start(f.operation.regions[0],
+                                          [f32], [Location.name("arg_loc")])
       with InsertionPoint(entry_block):
         func.ReturnOp([])
 
-    print(module)
-    assert module.operation.verify()
+    module.print(enable_debug_info=True)
+    assert module.verify()
     assert f.body.blocks[0] == entry_block
 
 


        


More information about the Mlir-commits mailing list