[Mlir-commits] [mlir] 4aa2171 - [mlir][CAPI] Attribute set/remove on operations.
Stella Laurenzo
llvmlistbot at llvm.org
Wed Oct 7 10:06:45 PDT 2020
Author: Stella Laurenzo
Date: 2020-10-07T10:03:23-07:00
New Revision: 4aa217160e5f06a96c6effc4950c3b402374de58
URL: https://github.com/llvm/llvm-project/commit/4aa217160e5f06a96c6effc4950c3b402374de58
DIFF: https://github.com/llvm/llvm-project/commit/4aa217160e5f06a96c6effc4950c3b402374de58.diff
LOG: [mlir][CAPI] Attribute set/remove on operations.
* New functions: mlirOperationSetAttributeByName, mlirOperationRemoveAttributeByName
* Also adds some *IsNull checks and standardizes the rest to use "static inline" form, which makes them all non-opaque and not part of the ABI (which is desirable).
* Changes needed to resolve TODOs in npcomp PyTorch capture.
Differential Revision: https://reviews.llvm.org/D88946
Added:
Modified:
mlir/include/mlir-c/IR.h
mlir/lib/CAPI/IR/IR.cpp
mlir/test/CAPI/ir.c
Removed:
################################################################################
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index c751da804097..b2a17869e2b3 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -92,7 +92,9 @@ MlirContext mlirContextCreate();
int mlirContextEqual(MlirContext ctx1, MlirContext ctx2);
/** Checks whether a context is null. */
-inline int mlirContextIsNull(MlirContext context) { return !context.ptr; }
+static inline int mlirContextIsNull(MlirContext context) {
+ return !context.ptr;
+}
/** Takes an MLIR context owned by the caller and destroys it. */
void mlirContextDestroy(MlirContext context);
@@ -127,7 +129,9 @@ MlirDialect mlirContextGetOrLoadDialect(MlirContext context,
MlirContext mlirDialectGetContext(MlirDialect dialect);
/** Checks if the dialect is null. */
-int mlirDialectIsNull(MlirDialect dialect);
+static inline int mlirDialectIsNull(MlirDialect dialect) {
+ return !dialect.ptr;
+}
/** Checks if two dialects that belong to the same context are equal. Dialects
* from
diff erent contexts will not compare equal. */
@@ -171,7 +175,7 @@ MlirModule mlirModuleCreateParse(MlirContext context, const char *module);
MlirContext mlirModuleGetContext(MlirModule module);
/** Checks whether a module is null. */
-inline int mlirModuleIsNull(MlirModule module) { return !module.ptr; }
+static inline int mlirModuleIsNull(MlirModule module) { return !module.ptr; }
/** Takes a module owned by the caller and deletes it. */
void mlirModuleDestroy(MlirModule module);
@@ -235,7 +239,7 @@ MlirOperation mlirOperationCreate(const MlirOperationState *state);
void mlirOperationDestroy(MlirOperation op);
/** Checks whether the underlying operation is null. */
-int mlirOperationIsNull(MlirOperation op);
+static inline int mlirOperationIsNull(MlirOperation op) { return !op.ptr; }
/** Returns the number of regions attached to the given operation. */
intptr_t mlirOperationGetNumRegions(MlirOperation op);
@@ -275,6 +279,15 @@ MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos);
MlirAttribute mlirOperationGetAttributeByName(MlirOperation op,
const char *name);
+/** Sets an attribute by name, replacing the existing if it exists or
+ * adding a new one otherwise. */
+void mlirOperationSetAttributeByName(MlirOperation op, const char *name,
+ MlirAttribute attr);
+
+/** Removes an attribute by name. Returns 0 if the attribute was not found
+ * and !0 if removed. */
+int mlirOperationRemoveAttributeByName(MlirOperation op, const char *name);
+
/** Prints an operation by sending chunks of the string representation and
* forwarding `userData to `callback`. Note that the callback may be called
* several times with consecutive chunks of the string. */
@@ -295,7 +308,7 @@ MlirRegion mlirRegionCreate();
void mlirRegionDestroy(MlirRegion region);
/** Checks whether a region is null. */
-int mlirRegionIsNull(MlirRegion region);
+static inline int mlirRegionIsNull(MlirRegion region) { return !region.ptr; }
/** Gets the first block in the region. */
MlirBlock mlirRegionGetFirstBlock(MlirRegion region);
@@ -333,7 +346,7 @@ MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType *args);
void mlirBlockDestroy(MlirBlock block);
/** Checks whether a block is null. */
-int mlirBlockIsNull(MlirBlock block);
+static inline int mlirBlockIsNull(MlirBlock block) { return !block.ptr; }
/** Returns the block immediately following the given block in its parent
* region. */
@@ -381,6 +394,9 @@ void mlirBlockPrint(MlirBlock block, MlirStringCallback callback,
/* Value API. */
/*============================================================================*/
+/** Returns whether the value is null. */
+static inline int mlirValueIsNull(MlirValue value) { return !value.ptr; }
+
/** Returns the type of the value. */
MlirType mlirValueGetType(MlirValue value);
@@ -401,7 +417,7 @@ MlirType mlirTypeParseGet(MlirContext context, const char *type);
MlirContext mlirTypeGetContext(MlirType type);
/** Checks whether a type is null. */
-inline int mlirTypeIsNull(MlirType type) { return !type.ptr; }
+static inline int mlirTypeIsNull(MlirType type) { return !type.ptr; }
/** Checks if two types are equal. */
int mlirTypeEqual(MlirType t1, MlirType t2);
@@ -425,7 +441,7 @@ MlirAttribute mlirAttributeParseGet(MlirContext context, const char *attr);
MlirContext mlirAttributeGetContext(MlirAttribute attribute);
/** Checks whether an attribute is null. */
-inline int mlirAttributeIsNull(MlirAttribute attr) { return !attr.ptr; }
+static inline int mlirAttributeIsNull(MlirAttribute attr) { return !attr.ptr; }
/** Checks if two attributes are equal. */
int mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2);
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 359ee69708eb..45cd009bddc0 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -66,10 +66,6 @@ MlirContext mlirDialectGetContext(MlirDialect dialect) {
return wrap(unwrap(dialect)->getContext());
}
-int mlirDialectIsNull(MlirDialect dialect) {
- return unwrap(dialect) == nullptr;
-}
-
int mlirDialectEqual(MlirDialect dialect1, MlirDialect dialect2) {
return unwrap(dialect1) == unwrap(dialect2);
}
@@ -215,8 +211,6 @@ MlirOperation mlirOperationCreate(const MlirOperationState *state) {
void mlirOperationDestroy(MlirOperation op) { unwrap(op)->erase(); }
-int mlirOperationIsNull(MlirOperation op) { return unwrap(op) == nullptr; }
-
intptr_t mlirOperationGetNumRegions(MlirOperation op) {
return static_cast<intptr_t>(unwrap(op)->getNumRegions());
}
@@ -267,6 +261,16 @@ MlirAttribute mlirOperationGetAttributeByName(MlirOperation op,
return wrap(unwrap(op)->getAttr(name));
}
+void mlirOperationSetAttributeByName(MlirOperation op, const char *name,
+ MlirAttribute attr) {
+ unwrap(op)->setAttr(name, unwrap(attr));
+}
+
+int mlirOperationRemoveAttributeByName(MlirOperation op, const char *name) {
+ auto removeResult = unwrap(op)->removeAttr(name);
+ return removeResult == MutableDictionaryAttr::RemoveResult::Removed;
+}
+
void mlirOperationPrint(MlirOperation op, MlirStringCallback callback,
void *userData) {
detail::CallbackOstream stream(callback, userData);
@@ -328,8 +332,6 @@ void mlirRegionDestroy(MlirRegion region) {
delete static_cast<Region *>(region.ptr);
}
-int mlirRegionIsNull(MlirRegion region) { return unwrap(region) == nullptr; }
-
/* ========================================================================== */
/* Block API. */
/* ========================================================================== */
@@ -391,8 +393,6 @@ void mlirBlockInsertOwnedOperationBefore(MlirBlock block,
void mlirBlockDestroy(MlirBlock block) { delete unwrap(block); }
-int mlirBlockIsNull(MlirBlock block) { return unwrap(block) == nullptr; }
-
intptr_t mlirBlockGetNumArguments(MlirBlock block) {
return static_cast<intptr_t>(unwrap(block)->getNumArguments());
}
diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index 18c4e8b08559..8eab4ebb3858 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -216,7 +216,7 @@ static void printToStderr(const char *str, intptr_t len, void *userData) {
fwrite(str, 1, len, stderr);
}
-static void printFirstOfEach(MlirOperation operation) {
+static void printFirstOfEach(MlirContext ctx, MlirOperation operation) {
// Assuming we are given a module, go to the first operation of the first
// function.
MlirRegion region = mlirOperationGetRegion(operation, 0);
@@ -227,24 +227,59 @@ static void printFirstOfEach(MlirOperation operation) {
operation = mlirBlockGetFirstOperation(block);
// In the module we created, the first operation of the first function is an
- // "std.dim", which has an attribute an a single result that we can use to
+ // "std.dim", which has an attribute and a single result that we can use to
// test the printing mechanism.
mlirBlockPrint(block, printToStderr, NULL);
fprintf(stderr, "\n");
+ fprintf(stderr, "First operation: ");
mlirOperationPrint(operation, printToStderr, NULL);
fprintf(stderr, "\n");
- MlirNamedAttribute namedAttr = mlirOperationGetAttribute(operation, 0);
- mlirAttributePrint(namedAttr.attribute, printToStderr, NULL);
+ // Get the attribute by index.
+ MlirNamedAttribute namedAttr0 = mlirOperationGetAttribute(operation, 0);
+ fprintf(stderr, "Get attr 0: ");
+ mlirAttributePrint(namedAttr0.attribute, printToStderr, NULL);
fprintf(stderr, "\n");
+ // Now re-get the attribute by name.
+ MlirAttribute attr0ByName =
+ mlirOperationGetAttributeByName(operation, namedAttr0.name);
+ fprintf(stderr, "Get attr 0 by name: ");
+ mlirAttributePrint(attr0ByName, printToStderr, NULL);
+ fprintf(stderr, "\n");
+
+ // Get a non-existing attribute and assert that it is null (sanity).
+ fprintf(stderr, "does_not_exist is null: %d\n",
+ mlirAttributeIsNull(
+ mlirOperationGetAttributeByName(operation, "does_not_exist")));
+
+ // Get result 0 and its type.
MlirValue value = mlirOperationGetResult(operation, 0);
+ fprintf(stderr, "Result 0: ");
mlirValuePrint(value, printToStderr, NULL);
fprintf(stderr, "\n");
+ fprintf(stderr, "Value is null: %d\n", mlirValueIsNull(value));
MlirType type = mlirValueGetType(value);
+ fprintf(stderr, "Result 0 type: ");
mlirTypePrint(type, printToStderr, NULL);
fprintf(stderr, "\n");
+
+ // Set a custom attribute.
+ mlirOperationSetAttributeByName(operation, "custom_attr",
+ mlirBoolAttrGet(ctx, 1));
+ fprintf(stderr, "Op with set attr: ");
+ mlirOperationPrint(operation, printToStderr, NULL);
+ fprintf(stderr, "\n");
+
+ // Remove the attribute.
+ fprintf(stderr, "Remove attr: %d\n",
+ mlirOperationRemoveAttributeByName(operation, "custom_attr"));
+ fprintf(stderr, "Remove attr again: %d\n",
+ mlirOperationRemoveAttributeByName(operation, "custom_attr"));
+ fprintf(stderr, "Removed attr is null: %d\n",
+ mlirAttributeIsNull(
+ mlirOperationGetAttributeByName(operation, "custom_attr")));
}
/// Creates an operation with a region containing multiple blocks with
@@ -884,7 +919,7 @@ int main() {
// CHECK: Number of values: 9
// clang-format on
- printFirstOfEach(module);
+ printFirstOfEach(ctx, module);
// clang-format off
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[DIM:.*]] = dim %{{.*}}, %[[C0]] : memref<?xf32>
@@ -896,10 +931,17 @@ int main() {
// CHECK: store %[[SUM]], %{{.*}}[%[[I]]] : memref<?xf32>
// CHECK: }
// CHECK: return
- // CHECK: constant 0 : index
- // CHECK: 0 : index
- // CHECK: constant 0 : index
- // CHECK: index
+ // CHECK: First operation: {{.*}} = constant 0 : index
+ // CHECK: Get attr 0: 0 : index
+ // CHECK: Get attr 0 by name: 0 : index
+ // CHECK: does_not_exist is null: 1
+ // CHECK: Result 0: {{.*}} = constant 0 : index
+ // CHECK: Value is null: 0
+ // CHECK: Result 0 type: index
+ // CHECK: Op with set attr: {{.*}} {custom_attr = true}
+ // CHECK: Remove attr: 1
+ // CHECK: Remove attr again: 0
+ // CHECK: Removed attr is null: 1
// clang-format on
mlirModuleDestroy(moduleOp);
More information about the Mlir-commits
mailing list