[Mlir-commits] [mlir] 5801885 - Defend early against operation created without a registered dialect

Mehdi Amini llvmlistbot at llvm.org
Wed Jul 14 20:03:18 PDT 2021


Author: Mehdi Amini
Date: 2021-07-15T03:02:52Z
New Revision: 58018858e887320e2432e2e00ace13273b8a1f29

URL: https://github.com/llvm/llvm-project/commit/58018858e887320e2432e2e00ace13273b8a1f29
DIFF: https://github.com/llvm/llvm-project/commit/58018858e887320e2432e2e00ace13273b8a1f29.diff

LOG: Defend early against operation created without a registered dialect

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D105961

Added: 
    

Modified: 
    mlir/lib/IR/BuiltinAttributes.cpp
    mlir/lib/IR/BuiltinTypes.cpp
    mlir/lib/IR/Operation.cpp
    mlir/lib/IR/Verifier.cpp
    mlir/lib/Parser/Parser.cpp
    mlir/test/CAPI/ir.c
    mlir/test/Dialect/PDL/invalid.mlir
    mlir/test/IR/invalid-module-op.mlir
    mlir/test/IR/invalid-unregistered.mlir
    mlir/test/lit.cfg.py
    mlir/unittests/Pass/PassManagerTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index 763ab803e6042..85914be243b92 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -360,7 +360,7 @@ LogicalResult OpaqueAttr::verify(function_ref<InFlightDiagnostic()> emitError,
            << " attribute created with unregistered dialect. If this is "
               "intended, please call allowUnregisteredDialects() on the "
               "MLIRContext, or use -allow-unregistered-dialect with "
-              "mlir-opt";
+              "the MLIR opt tool used";
   }
 
   return success();

diff  --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index f350596384a90..5dfb14b60abf8 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -264,7 +264,7 @@ LogicalResult OpaqueType::verify(function_ref<InFlightDiagnostic()> emitError,
            << "` type created with unregistered dialect. If this is "
               "intended, please call allowUnregisteredDialects() on the "
               "MLIRContext, or use -allow-unregistered-dialect with "
-              "mlir-opt";
+              "the MLIR opt tool used";
   }
 
   return success();

diff  --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index cff686574efb9..97059ba39bafb 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -176,6 +176,14 @@ Operation::Operation(Location location, OperationName name, unsigned numResults,
       numRegions(numRegions), hasOperandStorage(hasOperandStorage), name(name),
       attrs(attributes) {
   assert(attributes && "unexpected null attribute dictionary");
+#ifndef NDEBUG
+  if (!getDialect() && !getContext()->allowsUnregisteredDialects())
+    llvm::report_fatal_error(
+        name.getStringRef() +
+        " created with unregistered dialect. If this is intended, please call "
+        "allowUnregisteredDialects() on the MLIRContext, or use "
+        "-allow-unregistered-dialect with the MLIR opt tool used");
+#endif
 }
 
 // Operations are deleted through the destroy() member because they are

diff  --git a/mlir/lib/IR/Verifier.cpp b/mlir/lib/IR/Verifier.cpp
index 42068677c1ced..084fd8c1fd569 100644
--- a/mlir/lib/IR/Verifier.cpp
+++ b/mlir/lib/IR/Verifier.cpp
@@ -230,7 +230,7 @@ LogicalResult OperationVerifier::verifyOperation(
              << "created with unregistered dialect. If this is "
                 "intended, please call allowUnregisteredDialects() on the "
                 "MLIRContext, or use -allow-unregistered-dialect with "
-                "mlir-opt";
+                "the MLIR opt tool used";
     }
     return success();
   }

diff  --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index 08d8dd9c8ac2d..3bf226013cc44 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -653,7 +653,7 @@ Value OperationParser::resolveSSAUse(SSAUseInfo useInfo, Type type) {
 
   // Otherwise, this is a forward reference.  Create a placeholder and remember
   // that we did so.
-  auto result = createForwardRefPlaceholder(useInfo.loc, type);
+  Value result = createForwardRefPlaceholder(useInfo.loc, type);
   entries[useInfo.number] = {result, useInfo.loc};
   return maybeRecordUse(result);
 }
@@ -730,7 +730,7 @@ Value OperationParser::createForwardRefPlaceholder(SMLoc loc, Type type) {
   // We create these placeholders as having an empty name, which we know
   // cannot be created through normal user input, allowing us to distinguish
   // them.
-  auto name = OperationName("placeholder", getContext());
+  auto name = OperationName("unrealized_conversion_cast", getContext());
   auto *op = Operation::create(
       getEncodedSourceLocation(loc), name, type, /*operands=*/{},
       /*attributes=*/llvm::None, /*successors=*/{}, /*numRegions=*/0);

diff  --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index ac8a0dc266870..3bff3b8e5d752 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -479,6 +479,7 @@ static int constructAndTraverseIr(MlirContext ctx) {
 /// block/operation-relative API and their final order is checked.
 static void buildWithInsertionsAndPrint(MlirContext ctx) {
   MlirLocation loc = mlirLocationUnknownGet(ctx);
+  mlirContextSetAllowUnregisteredDialects(ctx, true);
 
   MlirRegion owningRegion = mlirRegionCreate();
   MlirBlock nullBlock = mlirRegionGetFirstBlock(owningRegion);
@@ -542,6 +543,7 @@ static void buildWithInsertionsAndPrint(MlirContext ctx) {
 
   mlirOperationDump(op);
   mlirOperationDestroy(op);
+  mlirContextSetAllowUnregisteredDialects(ctx, false);
   // clang-format off
   // CHECK-LABEL:  "insertion.order.test"
   // CHECK:      ^{{.*}}(%{{.*}}: i1
@@ -1561,6 +1563,8 @@ int testOperands() {
   // CHECK-LABEL: @testOperands
 
   MlirContext ctx = mlirContextCreate();
+  mlirRegisterAllDialects(ctx);
+  mlirContextGetOrLoadDialect(ctx, mlirStringRefCreateFromCString("test"));
   MlirLocation loc = mlirLocationUnknownGet(ctx);
   MlirType indexType = mlirIndexTypeGet(ctx);
 
@@ -1590,6 +1594,7 @@ int testOperands() {
   MlirValue constOneValue = mlirOperationGetResult(constOne, 0);
 
   // Create the operation under test.
+  mlirContextSetAllowUnregisteredDialects(ctx, true);
   MlirOperationState opState =
       mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op"), loc);
   MlirValue initialOperands[] = {constZeroValue};
@@ -1604,13 +1609,13 @@ int testOperands() {
   MlirValue opOperand = mlirOperationGetOperand(op, 0);
   fprintf(stderr, "Original operand: ");
   mlirValuePrint(opOperand, printToStderr, NULL);
-  // CHECK: Original operand: {{.+}} {value = 0 : index}
+  // CHECK: Original operand: {{.+}} constant 0 : index
 
   mlirOperationSetOperand(op, 0, constOneValue);
   opOperand = mlirOperationGetOperand(op, 0);
   fprintf(stderr, "Updated operand: ");
   mlirValuePrint(opOperand, printToStderr, NULL);
-  // CHECK: Updated operand: {{.+}} {value = 1 : index}
+  // CHECK: Updated operand: {{.+}} constant 1 : index
 
   mlirOperationDestroy(op);
   mlirOperationDestroy(constZero);
@@ -1626,6 +1631,8 @@ int testClone() {
   // CHECK-LABEL: @testClone
 
   MlirContext ctx = mlirContextCreate();
+  mlirRegisterAllDialects(ctx);
+  mlirContextGetOrLoadDialect(ctx, mlirStringRefCreateFromCString("std"));
   MlirLocation loc = mlirLocationUnknownGet(ctx);
   MlirType indexType = mlirIndexTypeGet(ctx);
   MlirStringRef valueStringRef =  mlirStringRefCreateFromCString("value");
@@ -1646,8 +1653,8 @@ int testClone() {
 
   mlirOperationPrint(constZero, printToStderr, NULL);
   mlirOperationPrint(constOne, printToStderr, NULL);
-  // CHECK: %0 = "std.constant"() {value = 0 : index} : () -> index
-  // CHECK: %0 = "std.constant"() {value = 1 : index} : () -> index
+  // CHECK: constant 0 : index
+  // CHECK: constant 1 : index
 
   return 0;
 }

diff  --git a/mlir/test/Dialect/PDL/invalid.mlir b/mlir/test/Dialect/PDL/invalid.mlir
index e371d84086704..009b622ef3d68 100644
--- a/mlir/test/Dialect/PDL/invalid.mlir
+++ b/mlir/test/Dialect/PDL/invalid.mlir
@@ -167,7 +167,7 @@ pdl.pattern : benefit(1) {
 // expected-error at below {{expected only `pdl` operations within the pattern body}}
 pdl.pattern : benefit(1) {
   // expected-note at below {{see non-`pdl` operation defined here}}
-  "foo.other_op"() : () -> ()
+  "test.foo.other_op"() : () -> ()
 
   %root = pdl.operation "foo.op"
   pdl.rewrite %root with "foo"

diff  --git a/mlir/test/IR/invalid-module-op.mlir b/mlir/test/IR/invalid-module-op.mlir
index 741a3a9b2dc94..d0de3db74d4f5 100644
--- a/mlir/test/IR/invalid-module-op.mlir
+++ b/mlir/test/IR/invalid-module-op.mlir
@@ -6,9 +6,9 @@ func @module_op() {
   // expected-error at +1 {{Operations with a 'SymbolTable' must have exactly one block}}
   module {
   ^bb1:
-    "module_terminator"() : () -> ()
+    "test.dummy"() : () -> ()
   ^bb2:
-    "module_terminator"() : () -> ()
+    "test.dummy"() : () -> ()
   }
   return
 }

diff  --git a/mlir/test/IR/invalid-unregistered.mlir b/mlir/test/IR/invalid-unregistered.mlir
index 37ac45ef6d2a0..30d4b7af88ed4 100644
--- a/mlir/test/IR/invalid-unregistered.mlir
+++ b/mlir/test/IR/invalid-unregistered.mlir
@@ -1,5 +1,7 @@
 // RUN: mlir-opt %s -split-input-file -verify-diagnostics
 
+// REQUIRES: noasserts
+
 // expected-error @below {{op created with unregistered dialect}}
 "unregistered_dialect.op"() : () -> ()
 

diff  --git a/mlir/test/lit.cfg.py b/mlir/test/lit.cfg.py
index dd38b8fec864f..a362645bb75c8 100644
--- a/mlir/test/lit.cfg.py
+++ b/mlir/test/lit.cfg.py
@@ -111,3 +111,8 @@
         # lib/Bindings/Python/CMakeLists.txt for where this is set up.
         os.path.join(config.llvm_obj_root, 'python'),
     ], append_path=True)
+
+if config.enable_assertions:
+    config.available_features.add('asserts')
+else:
+    config.available_features.add('noasserts')

diff  --git a/mlir/unittests/Pass/PassManagerTest.cpp b/mlir/unittests/Pass/PassManagerTest.cpp
index 6e4283d3a3e42..8db7e5e66cd8e 100644
--- a/mlir/unittests/Pass/PassManagerTest.cpp
+++ b/mlir/unittests/Pass/PassManagerTest.cpp
@@ -91,6 +91,7 @@ struct InvalidPass : Pass {
 
 TEST(PassManagerTest, InvalidPass) {
   MLIRContext context;
+  context.allowUnregisteredDialects();
 
   // Create a module
   OwningModuleRef module(ModuleOp::create(UnknownLoc::get(&context)));


        


More information about the Mlir-commits mailing list