[Mlir-commits] [mlir] 39613c2 - [mlir] Expose Value hierarchy to C API

Alex Zinenko llvmlistbot at llvm.org
Tue Oct 20 00:39:18 PDT 2020


Author: Alex Zinenko
Date: 2020-10-20T09:39:08+02:00
New Revision: 39613c2cbc8f11ff6246211385134f0a548b5b57

URL: https://github.com/llvm/llvm-project/commit/39613c2cbc8f11ff6246211385134f0a548b5b57
DIFF: https://github.com/llvm/llvm-project/commit/39613c2cbc8f11ff6246211385134f0a548b5b57.diff

LOG: [mlir] Expose Value hierarchy to C API

The Value hierarchy consists of BlockArgument and OpResult, both of which
derive Value. Introduce IsA functions and functions specific to each class,
similarly to other class hierarchies. Also, introduce functions for
pointer-comparison of Block and Operation that are necessary for testing and
are generally useful.

Reviewed By: stellaraccident, mehdi_amini

Differential Revision: https://reviews.llvm.org/D89714

Added: 
    

Modified: 
    mlir/include/mlir-c/IR.h
    mlir/lib/CAPI/IR/IR.cpp
    mlir/test/CAPI/ir.c

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index a00b96119298..816123472647 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -241,6 +241,10 @@ void mlirOperationDestroy(MlirOperation op);
 /** Checks whether the underlying operation is null. */
 static inline int mlirOperationIsNull(MlirOperation op) { return !op.ptr; }
 
+/** Checks whether two operation handles point to the same operation. This does
+ * not perform deep comparison. */
+int mlirOperationEqual(MlirOperation op, MlirOperation other);
+
 /** Returns the number of regions attached to the given operation. */
 intptr_t mlirOperationGetNumRegions(MlirOperation op);
 
@@ -348,6 +352,10 @@ void mlirBlockDestroy(MlirBlock block);
 /** Checks whether a block is null. */
 static inline int mlirBlockIsNull(MlirBlock block) { return !block.ptr; }
 
+/** Checks whether two blocks handles point to the same block. This does not
+ * perform deep comparison. */
+int mlirBlockEqual(MlirBlock block, MlirBlock other);
+
 /** Returns the block immediately following the given block in its parent
  * region. */
 MlirBlock mlirBlockGetNextInRegion(MlirBlock block);
@@ -397,6 +405,30 @@ void mlirBlockPrint(MlirBlock block, MlirStringCallback callback,
 /** Returns whether the value is null. */
 static inline int mlirValueIsNull(MlirValue value) { return !value.ptr; }
 
+/** Returns 1 if the value is a block argument, 0 otherwise. */
+int mlirValueIsABlockArgument(MlirValue value);
+
+/** Returns 1 if the value is an operation result, 0 otherwise. */
+int mlirValueIsAOpResult(MlirValue value);
+
+/** Returns the block in which this value is defined as an argument. Asserts if
+ * the value is not a block argument. */
+MlirBlock mlirBlockArgumentGetOwner(MlirValue value);
+
+/** Returns the position of the value in the argument list of its block. */
+intptr_t mlirBlockArgumentGetArgNumber(MlirValue value);
+
+/** Sets the type of the block argument to the given type. */
+void mlirBlockArgumentSetType(MlirValue value, MlirType type);
+
+/** Returns an operation that produced this value as its result. Asserts if the
+ * value is not an op result. */
+MlirOperation mlirOpResultGetOwner(MlirValue value);
+
+/** Returns the position of the value in the list of results of the operation
+ * that produced it. */
+intptr_t mlirOpResultGetResultNumber(MlirValue value);
+
 /** Returns the type of the value. */
 MlirType mlirValueGetType(MlirValue value);
 

diff  --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 8226e4e552f7..4bae43c424fd 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -211,6 +211,10 @@ MlirOperation mlirOperationCreate(const MlirOperationState *state) {
 
 void mlirOperationDestroy(MlirOperation op) { unwrap(op)->erase(); }
 
+int mlirOperationEqual(MlirOperation op, MlirOperation other) {
+  return unwrap(op) == unwrap(other);
+}
+
 intptr_t mlirOperationGetNumRegions(MlirOperation op) {
   return static_cast<intptr_t>(unwrap(op)->getNumRegions());
 }
@@ -343,6 +347,10 @@ MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType *args) {
   return wrap(b);
 }
 
+int mlirBlockEqual(MlirBlock block, MlirBlock other) {
+  return unwrap(block) == unwrap(other);
+}
+
 MlirBlock mlirBlockGetNextInRegion(MlirBlock block) {
   return wrap(unwrap(block)->getNextNode());
 }
@@ -412,6 +420,36 @@ void mlirBlockPrint(MlirBlock block, MlirStringCallback callback,
 /* Value API.                                                                 */
 /* ========================================================================== */
 
+int mlirValueIsABlockArgument(MlirValue value) {
+  return unwrap(value).isa<BlockArgument>();
+}
+
+int mlirValueIsAOpResult(MlirValue value) {
+  return unwrap(value).isa<OpResult>();
+}
+
+MlirBlock mlirBlockArgumentGetOwner(MlirValue value) {
+  return wrap(unwrap(value).cast<BlockArgument>().getOwner());
+}
+
+intptr_t mlirBlockArgumentGetArgNumber(MlirValue value) {
+  return static_cast<intptr_t>(
+      unwrap(value).cast<BlockArgument>().getArgNumber());
+}
+
+void mlirBlockArgumentSetType(MlirValue value, MlirType type) {
+  unwrap(value).cast<BlockArgument>().setType(unwrap(type));
+}
+
+MlirOperation mlirOpResultGetOwner(MlirValue value) {
+  return wrap(unwrap(value).cast<OpResult>().getOwner());
+}
+
+intptr_t mlirOpResultGetResultNumber(MlirValue value) {
+  return static_cast<intptr_t>(
+      unwrap(value).cast<OpResult>().getResultNumber());
+}
+
 MlirType mlirValueGetType(MlirValue value) {
   return wrap(unwrap(value).getType());
 }

diff  --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index 0c427c77bdb8..7c86f403b339 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -153,10 +153,12 @@ struct ModuleStats {
   unsigned numBlocks;
   unsigned numRegions;
   unsigned numValues;
+  unsigned numBlockArguments;
+  unsigned numOpResults;
 };
 typedef struct ModuleStats ModuleStats;
 
-void collectStatsSingle(OpListNode *head, ModuleStats *stats) {
+int collectStatsSingle(OpListNode *head, ModuleStats *stats) {
   MlirOperation operation = head->op;
   stats->numOperations += 1;
   stats->numValues += mlirOperationGetNumResults(operation);
@@ -166,12 +168,39 @@ void collectStatsSingle(OpListNode *head, ModuleStats *stats) {
 
   stats->numRegions += numRegions;
 
+  intptr_t numResults = mlirOperationGetNumResults(operation);
+  for (intptr_t i = 0; i < numResults; ++i) {
+    MlirValue result = mlirOperationGetResult(operation, i);
+    if (!mlirValueIsAOpResult(result))
+      return 1;
+    if (mlirValueIsABlockArgument(result))
+      return 2;
+    if (!mlirOperationEqual(operation, mlirOpResultGetOwner(result)))
+      return 3;
+    if (i != mlirOpResultGetResultNumber(result))
+      return 4;
+    ++stats->numOpResults;
+  }
+
   for (unsigned i = 0; i < numRegions; ++i) {
     MlirRegion region = mlirOperationGetRegion(operation, i);
     for (MlirBlock block = mlirRegionGetFirstBlock(region);
          !mlirBlockIsNull(block); block = mlirBlockGetNextInRegion(block)) {
       ++stats->numBlocks;
-      stats->numValues += mlirBlockGetNumArguments(block);
+      intptr_t numArgs = mlirBlockGetNumArguments(block);
+      stats->numValues += numArgs;
+      for (intptr_t j = 0; j < numArgs; ++j) {
+        MlirValue arg = mlirBlockGetArgument(block, j);
+        if (!mlirValueIsABlockArgument(arg))
+          return 5;
+        if (mlirValueIsAOpResult(arg))
+          return 6;
+        if (!mlirBlockEqual(block, mlirBlockArgumentGetOwner(arg)))
+          return 7;
+        if (j != mlirBlockArgumentGetArgNumber(arg))
+          return 8;
+        ++stats->numBlockArguments;
+      }
 
       for (MlirOperation child = mlirBlockGetFirstOperation(block);
            !mlirOperationIsNull(child);
@@ -183,9 +212,10 @@ void collectStatsSingle(OpListNode *head, ModuleStats *stats) {
       }
     }
   }
+  return 0;
 }
 
-void collectStats(MlirOperation operation) {
+int collectStats(MlirOperation operation) {
   OpListNode *head = malloc(sizeof(OpListNode));
   head->op = operation;
   head->next = NULL;
@@ -196,9 +226,13 @@ void collectStats(MlirOperation operation) {
   stats.numBlocks = 0;
   stats.numRegions = 0;
   stats.numValues = 0;
+  stats.numBlockArguments = 0;
+  stats.numOpResults = 0;
 
   do {
-    collectStatsSingle(head, &stats);
+    int retval = collectStatsSingle(head, &stats);
+    if (retval)
+      return retval;
     OpListNode *next = head->next;
     free(head);
     head = next;
@@ -209,6 +243,11 @@ void collectStats(MlirOperation operation) {
   fprintf(stderr, "Number of blocks: %u\n", stats.numBlocks);
   fprintf(stderr, "Number of regions: %u\n", stats.numRegions);
   fprintf(stderr, "Number of values: %u\n", stats.numValues);
+  fprintf(stderr, "Number of block arguments: %u\n", stats.numBlockArguments);
+  fprintf(stderr, "Number of op results: %u\n", stats.numOpResults);
+  if (stats.numValues != stats.numBlockArguments + stats.numOpResults)
+    return 100;
+  return 0;
 }
 
 static void printToStderr(const char *str, intptr_t len, void *userData) {
@@ -914,13 +953,19 @@ int main() {
   // CHECK: }
   // clang-format on
 
-  collectStats(module);
+  fprintf(stderr, "@stats\n");
+  int errcode = collectStats(module);
+  fprintf(stderr, "%d\n", errcode);
   // clang-format off
+  // CHECK-LABEL: @stats
   // CHECK: Number of operations: 13
   // CHECK: Number of attributes: 4
   // CHECK: Number of blocks: 3
   // CHECK: Number of regions: 3
   // CHECK: Number of values: 9
+  // CHECK: Number of block arguments: 3
+  // CHECK: Number of op results: 6
+  // CHECK: 0
   // clang-format on
 
   printFirstOfEach(ctx, module);
@@ -988,7 +1033,7 @@ int main() {
   // CHECK: 0
   // clang-format on
   fprintf(stderr, "@types\n");
-  int errcode = printStandardTypes(ctx);
+  errcode = printStandardTypes(ctx);
   fprintf(stderr, "%d\n", errcode);
 
   // clang-format off


        


More information about the Mlir-commits mailing list