[Mlir-commits] [mlir] 5d3489e - [mlir][transform][lingalg][python] Replace pdl.operation => transform.any_op. (#66392)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Sep 15 00:06:12 PDT 2023


Author: Ingo Müller
Date: 2023-09-15T09:06:07+02:00
New Revision: 5d3489e940d47fcf2108c518b630670c3183e1c1

URL: https://github.com/llvm/llvm-project/commit/5d3489e940d47fcf2108c518b630670c3183e1c1
DIFF: https://github.com/llvm/llvm-project/commit/5d3489e940d47fcf2108c518b630670c3183e1c1.diff

LOG: [mlir][transform][lingalg][python] Replace pdl.operation => transform.any_op. (#66392)

For some reason, the mix-ins of the Python bindings of this dialect used
the PDL type for "any op". However, PDL isn't involved here, so it makes
more sense to use the corresponding type of the transform dialect. This
PR changes that.

Added: 
    

Modified: 
    mlir/python/mlir/dialects/_structured_transform_ops_ext.py

Removed: 
    


################################################################################
diff  --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
index 212fbc5badcbce8..c5134b6e718f3b8 100644
--- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
+++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
@@ -4,7 +4,7 @@
 
 try:
     from ..ir import *
-    from ..dialects import pdl, transform
+    from ..dialects import transform
 except ImportError as e:
     raise RuntimeError("Error loading imports from extension module") from e
 
@@ -203,7 +203,8 @@ class DecomposeOp:
     """Specialization for DecomposeOp class."""
 
     def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
-        super().__init__(pdl.OperationType.get(), target, loc=loc, ip=ip)
+        transformed_type = transform.AnyOpType.get()
+        super().__init__(transformed_type, target, loc=loc, ip=ip)
 
 
 class FuseIntoContainingOp:
@@ -274,7 +275,8 @@ class GeneralizeOp:
     """Specialization for GeneralizeOp class."""
 
     def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
-        super().__init__(pdl.OperationType.get(), target, loc=loc, ip=ip)
+        transformed_type = transform.AnyOpType.get()
+        super().__init__(transformed_type, target, loc=loc, ip=ip)
 
 
 class InterchangeOp:
@@ -288,9 +290,9 @@ def __init__(
         loc=None,
         ip=None,
     ):
-        pdl_operation_type = pdl.OperationType.get()
+        transformed_type = transform.AnyOpType.get()
         super().__init__(
-            pdl_operation_type,
+            transformed_type,
             target,
             iterator_interchange=iterator_interchange,
             loc=loc,
@@ -503,11 +505,11 @@ def __init__(
     ):
         transpose_paddings = _get_int_array_array_attr(transpose_paddings)
 
-        pdl_operation_type = pdl.OperationType.get()
+        any_op_type = transform.AnyOpType.get()
         super().__init__(
-            pdl_operation_type,
-            pdl_operation_type,
-            pdl_operation_type,
+            any_op_type,
+            any_op_type,
+            any_op_type,
             target,
             padding_values=padding_values,
             padding_dimensions=padding_dimensions,
@@ -524,8 +526,8 @@ class ScalarizeOp:
     """Specialization for ScalarizeOp class."""
 
     def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
-        pdl_operation_type = pdl.OperationType.get()
-        super().__init__(pdl_operation_type, target, loc=loc, ip=ip)
+        result_type = transform.AnyOpType.get()
+        super().__init__(result_type, target, loc=loc, ip=ip)
 
 
 class SplitOp:
@@ -736,9 +738,9 @@ def __init__(
         loc=None,
         ip=None,
     ):
-        pdl_operation_type = pdl.OperationType.get()
+        transformed_type = transform.AnyOpType.get()
         super().__init__(
-            pdl_operation_type,
+            transformed_type,
             target,
             disable_multi_reduction_to_contract_patterns=disable_multi_reduction_to_contract_patterns,
             disable_transfer_permutation_map_lowering_patterns=disable_transfer_permutation_map_lowering_patterns,


        


More information about the Mlir-commits mailing list