[Mlir-commits] [mlir] 9a9214f - [mlir] Add C and python API for is_registered_operation.

Stella Laurenzo llvmlistbot at llvm.org
Tue Mar 30 22:57:10 PDT 2021


Author: Stella Laurenzo
Date: 2021-03-30T22:56:02-07:00
New Revision: 9a9214fa2562b397764193517fa540a3dcbfd5a1

URL: https://github.com/llvm/llvm-project/commit/9a9214fa2562b397764193517fa540a3dcbfd5a1
DIFF: https://github.com/llvm/llvm-project/commit/9a9214fa2562b397764193517fa540a3dcbfd5a1.diff

LOG: [mlir] Add C and python API for is_registered_operation.

* Suggested to be broken out of D99578

Differential Revision: https://reviews.llvm.org/D99638

Added: 
    

Modified: 
    mlir/include/mlir-c/IR.h
    mlir/lib/Bindings/Python/IRCore.cpp
    mlir/lib/CAPI/IR/IR.cpp
    mlir/test/Bindings/Python/dialects.py
    mlir/test/CAPI/ir.c

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index d807cd46dd58..048bd46679db 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -119,6 +119,13 @@ mlirContextGetNumLoadedDialects(MlirContext context);
 MLIR_CAPI_EXPORTED MlirDialect mlirContextGetOrLoadDialect(MlirContext context,
                                                            MlirStringRef name);
 
+/// Returns whether the given fully-qualified operation (i.e.
+/// 'dialect.operation') is registered with the context. This will return true
+/// if the dialect is loaded and the operation is registered within the
+/// dialect.
+MLIR_CAPI_EXPORTED bool mlirContextIsRegisteredOperation(MlirContext context,
+                                                         MlirStringRef name);
+
 //===----------------------------------------------------------------------===//
 // Dialect API.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 0a4c5fcb40c3..5046eedb1194 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1752,7 +1752,12 @@ void mlir::python::populateIRCore(py::module &m) {
           },
           [](PyMlirContext &self, bool value) {
             mlirContextSetAllowUnregisteredDialects(self.get(), value);
-          });
+          })
+      .def("is_registered_operation",
+           [](PyMlirContext &self, std::string &name) {
+             return mlirContextIsRegisteredOperation(
+                 self.get(), MlirStringRef{name.data(), name.size()});
+           });
 
   //----------------------------------------------------------------------------
   // Mapping of PyDialectDescriptor

diff  --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 67032a4b5540..14cde9633f52 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -60,6 +60,10 @@ MlirDialect mlirContextGetOrLoadDialect(MlirContext context,
   return wrap(unwrap(context)->getOrLoadDialect(unwrap(name)));
 }
 
+bool mlirContextIsRegisteredOperation(MlirContext context, MlirStringRef name) {
+  return unwrap(context)->isOperationRegistered(unwrap(name));
+}
+
 //===----------------------------------------------------------------------===//
 // Dialect API.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Bindings/Python/dialects.py b/mlir/test/Bindings/Python/dialects.py
index 41f4239e2b66..d5f5bee7f4b0 100644
--- a/mlir/test/Bindings/Python/dialects.py
+++ b/mlir/test/Bindings/Python/dialects.py
@@ -3,14 +3,17 @@
 import gc
 from mlir.ir import *
 
+
 def run(f):
   print("\nTEST:", f.__name__)
   f()
   gc.collect()
   assert Context._get_live_count() == 0
+  return f
 
 
 # CHECK-LABEL: TEST: testDialectDescriptor
+ at run
 def testDialectDescriptor():
   ctx = Context()
   d = ctx.get_dialect_descriptor("std")
@@ -25,10 +28,9 @@ def testDialectDescriptor():
   else:
     assert False, "Expected exception"
 
-run(testDialectDescriptor)
-
 
 # CHECK-LABEL: TEST: testUserDialectClass
+ at run
 def testUserDialectClass():
   ctx = Context()
   # Access using attribute.
@@ -60,14 +62,14 @@ def testUserDialectClass():
   # CHECK: <Dialect (class mlir.dialects._std_ops_gen._Dialect)>
   print(d)
 
-run(testUserDialectClass)
-
 
 # CHECK-LABEL: TEST: testCustomOpView
 # This test uses the standard dialect AddFOp as an example of a user op.
 # TODO: Op creation and access is still quite verbose: simplify this test as
 # additional capabilities come online.
+ at run
 def testCustomOpView():
+
   def createInput():
     op = Operation.create("pytest_dummy.intinput", results=[f32])
     # TODO: Auto result cast from operation
@@ -95,4 +97,12 @@ def createInput():
   m.operation.print()
 
 
-run(testCustomOpView)
+# CHECK-LABEL: TEST: testIsRegisteredOperation
+ at run
+def testIsRegisteredOperation():
+  ctx = Context()
+
+  # CHECK: std.cond_br: True
+  print(f"std.cond_br: {ctx.is_registered_operation('std.cond_br')}")
+  # CHECK: std.not_existing: False
+  print(f"std.not_existing: {ctx.is_registered_operation('std.not_existing')}")

diff  --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index 40ef39b19d26..5ce496c8a0e2 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -1442,6 +1442,22 @@ int registerOnlyStd() {
   fprintf(stderr, "@registration\n");
   // CHECK-LABEL: @registration
 
+  // CHECK: std.cond_br is_registered: 1
+  fprintf(stderr, "std.cond_br is_registered: %d\n",
+          mlirContextIsRegisteredOperation(
+              ctx, mlirStringRefCreateFromCString("std.cond_br")));
+
+  // CHECK: std.not_existing_op is_registered: 0
+  fprintf(stderr, "std.not_existing_op is_registered: %d\n",
+          mlirContextIsRegisteredOperation(
+              ctx, mlirStringRefCreateFromCString("std.not_existing_op")));
+
+  // CHECK: not_existing_dialect.not_existing_op is_registered: 0
+  fprintf(stderr, "not_existing_dialect.not_existing_op is_registered: %d\n",
+          mlirContextIsRegisteredOperation(
+              ctx, mlirStringRefCreateFromCString(
+                       "not_existing_dialect.not_existing_op")));
+
   return 0;
 }
 


        


More information about the Mlir-commits mailing list