[Mlir-commits] [mlir] Fix how the mlir variadic Python acessor `_ods_equally_sized_accessor` is used (#101132) (PR #106003)

Kasper Nielsen llvmlistbot at llvm.org
Sun Aug 25 12:23:25 PDT 2024


https://github.com/kasper0406 created https://github.com/llvm/llvm-project/pull/106003

As reported in https://github.com/llvm/llvm-project/issues/101132, this fixes two bugs:

1. When accessing variadic operands inside an operation, it must be accessed as `self.operation.operands` instead of `operation.operands`
2. The implementation of the `equally_sized_accessor` function is doing wrong arithmetics when calculating the resulting index and group sizes.

The above code had no tests previously. I have added a test for the `equally_sized_accessor` function.


>From c35573e237a135321e334b95452283abfa315601 Mon Sep 17 00:00:00 2001
From: Kasper Nielsen <kasper0406 at gmail.com>
Date: Wed, 21 Aug 2024 16:55:48 +0200
Subject: [PATCH 1/2] Fix MLIR Python bindings when trying to access varadic
 elements

---
 mlir/python/mlir/dialects/_ods_common.py      | 2 +-
 mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp | 2 +-
 2 files changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py
index 1e7e8244ed4420..0b56a376d23813 100644
--- a/mlir/python/mlir/dialects/_ods_common.py
+++ b/mlir/python/mlir/dialects/_ods_common.py
@@ -65,7 +65,7 @@ def equally_sized_accessor(
           group.
     """
 
-    total_variadic_length = len(elements) - n_variadic + 1
+    total_variadic_length = len(elements) - n_preceding_simple
     # This should be enforced by the C++-side trait verifier.
     assert total_variadic_length % n_variadic == 0
 
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 052020acdcb764..97f2cc6c3f5763 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -145,7 +145,7 @@ constexpr const char *opOneVariadicTemplate = R"Py(
 constexpr const char *opVariadicEqualPrefixTemplate = R"Py(
   @builtins.property
   def {0}(self):
-    start, pg = _ods_equally_sized_accessor(operation.{1}s, {2}, {3}, {4}))Py";
+    start, pg = _ods_equally_sized_accessor(self.operation.{1}s, {2}, {3}, {4}))Py";
 
 /// Second part of the template for equally-sized case, accessing a single
 /// element:

>From c1e68b4ebd0b0e0d4638ec23dc53da4e1886d911 Mon Sep 17 00:00:00 2001
From: Kasper Nielsen <kasper0406 at gmail.com>
Date: Sun, 25 Aug 2024 21:15:15 +0200
Subject: [PATCH 2/2] Add and fix tests

---
 mlir/test/mlir-tblgen/op-python-bindings.td | 12 ++--
 mlir/test/python/dialects/ods_helpers.py    | 67 +++++++++++++++++++++
 2 files changed, 73 insertions(+), 6 deletions(-)

diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index 9f202ba08608c6..d0642968854fe5 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -480,17 +480,17 @@ def SameVariadicOperandSizeOp : TestOp<"same_variadic_operand",
                                        [SameVariadicOperandSize]> {
   // CHECK: @builtins.property
   // CHECK: def variadic1(self):
-  // CHECK:   start, pg = _ods_equally_sized_accessor(operation.operands, 2, 0, 0)
+  // CHECK:   start, pg = _ods_equally_sized_accessor(self.operation.operands, 2, 0, 0)
   // CHECK:   return self.operation.operands[start:start + pg]
   //
   // CHECK: @builtins.property
   // CHECK: def non_variadic(self):
-  // CHECK:   start, pg = _ods_equally_sized_accessor(operation.operands, 2, 0, 1)
+  // CHECK:   start, pg = _ods_equally_sized_accessor(self.operation.operands, 2, 0, 1)
   // CHECK:   return self.operation.operands[start]
   //
   // CHECK: @builtins.property
   // CHECK: def variadic2(self):
-  // CHECK:   start, pg = _ods_equally_sized_accessor(operation.operands, 2, 1, 1)
+  // CHECK:   start, pg = _ods_equally_sized_accessor(self.operation.operands, 2, 1, 1)
   // CHECK:   return self.operation.operands[start:start + pg]
   let arguments = (ins Variadic<AnyType>:$variadic1, AnyType:$non_variadic,
                    Variadic<AnyType>:$variadic2);
@@ -506,17 +506,17 @@ def SameVariadicResultSizeOp : TestOp<"same_variadic_result",
                                       [SameVariadicResultSize]> {
   // CHECK: @builtins.property
   // CHECK: def variadic1(self):
-  // CHECK:   start, pg = _ods_equally_sized_accessor(operation.results, 2, 0, 0)
+  // CHECK:   start, pg = _ods_equally_sized_accessor(self.operation.results, 2, 0, 0)
   // CHECK:   return self.operation.results[start:start + pg]
   //
   // CHECK: @builtins.property
   // CHECK: def non_variadic(self):
-  // CHECK:   start, pg = _ods_equally_sized_accessor(operation.results, 2, 0, 1)
+  // CHECK:   start, pg = _ods_equally_sized_accessor(self.operation.results, 2, 0, 1)
   // CHECK:   return self.operation.results[start]
   //
   // CHECK: @builtins.property
   // CHECK: def variadic2(self):
-  // CHECK:   start, pg = _ods_equally_sized_accessor(operation.results, 2, 1, 1)
+  // CHECK:   start, pg = _ods_equally_sized_accessor(self.operation.results, 2, 1, 1)
   // CHECK:   return self.operation.results[start:start + pg]
   let results = (outs Variadic<AnyType>:$variadic1, AnyType:$non_variadic,
                  Variadic<AnyType>:$variadic2);
diff --git a/mlir/test/python/dialects/ods_helpers.py b/mlir/test/python/dialects/ods_helpers.py
index 0d2a18e0eb0af2..cb0d0528a6a866 100644
--- a/mlir/test/python/dialects/ods_helpers.py
+++ b/mlir/test/python/dialects/ods_helpers.py
@@ -3,6 +3,7 @@
 import gc
 
 from mlir.ir import *
+from mlir.dialects._ods_common import equally_sized_accessor
 
 
 def run(f):
@@ -208,3 +209,69 @@ class TestOp(OpView):
 
 
 run(testOdsBuildDefaultCastError)
+
+
+def testOdsEquallySizedAccessor():
+    class TestOpMultiResultSegments(OpView):
+        OPERATION_NAME = "custom.test_op"
+        _ODS_REGIONS = (1, True)
+
+    with Context() as ctx, Location.unknown():
+        ctx.allow_unregistered_dialects = True
+        m = Module.create()
+        with InsertionPoint(m.body):
+            v = add_dummy_value()
+            ts = [IntegerType.get_signless(i * 8) for i in range(4)]
+
+            op = TestOpMultiResultSegments.build_generic(
+                results=[ts[0], ts[1], ts[2], ts[3]], operands=[v]
+            )
+            start, pg = equally_sized_accessor(op.results, 3, 1, 0)
+            # CHECK: start: 1, pg: 1
+            print(f"start: {start}, pg: {pg}")
+            # CHECK: i8
+            print(op.results[start].type)
+
+            start, pg = equally_sized_accessor(op.results, 3, 1, 1)
+            # CHECK: start: 2, pg: 1
+            print(f"start: {start}, pg: {pg}")
+            # CHECK: i16
+            print(op.results[start].type)
+
+
+run(testOdsEquallySizedAccessor)
+
+
+def testOdsEquallySizedAccessorMultipleSegments():
+    class TestOpMultiResultSegments(OpView):
+        OPERATION_NAME = "custom.test_op"
+        _ODS_REGIONS = (1, True)
+        _ODS_RESULT_SEGMENTS = [0, -1, -1]
+
+    def types(lst):
+        return [e.type for e in lst]
+
+    with Context() as ctx, Location.unknown():
+        ctx.allow_unregistered_dialects = True
+        m = Module.create()
+        with InsertionPoint(m.body):
+            v = add_dummy_value()
+            ts = [IntegerType.get_signless(i * 8) for i in range(7)]
+
+            op = TestOpMultiResultSegments.build_generic(
+                results=[ts[0], [ts[1], ts[2], ts[3]], [ts[4], ts[5], ts[6]]], operands=[v]
+            )
+            start, pg = equally_sized_accessor(op.results, 2, 1, 0)
+            # CHECK: start: 1, pg: 3
+            print(f"start: {start}, pg: {pg}")
+            # CHECK: [IntegerType(i8), IntegerType(i16), IntegerType(i24)]
+            print(types(op.results[start:start + pg]))
+
+            start, pg = equally_sized_accessor(op.results, 2, 1, 1)
+            # CHECK: start: 4, pg: 3
+            print(f"start: {start}, pg: {pg}")
+            # CHECK: [IntegerType(i32), IntegerType(i40), IntegerType(i48)]
+            print(types(op.results[start:start + pg]))
+
+
+run(testOdsEquallySizedAccessorMultipleSegments)



More information about the Mlir-commits mailing list