[Mlir-commits] [mlir] ed21c92 - [mlir] Introduce Python bindings for the PDL dialect

Denys Shabalin llvmlistbot at llvm.org
Wed Jan 19 02:20:00 PST 2022


Author: Denys Shabalin
Date: 2022-01-19T11:19:56+01:00
New Revision: ed21c9276a4cc88d60cbaddc56132b1793ca30c7

URL: https://github.com/llvm/llvm-project/commit/ed21c9276a4cc88d60cbaddc56132b1793ca30c7
DIFF: https://github.com/llvm/llvm-project/commit/ed21c9276a4cc88d60cbaddc56132b1793ca30c7.diff

LOG: [mlir] Introduce Python bindings for the PDL dialect

This change adds full python bindings for PDL, including types and operations
with additional mixins to make operation construction more similar to the PDL
syntax.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D117458

Added: 
    mlir/lib/Bindings/Python/DialectPDL.cpp
    mlir/python/mlir/_mlir_libs/_mlir/dialects/pdl.pyi
    mlir/python/mlir/dialects/PDLOps.td
    mlir/python/mlir/dialects/_pdl_ops_ext.py
    mlir/python/mlir/dialects/pdl.py
    mlir/test/python/dialects/pdl_ops.py
    mlir/test/python/dialects/pdl_types.py

Modified: 
    mlir/include/mlir-c/Dialect/PDL.h
    mlir/lib/CAPI/Dialect/PDL.cpp
    mlir/python/CMakeLists.txt
    mlir/python/mlir/dialects/_ods_common.py
    mlir/test/CAPI/pdl.c

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/Dialect/PDL.h b/mlir/include/mlir-c/Dialect/PDL.h
index 5e0a2bc955f65..1b152899948ca 100644
--- a/mlir/include/mlir-c/Dialect/PDL.h
+++ b/mlir/include/mlir-c/Dialect/PDL.h
@@ -49,6 +49,8 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAPDLRangeType(MlirType type);
 
 MLIR_CAPI_EXPORTED MlirType mlirPDLRangeTypeGet(MlirType elementType);
 
+MLIR_CAPI_EXPORTED MlirType mlirPDLRangeTypeGetElementType(MlirType type);
+
 //===---------------------------------------------------------------------===//
 // TypeType
 //===---------------------------------------------------------------------===//

diff  --git a/mlir/lib/Bindings/Python/DialectPDL.cpp b/mlir/lib/Bindings/Python/DialectPDL.cpp
new file mode 100644
index 0000000000000..8d0b1014ae74f
--- /dev/null
+++ b/mlir/lib/Bindings/Python/DialectPDL.cpp
@@ -0,0 +1,102 @@
+//===- DialectPDL.cpp - 'pdl' dialect submodule ---------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir-c/Dialect/PDL.h"
+#include "mlir-c/IR.h"
+#include "mlir/Bindings/Python/PybindAdaptors.h"
+
+namespace py = pybind11;
+using namespace llvm;
+using namespace mlir;
+using namespace mlir::python;
+using namespace mlir::python::adaptors;
+
+void populateDialectPDLSubmodule(const pybind11::module &m) {
+  //===-------------------------------------------------------------------===//
+  // PDLType
+  //===-------------------------------------------------------------------===//
+
+  auto pdlType = mlir_type_subclass(m, "PDLType", mlirTypeIsAPDLType);
+
+  //===-------------------------------------------------------------------===//
+  // AttributeType
+  //===-------------------------------------------------------------------===//
+
+  auto attributeType =
+      mlir_type_subclass(m, "AttributeType", mlirTypeIsAPDLAttributeType);
+  attributeType.def_classmethod(
+      "get",
+      [](py::object cls, MlirContext ctx) {
+        return cls(mlirPDLAttributeTypeGet(ctx));
+      },
+      "Get an instance of AttributeType in given context.", py::arg("cls"),
+      py::arg("context") = py::none());
+
+  //===-------------------------------------------------------------------===//
+  // OperationType
+  //===-------------------------------------------------------------------===//
+
+  auto operationType =
+      mlir_type_subclass(m, "OperationType", mlirTypeIsAPDLOperationType);
+  operationType.def_classmethod(
+      "get",
+      [](py::object cls, MlirContext ctx) {
+        return cls(mlirPDLOperationTypeGet(ctx));
+      },
+      "Get an instance of OperationType in given context.", py::arg("cls"),
+      py::arg("context") = py::none());
+
+  //===-------------------------------------------------------------------===//
+  // RangeType
+  //===-------------------------------------------------------------------===//
+
+  auto rangeType = mlir_type_subclass(m, "RangeType", mlirTypeIsAPDLRangeType);
+  rangeType.def_classmethod(
+      "get",
+      [](py::object cls, MlirType elementType) {
+        return cls(mlirPDLRangeTypeGet(elementType));
+      },
+      "Gets an instance of RangeType in the same context as the provided "
+      "element type.",
+      py::arg("cls"), py::arg("element_type"));
+  rangeType.def_property_readonly(
+      "element_type",
+      [](MlirType type) { return mlirPDLRangeTypeGetElementType(type); },
+      "Get the element type.");
+
+  //===-------------------------------------------------------------------===//
+  // TypeType
+  //===-------------------------------------------------------------------===//
+
+  auto typeType = mlir_type_subclass(m, "TypeType", mlirTypeIsAPDLTypeType);
+  typeType.def_classmethod(
+      "get",
+      [](py::object cls, MlirContext ctx) {
+        return cls(mlirPDLTypeTypeGet(ctx));
+      },
+      "Get an instance of TypeType in given context.", py::arg("cls"),
+      py::arg("context") = py::none());
+
+  //===-------------------------------------------------------------------===//
+  // ValueType
+  //===-------------------------------------------------------------------===//
+
+  auto valueType = mlir_type_subclass(m, "ValueType", mlirTypeIsAPDLValueType);
+  valueType.def_classmethod(
+      "get",
+      [](py::object cls, MlirContext ctx) {
+        return cls(mlirPDLValueTypeGet(ctx));
+      },
+      "Get an instance of TypeType in given context.", py::arg("cls"),
+      py::arg("context") = py::none());
+}
+
+PYBIND11_MODULE(_mlirDialectsPDL, m) {
+  m.doc() = "MLIR PDL dialect.";
+  populateDialectPDLSubmodule(m);
+}

diff  --git a/mlir/lib/CAPI/Dialect/PDL.cpp b/mlir/lib/CAPI/Dialect/PDL.cpp
index 42b4ec24b008e..497b2cb1fb3ff 100644
--- a/mlir/lib/CAPI/Dialect/PDL.cpp
+++ b/mlir/lib/CAPI/Dialect/PDL.cpp
@@ -60,6 +60,10 @@ MlirType mlirPDLRangeTypeGet(MlirType elementType) {
   return wrap(pdl::RangeType::get(unwrap(elementType)));
 }
 
+MlirType mlirPDLRangeTypeGetElementType(MlirType type) {
+  return wrap(unwrap(type).cast<pdl::RangeType>().getElementType());
+}
+
 //===---------------------------------------------------------------------===//
 // TypeType
 //===---------------------------------------------------------------------===//

diff  --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 2a9d7a7f47be1..77d6b08322644 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -123,6 +123,15 @@ declare_mlir_python_sources(
     dialects/quant.py
     _mlir_libs/_mlir/dialects/quant.pyi)
 
+declare_mlir_python_sources(
+  MLIRPythonSources.Dialects.pdl
+  ADD_TO_PARENT MLIRPythonSources.Dialects
+  ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
+  SOURCES
+    dialects/pdl.py
+    dialects/_pdl_ops_ext.py
+    _mlir_libs/_mlir/dialects/pdl.pyi)
+
 declare_mlir_dialect_python_bindings(
   ADD_TO_PARENT MLIRPythonSources.Dialects
   ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
@@ -243,6 +252,19 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.Quant.Pybind
     MLIRCAPIQuant
 )
 
+declare_mlir_python_extension(MLIRPythonExtension.Dialects.PDL.Pybind
+  MODULE_NAME _mlirDialectsPDL
+  ADD_TO_PARENT MLIRPythonSources.Dialects.pdl
+  ROOT_DIR "${PYTHON_SOURCE_DIR}"
+  SOURCES
+    DialectPDL.cpp
+  PRIVATE_LINK_LIBS
+    LLVMSupport
+  EMBED_CAPI_LINK_LIBS
+    MLIRCAPIIR
+    MLIRCAPIPDL
+)
+
 declare_mlir_python_extension(MLIRPythonExtension.Dialects.SparseTensor.Pybind
   MODULE_NAME _mlirDialectsSparseTensor
   ADD_TO_PARENT MLIRPythonSources.Dialects.sparse_tensor

diff  --git a/mlir/python/mlir/_mlir_libs/_mlir/dialects/pdl.pyi b/mlir/python/mlir/_mlir_libs/_mlir/dialects/pdl.pyi
new file mode 100644
index 0000000000000..8ec944d191c6f
--- /dev/null
+++ b/mlir/python/mlir/_mlir_libs/_mlir/dialects/pdl.pyi
@@ -0,0 +1,64 @@
+#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+#  See https://llvm.org/LICENSE.txt for license information.
+#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from typing import Optional
+
+from mlir.ir import Type, Context
+
+__all__ = [
+    'PDLType',
+    'AttributeType',
+    'OperationType',
+    'RangeType',
+    'TypeType',
+    'ValueType',
+]
+
+
+class PDLType(Type):
+  @staticmethod
+  def isinstance(type: Type) -> bool: ...
+
+
+class AttributeType(Type):
+  @staticmethod
+  def isinstance(type: Type) -> bool: ...
+
+  @staticmethod
+  def get(context: Optional[Context] = None) -> AttributeType: ...
+
+
+class OperationType(Type):
+  @staticmethod
+  def isinstance(type: Type) -> bool: ...
+
+  @staticmethod
+  def get(context: Optional[Context] = None) -> OperationType: ...
+
+
+class RangeType(Type):
+  @staticmethod
+  def isinstance(type: Type) -> bool: ...
+
+  @staticmethod
+  def get(element_type: Type) -> RangeType: ...
+
+  @property
+  def element_type(self) -> Type: ...
+
+
+class TypeType(Type):
+  @staticmethod
+  def isinstance(type: Type) -> bool: ...
+
+  @staticmethod
+  def get(context: Optional[Context] = None) -> TypeType: ...
+
+
+class ValueType(Type):
+  @staticmethod
+  def isinstance(type: Type) -> bool: ...
+
+  @staticmethod
+  def get(context: Optional[Context] = None) -> ValueType: ...

diff  --git a/mlir/python/mlir/dialects/PDLOps.td b/mlir/python/mlir/dialects/PDLOps.td
new file mode 100644
index 0000000000000..e4e6a83cd03f0
--- /dev/null
+++ b/mlir/python/mlir/dialects/PDLOps.td
@@ -0,0 +1,15 @@
+//===-- PDLOps.td - Entry point for PDLOps bind ------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef PYTHON_BINDINGS_PDL_OPS
+#define PYTHON_BINDINGS_PDL_OPS
+
+include "mlir/Bindings/Python/Attributes.td"
+include "mlir/Dialect/PDL/IR/PDLOps.td"
+
+#endif

diff  --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py
index 6bb84e97800dd..0c66593ce3c32 100644
--- a/mlir/python/mlir/dialects/_ods_common.py
+++ b/mlir/python/mlir/dialects/_ods_common.py
@@ -144,7 +144,8 @@ def get_op_result_or_value(
 
 
 def get_op_results_or_values(
-    arg: _Union[_cext.ir.OpView, _cext.ir.Operation, _Sequence[_cext.ir.Value]]
+    arg: _Union[_cext.ir.OpView, _cext.ir.Operation,
+                _Sequence[_Union[_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value]]]
 ) -> _Union[_Sequence[_cext.ir.Value], _cext.ir.OpResultList]:
   """Returns the given sequence of values or the results of the given op.
 
@@ -157,4 +158,4 @@ def get_op_results_or_values(
   elif isinstance(arg, _cext.ir.Operation):
     return arg.results
   else:
-    return arg
+    return [get_op_result_or_value(element) for element in arg]

diff  --git a/mlir/python/mlir/dialects/_pdl_ops_ext.py b/mlir/python/mlir/dialects/_pdl_ops_ext.py
new file mode 100644
index 0000000000000..364db53854f87
--- /dev/null
+++ b/mlir/python/mlir/dialects/_pdl_ops_ext.py
@@ -0,0 +1,284 @@
+#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+#  See https://llvm.org/LICENSE.txt for license information.
+#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+try:
+  from ..ir import *
+  from ..dialects import pdl
+except ImportError as e:
+  raise RuntimeError("Error loading imports from extension module") from e
+
+from typing import Union, Optional, Sequence, List, Mapping
+from ._ods_common import get_op_result_or_value as _get_value, get_op_results_or_values as _get_values
+
+
+def _get_int_attr(bits: int, value: Union[IntegerAttr, int]) -> IntegerAttr:
+  """Converts the given value to signless integer attribute of given bit width."""
+  if isinstance(value, int):
+    ty = IntegerType.get_signless(bits)
+    return IntegerAttr.get(ty, value)
+  else:
+    return value
+
+
+def _get_array_attr(attrs: Union[ArrayAttr, Sequence[Attribute]]) -> ArrayAttr:
+  """Converts the given value to array attribute."""
+  if isinstance(attrs, ArrayAttr):
+    return attrs
+  else:
+    return ArrayAttr.get(list(attrs))
+
+
+def _get_str_array_attr(attrs: Union[ArrayAttr, Sequence[str]]) -> ArrayAttr:
+  """Converts the given value to string array attribute."""
+  if isinstance(attrs, ArrayAttr):
+    return attrs
+  else:
+    return ArrayAttr.get([StringAttr.get(s) for s in attrs])
+
+
+def _get_str_attr(name: Union[StringAttr, str]) -> Optional[StringAttr]:
+  """Converts the given value to string attribute."""
+  if isinstance(name, str):
+    return StringAttr.get(name)
+  else:
+    return name
+
+
+def _get_type_attr(type: Union[TypeAttr, Type]) -> TypeAttr:
+  """Converts the given value to type attribute."""
+  if isinstance(type, Type):
+    return TypeAttr.get(type)
+  else:
+    return type
+
+
+class ApplyNativeConstraintOp:
+  """Specialization for PDL apply native constraint op class."""
+
+  def __init__(self,
+               name: Union[str, StringAttr],
+               args: Sequence[Union[OpView, Operation, Value]] = [],
+               params: Optional[Union[ArrayAttr, Sequence[Attribute]]] = None,
+               *,
+               loc=None,
+               ip=None):
+    name = _get_str_attr(name)
+    args = _get_values(args)
+    params = params if params is None else _get_array_attr(params)
+    super().__init__(name, args, params, loc=loc, ip=ip)
+
+
+class ApplyNativeRewriteOp:
+  """Specialization for PDL apply native rewrite op class."""
+
+  def __init__(self,
+               results: Sequence[Type],
+               name: Union[str, StringAttr],
+               args: Sequence[Union[OpView, Operation, Value]] = [],
+               params: Optional[Union[ArrayAttr, Sequence[Attribute]]] = None,
+               *,
+               loc=None,
+               ip=None):
+    name = _get_str_attr(name)
+    args = _get_values(args)
+    params = params if params is None else _get_array_attr(params)
+    super().__init__(results, name, args, params, loc=loc, ip=ip)
+
+
+class AttributeOp:
+  """Specialization for PDL attribute op class."""
+
+  def __init__(self,
+               type: Optional[Union[OpView, Operation, Value]] = None,
+               value: Optional[Attribute] = None,
+               *,
+               loc=None,
+               ip=None):
+    type = type if type is None else _get_value(type)
+    result = pdl.AttributeType.get()
+    super().__init__(result, type, value, loc=loc, ip=ip)
+
+
+class EraseOp:
+  """Specialization for PDL erase op class."""
+
+  def __init__(self,
+               operation: Optional[Union[OpView, Operation, Value]] = None,
+               *,
+               loc=None,
+               ip=None):
+    operation = _get_value(operation)
+    super().__init__(operation, loc=loc, ip=ip)
+
+
+class OperandOp:
+  """Specialization for PDL operand op class."""
+
+  def __init__(self,
+               type: Optional[Union[OpView, Operation, Value]] = None,
+               *,
+               loc=None,
+               ip=None):
+    type = type if type is None else _get_value(type)
+    result = pdl.ValueType.get()
+    super().__init__(result, type, loc=loc, ip=ip)
+
+
+class OperandsOp:
+  """Specialization for PDL operands op class."""
+
+  def __init__(self,
+               types: Optional[Union[OpView, Operation, Value]] = None,
+               *,
+               loc=None,
+               ip=None):
+    types = types if types is None else _get_value(types)
+    result = pdl.RangeType.get(pdl.ValueType.get())
+    super().__init__(result, types, loc=loc, ip=ip)
+
+
+class OperationOp:
+  """Specialization for PDL operand op class."""
+
+  def __init__(self,
+               name: Optional[Union[str, StringAttr]] = None,
+               args: Sequence[Union[OpView, Operation, Value]] = [],
+               attributes: Mapping[str, Union[OpView, Operation, Value]] = {},
+               types: Sequence[Union[OpView, Operation, Value]] = [],
+               *,
+               loc=None,
+               ip=None):
+    name = name if name is None else _get_str_attr(name)
+    args = _get_values(args)
+    attributeNames = []
+    attributeValues = []
+    for attrName, attrValue in attributes.items():
+      attributeNames.append(StringAttr.get(attrName))
+      attributeValues.append(_get_value(attrValue))
+    attributeNames = ArrayAttr.get(attributeNames)
+    types = _get_values(types)
+    result = pdl.OperationType.get()
+    super().__init__(result, name, args, attributeValues, attributeNames, types, loc=loc, ip=ip)
+
+
+class PatternOp:
+  """Specialization for PDL pattern op class."""
+
+  def __init__(self,
+               benefit: Union[IntegerAttr, int],
+               name: Optional[Union[StringAttr, str]] = None,
+               *,
+               loc=None,
+               ip=None):
+    """Creates an PDL `pattern` operation."""
+    name_attr = None if name is None else _get_str_attr(name)
+    benefit_attr = _get_int_attr(16, benefit)
+    super().__init__(benefit_attr, name_attr, loc=loc, ip=ip)
+    self.regions[0].blocks.append()
+
+  @property
+  def body(self):
+    """Return the body (block) of the pattern."""
+    return self.regions[0].blocks[0]
+
+
+class ReplaceOp:
+  """Specialization for PDL replace op class."""
+
+  def __init__(self,
+               op: Union[OpView, Operation, Value],
+               *,
+               with_op: Optional[Union[OpView, Operation, Value]] = None,
+               with_values: Sequence[Union[OpView, Operation, Value]] = [],
+               loc=None,
+               ip=None):
+    op = _get_value(op)
+    with_op = with_op if with_op is None else _get_value(with_op)
+    with_values = _get_values(with_values)
+    super().__init__(op, with_op, with_values, loc=loc, ip=ip)
+
+
+class ResultOp:
+  """Specialization for PDL result op class."""
+
+  def __init__(self,
+               parent: Union[OpView, Operation, Value],
+               index: Union[IntegerAttr, int],
+               *,
+               loc=None,
+               ip=None):
+    index = _get_int_attr(32, index)
+    parent = _get_value(parent)
+    result = pdl.ValueType.get()
+    super().__init__(result, parent, index, loc=loc, ip=ip)
+
+
+class ResultsOp:
+  """Specialization for PDL results op class."""
+
+  def __init__(self,
+               result: Type,
+               parent: Union[OpView, Operation, Value],
+               index: Optional[Union[IntegerAttr, int]] = None,
+               *,
+               loc=None,
+               ip=None):
+    parent = _get_value(parent)
+    index = index if index is None else _get_int_attr(32, index)
+    super().__init__(result, parent, index, loc=loc, ip=ip)
+
+
+class RewriteOp:
+  """Specialization for PDL rewrite op class."""
+
+  def __init__(self,
+               root: Optional[Union[OpView, Operation, Value]] = None,
+               name: Optional[Union[StringAttr, str]] = None,
+               args: Sequence[Union[OpView, Operation, Value]] = [],
+               params: Optional[Union[ArrayAttr, Sequence[Attribute]]] = None,
+               *,
+               loc=None,
+               ip=None):
+    root = root if root is None else _get_value(root)
+    name = name if name is None else _get_str_attr(name)
+    args = _get_values(args)
+    params = params if params is None else _get_array_attr(params)
+    super().__init__(root, name, args, params, loc=loc, ip=ip)
+
+  def add_body(self):
+    """Add body (block) to the rewrite."""
+    self.regions[0].blocks.append()
+    return self.body
+
+  @property
+  def body(self):
+    """Return the body (block) of the rewrite."""
+    return self.regions[0].blocks[0]
+
+
+class TypeOp:
+  """Specialization for PDL type op class."""
+
+  def __init__(self,
+               type: Optional[Union[TypeAttr, Type]] = None,
+               *,
+               loc=None,
+               ip=None):
+    type = type if type is None else _get_type_attr(type)
+    result = pdl.TypeType.get()
+    super().__init__(result, type, loc=loc, ip=ip)
+
+
+class TypesOp:
+  """Specialization for PDL types op class."""
+
+  def __init__(self,
+               types: Sequence[Union[TypeAttr, Type]] = [],
+               *,
+               loc=None,
+               ip=None):
+    types = _get_array_attr([_get_type_attr(ty) for ty in types])
+    types = None if not types else types
+    result = pdl.RangeType.get(pdl.TypeType.get())
+    super().__init__(result, types, loc=loc, ip=ip)

diff  --git a/mlir/python/mlir/dialects/pdl.py b/mlir/python/mlir/dialects/pdl.py
new file mode 100644
index 0000000000000..dda2b7d652196
--- /dev/null
+++ b/mlir/python/mlir/dialects/pdl.py
@@ -0,0 +1,6 @@
+#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+#  See https://llvm.org/LICENSE.txt for license information.
+#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from ._pdl_ops_gen import *
+from .._mlir_libs._mlirDialectsPDL import *

diff  --git a/mlir/test/CAPI/pdl.c b/mlir/test/CAPI/pdl.c
index b7c59cc4579dd..c7ef98b4d6e76 100644
--- a/mlir/test/CAPI/pdl.c
+++ b/mlir/test/CAPI/pdl.c
@@ -146,6 +146,7 @@ void testRangeType(MlirContext ctx) {
   MlirType parsedType = mlirTypeParseGet(
       ctx, mlirStringRefCreateFromCString("!pdl.range<type>"));
   MlirType constructedType = mlirPDLRangeTypeGet(typeType);
+  MlirType elementType = mlirPDLRangeTypeGetElementType(constructedType);
 
   assert(!mlirTypeIsNull(typeType) && "couldn't get PDLTypeType");
   assert(!mlirTypeIsNull(parsedType) && "couldn't parse PDLAttributeType");
@@ -191,11 +192,15 @@ void testRangeType(MlirContext ctx) {
 
   // CHECK: equal: 1
   fprintf(stderr, "equal: %d\n", mlirTypeEqual(parsedType, constructedType));
+  // CHECK: equal: 1
+  fprintf(stderr, "equal: %d\n", mlirTypeEqual(typeType, elementType));
 
   // CHECK: !pdl.range<type>
   mlirTypeDump(parsedType);
   // CHECK: !pdl.range<type>
   mlirTypeDump(constructedType);
+  // CHECK: !pdl.type
+  mlirTypeDump(elementType);
 
   fprintf(stderr, "\n\n");
 }

diff  --git a/mlir/test/python/dialects/pdl_ops.py b/mlir/test/python/dialects/pdl_ops.py
new file mode 100644
index 0000000000000..9b5ce4c533b8e
--- /dev/null
+++ b/mlir/test/python/dialects/pdl_ops.py
@@ -0,0 +1,318 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+from mlir.ir import *
+from mlir.dialects.pdl import *
+
+
+def constructAndPrintInModule(f):
+  print("\nTEST:", f.__name__)
+  with Context(), Location.unknown():
+    module = Module.create()
+    with InsertionPoint(module.body):
+      f()
+    print(module)
+  return f
+
+
+# CHECK: module  {
+# CHECK:   pdl.pattern @operations : benefit(1)  {
+# CHECK:     %0 = pdl.attribute
+# CHECK:     %1 = pdl.type
+# CHECK:     %2 = pdl.operation  {"attr" = %0} -> (%1 : !pdl.type)
+# CHECK:     %3 = pdl.result 0 of %2
+# CHECK:     %4 = pdl.operand
+# CHECK:     %5 = pdl.operation(%3, %4 : !pdl.value, !pdl.value)
+# CHECK:     pdl.rewrite %5 with "rewriter"
+# CHECK:   }
+# CHECK: }
+ at constructAndPrintInModule
+def test_operations():
+  pattern = PatternOp(1, "operations")
+  with InsertionPoint(pattern.body):
+    attr = AttributeOp()
+    ty = TypeOp()
+    op0 = OperationOp(attributes={"attr": attr}, types=[ty])
+    op0_result = ResultOp(op0, 0)
+    input = OperandOp()
+    root = OperationOp(args=[op0_result, input])
+    RewriteOp(root, "rewriter")
+
+
+# CHECK: module  {
+# CHECK:   pdl.pattern @rewrite_with_args : benefit(1)  {
+# CHECK:     %0 = pdl.operand
+# CHECK:     %1 = pdl.operation(%0 : !pdl.value)
+# CHECK:     pdl.rewrite %1 with "rewriter"(%0 : !pdl.value)
+# CHECK:   }
+# CHECK: }
+ at constructAndPrintInModule
+def test_rewrite_with_args():
+  pattern = PatternOp(1, "rewrite_with_args")
+  with InsertionPoint(pattern.body):
+    input = OperandOp()
+    root = OperationOp(args=[input])
+    RewriteOp(root, "rewriter", args=[input])
+
+# CHECK: module  {
+# CHECK:   pdl.pattern @rewrite_with_params : benefit(1)  {
+# CHECK:     %0 = pdl.operation
+# CHECK:     pdl.rewrite %0 with "rewriter" ["I am param"]
+# CHECK:   }
+# CHECK: }
+ at constructAndPrintInModule
+def test_rewrite_with_params():
+  pattern = PatternOp(1, "rewrite_with_params")
+  with InsertionPoint(pattern.body):
+    op = OperationOp()
+    RewriteOp(op, "rewriter", params=[StringAttr.get("I am param")])
+
+# CHECK: module  {
+# CHECK:   pdl.pattern @rewrite_with_args_and_params : benefit(1)  {
+# CHECK:     %0 = pdl.operand
+# CHECK:     %1 = pdl.operation(%0 : !pdl.value)
+# CHECK:     pdl.rewrite %1 with "rewriter" ["I am param"](%0 : !pdl.value)
+# CHECK:   }
+# CHECK: }
+ at constructAndPrintInModule
+def test_rewrite_with_args_and_params():
+  pattern = PatternOp(1, "rewrite_with_args_and_params")
+  with InsertionPoint(pattern.body):
+    input = OperandOp()
+    root = OperationOp(args=[input])
+    RewriteOp(root, "rewriter", params=[StringAttr.get("I am param")], args=[input])
+
+# CHECK: module  {
+# CHECK:   pdl.pattern @rewrite_multi_root_optimal : benefit(1)  {
+# CHECK:     %0 = pdl.operand
+# CHECK:     %1 = pdl.operand
+# CHECK:     %2 = pdl.type
+# CHECK:     %3 = pdl.operation(%0 : !pdl.value)  -> (%2 : !pdl.type)
+# CHECK:     %4 = pdl.result 0 of %3
+# CHECK:     %5 = pdl.operation(%4 : !pdl.value)
+# CHECK:     %6 = pdl.operation(%1 : !pdl.value)  -> (%2 : !pdl.type)
+# CHECK:     %7 = pdl.result 0 of %6
+# CHECK:     %8 = pdl.operation(%4, %7 : !pdl.value, !pdl.value)
+# CHECK:     pdl.rewrite with "rewriter" ["I am param"](%5, %8 : !pdl.operation, !pdl.operation)
+# CHECK:   }
+# CHECK: }
+ at constructAndPrintInModule
+def test_rewrite_multi_root_optimal():
+  pattern = PatternOp(1, "rewrite_multi_root_optimal")
+  with InsertionPoint(pattern.body):
+    input1 = OperandOp()
+    input2 = OperandOp()
+    ty = TypeOp()
+    op1 = OperationOp(args=[input1], types=[ty])
+    val1 = ResultOp(op1, 0)
+    root1 = OperationOp(args=[val1])
+    op2 = OperationOp(args=[input2], types=[ty])
+    val2 = ResultOp(op2, 0)
+    root2 = OperationOp(args=[val1, val2])
+    RewriteOp(name="rewriter", params=[StringAttr.get("I am param")], args=[root1, root2])
+
+# CHECK: module  {
+# CHECK:   pdl.pattern @rewrite_multi_root_forced : benefit(1)  {
+# CHECK:     %0 = pdl.operand
+# CHECK:     %1 = pdl.operand
+# CHECK:     %2 = pdl.type
+# CHECK:     %3 = pdl.operation(%0 : !pdl.value)  -> (%2 : !pdl.type)
+# CHECK:     %4 = pdl.result 0 of %3
+# CHECK:     %5 = pdl.operation(%4 : !pdl.value)
+# CHECK:     %6 = pdl.operation(%1 : !pdl.value)  -> (%2 : !pdl.type)
+# CHECK:     %7 = pdl.result 0 of %6
+# CHECK:     %8 = pdl.operation(%4, %7 : !pdl.value, !pdl.value)
+# CHECK:     pdl.rewrite %5 with "rewriter" ["I am param"](%8 : !pdl.operation)
+# CHECK:   }
+# CHECK: }
+ at constructAndPrintInModule
+def test_rewrite_multi_root_forced():
+  pattern = PatternOp(1, "rewrite_multi_root_forced")
+  with InsertionPoint(pattern.body):
+    input1 = OperandOp()
+    input2 = OperandOp()
+    ty = TypeOp()
+    op1 = OperationOp(args=[input1], types=[ty])
+    val1 = ResultOp(op1, 0)
+    root1 = OperationOp(args=[val1])
+    op2 = OperationOp(args=[input2], types=[ty])
+    val2 = ResultOp(op2, 0)
+    root2 = OperationOp(args=[val1, val2])
+    RewriteOp(root1, name="rewriter", params=[StringAttr.get("I am param")], args=[root2])
+
+# CHECK: module  {
+# CHECK:   pdl.pattern @rewrite_add_body : benefit(1)  {
+# CHECK:     %0 = pdl.type : i32
+# CHECK:     %1 = pdl.type
+# CHECK:     %2 = pdl.operation  -> (%0, %1 : !pdl.type, !pdl.type)
+# CHECK:     pdl.rewrite %2  {
+# CHECK:       %3 = pdl.type
+# CHECK:       %4 = pdl.operation "foo.op"  -> (%0, %3 : !pdl.type, !pdl.type)
+# CHECK:       pdl.replace %2 with %4
+# CHECK:     }
+# CHECK:   }
+# CHECK: }
+ at constructAndPrintInModule
+def test_rewrite_add_body():
+  pattern = PatternOp(1, "rewrite_add_body")
+  with InsertionPoint(pattern.body):
+    ty1 = TypeOp(IntegerType.get_signless(32))
+    ty2 = TypeOp()
+    root = OperationOp(types=[ty1, ty2])
+    rewrite = RewriteOp(root)
+    with InsertionPoint(rewrite.add_body()):
+      ty3 = TypeOp()
+      newOp = OperationOp(name="foo.op", types=[ty1, ty3])
+      ReplaceOp(root, with_op=newOp)
+
+# CHECK: module  {
+# CHECK:   pdl.pattern @rewrite_type : benefit(1)  {
+# CHECK:     %0 = pdl.type : i32
+# CHECK:     %1 = pdl.type
+# CHECK:     %2 = pdl.operation  -> (%0, %1 : !pdl.type, !pdl.type)
+# CHECK:     pdl.rewrite %2  {
+# CHECK:       %3 = pdl.operation "foo.op"  -> (%0, %1 : !pdl.type, !pdl.type)
+# CHECK:     }
+# CHECK:   }
+# CHECK: }
+ at constructAndPrintInModule
+def test_rewrite_type():
+  pattern = PatternOp(1, "rewrite_type")
+  with InsertionPoint(pattern.body):
+    ty1 = TypeOp(IntegerType.get_signless(32))
+    ty2 = TypeOp()
+    root = OperationOp(types=[ty1, ty2])
+    rewrite = RewriteOp(root)
+    with InsertionPoint(rewrite.add_body()):
+      newOp = OperationOp(name="foo.op", types=[ty1, ty2])
+
+# CHECK: module  {
+# CHECK:   pdl.pattern @rewrite_types : benefit(1)  {
+# CHECK:     %0 = pdl.types
+# CHECK:     %1 = pdl.operation  -> (%0 : !pdl.range<type>)
+# CHECK:     pdl.rewrite %1  {
+# CHECK:       %2 = pdl.types : [i32, i64]
+# CHECK:       %3 = pdl.operation "foo.op"  -> (%0, %2 : !pdl.range<type>, !pdl.range<type>)
+# CHECK:     }
+# CHECK:   }
+# CHECK: }
+ at constructAndPrintInModule
+def test_rewrite_types():
+  pattern = PatternOp(1, "rewrite_types")
+  with InsertionPoint(pattern.body):
+    types = TypesOp()
+    root = OperationOp(types=[types])
+    rewrite = RewriteOp(root)
+    with InsertionPoint(rewrite.add_body()):
+      otherTypes = TypesOp([IntegerType.get_signless(32), IntegerType.get_signless(64)])
+      newOp = OperationOp(name="foo.op", types=[types, otherTypes])
+
+# CHECK: module  {
+# CHECK:   pdl.pattern @rewrite_operands : benefit(1)  {
+# CHECK:     %0 = pdl.types
+# CHECK:     %1 = pdl.operands : %0
+# CHECK:     %2 = pdl.operation(%1 : !pdl.range<value>)
+# CHECK:     pdl.rewrite %2  {
+# CHECK:       %3 = pdl.operation "foo.op"  -> (%0 : !pdl.range<type>)
+# CHECK:     }
+# CHECK:   }
+# CHECK: }
+ at constructAndPrintInModule
+def test_rewrite_operands():
+  pattern = PatternOp(1, "rewrite_operands")
+  with InsertionPoint(pattern.body):
+    types = TypesOp()
+    operands = OperandsOp(types)
+    root = OperationOp(args=[operands])
+    rewrite = RewriteOp(root)
+    with InsertionPoint(rewrite.add_body()):
+      newOp = OperationOp(name="foo.op", types=[types])
+
+# CHECK: module  {
+# CHECK:   pdl.pattern @native_rewrite : benefit(1)  {
+# CHECK:     %0 = pdl.operation
+# CHECK:     pdl.rewrite %0  {
+# CHECK:       pdl.apply_native_rewrite "NativeRewrite"(%0 : !pdl.operation)
+# CHECK:     }
+# CHECK:   }
+# CHECK: }
+ at constructAndPrintInModule
+def test_native_rewrite():
+  pattern = PatternOp(1, "native_rewrite")
+  with InsertionPoint(pattern.body):
+    root = OperationOp()
+    rewrite = RewriteOp(root)
+    with InsertionPoint(rewrite.add_body()):
+      ApplyNativeRewriteOp([], "NativeRewrite", args=[root])
+
+# CHECK: module  {
+# CHECK:   pdl.pattern @attribute_with_value : benefit(1)  {
+# CHECK:     %0 = pdl.operation
+# CHECK:     pdl.rewrite %0  {
+# CHECK:       %1 = pdl.attribute "value"
+# CHECK:       pdl.apply_native_rewrite "NativeRewrite"(%1 : !pdl.attribute)
+# CHECK:     }
+# CHECK:   }
+# CHECK: }
+ at constructAndPrintInModule
+def test_attribute_with_value():
+  pattern = PatternOp(1, "attribute_with_value")
+  with InsertionPoint(pattern.body):
+    root = OperationOp()
+    rewrite = RewriteOp(root)
+    with InsertionPoint(rewrite.add_body()):
+      attr = AttributeOp(value=Attribute.parse('"value"'))
+      ApplyNativeRewriteOp([], "NativeRewrite", args=[attr])
+
+# CHECK: module  {
+# CHECK:   pdl.pattern @erase : benefit(1)  {
+# CHECK:     %0 = pdl.operation
+# CHECK:     pdl.rewrite %0  {
+# CHECK:       pdl.erase %0
+# CHECK:     }
+# CHECK:   }
+# CHECK: }
+ at constructAndPrintInModule
+def test_erase():
+  pattern = PatternOp(1, "erase")
+  with InsertionPoint(pattern.body):
+    root = OperationOp()
+    rewrite = RewriteOp(root)
+    with InsertionPoint(rewrite.add_body()):
+      EraseOp(root)
+
+# CHECK: module  {
+# CHECK:   pdl.pattern @operation_results : benefit(1)  {
+# CHECK:     %0 = pdl.types
+# CHECK:     %1 = pdl.operation  -> (%0 : !pdl.range<type>)
+# CHECK:     %2 = pdl.results of %1
+# CHECK:     %3 = pdl.operation(%2 : !pdl.range<value>)
+# CHECK:     pdl.rewrite %3 with "rewriter"
+# CHECK:   }
+# CHECK: }
+ at constructAndPrintInModule
+def test_operation_results():
+  valueRange = RangeType.get(ValueType.get())
+  pattern = PatternOp(1, "operation_results")
+  with InsertionPoint(pattern.body):
+    types = TypesOp()
+    inputOp = OperationOp(types=[types])
+    results = ResultsOp(valueRange, inputOp)
+    root = OperationOp(args=[results])
+    RewriteOp(root, name="rewriter")
+
+# CHECK: module  {
+# CHECK:   pdl.pattern : benefit(1)  {
+# CHECK:     %0 = pdl.type
+# CHECK:     pdl.apply_native_constraint "typeConstraint" [](%0 : !pdl.type)
+# CHECK:     %1 = pdl.operation  -> (%0 : !pdl.type)
+# CHECK:     pdl.rewrite %1 with "rewrite"
+# CHECK:   }
+# CHECK: }
+ at constructAndPrintInModule
+def test_apply_native_constraint():
+  pattern = PatternOp(1)
+  with InsertionPoint(pattern.body):
+    resultType = TypeOp()
+    ApplyNativeConstraintOp("typeConstraint", args=[resultType], params=[])
+    root = OperationOp(types=[resultType])
+    RewriteOp(root, name="rewrite")

diff  --git a/mlir/test/python/dialects/pdl_types.py b/mlir/test/python/dialects/pdl_types.py
new file mode 100644
index 0000000000000..16a41e2a4c1ce
--- /dev/null
+++ b/mlir/test/python/dialects/pdl_types.py
@@ -0,0 +1,150 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+from mlir.ir import *
+from mlir.dialects import pdl
+
+
+def run(f):
+  print("\nTEST:", f.__name__)
+  f()
+  return f
+
+
+# CHECK-LABEL: TEST: test_attribute_type
+ at run
+def test_attribute_type():
+  with Context():
+    parsedType = Type.parse("!pdl.attribute")
+    constructedType = pdl.AttributeType.get()
+
+    assert pdl.AttributeType.isinstance(parsedType)
+    assert not pdl.OperationType.isinstance(parsedType)
+    assert not pdl.RangeType.isinstance(parsedType)
+    assert not pdl.TypeType.isinstance(parsedType)
+    assert not pdl.ValueType.isinstance(parsedType)
+
+    assert pdl.AttributeType.isinstance(constructedType)
+    assert not pdl.OperationType.isinstance(constructedType)
+    assert not pdl.RangeType.isinstance(constructedType)
+    assert not pdl.TypeType.isinstance(constructedType)
+    assert not pdl.ValueType.isinstance(constructedType)
+
+    assert parsedType == constructedType
+
+    # CHECK: !pdl.attribute
+    print(parsedType)
+    # CHECK: !pdl.attribute
+    print(constructedType)
+
+
+# CHECK-LABEL: TEST: test_operation_type
+ at run
+def test_operation_type():
+  with Context():
+    parsedType = Type.parse("!pdl.operation")
+    constructedType = pdl.OperationType.get()
+
+    assert not pdl.AttributeType.isinstance(parsedType)
+    assert pdl.OperationType.isinstance(parsedType)
+    assert not pdl.RangeType.isinstance(parsedType)
+    assert not pdl.TypeType.isinstance(parsedType)
+    assert not pdl.ValueType.isinstance(parsedType)
+
+    assert not pdl.AttributeType.isinstance(constructedType)
+    assert pdl.OperationType.isinstance(constructedType)
+    assert not pdl.RangeType.isinstance(constructedType)
+    assert not pdl.TypeType.isinstance(constructedType)
+    assert not pdl.ValueType.isinstance(constructedType)
+
+    assert parsedType == constructedType
+
+    # CHECK: !pdl.operation
+    print(parsedType)
+    # CHECK: !pdl.operation
+    print(constructedType)
+
+
+# CHECK-LABEL: TEST: test_range_type
+ at run
+def test_range_type():
+  with Context():
+    typeType = Type.parse("!pdl.type")
+    parsedType = Type.parse("!pdl.range<type>")
+    constructedType = pdl.RangeType.get(typeType)
+    elementType = constructedType.element_type
+
+    assert not pdl.AttributeType.isinstance(parsedType)
+    assert not pdl.OperationType.isinstance(parsedType)
+    assert pdl.RangeType.isinstance(parsedType)
+    assert not pdl.TypeType.isinstance(parsedType)
+    assert not pdl.ValueType.isinstance(parsedType)
+
+    assert not pdl.AttributeType.isinstance(constructedType)
+    assert not pdl.OperationType.isinstance(constructedType)
+    assert pdl.RangeType.isinstance(constructedType)
+    assert not pdl.TypeType.isinstance(constructedType)
+    assert not pdl.ValueType.isinstance(constructedType)
+
+    assert parsedType == constructedType
+    assert elementType == typeType
+
+    # CHECK: !pdl.range<type>
+    print(parsedType)
+    # CHECK: !pdl.range<type>
+    print(constructedType)
+    # CHECK: !pdl.type
+    print(elementType)
+
+
+# CHECK-LABEL: TEST: test_type_type
+ at run
+def test_type_type():
+  with Context():
+    parsedType = Type.parse("!pdl.type")
+    constructedType = pdl.TypeType.get()
+
+    assert not pdl.AttributeType.isinstance(parsedType)
+    assert not pdl.OperationType.isinstance(parsedType)
+    assert not pdl.RangeType.isinstance(parsedType)
+    assert pdl.TypeType.isinstance(parsedType)
+    assert not pdl.ValueType.isinstance(parsedType)
+
+    assert not pdl.AttributeType.isinstance(constructedType)
+    assert not pdl.OperationType.isinstance(constructedType)
+    assert not pdl.RangeType.isinstance(constructedType)
+    assert pdl.TypeType.isinstance(constructedType)
+    assert not pdl.ValueType.isinstance(constructedType)
+
+    assert parsedType == constructedType
+
+    # CHECK: !pdl.type
+    print(parsedType)
+    # CHECK: !pdl.type
+    print(constructedType)
+
+
+# CHECK-LABEL: TEST: test_value_type
+ at run
+def test_value_type():
+  with Context():
+    parsedType = Type.parse("!pdl.value")
+    constructedType = pdl.ValueType.get()
+
+    assert not pdl.AttributeType.isinstance(parsedType)
+    assert not pdl.OperationType.isinstance(parsedType)
+    assert not pdl.RangeType.isinstance(parsedType)
+    assert not pdl.TypeType.isinstance(parsedType)
+    assert pdl.ValueType.isinstance(parsedType)
+
+    assert not pdl.AttributeType.isinstance(constructedType)
+    assert not pdl.OperationType.isinstance(constructedType)
+    assert not pdl.RangeType.isinstance(constructedType)
+    assert not pdl.TypeType.isinstance(constructedType)
+    assert pdl.ValueType.isinstance(constructedType)
+
+    assert parsedType == constructedType
+
+    # CHECK: !pdl.value
+    print(parsedType)
+    # CHECK: !pdl.value
+    print(constructedType)


        


More information about the Mlir-commits mailing list