[Mlir-commits] [mlir] df9ae59 - Use MlirStringRef throughout the C API

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Nov 23 14:09:14 PST 2020


Author: George
Date: 2020-11-23T14:07:30-08:00
New Revision: df9ae5992889560a8f3c6760b54d5051b47c7bf5

URL: https://github.com/llvm/llvm-project/commit/df9ae5992889560a8f3c6760b54d5051b47c7bf5
DIFF: https://github.com/llvm/llvm-project/commit/df9ae5992889560a8f3c6760b54d5051b47c7bf5.diff

LOG: Use MlirStringRef throughout the C API

While this makes the unit tests a bit more verbose, this simplifies the creation of bindings because only the bidirectional mapping between the host language's string type and MlirStringRef need to be implemented.

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D91905

Added: 
    

Modified: 
    mlir/include/mlir-c/IR.h
    mlir/include/mlir-c/Support.h
    mlir/include/mlir/CAPI/Utils.h
    mlir/lib/CAPI/IR/IR.cpp
    mlir/lib/CAPI/IR/StandardAttributes.cpp
    mlir/test/CAPI/ir.c
    mlir/test/CAPI/pass.c

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 6c9394c38b17..2ca5b80b825a 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -69,7 +69,7 @@ DEFINE_C_API_STRUCT(MlirValue, const void);
  * a string.
  */
 struct MlirNamedAttribute {
-  const char *name;
+  MlirStringRef name;
   MlirAttribute attribute;
 };
 typedef struct MlirNamedAttribute MlirNamedAttribute;
@@ -143,10 +143,8 @@ MLIR_CAPI_EXPORTED MlirStringRef mlirDialectGetNamespace(MlirDialect dialect);
 //===----------------------------------------------------------------------===//
 
 /// Creates an File/Line/Column location owned by the given context.
-MLIR_CAPI_EXPORTED MlirLocation mlirLocationFileLineColGet(MlirContext context,
-                                                           const char *filename,
-                                                           unsigned line,
-                                                           unsigned col);
+MLIR_CAPI_EXPORTED MlirLocation mlirLocationFileLineColGet(
+    MlirContext context, MlirStringRef filename, unsigned line, unsigned col);
 
 /// Creates a location with unknown position owned by the given context.
 MLIR_CAPI_EXPORTED MlirLocation mlirLocationUnknownGet(MlirContext context);
@@ -170,7 +168,7 @@ MLIR_CAPI_EXPORTED MlirModule mlirModuleCreateEmpty(MlirLocation location);
 
 /// Parses a module from the string and transfers ownership to the caller.
 MLIR_CAPI_EXPORTED MlirModule mlirModuleCreateParse(MlirContext context,
-                                                    const char *module);
+                                                    MlirStringRef module);
 
 /// Gets the context that a module was created with.
 MLIR_CAPI_EXPORTED MlirContext mlirModuleGetContext(MlirModule module);
@@ -202,7 +200,7 @@ MLIR_CAPI_EXPORTED MlirOperation mlirModuleGetOperation(MlirModule module);
  * mlirOperationState* functions instead.
  */
 struct MlirOperationState {
-  const char *name;
+  MlirStringRef name;
   MlirLocation location;
   intptr_t nResults;
   MlirType *results;
@@ -218,16 +216,16 @@ struct MlirOperationState {
 typedef struct MlirOperationState MlirOperationState;
 
 /// Constructs an operation state from a name and a location.
-MLIR_CAPI_EXPORTED MlirOperationState mlirOperationStateGet(const char *name,
+MLIR_CAPI_EXPORTED MlirOperationState mlirOperationStateGet(MlirStringRef name,
                                                             MlirLocation loc);
 
 /// Adds a list of components to the operation state.
 MLIR_CAPI_EXPORTED void mlirOperationStateAddResults(MlirOperationState *state,
                                                      intptr_t n,
                                                      MlirType const *results);
-MLIR_CAPI_EXPORTED void mlirOperationStateAddOperands(MlirOperationState *state,
-                                                      intptr_t n,
-                                                      MlirValue const *operands);
+MLIR_CAPI_EXPORTED void
+mlirOperationStateAddOperands(MlirOperationState *state, intptr_t n,
+                              MlirValue const *operands);
 MLIR_CAPI_EXPORTED void
 mlirOperationStateAddOwnedRegions(MlirOperationState *state, intptr_t n,
                                   MlirRegion const *regions);
@@ -349,18 +347,18 @@ mlirOperationGetAttribute(MlirOperation op, intptr_t pos);
 
 /// Returns an attribute attached to the operation given its name.
 MLIR_CAPI_EXPORTED MlirAttribute
-mlirOperationGetAttributeByName(MlirOperation op, const char *name);
+mlirOperationGetAttributeByName(MlirOperation op, MlirStringRef name);
 
 /** Sets an attribute by name, replacing the existing if it exists or
  * adding a new one otherwise. */
 MLIR_CAPI_EXPORTED void mlirOperationSetAttributeByName(MlirOperation op,
-                                                        const char *name,
+                                                        MlirStringRef name,
                                                         MlirAttribute attr);
 
 /** Removes an attribute by name. Returns 0 if the attribute was not found
  * and !0 if removed. */
 MLIR_CAPI_EXPORTED int mlirOperationRemoveAttributeByName(MlirOperation op,
-                                                          const char *name);
+                                                          MlirStringRef name);
 
 /** Prints an operation by sending chunks of the string representation and
  * forwarding `userData to `callback`. Note that the callback may be called
@@ -425,7 +423,8 @@ MLIR_CAPI_EXPORTED void mlirRegionInsertOwnedBlockBefore(MlirRegion region,
 
 /** Creates a new empty block with the given argument types and transfers
  * ownership to the caller. */
-MLIR_CAPI_EXPORTED MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType const *args);
+MLIR_CAPI_EXPORTED MlirBlock mlirBlockCreate(intptr_t nArgs,
+                                             MlirType const *args);
 
 /// Takes a block owned by the caller and destroys it.
 MLIR_CAPI_EXPORTED void mlirBlockDestroy(MlirBlock block);
@@ -538,7 +537,7 @@ mlirValuePrint(MlirValue value, MlirStringCallback callback, void *userData);
 
 /// Parses a type. The type is owned by the context.
 MLIR_CAPI_EXPORTED MlirType mlirTypeParseGet(MlirContext context,
-                                             const char *type);
+                                             MlirStringRef type);
 
 /// Gets the context that a type was created with.
 MLIR_CAPI_EXPORTED MlirContext mlirTypeGetContext(MlirType type);
@@ -564,7 +563,7 @@ MLIR_CAPI_EXPORTED void mlirTypeDump(MlirType type);
 
 /// Parses an attribute. The attribute is owned by the context.
 MLIR_CAPI_EXPORTED MlirAttribute mlirAttributeParseGet(MlirContext context,
-                                                       const char *attr);
+                                                       MlirStringRef attr);
 
 /// Gets the context that an attribute was created with.
 MLIR_CAPI_EXPORTED MlirContext mlirAttributeGetContext(MlirAttribute attribute);
@@ -589,7 +588,7 @@ MLIR_CAPI_EXPORTED void mlirAttributePrint(MlirAttribute attr,
 MLIR_CAPI_EXPORTED void mlirAttributeDump(MlirAttribute attr);
 
 /// Associates an attribute with the name. Takes ownership of neither.
-MLIR_CAPI_EXPORTED MlirNamedAttribute mlirNamedAttributeGet(const char *name,
+MLIR_CAPI_EXPORTED MlirNamedAttribute mlirNamedAttributeGet(MlirStringRef name,
                                                             MlirAttribute attr);
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir-c/Support.h b/mlir/include/mlir-c/Support.h
index 717ec41a805d..afa094a41fcb 100644
--- a/mlir/include/mlir-c/Support.h
+++ b/mlir/include/mlir-c/Support.h
@@ -75,12 +75,10 @@ mlirStringRefCreateFromCString(const char *str);
  *
  * This function is called back by the functions that need to return a reference
  * to the portion of the string with the following arguments:
- *   - a pointer to the beginning of a string;
- *   - the length of the string (the pointer may point to a larger buffer, not
- *     necessarily null-terminated);
+ *   - an MlirStringRef represening the current portion of the string
  *   - a pointer to user data forwarded from the printing call.
  */
-typedef void (*MlirStringCallback)(const char *, intptr_t, void *);
+typedef void (*MlirStringCallback)(MlirStringRef, void *);
 
 //===----------------------------------------------------------------------===//
 // MlirLogicalResult.

diff  --git a/mlir/include/mlir/CAPI/Utils.h b/mlir/include/mlir/CAPI/Utils.h
index 7307f303868e..c2e43850c2b6 100644
--- a/mlir/include/mlir/CAPI/Utils.h
+++ b/mlir/include/mlir/CAPI/Utils.h
@@ -14,6 +14,7 @@
 #ifndef MLIR_CAPI_UTILS_H
 #define MLIR_CAPI_UTILS_H
 
+#include "mlir-c/Support.h"
 #include "llvm/Support/raw_ostream.h"
 
 //===----------------------------------------------------------------------===//
@@ -26,20 +27,21 @@ namespace detail {
 /// user-supplied callback together with opaque user-supplied data.
 class CallbackOstream : public llvm::raw_ostream {
 public:
-  CallbackOstream(std::function<void(const char *, intptr_t, void *)> callback,
+  CallbackOstream(std::function<void(MlirStringRef, void *)> callback,
                   void *opaqueData)
       : raw_ostream(/*unbuffered=*/true), callback(callback),
         opaqueData(opaqueData), pos(0u) {}
 
   void write_impl(const char *ptr, size_t size) override {
-    callback(ptr, size, opaqueData);
+    MlirStringRef string = mlirStringRefCreate(ptr, size);
+    callback(string, opaqueData);
     pos += size;
   }
 
   uint64_t current_pos() const override { return pos; }
 
 private:
-  std::function<void(const char *, intptr_t, void *)> callback;
+  std::function<void(MlirStringRef, void *)> callback;
   void *opaqueData;
   uint64_t pos;
 };

diff  --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index f2329afce7e0..cf76811f6c12 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -109,9 +109,10 @@ void mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags) {
 //===----------------------------------------------------------------------===//
 
 MlirLocation mlirLocationFileLineColGet(MlirContext context,
-                                        const char *filename, unsigned line,
+                                        MlirStringRef filename, unsigned line,
                                         unsigned col) {
-  return wrap(FileLineColLoc::get(filename, line, col, unwrap(context)));
+  return wrap(
+      FileLineColLoc::get(unwrap(filename), line, col, unwrap(context)));
 }
 
 MlirLocation mlirLocationUnknownGet(MlirContext context) {
@@ -136,8 +137,8 @@ MlirModule mlirModuleCreateEmpty(MlirLocation location) {
   return wrap(ModuleOp::create(unwrap(location)));
 }
 
-MlirModule mlirModuleCreateParse(MlirContext context, const char *module) {
-  OwningModuleRef owning = parseSourceString(module, unwrap(context));
+MlirModule mlirModuleCreateParse(MlirContext context, MlirStringRef module) {
+  OwningModuleRef owning = parseSourceString(unwrap(module), unwrap(context));
   if (!owning)
     return MlirModule{nullptr};
   return MlirModule{owning.release().getOperation()};
@@ -164,7 +165,7 @@ MlirOperation mlirModuleGetOperation(MlirModule module) {
 // Operation state API.
 //===----------------------------------------------------------------------===//
 
-MlirOperationState mlirOperationStateGet(const char *name, MlirLocation loc) {
+MlirOperationState mlirOperationStateGet(MlirStringRef name, MlirLocation loc) {
   MlirOperationState state;
   state.name = name;
   state.location = loc;
@@ -215,7 +216,7 @@ void mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n,
 
 MlirOperation mlirOperationCreate(const MlirOperationState *state) {
   assert(state);
-  OperationState cppState(unwrap(state->location), state->name);
+  OperationState cppState(unwrap(state->location), unwrap(state->name));
   SmallVector<Type, 4> resultStorage;
   SmallVector<Value, 8> operandStorage;
   SmallVector<Block *, 2> successorStorage;
@@ -227,7 +228,7 @@ MlirOperation mlirOperationCreate(const MlirOperationState *state) {
 
   cppState.attributes.reserve(state->nAttributes);
   for (intptr_t i = 0; i < state->nAttributes; ++i)
-    cppState.addAttribute(state->attributes[i].name,
+    cppState.addAttribute(unwrap(state->attributes[i].name),
                           unwrap(state->attributes[i].attribute));
 
   for (intptr_t i = 0; i < state->nRegions; ++i)
@@ -302,21 +303,21 @@ intptr_t mlirOperationGetNumAttributes(MlirOperation op) {
 
 MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos) {
   NamedAttribute attr = unwrap(op)->getAttrs()[pos];
-  return MlirNamedAttribute{attr.first.c_str(), wrap(attr.second)};
+  return MlirNamedAttribute{wrap(attr.first.strref()), wrap(attr.second)};
 }
 
 MlirAttribute mlirOperationGetAttributeByName(MlirOperation op,
-                                              const char *name) {
-  return wrap(unwrap(op)->getAttr(name));
+                                              MlirStringRef name) {
+  return wrap(unwrap(op)->getAttr(unwrap(name)));
 }
 
-void mlirOperationSetAttributeByName(MlirOperation op, const char *name,
+void mlirOperationSetAttributeByName(MlirOperation op, MlirStringRef name,
                                      MlirAttribute attr) {
-  unwrap(op)->setAttr(name, unwrap(attr));
+  unwrap(op)->setAttr(unwrap(name), unwrap(attr));
 }
 
-int mlirOperationRemoveAttributeByName(MlirOperation op, const char *name) {
-  auto removeResult = unwrap(op)->removeAttr(name);
+int mlirOperationRemoveAttributeByName(MlirOperation op, MlirStringRef name) {
+  auto removeResult = unwrap(op)->removeAttr(unwrap(name));
   return removeResult == MutableDictionaryAttr::RemoveResult::Removed;
 }
 
@@ -529,8 +530,8 @@ void mlirValuePrint(MlirValue value, MlirStringCallback callback,
 // Type API.
 //===----------------------------------------------------------------------===//
 
-MlirType mlirTypeParseGet(MlirContext context, const char *type) {
-  return wrap(mlir::parseType(type, unwrap(context)));
+MlirType mlirTypeParseGet(MlirContext context, MlirStringRef type) {
+  return wrap(mlir::parseType(unwrap(type), unwrap(context)));
 }
 
 MlirContext mlirTypeGetContext(MlirType type) {
@@ -550,8 +551,8 @@ void mlirTypeDump(MlirType type) { unwrap(type).dump(); }
 // Attribute API.
 //===----------------------------------------------------------------------===//
 
-MlirAttribute mlirAttributeParseGet(MlirContext context, const char *attr) {
-  return wrap(mlir::parseAttribute(attr, unwrap(context)));
+MlirAttribute mlirAttributeParseGet(MlirContext context, MlirStringRef attr) {
+  return wrap(mlir::parseAttribute(unwrap(attr), unwrap(context)));
 }
 
 MlirContext mlirAttributeGetContext(MlirAttribute attribute) {
@@ -574,7 +575,8 @@ void mlirAttributePrint(MlirAttribute attr, MlirStringCallback callback,
 
 void mlirAttributeDump(MlirAttribute attr) { unwrap(attr).dump(); }
 
-MlirNamedAttribute mlirNamedAttributeGet(const char *name, MlirAttribute attr) {
+MlirNamedAttribute mlirNamedAttributeGet(MlirStringRef name,
+                                         MlirAttribute attr) {
   return MlirNamedAttribute{name, attr};
 }
 

diff  --git a/mlir/lib/CAPI/IR/StandardAttributes.cpp b/mlir/lib/CAPI/IR/StandardAttributes.cpp
index c23383fd5085..784c11aec740 100644
--- a/mlir/lib/CAPI/IR/StandardAttributes.cpp
+++ b/mlir/lib/CAPI/IR/StandardAttributes.cpp
@@ -68,8 +68,9 @@ MlirAttribute mlirDictionaryAttrGet(MlirContext ctx, intptr_t numElements,
   SmallVector<NamedAttribute, 8> attributes;
   attributes.reserve(numElements);
   for (intptr_t i = 0; i < numElements; ++i)
-    attributes.emplace_back(Identifier::get(elements[i].name, unwrap(ctx)),
-                            unwrap(elements[i].attribute));
+    attributes.emplace_back(
+        Identifier::get(unwrap(elements[i].name), unwrap(ctx)),
+        unwrap(elements[i].attribute));
   return wrap(DictionaryAttr::get(attributes, unwrap(ctx)));
 }
 
@@ -81,7 +82,7 @@ MlirNamedAttribute mlirDictionaryAttrGetElement(MlirAttribute attr,
                                                 intptr_t pos) {
   NamedAttribute attribute =
       unwrap(attr).cast<DictionaryAttr>().getValue()[pos];
-  return {attribute.first.c_str(), wrap(attribute.second)};
+  return {wrap(attribute.first.strref()), wrap(attribute.second)};
 }
 
 MlirAttribute mlirDictionaryAttrGetElementByName(MlirAttribute attr,

diff  --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index 821ead52c166..afada8ade70d 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -30,23 +30,27 @@ void populateLoopBody(MlirContext ctx, MlirBlock loopBody,
   MlirValue iv = mlirBlockGetArgument(loopBody, 0);
   MlirValue funcArg0 = mlirBlockGetArgument(funcBody, 0);
   MlirValue funcArg1 = mlirBlockGetArgument(funcBody, 1);
-  MlirType f32Type = mlirTypeParseGet(ctx, "f32");
+  MlirType f32Type =
+      mlirTypeParseGet(ctx, mlirStringRefCreateFromCString("f32"));
 
-  MlirOperationState loadLHSState = mlirOperationStateGet("std.load", location);
+  MlirOperationState loadLHSState = mlirOperationStateGet(
+      mlirStringRefCreateFromCString("std.load"), location);
   MlirValue loadLHSOperands[] = {funcArg0, iv};
   mlirOperationStateAddOperands(&loadLHSState, 2, loadLHSOperands);
   mlirOperationStateAddResults(&loadLHSState, 1, &f32Type);
   MlirOperation loadLHS = mlirOperationCreate(&loadLHSState);
   mlirBlockAppendOwnedOperation(loopBody, loadLHS);
 
-  MlirOperationState loadRHSState = mlirOperationStateGet("std.load", location);
+  MlirOperationState loadRHSState = mlirOperationStateGet(
+      mlirStringRefCreateFromCString("std.load"), location);
   MlirValue loadRHSOperands[] = {funcArg1, iv};
   mlirOperationStateAddOperands(&loadRHSState, 2, loadRHSOperands);
   mlirOperationStateAddResults(&loadRHSState, 1, &f32Type);
   MlirOperation loadRHS = mlirOperationCreate(&loadRHSState);
   mlirBlockAppendOwnedOperation(loopBody, loadRHS);
 
-  MlirOperationState addState = mlirOperationStateGet("std.addf", location);
+  MlirOperationState addState = mlirOperationStateGet(
+      mlirStringRefCreateFromCString("std.addf"), location);
   MlirValue addOperands[] = {mlirOperationGetResult(loadLHS, 0),
                              mlirOperationGetResult(loadRHS, 0)};
   mlirOperationStateAddOperands(&addState, 2, addOperands);
@@ -54,13 +58,15 @@ void populateLoopBody(MlirContext ctx, MlirBlock loopBody,
   MlirOperation add = mlirOperationCreate(&addState);
   mlirBlockAppendOwnedOperation(loopBody, add);
 
-  MlirOperationState storeState = mlirOperationStateGet("std.store", location);
+  MlirOperationState storeState = mlirOperationStateGet(
+      mlirStringRefCreateFromCString("std.store"), location);
   MlirValue storeOperands[] = {mlirOperationGetResult(add, 0), funcArg0, iv};
   mlirOperationStateAddOperands(&storeState, 3, storeOperands);
   MlirOperation store = mlirOperationCreate(&storeState);
   mlirBlockAppendOwnedOperation(loopBody, store);
 
-  MlirOperationState yieldState = mlirOperationStateGet("scf.yield", location);
+  MlirOperationState yieldState = mlirOperationStateGet(
+      mlirStringRefCreateFromCString("scf.yield"), location);
   MlirOperation yield = mlirOperationCreate(&yieldState);
   mlirBlockAppendOwnedOperation(loopBody, yield);
 }
@@ -69,31 +75,39 @@ MlirModule makeAndDumpAdd(MlirContext ctx, MlirLocation location) {
   MlirModule moduleOp = mlirModuleCreateEmpty(location);
   MlirBlock moduleBody = mlirModuleGetBody(moduleOp);
 
-  MlirType memrefType = mlirTypeParseGet(ctx, "memref<?xf32>");
+  MlirType memrefType =
+      mlirTypeParseGet(ctx, mlirStringRefCreateFromCString("memref<?xf32>"));
   MlirType funcBodyArgTypes[] = {memrefType, memrefType};
   MlirRegion funcBodyRegion = mlirRegionCreate();
   MlirBlock funcBody = mlirBlockCreate(
       sizeof(funcBodyArgTypes) / sizeof(MlirType), funcBodyArgTypes);
   mlirRegionAppendOwnedBlock(funcBodyRegion, funcBody);
 
-  MlirAttribute funcTypeAttr =
-      mlirAttributeParseGet(ctx, "(memref<?xf32>, memref<?xf32>) -> ()");
-  MlirAttribute funcNameAttr = mlirAttributeParseGet(ctx, "\"add\"");
+  MlirAttribute funcTypeAttr = mlirAttributeParseGet(
+      ctx,
+      mlirStringRefCreateFromCString("(memref<?xf32>, memref<?xf32>) -> ()"));
+  MlirAttribute funcNameAttr =
+      mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("\"add\""));
   MlirNamedAttribute funcAttrs[] = {
-      mlirNamedAttributeGet("type", funcTypeAttr),
-      mlirNamedAttributeGet("sym_name", funcNameAttr)};
-  MlirOperationState funcState = mlirOperationStateGet("func", location);
+      mlirNamedAttributeGet(mlirStringRefCreateFromCString("type"),
+                            funcTypeAttr),
+      mlirNamedAttributeGet(mlirStringRefCreateFromCString("sym_name"),
+                            funcNameAttr)};
+  MlirOperationState funcState =
+      mlirOperationStateGet(mlirStringRefCreateFromCString("func"), location);
   mlirOperationStateAddAttributes(&funcState, 2, funcAttrs);
   mlirOperationStateAddOwnedRegions(&funcState, 1, &funcBodyRegion);
   MlirOperation func = mlirOperationCreate(&funcState);
   mlirBlockInsertOwnedOperation(moduleBody, 0, func);
 
-  MlirType indexType = mlirTypeParseGet(ctx, "index");
-  MlirAttribute indexZeroLiteral = mlirAttributeParseGet(ctx, "0 : index");
-  MlirNamedAttribute indexZeroValueAttr =
-      mlirNamedAttributeGet("value", indexZeroLiteral);
-  MlirOperationState constZeroState =
-      mlirOperationStateGet("std.constant", location);
+  MlirType indexType =
+      mlirTypeParseGet(ctx, mlirStringRefCreateFromCString("index"));
+  MlirAttribute indexZeroLiteral =
+      mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("0 : index"));
+  MlirNamedAttribute indexZeroValueAttr = mlirNamedAttributeGet(
+      mlirStringRefCreateFromCString("value"), indexZeroLiteral);
+  MlirOperationState constZeroState = mlirOperationStateGet(
+      mlirStringRefCreateFromCString("std.constant"), location);
   mlirOperationStateAddResults(&constZeroState, 1, &indexType);
   mlirOperationStateAddAttributes(&constZeroState, 1, &indexZeroValueAttr);
   MlirOperation constZero = mlirOperationCreate(&constZeroState);
@@ -102,7 +116,8 @@ MlirModule makeAndDumpAdd(MlirContext ctx, MlirLocation location) {
   MlirValue funcArg0 = mlirBlockGetArgument(funcBody, 0);
   MlirValue constZeroValue = mlirOperationGetResult(constZero, 0);
   MlirValue dimOperands[] = {funcArg0, constZeroValue};
-  MlirOperationState dimState = mlirOperationStateGet("std.dim", location);
+  MlirOperationState dimState = mlirOperationStateGet(
+      mlirStringRefCreateFromCString("std.dim"), location);
   mlirOperationStateAddOperands(&dimState, 2, dimOperands);
   mlirOperationStateAddResults(&dimState, 1, &indexType);
   MlirOperation dim = mlirOperationCreate(&dimState);
@@ -112,11 +127,12 @@ MlirModule makeAndDumpAdd(MlirContext ctx, MlirLocation location) {
   MlirBlock loopBody = mlirBlockCreate(/*nArgs=*/1, &indexType);
   mlirRegionAppendOwnedBlock(loopBodyRegion, loopBody);
 
-  MlirAttribute indexOneLiteral = mlirAttributeParseGet(ctx, "1 : index");
-  MlirNamedAttribute indexOneValueAttr =
-      mlirNamedAttributeGet("value", indexOneLiteral);
-  MlirOperationState constOneState =
-      mlirOperationStateGet("std.constant", location);
+  MlirAttribute indexOneLiteral =
+      mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("1 : index"));
+  MlirNamedAttribute indexOneValueAttr = mlirNamedAttributeGet(
+      mlirStringRefCreateFromCString("value"), indexOneLiteral);
+  MlirOperationState constOneState = mlirOperationStateGet(
+      mlirStringRefCreateFromCString("std.constant"), location);
   mlirOperationStateAddResults(&constOneState, 1, &indexType);
   mlirOperationStateAddAttributes(&constOneState, 1, &indexOneValueAttr);
   MlirOperation constOne = mlirOperationCreate(&constOneState);
@@ -125,7 +141,8 @@ MlirModule makeAndDumpAdd(MlirContext ctx, MlirLocation location) {
   MlirValue dimValue = mlirOperationGetResult(dim, 0);
   MlirValue constOneValue = mlirOperationGetResult(constOne, 0);
   MlirValue loopOperands[] = {constZeroValue, dimValue, constOneValue};
-  MlirOperationState loopState = mlirOperationStateGet("scf.for", location);
+  MlirOperationState loopState = mlirOperationStateGet(
+      mlirStringRefCreateFromCString("scf.for"), location);
   mlirOperationStateAddOperands(&loopState, 3, loopOperands);
   mlirOperationStateAddOwnedRegions(&loopState, 1, &loopBodyRegion);
   MlirOperation loop = mlirOperationCreate(&loopState);
@@ -133,7 +150,8 @@ MlirModule makeAndDumpAdd(MlirContext ctx, MlirLocation location) {
 
   populateLoopBody(ctx, loopBody, location, funcBody);
 
-  MlirOperationState retState = mlirOperationStateGet("std.return", location);
+  MlirOperationState retState = mlirOperationStateGet(
+      mlirStringRefCreateFromCString("std.return"), location);
   MlirOperation ret = mlirOperationCreate(&retState);
   mlirBlockAppendOwnedOperation(funcBody, ret);
 
@@ -280,9 +298,9 @@ int collectStats(MlirOperation operation) {
   return 0;
 }
 
-static void printToStderr(const char *str, intptr_t len, void *userData) {
+static void printToStderr(MlirStringRef str, void *userData) {
   (void)userData;
-  fwrite(str, 1, len, stderr);
+  fwrite(str.data, 1, str.length, stderr);
 }
 
 static void printFirstOfEach(MlirContext ctx, MlirOperation operation) {
@@ -366,8 +384,8 @@ static void printFirstOfEach(MlirContext ctx, MlirOperation operation) {
 
   // Get a non-existing attribute and assert that it is null (sanity).
   fprintf(stderr, "does_not_exist is null: %d\n",
-          mlirAttributeIsNull(
-              mlirOperationGetAttributeByName(operation, "does_not_exist")));
+          mlirAttributeIsNull(mlirOperationGetAttributeByName(
+              operation, mlirStringRefCreateFromCString("does_not_exist"))));
   // CHECK: does_not_exist is null: 1
 
   // Get result 0 and its type.
@@ -386,7 +404,8 @@ static void printFirstOfEach(MlirContext ctx, MlirOperation operation) {
   // CHECK: Result 0 type: index
 
   // Set a custom attribute.
-  mlirOperationSetAttributeByName(operation, "custom_attr",
+  mlirOperationSetAttributeByName(operation,
+                                  mlirStringRefCreateFromCString("custom_attr"),
                                   mlirBoolAttrGet(ctx, 1));
   fprintf(stderr, "Op with set attr: ");
   mlirOperationPrint(operation, printToStderr, NULL);
@@ -395,12 +414,14 @@ static void printFirstOfEach(MlirContext ctx, MlirOperation operation) {
 
   // Remove the attribute.
   fprintf(stderr, "Remove attr: %d\n",
-          mlirOperationRemoveAttributeByName(operation, "custom_attr"));
+          mlirOperationRemoveAttributeByName(
+              operation, mlirStringRefCreateFromCString("custom_attr")));
   fprintf(stderr, "Remove attr again: %d\n",
-          mlirOperationRemoveAttributeByName(operation, "custom_attr"));
+          mlirOperationRemoveAttributeByName(
+              operation, mlirStringRefCreateFromCString("custom_attr")));
   fprintf(stderr, "Removed attr is null: %d\n",
-          mlirAttributeIsNull(
-              mlirOperationGetAttributeByName(operation, "custom_attr")));
+          mlirAttributeIsNull(mlirOperationGetAttributeByName(
+              operation, mlirStringRefCreateFromCString("custom_attr"))));
   // CHECK: Remove attr: 1
   // CHECK: Remove attr again: 0
   // CHECK: Removed attr is null: 1
@@ -409,7 +430,7 @@ static void printFirstOfEach(MlirContext ctx, MlirOperation operation) {
   int64_t eltsShape[] = {4};
   int32_t eltsData[] = {1, 2, 3, 4};
   mlirOperationSetAttributeByName(
-      operation, "elts",
+      operation, mlirStringRefCreateFromCString("elts"),
       mlirDenseElementsAttrInt32Get(
           mlirRankedTensorTypeGet(1, eltsShape, mlirIntegerTypeGet(ctx, 32)), 4,
           eltsData));
@@ -421,7 +442,9 @@ static void printFirstOfEach(MlirContext ctx, MlirOperation operation) {
   fprintf(stderr, "Op print with all flags: ");
   mlirOperationPrintWithFlags(operation, flags, printToStderr, NULL);
   fprintf(stderr, "\n");
+  // clang-format off
   // CHECK: Op print with all flags: %{{.*}} = "std.constant"() {elts = opaque<"", "0xDEADBEEF"> : tensor<4xi32>, value = 0 : index} : () -> index loc(unknown)
+  // clang-format on
 
   mlirOpPrintingFlagsDestroy(flags);
 }
@@ -450,7 +473,8 @@ static void buildWithInsertionsAndPrint(MlirContext ctx) {
 
   MlirRegion owningRegion = mlirRegionCreate();
   MlirBlock nullBlock = mlirRegionGetFirstBlock(owningRegion);
-  MlirOperationState state = mlirOperationStateGet("insertion.order.test", loc);
+  MlirOperationState state = mlirOperationStateGet(
+      mlirStringRefCreateFromCString("insertion.order.test"), loc);
   mlirOperationStateAddOwnedRegions(&state, 1, &owningRegion);
   MlirOperation op = mlirOperationCreate(&state);
   MlirRegion region = mlirOperationGetRegion(op, 0);
@@ -471,13 +495,20 @@ static void buildWithInsertionsAndPrint(MlirContext ctx) {
   mlirRegionInsertOwnedBlockAfter(region, nullBlock, block1);
   mlirRegionInsertOwnedBlockAfter(region, block3, block4);
 
-  MlirOperationState op1State = mlirOperationStateGet("dummy.op1", loc);
-  MlirOperationState op2State = mlirOperationStateGet("dummy.op2", loc);
-  MlirOperationState op3State = mlirOperationStateGet("dummy.op3", loc);
-  MlirOperationState op4State = mlirOperationStateGet("dummy.op4", loc);
-  MlirOperationState op5State = mlirOperationStateGet("dummy.op5", loc);
-  MlirOperationState op6State = mlirOperationStateGet("dummy.op6", loc);
-  MlirOperationState op7State = mlirOperationStateGet("dummy.op7", loc);
+  MlirOperationState op1State =
+      mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op1"), loc);
+  MlirOperationState op2State =
+      mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op2"), loc);
+  MlirOperationState op3State =
+      mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op3"), loc);
+  MlirOperationState op4State =
+      mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op4"), loc);
+  MlirOperationState op5State =
+      mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op5"), loc);
+  MlirOperationState op6State =
+      mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op6"), loc);
+  MlirOperationState op7State =
+      mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op7"), loc);
   MlirOperation op1 = mlirOperationCreate(&op1State);
   MlirOperation op2 = mlirOperationCreate(&op2State);
   MlirOperation op3 = mlirOperationCreate(&op3State);

diff  --git a/mlir/test/CAPI/pass.c b/mlir/test/CAPI/pass.c
index 58fb54627a2f..b7b9e373feb2 100644
--- a/mlir/test/CAPI/pass.c
+++ b/mlir/test/CAPI/pass.c
@@ -25,13 +25,14 @@ void testRunPassOnModule() {
   MlirContext ctx = mlirContextCreate();
   mlirRegisterAllDialects(ctx);
 
-  MlirModule module =
-      mlirModuleCreateParse(ctx,
-                            // clang-format off
+  MlirModule module = mlirModuleCreateParse(
+      ctx,
+      // clang-format off
+                            mlirStringRefCreateFromCString(
 "func @foo(%arg0 : i32) -> i32 {                                            \n"
 "  %res = addi %arg0, %arg0 : i32                                           \n"
 "  return %res : i32                                                        \n"
-"}");
+"}"));
   // clang-format on
   if (mlirModuleIsNull(module)) {
     fprintf(stderr, "Unexpected failure parsing module.\n");
@@ -63,9 +64,10 @@ void testRunPassOnNestedModule() {
   MlirContext ctx = mlirContextCreate();
   mlirRegisterAllDialects(ctx);
 
-  MlirModule module =
-      mlirModuleCreateParse(ctx,
-                            // clang-format off
+  MlirModule module = mlirModuleCreateParse(
+      ctx,
+      // clang-format off
+                            mlirStringRefCreateFromCString(
 "func @foo(%arg0 : i32) -> i32 {                                            \n"
 "  %res = addi %arg0, %arg0 : i32                                           \n"
 "  return %res : i32                                                        \n"
@@ -75,7 +77,7 @@ void testRunPassOnNestedModule() {
 "    %res = addf %arg0, %arg0 : f32                                         \n"
 "    return %res : f32                                                      \n"
 "  }                                                                        \n"
-"}");
+"}"));
   // clang-format on
   if (mlirModuleIsNull(module))
     exit(1);
@@ -121,9 +123,9 @@ void testRunPassOnNestedModule() {
   mlirContextDestroy(ctx);
 }
 
-static void printToStderr(const char *str, intptr_t len, void *userData) {
+static void printToStderr(MlirStringRef str, void *userData) {
   (void)userData;
-  fwrite(str, 1, len, stderr);
+  fwrite(str.data, 1, str.length, stderr);
 }
 
 void testPrintPassPipeline() {


        


More information about the Mlir-commits mailing list