[Mlir-commits] [mlir] 54c9984 - [mlir][Python] Fix generation of accessors for Optional

Alex Zinenko llvmlistbot at llvm.org
Thu Nov 18 00:43:04 PST 2021


Author: Michal Terepeta
Date: 2021-11-18T09:42:57+01:00
New Revision: 54c99842079997b0fe208acdab01e540c0d81b51

URL: https://github.com/llvm/llvm-project/commit/54c99842079997b0fe208acdab01e540c0d81b51
DIFF: https://github.com/llvm/llvm-project/commit/54c99842079997b0fe208acdab01e540c0d81b51.diff

LOG: [mlir][Python] Fix generation of accessors for Optional

Previously, in case there was only one `Optional` operand/result within
the list, we would always return `None` from the accessor, e.g., for a
single optional result we would generate:

```
return self.operation.results[0] if len(self.operation.results) > 1 else None
```

But what we really want is to return `None` only if the length of
`results` is smaller than the total number of element groups (i.e.,
the optional operand/result is in fact missing).

This commit also renames a few local variables in the generator to make
the distinction between `isVariadic()` and `isVariableLength()` a bit
more clear.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D113855

Added: 
    

Modified: 
    mlir/python/mlir/dialects/_linalg_ops_ext.py
    mlir/test/mlir-tblgen/op-python-bindings.td
    mlir/test/python/dialects/linalg/ops.py
    mlir/test/python/dialects/python_test.py
    mlir/test/python/python_test_ops.td
    mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/python/mlir/dialects/_linalg_ops_ext.py b/mlir/python/mlir/dialects/_linalg_ops_ext.py
index b7641c0a4b53c..d6c57547ee163 100644
--- a/mlir/python/mlir/dialects/_linalg_ops_ext.py
+++ b/mlir/python/mlir/dialects/_linalg_ops_ext.py
@@ -36,15 +36,6 @@ def __init__(self, output: Value, value: Value, *, loc=None, ip=None):
     OpView.__init__(self, op)
     linalgDialect = Context.current.get_dialect_descriptor("linalg")
     fill_builtin_region(linalgDialect, self.operation)
-    # TODO: self.result is None. When len(results) == 1 we expect it to be
-    # results[0] as per _linalg_ops_gen.py. This seems like an orthogonal bug
-    # in the generator of _linalg_ops_gen.py where we have:
-    # ```
-    # def result(self):
-    #   return self.operation.results[0] \
-    #     if len(self.operation.results) > 1 else None
-    # ```
-
 
 class InitTensorOp:
   """Extends the linalg.init_tensor op."""

diff  --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index becce13050a18..aa9977e047f15 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -304,7 +304,7 @@ def OneOptionalOperandOp : TestOp<"one_optional_operand"> {
 
   // CHECK: @builtins.property
   // CHECK: def optional(self):
-  // CHECK:   return self.operation.operands[1] if len(self.operation.operands) > 2 else None
+  // CHECK:   return None if len(self.operation.operands) < 2 else self.operation.operands[1]
 
 }
 

diff  --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index d788292f3424b..e5b96c260eaad 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -68,10 +68,7 @@ def testFill():
       @builtin.FuncOp.from_py_func(RankedTensorType.get((12, -1), f32))
       def fill_tensor(out):
         zero = arith.ConstantOp(value=FloatAttr.get(f32, 0.), result=f32).result
-        # TODO: FillOp.result is None. When len(results) == 1 we expect it to
-        # be results[0] as per _linalg_ops_gen.py. This seems like an
-        # orthogonal bug in the generator of _linalg_ops_gen.py.
-        return linalg.FillOp(output=out, value=zero).results[0]
+        return linalg.FillOp(output=out, value=zero).result
 
       # CHECK-LABEL: func @fill_buffer
       #  CHECK-SAME:   %[[OUT:[0-9a-z]+]]: memref<12x?xf32>

diff  --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index 2267b59cd4d77..f9da91fba4cdf 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -207,3 +207,21 @@ def resultTypesDefinedByTraits():
       print(implied.flt.type)
       # CHECK: index
       print(implied.index.type)
+
+
+# CHECK-LABEL: TEST: testOptionalOperandOp
+ at run
+def testOptionalOperandOp():
+  with Context() as ctx, Location.unknown():
+    test.register_python_test_dialect(ctx)
+
+    module = Module.create()
+    with InsertionPoint(module.body):
+
+      op1 = test.OptionalOperandOp(None)
+      # CHECK: op1.input is None: True
+      print(f"op1.input is None: {op1.input is None}")
+
+      op2 = test.OptionalOperandOp(op1)
+      # CHECK: op2.input is None: False
+      print(f"op2.input is None: {op2.input is None}")

diff  --git a/mlir/test/python/python_test_ops.td b/mlir/test/python/python_test_ops.td
index 0f947e7e536bd..6ee71dbf8b123 100644
--- a/mlir/test/python/python_test_ops.td
+++ b/mlir/test/python/python_test_ops.td
@@ -76,4 +76,9 @@ def FirstAttrDeriveAttrOp : TestOp<"first_attr_derive_attr_op",
   let results = (outs AnyType:$one, AnyType:$two, AnyType:$three);
 }
 
+def OptionalOperandOp : TestOp<"optional_operand_op"> {
+  let arguments = (ins Optional<AnyType>:$input);
+  let results = (outs I32:$result);
+}
+
 #endif // PYTHON_TEST_OPS

diff  --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 8babff25db07b..fb634a1be3957 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -109,10 +109,13 @@ constexpr const char *opSingleAfterVariableTemplate = R"Py(
 ///   {1} is either 'operand' or 'result';
 ///   {2} is the total number of element groups;
 ///   {3} is the position of the current group in the group list.
+/// This works if we have only one variable-length group (and it's the optional
+/// operand/result): we can deduce it's absent if the `len(operation.{1}s)` is
+/// smaller than the total number of groups.
 constexpr const char *opOneOptionalTemplate = R"Py(
   @builtins.property
   def {0}(self):
-    return self.operation.{1}s[{3}] if len(self.operation.{1}s) > {2} else None
+    return None if len(self.operation.{1}s) < {2} else self.operation.{1}s[{3}]
 )Py";
 
 /// Template for the variadic group accessor in the single variadic group case:
@@ -311,7 +314,7 @@ static std::string attrSizedTraitForKind(const char *kind) {
 /// `operand` or `result` and is used verbatim in the emitted code.
 static void emitElementAccessors(
     const Operator &op, raw_ostream &os, const char *kind,
-    llvm::function_ref<unsigned(const Operator &)> getNumVariadic,
+    llvm::function_ref<unsigned(const Operator &)> getNumVariableLength,
     llvm::function_ref<int(const Operator &)> getNumElements,
     llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
         getElement) {
@@ -326,12 +329,12 @@ static void emitElementAccessors(
                     llvm::StringRef(kind).drop_front());
   std::string attrSizedTrait = attrSizedTraitForKind(kind);
 
-  unsigned numVariadic = getNumVariadic(op);
+  unsigned numVariableLength = getNumVariableLength(op);
 
-  // If there is only one variadic element group, its size can be inferred from
-  // the total number of elements. If there are none, the generation is
-  // straightforward.
-  if (numVariadic <= 1) {
+  // If there is only one variable-length element group, its size can be
+  // inferred from the total number of elements. If there are none, the
+  // generation is straightforward.
+  if (numVariableLength <= 1) {
     bool seenVariableLength = false;
     for (int i = 0, e = getNumElements(op); i < e; ++i) {
       const NamedTypeConstraint &element = getElement(op, i);
@@ -364,7 +367,7 @@ static void emitElementAccessors(
       const NamedTypeConstraint &element = getElement(op, i);
       if (!element.name.empty()) {
         os << llvm::formatv(opVariadicEqualPrefixTemplate,
-                            sanitizeName(element.name), kind, numVariadic,
+                            sanitizeName(element.name), kind, numVariableLength,
                             numPrecedingSimple, numPrecedingVariadic);
         os << llvm::formatv(element.isVariableLength()
                                 ? opVariadicEqualVariadicTemplate
@@ -414,20 +417,20 @@ static const NamedTypeConstraint &getResult(const Operator &op, int i) {
 
 /// Emits accessors to Op operands.
 static void emitOperandAccessors(const Operator &op, raw_ostream &os) {
-  auto getNumVariadic = [](const Operator &oper) {
+  auto getNumVariableLengthOperands = [](const Operator &oper) {
     return oper.getNumVariableLengthOperands();
   };
-  emitElementAccessors(op, os, "operand", getNumVariadic, getNumOperands,
-                       getOperand);
+  emitElementAccessors(op, os, "operand", getNumVariableLengthOperands,
+                       getNumOperands, getOperand);
 }
 
 /// Emits accessors Op results.
 static void emitResultAccessors(const Operator &op, raw_ostream &os) {
-  auto getNumVariadic = [](const Operator &oper) {
+  auto getNumVariableLengthResults = [](const Operator &oper) {
     return oper.getNumVariableLengthResults();
   };
-  emitElementAccessors(op, os, "result", getNumVariadic, getNumResults,
-                       getResult);
+  emitElementAccessors(op, os, "result", getNumVariableLengthResults,
+                       getNumResults, getResult);
 }
 
 /// Emits accessors to Op attributes.


        


More information about the Mlir-commits mailing list