[Mlir-commits] [mlir] 6981e5e - [mlir][python] fix constructor generation for optional operands in presence of segment attribute

Alex Zinenko llvmlistbot at llvm.org
Fri Nov 5 04:40:34 PDT 2021


Author: Alex Zinenko
Date: 2021-11-05T12:40:27+01:00
New Revision: 6981e5ec91c98a23753d2dae590156107d857fda

URL: https://github.com/llvm/llvm-project/commit/6981e5ec91c98a23753d2dae590156107d857fda
DIFF: https://github.com/llvm/llvm-project/commit/6981e5ec91c98a23753d2dae590156107d857fda.diff

LOG: [mlir][python] fix constructor generation for optional operands in presence of segment attribute

The ODS-based Python op bindings generator has been generating incorrect
specification of the operand segment in presence if both optional and variadic
operand groups: optional groups were treated as variadic whereas they require
separate treatement. Make sure it is the case. Also harden the tests around
generated op constructors as they could hitherto accept the code for both
optional and variadic arguments.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D113259

Added: 
    

Modified: 
    mlir/lib/Bindings/Python/IRCore.cpp
    mlir/test/mlir-tblgen/op-python-bindings.td
    mlir/test/python/dialects/vector.py
    mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index d465c1382459c..cf59a67f9c8f0 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1153,7 +1153,7 @@ PyOpView::buildGeneric(py::object cls, py::list resultTypeList,
       throw py::value_error((llvm::Twine("Operation \"") + name +
                              "\" requires " +
                              llvm::Twine(resultSegmentSpec.size()) +
-                             "result segments but was provided " +
+                             " result segments but was provided " +
                              llvm::Twine(resultTypeList.size()))
                                 .str());
     }
@@ -1164,7 +1164,7 @@ PyOpView::buildGeneric(py::object cls, py::list resultTypeList,
       if (segmentSpec == 1 || segmentSpec == 0) {
         // Unpack unary element.
         try {
-          auto resultType = py::cast<PyType *>(std::get<0>(it.value()));
+          auto *resultType = py::cast<PyType *>(std::get<0>(it.value()));
           if (resultType) {
             resultTypes.push_back(resultType);
             resultSegmentLengths.push_back(1);

diff  --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index d6dc56428eb57..becce13050a18 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -18,7 +18,7 @@ class TestOp<string mnemonic, list<OpTrait> traits = []> :
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class AttrSizedOperandsOp(_ods_ir.OpView):
 // CHECK-LABEL: OPERATION_NAME = "test.attr_sized_operands"
-// CHECK: _ODS_OPERAND_SEGMENTS = [-1,1,-1,]
+// CHECK: _ODS_OPERAND_SEGMENTS = [-1,1,0,]
 def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
                                  [AttrSizedOperandSegments]> {
   // CHECK: def __init__(self, variadic1, non_variadic, variadic2, *, loc=None, ip=None):
@@ -28,7 +28,7 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
   // CHECK:   regions = None
   // CHECK:   operands.append(_get_op_results_or_values(variadic1))
   // CHECK:   operands.append(_get_op_result_or_value(non_variadic))
-  // CHECK:   if variadic2 is not None: operands.append(_get_op_result_or_value(variadic2))
+  // CHECK:   operands.append(_get_op_result_or_value(variadic2) if variadic2 is not None else None)
   // CHECK:   _ods_successors = None
   // CHECK:   super().__init__(self.build_generic(
   // CHECK:     attributes=attributes, results=results, operands=operands,
@@ -40,6 +40,7 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
   // CHECK:       self.operation.operands,
   // CHECK:       self.operation.attributes["operand_segment_sizes"], 0)
   // CHECK:   return operand_range
+  // CHECK-NOT: if len(operand_range)
   //
   // CHECK: @builtins.property
   // CHECK: def non_variadic(self):
@@ -61,7 +62,7 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class AttrSizedResultsOp(_ods_ir.OpView):
 // CHECK-LABEL: OPERATION_NAME = "test.attr_sized_results"
-// CHECK: _ODS_RESULT_SEGMENTS = [-1,1,-1,]
+// CHECK: _ODS_RESULT_SEGMENTS = [0,1,-1,]
 def AttrSizedResultsOp : TestOp<"attr_sized_results",
                                [AttrSizedResultSegments]> {
   // CHECK: def __init__(self, variadic1, non_variadic, variadic2, *, loc=None, ip=None):
@@ -71,7 +72,7 @@ def AttrSizedResultsOp : TestOp<"attr_sized_results",
   // CHECK:   regions = None
   // CHECK:   if variadic1 is not None: results.append(variadic1)
   // CHECK:   results.append(non_variadic)
-  // CHECK:   if variadic2 is not None: results.append(variadic2)
+  // CHECK:   results.append(variadic2)
   // CHECK:   _ods_successors = None
   // CHECK:   super().__init__(self.build_generic(
   // CHECK:     attributes=attributes, results=results, operands=operands,
@@ -97,8 +98,9 @@ def AttrSizedResultsOp : TestOp<"attr_sized_results",
   // CHECK:       self.operation.results,
   // CHECK:       self.operation.attributes["result_segment_sizes"], 2)
   // CHECK:   return result_range
+  // CHECK-NOT: if len(result_range)
   let results = (outs Optional<AnyType>:$variadic1, AnyType:$non_variadic,
-                 Optional<AnyType>:$variadic2);
+                 Variadic<AnyType>:$variadic2);
 }
 
 
@@ -277,6 +279,35 @@ def MissingNamesOp : TestOp<"missing_names"> {
   let results = (outs I32:$i32, AnyFloat, I64:$i64);
 }
 
+// CHECK: @_ods_cext.register_operation(_Dialect)
+// CHECK: class OneOptionalOperandOp(_ods_ir.OpView):
+// CHECK-LABEL: OPERATION_NAME = "test.one_optional_operand"
+// CHECK-NOT: _ODS_OPERAND_SEGMENTS
+// CHECK-NOT: _ODS_RESULT_SEGMENTS
+def OneOptionalOperandOp : TestOp<"one_optional_operand"> {
+  let arguments = (ins AnyType:$non_optional, Optional<AnyType>:$optional);
+  // CHECK: def __init__(self, non_optional, optional, *, loc=None, ip=None):
+  // CHECK:   operands = []
+  // CHECK:   results = []
+  // CHECK:   attributes = {}
+  // CHECK:   regions = None
+  // CHECK:   operands.append(_get_op_result_or_value(non_optional))
+  // CHECK:   if optional is not None: operands.append(_get_op_result_or_value(optional))
+  // CHECK:   _ods_successors = None
+  // CHECK:   super().__init__(self.build_generic(
+  // CHECK:     attributes=attributes, results=results, operands=operands,
+  // CHECK:     successors=_ods_successors, regions=regions, loc=loc, ip=ip))
+
+  // CHECK: @builtins.property
+  // CHECK: def non_optional(self):
+  // CHECK:   return self.operation.operands[0]
+
+  // CHECK: @builtins.property
+  // CHECK: def optional(self):
+  // CHECK:   return self.operation.operands[1] if len(self.operation.operands) > 2 else None
+
+}
+
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class OneVariadicOperandOp(_ods_ir.OpView):
 // CHECK-LABEL: OPERATION_NAME = "test.one_variadic_operand"

diff  --git a/mlir/test/python/dialects/vector.py b/mlir/test/python/dialects/vector.py
index 4d7052859e7df..b8db94070d6a2 100644
--- a/mlir/test/python/dialects/vector.py
+++ b/mlir/test/python/dialects/vector.py
@@ -2,25 +2,58 @@
 
 from mlir.ir import *
 import mlir.dialects.builtin as builtin
+import mlir.dialects.std as std
 import mlir.dialects.vector as vector
 
 def run(f):
   print("\nTEST:", f.__name__)
-  f()
+  with Context(), Location.unknown():
+    f()
+  return f
 
 # CHECK-LABEL: TEST: testPrintOp
 @run
 def testPrintOp():
-  with Context() as ctx, Location.unknown():
-    module = Module.create()
-    with InsertionPoint(module.body):
-      @builtin.FuncOp.from_py_func(VectorType.get((12, 5), F32Type.get()))
-      def print_vector(arg):
-        return vector.PrintOp(arg)
-
-    # CHECK-LABEL: func @print_vector(
-    # CHECK-SAME:                     %[[ARG:.*]]: vector<12x5xf32>) {
-    #       CHECK:   vector.print %[[ARG]] : vector<12x5xf32>
-    #       CHECK:   return
-    #       CHECK: }
-    print(module)
+  module = Module.create()
+  with InsertionPoint(module.body):
+
+    @builtin.FuncOp.from_py_func(VectorType.get((12, 5), F32Type.get()))
+    def print_vector(arg):
+      return vector.PrintOp(arg)
+
+  # CHECK-LABEL: func @print_vector(
+  # CHECK-SAME:                     %[[ARG:.*]]: vector<12x5xf32>) {
+  #       CHECK:   vector.print %[[ARG]] : vector<12x5xf32>
+  #       CHECK:   return
+  #       CHECK: }
+  print(module)
+
+
+# CHECK-LABEL: TEST: testTransferReadOp
+ at run
+def testTransferReadOp():
+  module = Module.create()
+  with InsertionPoint(module.body):
+    vector_type = VectorType.get([2, 3], F32Type.get())
+    memref_type = MemRefType.get([-1, -1], F32Type.get())
+    index_type = IndexType.get()
+    mask_type = VectorType.get(vector_type.shape, IntegerType.get_signless(1))
+    identity_map = AffineMap.get_identity(vector_type.rank)
+    identity_map_attr = AffineMapAttr.get(identity_map)
+    func = builtin.FuncOp("transfer_read",
+                          ([memref_type, index_type,
+                            F32Type.get(), mask_type], []))
+    with InsertionPoint(func.add_entry_block()):
+      A, zero, padding, mask = func.arguments
+      vector.TransferReadOp(vector_type, A, [zero, zero], identity_map_attr,
+                            padding, mask, None)
+      vector.TransferReadOp(vector_type, A, [zero, zero], identity_map_attr,
+                            padding, None, None)
+      std.ReturnOp([])
+
+  # CHECK: @transfer_read(%[[MEM:.*]]: memref<?x?xf32>, %[[IDX:.*]]: index,
+  # CHECK: %[[PAD:.*]]: f32, %[[MASK:.*]]: vector<2x3xi1>)
+  # CHECK: vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]], %[[PAD]], %[[MASK]]
+  # CHECK: vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]], %[[PAD]]
+  # CHECK-NOT: %[[MASK]]
+  print(module)

diff  --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index d9ce2963a8f37..8babff25db07b 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -67,6 +67,7 @@ class {0}(_ods_ir.OpView):
 /// Each segment spec is either None (default) or an array of integers
 /// where:
 ///   1 = single element (expect non sequence operand/result)
+///   0 = optional element (expect a value or None)
 ///   -1 = operand/result is a sequence corresponding to a variadic
 constexpr const char *opClassSizedSegmentsTemplate = R"Py(
   _ODS_{0}_SEGMENTS = {1}
@@ -505,6 +506,9 @@ constexpr const char *singleResultAppendTemplate = "results.append({0})";
 ///   {0} is the field name.
 constexpr const char *optionalAppendOperandTemplate =
     "if {0} is not None: operands.append(_get_op_result_or_value({0}))";
+constexpr const char *optionalAppendAttrSizedOperandsTemplate =
+    "operands.append(_get_op_result_or_value({0}) if {0} is not None else "
+    "None)";
 constexpr const char *optionalAppendResultTemplate =
     "if {0} is not None: results.append({0})";
 
@@ -693,7 +697,11 @@ populateBuilderLinesOperand(const Operator &op,
     if (!element.isVariableLength()) {
       formatString = singleOperandAppendTemplate;
     } else if (element.isOptional()) {
-      formatString = optionalAppendOperandTemplate;
+      if (sizedSegments) {
+        formatString = optionalAppendAttrSizedOperandsTemplate;
+      } else {
+        formatString = optionalAppendOperandTemplate;
+      }
     } else {
       assert(element.isVariadic() && "unhandled element group type");
       // If emitting with sizedSegments, then we add the actual list-typed
@@ -882,10 +890,10 @@ static void emitSegmentSpec(
   std::string segmentSpec("[");
   for (int i = 0, e = getNumElements(op); i < e; ++i) {
     const NamedTypeConstraint &element = getElement(op, i);
-    if (element.isVariableLength()) {
-      segmentSpec.append("-1,");
-    } else if (element.isOptional()) {
+    if (element.isOptional()) {
       segmentSpec.append("0,");
+    } else if (element.isVariadic()) {
+      segmentSpec.append("-1,");
     } else {
       segmentSpec.append("1,");
     }


        


More information about the Mlir-commits mailing list