[Mlir-commits] [mlir] 54c9984 - [mlir][Python] Fix generation of accessors for Optional
Alex Zinenko
llvmlistbot at llvm.org
Thu Nov 18 00:43:04 PST 2021
Author: Michal Terepeta
Date: 2021-11-18T09:42:57+01:00
New Revision: 54c99842079997b0fe208acdab01e540c0d81b51
URL: https://github.com/llvm/llvm-project/commit/54c99842079997b0fe208acdab01e540c0d81b51
DIFF: https://github.com/llvm/llvm-project/commit/54c99842079997b0fe208acdab01e540c0d81b51.diff
LOG: [mlir][Python] Fix generation of accessors for Optional
Previously, in case there was only one `Optional` operand/result within
the list, we would always return `None` from the accessor, e.g., for a
single optional result we would generate:
```
return self.operation.results[0] if len(self.operation.results) > 1 else None
```
But what we really want is to return `None` only if the length of
`results` is smaller than the total number of element groups (i.e.,
the optional operand/result is in fact missing).
This commit also renames a few local variables in the generator to make
the distinction between `isVariadic()` and `isVariableLength()` a bit
more clear.
Reviewed By: ftynse
Differential Revision: https://reviews.llvm.org/D113855
Added:
Modified:
mlir/python/mlir/dialects/_linalg_ops_ext.py
mlir/test/mlir-tblgen/op-python-bindings.td
mlir/test/python/dialects/linalg/ops.py
mlir/test/python/dialects/python_test.py
mlir/test/python/python_test_ops.td
mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
Removed:
################################################################################
diff --git a/mlir/python/mlir/dialects/_linalg_ops_ext.py b/mlir/python/mlir/dialects/_linalg_ops_ext.py
index b7641c0a4b53c..d6c57547ee163 100644
--- a/mlir/python/mlir/dialects/_linalg_ops_ext.py
+++ b/mlir/python/mlir/dialects/_linalg_ops_ext.py
@@ -36,15 +36,6 @@ def __init__(self, output: Value, value: Value, *, loc=None, ip=None):
OpView.__init__(self, op)
linalgDialect = Context.current.get_dialect_descriptor("linalg")
fill_builtin_region(linalgDialect, self.operation)
- # TODO: self.result is None. When len(results) == 1 we expect it to be
- # results[0] as per _linalg_ops_gen.py. This seems like an orthogonal bug
- # in the generator of _linalg_ops_gen.py where we have:
- # ```
- # def result(self):
- # return self.operation.results[0] \
- # if len(self.operation.results) > 1 else None
- # ```
-
class InitTensorOp:
"""Extends the linalg.init_tensor op."""
diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index becce13050a18..aa9977e047f15 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -304,7 +304,7 @@ def OneOptionalOperandOp : TestOp<"one_optional_operand"> {
// CHECK: @builtins.property
// CHECK: def optional(self):
- // CHECK: return self.operation.operands[1] if len(self.operation.operands) > 2 else None
+ // CHECK: return None if len(self.operation.operands) < 2 else self.operation.operands[1]
}
diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index d788292f3424b..e5b96c260eaad 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -68,10 +68,7 @@ def testFill():
@builtin.FuncOp.from_py_func(RankedTensorType.get((12, -1), f32))
def fill_tensor(out):
zero = arith.ConstantOp(value=FloatAttr.get(f32, 0.), result=f32).result
- # TODO: FillOp.result is None. When len(results) == 1 we expect it to
- # be results[0] as per _linalg_ops_gen.py. This seems like an
- # orthogonal bug in the generator of _linalg_ops_gen.py.
- return linalg.FillOp(output=out, value=zero).results[0]
+ return linalg.FillOp(output=out, value=zero).result
# CHECK-LABEL: func @fill_buffer
# CHECK-SAME: %[[OUT:[0-9a-z]+]]: memref<12x?xf32>
diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index 2267b59cd4d77..f9da91fba4cdf 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -207,3 +207,21 @@ def resultTypesDefinedByTraits():
print(implied.flt.type)
# CHECK: index
print(implied.index.type)
+
+
+# CHECK-LABEL: TEST: testOptionalOperandOp
+ at run
+def testOptionalOperandOp():
+ with Context() as ctx, Location.unknown():
+ test.register_python_test_dialect(ctx)
+
+ module = Module.create()
+ with InsertionPoint(module.body):
+
+ op1 = test.OptionalOperandOp(None)
+ # CHECK: op1.input is None: True
+ print(f"op1.input is None: {op1.input is None}")
+
+ op2 = test.OptionalOperandOp(op1)
+ # CHECK: op2.input is None: False
+ print(f"op2.input is None: {op2.input is None}")
diff --git a/mlir/test/python/python_test_ops.td b/mlir/test/python/python_test_ops.td
index 0f947e7e536bd..6ee71dbf8b123 100644
--- a/mlir/test/python/python_test_ops.td
+++ b/mlir/test/python/python_test_ops.td
@@ -76,4 +76,9 @@ def FirstAttrDeriveAttrOp : TestOp<"first_attr_derive_attr_op",
let results = (outs AnyType:$one, AnyType:$two, AnyType:$three);
}
+def OptionalOperandOp : TestOp<"optional_operand_op"> {
+ let arguments = (ins Optional<AnyType>:$input);
+ let results = (outs I32:$result);
+}
+
#endif // PYTHON_TEST_OPS
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 8babff25db07b..fb634a1be3957 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -109,10 +109,13 @@ constexpr const char *opSingleAfterVariableTemplate = R"Py(
/// {1} is either 'operand' or 'result';
/// {2} is the total number of element groups;
/// {3} is the position of the current group in the group list.
+/// This works if we have only one variable-length group (and it's the optional
+/// operand/result): we can deduce it's absent if the `len(operation.{1}s)` is
+/// smaller than the total number of groups.
constexpr const char *opOneOptionalTemplate = R"Py(
@builtins.property
def {0}(self):
- return self.operation.{1}s[{3}] if len(self.operation.{1}s) > {2} else None
+ return None if len(self.operation.{1}s) < {2} else self.operation.{1}s[{3}]
)Py";
/// Template for the variadic group accessor in the single variadic group case:
@@ -311,7 +314,7 @@ static std::string attrSizedTraitForKind(const char *kind) {
/// `operand` or `result` and is used verbatim in the emitted code.
static void emitElementAccessors(
const Operator &op, raw_ostream &os, const char *kind,
- llvm::function_ref<unsigned(const Operator &)> getNumVariadic,
+ llvm::function_ref<unsigned(const Operator &)> getNumVariableLength,
llvm::function_ref<int(const Operator &)> getNumElements,
llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
getElement) {
@@ -326,12 +329,12 @@ static void emitElementAccessors(
llvm::StringRef(kind).drop_front());
std::string attrSizedTrait = attrSizedTraitForKind(kind);
- unsigned numVariadic = getNumVariadic(op);
+ unsigned numVariableLength = getNumVariableLength(op);
- // If there is only one variadic element group, its size can be inferred from
- // the total number of elements. If there are none, the generation is
- // straightforward.
- if (numVariadic <= 1) {
+ // If there is only one variable-length element group, its size can be
+ // inferred from the total number of elements. If there are none, the
+ // generation is straightforward.
+ if (numVariableLength <= 1) {
bool seenVariableLength = false;
for (int i = 0, e = getNumElements(op); i < e; ++i) {
const NamedTypeConstraint &element = getElement(op, i);
@@ -364,7 +367,7 @@ static void emitElementAccessors(
const NamedTypeConstraint &element = getElement(op, i);
if (!element.name.empty()) {
os << llvm::formatv(opVariadicEqualPrefixTemplate,
- sanitizeName(element.name), kind, numVariadic,
+ sanitizeName(element.name), kind, numVariableLength,
numPrecedingSimple, numPrecedingVariadic);
os << llvm::formatv(element.isVariableLength()
? opVariadicEqualVariadicTemplate
@@ -414,20 +417,20 @@ static const NamedTypeConstraint &getResult(const Operator &op, int i) {
/// Emits accessors to Op operands.
static void emitOperandAccessors(const Operator &op, raw_ostream &os) {
- auto getNumVariadic = [](const Operator &oper) {
+ auto getNumVariableLengthOperands = [](const Operator &oper) {
return oper.getNumVariableLengthOperands();
};
- emitElementAccessors(op, os, "operand", getNumVariadic, getNumOperands,
- getOperand);
+ emitElementAccessors(op, os, "operand", getNumVariableLengthOperands,
+ getNumOperands, getOperand);
}
/// Emits accessors Op results.
static void emitResultAccessors(const Operator &op, raw_ostream &os) {
- auto getNumVariadic = [](const Operator &oper) {
+ auto getNumVariableLengthResults = [](const Operator &oper) {
return oper.getNumVariableLengthResults();
};
- emitElementAccessors(op, os, "result", getNumVariadic, getNumResults,
- getResult);
+ emitElementAccessors(op, os, "result", getNumVariableLengthResults,
+ getNumResults, getResult);
}
/// Emits accessors to Op attributes.
More information about the Mlir-commits
mailing list