[Mlir-commits] [llvm] [mlir] [mlir] Add PDL C & Python usage (PR #94714)
Jacques Pienaar
llvmlistbot at llvm.org
Sun Jun 9 15:21:46 PDT 2024
https://github.com/jpienaar updated https://github.com/llvm/llvm-project/pull/94714
>From f70e85e83e36e1c786d607fa89b126d248bc6f3f Mon Sep 17 00:00:00 2001
From: Jacques Pienaar <jpienaar at google.com>
Date: Fri, 7 Jun 2024 01:26:36 +0000
Subject: [PATCH 1/2] [mlir] Add PDL C & Python usage
Following a rather direct approach to expose PDL usage from C and then Python. This doesn't yes plumb through adding support for custom matchers through this interface, so constrained to basics initially.
Signed-off-by: Jacques Pienaar <jpienaar at google.com>
---
mlir/include/mlir-c/Bindings/Python/Interop.h | 21 ++++
mlir/include/mlir-c/Rewrite.h | 60 ++++++++++
.../mlir/Bindings/Python/PybindAdaptors.h | 20 ++++
mlir/lib/Bindings/Python/IRModule.h | 1 +
mlir/lib/Bindings/Python/MainModule.cpp | 4 +
mlir/lib/Bindings/Python/Rewrite.cpp | 110 ++++++++++++++++++
mlir/lib/Bindings/Python/Rewrite.h | 22 ++++
mlir/lib/CAPI/Transforms/CMakeLists.txt | 3 +
mlir/lib/CAPI/Transforms/Rewrite.cpp | 83 +++++++++++++
mlir/python/CMakeLists.txt | 2 +
mlir/python/mlir/rewrite.py | 5 +
mlir/test/python/integration/dialects/pdl.py | 67 +++++++++++
.../llvm-project-overlay/mlir/BUILD.bazel | 12 +-
.../mlir/python/BUILD.bazel | 7 ++
14 files changed, 416 insertions(+), 1 deletion(-)
create mode 100644 mlir/include/mlir-c/Rewrite.h
create mode 100644 mlir/lib/Bindings/Python/Rewrite.cpp
create mode 100644 mlir/lib/Bindings/Python/Rewrite.h
create mode 100644 mlir/lib/CAPI/Transforms/Rewrite.cpp
create mode 100644 mlir/python/mlir/rewrite.py
create mode 100644 mlir/test/python/integration/dialects/pdl.py
diff --git a/mlir/include/mlir-c/Bindings/Python/Interop.h b/mlir/include/mlir-c/Bindings/Python/Interop.h
index 0a36e97c2ae68..a33190c380d37 100644
--- a/mlir/include/mlir-c/Bindings/Python/Interop.h
+++ b/mlir/include/mlir-c/Bindings/Python/Interop.h
@@ -39,6 +39,7 @@
#include "mlir-c/IR.h"
#include "mlir-c/IntegerSet.h"
#include "mlir-c/Pass.h"
+#include "mlir-c/Rewrite.h"
// The 'mlir' Python package is relocatable and supports co-existing in multiple
// projects. Each project must define its outer package prefix with this define
@@ -284,6 +285,26 @@ static inline MlirModule mlirPythonCapsuleToModule(PyObject *capsule) {
return module;
}
+/** Creates a capsule object encapsulating the raw C-API
+ * MlirFrozenRewritePatternSet.
+ * The returned capsule does not extend or affect ownership of any Python
+ * objects that reference the module in any way. */
+static inline PyObject *
+mlirPythonFrozenRewritePatternSetToCapsule(MlirFrozenRewritePatternSet pm) {
+ return PyCapsule_New(MLIR_PYTHON_GET_WRAPPED_POINTER(pm),
+ MLIR_PYTHON_CAPSULE_PASS_MANAGER, NULL);
+}
+
+/** Extracts an MlirFrozenRewritePatternSet from a capsule as produced from
+ * mlirPythonFrozenRewritePatternSetToCapsule. If the capsule is not of the
+ * right type, then a null module is returned. */
+static inline MlirFrozenRewritePatternSet
+mlirPythonCapsuleToFrozenRewritePatternSet(PyObject *capsule) {
+ void *ptr = PyCapsule_GetPointer(capsule, MLIR_PYTHON_CAPSULE_PASS_MANAGER);
+ MlirFrozenRewritePatternSet pm = {ptr};
+ return pm;
+}
+
/** Creates a capsule object encapsulating the raw C-API MlirPassManager.
* The returned capsule does not extend or affect ownership of any Python
* objects that reference the module in any way. */
diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
new file mode 100644
index 0000000000000..45218a1cd4ebd
--- /dev/null
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -0,0 +1,60 @@
+//===-- mlir-c/Rewrite.h - Helpers for C API to Rewrites ----------*- C -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM
+// Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header declares the registration and creation method for
+// rewrite patterns.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_C_REWRITE_H
+#define MLIR_C_REWRITE_H
+
+#include "mlir-c/IR.h"
+#include "mlir-c/Support.h"
+#include "mlir/Config/mlir-config.h"
+
+//===----------------------------------------------------------------------===//
+/// Opaque type declarations (see mlir-c/IR.h for more details).
+//===----------------------------------------------------------------------===//
+
+#define DEFINE_C_API_STRUCT(name, storage) \
+ struct name { \
+ storage *ptr; \
+ }; \
+ typedef struct name name
+
+DEFINE_C_API_STRUCT(MlirFrozenRewritePatternSet, void);
+DEFINE_C_API_STRUCT(MlirGreedyRewriteDriverConfig, void);
+DEFINE_C_API_STRUCT(MlirRewritePatternSet, void);
+
+MLIR_CAPI_EXPORTED MlirFrozenRewritePatternSet
+mlirFreezeRewritePattern(MlirRewritePatternSet op);
+
+MLIR_CAPI_EXPORTED void
+mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op);
+
+MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily(
+ MlirModule op, MlirFrozenRewritePatternSet patterns,
+ MlirGreedyRewriteDriverConfig);
+
+#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
+DEFINE_C_API_STRUCT(MlirPDLPatternModule, void);
+
+MLIR_CAPI_EXPORTED MlirPDLPatternModule
+mlirPDLPatternModuleFromModule(MlirModule op);
+
+MLIR_CAPI_EXPORTED void mlirPDLPatternModuleDestroy(MlirPDLPatternModule op);
+
+MLIR_CAPI_EXPORTED MlirRewritePatternSet
+mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op);
+#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
+
+#undef DEFINE_C_API_STRUCT
+
+#endif // MLIR_C_REWRITE_H
diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
index d8f22c7aa1709..39ee1551ccb2e 100644
--- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
+++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
@@ -198,6 +198,26 @@ struct type_caster<MlirModule> {
};
};
+/// Casts object <-> MlirFrozenRewritePatternSet.
+template <> struct type_caster<MlirFrozenRewritePatternSet> {
+ PYBIND11_TYPE_CASTER(MlirFrozenRewritePatternSet,
+ _("MlirFrozenRewritePatternSet"));
+ bool load(handle src, bool) {
+ py::object capsule = mlirApiObjectToCapsule(src);
+ value = mlirPythonCapsuleToFrozenRewritePatternSet(capsule.ptr());
+ return true; // !mlirModuleIsNull(value);
+ }
+ static handle cast(MlirFrozenRewritePatternSet v, return_value_policy,
+ handle) {
+ py::object capsule = py::reinterpret_steal<py::object>(
+ mlirPythonFrozenRewritePatternSetToCapsule(v));
+ return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("rewrite"))
+ .attr("FrozenRewritePatternSet")
+ .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
+ .release();
+ };
+};
+
/// Casts object <-> MlirOperation.
template <>
struct type_caster<MlirOperation> {
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index b038a0c54d29b..6f0dc2690a5e6 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -22,6 +22,7 @@
#include "mlir-c/Diagnostics.h"
#include "mlir-c/IR.h"
#include "mlir-c/IntegerSet.h"
+#include "mlir-c/Transforms.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"
#include "llvm/ADT/DenseMap.h"
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index 17272472ccca4..8da1ab16a4514 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -11,6 +11,7 @@
#include "Globals.h"
#include "IRModule.h"
#include "Pass.h"
+#include "Rewrite.h"
namespace py = pybind11;
using namespace mlir;
@@ -116,6 +117,9 @@ PYBIND11_MODULE(_mlir, m) {
populateIRInterfaces(irModule);
populateIRTypes(irModule);
+ auto rewriteModule = m.def_submodule("rewrite", "MLIR Rewrite Bindings");
+ populateRewriteSubmodule(rewriteModule);
+
// Define and populate PassManager submodule.
auto passModule =
m.def_submodule("passmanager", "MLIR Pass Management Bindings");
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
new file mode 100644
index 0000000000000..1d8128be9f082
--- /dev/null
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -0,0 +1,110 @@
+//===- Rewrite.cpp - Rewrite ----------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "Rewrite.h"
+
+#include "IRModule.h"
+#include "mlir-c/Bindings/Python/Interop.h"
+#include "mlir-c/Rewrite.h"
+#include "mlir/Config/mlir-config.h"
+
+namespace py = pybind11;
+using namespace mlir;
+using namespace py::literals;
+using namespace mlir::python;
+
+namespace {
+
+#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
+/// Owning Wrapper around a PDLPatternModule.
+class PyPDLPatternModule {
+public:
+ PyPDLPatternModule(MlirPDLPatternModule module) : module(module) {}
+ PyPDLPatternModule(PyPDLPatternModule &&other) noexcept
+ : module(other.module) {
+ other.module.ptr = nullptr;
+ }
+ ~PyPDLPatternModule() {
+ if (module.ptr != nullptr)
+ mlirPDLPatternModuleDestroy(module);
+ }
+ MlirPDLPatternModule get() { return module; }
+
+private:
+ MlirPDLPatternModule module;
+};
+#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
+
+/// Owning Wrapper around a FrozenRewritePatternSet.
+class PyFrozenRewritePatternSet {
+public:
+ PyFrozenRewritePatternSet(MlirFrozenRewritePatternSet set) : set(set) {}
+ PyFrozenRewritePatternSet(PyFrozenRewritePatternSet &&other) noexcept
+ : set(other.set) {
+ other.set.ptr = nullptr;
+ }
+ ~PyFrozenRewritePatternSet() {
+ if (set.ptr != nullptr)
+ mlirFrozenRewritePatternSetDestroy(set);
+ }
+ MlirFrozenRewritePatternSet get() { return set; }
+
+ pybind11::object getCapsule() {
+ return py::reinterpret_steal<py::object>(
+ mlirPythonFrozenRewritePatternSetToCapsule(get()));
+ }
+
+ static pybind11::object createFromCapsule(pybind11::object capsule) {
+ MlirFrozenRewritePatternSet rawPm =
+ mlirPythonCapsuleToFrozenRewritePatternSet(capsule.ptr());
+ if (rawPm.ptr == nullptr)
+ throw py::error_already_set();
+ return py::cast(PyFrozenRewritePatternSet(rawPm),
+ py::return_value_policy::move);
+ }
+
+private:
+ MlirFrozenRewritePatternSet set;
+};
+
+} // namespace
+
+/// Create the `mlir.rewrite` here.
+void mlir::python::populateRewriteSubmodule(py::module &m) {
+ //----------------------------------------------------------------------------
+ // Mapping of the top-level PassManager
+ //----------------------------------------------------------------------------
+#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
+ py::class_<PyPDLPatternModule>(m, "PDLModule", py::module_local())
+ .def(py::init<>([](MlirModule module) {
+ return mlirPDLPatternModuleFromModule(module);
+ }),
+ "module"_a, "Create a PDL module from the given module.")
+ .def("freeze", [](PyPDLPatternModule &self) {
+ return new PyFrozenRewritePatternSet(mlirFreezeRewritePattern(
+ mlirRewritePatternSetFromPDLPatternModule(self.get())));
+ });
+#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCg
+ py::class_<PyFrozenRewritePatternSet>(m, "FrozenRewritePatternSet",
+ py::module_local())
+ .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
+ &PyFrozenRewritePatternSet::getCapsule)
+ .def(MLIR_PYTHON_CAPI_FACTORY_ATTR,
+ &PyFrozenRewritePatternSet::createFromCapsule);
+ m.def(
+ "apply_patterns_and_fold_greedily",
+ [](MlirModule module, MlirFrozenRewritePatternSet set) {
+ auto status = mlirApplyPatternsAndFoldGreedily(module, set, {});
+ if (mlirLogicalResultIsFailure(status))
+ // FIXME: Not sure this is the right error to throw here.
+ throw py::value_error("pattern application failed to converge");
+ },
+ "module"_a, "set"_a,
+ "Applys the given patterns to the given module greedily while folding "
+ "results.");
+}
diff --git a/mlir/lib/Bindings/Python/Rewrite.h b/mlir/lib/Bindings/Python/Rewrite.h
new file mode 100644
index 0000000000000..997b80adda303
--- /dev/null
+++ b/mlir/lib/Bindings/Python/Rewrite.h
@@ -0,0 +1,22 @@
+//===- Rewrite.h - Rewrite Submodules of pybind module --------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_BINDINGS_PYTHON_REWRITE_H
+#define MLIR_BINDINGS_PYTHON_REWRITE_H
+
+#include "PybindUtils.h"
+
+namespace mlir {
+namespace python {
+
+void populateRewriteSubmodule(pybind11::module &m);
+
+} // namespace python
+} // namespace mlir
+
+#endif // MLIR_BINDINGS_PYTHON_REWRITE_H
diff --git a/mlir/lib/CAPI/Transforms/CMakeLists.txt b/mlir/lib/CAPI/Transforms/CMakeLists.txt
index 2638025a8c359..6c67aa09fdf40 100644
--- a/mlir/lib/CAPI/Transforms/CMakeLists.txt
+++ b/mlir/lib/CAPI/Transforms/CMakeLists.txt
@@ -1,6 +1,9 @@
add_mlir_upstream_c_api_library(MLIRCAPITransforms
Passes.cpp
+ Rewrite.cpp
LINK_LIBS PUBLIC
+ MLIRIR
MLIRTransforms
+ MLIRTransformUtils
)
diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
new file mode 100644
index 0000000000000..0de1958398f63
--- /dev/null
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -0,0 +1,83 @@
+//===- Rewrite.cpp - C API for Rewrite Patterns ---------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir-c/Rewrite.h"
+#include "mlir-c/Transforms.h"
+#include "mlir/CAPI/IR.h"
+#include "mlir/CAPI/Support.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Rewrite/FrozenRewritePatternSet.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+
+inline mlir::RewritePatternSet &unwrap(MlirRewritePatternSet module) {
+ assert(module.ptr && "unexpected null module");
+ return *(static_cast<mlir::RewritePatternSet *>(module.ptr));
+}
+
+inline MlirRewritePatternSet wrap(mlir::RewritePatternSet *module) {
+ return {module};
+}
+
+inline mlir::FrozenRewritePatternSet *
+unwrap(MlirFrozenRewritePatternSet module) {
+ assert(module.ptr && "unexpected null module");
+ return static_cast<mlir::FrozenRewritePatternSet *>(module.ptr);
+}
+
+inline MlirFrozenRewritePatternSet wrap(mlir::FrozenRewritePatternSet *module) {
+ return {module};
+}
+
+MlirFrozenRewritePatternSet mlirFreezeRewritePattern(MlirRewritePatternSet op) {
+ auto *m = new mlir::FrozenRewritePatternSet(std::move(unwrap(op)));
+ op.ptr = nullptr;
+ return wrap(m);
+}
+
+void mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op) {
+ delete unwrap(op);
+ op.ptr = nullptr;
+}
+
+MlirLogicalResult
+mlirApplyPatternsAndFoldGreedily(MlirModule op,
+ MlirFrozenRewritePatternSet patterns,
+ MlirGreedyRewriteDriverConfig) {
+ return wrap(
+ mlir::applyPatternsAndFoldGreedily(unwrap(op), *unwrap(patterns)));
+}
+
+#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
+inline mlir::PDLPatternModule *unwrap(MlirPDLPatternModule module) {
+ assert(module.ptr && "unexpected null module");
+ return static_cast<mlir::PDLPatternModule *>(module.ptr);
+}
+
+inline MlirPDLPatternModule wrap(mlir::PDLPatternModule *module) {
+ return {module};
+}
+
+MlirPDLPatternModule mlirPDLPatternModuleFromModule(MlirModule op) {
+ return wrap(new mlir::PDLPatternModule(
+ mlir::OwningOpRef<mlir::ModuleOp>(unwrap(op))));
+}
+
+void mlirPDLPatternModuleDestroy(MlirPDLPatternModule op) {
+ delete unwrap(op);
+ op.ptr = nullptr;
+}
+
+MlirRewritePatternSet
+mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op) {
+ auto *m = new mlir::RewritePatternSet(std::move(*unwrap(op)));
+ op.ptr = nullptr;
+ return wrap(m);
+}
+#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index d8f2d1989fdea..d03036e17749d 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -21,6 +21,7 @@ declare_mlir_python_sources(MLIRPythonSources.Core.Python
_mlir_libs/__init__.py
ir.py
passmanager.py
+ rewrite.py
dialects/_ods_common.py
# The main _mlir module has submodules: include stubs from each.
@@ -448,6 +449,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Core
IRModule.cpp
IRTypes.cpp
Pass.cpp
+ Rewrite.cpp
# Headers must be included explicitly so they are installed.
Globals.h
diff --git a/mlir/python/mlir/rewrite.py b/mlir/python/mlir/rewrite.py
new file mode 100644
index 0000000000000..5bc1bba7ae9a7
--- /dev/null
+++ b/mlir/python/mlir/rewrite.py
@@ -0,0 +1,5 @@
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from ._mlir_libs._mlir.rewrite import *
diff --git a/mlir/test/python/integration/dialects/pdl.py b/mlir/test/python/integration/dialects/pdl.py
new file mode 100644
index 0000000000000..f2eb93fe953ef
--- /dev/null
+++ b/mlir/test/python/integration/dialects/pdl.py
@@ -0,0 +1,67 @@
+# RUN: %PYTHON %s 2>&1 | FileCheck %s
+
+from mlir.dialects import arith, func, pdl
+from mlir.dialects.builtin import module
+from mlir.ir import *
+from mlir.rewrite import *
+
+
+def construct_and_print_in_module(f):
+ print("\nTEST:", f.__name__)
+ with Context(), Location.unknown():
+ module = Module.create()
+ with InsertionPoint(module.body):
+ module = f(module)
+ if module is not None:
+ print(module)
+ return f
+
+
+# CHECK-LABEL: TEST: test_add_to_mul
+# CHECK: arith.muli
+ at construct_and_print_in_module
+def test_add_to_mul(module_):
+ index_type = IndexType.get()
+
+ # Create a test case.
+ @module(sym_name="ir")
+ def ir():
+ @func.func(index_type, index_type)
+ def add_func(a, b):
+ return arith.addi(a, b)
+
+ # Create a rewrite from add to mul. This will match
+ # - operation name is arith.addi
+ # - operands are index types.
+ # - there are two operands.
+ with Location.unknown():
+ m = Module.create()
+ with InsertionPoint(m.body):
+ # Change all arith.addi with index types to arith.muli.
+ pattern = pdl.PatternOp(1, "addi_to_mul")
+ with InsertionPoint(pattern.body):
+ # Match arith.addi with index types.
+ index_type = pdl.TypeOp(IndexType.get())
+ operand0 = pdl.OperandOp(index_type)
+ operand1 = pdl.OperandOp(index_type)
+ op0 = pdl.OperationOp(
+ name="arith.addi", args=[operand0, operand1], types=[index_type]
+ )
+
+ # Replace the matched op with arith.muli.
+ rewrite = pdl.RewriteOp(op0)
+ with InsertionPoint(rewrite.add_body()):
+ newOp = pdl.OperationOp(
+ name="arith.muli", args=[operand0, operand1], types=[index_type]
+ )
+ pdl.ReplaceOp(op0, with_op=newOp)
+
+ # Create a PDL module from module and freeze it. At this point the ownership
+ # of the module is transferred to the PDL module. This ownership transfer is
+ # not yet captured Python side/has sharp edges. So best to construct the
+ # module and PDL module in same scope.
+ # FIXME: This should be made more robust.
+ frozen = PDLModule(m).freeze()
+ # Could apply frozen pattern set multiple times.
+ apply_patterns_and_fold_greedily(module_, frozen)
+ return module_
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 5d2248a8fe360..2d15955548961 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -420,6 +420,7 @@ mlir_c_api_cc_library(
"include/mlir-c/Interfaces.h",
"include/mlir-c/Pass.h",
"include/mlir-c/RegisterEverything.h",
+ "include/mlir-c/Rewrite.h",
"include/mlir-c/Support.h",
"include/mlir/CAPI/AffineExpr.h",
"include/mlir/CAPI/AffineMap.h",
@@ -866,7 +867,10 @@ mlir_c_api_cc_library(
mlir_c_api_cc_library(
name = "CAPITransforms",
- srcs = ["lib/CAPI/Transforms/Passes.cpp"],
+ srcs = [
+ "lib/CAPI/Transforms/Passes.cpp",
+ "lib/CAPI/Transforms/Rewrite.cpp",
+ ],
hdrs = ["include/mlir-c/Transforms.h"],
capi_deps = [
":CAPIIR",
@@ -876,7 +880,10 @@ mlir_c_api_cc_library(
],
includes = ["include"],
deps = [
+ ":IR",
":Pass",
+ ":Rewrite",
+ ":TransformUtils",
":Transforms",
],
)
@@ -939,6 +946,7 @@ cc_library(
textual_hdrs = glob(MLIR_BINDINGS_PYTHON_HEADERS),
deps = [
":CAPIIRHeaders",
+ ":CAPITransformsHeaders",
"@local_config_python//:python_headers",
"@pybind11",
],
@@ -957,6 +965,7 @@ cc_library(
textual_hdrs = glob(MLIR_BINDINGS_PYTHON_HEADERS),
deps = [
":CAPIIR",
+ ":CAPITransforms",
"@local_config_python//:python_headers",
"@pybind11",
],
@@ -981,6 +990,7 @@ MLIR_PYTHON_BINDINGS_SOURCES = [
"lib/Bindings/Python/IRModule.cpp",
"lib/Bindings/Python/IRTypes.cpp",
"lib/Bindings/Python/Pass.cpp",
+ "lib/Bindings/Python/Rewrite.cpp",
]
cc_library(
diff --git a/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel
index add150de69faf..254cab0db4a5d 100644
--- a/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel
@@ -82,6 +82,13 @@ filegroup(
],
)
+filegroup(
+ name = "RewritePyFiles",
+ srcs = [
+ "mlir/rewrite.py",
+ ],
+)
+
filegroup(
name = "RuntimePyFiles",
srcs = glob([
>From a689d266d1567722b116df54c9ad7a50150548a5 Mon Sep 17 00:00:00 2001
From: Jacques Pienaar <jpienaar at google.com>
Date: Sun, 9 Jun 2024 22:21:36 +0000
Subject: [PATCH 2/2] Fix formatting
Signed-off-by: Jacques Pienaar <jpienaar at google.com>
---
.../mlir/Bindings/Python/PybindAdaptors.h | 3 +-
mlir/test/python/integration/dialects/pdl.py | 96 +++++++++----------
2 files changed, 50 insertions(+), 49 deletions(-)
diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
index 39ee1551ccb2e..441b1b55f5d1d 100644
--- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
+++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
@@ -199,7 +199,8 @@ struct type_caster<MlirModule> {
};
/// Casts object <-> MlirFrozenRewritePatternSet.
-template <> struct type_caster<MlirFrozenRewritePatternSet> {
+template <>
+struct type_caster<MlirFrozenRewritePatternSet> {
PYBIND11_TYPE_CASTER(MlirFrozenRewritePatternSet,
_("MlirFrozenRewritePatternSet"));
bool load(handle src, bool) {
diff --git a/mlir/test/python/integration/dialects/pdl.py b/mlir/test/python/integration/dialects/pdl.py
index f2eb93fe953ef..04441af8ccb17 100644
--- a/mlir/test/python/integration/dialects/pdl.py
+++ b/mlir/test/python/integration/dialects/pdl.py
@@ -7,61 +7,61 @@
def construct_and_print_in_module(f):
- print("\nTEST:", f.__name__)
- with Context(), Location.unknown():
- module = Module.create()
- with InsertionPoint(module.body):
- module = f(module)
- if module is not None:
- print(module)
- return f
+ print("\nTEST:", f.__name__)
+ with Context(), Location.unknown():
+ module = Module.create()
+ with InsertionPoint(module.body):
+ module = f(module)
+ if module is not None:
+ print(module)
+ return f
# CHECK-LABEL: TEST: test_add_to_mul
# CHECK: arith.muli
@construct_and_print_in_module
def test_add_to_mul(module_):
- index_type = IndexType.get()
+ index_type = IndexType.get()
- # Create a test case.
- @module(sym_name="ir")
- def ir():
- @func.func(index_type, index_type)
- def add_func(a, b):
- return arith.addi(a, b)
+ # Create a test case.
+ @module(sym_name="ir")
+ def ir():
+ @func.func(index_type, index_type)
+ def add_func(a, b):
+ return arith.addi(a, b)
- # Create a rewrite from add to mul. This will match
- # - operation name is arith.addi
- # - operands are index types.
- # - there are two operands.
- with Location.unknown():
- m = Module.create()
- with InsertionPoint(m.body):
- # Change all arith.addi with index types to arith.muli.
- pattern = pdl.PatternOp(1, "addi_to_mul")
- with InsertionPoint(pattern.body):
- # Match arith.addi with index types.
- index_type = pdl.TypeOp(IndexType.get())
- operand0 = pdl.OperandOp(index_type)
- operand1 = pdl.OperandOp(index_type)
- op0 = pdl.OperationOp(
- name="arith.addi", args=[operand0, operand1], types=[index_type]
- )
+ # Create a rewrite from add to mul. This will match
+ # - operation name is arith.addi
+ # - operands are index types.
+ # - there are two operands.
+ with Location.unknown():
+ m = Module.create()
+ with InsertionPoint(m.body):
+ # Change all arith.addi with index types to arith.muli.
+ pattern = pdl.PatternOp(1, "addi_to_mul")
+ with InsertionPoint(pattern.body):
+ # Match arith.addi with index types.
+ index_type = pdl.TypeOp(IndexType.get())
+ operand0 = pdl.OperandOp(index_type)
+ operand1 = pdl.OperandOp(index_type)
+ op0 = pdl.OperationOp(
+ name="arith.addi", args=[operand0, operand1], types=[index_type]
+ )
- # Replace the matched op with arith.muli.
- rewrite = pdl.RewriteOp(op0)
- with InsertionPoint(rewrite.add_body()):
- newOp = pdl.OperationOp(
- name="arith.muli", args=[operand0, operand1], types=[index_type]
- )
- pdl.ReplaceOp(op0, with_op=newOp)
+ # Replace the matched op with arith.muli.
+ rewrite = pdl.RewriteOp(op0)
+ with InsertionPoint(rewrite.add_body()):
+ newOp = pdl.OperationOp(
+ name="arith.muli", args=[operand0, operand1], types=[index_type]
+ )
+ pdl.ReplaceOp(op0, with_op=newOp)
- # Create a PDL module from module and freeze it. At this point the ownership
- # of the module is transferred to the PDL module. This ownership transfer is
- # not yet captured Python side/has sharp edges. So best to construct the
- # module and PDL module in same scope.
- # FIXME: This should be made more robust.
- frozen = PDLModule(m).freeze()
- # Could apply frozen pattern set multiple times.
- apply_patterns_and_fold_greedily(module_, frozen)
- return module_
+ # Create a PDL module from module and freeze it. At this point the ownership
+ # of the module is transferred to the PDL module. This ownership transfer is
+ # not yet captured Python side/has sharp edges. So best to construct the
+ # module and PDL module in same scope.
+ # FIXME: This should be made more robust.
+ frozen = PDLModule(m).freeze()
+ # Could apply frozen pattern set multiple times.
+ apply_patterns_and_fold_greedily(module_, frozen)
+ return module_
More information about the Mlir-commits
mailing list