[Mlir-commits] [mlir] 4aa2171 - [mlir][CAPI] Attribute set/remove on operations.

Stella Laurenzo llvmlistbot at llvm.org
Wed Oct 7 10:06:45 PDT 2020


Author: Stella Laurenzo
Date: 2020-10-07T10:03:23-07:00
New Revision: 4aa217160e5f06a96c6effc4950c3b402374de58

URL: https://github.com/llvm/llvm-project/commit/4aa217160e5f06a96c6effc4950c3b402374de58
DIFF: https://github.com/llvm/llvm-project/commit/4aa217160e5f06a96c6effc4950c3b402374de58.diff

LOG: [mlir][CAPI] Attribute set/remove on operations.

* New functions: mlirOperationSetAttributeByName, mlirOperationRemoveAttributeByName
* Also adds some *IsNull checks and standardizes the rest to use "static inline" form, which makes them all non-opaque and not part of the ABI (which is desirable).
* Changes needed to resolve TODOs in npcomp PyTorch capture.

Differential Revision: https://reviews.llvm.org/D88946

Added: 
    

Modified: 
    mlir/include/mlir-c/IR.h
    mlir/lib/CAPI/IR/IR.cpp
    mlir/test/CAPI/ir.c

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index c751da804097..b2a17869e2b3 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -92,7 +92,9 @@ MlirContext mlirContextCreate();
 int mlirContextEqual(MlirContext ctx1, MlirContext ctx2);
 
 /** Checks whether a context is null. */
-inline int mlirContextIsNull(MlirContext context) { return !context.ptr; }
+static inline int mlirContextIsNull(MlirContext context) {
+  return !context.ptr;
+}
 
 /** Takes an MLIR context owned by the caller and destroys it. */
 void mlirContextDestroy(MlirContext context);
@@ -127,7 +129,9 @@ MlirDialect mlirContextGetOrLoadDialect(MlirContext context,
 MlirContext mlirDialectGetContext(MlirDialect dialect);
 
 /** Checks if the dialect is null. */
-int mlirDialectIsNull(MlirDialect dialect);
+static inline int mlirDialectIsNull(MlirDialect dialect) {
+  return !dialect.ptr;
+}
 
 /** Checks if two dialects that belong to the same context are equal. Dialects
  * from 
diff erent contexts will not compare equal. */
@@ -171,7 +175,7 @@ MlirModule mlirModuleCreateParse(MlirContext context, const char *module);
 MlirContext mlirModuleGetContext(MlirModule module);
 
 /** Checks whether a module is null. */
-inline int mlirModuleIsNull(MlirModule module) { return !module.ptr; }
+static inline int mlirModuleIsNull(MlirModule module) { return !module.ptr; }
 
 /** Takes a module owned by the caller and deletes it. */
 void mlirModuleDestroy(MlirModule module);
@@ -235,7 +239,7 @@ MlirOperation mlirOperationCreate(const MlirOperationState *state);
 void mlirOperationDestroy(MlirOperation op);
 
 /** Checks whether the underlying operation is null. */
-int mlirOperationIsNull(MlirOperation op);
+static inline int mlirOperationIsNull(MlirOperation op) { return !op.ptr; }
 
 /** Returns the number of regions attached to the given operation. */
 intptr_t mlirOperationGetNumRegions(MlirOperation op);
@@ -275,6 +279,15 @@ MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos);
 MlirAttribute mlirOperationGetAttributeByName(MlirOperation op,
                                               const char *name);
 
+/** Sets an attribute by name, replacing the existing if it exists or
+ * adding a new one otherwise. */
+void mlirOperationSetAttributeByName(MlirOperation op, const char *name,
+                                     MlirAttribute attr);
+
+/** Removes an attribute by name. Returns 0 if the attribute was not found
+ * and !0 if removed. */
+int mlirOperationRemoveAttributeByName(MlirOperation op, const char *name);
+
 /** Prints an operation by sending chunks of the string representation and
  * forwarding `userData to `callback`. Note that the callback may be called
  * several times with consecutive chunks of the string. */
@@ -295,7 +308,7 @@ MlirRegion mlirRegionCreate();
 void mlirRegionDestroy(MlirRegion region);
 
 /** Checks whether a region is null. */
-int mlirRegionIsNull(MlirRegion region);
+static inline int mlirRegionIsNull(MlirRegion region) { return !region.ptr; }
 
 /** Gets the first block in the region. */
 MlirBlock mlirRegionGetFirstBlock(MlirRegion region);
@@ -333,7 +346,7 @@ MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType *args);
 void mlirBlockDestroy(MlirBlock block);
 
 /** Checks whether a block is null. */
-int mlirBlockIsNull(MlirBlock block);
+static inline int mlirBlockIsNull(MlirBlock block) { return !block.ptr; }
 
 /** Returns the block immediately following the given block in its parent
  * region. */
@@ -381,6 +394,9 @@ void mlirBlockPrint(MlirBlock block, MlirStringCallback callback,
 /* Value API.                                                                 */
 /*============================================================================*/
 
+/** Returns whether the value is null. */
+static inline int mlirValueIsNull(MlirValue value) { return !value.ptr; }
+
 /** Returns the type of the value. */
 MlirType mlirValueGetType(MlirValue value);
 
@@ -401,7 +417,7 @@ MlirType mlirTypeParseGet(MlirContext context, const char *type);
 MlirContext mlirTypeGetContext(MlirType type);
 
 /** Checks whether a type is null. */
-inline int mlirTypeIsNull(MlirType type) { return !type.ptr; }
+static inline int mlirTypeIsNull(MlirType type) { return !type.ptr; }
 
 /** Checks if two types are equal. */
 int mlirTypeEqual(MlirType t1, MlirType t2);
@@ -425,7 +441,7 @@ MlirAttribute mlirAttributeParseGet(MlirContext context, const char *attr);
 MlirContext mlirAttributeGetContext(MlirAttribute attribute);
 
 /** Checks whether an attribute is null. */
-inline int mlirAttributeIsNull(MlirAttribute attr) { return !attr.ptr; }
+static inline int mlirAttributeIsNull(MlirAttribute attr) { return !attr.ptr; }
 
 /** Checks if two attributes are equal. */
 int mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2);

diff  --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 359ee69708eb..45cd009bddc0 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -66,10 +66,6 @@ MlirContext mlirDialectGetContext(MlirDialect dialect) {
   return wrap(unwrap(dialect)->getContext());
 }
 
-int mlirDialectIsNull(MlirDialect dialect) {
-  return unwrap(dialect) == nullptr;
-}
-
 int mlirDialectEqual(MlirDialect dialect1, MlirDialect dialect2) {
   return unwrap(dialect1) == unwrap(dialect2);
 }
@@ -215,8 +211,6 @@ MlirOperation mlirOperationCreate(const MlirOperationState *state) {
 
 void mlirOperationDestroy(MlirOperation op) { unwrap(op)->erase(); }
 
-int mlirOperationIsNull(MlirOperation op) { return unwrap(op) == nullptr; }
-
 intptr_t mlirOperationGetNumRegions(MlirOperation op) {
   return static_cast<intptr_t>(unwrap(op)->getNumRegions());
 }
@@ -267,6 +261,16 @@ MlirAttribute mlirOperationGetAttributeByName(MlirOperation op,
   return wrap(unwrap(op)->getAttr(name));
 }
 
+void mlirOperationSetAttributeByName(MlirOperation op, const char *name,
+                                     MlirAttribute attr) {
+  unwrap(op)->setAttr(name, unwrap(attr));
+}
+
+int mlirOperationRemoveAttributeByName(MlirOperation op, const char *name) {
+  auto removeResult = unwrap(op)->removeAttr(name);
+  return removeResult == MutableDictionaryAttr::RemoveResult::Removed;
+}
+
 void mlirOperationPrint(MlirOperation op, MlirStringCallback callback,
                         void *userData) {
   detail::CallbackOstream stream(callback, userData);
@@ -328,8 +332,6 @@ void mlirRegionDestroy(MlirRegion region) {
   delete static_cast<Region *>(region.ptr);
 }
 
-int mlirRegionIsNull(MlirRegion region) { return unwrap(region) == nullptr; }
-
 /* ========================================================================== */
 /* Block API.                                                                 */
 /* ========================================================================== */
@@ -391,8 +393,6 @@ void mlirBlockInsertOwnedOperationBefore(MlirBlock block,
 
 void mlirBlockDestroy(MlirBlock block) { delete unwrap(block); }
 
-int mlirBlockIsNull(MlirBlock block) { return unwrap(block) == nullptr; }
-
 intptr_t mlirBlockGetNumArguments(MlirBlock block) {
   return static_cast<intptr_t>(unwrap(block)->getNumArguments());
 }

diff  --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index 18c4e8b08559..8eab4ebb3858 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -216,7 +216,7 @@ static void printToStderr(const char *str, intptr_t len, void *userData) {
   fwrite(str, 1, len, stderr);
 }
 
-static void printFirstOfEach(MlirOperation operation) {
+static void printFirstOfEach(MlirContext ctx, MlirOperation operation) {
   // Assuming we are given a module, go to the first operation of the first
   // function.
   MlirRegion region = mlirOperationGetRegion(operation, 0);
@@ -227,24 +227,59 @@ static void printFirstOfEach(MlirOperation operation) {
   operation = mlirBlockGetFirstOperation(block);
 
   // In the module we created, the first operation of the first function is an
-  // "std.dim", which has an attribute an a single result that we can use to
+  // "std.dim", which has an attribute and a single result that we can use to
   // test the printing mechanism.
   mlirBlockPrint(block, printToStderr, NULL);
   fprintf(stderr, "\n");
+  fprintf(stderr, "First operation: ");
   mlirOperationPrint(operation, printToStderr, NULL);
   fprintf(stderr, "\n");
 
-  MlirNamedAttribute namedAttr = mlirOperationGetAttribute(operation, 0);
-  mlirAttributePrint(namedAttr.attribute, printToStderr, NULL);
+  // Get the attribute by index.
+  MlirNamedAttribute namedAttr0 = mlirOperationGetAttribute(operation, 0);
+  fprintf(stderr, "Get attr 0: ");
+  mlirAttributePrint(namedAttr0.attribute, printToStderr, NULL);
   fprintf(stderr, "\n");
 
+  // Now re-get the attribute by name.
+  MlirAttribute attr0ByName =
+      mlirOperationGetAttributeByName(operation, namedAttr0.name);
+  fprintf(stderr, "Get attr 0 by name: ");
+  mlirAttributePrint(attr0ByName, printToStderr, NULL);
+  fprintf(stderr, "\n");
+
+  // Get a non-existing attribute and assert that it is null (sanity).
+  fprintf(stderr, "does_not_exist is null: %d\n",
+          mlirAttributeIsNull(
+              mlirOperationGetAttributeByName(operation, "does_not_exist")));
+
+  // Get result 0 and its type.
   MlirValue value = mlirOperationGetResult(operation, 0);
+  fprintf(stderr, "Result 0: ");
   mlirValuePrint(value, printToStderr, NULL);
   fprintf(stderr, "\n");
+  fprintf(stderr, "Value is null: %d\n", mlirValueIsNull(value));
 
   MlirType type = mlirValueGetType(value);
+  fprintf(stderr, "Result 0 type: ");
   mlirTypePrint(type, printToStderr, NULL);
   fprintf(stderr, "\n");
+
+  // Set a custom attribute.
+  mlirOperationSetAttributeByName(operation, "custom_attr",
+                                  mlirBoolAttrGet(ctx, 1));
+  fprintf(stderr, "Op with set attr: ");
+  mlirOperationPrint(operation, printToStderr, NULL);
+  fprintf(stderr, "\n");
+
+  // Remove the attribute.
+  fprintf(stderr, "Remove attr: %d\n",
+          mlirOperationRemoveAttributeByName(operation, "custom_attr"));
+  fprintf(stderr, "Remove attr again: %d\n",
+          mlirOperationRemoveAttributeByName(operation, "custom_attr"));
+  fprintf(stderr, "Removed attr is null: %d\n",
+          mlirAttributeIsNull(
+              mlirOperationGetAttributeByName(operation, "custom_attr")));
 }
 
 /// Creates an operation with a region containing multiple blocks with
@@ -884,7 +919,7 @@ int main() {
   // CHECK: Number of values: 9
   // clang-format on
 
-  printFirstOfEach(module);
+  printFirstOfEach(ctx, module);
   // clang-format off
   // CHECK:   %[[C0:.*]] = constant 0 : index
   // CHECK:   %[[DIM:.*]] = dim %{{.*}}, %[[C0]] : memref<?xf32>
@@ -896,10 +931,17 @@ int main() {
   // CHECK:     store %[[SUM]], %{{.*}}[%[[I]]] : memref<?xf32>
   // CHECK:   }
   // CHECK: return
-  // CHECK: constant 0 : index
-  // CHECK: 0 : index
-  // CHECK: constant 0 : index
-  // CHECK: index
+  // CHECK: First operation: {{.*}} = constant 0 : index
+  // CHECK: Get attr 0: 0 : index
+  // CHECK: Get attr 0 by name: 0 : index
+  // CHECK: does_not_exist is null: 1
+  // CHECK: Result 0: {{.*}} = constant 0 : index
+  // CHECK: Value is null: 0
+  // CHECK: Result 0 type: index
+  // CHECK: Op with set attr: {{.*}} {custom_attr = true}
+  // CHECK: Remove attr: 1
+  // CHECK: Remove attr again: 0
+  // CHECK: Removed attr is null: 1
   // clang-format on
 
   mlirModuleDestroy(moduleOp);


        


More information about the Mlir-commits mailing list