[Mlir-commits] [llvm] [mlir] [mlir] Add PDL C & Python usage (PR #94714)

Jacques Pienaar llvmlistbot at llvm.org
Sun Jun 9 15:21:46 PDT 2024


https://github.com/jpienaar updated https://github.com/llvm/llvm-project/pull/94714

>From f70e85e83e36e1c786d607fa89b126d248bc6f3f Mon Sep 17 00:00:00 2001
From: Jacques Pienaar <jpienaar at google.com>
Date: Fri, 7 Jun 2024 01:26:36 +0000
Subject: [PATCH 1/2] [mlir] Add PDL C & Python usage

Following a rather direct approach to expose PDL usage from C and then Python. This doesn't yes plumb through adding support for custom matchers through this interface, so constrained to basics initially.

Signed-off-by: Jacques Pienaar <jpienaar at google.com>
---
 mlir/include/mlir-c/Bindings/Python/Interop.h |  21 ++++
 mlir/include/mlir-c/Rewrite.h                 |  60 ++++++++++
 .../mlir/Bindings/Python/PybindAdaptors.h     |  20 ++++
 mlir/lib/Bindings/Python/IRModule.h           |   1 +
 mlir/lib/Bindings/Python/MainModule.cpp       |   4 +
 mlir/lib/Bindings/Python/Rewrite.cpp          | 110 ++++++++++++++++++
 mlir/lib/Bindings/Python/Rewrite.h            |  22 ++++
 mlir/lib/CAPI/Transforms/CMakeLists.txt       |   3 +
 mlir/lib/CAPI/Transforms/Rewrite.cpp          |  83 +++++++++++++
 mlir/python/CMakeLists.txt                    |   2 +
 mlir/python/mlir/rewrite.py                   |   5 +
 mlir/test/python/integration/dialects/pdl.py  |  67 +++++++++++
 .../llvm-project-overlay/mlir/BUILD.bazel     |  12 +-
 .../mlir/python/BUILD.bazel                   |   7 ++
 14 files changed, 416 insertions(+), 1 deletion(-)
 create mode 100644 mlir/include/mlir-c/Rewrite.h
 create mode 100644 mlir/lib/Bindings/Python/Rewrite.cpp
 create mode 100644 mlir/lib/Bindings/Python/Rewrite.h
 create mode 100644 mlir/lib/CAPI/Transforms/Rewrite.cpp
 create mode 100644 mlir/python/mlir/rewrite.py
 create mode 100644 mlir/test/python/integration/dialects/pdl.py

diff --git a/mlir/include/mlir-c/Bindings/Python/Interop.h b/mlir/include/mlir-c/Bindings/Python/Interop.h
index 0a36e97c2ae68..a33190c380d37 100644
--- a/mlir/include/mlir-c/Bindings/Python/Interop.h
+++ b/mlir/include/mlir-c/Bindings/Python/Interop.h
@@ -39,6 +39,7 @@
 #include "mlir-c/IR.h"
 #include "mlir-c/IntegerSet.h"
 #include "mlir-c/Pass.h"
+#include "mlir-c/Rewrite.h"
 
 // The 'mlir' Python package is relocatable and supports co-existing in multiple
 // projects. Each project must define its outer package prefix with this define
@@ -284,6 +285,26 @@ static inline MlirModule mlirPythonCapsuleToModule(PyObject *capsule) {
   return module;
 }
 
+/** Creates a capsule object encapsulating the raw C-API
+ * MlirFrozenRewritePatternSet.
+ * The returned capsule does not extend or affect ownership of any Python
+ * objects that reference the module in any way. */
+static inline PyObject *
+mlirPythonFrozenRewritePatternSetToCapsule(MlirFrozenRewritePatternSet pm) {
+  return PyCapsule_New(MLIR_PYTHON_GET_WRAPPED_POINTER(pm),
+                       MLIR_PYTHON_CAPSULE_PASS_MANAGER, NULL);
+}
+
+/** Extracts an MlirFrozenRewritePatternSet from a capsule as produced from
+ * mlirPythonFrozenRewritePatternSetToCapsule. If the capsule is not of the
+ * right type, then a null module is returned. */
+static inline MlirFrozenRewritePatternSet
+mlirPythonCapsuleToFrozenRewritePatternSet(PyObject *capsule) {
+  void *ptr = PyCapsule_GetPointer(capsule, MLIR_PYTHON_CAPSULE_PASS_MANAGER);
+  MlirFrozenRewritePatternSet pm = {ptr};
+  return pm;
+}
+
 /** Creates a capsule object encapsulating the raw C-API MlirPassManager.
  * The returned capsule does not extend or affect ownership of any Python
  * objects that reference the module in any way. */
diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
new file mode 100644
index 0000000000000..45218a1cd4ebd
--- /dev/null
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -0,0 +1,60 @@
+//===-- mlir-c/Rewrite.h - Helpers for C API to Rewrites ----------*- C -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM
+// Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header declares the registration and creation method for
+// rewrite patterns.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_C_REWRITE_H
+#define MLIR_C_REWRITE_H
+
+#include "mlir-c/IR.h"
+#include "mlir-c/Support.h"
+#include "mlir/Config/mlir-config.h"
+
+//===----------------------------------------------------------------------===//
+/// Opaque type declarations (see mlir-c/IR.h for more details).
+//===----------------------------------------------------------------------===//
+
+#define DEFINE_C_API_STRUCT(name, storage)                                     \
+  struct name {                                                                \
+    storage *ptr;                                                              \
+  };                                                                           \
+  typedef struct name name
+
+DEFINE_C_API_STRUCT(MlirFrozenRewritePatternSet, void);
+DEFINE_C_API_STRUCT(MlirGreedyRewriteDriverConfig, void);
+DEFINE_C_API_STRUCT(MlirRewritePatternSet, void);
+
+MLIR_CAPI_EXPORTED MlirFrozenRewritePatternSet
+mlirFreezeRewritePattern(MlirRewritePatternSet op);
+
+MLIR_CAPI_EXPORTED void
+mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op);
+
+MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily(
+    MlirModule op, MlirFrozenRewritePatternSet patterns,
+    MlirGreedyRewriteDriverConfig);
+
+#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
+DEFINE_C_API_STRUCT(MlirPDLPatternModule, void);
+
+MLIR_CAPI_EXPORTED MlirPDLPatternModule
+mlirPDLPatternModuleFromModule(MlirModule op);
+
+MLIR_CAPI_EXPORTED void mlirPDLPatternModuleDestroy(MlirPDLPatternModule op);
+
+MLIR_CAPI_EXPORTED MlirRewritePatternSet
+mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op);
+#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
+
+#undef DEFINE_C_API_STRUCT
+
+#endif // MLIR_C_REWRITE_H
diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
index d8f22c7aa1709..39ee1551ccb2e 100644
--- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
+++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
@@ -198,6 +198,26 @@ struct type_caster<MlirModule> {
   };
 };
 
+/// Casts object <-> MlirFrozenRewritePatternSet.
+template <> struct type_caster<MlirFrozenRewritePatternSet> {
+  PYBIND11_TYPE_CASTER(MlirFrozenRewritePatternSet,
+                       _("MlirFrozenRewritePatternSet"));
+  bool load(handle src, bool) {
+    py::object capsule = mlirApiObjectToCapsule(src);
+    value = mlirPythonCapsuleToFrozenRewritePatternSet(capsule.ptr());
+    return true; // !mlirModuleIsNull(value);
+  }
+  static handle cast(MlirFrozenRewritePatternSet v, return_value_policy,
+                     handle) {
+    py::object capsule = py::reinterpret_steal<py::object>(
+        mlirPythonFrozenRewritePatternSetToCapsule(v));
+    return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("rewrite"))
+        .attr("FrozenRewritePatternSet")
+        .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
+        .release();
+  };
+};
+
 /// Casts object <-> MlirOperation.
 template <>
 struct type_caster<MlirOperation> {
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index b038a0c54d29b..6f0dc2690a5e6 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -22,6 +22,7 @@
 #include "mlir-c/Diagnostics.h"
 #include "mlir-c/IR.h"
 #include "mlir-c/IntegerSet.h"
+#include "mlir-c/Transforms.h"
 #include "mlir/Bindings/Python/PybindAdaptors.h"
 #include "llvm/ADT/DenseMap.h"
 
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index 17272472ccca4..8da1ab16a4514 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -11,6 +11,7 @@
 #include "Globals.h"
 #include "IRModule.h"
 #include "Pass.h"
+#include "Rewrite.h"
 
 namespace py = pybind11;
 using namespace mlir;
@@ -116,6 +117,9 @@ PYBIND11_MODULE(_mlir, m) {
   populateIRInterfaces(irModule);
   populateIRTypes(irModule);
 
+  auto rewriteModule = m.def_submodule("rewrite", "MLIR Rewrite Bindings");
+  populateRewriteSubmodule(rewriteModule);
+
   // Define and populate PassManager submodule.
   auto passModule =
       m.def_submodule("passmanager", "MLIR Pass Management Bindings");
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
new file mode 100644
index 0000000000000..1d8128be9f082
--- /dev/null
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -0,0 +1,110 @@
+//===- Rewrite.cpp - Rewrite ----------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "Rewrite.h"
+
+#include "IRModule.h"
+#include "mlir-c/Bindings/Python/Interop.h"
+#include "mlir-c/Rewrite.h"
+#include "mlir/Config/mlir-config.h"
+
+namespace py = pybind11;
+using namespace mlir;
+using namespace py::literals;
+using namespace mlir::python;
+
+namespace {
+
+#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
+/// Owning Wrapper around a PDLPatternModule.
+class PyPDLPatternModule {
+public:
+  PyPDLPatternModule(MlirPDLPatternModule module) : module(module) {}
+  PyPDLPatternModule(PyPDLPatternModule &&other) noexcept
+      : module(other.module) {
+    other.module.ptr = nullptr;
+  }
+  ~PyPDLPatternModule() {
+    if (module.ptr != nullptr)
+      mlirPDLPatternModuleDestroy(module);
+  }
+  MlirPDLPatternModule get() { return module; }
+
+private:
+  MlirPDLPatternModule module;
+};
+#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
+
+/// Owning Wrapper around a FrozenRewritePatternSet.
+class PyFrozenRewritePatternSet {
+public:
+  PyFrozenRewritePatternSet(MlirFrozenRewritePatternSet set) : set(set) {}
+  PyFrozenRewritePatternSet(PyFrozenRewritePatternSet &&other) noexcept
+      : set(other.set) {
+    other.set.ptr = nullptr;
+  }
+  ~PyFrozenRewritePatternSet() {
+    if (set.ptr != nullptr)
+      mlirFrozenRewritePatternSetDestroy(set);
+  }
+  MlirFrozenRewritePatternSet get() { return set; }
+
+  pybind11::object getCapsule() {
+    return py::reinterpret_steal<py::object>(
+        mlirPythonFrozenRewritePatternSetToCapsule(get()));
+  }
+
+  static pybind11::object createFromCapsule(pybind11::object capsule) {
+    MlirFrozenRewritePatternSet rawPm =
+        mlirPythonCapsuleToFrozenRewritePatternSet(capsule.ptr());
+    if (rawPm.ptr == nullptr)
+      throw py::error_already_set();
+    return py::cast(PyFrozenRewritePatternSet(rawPm),
+                    py::return_value_policy::move);
+  }
+
+private:
+  MlirFrozenRewritePatternSet set;
+};
+
+} // namespace
+
+/// Create the `mlir.rewrite` here.
+void mlir::python::populateRewriteSubmodule(py::module &m) {
+  //----------------------------------------------------------------------------
+  // Mapping of the top-level PassManager
+  //----------------------------------------------------------------------------
+#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
+  py::class_<PyPDLPatternModule>(m, "PDLModule", py::module_local())
+      .def(py::init<>([](MlirModule module) {
+             return mlirPDLPatternModuleFromModule(module);
+           }),
+           "module"_a, "Create a PDL module from the given module.")
+      .def("freeze", [](PyPDLPatternModule &self) {
+        return new PyFrozenRewritePatternSet(mlirFreezeRewritePattern(
+            mlirRewritePatternSetFromPDLPatternModule(self.get())));
+      });
+#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCg
+  py::class_<PyFrozenRewritePatternSet>(m, "FrozenRewritePatternSet",
+                                        py::module_local())
+      .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
+                             &PyFrozenRewritePatternSet::getCapsule)
+      .def(MLIR_PYTHON_CAPI_FACTORY_ATTR,
+           &PyFrozenRewritePatternSet::createFromCapsule);
+  m.def(
+      "apply_patterns_and_fold_greedily",
+      [](MlirModule module, MlirFrozenRewritePatternSet set) {
+        auto status = mlirApplyPatternsAndFoldGreedily(module, set, {});
+        if (mlirLogicalResultIsFailure(status))
+          // FIXME: Not sure this is the right error to throw here.
+          throw py::value_error("pattern application failed to converge");
+      },
+      "module"_a, "set"_a,
+      "Applys the given patterns to the given module greedily while folding "
+      "results.");
+}
diff --git a/mlir/lib/Bindings/Python/Rewrite.h b/mlir/lib/Bindings/Python/Rewrite.h
new file mode 100644
index 0000000000000..997b80adda303
--- /dev/null
+++ b/mlir/lib/Bindings/Python/Rewrite.h
@@ -0,0 +1,22 @@
+//===- Rewrite.h - Rewrite Submodules of pybind module --------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_BINDINGS_PYTHON_REWRITE_H
+#define MLIR_BINDINGS_PYTHON_REWRITE_H
+
+#include "PybindUtils.h"
+
+namespace mlir {
+namespace python {
+
+void populateRewriteSubmodule(pybind11::module &m);
+
+} // namespace python
+} // namespace mlir
+
+#endif // MLIR_BINDINGS_PYTHON_REWRITE_H
diff --git a/mlir/lib/CAPI/Transforms/CMakeLists.txt b/mlir/lib/CAPI/Transforms/CMakeLists.txt
index 2638025a8c359..6c67aa09fdf40 100644
--- a/mlir/lib/CAPI/Transforms/CMakeLists.txt
+++ b/mlir/lib/CAPI/Transforms/CMakeLists.txt
@@ -1,6 +1,9 @@
 add_mlir_upstream_c_api_library(MLIRCAPITransforms
   Passes.cpp
+  Rewrite.cpp
 
   LINK_LIBS PUBLIC
+  MLIRIR
   MLIRTransforms
+  MLIRTransformUtils
 )
diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
new file mode 100644
index 0000000000000..0de1958398f63
--- /dev/null
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -0,0 +1,83 @@
+//===- Rewrite.cpp - C API for Rewrite Patterns ---------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir-c/Rewrite.h"
+#include "mlir-c/Transforms.h"
+#include "mlir/CAPI/IR.h"
+#include "mlir/CAPI/Support.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Rewrite/FrozenRewritePatternSet.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+
+inline mlir::RewritePatternSet &unwrap(MlirRewritePatternSet module) {
+  assert(module.ptr && "unexpected null module");
+  return *(static_cast<mlir::RewritePatternSet *>(module.ptr));
+}
+
+inline MlirRewritePatternSet wrap(mlir::RewritePatternSet *module) {
+  return {module};
+}
+
+inline mlir::FrozenRewritePatternSet *
+unwrap(MlirFrozenRewritePatternSet module) {
+  assert(module.ptr && "unexpected null module");
+  return static_cast<mlir::FrozenRewritePatternSet *>(module.ptr);
+}
+
+inline MlirFrozenRewritePatternSet wrap(mlir::FrozenRewritePatternSet *module) {
+  return {module};
+}
+
+MlirFrozenRewritePatternSet mlirFreezeRewritePattern(MlirRewritePatternSet op) {
+  auto *m = new mlir::FrozenRewritePatternSet(std::move(unwrap(op)));
+  op.ptr = nullptr;
+  return wrap(m);
+}
+
+void mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op) {
+  delete unwrap(op);
+  op.ptr = nullptr;
+}
+
+MlirLogicalResult
+mlirApplyPatternsAndFoldGreedily(MlirModule op,
+                                 MlirFrozenRewritePatternSet patterns,
+                                 MlirGreedyRewriteDriverConfig) {
+  return wrap(
+      mlir::applyPatternsAndFoldGreedily(unwrap(op), *unwrap(patterns)));
+}
+
+#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
+inline mlir::PDLPatternModule *unwrap(MlirPDLPatternModule module) {
+  assert(module.ptr && "unexpected null module");
+  return static_cast<mlir::PDLPatternModule *>(module.ptr);
+}
+
+inline MlirPDLPatternModule wrap(mlir::PDLPatternModule *module) {
+  return {module};
+}
+
+MlirPDLPatternModule mlirPDLPatternModuleFromModule(MlirModule op) {
+  return wrap(new mlir::PDLPatternModule(
+      mlir::OwningOpRef<mlir::ModuleOp>(unwrap(op))));
+}
+
+void mlirPDLPatternModuleDestroy(MlirPDLPatternModule op) {
+  delete unwrap(op);
+  op.ptr = nullptr;
+}
+
+MlirRewritePatternSet
+mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op) {
+  auto *m = new mlir::RewritePatternSet(std::move(*unwrap(op)));
+  op.ptr = nullptr;
+  return wrap(m);
+}
+#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index d8f2d1989fdea..d03036e17749d 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -21,6 +21,7 @@ declare_mlir_python_sources(MLIRPythonSources.Core.Python
     _mlir_libs/__init__.py
     ir.py
     passmanager.py
+    rewrite.py
     dialects/_ods_common.py
 
     # The main _mlir module has submodules: include stubs from each.
@@ -448,6 +449,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Core
     IRModule.cpp
     IRTypes.cpp
     Pass.cpp
+    Rewrite.cpp
 
     # Headers must be included explicitly so they are installed.
     Globals.h
diff --git a/mlir/python/mlir/rewrite.py b/mlir/python/mlir/rewrite.py
new file mode 100644
index 0000000000000..5bc1bba7ae9a7
--- /dev/null
+++ b/mlir/python/mlir/rewrite.py
@@ -0,0 +1,5 @@
+#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+#  See https://llvm.org/LICENSE.txt for license information.
+#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from ._mlir_libs._mlir.rewrite import *
diff --git a/mlir/test/python/integration/dialects/pdl.py b/mlir/test/python/integration/dialects/pdl.py
new file mode 100644
index 0000000000000..f2eb93fe953ef
--- /dev/null
+++ b/mlir/test/python/integration/dialects/pdl.py
@@ -0,0 +1,67 @@
+# RUN: %PYTHON %s 2>&1 | FileCheck %s
+
+from mlir.dialects import arith, func, pdl
+from mlir.dialects.builtin import module
+from mlir.ir import *
+from mlir.rewrite import *
+
+
+def construct_and_print_in_module(f):
+  print("\nTEST:", f.__name__)
+  with Context(), Location.unknown():
+    module = Module.create()
+    with InsertionPoint(module.body):
+      module = f(module)
+    if module is not None:
+      print(module)
+  return f
+
+
+# CHECK-LABEL: TEST: test_add_to_mul
+# CHECK: arith.muli
+ at construct_and_print_in_module
+def test_add_to_mul(module_):
+  index_type = IndexType.get()
+
+  # Create a test case.
+  @module(sym_name="ir")
+  def ir():
+    @func.func(index_type, index_type)
+    def add_func(a, b):
+      return arith.addi(a, b)
+
+  # Create a rewrite from add to mul. This will match
+  # - operation name is arith.addi
+  # - operands are index types.
+  # - there are two operands.
+  with Location.unknown():
+    m = Module.create()
+    with InsertionPoint(m.body):
+      # Change all arith.addi with index types to arith.muli.
+      pattern = pdl.PatternOp(1, "addi_to_mul")
+      with InsertionPoint(pattern.body):
+        # Match arith.addi with index types.
+        index_type = pdl.TypeOp(IndexType.get())
+        operand0 = pdl.OperandOp(index_type)
+        operand1 = pdl.OperandOp(index_type)
+        op0 = pdl.OperationOp(
+            name="arith.addi", args=[operand0, operand1], types=[index_type]
+        )
+
+        # Replace the matched op with arith.muli.
+        rewrite = pdl.RewriteOp(op0)
+        with InsertionPoint(rewrite.add_body()):
+          newOp = pdl.OperationOp(
+              name="arith.muli", args=[operand0, operand1], types=[index_type]
+          )
+          pdl.ReplaceOp(op0, with_op=newOp)
+
+  # Create a PDL module from module and freeze it. At this point the ownership
+  # of the module is transferred to the PDL module. This ownership transfer is
+  # not yet captured Python side/has sharp edges. So best to construct the
+  # module and PDL module in same scope.
+  # FIXME: This should be made more robust.
+  frozen = PDLModule(m).freeze()
+  # Could apply frozen pattern set multiple times.
+  apply_patterns_and_fold_greedily(module_, frozen)
+  return module_
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 5d2248a8fe360..2d15955548961 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -420,6 +420,7 @@ mlir_c_api_cc_library(
         "include/mlir-c/Interfaces.h",
         "include/mlir-c/Pass.h",
         "include/mlir-c/RegisterEverything.h",
+        "include/mlir-c/Rewrite.h",
         "include/mlir-c/Support.h",
         "include/mlir/CAPI/AffineExpr.h",
         "include/mlir/CAPI/AffineMap.h",
@@ -866,7 +867,10 @@ mlir_c_api_cc_library(
 
 mlir_c_api_cc_library(
     name = "CAPITransforms",
-    srcs = ["lib/CAPI/Transforms/Passes.cpp"],
+    srcs = [
+        "lib/CAPI/Transforms/Passes.cpp",
+        "lib/CAPI/Transforms/Rewrite.cpp",
+    ],
     hdrs = ["include/mlir-c/Transforms.h"],
     capi_deps = [
         ":CAPIIR",
@@ -876,7 +880,10 @@ mlir_c_api_cc_library(
     ],
     includes = ["include"],
     deps = [
+        ":IR",
         ":Pass",
+        ":Rewrite",
+        ":TransformUtils",
         ":Transforms",
     ],
 )
@@ -939,6 +946,7 @@ cc_library(
     textual_hdrs = glob(MLIR_BINDINGS_PYTHON_HEADERS),
     deps = [
         ":CAPIIRHeaders",
+        ":CAPITransformsHeaders",
         "@local_config_python//:python_headers",
         "@pybind11",
     ],
@@ -957,6 +965,7 @@ cc_library(
     textual_hdrs = glob(MLIR_BINDINGS_PYTHON_HEADERS),
     deps = [
         ":CAPIIR",
+        ":CAPITransforms",
         "@local_config_python//:python_headers",
         "@pybind11",
     ],
@@ -981,6 +990,7 @@ MLIR_PYTHON_BINDINGS_SOURCES = [
     "lib/Bindings/Python/IRModule.cpp",
     "lib/Bindings/Python/IRTypes.cpp",
     "lib/Bindings/Python/Pass.cpp",
+    "lib/Bindings/Python/Rewrite.cpp",
 ]
 
 cc_library(
diff --git a/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel
index add150de69faf..254cab0db4a5d 100644
--- a/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel
@@ -82,6 +82,13 @@ filegroup(
     ],
 )
 
+filegroup(
+    name = "RewritePyFiles",
+    srcs = [
+        "mlir/rewrite.py",
+    ],
+)
+
 filegroup(
     name = "RuntimePyFiles",
     srcs = glob([

>From a689d266d1567722b116df54c9ad7a50150548a5 Mon Sep 17 00:00:00 2001
From: Jacques Pienaar <jpienaar at google.com>
Date: Sun, 9 Jun 2024 22:21:36 +0000
Subject: [PATCH 2/2] Fix formatting

Signed-off-by: Jacques Pienaar <jpienaar at google.com>
---
 .../mlir/Bindings/Python/PybindAdaptors.h     |  3 +-
 mlir/test/python/integration/dialects/pdl.py  | 96 +++++++++----------
 2 files changed, 50 insertions(+), 49 deletions(-)

diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
index 39ee1551ccb2e..441b1b55f5d1d 100644
--- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
+++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
@@ -199,7 +199,8 @@ struct type_caster<MlirModule> {
 };
 
 /// Casts object <-> MlirFrozenRewritePatternSet.
-template <> struct type_caster<MlirFrozenRewritePatternSet> {
+template <>
+struct type_caster<MlirFrozenRewritePatternSet> {
   PYBIND11_TYPE_CASTER(MlirFrozenRewritePatternSet,
                        _("MlirFrozenRewritePatternSet"));
   bool load(handle src, bool) {
diff --git a/mlir/test/python/integration/dialects/pdl.py b/mlir/test/python/integration/dialects/pdl.py
index f2eb93fe953ef..04441af8ccb17 100644
--- a/mlir/test/python/integration/dialects/pdl.py
+++ b/mlir/test/python/integration/dialects/pdl.py
@@ -7,61 +7,61 @@
 
 
 def construct_and_print_in_module(f):
-  print("\nTEST:", f.__name__)
-  with Context(), Location.unknown():
-    module = Module.create()
-    with InsertionPoint(module.body):
-      module = f(module)
-    if module is not None:
-      print(module)
-  return f
+    print("\nTEST:", f.__name__)
+    with Context(), Location.unknown():
+        module = Module.create()
+        with InsertionPoint(module.body):
+            module = f(module)
+        if module is not None:
+            print(module)
+    return f
 
 
 # CHECK-LABEL: TEST: test_add_to_mul
 # CHECK: arith.muli
 @construct_and_print_in_module
 def test_add_to_mul(module_):
-  index_type = IndexType.get()
+    index_type = IndexType.get()
 
-  # Create a test case.
-  @module(sym_name="ir")
-  def ir():
-    @func.func(index_type, index_type)
-    def add_func(a, b):
-      return arith.addi(a, b)
+    # Create a test case.
+    @module(sym_name="ir")
+    def ir():
+        @func.func(index_type, index_type)
+        def add_func(a, b):
+            return arith.addi(a, b)
 
-  # Create a rewrite from add to mul. This will match
-  # - operation name is arith.addi
-  # - operands are index types.
-  # - there are two operands.
-  with Location.unknown():
-    m = Module.create()
-    with InsertionPoint(m.body):
-      # Change all arith.addi with index types to arith.muli.
-      pattern = pdl.PatternOp(1, "addi_to_mul")
-      with InsertionPoint(pattern.body):
-        # Match arith.addi with index types.
-        index_type = pdl.TypeOp(IndexType.get())
-        operand0 = pdl.OperandOp(index_type)
-        operand1 = pdl.OperandOp(index_type)
-        op0 = pdl.OperationOp(
-            name="arith.addi", args=[operand0, operand1], types=[index_type]
-        )
+    # Create a rewrite from add to mul. This will match
+    # - operation name is arith.addi
+    # - operands are index types.
+    # - there are two operands.
+    with Location.unknown():
+        m = Module.create()
+        with InsertionPoint(m.body):
+            # Change all arith.addi with index types to arith.muli.
+            pattern = pdl.PatternOp(1, "addi_to_mul")
+            with InsertionPoint(pattern.body):
+                # Match arith.addi with index types.
+                index_type = pdl.TypeOp(IndexType.get())
+                operand0 = pdl.OperandOp(index_type)
+                operand1 = pdl.OperandOp(index_type)
+                op0 = pdl.OperationOp(
+                    name="arith.addi", args=[operand0, operand1], types=[index_type]
+                )
 
-        # Replace the matched op with arith.muli.
-        rewrite = pdl.RewriteOp(op0)
-        with InsertionPoint(rewrite.add_body()):
-          newOp = pdl.OperationOp(
-              name="arith.muli", args=[operand0, operand1], types=[index_type]
-          )
-          pdl.ReplaceOp(op0, with_op=newOp)
+                # Replace the matched op with arith.muli.
+                rewrite = pdl.RewriteOp(op0)
+                with InsertionPoint(rewrite.add_body()):
+                    newOp = pdl.OperationOp(
+                        name="arith.muli", args=[operand0, operand1], types=[index_type]
+                    )
+                    pdl.ReplaceOp(op0, with_op=newOp)
 
-  # Create a PDL module from module and freeze it. At this point the ownership
-  # of the module is transferred to the PDL module. This ownership transfer is
-  # not yet captured Python side/has sharp edges. So best to construct the
-  # module and PDL module in same scope.
-  # FIXME: This should be made more robust.
-  frozen = PDLModule(m).freeze()
-  # Could apply frozen pattern set multiple times.
-  apply_patterns_and_fold_greedily(module_, frozen)
-  return module_
+    # Create a PDL module from module and freeze it. At this point the ownership
+    # of the module is transferred to the PDL module. This ownership transfer is
+    # not yet captured Python side/has sharp edges. So best to construct the
+    # module and PDL module in same scope.
+    # FIXME: This should be made more robust.
+    frozen = PDLModule(m).freeze()
+    # Could apply frozen pattern set multiple times.
+    apply_patterns_and_fold_greedily(module_, frozen)
+    return module_



More information about the Mlir-commits mailing list