[Mlir-commits] [mlir] 37107e1 - [mlir][python] Add generic operation parse APIs

Rahul Kayaith llvmlistbot at llvm.org
Wed Mar 1 15:17:21 PST 2023


Author: rkayaith
Date: 2023-03-01T18:17:12-05:00
New Revision: 37107e177e4a0a9ceab397f2667fe4dab98fb729

URL: https://github.com/llvm/llvm-project/commit/37107e177e4a0a9ceab397f2667fe4dab98fb729
DIFF: https://github.com/llvm/llvm-project/commit/37107e177e4a0a9ceab397f2667fe4dab98fb729.diff

LOG: [mlir][python] Add generic operation parse APIs

Currently the bindings only allow for parsing IR with a top-level
`builtin.module` op, since the parse APIs insert an implicit module op.
This change adds `Operation.parse`, which returns whatever top-level op
is actually in the source.

To simplify parsing of specific operations, `OpView.parse` is also
added, which handles the error checking for `OpView` subclasses.

Reviewed By: ftynse, stellaraccident

Differential Revision: https://reviews.llvm.org/D143352

Added: 
    

Modified: 
    mlir/include/mlir-c/IR.h
    mlir/include/mlir/Parser/Parser.h
    mlir/lib/Bindings/Python/IRCore.cpp
    mlir/lib/Bindings/Python/IRModule.h
    mlir/lib/CAPI/IR/IR.cpp
    mlir/lib/Parser/Parser.cpp
    mlir/test/python/ir/operation.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 023b99f42ba43..84d226b40b71a 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -422,6 +422,16 @@ mlirOpPrintingFlagsAssumeVerified(MlirOpPrintingFlags flags);
 ///   - Result type inference is enabled and cannot be performed.
 MLIR_CAPI_EXPORTED MlirOperation mlirOperationCreate(MlirOperationState *state);
 
+/// Parses an operation, giving ownership to the caller. If parsing fails a null
+/// operation will be returned, and an error diagnostic emitted.
+///
+/// `sourceStr` may be either the text assembly format, or binary bytecode
+/// format. `sourceName` is used as the file name of the source; any IR without
+/// locations will get a `FileLineColLoc` location with `sourceName` as the file
+/// name.
+MLIR_CAPI_EXPORTED MlirOperation mlirOperationCreateParse(
+    MlirContext context, MlirStringRef sourceStr, MlirStringRef sourceName);
+
 /// Creates a deep copy of an operation. The operation is not inserted and
 /// ownership is transferred to the caller.
 MLIR_CAPI_EXPORTED MlirOperation mlirOperationClone(MlirOperation op);

diff  --git a/mlir/include/mlir/Parser/Parser.h b/mlir/include/mlir/Parser/Parser.h
index 1f38a2e8c7e02..828760fcbefa9 100644
--- a/mlir/include/mlir/Parser/Parser.h
+++ b/mlir/include/mlir/Parser/Parser.h
@@ -138,11 +138,14 @@ LogicalResult parseSourceFile(llvm::StringRef filename,
 /// If the block is non-empty, the operations are placed before the current
 /// terminator. If parsing is successful, success is returned. Otherwise, an
 /// error message is emitted through the error handler registered in the
-/// context, and failure is returned. If `sourceFileLoc` is non-null, it is
-/// populated with a file location representing the start of the source file
-/// that is being parsed.
+/// context, and failure is returned.
+/// `sourceName` is used as the file name of the source; any IR without
+/// locations will get a `FileLineColLoc` location with `sourceName` as the file
+/// name. If `sourceFileLoc` is non-null, it is populated with a file location
+/// representing the start of the source file that is being parsed.
 LogicalResult parseSourceString(llvm::StringRef sourceStr, Block *block,
                                 const ParserConfig &config,
+                                StringRef sourceName = "",
                                 LocationAttr *sourceFileLoc = nullptr);
 
 namespace detail {
@@ -235,12 +238,17 @@ parseSourceFile(llvm::StringRef filename,
 /// failure is returned. `ContainerOpT` is required to have a single region
 /// containing a single block, and must implement the
 /// `SingleBlockImplicitTerminator` trait.
+/// `sourceName` is used as the file name of the source; any IR without
+/// locations will get a `FileLineColLoc` location with `sourceName` as the file
+/// name.
 template <typename ContainerOpT = Operation *>
 inline OwningOpRef<ContainerOpT> parseSourceString(llvm::StringRef sourceStr,
-                                                   const ParserConfig &config) {
+                                                   const ParserConfig &config,
+                                                   StringRef sourceName = "") {
   LocationAttr sourceFileLoc;
   Block block;
-  if (failed(parseSourceString(sourceStr, &block, config, &sourceFileLoc)))
+  if (failed(parseSourceString(sourceStr, &block, config, sourceName,
+                               &sourceFileLoc)))
     return OwningOpRef<ContainerOpT>();
   return detail::constructContainerOpForParserIfNecessary<ContainerOpT>(
       &block, config.getContext(), sourceFileLoc);

diff  --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index e09f0fdeee901..12d37da5b098d 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -20,8 +20,8 @@
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/SmallVector.h"
 
-#include <utility>
 #include <optional>
+#include <utility>
 
 namespace py = pybind11;
 using namespace mlir;
@@ -1059,6 +1059,19 @@ PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
   return created;
 }
 
+PyOperationRef PyOperation::parse(PyMlirContextRef contextRef,
+                                  const std::string &sourceStr,
+                                  const std::string &sourceName) {
+  MlirOperation op =
+      mlirOperationCreateParse(contextRef->get(), toMlirStringRef(sourceStr),
+                               toMlirStringRef(sourceName));
+  // TODO: Include error diagnostic messages in the exception message
+  if (mlirOperationIsNull(op))
+    throw py::value_error(
+        "Unable to parse operation assembly (see diagnostics)");
+  return PyOperation::createDetached(std::move(contextRef), op);
+}
+
 void PyOperation::checkValid() const {
   if (!valid) {
     throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated");
@@ -2769,6 +2782,17 @@ void mlir::python::populateIRCore(py::module &m) {
                   py::arg("successors") = py::none(), py::arg("regions") = 0,
                   py::arg("loc") = py::none(), py::arg("ip") = py::none(),
                   kOperationCreateDocstring)
+      .def_static(
+          "parse",
+          [](const std::string &sourceStr, const std::string &sourceName,
+             DefaultingPyMlirContext context) {
+            return PyOperation::parse(context->getRef(), sourceStr, sourceName)
+                ->createOpView();
+          },
+          py::arg("source"), py::kw_only(), py::arg("source_name") = "",
+          py::arg("context") = py::none(),
+          "Parses an operation. Supports both text assembly format and binary "
+          "bytecode format.")
       .def_property_readonly("parent",
                              [](PyOperation &self) -> py::object {
                                auto parent = self.getParentOperation();
@@ -2820,6 +2844,30 @@ void mlir::python::populateIRCore(py::module &m) {
       py::arg("successors") = py::none(), py::arg("regions") = py::none(),
       py::arg("loc") = py::none(), py::arg("ip") = py::none(),
       "Builds a specific, generated OpView based on class level attributes.");
+  opViewClass.attr("parse") = classmethod(
+      [](const py::object &cls, const std::string &sourceStr,
+         const std::string &sourceName, DefaultingPyMlirContext context) {
+        PyOperationRef parsed =
+            PyOperation::parse(context->getRef(), sourceStr, sourceName);
+
+        // Check if the expected operation was parsed, and cast to to the
+        // appropriate `OpView` subclass if successful.
+        // NOTE: This accesses attributes that have been automatically added to
+        // `OpView` subclasses, and is not intended to be used on `OpView`
+        // directly.
+        std::string clsOpName =
+            py::cast<std::string>(cls.attr("OPERATION_NAME"));
+        MlirStringRef parsedOpName =
+            mlirIdentifierStr(mlirOperationGetName(*parsed.get()));
+        if (!mlirStringRefEqual(parsedOpName, toMlirStringRef(clsOpName)))
+          throw py::value_error(
+              "Expected a '" + clsOpName + "' op, got: '" +
+              std::string(parsedOpName.data, parsedOpName.length) + "'");
+        return cls.attr("_Raw")(parsed.getObject());
+      },
+      py::arg("cls"), py::arg("source"), py::kw_only(),
+      py::arg("source_name") = "", py::arg("context") = py::none(),
+      "Parses a specific, generated OpView based on class level attributes");
 
   //----------------------------------------------------------------------------
   // Mapping of PyRegion.

diff  --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 37115acbe0665..fa4bc1c3db1bf 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -9,9 +9,9 @@
 #ifndef MLIR_BINDINGS_PYTHON_IRMODULES_H
 #define MLIR_BINDINGS_PYTHON_IRMODULES_H
 
+#include <optional>
 #include <utility>
 #include <vector>
-#include <optional>
 
 #include "PybindUtils.h"
 
@@ -548,6 +548,12 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
   createDetached(PyMlirContextRef contextRef, MlirOperation operation,
                  pybind11::object parentKeepAlive = pybind11::object());
 
+  /// Parses a source string (either text assembly or bytecode), creating a
+  /// detached operation.
+  static PyOperationRef parse(PyMlirContextRef contextRef,
+                              const std::string &sourceStr,
+                              const std::string &sourceName);
+
   /// Detaches the operation from its parent block and updates its state
   /// accordingly.
   void detachFromParent() {

diff  --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index e83f0f8240aea..051559acd440c 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -368,6 +368,15 @@ MlirOperation mlirOperationCreate(MlirOperationState *state) {
   return result;
 }
 
+MlirOperation mlirOperationCreateParse(MlirContext context,
+                                       MlirStringRef sourceStr,
+                                       MlirStringRef sourceName) {
+
+  return wrap(
+      parseSourceString(unwrap(sourceStr), unwrap(context), unwrap(sourceName))
+          .release());
+}
+
 MlirOperation mlirOperationClone(MlirOperation op) {
   return wrap(unwrap(op)->clone());
 }

diff  --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index 57dd3eeb2714a..6f8f46f30281f 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -89,8 +89,9 @@ LogicalResult mlir::parseSourceFile(
 
 LogicalResult mlir::parseSourceString(llvm::StringRef sourceStr, Block *block,
                                       const ParserConfig &config,
+                                      StringRef sourceName,
                                       LocationAttr *sourceFileLoc) {
-  auto memBuffer = llvm::MemoryBuffer::getMemBuffer(sourceStr);
+  auto memBuffer = llvm::MemoryBuffer::getMemBuffer(sourceStr, sourceName);
   if (!memBuffer)
     return failure();
 

diff  --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index f393cf92c3c1a..bca27a680bdea 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -4,6 +4,7 @@
 import io
 import itertools
 from mlir.ir import *
+from mlir.dialects.builtin import ModuleOp
 
 
 def run(f):
@@ -900,3 +901,31 @@ def testOperationHash():
   with ctx, Location.unknown():
     op = Operation.create("custom.op1")
     assert hash(op) == hash(op.operation)
+
+
+# CHECK-LABEL: TEST: testOperationParse
+ at run
+def testOperationParse():
+  with Context() as ctx:
+    ctx.allow_unregistered_dialects = True
+
+    # Generic operation parsing.
+    m = Operation.parse('module {}')
+    o = Operation.parse('"test.foo"() : () -> ()')
+    assert isinstance(m, ModuleOp)
+    assert type(o) is OpView
+
+    # Parsing specific operation.
+    m = ModuleOp.parse('module {}')
+    assert isinstance(m, ModuleOp)
+    try:
+      ModuleOp.parse('"test.foo"() : () -> ()')
+    except ValueError as e:
+      # CHECK: error: Expected a 'builtin.module' op, got: 'test.foo'
+      print(f"error: {e}")
+    else:
+      assert False, "expected error"
+
+    o = Operation.parse('"test.foo"() : () -> ()', source_name="my-source-string")
+    # CHECK: op_with_source_name: "test.foo"() : () -> () loc("my-source-string":1:1)
+    print(f"op_with_source_name: {o.get_asm(enable_debug_info=True, use_local_scope=True)}")


        


More information about the Mlir-commits mailing list