[Mlir-commits] [mlir] 3766ba4 - [mlir][python] Fix how the mlir variadic Python accessor `_ods_equally_sized_accessor` is used (#101132) (#106003)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Aug 31 00:17:37 PDT 2024
Author: Kasper Nielsen
Date: 2024-08-31T03:17:33-04:00
New Revision: 3766ba44a8945681f4c52acb0331efcff66ef7b1
URL: https://github.com/llvm/llvm-project/commit/3766ba44a8945681f4c52acb0331efcff66ef7b1
DIFF: https://github.com/llvm/llvm-project/commit/3766ba44a8945681f4c52acb0331efcff66ef7b1.diff
LOG: [mlir][python] Fix how the mlir variadic Python accessor `_ods_equally_sized_accessor` is used (#101132) (#106003)
As reported in https://github.com/llvm/llvm-project/issues/101132, this
fixes two bugs:
1. When accessing variadic operands inside an operation, it must be
accessed as `self.operation.operands` instead of `operation.operands`
2. The implementation of the `equally_sized_accessor` function is doing
wrong arithmetics when calculating the resulting index and group sizes.
I have added a test for the `equally_sized_accessor` function, which did
not have a test previously.
Added:
Modified:
mlir/python/mlir/dialects/_ods_common.py
mlir/test/mlir-tblgen/op-python-bindings.td
mlir/test/python/dialects/ods_helpers.py
mlir/test/python/dialects/python_test.py
mlir/test/python/python_test_ops.td
mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
Removed:
################################################################################
diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py
index 1e7e8244ed4420..d40d936cdc83d6 100644
--- a/mlir/python/mlir/dialects/_ods_common.py
+++ b/mlir/python/mlir/dialects/_ods_common.py
@@ -51,13 +51,14 @@ def segmented_accessor(elements, raw_segments, idx):
def equally_sized_accessor(
- elements, n_variadic, n_preceding_simple, n_preceding_variadic
+ elements, n_simple, n_variadic, n_preceding_simple, n_preceding_variadic
):
"""
Returns a starting position and a number of elements per variadic group
assuming equally-sized groups and the given numbers of preceding groups.
elements: a sequential container.
+ n_simple: the number of non-variadic groups in the container.
n_variadic: the number of variadic groups in the container.
n_preceding_simple: the number of non-variadic groups preceding the current
group.
@@ -65,7 +66,7 @@ def equally_sized_accessor(
group.
"""
- total_variadic_length = len(elements) - n_variadic + 1
+ total_variadic_length = len(elements) - n_simple
# This should be enforced by the C++-side trait verifier.
assert total_variadic_length % n_variadic == 0
diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index 9f202ba08608c6..ba85cb8406b31a 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -480,18 +480,18 @@ def SameVariadicOperandSizeOp : TestOp<"same_variadic_operand",
[SameVariadicOperandSize]> {
// CHECK: @builtins.property
// CHECK: def variadic1(self):
- // CHECK: start, pg = _ods_equally_sized_accessor(operation.operands, 2, 0, 0)
- // CHECK: return self.operation.operands[start:start + pg]
+ // CHECK: start, elements_per_group = _ods_equally_sized_accessor(self.operation.operands, 1, 2, 0, 0)
+ // CHECK: return self.operation.operands[start:start + elements_per_group]
//
// CHECK: @builtins.property
// CHECK: def non_variadic(self):
- // CHECK: start, pg = _ods_equally_sized_accessor(operation.operands, 2, 0, 1)
+ // CHECK: start, elements_per_group = _ods_equally_sized_accessor(self.operation.operands, 1, 2, 0, 1)
// CHECK: return self.operation.operands[start]
//
// CHECK: @builtins.property
// CHECK: def variadic2(self):
- // CHECK: start, pg = _ods_equally_sized_accessor(operation.operands, 2, 1, 1)
- // CHECK: return self.operation.operands[start:start + pg]
+ // CHECK: start, elements_per_group = _ods_equally_sized_accessor(self.operation.operands, 1, 2, 1, 1)
+ // CHECK: return self.operation.operands[start:start + elements_per_group]
let arguments = (ins Variadic<AnyType>:$variadic1, AnyType:$non_variadic,
Variadic<AnyType>:$variadic2);
}
@@ -506,18 +506,18 @@ def SameVariadicResultSizeOp : TestOp<"same_variadic_result",
[SameVariadicResultSize]> {
// CHECK: @builtins.property
// CHECK: def variadic1(self):
- // CHECK: start, pg = _ods_equally_sized_accessor(operation.results, 2, 0, 0)
- // CHECK: return self.operation.results[start:start + pg]
+ // CHECK: start, elements_per_group = _ods_equally_sized_accessor(self.operation.results, 1, 2, 0, 0)
+ // CHECK: return self.operation.results[start:start + elements_per_group]
//
// CHECK: @builtins.property
// CHECK: def non_variadic(self):
- // CHECK: start, pg = _ods_equally_sized_accessor(operation.results, 2, 0, 1)
+ // CHECK: start, elements_per_group = _ods_equally_sized_accessor(self.operation.results, 1, 2, 0, 1)
// CHECK: return self.operation.results[start]
//
// CHECK: @builtins.property
// CHECK: def variadic2(self):
- // CHECK: start, pg = _ods_equally_sized_accessor(operation.results, 2, 1, 1)
- // CHECK: return self.operation.results[start:start + pg]
+ // CHECK: start, elements_per_group = _ods_equally_sized_accessor(self.operation.results, 1, 2, 1, 1)
+ // CHECK: return self.operation.results[start:start + elements_per_group]
let results = (outs Variadic<AnyType>:$variadic1, AnyType:$non_variadic,
Variadic<AnyType>:$variadic2);
}
diff --git a/mlir/test/python/dialects/ods_helpers.py b/mlir/test/python/dialects/ods_helpers.py
index 0d2a18e0eb0af2..6f02153e08db5e 100644
--- a/mlir/test/python/dialects/ods_helpers.py
+++ b/mlir/test/python/dialects/ods_helpers.py
@@ -3,6 +3,7 @@
import gc
from mlir.ir import *
+from mlir.dialects._ods_common import equally_sized_accessor
def run(f):
@@ -208,3 +209,70 @@ class TestOp(OpView):
run(testOdsBuildDefaultCastError)
+
+
+def testOdsEquallySizedAccessor():
+ class TestOpMultiResultSegments(OpView):
+ OPERATION_NAME = "custom.test_op"
+ _ODS_REGIONS = (1, True)
+
+ with Context() as ctx, Location.unknown():
+ ctx.allow_unregistered_dialects = True
+ m = Module.create()
+ with InsertionPoint(m.body):
+ v = add_dummy_value()
+ ts = [IntegerType.get_signless(i * 8) for i in range(4)]
+
+ op = TestOpMultiResultSegments.build_generic(
+ results=[ts[0], ts[1], ts[2], ts[3]], operands=[v]
+ )
+ start, elements_per_group = equally_sized_accessor(op.results, 1, 3, 1, 0)
+ # CHECK: start: 1, elements_per_group: 1
+ print(f"start: {start}, elements_per_group: {elements_per_group}")
+ # CHECK: i8
+ print(op.results[start].type)
+
+ start, elements_per_group = equally_sized_accessor(op.results, 1, 3, 1, 1)
+ # CHECK: start: 2, elements_per_group: 1
+ print(f"start: {start}, elements_per_group: {elements_per_group}")
+ # CHECK: i16
+ print(op.results[start].type)
+
+
+run(testOdsEquallySizedAccessor)
+
+
+def testOdsEquallySizedAccessorMultipleSegments():
+ class TestOpMultiResultSegments(OpView):
+ OPERATION_NAME = "custom.test_op"
+ _ODS_REGIONS = (1, True)
+ _ODS_RESULT_SEGMENTS = [0, -1, -1]
+
+ def types(lst):
+ return [e.type for e in lst]
+
+ with Context() as ctx, Location.unknown():
+ ctx.allow_unregistered_dialects = True
+ m = Module.create()
+ with InsertionPoint(m.body):
+ v = add_dummy_value()
+ ts = [IntegerType.get_signless(i * 8) for i in range(7)]
+
+ op = TestOpMultiResultSegments.build_generic(
+ results=[ts[0], [ts[1], ts[2], ts[3]], [ts[4], ts[5], ts[6]]],
+ operands=[v],
+ )
+ start, elements_per_group = equally_sized_accessor(op.results, 1, 2, 1, 0)
+ # CHECK: start: 1, elements_per_group: 3
+ print(f"start: {start}, elements_per_group: {elements_per_group}")
+ # CHECK: [IntegerType(i8), IntegerType(i16), IntegerType(i24)]
+ print(types(op.results[start : start + elements_per_group]))
+
+ start, elements_per_group = equally_sized_accessor(op.results, 1, 2, 1, 1)
+ # CHECK: start: 4, elements_per_group: 3
+ print(f"start: {start}, elements_per_group: {elements_per_group}")
+ # CHECK: [IntegerType(i32), IntegerType(i40), IntegerType(i48)]
+ print(types(op.results[start : start + elements_per_group]))
+
+
+run(testOdsEquallySizedAccessorMultipleSegments)
diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index a76f3f2b5e4583..948d1225ea489c 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -555,3 +555,123 @@ def testInferTypeOpInterface():
two_operands = test.InferResultsVariadicInputsOp(single=zero, doubled=zero)
# CHECK: f32
print(two_operands.result.type)
+
+
+# CHECK-LABEL: TEST: testVariadicOperandAccess
+ at run
+def testVariadicOperandAccess():
+ def values(lst):
+ return [str(e) for e in lst]
+
+ with Context() as ctx, Location.unknown(ctx):
+ module = Module.create()
+ with InsertionPoint(module.body):
+ i32 = IntegerType.get_signless(32)
+ zero = arith.ConstantOp(i32, 0)
+ one = arith.ConstantOp(i32, 1)
+ two = arith.ConstantOp(i32, 2)
+ three = arith.ConstantOp(i32, 3)
+ four = arith.ConstantOp(i32, 4)
+
+ variadic_operands = test.SameVariadicOperandSizeOp(
+ [zero, one], two, [three, four]
+ )
+ # CHECK: Value(%{{.*}} = arith.constant 2 : i32)
+ print(variadic_operands.non_variadic)
+ # CHECK: ['Value(%{{.*}} = arith.constant 0 : i32)', 'Value(%{{.*}} = arith.constant 1 : i32)']
+ print(values(variadic_operands.variadic1))
+ # CHECK: ['Value(%{{.*}} = arith.constant 3 : i32)', 'Value(%{{.*}} = arith.constant 4 : i32)']
+ print(values(variadic_operands.variadic2))
+
+
+# CHECK-LABEL: TEST: testVariadicResultAccess
+ at run
+def testVariadicResultAccess():
+ def types(lst):
+ return [e.type for e in lst]
+
+ with Context() as ctx, Location.unknown(ctx):
+ module = Module.create()
+ with InsertionPoint(module.body):
+ i = [IntegerType.get_signless(k) for k in range(7)]
+
+ # Test Variadic-Fixed-Variadic
+ op = test.SameVariadicResultSizeOpVFV([i[0], i[1]], i[2], [i[3], i[4]])
+ # CHECK: i2
+ print(op.non_variadic.type)
+ # CHECK: [IntegerType(i0), IntegerType(i1)]
+ print(types(op.variadic1))
+ # CHECK: [IntegerType(i3), IntegerType(i4)]
+ print(types(op.variadic2))
+
+ # Test Variadic-Variadic-Variadic
+ op = test.SameVariadicResultSizeOpVVV(
+ [i[0], i[1]], [i[2], i[3]], [i[4], i[5]]
+ )
+ # CHECK: [IntegerType(i0), IntegerType(i1)]
+ print(types(op.variadic1))
+ # CHECK: [IntegerType(i2), IntegerType(i3)]
+ print(types(op.variadic2))
+ # CHECK: [IntegerType(i4), IntegerType(i5)]
+ print(types(op.variadic3))
+
+ # Test Fixed-Fixed-Variadic
+ op = test.SameVariadicResultSizeOpFFV(i[0], i[1], [i[2], i[3], i[4]])
+ # CHECK: i0
+ print(op.non_variadic1.type)
+ # CHECK: i1
+ print(op.non_variadic2.type)
+ # CHECK: [IntegerType(i2), IntegerType(i3), IntegerType(i4)]
+ print(types(op.variadic))
+
+ # Test Variadic-Variadic-Fixed
+ op = test.SameVariadicResultSizeOpVVF(
+ [i[0], i[1], i[2]], [i[3], i[4], i[5]], i[6]
+ )
+ # CHECK: [IntegerType(i0), IntegerType(i1), IntegerType(i2)]
+ print(types(op.variadic1))
+ # CHECK: [IntegerType(i3), IntegerType(i4), IntegerType(i5)]
+ print(types(op.variadic2))
+ # CHECK: i6
+ print(op.non_variadic.type)
+
+ # Test Fixed-Variadic-Fixed-Variadic-Fixed
+ op = test.SameVariadicResultSizeOpFVFVF(
+ i[0], [i[1], i[2]], i[3], [i[4], i[5]], i[6]
+ )
+ # CHECK: i0
+ print(op.non_variadic1.type)
+ # CHECK: [IntegerType(i1), IntegerType(i2)]
+ print(types(op.variadic1))
+ # CHECK: i3
+ print(op.non_variadic2.type)
+ # CHECK: [IntegerType(i4), IntegerType(i5)]
+ print(types(op.variadic2))
+ # CHECK: i6
+ print(op.non_variadic3.type)
+
+ # Test Fixed-Variadic-Fixed-Variadic-Fixed - Variadic group size 0
+ op = test.SameVariadicResultSizeOpFVFVF(i[0], [], i[1], [], i[2])
+ # CHECK: i0
+ print(op.non_variadic1.type)
+ # CHECK: []
+ print(types(op.variadic1))
+ # CHECK: i1
+ print(op.non_variadic2.type)
+ # CHECK: []
+ print(types(op.variadic2))
+ # CHECK: i2
+ print(op.non_variadic3.type)
+
+ # Test Fixed-Variadic-Fixed-Variadic-Fixed - Variadic group size 1
+ op = test.SameVariadicResultSizeOpFVFVF(i[0], [i[1]], i[2], [i[3]], i[4])
+ # CHECK: i0
+ print(op.non_variadic1.type)
+ # CHECK: [IntegerType(i1)]
+ print(types(op.variadic1))
+ # CHECK: i2
+ print(op.non_variadic2.type)
+ # CHECK: [IntegerType(i3)]
+ print(types(op.variadic2))
+ # CHECK: i4
+ print(op.non_variadic3.type)
diff --git a/mlir/test/python/python_test_ops.td b/mlir/test/python/python_test_ops.td
index 6211fb9987c76a..026e64a3cfc19b 100644
--- a/mlir/test/python/python_test_ops.td
+++ b/mlir/test/python/python_test_ops.td
@@ -227,4 +227,42 @@ def OptionalOperandOp : TestOp<"optional_operand_op"> {
let results = (outs I32:$result);
}
+def SameVariadicOperandSizeOp : TestOp<"same_variadic_operand",
+ [SameVariadicOperandSize]> {
+ let arguments = (ins Variadic<AnyType>:$variadic1, AnyType:$non_variadic,
+ Variadic<AnyType>:$variadic2);
+}
+
+// Check
diff erent arrangements of variadic groups
+def SameVariadicResultSizeOpVFV : TestOp<"same_variadic_result_vfv",
+ [SameVariadicResultSize]> {
+ let results = (outs Variadic<AnyType>:$variadic1, AnyType:$non_variadic,
+ Variadic<AnyType>:$variadic2);
+}
+
+def SameVariadicResultSizeOpVVV : TestOp<"same_variadic_result_vvv",
+ [SameVariadicResultSize]> {
+ let results = (outs Variadic<AnyType>:$variadic1, Variadic<AnyType>:$variadic2,
+ Variadic<AnyType>:$variadic3);
+}
+
+def SameVariadicResultSizeOpFFV : TestOp<"same_variadic_result_ffv",
+ [SameVariadicResultSize]> {
+ let results = (outs AnyType:$non_variadic1, AnyType:$non_variadic2,
+ Variadic<AnyType>:$variadic);
+}
+
+def SameVariadicResultSizeOpVVF : TestOp<"same_variadic_result_vvf",
+ [SameVariadicResultSize]> {
+ let results = (outs Variadic<AnyType>:$variadic1, Variadic<AnyType>:$variadic2,
+ AnyType:$non_variadic);
+}
+
+def SameVariadicResultSizeOpFVFVF : TestOp<"same_variadic_result_fvfvf",
+ [SameVariadicResultSize]> {
+ let results = (outs AnyType:$non_variadic1, Variadic<AnyType>:$variadic1,
+ AnyType:$non_variadic2, Variadic<AnyType>:$variadic2,
+ AnyType:$non_variadic3);
+}
+
#endif // PYTHON_TEST_OPS
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 052020acdcb764..553ab6adc65b06 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -139,13 +139,14 @@ constexpr const char *opOneVariadicTemplate = R"Py(
/// First part of the template for equally-sized variadic group accessor:
/// {0} is the name of the accessor;
/// {1} is either 'operand' or 'result';
-/// {2} is the total number of variadic groups;
-/// {3} is the number of non-variadic groups preceding the current group;
-/// {3} is the number of variadic groups preceding the current group.
+/// {2} is the total number of non-variadic groups;
+/// {3} is the total number of variadic groups;
+/// {4} is the number of non-variadic groups preceding the current group;
+/// {5} is the number of variadic groups preceding the current group.
constexpr const char *opVariadicEqualPrefixTemplate = R"Py(
@builtins.property
def {0}(self):
- start, pg = _ods_equally_sized_accessor(operation.{1}s, {2}, {3}, {4}))Py";
+ start, elements_per_group = _ods_equally_sized_accessor(self.operation.{1}s, {2}, {3}, {4}, {5}))Py";
/// Second part of the template for equally-sized case, accessing a single
/// element:
@@ -158,7 +159,7 @@ constexpr const char *opVariadicEqualSimpleTemplate = R"Py(
/// group:
/// {0} is either 'operand' or 'result'.
constexpr const char *opVariadicEqualVariadicTemplate = R"Py(
- return self.operation.{0}s[start:start + pg]
+ return self.operation.{0}s[start:start + elements_per_group]
)Py";
/// Template for an attribute-sized group accessor:
@@ -324,8 +325,7 @@ static std::string attrSizedTraitForKind(const char *kind) {
/// `operand` or `result` and is used verbatim in the emitted code.
static void emitElementAccessors(
const Operator &op, raw_ostream &os, const char *kind,
- llvm::function_ref<unsigned(const Operator &)> getNumVariableLength,
- llvm::function_ref<int(const Operator &)> getNumElements,
+ unsigned numVariadicGroups, unsigned numElements,
llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
getElement) {
assert(llvm::is_contained(
@@ -339,14 +339,12 @@ static void emitElementAccessors(
llvm::StringRef(kind).drop_front());
std::string attrSizedTrait = attrSizedTraitForKind(kind);
- unsigned numVariableLength = getNumVariableLength(op);
-
// If there is only one variable-length element group, its size can be
// inferred from the total number of elements. If there are none, the
// generation is straightforward.
- if (numVariableLength <= 1) {
+ if (numVariadicGroups <= 1) {
bool seenVariableLength = false;
- for (int i = 0, e = getNumElements(op); i < e; ++i) {
+ for (unsigned i = 0; i < numElements; ++i) {
const NamedTypeConstraint &element = getElement(op, i);
if (element.isVariableLength())
seenVariableLength = true;
@@ -355,12 +353,10 @@ static void emitElementAccessors(
if (element.isVariableLength()) {
os << llvm::formatv(element.isOptional() ? opOneOptionalTemplate
: opOneVariadicTemplate,
- sanitizeName(element.name), kind,
- getNumElements(op), i);
+ sanitizeName(element.name), kind, numElements, i);
} else if (seenVariableLength) {
os << llvm::formatv(opSingleAfterVariableTemplate,
- sanitizeName(element.name), kind,
- getNumElements(op), i);
+ sanitizeName(element.name), kind, numElements, i);
} else {
os << llvm::formatv(opSingleTemplate, sanitizeName(element.name), kind,
i);
@@ -371,14 +367,25 @@ static void emitElementAccessors(
// Handle the operations where variadic groups have the same size.
if (op.getTrait(sameSizeTrait)) {
+ // Count the number of simple elements
+ unsigned numSimpleLength = 0;
+ for (unsigned i = 0; i < numElements; ++i) {
+ const NamedTypeConstraint &element = getElement(op, i);
+ if (!element.isVariableLength()) {
+ ++numSimpleLength;
+ }
+ }
+
+ // Generate the accessors
int numPrecedingSimple = 0;
int numPrecedingVariadic = 0;
- for (int i = 0, e = getNumElements(op); i < e; ++i) {
+ for (unsigned i = 0; i < numElements; ++i) {
const NamedTypeConstraint &element = getElement(op, i);
if (!element.name.empty()) {
os << llvm::formatv(opVariadicEqualPrefixTemplate,
- sanitizeName(element.name), kind, numVariableLength,
- numPrecedingSimple, numPrecedingVariadic);
+ sanitizeName(element.name), kind, numSimpleLength,
+ numVariadicGroups, numPrecedingSimple,
+ numPrecedingVariadic);
os << llvm::formatv(element.isVariableLength()
? opVariadicEqualVariadicTemplate
: opVariadicEqualSimpleTemplate,
@@ -396,7 +403,7 @@ static void emitElementAccessors(
// provided as an attribute. For non-variadic elements, make sure to return
// an element rather than a singleton container.
if (op.getTrait(attrSizedTrait)) {
- for (int i = 0, e = getNumElements(op); i < e; ++i) {
+ for (unsigned i = 0; i < numElements; ++i) {
const NamedTypeConstraint &element = getElement(op, i);
if (element.name.empty())
continue;
@@ -427,20 +434,14 @@ static const NamedTypeConstraint &getResult(const Operator &op, int i) {
/// Emits accessors to Op operands.
static void emitOperandAccessors(const Operator &op, raw_ostream &os) {
- auto getNumVariableLengthOperands = [](const Operator &oper) {
- return oper.getNumVariableLengthOperands();
- };
- emitElementAccessors(op, os, "operand", getNumVariableLengthOperands,
- getNumOperands, getOperand);
+ emitElementAccessors(op, os, "operand", op.getNumVariableLengthOperands(),
+ getNumOperands(op), getOperand);
}
/// Emits accessors Op results.
static void emitResultAccessors(const Operator &op, raw_ostream &os) {
- auto getNumVariableLengthResults = [](const Operator &oper) {
- return oper.getNumVariableLengthResults();
- };
- emitElementAccessors(op, os, "result", getNumVariableLengthResults,
- getNumResults, getResult);
+ emitElementAccessors(op, os, "result", op.getNumVariableLengthResults(),
+ getNumResults(op), getResult);
}
/// Emits accessors to Op attributes.
More information about the Mlir-commits
mailing list