[Mlir-commits] [mlir] 37107e1 - [mlir][python] Add generic operation parse APIs
Rahul Kayaith
llvmlistbot at llvm.org
Wed Mar 1 15:17:21 PST 2023
Author: rkayaith
Date: 2023-03-01T18:17:12-05:00
New Revision: 37107e177e4a0a9ceab397f2667fe4dab98fb729
URL: https://github.com/llvm/llvm-project/commit/37107e177e4a0a9ceab397f2667fe4dab98fb729
DIFF: https://github.com/llvm/llvm-project/commit/37107e177e4a0a9ceab397f2667fe4dab98fb729.diff
LOG: [mlir][python] Add generic operation parse APIs
Currently the bindings only allow for parsing IR with a top-level
`builtin.module` op, since the parse APIs insert an implicit module op.
This change adds `Operation.parse`, which returns whatever top-level op
is actually in the source.
To simplify parsing of specific operations, `OpView.parse` is also
added, which handles the error checking for `OpView` subclasses.
Reviewed By: ftynse, stellaraccident
Differential Revision: https://reviews.llvm.org/D143352
Added:
Modified:
mlir/include/mlir-c/IR.h
mlir/include/mlir/Parser/Parser.h
mlir/lib/Bindings/Python/IRCore.cpp
mlir/lib/Bindings/Python/IRModule.h
mlir/lib/CAPI/IR/IR.cpp
mlir/lib/Parser/Parser.cpp
mlir/test/python/ir/operation.py
Removed:
################################################################################
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 023b99f42ba43..84d226b40b71a 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -422,6 +422,16 @@ mlirOpPrintingFlagsAssumeVerified(MlirOpPrintingFlags flags);
/// - Result type inference is enabled and cannot be performed.
MLIR_CAPI_EXPORTED MlirOperation mlirOperationCreate(MlirOperationState *state);
+/// Parses an operation, giving ownership to the caller. If parsing fails a null
+/// operation will be returned, and an error diagnostic emitted.
+///
+/// `sourceStr` may be either the text assembly format, or binary bytecode
+/// format. `sourceName` is used as the file name of the source; any IR without
+/// locations will get a `FileLineColLoc` location with `sourceName` as the file
+/// name.
+MLIR_CAPI_EXPORTED MlirOperation mlirOperationCreateParse(
+ MlirContext context, MlirStringRef sourceStr, MlirStringRef sourceName);
+
/// Creates a deep copy of an operation. The operation is not inserted and
/// ownership is transferred to the caller.
MLIR_CAPI_EXPORTED MlirOperation mlirOperationClone(MlirOperation op);
diff --git a/mlir/include/mlir/Parser/Parser.h b/mlir/include/mlir/Parser/Parser.h
index 1f38a2e8c7e02..828760fcbefa9 100644
--- a/mlir/include/mlir/Parser/Parser.h
+++ b/mlir/include/mlir/Parser/Parser.h
@@ -138,11 +138,14 @@ LogicalResult parseSourceFile(llvm::StringRef filename,
/// If the block is non-empty, the operations are placed before the current
/// terminator. If parsing is successful, success is returned. Otherwise, an
/// error message is emitted through the error handler registered in the
-/// context, and failure is returned. If `sourceFileLoc` is non-null, it is
-/// populated with a file location representing the start of the source file
-/// that is being parsed.
+/// context, and failure is returned.
+/// `sourceName` is used as the file name of the source; any IR without
+/// locations will get a `FileLineColLoc` location with `sourceName` as the file
+/// name. If `sourceFileLoc` is non-null, it is populated with a file location
+/// representing the start of the source file that is being parsed.
LogicalResult parseSourceString(llvm::StringRef sourceStr, Block *block,
const ParserConfig &config,
+ StringRef sourceName = "",
LocationAttr *sourceFileLoc = nullptr);
namespace detail {
@@ -235,12 +238,17 @@ parseSourceFile(llvm::StringRef filename,
/// failure is returned. `ContainerOpT` is required to have a single region
/// containing a single block, and must implement the
/// `SingleBlockImplicitTerminator` trait.
+/// `sourceName` is used as the file name of the source; any IR without
+/// locations will get a `FileLineColLoc` location with `sourceName` as the file
+/// name.
template <typename ContainerOpT = Operation *>
inline OwningOpRef<ContainerOpT> parseSourceString(llvm::StringRef sourceStr,
- const ParserConfig &config) {
+ const ParserConfig &config,
+ StringRef sourceName = "") {
LocationAttr sourceFileLoc;
Block block;
- if (failed(parseSourceString(sourceStr, &block, config, &sourceFileLoc)))
+ if (failed(parseSourceString(sourceStr, &block, config, sourceName,
+ &sourceFileLoc)))
return OwningOpRef<ContainerOpT>();
return detail::constructContainerOpForParserIfNecessary<ContainerOpT>(
&block, config.getContext(), sourceFileLoc);
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index e09f0fdeee901..12d37da5b098d 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -20,8 +20,8 @@
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
-#include <utility>
#include <optional>
+#include <utility>
namespace py = pybind11;
using namespace mlir;
@@ -1059,6 +1059,19 @@ PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
return created;
}
+PyOperationRef PyOperation::parse(PyMlirContextRef contextRef,
+ const std::string &sourceStr,
+ const std::string &sourceName) {
+ MlirOperation op =
+ mlirOperationCreateParse(contextRef->get(), toMlirStringRef(sourceStr),
+ toMlirStringRef(sourceName));
+ // TODO: Include error diagnostic messages in the exception message
+ if (mlirOperationIsNull(op))
+ throw py::value_error(
+ "Unable to parse operation assembly (see diagnostics)");
+ return PyOperation::createDetached(std::move(contextRef), op);
+}
+
void PyOperation::checkValid() const {
if (!valid) {
throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated");
@@ -2769,6 +2782,17 @@ void mlir::python::populateIRCore(py::module &m) {
py::arg("successors") = py::none(), py::arg("regions") = 0,
py::arg("loc") = py::none(), py::arg("ip") = py::none(),
kOperationCreateDocstring)
+ .def_static(
+ "parse",
+ [](const std::string &sourceStr, const std::string &sourceName,
+ DefaultingPyMlirContext context) {
+ return PyOperation::parse(context->getRef(), sourceStr, sourceName)
+ ->createOpView();
+ },
+ py::arg("source"), py::kw_only(), py::arg("source_name") = "",
+ py::arg("context") = py::none(),
+ "Parses an operation. Supports both text assembly format and binary "
+ "bytecode format.")
.def_property_readonly("parent",
[](PyOperation &self) -> py::object {
auto parent = self.getParentOperation();
@@ -2820,6 +2844,30 @@ void mlir::python::populateIRCore(py::module &m) {
py::arg("successors") = py::none(), py::arg("regions") = py::none(),
py::arg("loc") = py::none(), py::arg("ip") = py::none(),
"Builds a specific, generated OpView based on class level attributes.");
+ opViewClass.attr("parse") = classmethod(
+ [](const py::object &cls, const std::string &sourceStr,
+ const std::string &sourceName, DefaultingPyMlirContext context) {
+ PyOperationRef parsed =
+ PyOperation::parse(context->getRef(), sourceStr, sourceName);
+
+ // Check if the expected operation was parsed, and cast to to the
+ // appropriate `OpView` subclass if successful.
+ // NOTE: This accesses attributes that have been automatically added to
+ // `OpView` subclasses, and is not intended to be used on `OpView`
+ // directly.
+ std::string clsOpName =
+ py::cast<std::string>(cls.attr("OPERATION_NAME"));
+ MlirStringRef parsedOpName =
+ mlirIdentifierStr(mlirOperationGetName(*parsed.get()));
+ if (!mlirStringRefEqual(parsedOpName, toMlirStringRef(clsOpName)))
+ throw py::value_error(
+ "Expected a '" + clsOpName + "' op, got: '" +
+ std::string(parsedOpName.data, parsedOpName.length) + "'");
+ return cls.attr("_Raw")(parsed.getObject());
+ },
+ py::arg("cls"), py::arg("source"), py::kw_only(),
+ py::arg("source_name") = "", py::arg("context") = py::none(),
+ "Parses a specific, generated OpView based on class level attributes");
//----------------------------------------------------------------------------
// Mapping of PyRegion.
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 37115acbe0665..fa4bc1c3db1bf 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -9,9 +9,9 @@
#ifndef MLIR_BINDINGS_PYTHON_IRMODULES_H
#define MLIR_BINDINGS_PYTHON_IRMODULES_H
+#include <optional>
#include <utility>
#include <vector>
-#include <optional>
#include "PybindUtils.h"
@@ -548,6 +548,12 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
createDetached(PyMlirContextRef contextRef, MlirOperation operation,
pybind11::object parentKeepAlive = pybind11::object());
+ /// Parses a source string (either text assembly or bytecode), creating a
+ /// detached operation.
+ static PyOperationRef parse(PyMlirContextRef contextRef,
+ const std::string &sourceStr,
+ const std::string &sourceName);
+
/// Detaches the operation from its parent block and updates its state
/// accordingly.
void detachFromParent() {
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index e83f0f8240aea..051559acd440c 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -368,6 +368,15 @@ MlirOperation mlirOperationCreate(MlirOperationState *state) {
return result;
}
+MlirOperation mlirOperationCreateParse(MlirContext context,
+ MlirStringRef sourceStr,
+ MlirStringRef sourceName) {
+
+ return wrap(
+ parseSourceString(unwrap(sourceStr), unwrap(context), unwrap(sourceName))
+ .release());
+}
+
MlirOperation mlirOperationClone(MlirOperation op) {
return wrap(unwrap(op)->clone());
}
diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index 57dd3eeb2714a..6f8f46f30281f 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -89,8 +89,9 @@ LogicalResult mlir::parseSourceFile(
LogicalResult mlir::parseSourceString(llvm::StringRef sourceStr, Block *block,
const ParserConfig &config,
+ StringRef sourceName,
LocationAttr *sourceFileLoc) {
- auto memBuffer = llvm::MemoryBuffer::getMemBuffer(sourceStr);
+ auto memBuffer = llvm::MemoryBuffer::getMemBuffer(sourceStr, sourceName);
if (!memBuffer)
return failure();
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index f393cf92c3c1a..bca27a680bdea 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -4,6 +4,7 @@
import io
import itertools
from mlir.ir import *
+from mlir.dialects.builtin import ModuleOp
def run(f):
@@ -900,3 +901,31 @@ def testOperationHash():
with ctx, Location.unknown():
op = Operation.create("custom.op1")
assert hash(op) == hash(op.operation)
+
+
+# CHECK-LABEL: TEST: testOperationParse
+ at run
+def testOperationParse():
+ with Context() as ctx:
+ ctx.allow_unregistered_dialects = True
+
+ # Generic operation parsing.
+ m = Operation.parse('module {}')
+ o = Operation.parse('"test.foo"() : () -> ()')
+ assert isinstance(m, ModuleOp)
+ assert type(o) is OpView
+
+ # Parsing specific operation.
+ m = ModuleOp.parse('module {}')
+ assert isinstance(m, ModuleOp)
+ try:
+ ModuleOp.parse('"test.foo"() : () -> ()')
+ except ValueError as e:
+ # CHECK: error: Expected a 'builtin.module' op, got: 'test.foo'
+ print(f"error: {e}")
+ else:
+ assert False, "expected error"
+
+ o = Operation.parse('"test.foo"() : () -> ()', source_name="my-source-string")
+ # CHECK: op_with_source_name: "test.foo"() : () -> () loc("my-source-string":1:1)
+ print(f"op_with_source_name: {o.get_asm(enable_debug_info=True, use_local_scope=True)}")
More information about the Mlir-commits
mailing list