[Mlir-commits] [llvm] [mlir] [mlir] Add PDL C & Python usage (PR #94714)

Jacques Pienaar llvmlistbot at llvm.org
Thu Jun 6 18:38:16 PDT 2024


https://github.com/jpienaar created https://github.com/llvm/llvm-project/pull/94714

Following a rather direct approach to expose PDL usage from C and then Python. This doesn't yes plumb through adding support for custom matchers through this interface, so constrained to basics initially.

This also exposes greedy rewrite driver. Only way currently to define patterns is via PDL (just to keep small). The creation of the PDL pattern module could be improved to avoid folks potentially accessing the module used to construct it post construction. No ergonomic work done yet.

>From f70e85e83e36e1c786d607fa89b126d248bc6f3f Mon Sep 17 00:00:00 2001
From: Jacques Pienaar <jpienaar at google.com>
Date: Fri, 7 Jun 2024 01:26:36 +0000
Subject: [PATCH] [mlir] Add PDL C & Python usage

Following a rather direct approach to expose PDL usage from C and then Python. This doesn't yes plumb through adding support for custom matchers through this interface, so constrained to basics initially.

Signed-off-by: Jacques Pienaar <jpienaar at google.com>
---
 mlir/include/mlir-c/Bindings/Python/Interop.h |  21 ++++
 mlir/include/mlir-c/Rewrite.h                 |  60 ++++++++++
 .../mlir/Bindings/Python/PybindAdaptors.h     |  20 ++++
 mlir/lib/Bindings/Python/IRModule.h           |   1 +
 mlir/lib/Bindings/Python/MainModule.cpp       |   4 +
 mlir/lib/Bindings/Python/Rewrite.cpp          | 110 ++++++++++++++++++
 mlir/lib/Bindings/Python/Rewrite.h            |  22 ++++
 mlir/lib/CAPI/Transforms/CMakeLists.txt       |   3 +
 mlir/lib/CAPI/Transforms/Rewrite.cpp          |  83 +++++++++++++
 mlir/python/CMakeLists.txt                    |   2 +
 mlir/python/mlir/rewrite.py                   |   5 +
 mlir/test/python/integration/dialects/pdl.py  |  67 +++++++++++
 .../llvm-project-overlay/mlir/BUILD.bazel     |  12 +-
 .../mlir/python/BUILD.bazel                   |   7 ++
 14 files changed, 416 insertions(+), 1 deletion(-)
 create mode 100644 mlir/include/mlir-c/Rewrite.h
 create mode 100644 mlir/lib/Bindings/Python/Rewrite.cpp
 create mode 100644 mlir/lib/Bindings/Python/Rewrite.h
 create mode 100644 mlir/lib/CAPI/Transforms/Rewrite.cpp
 create mode 100644 mlir/python/mlir/rewrite.py
 create mode 100644 mlir/test/python/integration/dialects/pdl.py

diff --git a/mlir/include/mlir-c/Bindings/Python/Interop.h b/mlir/include/mlir-c/Bindings/Python/Interop.h
index 0a36e97c2ae68..a33190c380d37 100644
--- a/mlir/include/mlir-c/Bindings/Python/Interop.h
+++ b/mlir/include/mlir-c/Bindings/Python/Interop.h
@@ -39,6 +39,7 @@
 #include "mlir-c/IR.h"
 #include "mlir-c/IntegerSet.h"
 #include "mlir-c/Pass.h"
+#include "mlir-c/Rewrite.h"
 
 // The 'mlir' Python package is relocatable and supports co-existing in multiple
 // projects. Each project must define its outer package prefix with this define
@@ -284,6 +285,26 @@ static inline MlirModule mlirPythonCapsuleToModule(PyObject *capsule) {
   return module;
 }
 
+/** Creates a capsule object encapsulating the raw C-API
+ * MlirFrozenRewritePatternSet.
+ * The returned capsule does not extend or affect ownership of any Python
+ * objects that reference the module in any way. */
+static inline PyObject *
+mlirPythonFrozenRewritePatternSetToCapsule(MlirFrozenRewritePatternSet pm) {
+  return PyCapsule_New(MLIR_PYTHON_GET_WRAPPED_POINTER(pm),
+                       MLIR_PYTHON_CAPSULE_PASS_MANAGER, NULL);
+}
+
+/** Extracts an MlirFrozenRewritePatternSet from a capsule as produced from
+ * mlirPythonFrozenRewritePatternSetToCapsule. If the capsule is not of the
+ * right type, then a null module is returned. */
+static inline MlirFrozenRewritePatternSet
+mlirPythonCapsuleToFrozenRewritePatternSet(PyObject *capsule) {
+  void *ptr = PyCapsule_GetPointer(capsule, MLIR_PYTHON_CAPSULE_PASS_MANAGER);
+  MlirFrozenRewritePatternSet pm = {ptr};
+  return pm;
+}
+
 /** Creates a capsule object encapsulating the raw C-API MlirPassManager.
  * The returned capsule does not extend or affect ownership of any Python
  * objects that reference the module in any way. */
diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
new file mode 100644
index 0000000000000..45218a1cd4ebd
--- /dev/null
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -0,0 +1,60 @@
+//===-- mlir-c/Rewrite.h - Helpers for C API to Rewrites ----------*- C -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM
+// Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header declares the registration and creation method for
+// rewrite patterns.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_C_REWRITE_H
+#define MLIR_C_REWRITE_H
+
+#include "mlir-c/IR.h"
+#include "mlir-c/Support.h"
+#include "mlir/Config/mlir-config.h"
+
+//===----------------------------------------------------------------------===//
+/// Opaque type declarations (see mlir-c/IR.h for more details).
+//===----------------------------------------------------------------------===//
+
+#define DEFINE_C_API_STRUCT(name, storage)                                     \
+  struct name {                                                                \
+    storage *ptr;                                                              \
+  };                                                                           \
+  typedef struct name name
+
+DEFINE_C_API_STRUCT(MlirFrozenRewritePatternSet, void);
+DEFINE_C_API_STRUCT(MlirGreedyRewriteDriverConfig, void);
+DEFINE_C_API_STRUCT(MlirRewritePatternSet, void);
+
+MLIR_CAPI_EXPORTED MlirFrozenRewritePatternSet
+mlirFreezeRewritePattern(MlirRewritePatternSet op);
+
+MLIR_CAPI_EXPORTED void
+mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op);
+
+MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily(
+    MlirModule op, MlirFrozenRewritePatternSet patterns,
+    MlirGreedyRewriteDriverConfig);
+
+#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
+DEFINE_C_API_STRUCT(MlirPDLPatternModule, void);
+
+MLIR_CAPI_EXPORTED MlirPDLPatternModule
+mlirPDLPatternModuleFromModule(MlirModule op);
+
+MLIR_CAPI_EXPORTED void mlirPDLPatternModuleDestroy(MlirPDLPatternModule op);
+
+MLIR_CAPI_EXPORTED MlirRewritePatternSet
+mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op);
+#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
+
+#undef DEFINE_C_API_STRUCT
+
+#endif // MLIR_C_REWRITE_H
diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
index d8f22c7aa1709..39ee1551ccb2e 100644
--- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
+++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
@@ -198,6 +198,26 @@ struct type_caster<MlirModule> {
   };
 };
 
+/// Casts object <-> MlirFrozenRewritePatternSet.
+template <> struct type_caster<MlirFrozenRewritePatternSet> {
+  PYBIND11_TYPE_CASTER(MlirFrozenRewritePatternSet,
+                       _("MlirFrozenRewritePatternSet"));
+  bool load(handle src, bool) {
+    py::object capsule = mlirApiObjectToCapsule(src);
+    value = mlirPythonCapsuleToFrozenRewritePatternSet(capsule.ptr());
+    return true; // !mlirModuleIsNull(value);
+  }
+  static handle cast(MlirFrozenRewritePatternSet v, return_value_policy,
+                     handle) {
+    py::object capsule = py::reinterpret_steal<py::object>(
+        mlirPythonFrozenRewritePatternSetToCapsule(v));
+    return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("rewrite"))
+        .attr("FrozenRewritePatternSet")
+        .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
+        .release();
+  };
+};
+
 /// Casts object <-> MlirOperation.
 template <>
 struct type_caster<MlirOperation> {
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index b038a0c54d29b..6f0dc2690a5e6 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -22,6 +22,7 @@
 #include "mlir-c/Diagnostics.h"
 #include "mlir-c/IR.h"
 #include "mlir-c/IntegerSet.h"
+#include "mlir-c/Transforms.h"
 #include "mlir/Bindings/Python/PybindAdaptors.h"
 #include "llvm/ADT/DenseMap.h"
 
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index 17272472ccca4..8da1ab16a4514 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -11,6 +11,7 @@
 #include "Globals.h"
 #include "IRModule.h"
 #include "Pass.h"
+#include "Rewrite.h"
 
 namespace py = pybind11;
 using namespace mlir;
@@ -116,6 +117,9 @@ PYBIND11_MODULE(_mlir, m) {
   populateIRInterfaces(irModule);
   populateIRTypes(irModule);
 
+  auto rewriteModule = m.def_submodule("rewrite", "MLIR Rewrite Bindings");
+  populateRewriteSubmodule(rewriteModule);
+
   // Define and populate PassManager submodule.
   auto passModule =
       m.def_submodule("passmanager", "MLIR Pass Management Bindings");
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
new file mode 100644
index 0000000000000..1d8128be9f082
--- /dev/null
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -0,0 +1,110 @@
+//===- Rewrite.cpp - Rewrite ----------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "Rewrite.h"
+
+#include "IRModule.h"
+#include "mlir-c/Bindings/Python/Interop.h"
+#include "mlir-c/Rewrite.h"
+#include "mlir/Config/mlir-config.h"
+
+namespace py = pybind11;
+using namespace mlir;
+using namespace py::literals;
+using namespace mlir::python;
+
+namespace {
+
+#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
+/// Owning Wrapper around a PDLPatternModule.
+class PyPDLPatternModule {
+public:
+  PyPDLPatternModule(MlirPDLPatternModule module) : module(module) {}
+  PyPDLPatternModule(PyPDLPatternModule &&other) noexcept
+      : module(other.module) {
+    other.module.ptr = nullptr;
+  }
+  ~PyPDLPatternModule() {
+    if (module.ptr != nullptr)
+      mlirPDLPatternModuleDestroy(module);
+  }
+  MlirPDLPatternModule get() { return module; }
+
+private:
+  MlirPDLPatternModule module;
+};
+#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
+
+/// Owning Wrapper around a FrozenRewritePatternSet.
+class PyFrozenRewritePatternSet {
+public:
+  PyFrozenRewritePatternSet(MlirFrozenRewritePatternSet set) : set(set) {}
+  PyFrozenRewritePatternSet(PyFrozenRewritePatternSet &&other) noexcept
+      : set(other.set) {
+    other.set.ptr = nullptr;
+  }
+  ~PyFrozenRewritePatternSet() {
+    if (set.ptr != nullptr)
+      mlirFrozenRewritePatternSetDestroy(set);
+  }
+  MlirFrozenRewritePatternSet get() { return set; }
+
+  pybind11::object getCapsule() {
+    return py::reinterpret_steal<py::object>(
+        mlirPythonFrozenRewritePatternSetToCapsule(get()));
+  }
+
+  static pybind11::object createFromCapsule(pybind11::object capsule) {
+    MlirFrozenRewritePatternSet rawPm =
+        mlirPythonCapsuleToFrozenRewritePatternSet(capsule.ptr());
+    if (rawPm.ptr == nullptr)
+      throw py::error_already_set();
+    return py::cast(PyFrozenRewritePatternSet(rawPm),
+                    py::return_value_policy::move);
+  }
+
+private:
+  MlirFrozenRewritePatternSet set;
+};
+
+} // namespace
+
+/// Create the `mlir.rewrite` here.
+void mlir::python::populateRewriteSubmodule(py::module &m) {
+  //----------------------------------------------------------------------------
+  // Mapping of the top-level PassManager
+  //----------------------------------------------------------------------------
+#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
+  py::class_<PyPDLPatternModule>(m, "PDLModule", py::module_local())
+      .def(py::init<>([](MlirModule module) {
+             return mlirPDLPatternModuleFromModule(module);
+           }),
+           "module"_a, "Create a PDL module from the given module.")
+      .def("freeze", [](PyPDLPatternModule &self) {
+        return new PyFrozenRewritePatternSet(mlirFreezeRewritePattern(
+            mlirRewritePatternSetFromPDLPatternModule(self.get())));
+      });
+#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCg
+  py::class_<PyFrozenRewritePatternSet>(m, "FrozenRewritePatternSet",
+                                        py::module_local())
+      .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
+                             &PyFrozenRewritePatternSet::getCapsule)
+      .def(MLIR_PYTHON_CAPI_FACTORY_ATTR,
+           &PyFrozenRewritePatternSet::createFromCapsule);
+  m.def(
+      "apply_patterns_and_fold_greedily",
+      [](MlirModule module, MlirFrozenRewritePatternSet set) {
+        auto status = mlirApplyPatternsAndFoldGreedily(module, set, {});
+        if (mlirLogicalResultIsFailure(status))
+          // FIXME: Not sure this is the right error to throw here.
+          throw py::value_error("pattern application failed to converge");
+      },
+      "module"_a, "set"_a,
+      "Applys the given patterns to the given module greedily while folding "
+      "results.");
+}
diff --git a/mlir/lib/Bindings/Python/Rewrite.h b/mlir/lib/Bindings/Python/Rewrite.h
new file mode 100644
index 0000000000000..997b80adda303
--- /dev/null
+++ b/mlir/lib/Bindings/Python/Rewrite.h
@@ -0,0 +1,22 @@
+//===- Rewrite.h - Rewrite Submodules of pybind module --------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_BINDINGS_PYTHON_REWRITE_H
+#define MLIR_BINDINGS_PYTHON_REWRITE_H
+
+#include "PybindUtils.h"
+
+namespace mlir {
+namespace python {
+
+void populateRewriteSubmodule(pybind11::module &m);
+
+} // namespace python
+} // namespace mlir
+
+#endif // MLIR_BINDINGS_PYTHON_REWRITE_H
diff --git a/mlir/lib/CAPI/Transforms/CMakeLists.txt b/mlir/lib/CAPI/Transforms/CMakeLists.txt
index 2638025a8c359..6c67aa09fdf40 100644
--- a/mlir/lib/CAPI/Transforms/CMakeLists.txt
+++ b/mlir/lib/CAPI/Transforms/CMakeLists.txt
@@ -1,6 +1,9 @@
 add_mlir_upstream_c_api_library(MLIRCAPITransforms
   Passes.cpp
+  Rewrite.cpp
 
   LINK_LIBS PUBLIC
+  MLIRIR
   MLIRTransforms
+  MLIRTransformUtils
 )
diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
new file mode 100644
index 0000000000000..0de1958398f63
--- /dev/null
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -0,0 +1,83 @@
+//===- Rewrite.cpp - C API for Rewrite Patterns ---------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir-c/Rewrite.h"
+#include "mlir-c/Transforms.h"
+#include "mlir/CAPI/IR.h"
+#include "mlir/CAPI/Support.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Rewrite/FrozenRewritePatternSet.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+
+inline mlir::RewritePatternSet &unwrap(MlirRewritePatternSet module) {
+  assert(module.ptr && "unexpected null module");
+  return *(static_cast<mlir::RewritePatternSet *>(module.ptr));
+}
+
+inline MlirRewritePatternSet wrap(mlir::RewritePatternSet *module) {
+  return {module};
+}
+
+inline mlir::FrozenRewritePatternSet *
+unwrap(MlirFrozenRewritePatternSet module) {
+  assert(module.ptr && "unexpected null module");
+  return static_cast<mlir::FrozenRewritePatternSet *>(module.ptr);
+}
+
+inline MlirFrozenRewritePatternSet wrap(mlir::FrozenRewritePatternSet *module) {
+  return {module};
+}
+
+MlirFrozenRewritePatternSet mlirFreezeRewritePattern(MlirRewritePatternSet op) {
+  auto *m = new mlir::FrozenRewritePatternSet(std::move(unwrap(op)));
+  op.ptr = nullptr;
+  return wrap(m);
+}
+
+void mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op) {
+  delete unwrap(op);
+  op.ptr = nullptr;
+}
+
+MlirLogicalResult
+mlirApplyPatternsAndFoldGreedily(MlirModule op,
+                                 MlirFrozenRewritePatternSet patterns,
+                                 MlirGreedyRewriteDriverConfig) {
+  return wrap(
+      mlir::applyPatternsAndFoldGreedily(unwrap(op), *unwrap(patterns)));
+}
+
+#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
+inline mlir::PDLPatternModule *unwrap(MlirPDLPatternModule module) {
+  assert(module.ptr && "unexpected null module");
+  return static_cast<mlir::PDLPatternModule *>(module.ptr);
+}
+
+inline MlirPDLPatternModule wrap(mlir::PDLPatternModule *module) {
+  return {module};
+}
+
+MlirPDLPatternModule mlirPDLPatternModuleFromModule(MlirModule op) {
+  return wrap(new mlir::PDLPatternModule(
+      mlir::OwningOpRef<mlir::ModuleOp>(unwrap(op))));
+}
+
+void mlirPDLPatternModuleDestroy(MlirPDLPatternModule op) {
+  delete unwrap(op);
+  op.ptr = nullptr;
+}
+
+MlirRewritePatternSet
+mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op) {
+  auto *m = new mlir::RewritePatternSet(std::move(*unwrap(op)));
+  op.ptr = nullptr;
+  return wrap(m);
+}
+#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index d8f2d1989fdea..d03036e17749d 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -21,6 +21,7 @@ declare_mlir_python_sources(MLIRPythonSources.Core.Python
     _mlir_libs/__init__.py
     ir.py
     passmanager.py
+    rewrite.py
     dialects/_ods_common.py
 
     # The main _mlir module has submodules: include stubs from each.
@@ -448,6 +449,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Core
     IRModule.cpp
     IRTypes.cpp
     Pass.cpp
+    Rewrite.cpp
 
     # Headers must be included explicitly so they are installed.
     Globals.h
diff --git a/mlir/python/mlir/rewrite.py b/mlir/python/mlir/rewrite.py
new file mode 100644
index 0000000000000..5bc1bba7ae9a7
--- /dev/null
+++ b/mlir/python/mlir/rewrite.py
@@ -0,0 +1,5 @@
+#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+#  See https://llvm.org/LICENSE.txt for license information.
+#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from ._mlir_libs._mlir.rewrite import *
diff --git a/mlir/test/python/integration/dialects/pdl.py b/mlir/test/python/integration/dialects/pdl.py
new file mode 100644
index 0000000000000..f2eb93fe953ef
--- /dev/null
+++ b/mlir/test/python/integration/dialects/pdl.py
@@ -0,0 +1,67 @@
+# RUN: %PYTHON %s 2>&1 | FileCheck %s
+
+from mlir.dialects import arith, func, pdl
+from mlir.dialects.builtin import module
+from mlir.ir import *
+from mlir.rewrite import *
+
+
+def construct_and_print_in_module(f):
+  print("\nTEST:", f.__name__)
+  with Context(), Location.unknown():
+    module = Module.create()
+    with InsertionPoint(module.body):
+      module = f(module)
+    if module is not None:
+      print(module)
+  return f
+
+
+# CHECK-LABEL: TEST: test_add_to_mul
+# CHECK: arith.muli
+ at construct_and_print_in_module
+def test_add_to_mul(module_):
+  index_type = IndexType.get()
+
+  # Create a test case.
+  @module(sym_name="ir")
+  def ir():
+    @func.func(index_type, index_type)
+    def add_func(a, b):
+      return arith.addi(a, b)
+
+  # Create a rewrite from add to mul. This will match
+  # - operation name is arith.addi
+  # - operands are index types.
+  # - there are two operands.
+  with Location.unknown():
+    m = Module.create()
+    with InsertionPoint(m.body):
+      # Change all arith.addi with index types to arith.muli.
+      pattern = pdl.PatternOp(1, "addi_to_mul")
+      with InsertionPoint(pattern.body):
+        # Match arith.addi with index types.
+        index_type = pdl.TypeOp(IndexType.get())
+        operand0 = pdl.OperandOp(index_type)
+        operand1 = pdl.OperandOp(index_type)
+        op0 = pdl.OperationOp(
+            name="arith.addi", args=[operand0, operand1], types=[index_type]
+        )
+
+        # Replace the matched op with arith.muli.
+        rewrite = pdl.RewriteOp(op0)
+        with InsertionPoint(rewrite.add_body()):
+          newOp = pdl.OperationOp(
+              name="arith.muli", args=[operand0, operand1], types=[index_type]
+          )
+          pdl.ReplaceOp(op0, with_op=newOp)
+
+  # Create a PDL module from module and freeze it. At this point the ownership
+  # of the module is transferred to the PDL module. This ownership transfer is
+  # not yet captured Python side/has sharp edges. So best to construct the
+  # module and PDL module in same scope.
+  # FIXME: This should be made more robust.
+  frozen = PDLModule(m).freeze()
+  # Could apply frozen pattern set multiple times.
+  apply_patterns_and_fold_greedily(module_, frozen)
+  return module_
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 5d2248a8fe360..2d15955548961 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -420,6 +420,7 @@ mlir_c_api_cc_library(
         "include/mlir-c/Interfaces.h",
         "include/mlir-c/Pass.h",
         "include/mlir-c/RegisterEverything.h",
+        "include/mlir-c/Rewrite.h",
         "include/mlir-c/Support.h",
         "include/mlir/CAPI/AffineExpr.h",
         "include/mlir/CAPI/AffineMap.h",
@@ -866,7 +867,10 @@ mlir_c_api_cc_library(
 
 mlir_c_api_cc_library(
     name = "CAPITransforms",
-    srcs = ["lib/CAPI/Transforms/Passes.cpp"],
+    srcs = [
+        "lib/CAPI/Transforms/Passes.cpp",
+        "lib/CAPI/Transforms/Rewrite.cpp",
+    ],
     hdrs = ["include/mlir-c/Transforms.h"],
     capi_deps = [
         ":CAPIIR",
@@ -876,7 +880,10 @@ mlir_c_api_cc_library(
     ],
     includes = ["include"],
     deps = [
+        ":IR",
         ":Pass",
+        ":Rewrite",
+        ":TransformUtils",
         ":Transforms",
     ],
 )
@@ -939,6 +946,7 @@ cc_library(
     textual_hdrs = glob(MLIR_BINDINGS_PYTHON_HEADERS),
     deps = [
         ":CAPIIRHeaders",
+        ":CAPITransformsHeaders",
         "@local_config_python//:python_headers",
         "@pybind11",
     ],
@@ -957,6 +965,7 @@ cc_library(
     textual_hdrs = glob(MLIR_BINDINGS_PYTHON_HEADERS),
     deps = [
         ":CAPIIR",
+        ":CAPITransforms",
         "@local_config_python//:python_headers",
         "@pybind11",
     ],
@@ -981,6 +990,7 @@ MLIR_PYTHON_BINDINGS_SOURCES = [
     "lib/Bindings/Python/IRModule.cpp",
     "lib/Bindings/Python/IRTypes.cpp",
     "lib/Bindings/Python/Pass.cpp",
+    "lib/Bindings/Python/Rewrite.cpp",
 ]
 
 cc_library(
diff --git a/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel
index add150de69faf..254cab0db4a5d 100644
--- a/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel
@@ -82,6 +82,13 @@ filegroup(
     ],
 )
 
+filegroup(
+    name = "RewritePyFiles",
+    srcs = [
+        "mlir/rewrite.py",
+    ],
+)
+
 filegroup(
     name = "RuntimePyFiles",
     srcs = glob([



More information about the Mlir-commits mailing list