[llvm-branch-commits] [mlir] 52586c4 - [mlir][CAPI] Add result type inference to the CAPI.

Stella Laurenzo via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Sat Jan 23 14:38:32 PST 2021


Author: Stella Laurenzo
Date: 2021-01-23T14:30:51-08:00
New Revision: 52586c46b0883600a332fd64731dc5287981f980

URL: https://github.com/llvm/llvm-project/commit/52586c46b0883600a332fd64731dc5287981f980
DIFF: https://github.com/llvm/llvm-project/commit/52586c46b0883600a332fd64731dc5287981f980.diff

LOG: [mlir][CAPI] Add result type inference to the CAPI.

* Adds a flag to MlirOperationState to enable result type inference using the InferTypeOpInterface.
* I chose this level of implementation for a couple of reasons:
  a) In the creation flow is naturally where generated and custom builder code will be invoking such a thing
  b) it is a bit more efficient to share the data structure and unpacking vs having a standalone entry-point
  c) we can always decide to expose more of these interfaces with first-class APIs, but that doesn't preclude that we will always want to use this one in this way (and less API surface area for common things is better for API stability and evolution).
* I struggled to find an appropriate way to test it since we don't link the test dialect into anything CAPI accessible at present. I opted instead for one of the simplest ops I found in a regular dialect which implements the interface.
* This does not do any trait-based type selection. That will be left to generated tablegen wrappers.

Differential Revision: https://reviews.llvm.org/D95283

Added: 
    

Modified: 
    mlir/include/mlir-c/IR.h
    mlir/lib/CAPI/IR/IR.cpp
    mlir/test/CAPI/ir.c

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 13e32be049e4..39f42e4bf497 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -225,6 +225,7 @@ struct MlirOperationState {
   MlirBlock *successors;
   intptr_t nAttributes;
   MlirNamedAttribute *attributes;
+  bool enableResultTypeInference;
 };
 typedef struct MlirOperationState MlirOperationState;
 
@@ -249,6 +250,14 @@ MLIR_CAPI_EXPORTED void
 mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n,
                                 MlirNamedAttribute const *attributes);
 
+/// Enables result type inference for the operation under construction. If
+/// enabled, then the caller must not have called
+/// mlirOperationStateAddResults(). Note that if enabled, the
+/// mlirOperationCreate() call is failable: it will return a null operation
+/// on inference failure and will emit diagnostics.
+MLIR_CAPI_EXPORTED void
+mlirOperationStateEnableResultTypeInference(MlirOperationState *state);
+
 //===----------------------------------------------------------------------===//
 // Op Printing flags API.
 // While many of these are simple settings that could be represented in a
@@ -293,8 +302,14 @@ mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags);
 //===----------------------------------------------------------------------===//
 
 /// Creates an operation and transfers ownership to the caller.
-MLIR_CAPI_EXPORTED MlirOperation
-mlirOperationCreate(const MlirOperationState *state);
+/// Note that caller owned child objects are transferred in this call and must
+/// not be further used. Particularly, this applies to any regions added to
+/// the state (the implementation may invalidate any such pointers).
+///
+/// This call can fail under the following conditions, in which case, it will
+/// return a null operation and emit diagnostics:
+///   - Result type inference is enabled and cannot be performed.
+MLIR_CAPI_EXPORTED MlirOperation mlirOperationCreate(MlirOperationState *state);
 
 /// Takes an operation owned by the caller and destroys it.
 MLIR_CAPI_EXPORTED void mlirOperationDestroy(MlirOperation op);

diff  --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 30d4c8c41835..bf240046be25 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -18,6 +18,7 @@
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/Types.h"
 #include "mlir/IR/Verifier.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Parser.h"
 
 using namespace mlir;
@@ -188,6 +189,7 @@ MlirOperationState mlirOperationStateGet(MlirStringRef name, MlirLocation loc) {
   state.successors = nullptr;
   state.nAttributes = 0;
   state.attributes = nullptr;
+  state.enableResultTypeInference = false;
   return state;
 }
 
@@ -219,11 +221,47 @@ void mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n,
   APPEND_ELEMS(MlirNamedAttribute, nAttributes, attributes);
 }
 
+void mlirOperationStateEnableResultTypeInference(MlirOperationState *state) {
+  state->enableResultTypeInference = true;
+}
+
 //===----------------------------------------------------------------------===//
 // Operation API.
 //===----------------------------------------------------------------------===//
 
-MlirOperation mlirOperationCreate(const MlirOperationState *state) {
+static LogicalResult inferOperationTypes(OperationState &state) {
+  MLIRContext *context = state.getContext();
+  const AbstractOperation *abstractOp =
+      AbstractOperation::lookup(state.name.getStringRef(), context);
+  if (!abstractOp) {
+    emitError(state.location)
+        << "type inference was requested for the operation " << state.name
+        << ", but the operation was not registered. Ensure that the dialect "
+           "containing the operation is linked into MLIR and registered with "
+           "the context";
+    return failure();
+  }
+
+  // Fallback to inference via an op interface.
+  auto *inferInterface = abstractOp->getInterface<InferTypeOpInterface>();
+  if (!inferInterface) {
+    emitError(state.location)
+        << "type inference was requested for the operation " << state.name
+        << ", but the operation does not support type inference. Result "
+           "types must be specified explicitly.";
+    return failure();
+  }
+
+  if (succeeded(inferInterface->inferReturnTypes(
+          context, state.location, state.operands,
+          state.attributes.getDictionary(context), state.regions, state.types)))
+    return success();
+
+  // Diagnostic emitted by interface.
+  return failure();
+}
+
+MlirOperation mlirOperationCreate(MlirOperationState *state) {
   assert(state);
   OperationState cppState(unwrap(state->location), unwrap(state->name));
   SmallVector<Type, 4> resultStorage;
@@ -243,12 +281,21 @@ MlirOperation mlirOperationCreate(const MlirOperationState *state) {
   for (intptr_t i = 0; i < state->nRegions; ++i)
     cppState.addRegion(std::unique_ptr<Region>(unwrap(state->regions[i])));
 
-  MlirOperation result = wrap(Operation::create(cppState));
   free(state->results);
   free(state->operands);
   free(state->successors);
   free(state->regions);
   free(state->attributes);
+
+  // Infer result types.
+  if (state->enableResultTypeInference) {
+    assert(cppState.types.empty() &&
+           "result type inference enabled and result types provided");
+    if (failed(inferOperationTypes(cppState)))
+      return {nullptr};
+  }
+
+  MlirOperation result = wrap(Operation::create(cppState));
   return result;
 }
 

diff  --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index 5a222cef0bac..015738384424 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -553,6 +553,35 @@ static void buildWithInsertionsAndPrint(MlirContext ctx) {
   // clang-format on
 }
 
+/// Creates operations with type inference and tests various failure modes.
+static int createOperationWithTypeInference(MlirContext ctx) {
+  MlirLocation loc = mlirLocationUnknownGet(ctx);
+  MlirAttribute iAttr = mlirIntegerAttrGet(mlirIntegerTypeGet(ctx, 32), 4);
+
+  // The shape.const_size op implements result type inference and is only used
+  // for that reason.
+  MlirOperationState state = mlirOperationStateGet(
+      mlirStringRefCreateFromCString("shape.const_size"), loc);
+  MlirNamedAttribute valueAttr = mlirNamedAttributeGet(
+      mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("value")), iAttr);
+  mlirOperationStateAddAttributes(&state, 1, &valueAttr);
+  mlirOperationStateEnableResultTypeInference(&state);
+
+  // Expect result type inference to succeed.
+  MlirOperation op = mlirOperationCreate(&state);
+  if (mlirOperationIsNull(op)) {
+    fprintf(stderr, "ERROR: Result type inference unexpectedly failed");
+    return 1;
+  }
+
+  // CHECK: RESULT_TYPE_INFERENCE: !shape.size
+  fprintf(stderr, "RESULT_TYPE_INFERENCE: ");
+  mlirTypeDump(mlirValueGetType(mlirOperationGetResult(op, 0)));
+  fprintf(stderr, "\n");
+  mlirOperationDestroy(op);
+  return 0;
+}
+
 /// Dumps instances of all builtin types to check that C API works correctly.
 /// Additionally, performs simple identity checks that a builtin type
 /// constructed with C API can be inspected and has the expected type. The
@@ -957,14 +986,12 @@ int printBuiltinAttributes(MlirContext ctx) {
       (uint64_t *)mlirDenseElementsAttrGetRawData(uint64Elements);
   int64_t *int64RawData =
       (int64_t *)mlirDenseElementsAttrGetRawData(int64Elements);
-  float *floatRawData =
-      (float *)mlirDenseElementsAttrGetRawData(floatElements);
+  float *floatRawData = (float *)mlirDenseElementsAttrGetRawData(floatElements);
   double *doubleRawData =
       (double *)mlirDenseElementsAttrGetRawData(doubleElements);
   if (uint32RawData[0] != 0u || uint32RawData[1] != 1u ||
-      int32RawData[0] != 0 || int32RawData[1] != 1 ||
-      uint64RawData[0] != 0u || uint64RawData[1] != 1u ||
-      int64RawData[0] != 0 || int64RawData[1] != 1 ||
+      int32RawData[0] != 0 || int32RawData[1] != 1 || uint64RawData[0] != 0u ||
+      uint64RawData[1] != 1u || int64RawData[0] != 0 || int64RawData[1] != 1 ||
       floatRawData[0] != 0.0f || floatRawData[1] != 1.0f ||
       doubleRawData[0] != 0.0 || doubleRawData[1] != 1.0)
     return 18;
@@ -1389,19 +1416,21 @@ int main() {
   if (constructAndTraverseIr(ctx))
     return 1;
   buildWithInsertionsAndPrint(ctx);
+  if (createOperationWithTypeInference(ctx))
+    return 2;
 
   if (printBuiltinTypes(ctx))
-    return 2;
-  if (printBuiltinAttributes(ctx))
     return 3;
-  if (printAffineMap(ctx))
+  if (printBuiltinAttributes(ctx))
     return 4;
-  if (printAffineExpr(ctx))
+  if (printAffineMap(ctx))
     return 5;
-  if (affineMapFromExprs(ctx))
+  if (printAffineExpr(ctx))
     return 6;
-  if (registerOnlyStd())
+  if (affineMapFromExprs(ctx))
     return 7;
+  if (registerOnlyStd())
+    return 8;
 
   mlirContextDestroy(ctx);
 


        


More information about the llvm-branch-commits mailing list