[Mlir-commits] [mlir] [mlir][python] Fix how the mlir variadic Python accessor `_ods_equally_sized_accessor` is used (#101132) (PR #106003)
Kasper Nielsen
llvmlistbot at llvm.org
Thu Aug 29 08:05:41 PDT 2024
https://github.com/kasper0406 updated https://github.com/llvm/llvm-project/pull/106003
>From c35573e237a135321e334b95452283abfa315601 Mon Sep 17 00:00:00 2001
From: Kasper Nielsen <kasper0406 at gmail.com>
Date: Wed, 21 Aug 2024 16:55:48 +0200
Subject: [PATCH 1/6] Fix MLIR Python bindings when trying to access varadic
elements
---
mlir/python/mlir/dialects/_ods_common.py | 2 +-
mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py
index 1e7e8244ed4420..0b56a376d23813 100644
--- a/mlir/python/mlir/dialects/_ods_common.py
+++ b/mlir/python/mlir/dialects/_ods_common.py
@@ -65,7 +65,7 @@ def equally_sized_accessor(
group.
"""
- total_variadic_length = len(elements) - n_variadic + 1
+ total_variadic_length = len(elements) - n_preceding_simple
# This should be enforced by the C++-side trait verifier.
assert total_variadic_length % n_variadic == 0
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 052020acdcb764..97f2cc6c3f5763 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -145,7 +145,7 @@ constexpr const char *opOneVariadicTemplate = R"Py(
constexpr const char *opVariadicEqualPrefixTemplate = R"Py(
@builtins.property
def {0}(self):
- start, pg = _ods_equally_sized_accessor(operation.{1}s, {2}, {3}, {4}))Py";
+ start, pg = _ods_equally_sized_accessor(self.operation.{1}s, {2}, {3}, {4}))Py";
/// Second part of the template for equally-sized case, accessing a single
/// element:
>From c1e68b4ebd0b0e0d4638ec23dc53da4e1886d911 Mon Sep 17 00:00:00 2001
From: Kasper Nielsen <kasper0406 at gmail.com>
Date: Sun, 25 Aug 2024 21:15:15 +0200
Subject: [PATCH 2/6] Add and fix tests
---
mlir/test/mlir-tblgen/op-python-bindings.td | 12 ++--
mlir/test/python/dialects/ods_helpers.py | 67 +++++++++++++++++++++
2 files changed, 73 insertions(+), 6 deletions(-)
diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index 9f202ba08608c6..d0642968854fe5 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -480,17 +480,17 @@ def SameVariadicOperandSizeOp : TestOp<"same_variadic_operand",
[SameVariadicOperandSize]> {
// CHECK: @builtins.property
// CHECK: def variadic1(self):
- // CHECK: start, pg = _ods_equally_sized_accessor(operation.operands, 2, 0, 0)
+ // CHECK: start, pg = _ods_equally_sized_accessor(self.operation.operands, 2, 0, 0)
// CHECK: return self.operation.operands[start:start + pg]
//
// CHECK: @builtins.property
// CHECK: def non_variadic(self):
- // CHECK: start, pg = _ods_equally_sized_accessor(operation.operands, 2, 0, 1)
+ // CHECK: start, pg = _ods_equally_sized_accessor(self.operation.operands, 2, 0, 1)
// CHECK: return self.operation.operands[start]
//
// CHECK: @builtins.property
// CHECK: def variadic2(self):
- // CHECK: start, pg = _ods_equally_sized_accessor(operation.operands, 2, 1, 1)
+ // CHECK: start, pg = _ods_equally_sized_accessor(self.operation.operands, 2, 1, 1)
// CHECK: return self.operation.operands[start:start + pg]
let arguments = (ins Variadic<AnyType>:$variadic1, AnyType:$non_variadic,
Variadic<AnyType>:$variadic2);
@@ -506,17 +506,17 @@ def SameVariadicResultSizeOp : TestOp<"same_variadic_result",
[SameVariadicResultSize]> {
// CHECK: @builtins.property
// CHECK: def variadic1(self):
- // CHECK: start, pg = _ods_equally_sized_accessor(operation.results, 2, 0, 0)
+ // CHECK: start, pg = _ods_equally_sized_accessor(self.operation.results, 2, 0, 0)
// CHECK: return self.operation.results[start:start + pg]
//
// CHECK: @builtins.property
// CHECK: def non_variadic(self):
- // CHECK: start, pg = _ods_equally_sized_accessor(operation.results, 2, 0, 1)
+ // CHECK: start, pg = _ods_equally_sized_accessor(self.operation.results, 2, 0, 1)
// CHECK: return self.operation.results[start]
//
// CHECK: @builtins.property
// CHECK: def variadic2(self):
- // CHECK: start, pg = _ods_equally_sized_accessor(operation.results, 2, 1, 1)
+ // CHECK: start, pg = _ods_equally_sized_accessor(self.operation.results, 2, 1, 1)
// CHECK: return self.operation.results[start:start + pg]
let results = (outs Variadic<AnyType>:$variadic1, AnyType:$non_variadic,
Variadic<AnyType>:$variadic2);
diff --git a/mlir/test/python/dialects/ods_helpers.py b/mlir/test/python/dialects/ods_helpers.py
index 0d2a18e0eb0af2..cb0d0528a6a866 100644
--- a/mlir/test/python/dialects/ods_helpers.py
+++ b/mlir/test/python/dialects/ods_helpers.py
@@ -3,6 +3,7 @@
import gc
from mlir.ir import *
+from mlir.dialects._ods_common import equally_sized_accessor
def run(f):
@@ -208,3 +209,69 @@ class TestOp(OpView):
run(testOdsBuildDefaultCastError)
+
+
+def testOdsEquallySizedAccessor():
+ class TestOpMultiResultSegments(OpView):
+ OPERATION_NAME = "custom.test_op"
+ _ODS_REGIONS = (1, True)
+
+ with Context() as ctx, Location.unknown():
+ ctx.allow_unregistered_dialects = True
+ m = Module.create()
+ with InsertionPoint(m.body):
+ v = add_dummy_value()
+ ts = [IntegerType.get_signless(i * 8) for i in range(4)]
+
+ op = TestOpMultiResultSegments.build_generic(
+ results=[ts[0], ts[1], ts[2], ts[3]], operands=[v]
+ )
+ start, pg = equally_sized_accessor(op.results, 3, 1, 0)
+ # CHECK: start: 1, pg: 1
+ print(f"start: {start}, pg: {pg}")
+ # CHECK: i8
+ print(op.results[start].type)
+
+ start, pg = equally_sized_accessor(op.results, 3, 1, 1)
+ # CHECK: start: 2, pg: 1
+ print(f"start: {start}, pg: {pg}")
+ # CHECK: i16
+ print(op.results[start].type)
+
+
+run(testOdsEquallySizedAccessor)
+
+
+def testOdsEquallySizedAccessorMultipleSegments():
+ class TestOpMultiResultSegments(OpView):
+ OPERATION_NAME = "custom.test_op"
+ _ODS_REGIONS = (1, True)
+ _ODS_RESULT_SEGMENTS = [0, -1, -1]
+
+ def types(lst):
+ return [e.type for e in lst]
+
+ with Context() as ctx, Location.unknown():
+ ctx.allow_unregistered_dialects = True
+ m = Module.create()
+ with InsertionPoint(m.body):
+ v = add_dummy_value()
+ ts = [IntegerType.get_signless(i * 8) for i in range(7)]
+
+ op = TestOpMultiResultSegments.build_generic(
+ results=[ts[0], [ts[1], ts[2], ts[3]], [ts[4], ts[5], ts[6]]], operands=[v]
+ )
+ start, pg = equally_sized_accessor(op.results, 2, 1, 0)
+ # CHECK: start: 1, pg: 3
+ print(f"start: {start}, pg: {pg}")
+ # CHECK: [IntegerType(i8), IntegerType(i16), IntegerType(i24)]
+ print(types(op.results[start:start + pg]))
+
+ start, pg = equally_sized_accessor(op.results, 2, 1, 1)
+ # CHECK: start: 4, pg: 3
+ print(f"start: {start}, pg: {pg}")
+ # CHECK: [IntegerType(i32), IntegerType(i40), IntegerType(i48)]
+ print(types(op.results[start:start + pg]))
+
+
+run(testOdsEquallySizedAccessorMultipleSegments)
>From 6be3f4ce254b1527c39370d979fc32c26f798f5f Mon Sep 17 00:00:00 2001
From: Kasper Nielsen <kasper0406 at gmail.com>
Date: Mon, 26 Aug 2024 08:42:36 +0200
Subject: [PATCH 3/6] Fix code style
---
mlir/test/python/dialects/ods_helpers.py | 7 ++++---
1 file changed, 4 insertions(+), 3 deletions(-)
diff --git a/mlir/test/python/dialects/ods_helpers.py b/mlir/test/python/dialects/ods_helpers.py
index cb0d0528a6a866..5e1f594503136a 100644
--- a/mlir/test/python/dialects/ods_helpers.py
+++ b/mlir/test/python/dialects/ods_helpers.py
@@ -259,19 +259,20 @@ def types(lst):
ts = [IntegerType.get_signless(i * 8) for i in range(7)]
op = TestOpMultiResultSegments.build_generic(
- results=[ts[0], [ts[1], ts[2], ts[3]], [ts[4], ts[5], ts[6]]], operands=[v]
+ results=[ts[0], [ts[1], ts[2], ts[3]], [ts[4], ts[5], ts[6]]],
+ operands=[v]
)
start, pg = equally_sized_accessor(op.results, 2, 1, 0)
# CHECK: start: 1, pg: 3
print(f"start: {start}, pg: {pg}")
# CHECK: [IntegerType(i8), IntegerType(i16), IntegerType(i24)]
- print(types(op.results[start:start + pg]))
+ print(types(op.results[start : start + pg]))
start, pg = equally_sized_accessor(op.results, 2, 1, 1)
# CHECK: start: 4, pg: 3
print(f"start: {start}, pg: {pg}")
# CHECK: [IntegerType(i32), IntegerType(i40), IntegerType(i48)]
- print(types(op.results[start:start + pg]))
+ print(types(op.results[start : start + pg]))
run(testOdsEquallySizedAccessorMultipleSegments)
>From 81fd53c6db1f2796eef7f5cd1161715976b0d979 Mon Sep 17 00:00:00 2001
From: Kasper Nielsen <kasper0406 at gmail.com>
Date: Mon, 26 Aug 2024 11:25:40 +0200
Subject: [PATCH 4/6] Another coding style check
---
mlir/test/CMakeLists.txt | 2 ++
mlir/test/python/dialects/ods_helpers.py | 2 +-
2 files changed, 3 insertions(+), 1 deletion(-)
diff --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt
index df95e5db11f1e0..0e8de4d15014d7 100644
--- a/mlir/test/CMakeLists.txt
+++ b/mlir/test/CMakeLists.txt
@@ -211,6 +211,8 @@ if(MLIR_ENABLE_BINDINGS_PYTHON)
)
endif()
+set(MLIR_TEST_DEPENDS MLIRPythonModules)
+
# This target can be used to just build the dependencies
# for the check-mlir target without executing the tests.
# This is useful for bots when splitting the build step
diff --git a/mlir/test/python/dialects/ods_helpers.py b/mlir/test/python/dialects/ods_helpers.py
index 5e1f594503136a..0bacd0e5906883 100644
--- a/mlir/test/python/dialects/ods_helpers.py
+++ b/mlir/test/python/dialects/ods_helpers.py
@@ -260,7 +260,7 @@ def types(lst):
op = TestOpMultiResultSegments.build_generic(
results=[ts[0], [ts[1], ts[2], ts[3]], [ts[4], ts[5], ts[6]]],
- operands=[v]
+ operands=[v],
)
start, pg = equally_sized_accessor(op.results, 2, 1, 0)
# CHECK: start: 1, pg: 3
>From 9db5bb63b5faf7c3c21bbebde844a31ed64b1063 Mon Sep 17 00:00:00 2001
From: Kasper Nielsen <kasper0406 at gmail.com>
Date: Mon, 26 Aug 2024 12:34:41 +0200
Subject: [PATCH 5/6] Do not skip tests
---
mlir/test/CMakeLists.txt | 2 --
1 file changed, 2 deletions(-)
diff --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt
index 0e8de4d15014d7..df95e5db11f1e0 100644
--- a/mlir/test/CMakeLists.txt
+++ b/mlir/test/CMakeLists.txt
@@ -211,8 +211,6 @@ if(MLIR_ENABLE_BINDINGS_PYTHON)
)
endif()
-set(MLIR_TEST_DEPENDS MLIRPythonModules)
-
# This target can be used to just build the dependencies
# for the check-mlir target without executing the tests.
# This is useful for bots when splitting the build step
>From b861ec0d1c551fbd76b16438b9e93eed44e53a48 Mon Sep 17 00:00:00 2001
From: Kasper Nielsen <kasper0406 at gmail.com>
Date: Thu, 29 Aug 2024 16:49:43 +0200
Subject: [PATCH 6/6] Rename pg to elements_per_group
---
mlir/test/mlir-tblgen/op-python-bindings.td | 20 ++++++-------
mlir/test/python/dialects/ods_helpers.py | 28 +++++++++----------
mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp | 4 +--
3 files changed, 26 insertions(+), 26 deletions(-)
diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index d0642968854fe5..279c9920936fcc 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -480,18 +480,18 @@ def SameVariadicOperandSizeOp : TestOp<"same_variadic_operand",
[SameVariadicOperandSize]> {
// CHECK: @builtins.property
// CHECK: def variadic1(self):
- // CHECK: start, pg = _ods_equally_sized_accessor(self.operation.operands, 2, 0, 0)
- // CHECK: return self.operation.operands[start:start + pg]
+ // CHECK: start, elements_per_group = _ods_equally_sized_accessor(self.operation.operands, 2, 0, 0)
+ // CHECK: return self.operation.operands[start:start + elements_per_group]
//
// CHECK: @builtins.property
// CHECK: def non_variadic(self):
- // CHECK: start, pg = _ods_equally_sized_accessor(self.operation.operands, 2, 0, 1)
+ // CHECK: start, elements_per_group = _ods_equally_sized_accessor(self.operation.operands, 2, 0, 1)
// CHECK: return self.operation.operands[start]
//
// CHECK: @builtins.property
// CHECK: def variadic2(self):
- // CHECK: start, pg = _ods_equally_sized_accessor(self.operation.operands, 2, 1, 1)
- // CHECK: return self.operation.operands[start:start + pg]
+ // CHECK: start, elements_per_group = _ods_equally_sized_accessor(self.operation.operands, 2, 1, 1)
+ // CHECK: return self.operation.operands[start:start + elements_per_group]
let arguments = (ins Variadic<AnyType>:$variadic1, AnyType:$non_variadic,
Variadic<AnyType>:$variadic2);
}
@@ -506,18 +506,18 @@ def SameVariadicResultSizeOp : TestOp<"same_variadic_result",
[SameVariadicResultSize]> {
// CHECK: @builtins.property
// CHECK: def variadic1(self):
- // CHECK: start, pg = _ods_equally_sized_accessor(self.operation.results, 2, 0, 0)
- // CHECK: return self.operation.results[start:start + pg]
+ // CHECK: start, elements_per_group = _ods_equally_sized_accessor(self.operation.results, 2, 0, 0)
+ // CHECK: return self.operation.results[start:start + elements_per_group]
//
// CHECK: @builtins.property
// CHECK: def non_variadic(self):
- // CHECK: start, pg = _ods_equally_sized_accessor(self.operation.results, 2, 0, 1)
+ // CHECK: start, elements_per_group = _ods_equally_sized_accessor(self.operation.results, 2, 0, 1)
// CHECK: return self.operation.results[start]
//
// CHECK: @builtins.property
// CHECK: def variadic2(self):
- // CHECK: start, pg = _ods_equally_sized_accessor(self.operation.results, 2, 1, 1)
- // CHECK: return self.operation.results[start:start + pg]
+ // CHECK: start, elements_per_group = _ods_equally_sized_accessor(self.operation.results, 2, 1, 1)
+ // CHECK: return self.operation.results[start:start + elements_per_group]
let results = (outs Variadic<AnyType>:$variadic1, AnyType:$non_variadic,
Variadic<AnyType>:$variadic2);
}
diff --git a/mlir/test/python/dialects/ods_helpers.py b/mlir/test/python/dialects/ods_helpers.py
index 0bacd0e5906883..111555963b441e 100644
--- a/mlir/test/python/dialects/ods_helpers.py
+++ b/mlir/test/python/dialects/ods_helpers.py
@@ -226,15 +226,15 @@ class TestOpMultiResultSegments(OpView):
op = TestOpMultiResultSegments.build_generic(
results=[ts[0], ts[1], ts[2], ts[3]], operands=[v]
)
- start, pg = equally_sized_accessor(op.results, 3, 1, 0)
- # CHECK: start: 1, pg: 1
- print(f"start: {start}, pg: {pg}")
+ start, elements_per_group = equally_sized_accessor(op.results, 3, 1, 0)
+ # CHECK: start: 1, elements_per_group: 1
+ print(f"start: {start}, elements_per_group: {elements_per_group}")
# CHECK: i8
print(op.results[start].type)
- start, pg = equally_sized_accessor(op.results, 3, 1, 1)
- # CHECK: start: 2, pg: 1
- print(f"start: {start}, pg: {pg}")
+ start, elements_per_group = equally_sized_accessor(op.results, 3, 1, 1)
+ # CHECK: start: 2, elements_per_group: 1
+ print(f"start: {start}, elements_per_group: {elements_per_group}")
# CHECK: i16
print(op.results[start].type)
@@ -262,17 +262,17 @@ def types(lst):
results=[ts[0], [ts[1], ts[2], ts[3]], [ts[4], ts[5], ts[6]]],
operands=[v],
)
- start, pg = equally_sized_accessor(op.results, 2, 1, 0)
- # CHECK: start: 1, pg: 3
- print(f"start: {start}, pg: {pg}")
+ start, elements_per_group = equally_sized_accessor(op.results, 2, 1, 0)
+ # CHECK: start: 1, elements_per_group: 3
+ print(f"start: {start}, elements_per_group: {elements_per_group}")
# CHECK: [IntegerType(i8), IntegerType(i16), IntegerType(i24)]
- print(types(op.results[start : start + pg]))
+ print(types(op.results[start : start + elements_per_group]))
- start, pg = equally_sized_accessor(op.results, 2, 1, 1)
- # CHECK: start: 4, pg: 3
- print(f"start: {start}, pg: {pg}")
+ start, elements_per_group = equally_sized_accessor(op.results, 2, 1, 1)
+ # CHECK: start: 4, elements_per_group: 3
+ print(f"start: {start}, elements_per_group: {elements_per_group}")
# CHECK: [IntegerType(i32), IntegerType(i40), IntegerType(i48)]
- print(types(op.results[start : start + pg]))
+ print(types(op.results[start : start + elements_per_group]))
run(testOdsEquallySizedAccessorMultipleSegments)
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 97f2cc6c3f5763..8e9467bc8cfb74 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -145,7 +145,7 @@ constexpr const char *opOneVariadicTemplate = R"Py(
constexpr const char *opVariadicEqualPrefixTemplate = R"Py(
@builtins.property
def {0}(self):
- start, pg = _ods_equally_sized_accessor(self.operation.{1}s, {2}, {3}, {4}))Py";
+ start, elements_per_group = _ods_equally_sized_accessor(self.operation.{1}s, {2}, {3}, {4}))Py";
/// Second part of the template for equally-sized case, accessing a single
/// element:
@@ -158,7 +158,7 @@ constexpr const char *opVariadicEqualSimpleTemplate = R"Py(
/// group:
/// {0} is either 'operand' or 'result'.
constexpr const char *opVariadicEqualVariadicTemplate = R"Py(
- return self.operation.{0}s[start:start + pg]
+ return self.operation.{0}s[start:start + elements_per_group]
)Py";
/// Template for an attribute-sized group accessor:
More information about the Mlir-commits
mailing list