[Mlir-commits] [mlir] [mlir][python] Fix how the mlir variadic Python accessor `_ods_equally_sized_accessor` is used (#101132) (PR #106003)

Kasper Nielsen llvmlistbot at llvm.org
Fri Aug 30 06:32:05 PDT 2024


https://github.com/kasper0406 updated https://github.com/llvm/llvm-project/pull/106003

>From c35573e237a135321e334b95452283abfa315601 Mon Sep 17 00:00:00 2001
From: Kasper Nielsen <kasper0406 at gmail.com>
Date: Wed, 21 Aug 2024 16:55:48 +0200
Subject: [PATCH 1/8] Fix MLIR Python bindings when trying to access varadic
 elements

---
 mlir/python/mlir/dialects/_ods_common.py      | 2 +-
 mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp | 2 +-
 2 files changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py
index 1e7e8244ed4420..0b56a376d23813 100644
--- a/mlir/python/mlir/dialects/_ods_common.py
+++ b/mlir/python/mlir/dialects/_ods_common.py
@@ -65,7 +65,7 @@ def equally_sized_accessor(
           group.
     """
 
-    total_variadic_length = len(elements) - n_variadic + 1
+    total_variadic_length = len(elements) - n_preceding_simple
     # This should be enforced by the C++-side trait verifier.
     assert total_variadic_length % n_variadic == 0
 
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 052020acdcb764..97f2cc6c3f5763 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -145,7 +145,7 @@ constexpr const char *opOneVariadicTemplate = R"Py(
 constexpr const char *opVariadicEqualPrefixTemplate = R"Py(
   @builtins.property
   def {0}(self):
-    start, pg = _ods_equally_sized_accessor(operation.{1}s, {2}, {3}, {4}))Py";
+    start, pg = _ods_equally_sized_accessor(self.operation.{1}s, {2}, {3}, {4}))Py";
 
 /// Second part of the template for equally-sized case, accessing a single
 /// element:

>From c1e68b4ebd0b0e0d4638ec23dc53da4e1886d911 Mon Sep 17 00:00:00 2001
From: Kasper Nielsen <kasper0406 at gmail.com>
Date: Sun, 25 Aug 2024 21:15:15 +0200
Subject: [PATCH 2/8] Add and fix tests

---
 mlir/test/mlir-tblgen/op-python-bindings.td | 12 ++--
 mlir/test/python/dialects/ods_helpers.py    | 67 +++++++++++++++++++++
 2 files changed, 73 insertions(+), 6 deletions(-)

diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index 9f202ba08608c6..d0642968854fe5 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -480,17 +480,17 @@ def SameVariadicOperandSizeOp : TestOp<"same_variadic_operand",
                                        [SameVariadicOperandSize]> {
   // CHECK: @builtins.property
   // CHECK: def variadic1(self):
-  // CHECK:   start, pg = _ods_equally_sized_accessor(operation.operands, 2, 0, 0)
+  // CHECK:   start, pg = _ods_equally_sized_accessor(self.operation.operands, 2, 0, 0)
   // CHECK:   return self.operation.operands[start:start + pg]
   //
   // CHECK: @builtins.property
   // CHECK: def non_variadic(self):
-  // CHECK:   start, pg = _ods_equally_sized_accessor(operation.operands, 2, 0, 1)
+  // CHECK:   start, pg = _ods_equally_sized_accessor(self.operation.operands, 2, 0, 1)
   // CHECK:   return self.operation.operands[start]
   //
   // CHECK: @builtins.property
   // CHECK: def variadic2(self):
-  // CHECK:   start, pg = _ods_equally_sized_accessor(operation.operands, 2, 1, 1)
+  // CHECK:   start, pg = _ods_equally_sized_accessor(self.operation.operands, 2, 1, 1)
   // CHECK:   return self.operation.operands[start:start + pg]
   let arguments = (ins Variadic<AnyType>:$variadic1, AnyType:$non_variadic,
                    Variadic<AnyType>:$variadic2);
@@ -506,17 +506,17 @@ def SameVariadicResultSizeOp : TestOp<"same_variadic_result",
                                       [SameVariadicResultSize]> {
   // CHECK: @builtins.property
   // CHECK: def variadic1(self):
-  // CHECK:   start, pg = _ods_equally_sized_accessor(operation.results, 2, 0, 0)
+  // CHECK:   start, pg = _ods_equally_sized_accessor(self.operation.results, 2, 0, 0)
   // CHECK:   return self.operation.results[start:start + pg]
   //
   // CHECK: @builtins.property
   // CHECK: def non_variadic(self):
-  // CHECK:   start, pg = _ods_equally_sized_accessor(operation.results, 2, 0, 1)
+  // CHECK:   start, pg = _ods_equally_sized_accessor(self.operation.results, 2, 0, 1)
   // CHECK:   return self.operation.results[start]
   //
   // CHECK: @builtins.property
   // CHECK: def variadic2(self):
-  // CHECK:   start, pg = _ods_equally_sized_accessor(operation.results, 2, 1, 1)
+  // CHECK:   start, pg = _ods_equally_sized_accessor(self.operation.results, 2, 1, 1)
   // CHECK:   return self.operation.results[start:start + pg]
   let results = (outs Variadic<AnyType>:$variadic1, AnyType:$non_variadic,
                  Variadic<AnyType>:$variadic2);
diff --git a/mlir/test/python/dialects/ods_helpers.py b/mlir/test/python/dialects/ods_helpers.py
index 0d2a18e0eb0af2..cb0d0528a6a866 100644
--- a/mlir/test/python/dialects/ods_helpers.py
+++ b/mlir/test/python/dialects/ods_helpers.py
@@ -3,6 +3,7 @@
 import gc
 
 from mlir.ir import *
+from mlir.dialects._ods_common import equally_sized_accessor
 
 
 def run(f):
@@ -208,3 +209,69 @@ class TestOp(OpView):
 
 
 run(testOdsBuildDefaultCastError)
+
+
+def testOdsEquallySizedAccessor():
+    class TestOpMultiResultSegments(OpView):
+        OPERATION_NAME = "custom.test_op"
+        _ODS_REGIONS = (1, True)
+
+    with Context() as ctx, Location.unknown():
+        ctx.allow_unregistered_dialects = True
+        m = Module.create()
+        with InsertionPoint(m.body):
+            v = add_dummy_value()
+            ts = [IntegerType.get_signless(i * 8) for i in range(4)]
+
+            op = TestOpMultiResultSegments.build_generic(
+                results=[ts[0], ts[1], ts[2], ts[3]], operands=[v]
+            )
+            start, pg = equally_sized_accessor(op.results, 3, 1, 0)
+            # CHECK: start: 1, pg: 1
+            print(f"start: {start}, pg: {pg}")
+            # CHECK: i8
+            print(op.results[start].type)
+
+            start, pg = equally_sized_accessor(op.results, 3, 1, 1)
+            # CHECK: start: 2, pg: 1
+            print(f"start: {start}, pg: {pg}")
+            # CHECK: i16
+            print(op.results[start].type)
+
+
+run(testOdsEquallySizedAccessor)
+
+
+def testOdsEquallySizedAccessorMultipleSegments():
+    class TestOpMultiResultSegments(OpView):
+        OPERATION_NAME = "custom.test_op"
+        _ODS_REGIONS = (1, True)
+        _ODS_RESULT_SEGMENTS = [0, -1, -1]
+
+    def types(lst):
+        return [e.type for e in lst]
+
+    with Context() as ctx, Location.unknown():
+        ctx.allow_unregistered_dialects = True
+        m = Module.create()
+        with InsertionPoint(m.body):
+            v = add_dummy_value()
+            ts = [IntegerType.get_signless(i * 8) for i in range(7)]
+
+            op = TestOpMultiResultSegments.build_generic(
+                results=[ts[0], [ts[1], ts[2], ts[3]], [ts[4], ts[5], ts[6]]], operands=[v]
+            )
+            start, pg = equally_sized_accessor(op.results, 2, 1, 0)
+            # CHECK: start: 1, pg: 3
+            print(f"start: {start}, pg: {pg}")
+            # CHECK: [IntegerType(i8), IntegerType(i16), IntegerType(i24)]
+            print(types(op.results[start:start + pg]))
+
+            start, pg = equally_sized_accessor(op.results, 2, 1, 1)
+            # CHECK: start: 4, pg: 3
+            print(f"start: {start}, pg: {pg}")
+            # CHECK: [IntegerType(i32), IntegerType(i40), IntegerType(i48)]
+            print(types(op.results[start:start + pg]))
+
+
+run(testOdsEquallySizedAccessorMultipleSegments)

>From 6be3f4ce254b1527c39370d979fc32c26f798f5f Mon Sep 17 00:00:00 2001
From: Kasper Nielsen <kasper0406 at gmail.com>
Date: Mon, 26 Aug 2024 08:42:36 +0200
Subject: [PATCH 3/8] Fix code style

---
 mlir/test/python/dialects/ods_helpers.py | 7 ++++---
 1 file changed, 4 insertions(+), 3 deletions(-)

diff --git a/mlir/test/python/dialects/ods_helpers.py b/mlir/test/python/dialects/ods_helpers.py
index cb0d0528a6a866..5e1f594503136a 100644
--- a/mlir/test/python/dialects/ods_helpers.py
+++ b/mlir/test/python/dialects/ods_helpers.py
@@ -259,19 +259,20 @@ def types(lst):
             ts = [IntegerType.get_signless(i * 8) for i in range(7)]
 
             op = TestOpMultiResultSegments.build_generic(
-                results=[ts[0], [ts[1], ts[2], ts[3]], [ts[4], ts[5], ts[6]]], operands=[v]
+                results=[ts[0], [ts[1], ts[2], ts[3]], [ts[4], ts[5], ts[6]]],
+                operands=[v]
             )
             start, pg = equally_sized_accessor(op.results, 2, 1, 0)
             # CHECK: start: 1, pg: 3
             print(f"start: {start}, pg: {pg}")
             # CHECK: [IntegerType(i8), IntegerType(i16), IntegerType(i24)]
-            print(types(op.results[start:start + pg]))
+            print(types(op.results[start : start + pg]))
 
             start, pg = equally_sized_accessor(op.results, 2, 1, 1)
             # CHECK: start: 4, pg: 3
             print(f"start: {start}, pg: {pg}")
             # CHECK: [IntegerType(i32), IntegerType(i40), IntegerType(i48)]
-            print(types(op.results[start:start + pg]))
+            print(types(op.results[start : start + pg]))
 
 
 run(testOdsEquallySizedAccessorMultipleSegments)

>From 81fd53c6db1f2796eef7f5cd1161715976b0d979 Mon Sep 17 00:00:00 2001
From: Kasper Nielsen <kasper0406 at gmail.com>
Date: Mon, 26 Aug 2024 11:25:40 +0200
Subject: [PATCH 4/8] Another coding style check

---
 mlir/test/CMakeLists.txt                 | 2 ++
 mlir/test/python/dialects/ods_helpers.py | 2 +-
 2 files changed, 3 insertions(+), 1 deletion(-)

diff --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt
index df95e5db11f1e0..0e8de4d15014d7 100644
--- a/mlir/test/CMakeLists.txt
+++ b/mlir/test/CMakeLists.txt
@@ -211,6 +211,8 @@ if(MLIR_ENABLE_BINDINGS_PYTHON)
   )
 endif()
 
+set(MLIR_TEST_DEPENDS MLIRPythonModules)
+
 # This target can be used to just build the dependencies
 # for the check-mlir target without executing the tests.
 # This is useful for bots when splitting the build step
diff --git a/mlir/test/python/dialects/ods_helpers.py b/mlir/test/python/dialects/ods_helpers.py
index 5e1f594503136a..0bacd0e5906883 100644
--- a/mlir/test/python/dialects/ods_helpers.py
+++ b/mlir/test/python/dialects/ods_helpers.py
@@ -260,7 +260,7 @@ def types(lst):
 
             op = TestOpMultiResultSegments.build_generic(
                 results=[ts[0], [ts[1], ts[2], ts[3]], [ts[4], ts[5], ts[6]]],
-                operands=[v]
+                operands=[v],
             )
             start, pg = equally_sized_accessor(op.results, 2, 1, 0)
             # CHECK: start: 1, pg: 3

>From 9db5bb63b5faf7c3c21bbebde844a31ed64b1063 Mon Sep 17 00:00:00 2001
From: Kasper Nielsen <kasper0406 at gmail.com>
Date: Mon, 26 Aug 2024 12:34:41 +0200
Subject: [PATCH 5/8] Do not skip tests

---
 mlir/test/CMakeLists.txt | 2 --
 1 file changed, 2 deletions(-)

diff --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt
index 0e8de4d15014d7..df95e5db11f1e0 100644
--- a/mlir/test/CMakeLists.txt
+++ b/mlir/test/CMakeLists.txt
@@ -211,8 +211,6 @@ if(MLIR_ENABLE_BINDINGS_PYTHON)
   )
 endif()
 
-set(MLIR_TEST_DEPENDS MLIRPythonModules)
-
 # This target can be used to just build the dependencies
 # for the check-mlir target without executing the tests.
 # This is useful for bots when splitting the build step

>From b861ec0d1c551fbd76b16438b9e93eed44e53a48 Mon Sep 17 00:00:00 2001
From: Kasper Nielsen <kasper0406 at gmail.com>
Date: Thu, 29 Aug 2024 16:49:43 +0200
Subject: [PATCH 6/8] Rename pg to elements_per_group

---
 mlir/test/mlir-tblgen/op-python-bindings.td   | 20 ++++++-------
 mlir/test/python/dialects/ods_helpers.py      | 28 +++++++++----------
 mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp |  4 +--
 3 files changed, 26 insertions(+), 26 deletions(-)

diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index d0642968854fe5..279c9920936fcc 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -480,18 +480,18 @@ def SameVariadicOperandSizeOp : TestOp<"same_variadic_operand",
                                        [SameVariadicOperandSize]> {
   // CHECK: @builtins.property
   // CHECK: def variadic1(self):
-  // CHECK:   start, pg = _ods_equally_sized_accessor(self.operation.operands, 2, 0, 0)
-  // CHECK:   return self.operation.operands[start:start + pg]
+  // CHECK:   start, elements_per_group = _ods_equally_sized_accessor(self.operation.operands, 2, 0, 0)
+  // CHECK:   return self.operation.operands[start:start + elements_per_group]
   //
   // CHECK: @builtins.property
   // CHECK: def non_variadic(self):
-  // CHECK:   start, pg = _ods_equally_sized_accessor(self.operation.operands, 2, 0, 1)
+  // CHECK:   start, elements_per_group = _ods_equally_sized_accessor(self.operation.operands, 2, 0, 1)
   // CHECK:   return self.operation.operands[start]
   //
   // CHECK: @builtins.property
   // CHECK: def variadic2(self):
-  // CHECK:   start, pg = _ods_equally_sized_accessor(self.operation.operands, 2, 1, 1)
-  // CHECK:   return self.operation.operands[start:start + pg]
+  // CHECK:   start, elements_per_group = _ods_equally_sized_accessor(self.operation.operands, 2, 1, 1)
+  // CHECK:   return self.operation.operands[start:start + elements_per_group]
   let arguments = (ins Variadic<AnyType>:$variadic1, AnyType:$non_variadic,
                    Variadic<AnyType>:$variadic2);
 }
@@ -506,18 +506,18 @@ def SameVariadicResultSizeOp : TestOp<"same_variadic_result",
                                       [SameVariadicResultSize]> {
   // CHECK: @builtins.property
   // CHECK: def variadic1(self):
-  // CHECK:   start, pg = _ods_equally_sized_accessor(self.operation.results, 2, 0, 0)
-  // CHECK:   return self.operation.results[start:start + pg]
+  // CHECK:   start, elements_per_group = _ods_equally_sized_accessor(self.operation.results, 2, 0, 0)
+  // CHECK:   return self.operation.results[start:start + elements_per_group]
   //
   // CHECK: @builtins.property
   // CHECK: def non_variadic(self):
-  // CHECK:   start, pg = _ods_equally_sized_accessor(self.operation.results, 2, 0, 1)
+  // CHECK:   start, elements_per_group = _ods_equally_sized_accessor(self.operation.results, 2, 0, 1)
   // CHECK:   return self.operation.results[start]
   //
   // CHECK: @builtins.property
   // CHECK: def variadic2(self):
-  // CHECK:   start, pg = _ods_equally_sized_accessor(self.operation.results, 2, 1, 1)
-  // CHECK:   return self.operation.results[start:start + pg]
+  // CHECK:   start, elements_per_group = _ods_equally_sized_accessor(self.operation.results, 2, 1, 1)
+  // CHECK:   return self.operation.results[start:start + elements_per_group]
   let results = (outs Variadic<AnyType>:$variadic1, AnyType:$non_variadic,
                  Variadic<AnyType>:$variadic2);
 }
diff --git a/mlir/test/python/dialects/ods_helpers.py b/mlir/test/python/dialects/ods_helpers.py
index 0bacd0e5906883..111555963b441e 100644
--- a/mlir/test/python/dialects/ods_helpers.py
+++ b/mlir/test/python/dialects/ods_helpers.py
@@ -226,15 +226,15 @@ class TestOpMultiResultSegments(OpView):
             op = TestOpMultiResultSegments.build_generic(
                 results=[ts[0], ts[1], ts[2], ts[3]], operands=[v]
             )
-            start, pg = equally_sized_accessor(op.results, 3, 1, 0)
-            # CHECK: start: 1, pg: 1
-            print(f"start: {start}, pg: {pg}")
+            start, elements_per_group = equally_sized_accessor(op.results, 3, 1, 0)
+            # CHECK: start: 1, elements_per_group: 1
+            print(f"start: {start}, elements_per_group: {elements_per_group}")
             # CHECK: i8
             print(op.results[start].type)
 
-            start, pg = equally_sized_accessor(op.results, 3, 1, 1)
-            # CHECK: start: 2, pg: 1
-            print(f"start: {start}, pg: {pg}")
+            start, elements_per_group = equally_sized_accessor(op.results, 3, 1, 1)
+            # CHECK: start: 2, elements_per_group: 1
+            print(f"start: {start}, elements_per_group: {elements_per_group}")
             # CHECK: i16
             print(op.results[start].type)
 
@@ -262,17 +262,17 @@ def types(lst):
                 results=[ts[0], [ts[1], ts[2], ts[3]], [ts[4], ts[5], ts[6]]],
                 operands=[v],
             )
-            start, pg = equally_sized_accessor(op.results, 2, 1, 0)
-            # CHECK: start: 1, pg: 3
-            print(f"start: {start}, pg: {pg}")
+            start, elements_per_group = equally_sized_accessor(op.results, 2, 1, 0)
+            # CHECK: start: 1, elements_per_group: 3
+            print(f"start: {start}, elements_per_group: {elements_per_group}")
             # CHECK: [IntegerType(i8), IntegerType(i16), IntegerType(i24)]
-            print(types(op.results[start : start + pg]))
+            print(types(op.results[start : start + elements_per_group]))
 
-            start, pg = equally_sized_accessor(op.results, 2, 1, 1)
-            # CHECK: start: 4, pg: 3
-            print(f"start: {start}, pg: {pg}")
+            start, elements_per_group = equally_sized_accessor(op.results, 2, 1, 1)
+            # CHECK: start: 4, elements_per_group: 3
+            print(f"start: {start}, elements_per_group: {elements_per_group}")
             # CHECK: [IntegerType(i32), IntegerType(i40), IntegerType(i48)]
-            print(types(op.results[start : start + pg]))
+            print(types(op.results[start : start + elements_per_group]))
 
 
 run(testOdsEquallySizedAccessorMultipleSegments)
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 97f2cc6c3f5763..8e9467bc8cfb74 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -145,7 +145,7 @@ constexpr const char *opOneVariadicTemplate = R"Py(
 constexpr const char *opVariadicEqualPrefixTemplate = R"Py(
   @builtins.property
   def {0}(self):
-    start, pg = _ods_equally_sized_accessor(self.operation.{1}s, {2}, {3}, {4}))Py";
+    start, elements_per_group = _ods_equally_sized_accessor(self.operation.{1}s, {2}, {3}, {4}))Py";
 
 /// Second part of the template for equally-sized case, accessing a single
 /// element:
@@ -158,7 +158,7 @@ constexpr const char *opVariadicEqualSimpleTemplate = R"Py(
 /// group:
 ///   {0} is either 'operand' or 'result'.
 constexpr const char *opVariadicEqualVariadicTemplate = R"Py(
-    return self.operation.{0}s[start:start + pg]
+    return self.operation.{0}s[start:start + elements_per_group]
 )Py";
 
 /// Template for an attribute-sized group accessor:

>From 816dea164a7c780164ff057e05d43b8c5d811c8b Mon Sep 17 00:00:00 2001
From: Kasper Nielsen <kasper0406 at gmail.com>
Date: Fri, 30 Aug 2024 12:58:34 +0200
Subject: [PATCH 7/8] Additional fixes and better tests

---
 mlir/python/mlir/dialects/_ods_common.py      |  5 +-
 mlir/test/mlir-tblgen/op-python-bindings.td   | 12 ++--
 mlir/test/python/dialects/ods_helpers.py      |  8 +--
 mlir/test/python/dialects/python_test.py      | 49 ++++++++++++++++
 mlir/test/python/python_test_ops.td           | 12 ++++
 mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp | 58 +++++++++++--------
 6 files changed, 107 insertions(+), 37 deletions(-)

diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py
index 0b56a376d23813..d40d936cdc83d6 100644
--- a/mlir/python/mlir/dialects/_ods_common.py
+++ b/mlir/python/mlir/dialects/_ods_common.py
@@ -51,13 +51,14 @@ def segmented_accessor(elements, raw_segments, idx):
 
 
 def equally_sized_accessor(
-    elements, n_variadic, n_preceding_simple, n_preceding_variadic
+    elements, n_simple, n_variadic, n_preceding_simple, n_preceding_variadic
 ):
     """
     Returns a starting position and a number of elements per variadic group
     assuming equally-sized groups and the given numbers of preceding groups.
 
       elements: a sequential container.
+      n_simple: the number of non-variadic groups in the container.
       n_variadic: the number of variadic groups in the container.
       n_preceding_simple: the number of non-variadic groups preceding the current
           group.
@@ -65,7 +66,7 @@ def equally_sized_accessor(
           group.
     """
 
-    total_variadic_length = len(elements) - n_preceding_simple
+    total_variadic_length = len(elements) - n_simple
     # This should be enforced by the C++-side trait verifier.
     assert total_variadic_length % n_variadic == 0
 
diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index 279c9920936fcc..ba85cb8406b31a 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -480,17 +480,17 @@ def SameVariadicOperandSizeOp : TestOp<"same_variadic_operand",
                                        [SameVariadicOperandSize]> {
   // CHECK: @builtins.property
   // CHECK: def variadic1(self):
-  // CHECK:   start, elements_per_group = _ods_equally_sized_accessor(self.operation.operands, 2, 0, 0)
+  // CHECK:   start, elements_per_group = _ods_equally_sized_accessor(self.operation.operands, 1, 2, 0, 0)
   // CHECK:   return self.operation.operands[start:start + elements_per_group]
   //
   // CHECK: @builtins.property
   // CHECK: def non_variadic(self):
-  // CHECK:   start, elements_per_group = _ods_equally_sized_accessor(self.operation.operands, 2, 0, 1)
+  // CHECK:   start, elements_per_group = _ods_equally_sized_accessor(self.operation.operands, 1, 2, 0, 1)
   // CHECK:   return self.operation.operands[start]
   //
   // CHECK: @builtins.property
   // CHECK: def variadic2(self):
-  // CHECK:   start, elements_per_group = _ods_equally_sized_accessor(self.operation.operands, 2, 1, 1)
+  // CHECK:   start, elements_per_group = _ods_equally_sized_accessor(self.operation.operands, 1, 2, 1, 1)
   // CHECK:   return self.operation.operands[start:start + elements_per_group]
   let arguments = (ins Variadic<AnyType>:$variadic1, AnyType:$non_variadic,
                    Variadic<AnyType>:$variadic2);
@@ -506,17 +506,17 @@ def SameVariadicResultSizeOp : TestOp<"same_variadic_result",
                                       [SameVariadicResultSize]> {
   // CHECK: @builtins.property
   // CHECK: def variadic1(self):
-  // CHECK:   start, elements_per_group = _ods_equally_sized_accessor(self.operation.results, 2, 0, 0)
+  // CHECK:   start, elements_per_group = _ods_equally_sized_accessor(self.operation.results, 1, 2, 0, 0)
   // CHECK:   return self.operation.results[start:start + elements_per_group]
   //
   // CHECK: @builtins.property
   // CHECK: def non_variadic(self):
-  // CHECK:   start, elements_per_group = _ods_equally_sized_accessor(self.operation.results, 2, 0, 1)
+  // CHECK:   start, elements_per_group = _ods_equally_sized_accessor(self.operation.results, 1, 2, 0, 1)
   // CHECK:   return self.operation.results[start]
   //
   // CHECK: @builtins.property
   // CHECK: def variadic2(self):
-  // CHECK:   start, elements_per_group = _ods_equally_sized_accessor(self.operation.results, 2, 1, 1)
+  // CHECK:   start, elements_per_group = _ods_equally_sized_accessor(self.operation.results, 1, 2, 1, 1)
   // CHECK:   return self.operation.results[start:start + elements_per_group]
   let results = (outs Variadic<AnyType>:$variadic1, AnyType:$non_variadic,
                  Variadic<AnyType>:$variadic2);
diff --git a/mlir/test/python/dialects/ods_helpers.py b/mlir/test/python/dialects/ods_helpers.py
index 111555963b441e..6f02153e08db5e 100644
--- a/mlir/test/python/dialects/ods_helpers.py
+++ b/mlir/test/python/dialects/ods_helpers.py
@@ -226,13 +226,13 @@ class TestOpMultiResultSegments(OpView):
             op = TestOpMultiResultSegments.build_generic(
                 results=[ts[0], ts[1], ts[2], ts[3]], operands=[v]
             )
-            start, elements_per_group = equally_sized_accessor(op.results, 3, 1, 0)
+            start, elements_per_group = equally_sized_accessor(op.results, 1, 3, 1, 0)
             # CHECK: start: 1, elements_per_group: 1
             print(f"start: {start}, elements_per_group: {elements_per_group}")
             # CHECK: i8
             print(op.results[start].type)
 
-            start, elements_per_group = equally_sized_accessor(op.results, 3, 1, 1)
+            start, elements_per_group = equally_sized_accessor(op.results, 1, 3, 1, 1)
             # CHECK: start: 2, elements_per_group: 1
             print(f"start: {start}, elements_per_group: {elements_per_group}")
             # CHECK: i16
@@ -262,13 +262,13 @@ def types(lst):
                 results=[ts[0], [ts[1], ts[2], ts[3]], [ts[4], ts[5], ts[6]]],
                 operands=[v],
             )
-            start, elements_per_group = equally_sized_accessor(op.results, 2, 1, 0)
+            start, elements_per_group = equally_sized_accessor(op.results, 1, 2, 1, 0)
             # CHECK: start: 1, elements_per_group: 3
             print(f"start: {start}, elements_per_group: {elements_per_group}")
             # CHECK: [IntegerType(i8), IntegerType(i16), IntegerType(i24)]
             print(types(op.results[start : start + elements_per_group]))
 
-            start, elements_per_group = equally_sized_accessor(op.results, 2, 1, 1)
+            start, elements_per_group = equally_sized_accessor(op.results, 1, 2, 1, 1)
             # CHECK: start: 4, elements_per_group: 3
             print(f"start: {start}, elements_per_group: {elements_per_group}")
             # CHECK: [IntegerType(i32), IntegerType(i40), IntegerType(i48)]
diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index a76f3f2b5e4583..ec2d2f61dccf52 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -555,3 +555,52 @@ def testInferTypeOpInterface():
             two_operands = test.InferResultsVariadicInputsOp(single=zero, doubled=zero)
             # CHECK: f32
             print(two_operands.result.type)
+
+
+# CHECK-LABEL: TEST: testVariadicResultAccess
+ at run
+def testVariadicResultAccess():
+    def types(lst):
+        return [e.type for e in lst]
+
+    with Context() as ctx, Location.unknown(ctx):
+        module = Module.create()
+        with InsertionPoint(module.body):
+            i8 = IntegerType.get_signless(8)
+            i16 = IntegerType.get_signless(16)
+            i24 = IntegerType.get_signless(24)
+            i32 = IntegerType.get_signless(32)
+            i40 = IntegerType.get_signless(40)
+
+            variadic_result = test.SameVariadicResultSizeOp([i8, i16], i24, [i32, i40])
+            # CHECK: i24
+            print(variadic_result.non_variadic.type)
+            # CHECK: [IntegerType(i8), IntegerType(i16)]
+            print(types(variadic_result.variadic1))
+            # CHECK: [IntegerType(i32), IntegerType(i40)]
+            print(types(variadic_result.variadic2))
+
+
+# CHECK-LABEL: TEST: testVariadicOperandAccess
+ at run
+def testVariadicOperandAccess():
+    def values(lst):
+        return [str(e) for e in lst]
+
+    with Context() as ctx, Location.unknown(ctx):
+        module = Module.create()
+        with InsertionPoint(module.body):
+            i32 = IntegerType.get_signless(32)
+            zero = arith.ConstantOp(i32, 0)
+            one = arith.ConstantOp(i32, 1)
+            two = arith.ConstantOp(i32, 2)
+            three = arith.ConstantOp(i32, 3)
+            four = arith.ConstantOp(i32, 4)
+
+            variadic_operands = test.SameVariadicOperandSizeOp([zero, one], two, [three, four])
+            # CHECK: Value(%{{.*}} = arith.constant 2 : i32)
+            print(variadic_operands.non_variadic)
+            # CHECK: ['Value(%{{.*}} = arith.constant 0 : i32)', 'Value(%{{.*}} = arith.constant 1 : i32)']
+            print(values(variadic_operands.variadic1))
+            # CHECK: ['Value(%{{.*}} = arith.constant 3 : i32)', 'Value(%{{.*}} = arith.constant 4 : i32)']
+            print(values(variadic_operands.variadic2))
diff --git a/mlir/test/python/python_test_ops.td b/mlir/test/python/python_test_ops.td
index 6211fb9987c76a..67145d2da1ca85 100644
--- a/mlir/test/python/python_test_ops.td
+++ b/mlir/test/python/python_test_ops.td
@@ -227,4 +227,16 @@ def OptionalOperandOp : TestOp<"optional_operand_op"> {
   let results = (outs I32:$result);
 }
 
+def SameVariadicResultSizeOp : TestOp<"same_variadic_result",
+                                      [SameVariadicResultSize]> {
+  let results = (outs Variadic<AnyType>:$variadic1, AnyType:$non_variadic,
+                 Variadic<AnyType>:$variadic2);
+}
+
+def SameVariadicOperandSizeOp : TestOp<"same_variadic_operand",
+                                       [SameVariadicOperandSize]> {
+  let arguments = (ins Variadic<AnyType>:$variadic1, AnyType:$non_variadic,
+                   Variadic<AnyType>:$variadic2);
+}
+
 #endif // PYTHON_TEST_OPS
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 8e9467bc8cfb74..c6c9de72e409d0 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -139,13 +139,14 @@ constexpr const char *opOneVariadicTemplate = R"Py(
 /// First part of the template for equally-sized variadic group accessor:
 ///   {0} is the name of the accessor;
 ///   {1} is either 'operand' or 'result';
-///   {2} is the total number of variadic groups;
-///   {3} is the number of non-variadic groups preceding the current group;
-///   {3} is the number of variadic groups preceding the current group.
+///   {2} is the total number of non-variadic groups;
+///   {3} is the total number of variadic groups;
+///   {4} is the number of non-variadic groups preceding the current group;
+///   {5} is the number of variadic groups preceding the current group.
 constexpr const char *opVariadicEqualPrefixTemplate = R"Py(
   @builtins.property
   def {0}(self):
-    start, elements_per_group = _ods_equally_sized_accessor(self.operation.{1}s, {2}, {3}, {4}))Py";
+    start, elements_per_group = _ods_equally_sized_accessor(self.operation.{1}s, {2}, {3}, {4}, {5}))Py";
 
 /// Second part of the template for equally-sized case, accessing a single
 /// element:
@@ -324,8 +325,8 @@ static std::string attrSizedTraitForKind(const char *kind) {
 /// `operand` or `result` and is used verbatim in the emitted code.
 static void emitElementAccessors(
     const Operator &op, raw_ostream &os, const char *kind,
-    llvm::function_ref<unsigned(const Operator &)> getNumVariableLength,
-    llvm::function_ref<int(const Operator &)> getNumElements,
+    unsigned numVariadicGroups,
+    unsigned numElements,
     llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
         getElement) {
   assert(llvm::is_contained(
@@ -339,14 +340,12 @@ static void emitElementAccessors(
                     llvm::StringRef(kind).drop_front());
   std::string attrSizedTrait = attrSizedTraitForKind(kind);
 
-  unsigned numVariableLength = getNumVariableLength(op);
-
   // If there is only one variable-length element group, its size can be
   // inferred from the total number of elements. If there are none, the
   // generation is straightforward.
-  if (numVariableLength <= 1) {
+  if (numVariadicGroups <= 1) {
     bool seenVariableLength = false;
-    for (int i = 0, e = getNumElements(op); i < e; ++i) {
+    for (unsigned i = 0; i < numElements; ++i) {
       const NamedTypeConstraint &element = getElement(op, i);
       if (element.isVariableLength())
         seenVariableLength = true;
@@ -356,11 +355,11 @@ static void emitElementAccessors(
         os << llvm::formatv(element.isOptional() ? opOneOptionalTemplate
                                                  : opOneVariadicTemplate,
                             sanitizeName(element.name), kind,
-                            getNumElements(op), i);
+                            numElements, i);
       } else if (seenVariableLength) {
         os << llvm::formatv(opSingleAfterVariableTemplate,
                             sanitizeName(element.name), kind,
-                            getNumElements(op), i);
+                            numElements, i);
       } else {
         os << llvm::formatv(opSingleTemplate, sanitizeName(element.name), kind,
                             i);
@@ -371,13 +370,24 @@ static void emitElementAccessors(
 
   // Handle the operations where variadic groups have the same size.
   if (op.getTrait(sameSizeTrait)) {
+    // Count the number of simple elements
+    unsigned numSimpleLength = 0;
+    for (unsigned i = 0; i < numElements; ++i) {
+      const NamedTypeConstraint &element = getElement(op, i);
+      if (!element.isVariableLength()) {
+        ++numSimpleLength;
+      }
+    }
+
+    // Generate the accessors
     int numPrecedingSimple = 0;
     int numPrecedingVariadic = 0;
-    for (int i = 0, e = getNumElements(op); i < e; ++i) {
+    for (unsigned i = 0; i < numElements; ++i) {
       const NamedTypeConstraint &element = getElement(op, i);
       if (!element.name.empty()) {
         os << llvm::formatv(opVariadicEqualPrefixTemplate,
-                            sanitizeName(element.name), kind, numVariableLength,
+                            sanitizeName(element.name), kind,
+                            numSimpleLength, numVariadicGroups,
                             numPrecedingSimple, numPrecedingVariadic);
         os << llvm::formatv(element.isVariableLength()
                                 ? opVariadicEqualVariadicTemplate
@@ -396,7 +406,7 @@ static void emitElementAccessors(
   // provided as an attribute. For non-variadic elements, make sure to return
   // an element rather than a singleton container.
   if (op.getTrait(attrSizedTrait)) {
-    for (int i = 0, e = getNumElements(op); i < e; ++i) {
+    for (unsigned i = 0; i < numElements; ++i) {
       const NamedTypeConstraint &element = getElement(op, i);
       if (element.name.empty())
         continue;
@@ -427,20 +437,18 @@ static const NamedTypeConstraint &getResult(const Operator &op, int i) {
 
 /// Emits accessors to Op operands.
 static void emitOperandAccessors(const Operator &op, raw_ostream &os) {
-  auto getNumVariableLengthOperands = [](const Operator &oper) {
-    return oper.getNumVariableLengthOperands();
-  };
-  emitElementAccessors(op, os, "operand", getNumVariableLengthOperands,
-                       getNumOperands, getOperand);
+  emitElementAccessors(op, os, "operand",
+                       op.getNumVariableLengthOperands(),
+                       getNumOperands(op),
+                       getOperand);
 }
 
 /// Emits accessors Op results.
 static void emitResultAccessors(const Operator &op, raw_ostream &os) {
-  auto getNumVariableLengthResults = [](const Operator &oper) {
-    return oper.getNumVariableLengthResults();
-  };
-  emitElementAccessors(op, os, "result", getNumVariableLengthResults,
-                       getNumResults, getResult);
+  emitElementAccessors(op, os, "result",
+                       op.getNumVariableLengthResults(),
+                       getNumResults(op),
+                       getResult);
 }
 
 /// Emits accessors to Op attributes.

>From f5536972a5b10b153869459aaa7867d7de5e66fe Mon Sep 17 00:00:00 2001
From: Kasper Nielsen <kasper0406 at gmail.com>
Date: Fri, 30 Aug 2024 15:31:29 +0200
Subject: [PATCH 8/8] Fix code style

---
 mlir/test/python/dialects/python_test.py      |  4 ++-
 mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp | 27 +++++++------------
 2 files changed, 13 insertions(+), 18 deletions(-)

diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index ec2d2f61dccf52..62d315810613c6 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -597,7 +597,9 @@ def values(lst):
             three = arith.ConstantOp(i32, 3)
             four = arith.ConstantOp(i32, 4)
 
-            variadic_operands = test.SameVariadicOperandSizeOp([zero, one], two, [three, four])
+            variadic_operands = test.SameVariadicOperandSizeOp(
+                [zero, one], two, [three, four]
+            )
             # CHECK: Value(%{{.*}} = arith.constant 2 : i32)
             print(variadic_operands.non_variadic)
             # CHECK: ['Value(%{{.*}} = arith.constant 0 : i32)', 'Value(%{{.*}} = arith.constant 1 : i32)']
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index c6c9de72e409d0..553ab6adc65b06 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -325,8 +325,7 @@ static std::string attrSizedTraitForKind(const char *kind) {
 /// `operand` or `result` and is used verbatim in the emitted code.
 static void emitElementAccessors(
     const Operator &op, raw_ostream &os, const char *kind,
-    unsigned numVariadicGroups,
-    unsigned numElements,
+    unsigned numVariadicGroups, unsigned numElements,
     llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
         getElement) {
   assert(llvm::is_contained(
@@ -354,12 +353,10 @@ static void emitElementAccessors(
       if (element.isVariableLength()) {
         os << llvm::formatv(element.isOptional() ? opOneOptionalTemplate
                                                  : opOneVariadicTemplate,
-                            sanitizeName(element.name), kind,
-                            numElements, i);
+                            sanitizeName(element.name), kind, numElements, i);
       } else if (seenVariableLength) {
         os << llvm::formatv(opSingleAfterVariableTemplate,
-                            sanitizeName(element.name), kind,
-                            numElements, i);
+                            sanitizeName(element.name), kind, numElements, i);
       } else {
         os << llvm::formatv(opSingleTemplate, sanitizeName(element.name), kind,
                             i);
@@ -386,9 +383,9 @@ static void emitElementAccessors(
       const NamedTypeConstraint &element = getElement(op, i);
       if (!element.name.empty()) {
         os << llvm::formatv(opVariadicEqualPrefixTemplate,
-                            sanitizeName(element.name), kind,
-                            numSimpleLength, numVariadicGroups,
-                            numPrecedingSimple, numPrecedingVariadic);
+                            sanitizeName(element.name), kind, numSimpleLength,
+                            numVariadicGroups, numPrecedingSimple,
+                            numPrecedingVariadic);
         os << llvm::formatv(element.isVariableLength()
                                 ? opVariadicEqualVariadicTemplate
                                 : opVariadicEqualSimpleTemplate,
@@ -437,18 +434,14 @@ static const NamedTypeConstraint &getResult(const Operator &op, int i) {
 
 /// Emits accessors to Op operands.
 static void emitOperandAccessors(const Operator &op, raw_ostream &os) {
-  emitElementAccessors(op, os, "operand",
-                       op.getNumVariableLengthOperands(),
-                       getNumOperands(op),
-                       getOperand);
+  emitElementAccessors(op, os, "operand", op.getNumVariableLengthOperands(),
+                       getNumOperands(op), getOperand);
 }
 
 /// Emits accessors Op results.
 static void emitResultAccessors(const Operator &op, raw_ostream &os) {
-  emitElementAccessors(op, os, "result",
-                       op.getNumVariableLengthResults(),
-                       getNumResults(op),
-                       getResult);
+  emitElementAccessors(op, os, "result", op.getNumVariableLengthResults(),
+                       getNumResults(op), getResult);
 }
 
 /// Emits accessors to Op attributes.



More information about the Mlir-commits mailing list