[Mlir-commits] [mlir] ace1d0a - [mlir][python] Normalize asm-printing IR behavior.

Stella Laurenzo llvmlistbot at llvm.org
Sun Nov 28 18:03:25 PST 2021


Author: Stella Laurenzo
Date: 2021-11-28T18:02:01-08:00
New Revision: ace1d0ad3dc43e28715cbe2f3e0a5a76578bda9f

URL: https://github.com/llvm/llvm-project/commit/ace1d0ad3dc43e28715cbe2f3e0a5a76578bda9f
DIFF: https://github.com/llvm/llvm-project/commit/ace1d0ad3dc43e28715cbe2f3e0a5a76578bda9f.diff

LOG: [mlir][python] Normalize asm-printing IR behavior.

While working on an integration, I found a lot of inconsistencies on IR printing and verification. It turns out that we were:
  * Only doing "soft fail" verification on IR printing of Operation, not of a Module.
  * Failed verification was interacting badly with binary=True IR printing (causing a TypeError trying to pass an `str` to a `bytes` based handle).
  * For systematic integrations, it is often desirable to control verification yourself so that you can explicitly handle errors.

This patch:
  * Trues up the "soft fail" semantics by having `Module.__str__` delegate to `Operation.__str__` vs having a shortcut implementation.
  * Fixes soft fail in the presence of binary=True (and adds an additional happy path test case to make sure the binary functionality works).
  * Adds an `assume_verified` boolean flag to the `print`/`get_asm` methods which disables internal verification, presupposing that the caller has taken care of it.

It turns out that we had a number of tests which were generating illegal IR but it wasn't being caught because they were doing a print on the `Module` vs operation. All except two were trivially fixed:
  * linalg/ops.py : Had two tests for direct constructing a Matmul incorrectly. Fixing them made them just like the next two tests so just deleted (no need to test the verifier only at this level).
  * linalg/opdsl/emit_structured_generic.py : Hand coded conv and pooling tests appear to be using illegal shaped inputs/outputs, causing a verification failure. I just used the `assume_verified=` flag to restore the original behavior and left a TODO. Will get someone who owns that to fix it properly in a followup (would also be nice to break this file up into multiple test modules as it is hard to tell exactly what is failing).

Notes to downstreams:
  * If, like some of our tests, you get verification failures after this patch, it is likely that your IR was always invalid and you will need to fix the root cause. To temporarily revert to prior (broken) behavior, replace calls like `print(module)` with `print(module.operation.get_asm(assume_verified=True))`.

Differential Revision: https://reviews.llvm.org/D114680

Added: 
    

Modified: 
    mlir/lib/Bindings/Python/IRCore.cpp
    mlir/lib/Bindings/Python/IRModule.h
    mlir/test/python/dialects/builtin.py
    mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py
    mlir/test/python/dialects/linalg/ops.py
    mlir/test/python/dialects/shape.py
    mlir/test/python/dialects/std.py
    mlir/test/python/ir/module.py
    mlir/test/python/ir/operation.py

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 4c25fd4505b76..c70cfc5654f3b 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -93,6 +93,13 @@ static const char kOperationPrintDocstring[] =
   use_local_Scope: Whether to print in a way that is more optimized for
     multi-threaded access but may not be consistent with how the overall
     module prints.
+  assume_verified: By default, if not printing generic form, the verifier
+    will be run and if it fails, generic form will be printed with a comment
+    about failed verification. While a reasonable default for interactive use,
+    for systematic use, it is often better for the caller to verify explicitly
+    and report failures in a more robust fashion. Set this to True if doing this
+    in order to avoid running a redundant verification. If the IR is actually
+    invalid, behavior is undefined.
 )";
 
 static const char kOperationGetAsmDocstring[] =
@@ -828,14 +835,21 @@ void PyOperation::checkValid() const {
 void PyOperationBase::print(py::object fileObject, bool binary,
                             llvm::Optional<int64_t> largeElementsLimit,
                             bool enableDebugInfo, bool prettyDebugInfo,
-                            bool printGenericOpForm, bool useLocalScope) {
+                            bool printGenericOpForm, bool useLocalScope,
+                            bool assumeVerified) {
   PyOperation &operation = getOperation();
   operation.checkValid();
   if (fileObject.is_none())
     fileObject = py::module::import("sys").attr("stdout");
 
-  if (!printGenericOpForm && !mlirOperationVerify(operation)) {
-    fileObject.attr("write")("// Verification failed, printing generic form\n");
+  if (!assumeVerified && !printGenericOpForm &&
+      !mlirOperationVerify(operation)) {
+    std::string message("// Verification failed, printing generic form\n");
+    if (binary) {
+      fileObject.attr("write")(py::bytes(message));
+    } else {
+      fileObject.attr("write")(py::str(message));
+    }
     printGenericOpForm = true;
   }
 
@@ -857,8 +871,8 @@ void PyOperationBase::print(py::object fileObject, bool binary,
 py::object PyOperationBase::getAsm(bool binary,
                                    llvm::Optional<int64_t> largeElementsLimit,
                                    bool enableDebugInfo, bool prettyDebugInfo,
-                                   bool printGenericOpForm,
-                                   bool useLocalScope) {
+                                   bool printGenericOpForm, bool useLocalScope,
+                                   bool assumeVerified) {
   py::object fileObject;
   if (binary) {
     fileObject = py::module::import("io").attr("BytesIO")();
@@ -870,7 +884,8 @@ py::object PyOperationBase::getAsm(bool binary,
         /*enableDebugInfo=*/enableDebugInfo,
         /*prettyDebugInfo=*/prettyDebugInfo,
         /*printGenericOpForm=*/printGenericOpForm,
-        /*useLocalScope=*/useLocalScope);
+        /*useLocalScope=*/useLocalScope,
+        /*assumeVerified=*/assumeVerified);
 
   return fileObject.attr("getvalue")();
 }
@@ -2149,12 +2164,9 @@ void mlir::python::populateIRCore(py::module &m) {
           kDumpDocstring)
       .def(
           "__str__",
-          [](PyModule &self) {
-            MlirOperation operation = mlirModuleGetOperation(self.get());
-            PyPrintAccumulator printAccum;
-            mlirOperationPrint(operation, printAccum.getCallback(),
-                               printAccum.getUserData());
-            return printAccum.join();
+          [](py::object self) {
+            // Defer to the operation's __str__.
+            return self.attr("operation").attr("__str__")();
           },
           kOperationStrDunderDocstring);
 
@@ -2234,7 +2246,8 @@ void mlir::python::populateIRCore(py::module &m) {
                                /*enableDebugInfo=*/false,
                                /*prettyDebugInfo=*/false,
                                /*printGenericOpForm=*/false,
-                               /*useLocalScope=*/false);
+                               /*useLocalScope=*/false,
+                               /*assumeVerified=*/false);
           },
           "Returns the assembly form of the operation.")
       .def("print", &PyOperationBase::print,
@@ -2244,7 +2257,8 @@ void mlir::python::populateIRCore(py::module &m) {
            py::arg("enable_debug_info") = false,
            py::arg("pretty_debug_info") = false,
            py::arg("print_generic_op_form") = false,
-           py::arg("use_local_scope") = false, kOperationPrintDocstring)
+           py::arg("use_local_scope") = false,
+           py::arg("assume_verified") = false, kOperationPrintDocstring)
       .def("get_asm", &PyOperationBase::getAsm,
            // Careful: Lots of arguments must match up with get_asm method.
            py::arg("binary") = false,
@@ -2252,7 +2266,8 @@ void mlir::python::populateIRCore(py::module &m) {
            py::arg("enable_debug_info") = false,
            py::arg("pretty_debug_info") = false,
            py::arg("print_generic_op_form") = false,
-           py::arg("use_local_scope") = false, kOperationGetAsmDocstring)
+           py::arg("use_local_scope") = false,
+           py::arg("assume_verified") = false, kOperationGetAsmDocstring)
       .def(
           "verify",
           [](PyOperationBase &self) {

diff  --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index eb5c2385a165d..dc024a24793e0 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -394,11 +394,13 @@ class PyOperationBase {
   /// Implements the bound 'print' method and helps with others.
   void print(pybind11::object fileObject, bool binary,
              llvm::Optional<int64_t> largeElementsLimit, bool enableDebugInfo,
-             bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope);
+             bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope,
+             bool assumeVerified);
   pybind11::object getAsm(bool binary,
                           llvm::Optional<int64_t> largeElementsLimit,
                           bool enableDebugInfo, bool prettyDebugInfo,
-                          bool printGenericOpForm, bool useLocalScope);
+                          bool printGenericOpForm, bool useLocalScope,
+                          bool assumeVerified);
 
   /// Moves the operation before or after the other operation.
   void moveAfter(PyOperationBase &other);

diff  --git a/mlir/test/python/dialects/builtin.py b/mlir/test/python/dialects/builtin.py
index 8f3a041937b9d..7caf5b5d892ce 100644
--- a/mlir/test/python/dialects/builtin.py
+++ b/mlir/test/python/dialects/builtin.py
@@ -175,7 +175,8 @@ def testBuildFuncOp():
 # CHECK-LABEL: TEST: testFuncArgumentAccess
 @run
 def testFuncArgumentAccess():
-  with Context(), Location.unknown():
+  with Context() as ctx, Location.unknown():
+    ctx.allow_unregistered_dialects = True
     module = Module.create()
     f32 = F32Type.get()
     f64 = F64Type.get()
@@ -185,38 +186,38 @@ def testFuncArgumentAccess():
         std.ReturnOp(func.arguments)
       func.arg_attrs = ArrayAttr.get([
           DictAttr.get({
-              "foo": StringAttr.get("bar"),
-              "baz": UnitAttr.get()
+              "custom_dialect.foo": StringAttr.get("bar"),
+              "custom_dialect.baz": UnitAttr.get()
           }),
-          DictAttr.get({"qux": ArrayAttr.get([])})
+          DictAttr.get({"custom_dialect.qux": ArrayAttr.get([])})
       ])
       func.result_attrs = ArrayAttr.get([
-          DictAttr.get({"res1": FloatAttr.get(f32, 42.0)}),
-          DictAttr.get({"res2": FloatAttr.get(f64, 256.0)})
+          DictAttr.get({"custom_dialect.res1": FloatAttr.get(f32, 42.0)}),
+          DictAttr.get({"custom_dialect.res2": FloatAttr.get(f64, 256.0)})
       ])
 
       other = builtin.FuncOp("other_func", ([f32, f32], []))
       with InsertionPoint(other.add_entry_block()):
         std.ReturnOp([])
       other.arg_attrs = [
-          DictAttr.get({"foo": StringAttr.get("qux")}),
+          DictAttr.get({"custom_dialect.foo": StringAttr.get("qux")}),
           DictAttr.get()
       ]
 
-  # CHECK: [{baz, foo = "bar"}, {qux = []}]
+  # CHECK: [{custom_dialect.baz, custom_dialect.foo = "bar"}, {custom_dialect.qux = []}]
   print(func.arg_attrs)
 
-  # CHECK: [{res1 = 4.200000e+01 : f32}, {res2 = 2.560000e+02 : f64}]
+  # CHECK: [{custom_dialect.res1 = 4.200000e+01 : f32}, {custom_dialect.res2 = 2.560000e+02 : f64}]
   print(func.result_attrs)
 
   # CHECK: func @some_func(
-  # CHECK: %[[ARG0:.*]]: f32 {baz, foo = "bar"},
-  # CHECK: %[[ARG1:.*]]: f32 {qux = []}) ->
-  # CHECK: f32 {res1 = 4.200000e+01 : f32},
-  # CHECK: f32 {res2 = 2.560000e+02 : f64})
+  # CHECK: %[[ARG0:.*]]: f32 {custom_dialect.baz, custom_dialect.foo = "bar"},
+  # CHECK: %[[ARG1:.*]]: f32 {custom_dialect.qux = []}) ->
+  # CHECK: f32 {custom_dialect.res1 = 4.200000e+01 : f32},
+  # CHECK: f32 {custom_dialect.res2 = 2.560000e+02 : f64})
   # CHECK: return %[[ARG0]], %[[ARG1]] : f32, f32
   #
   # CHECK: func @other_func(
-  # CHECK: %{{.*}}: f32 {foo = "qux"},
+  # CHECK: %{{.*}}: f32 {custom_dialect.foo = "qux"},
   # CHECK: %{{.*}}: f32)
   print(module)

diff  --git a/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py b/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py
index d0c74270950ef..115c2272fbf5d 100644
--- a/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py
+++ b/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py
@@ -405,4 +405,7 @@ def test_non_default_op_name(input, init_result):
       return non_default_op_name(input, outs=[init_result])
 
 
-print(module)
+# TODO: Fix me! Conv and pooling ops above do not verify, which was uncovered
+# when switching to more robust module verification. For now, reverting to the
+# old behavior which does not verify on module print.
+print(module.operation.get_asm(assume_verified=True))

diff  --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index e5b96c260eaad..4f9f138683b83 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -83,49 +83,6 @@ def fill_buffer(out):
   print(module)
 
 
-# CHECK-LABEL: TEST: testStructuredOpOnTensors
- at run
-def testStructuredOpOnTensors():
-  with Context() as ctx, Location.unknown():
-    module = Module.create()
-    f32 = F32Type.get()
-    tensor_type = RankedTensorType.get((2, 3, 4), f32)
-    with InsertionPoint(module.body):
-      func = builtin.FuncOp(
-          name="matmul_test",
-          type=FunctionType.get(
-              inputs=[tensor_type, tensor_type], results=[tensor_type]))
-      with InsertionPoint(func.add_entry_block()):
-        lhs, rhs = func.entry_block.arguments
-        result = linalg.MatmulOp([lhs, rhs], results=[tensor_type]).result
-        std.ReturnOp([result])
-
-  # CHECK: %[[R:.*]] = linalg.matmul ins(%arg0, %arg1 : tensor<2x3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32>
-  print(module)
-
-
-# CHECK-LABEL: TEST: testStructuredOpOnBuffers
- at run
-def testStructuredOpOnBuffers():
-  with Context() as ctx, Location.unknown():
-    module = Module.create()
-    f32 = F32Type.get()
-    memref_type = MemRefType.get((2, 3, 4), f32)
-    with InsertionPoint(module.body):
-      func = builtin.FuncOp(
-          name="matmul_test",
-          type=FunctionType.get(
-              inputs=[memref_type, memref_type, memref_type], results=[]))
-      with InsertionPoint(func.add_entry_block()):
-        lhs, rhs, result = func.entry_block.arguments
-        # TODO: prperly hook up the region.
-        linalg.MatmulOp([lhs, rhs], outputs=[result])
-        std.ReturnOp([])
-
-  # CHECK: linalg.matmul ins(%arg0, %arg1 : memref<2x3x4xf32>, memref<2x3x4xf32>) outs(%arg2 : memref<2x3x4xf32>)
-  print(module)
-
-
 # CHECK-LABEL: TEST: testNamedStructuredOpCustomForm
 @run
 def testNamedStructuredOpCustomForm():

diff  --git a/mlir/test/python/dialects/shape.py b/mlir/test/python/dialects/shape.py
index 7c1c5d6f1cf9b..a798b85843dad 100644
--- a/mlir/test/python/dialects/shape.py
+++ b/mlir/test/python/dialects/shape.py
@@ -22,7 +22,8 @@ def testConstShape():
       @builtin.FuncOp.from_py_func(
           RankedTensorType.get((12, -1), f32))
       def const_shape_tensor(arg):
-        return shape.ConstShapeOp(DenseElementsAttr.get(np.array([10, 20])))
+        return shape.ConstShapeOp(
+          DenseElementsAttr.get(np.array([10, 20]), type=IndexType.get()))
 
     # CHECK-LABEL: func @const_shape_tensor(%arg0: tensor<12x?xf32>)
     # CHECK: shape.const_shape [10, 20] : tensor<2xindex>

diff  --git a/mlir/test/python/dialects/std.py b/mlir/test/python/dialects/std.py
index f6e77ca6156ed..2a3b2df88e4b0 100644
--- a/mlir/test/python/dialects/std.py
+++ b/mlir/test/python/dialects/std.py
@@ -78,8 +78,11 @@ def testConstantIndexOp():
 @constructAndPrintInModule
 def testFunctionCalls():
   foo = builtin.FuncOp("foo", ([], []))
+  foo.sym_visibility = StringAttr.get("private")
   bar = builtin.FuncOp("bar", ([], [IndexType.get()]))
+  bar.sym_visibility = StringAttr.get("private")
   qux = builtin.FuncOp("qux", ([], [F32Type.get()]))
+  qux.sym_visibility = StringAttr.get("private")
 
   with InsertionPoint(builtin.FuncOp("caller", ([], [])).add_entry_block()):
     std.CallOp(foo, [])
@@ -88,9 +91,9 @@ def testFunctionCalls():
     std.ReturnOp([])
 
 
-# CHECK: func @foo()
-# CHECK: func @bar() -> index
-# CHECK: func @qux() -> f32
+# CHECK: func private @foo()
+# CHECK: func private @bar() -> index
+# CHECK: func private @qux() -> f32
 # CHECK: func @caller() {
 # CHECK:   call @foo() : () -> ()
 # CHECK:   %0 = call @bar() : () -> index

diff  --git a/mlir/test/python/ir/module.py b/mlir/test/python/ir/module.py
index abddc66eb47f8..76358eb434c3b 100644
--- a/mlir/test/python/ir/module.py
+++ b/mlir/test/python/ir/module.py
@@ -8,11 +8,13 @@ def run(f):
   f()
   gc.collect()
   assert Context._get_live_count() == 0
+  return f
 
 
 # Verify successful parse.
 # CHECK-LABEL: TEST: testParseSuccess
 # CHECK: module @successfulParse
+ at run
 def testParseSuccess():
   ctx = Context()
   module = Module.parse(r"""module @successfulParse {}""", ctx)
@@ -23,12 +25,11 @@ def testParseSuccess():
   module.dump()  # Just outputs to stderr. Verifies that it functions.
   print(str(module))
 
-run(testParseSuccess)
-
 
 # Verify parse error.
 # CHECK-LABEL: TEST: testParseError
 # CHECK: testParseError: Unable to parse module assembly (see diagnostics)
+ at run
 def testParseError():
   ctx = Context()
   try:
@@ -38,12 +39,11 @@ def testParseError():
   else:
     print("Exception not produced")
 
-run(testParseError)
-
 
 # Verify successful parse.
 # CHECK-LABEL: TEST: testCreateEmpty
 # CHECK: module {
+ at run
 def testCreateEmpty():
   ctx = Context()
   loc = Location.unknown(ctx)
@@ -53,8 +53,6 @@ def testCreateEmpty():
   gc.collect()
   print(str(module))
 
-run(testCreateEmpty)
-
 
 # Verify round-trip of ASM that contains unicode.
 # Note that this does not test that the print path converts unicode properly
@@ -62,6 +60,7 @@ def testCreateEmpty():
 # CHECK-LABEL: TEST: testRoundtripUnicode
 # CHECK: func private @roundtripUnicode()
 # CHECK: foo = "\F0\9F\98\8A"
+ at run
 def testRoundtripUnicode():
   ctx = Context()
   module = Module.parse(r"""
@@ -69,11 +68,28 @@ def testRoundtripUnicode():
   """, ctx)
   print(str(module))
 
-run(testRoundtripUnicode)
+
+# Verify round-trip of ASM that contains unicode.
+# Note that this does not test that the print path converts unicode properly
+# because MLIR asm always normalizes it to the hex encoding.
+# CHECK-LABEL: TEST: testRoundtripBinary
+# CHECK: func private @roundtripUnicode()
+# CHECK: foo = "\F0\9F\98\8A"
+ at run
+def testRoundtripBinary():
+  with Context():
+    module = Module.parse(r"""
+      func private @roundtripUnicode() attributes { foo = "😊" }
+    """)
+    binary_asm = module.operation.get_asm(binary=True)
+    assert isinstance(binary_asm, bytes)
+    module = Module.parse(binary_asm)
+    print(module)
 
 
 # Tests that module.operation works and correctly interns instances.
 # CHECK-LABEL: TEST: testModuleOperation
+ at run
 def testModuleOperation():
   ctx = Context()
   module = Module.parse(r"""module @successfulParse {}""", ctx)
@@ -101,10 +117,9 @@ def testModuleOperation():
   assert ctx._get_live_operation_count() == 0
   assert ctx._get_live_module_count() == 0
 
-run(testModuleOperation)
-
 
 # CHECK-LABEL: TEST: testModuleCapsule
+ at run
 def testModuleCapsule():
   ctx = Context()
   module = Module.parse(r"""module @successfulParse {}""", ctx)
@@ -122,5 +137,3 @@ def testModuleCapsule():
   gc.collect()
   assert ctx._get_live_module_count() == 0
 
-
-run(testModuleCapsule)

diff  --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 8771ca046b8b8..133edc2e1aee5 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -630,21 +630,50 @@ def testSingleResultProperty():
   print(module.body.operations[2])
 
 
-# CHECK-LABEL: TEST: testPrintInvalidOperation
+def create_invalid_operation():
+  # This module has two region and is invalid verify that we fallback
+  # to the generic printer for safety.
+  op = Operation.create("builtin.module", regions=2)
+  op.regions[0].blocks.append()
+  return op
+
+# CHECK-LABEL: TEST: testInvalidOperationStrSoftFails
 @run
-def testPrintInvalidOperation():
+def testInvalidOperationStrSoftFails():
   ctx = Context()
   with Location.unknown(ctx):
-    module = Operation.create("builtin.module", regions=2)
-    # This module has two region and is invalid verify that we fallback
-    # to the generic printer for safety.
-    block = module.regions[0].blocks.append()
+    invalid_op = create_invalid_operation()
+    # Verify that we fallback to the generic printer for safety.
     # CHECK: // Verification failed, printing generic form
     # CHECK: "builtin.module"() ( {
     # CHECK: }) : () -> ()
-    print(module)
+    print(invalid_op)
     # CHECK: .verify = False
-    print(f".verify = {module.operation.verify()}")
+    print(f".verify = {invalid_op.operation.verify()}")
+
+
+# CHECK-LABEL: TEST: testInvalidModuleStrSoftFails
+ at run
+def testInvalidModuleStrSoftFails():
+  ctx = Context()
+  with Location.unknown(ctx):
+    module = Module.create()
+    with InsertionPoint(module.body):
+      invalid_op = create_invalid_operation()
+    # Verify that we fallback to the generic printer for safety.
+    # CHECK: // Verification failed, printing generic form
+    print(module)
+
+
+# CHECK-LABEL: TEST: testInvalidOperationGetAsmBinarySoftFails
+ at run
+def testInvalidOperationGetAsmBinarySoftFails():
+  ctx = Context()
+  with Location.unknown(ctx):
+    invalid_op = create_invalid_operation()
+    # Verify that we fallback to the generic printer for safety.
+    # CHECK: b'// Verification failed, printing generic form\n
+    print(invalid_op.get_asm(binary=True))
 
 
 # CHECK-LABEL: TEST: testCreateWithInvalidAttributes


        


More information about the Mlir-commits mailing list