[llvm] [mlir] [mlir] Add PDL C & Python usage (PR #94714)

Maksim Levental via llvm-commits llvm-commits at lists.llvm.org
Sun Jun 9 17:18:18 PDT 2024


================
@@ -0,0 +1,67 @@
+# RUN: %PYTHON %s 2>&1 | FileCheck %s
+
+from mlir.dialects import arith, func, pdl
+from mlir.dialects.builtin import module
+from mlir.ir import *
+from mlir.rewrite import *
+
+
+def construct_and_print_in_module(f):
+    print("\nTEST:", f.__name__)
+    with Context(), Location.unknown():
+        module = Module.create()
+        with InsertionPoint(module.body):
+            module = f(module)
+        if module is not None:
+            print(module)
+    return f
+
+
+# CHECK-LABEL: TEST: test_add_to_mul
+# CHECK: arith.muli
+ at construct_and_print_in_module
+def test_add_to_mul(module_):
+    index_type = IndexType.get()
+
+    # Create a test case.
+    @module(sym_name="ir")
+    def ir():
+        @func.func(index_type, index_type)
+        def add_func(a, b):
+            return arith.addi(a, b)
+
+    # Create a rewrite from add to mul. This will match
+    # - operation name is arith.addi
+    # - operands are index types.
+    # - there are two operands.
+    with Location.unknown():
+        m = Module.create()
+        with InsertionPoint(m.body):
+            # Change all arith.addi with index types to arith.muli.
+            pattern = pdl.PatternOp(1, "addi_to_mul")
+            with InsertionPoint(pattern.body):
+                # Match arith.addi with index types.
+                index_type = pdl.TypeOp(IndexType.get())
+                operand0 = pdl.OperandOp(index_type)
+                operand1 = pdl.OperandOp(index_type)
+                op0 = pdl.OperationOp(
+                    name="arith.addi", args=[operand0, operand1], types=[index_type]
+                )
+
+                # Replace the matched op with arith.muli.
+                rewrite = pdl.RewriteOp(op0)
+                with InsertionPoint(rewrite.add_body()):
----------------
makslevental wrote:

```suggestion
                @pdl.rewrite()
                def rew():
```

https://github.com/llvm/llvm-project/pull/94714


More information about the llvm-commits mailing list