[Mlir-commits] [mlir] 9a9214f - [mlir] Add C and python API for is_registered_operation.
Stella Laurenzo
llvmlistbot at llvm.org
Tue Mar 30 22:57:10 PDT 2021
Author: Stella Laurenzo
Date: 2021-03-30T22:56:02-07:00
New Revision: 9a9214fa2562b397764193517fa540a3dcbfd5a1
URL: https://github.com/llvm/llvm-project/commit/9a9214fa2562b397764193517fa540a3dcbfd5a1
DIFF: https://github.com/llvm/llvm-project/commit/9a9214fa2562b397764193517fa540a3dcbfd5a1.diff
LOG: [mlir] Add C and python API for is_registered_operation.
* Suggested to be broken out of D99578
Differential Revision: https://reviews.llvm.org/D99638
Added:
Modified:
mlir/include/mlir-c/IR.h
mlir/lib/Bindings/Python/IRCore.cpp
mlir/lib/CAPI/IR/IR.cpp
mlir/test/Bindings/Python/dialects.py
mlir/test/CAPI/ir.c
Removed:
################################################################################
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index d807cd46dd58..048bd46679db 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -119,6 +119,13 @@ mlirContextGetNumLoadedDialects(MlirContext context);
MLIR_CAPI_EXPORTED MlirDialect mlirContextGetOrLoadDialect(MlirContext context,
MlirStringRef name);
+/// Returns whether the given fully-qualified operation (i.e.
+/// 'dialect.operation') is registered with the context. This will return true
+/// if the dialect is loaded and the operation is registered within the
+/// dialect.
+MLIR_CAPI_EXPORTED bool mlirContextIsRegisteredOperation(MlirContext context,
+ MlirStringRef name);
+
//===----------------------------------------------------------------------===//
// Dialect API.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 0a4c5fcb40c3..5046eedb1194 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1752,7 +1752,12 @@ void mlir::python::populateIRCore(py::module &m) {
},
[](PyMlirContext &self, bool value) {
mlirContextSetAllowUnregisteredDialects(self.get(), value);
- });
+ })
+ .def("is_registered_operation",
+ [](PyMlirContext &self, std::string &name) {
+ return mlirContextIsRegisteredOperation(
+ self.get(), MlirStringRef{name.data(), name.size()});
+ });
//----------------------------------------------------------------------------
// Mapping of PyDialectDescriptor
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 67032a4b5540..14cde9633f52 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -60,6 +60,10 @@ MlirDialect mlirContextGetOrLoadDialect(MlirContext context,
return wrap(unwrap(context)->getOrLoadDialect(unwrap(name)));
}
+bool mlirContextIsRegisteredOperation(MlirContext context, MlirStringRef name) {
+ return unwrap(context)->isOperationRegistered(unwrap(name));
+}
+
//===----------------------------------------------------------------------===//
// Dialect API.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Bindings/Python/dialects.py b/mlir/test/Bindings/Python/dialects.py
index 41f4239e2b66..d5f5bee7f4b0 100644
--- a/mlir/test/Bindings/Python/dialects.py
+++ b/mlir/test/Bindings/Python/dialects.py
@@ -3,14 +3,17 @@
import gc
from mlir.ir import *
+
def run(f):
print("\nTEST:", f.__name__)
f()
gc.collect()
assert Context._get_live_count() == 0
+ return f
# CHECK-LABEL: TEST: testDialectDescriptor
+ at run
def testDialectDescriptor():
ctx = Context()
d = ctx.get_dialect_descriptor("std")
@@ -25,10 +28,9 @@ def testDialectDescriptor():
else:
assert False, "Expected exception"
-run(testDialectDescriptor)
-
# CHECK-LABEL: TEST: testUserDialectClass
+ at run
def testUserDialectClass():
ctx = Context()
# Access using attribute.
@@ -60,14 +62,14 @@ def testUserDialectClass():
# CHECK: <Dialect (class mlir.dialects._std_ops_gen._Dialect)>
print(d)
-run(testUserDialectClass)
-
# CHECK-LABEL: TEST: testCustomOpView
# This test uses the standard dialect AddFOp as an example of a user op.
# TODO: Op creation and access is still quite verbose: simplify this test as
# additional capabilities come online.
+ at run
def testCustomOpView():
+
def createInput():
op = Operation.create("pytest_dummy.intinput", results=[f32])
# TODO: Auto result cast from operation
@@ -95,4 +97,12 @@ def createInput():
m.operation.print()
-run(testCustomOpView)
+# CHECK-LABEL: TEST: testIsRegisteredOperation
+ at run
+def testIsRegisteredOperation():
+ ctx = Context()
+
+ # CHECK: std.cond_br: True
+ print(f"std.cond_br: {ctx.is_registered_operation('std.cond_br')}")
+ # CHECK: std.not_existing: False
+ print(f"std.not_existing: {ctx.is_registered_operation('std.not_existing')}")
diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index 40ef39b19d26..5ce496c8a0e2 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -1442,6 +1442,22 @@ int registerOnlyStd() {
fprintf(stderr, "@registration\n");
// CHECK-LABEL: @registration
+ // CHECK: std.cond_br is_registered: 1
+ fprintf(stderr, "std.cond_br is_registered: %d\n",
+ mlirContextIsRegisteredOperation(
+ ctx, mlirStringRefCreateFromCString("std.cond_br")));
+
+ // CHECK: std.not_existing_op is_registered: 0
+ fprintf(stderr, "std.not_existing_op is_registered: %d\n",
+ mlirContextIsRegisteredOperation(
+ ctx, mlirStringRefCreateFromCString("std.not_existing_op")));
+
+ // CHECK: not_existing_dialect.not_existing_op is_registered: 0
+ fprintf(stderr, "not_existing_dialect.not_existing_op is_registered: %d\n",
+ mlirContextIsRegisteredOperation(
+ ctx, mlirStringRefCreateFromCString(
+ "not_existing_dialect.not_existing_op")));
+
return 0;
}
More information about the Mlir-commits
mailing list