[Mlir-commits] [mlir] 7675f54 - [MLIR] Introduce new C bindings to differentiate between discardable and inherent attributes (#66332)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Sep 26 01:53:20 PDT 2023


Author: Mehdi Amini
Date: 2023-09-26T01:53:17-07:00
New Revision: 7675f541f75baa20e8ec007cd625a837e89fc01f

URL: https://github.com/llvm/llvm-project/commit/7675f541f75baa20e8ec007cd625a837e89fc01f
DIFF: https://github.com/llvm/llvm-project/commit/7675f541f75baa20e8ec007cd625a837e89fc01f.diff

LOG: [MLIR] Introduce new C bindings to differentiate between discardable and inherent attributes (#66332)

This is part of the transition toward properly splitting the two groups.
This only introduces new C APIs, the Python bindings are unaffected. No
API is removed.

Added: 
    

Modified: 
    mlir/include/mlir-c/IR.h
    mlir/include/mlir/IR/Operation.h
    mlir/lib/CAPI/IR/IR.cpp
    mlir/test/CAPI/ir.c

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 68eccab6dbacaef..a6408317db69e61 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -576,25 +576,77 @@ MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumSuccessors(MlirOperation op);
 MLIR_CAPI_EXPORTED MlirBlock mlirOperationGetSuccessor(MlirOperation op,
                                                        intptr_t pos);
 
+/// Returns true if this operation defines an inherent attribute with this name.
+/// Note: the attribute can be optional, so
+/// `mlirOperationGetInherentAttributeByName` can still return a null attribute.
+MLIR_CAPI_EXPORTED bool
+mlirOperationHasInherentAttributeByName(MlirOperation op, MlirStringRef name);
+
+/// Returns an inherent attribute attached to the operation given its name.
+MLIR_CAPI_EXPORTED MlirAttribute
+mlirOperationGetInherentAttributeByName(MlirOperation op, MlirStringRef name);
+
+/// Sets an inherent attribute by name, replacing the existing if it exists.
+/// This has no effect if "name" does not match an inherent attribute.
+MLIR_CAPI_EXPORTED void
+mlirOperationSetInherentAttributeByName(MlirOperation op, MlirStringRef name,
+                                        MlirAttribute attr);
+
+/// Returns the number of discardable attributes attached to the operation.
+MLIR_CAPI_EXPORTED intptr_t
+mlirOperationGetNumDiscardableAttributes(MlirOperation op);
+
+/// Return `pos`-th discardable attribute of the operation.
+MLIR_CAPI_EXPORTED MlirNamedAttribute
+mlirOperationGetDiscardableAttribute(MlirOperation op, intptr_t pos);
+
+/// Returns a discardable attribute attached to the operation given its name.
+MLIR_CAPI_EXPORTED MlirAttribute mlirOperationGetDiscardableAttributeByName(
+    MlirOperation op, MlirStringRef name);
+
+/// Sets a discardable attribute by name, replacing the existing if it exists or
+/// adding a new one otherwise. The new `attr` Attribute is not allowed to be
+/// null, use `mlirOperationRemoveDiscardableAttributeByName` to remove an
+/// Attribute instead.
+MLIR_CAPI_EXPORTED void
+mlirOperationSetDiscardableAttributeByName(MlirOperation op, MlirStringRef name,
+                                           MlirAttribute attr);
+
+/// Removes a discardable attribute by name. Returns false if the attribute was
+/// not found and true if removed.
+MLIR_CAPI_EXPORTED bool
+mlirOperationRemoveDiscardableAttributeByName(MlirOperation op,
+                                              MlirStringRef name);
+
 /// Returns the number of attributes attached to the operation.
+/// Deprecated, please use `mlirOperationGetNumInherentAttributes` or
+/// `mlirOperationGetNumDiscardableAttributes`.
 MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumAttributes(MlirOperation op);
 
 /// Return `pos`-th attribute of the operation.
+/// Deprecated, please use `mlirOperationGetInherentAttribute` or
+/// `mlirOperationGetDiscardableAttribute`.
 MLIR_CAPI_EXPORTED MlirNamedAttribute
 mlirOperationGetAttribute(MlirOperation op, intptr_t pos);
 
 /// Returns an attribute attached to the operation given its name.
+/// Deprecated, please use `mlirOperationGetInherentAttributeByName` or
+/// `mlirOperationGetDiscardableAttributeByName`.
 MLIR_CAPI_EXPORTED MlirAttribute
 mlirOperationGetAttributeByName(MlirOperation op, MlirStringRef name);
 
 /// Sets an attribute by name, replacing the existing if it exists or
 /// adding a new one otherwise.
+/// Deprecated, please use `mlirOperationSetInherentAttributeByName` or
+/// `mlirOperationSetDiscardableAttributeByName`.
 MLIR_CAPI_EXPORTED void mlirOperationSetAttributeByName(MlirOperation op,
                                                         MlirStringRef name,
                                                         MlirAttribute attr);
 
 /// Removes an attribute by name. Returns false if the attribute was not found
 /// and true if removed.
+/// Deprecated, please use `mlirOperationRemoveInherentAttributeByName` or
+/// `mlirOperationRemoveDiscardableAttributeByName`.
 MLIR_CAPI_EXPORTED bool mlirOperationRemoveAttributeByName(MlirOperation op,
                                                            MlirStringRef name);
 

diff  --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index b815eaf8899d6fc..35e9d31a6323173 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -457,6 +457,23 @@ class alignas(8) Operation final
     if (attributes.set(name, value) != value)
       attrs = attributes.getDictionary(getContext());
   }
+  void setDiscardableAttr(StringRef name, Attribute value) {
+    setDiscardableAttr(StringAttr::get(getContext(), name), value);
+  }
+
+  /// Remove the discardable attribute with the specified name if it exists.
+  /// Return the attribute that was erased, or nullptr if there was no attribute
+  /// with such name.
+  Attribute removeDiscardableAttr(StringAttr name) {
+    NamedAttrList attributes(attrs);
+    Attribute removedAttr = attributes.erase(name);
+    if (removedAttr)
+      attrs = attributes.getDictionary(getContext());
+    return removedAttr;
+  }
+  Attribute removeDiscardableAttr(StringRef name) {
+    return removeDiscardableAttr(StringAttr::get(getContext(), name));
+  }
 
   /// Return all of the discardable attributes on this operation.
   ArrayRef<NamedAttribute> getDiscardableAttrs() { return attrs.getValue(); }

diff  --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 7f5c2aaee67382b..04b386b8268e8d4 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -595,6 +595,53 @@ MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos) {
   return wrap(unwrap(op)->getSuccessor(static_cast<unsigned>(pos)));
 }
 
+MLIR_CAPI_EXPORTED bool
+mlirOperationHasInherentAttributeByName(MlirOperation op, MlirStringRef name) {
+  std::optional<Attribute> attr = unwrap(op)->getInherentAttr(unwrap(name));
+  return attr.has_value();
+}
+
+MlirAttribute mlirOperationGetInherentAttributeByName(MlirOperation op,
+                                                      MlirStringRef name) {
+  std::optional<Attribute> attr = unwrap(op)->getInherentAttr(unwrap(name));
+  if (attr.has_value())
+    return wrap(*attr);
+  return {};
+}
+
+void mlirOperationSetInherentAttributeByName(MlirOperation op,
+                                             MlirStringRef name,
+                                             MlirAttribute attr) {
+  unwrap(op)->setInherentAttr(
+      StringAttr::get(unwrap(op)->getContext(), unwrap(name)), unwrap(attr));
+}
+
+intptr_t mlirOperationGetNumDiscardableAttributes(MlirOperation op) {
+  return static_cast<intptr_t>(unwrap(op)->getDiscardableAttrs().size());
+}
+
+MlirNamedAttribute mlirOperationGetDiscardableAttribute(MlirOperation op,
+                                                        intptr_t pos) {
+  NamedAttribute attr = unwrap(op)->getDiscardableAttrs()[pos];
+  return MlirNamedAttribute{wrap(attr.getName()), wrap(attr.getValue())};
+}
+
+MlirAttribute mlirOperationGetDiscardableAttributeByName(MlirOperation op,
+                                                         MlirStringRef name) {
+  return wrap(unwrap(op)->getDiscardableAttr(unwrap(name)));
+}
+
+void mlirOperationSetDiscardableAttributeByName(MlirOperation op,
+                                                MlirStringRef name,
+                                                MlirAttribute attr) {
+  unwrap(op)->setDiscardableAttr(unwrap(name), unwrap(attr));
+}
+
+bool mlirOperationRemoveDiscardableAttributeByName(MlirOperation op,
+                                                   MlirStringRef name) {
+  return !!unwrap(op)->removeDiscardableAttr(unwrap(name));
+}
+
 intptr_t mlirOperationGetNumAttributes(MlirOperation op) {
   return static_cast<intptr_t>(unwrap(op)->getAttrs().size());
 }

diff  --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index c031e61945d03b6..a181332e219db8a 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -407,24 +407,23 @@ static void printFirstOfEach(MlirContext ctx, MlirOperation operation) {
   fprintf(stderr, "\n");
   // CHECK: Terminator: func.return
 
-  // Get the attribute by index.
-  MlirNamedAttribute namedAttr0 = mlirOperationGetAttribute(operation, 0);
-  fprintf(stderr, "Get attr 0: ");
-  mlirAttributePrint(namedAttr0.attribute, printToStderr, NULL);
+  // Get the attribute by name.
+  bool hasValueAttr = mlirOperationHasInherentAttributeByName(
+      operation, mlirStringRefCreateFromCString("value"));
+  if (hasValueAttr)
+    // CHECK: Has attr "value"
+    fprintf(stderr, "Has attr \"value\"");
+
+  MlirAttribute valueAttr0 = mlirOperationGetInherentAttributeByName(
+      operation, mlirStringRefCreateFromCString("value"));
+  fprintf(stderr, "Get attr \"value\": ");
+  mlirAttributePrint(valueAttr0, printToStderr, NULL);
   fprintf(stderr, "\n");
-  // CHECK: Get attr 0: 0 : index
-
-  // Now re-get the attribute by name.
-  MlirAttribute attr0ByName = mlirOperationGetAttributeByName(
-      operation, mlirIdentifierStr(namedAttr0.name));
-  fprintf(stderr, "Get attr 0 by name: ");
-  mlirAttributePrint(attr0ByName, printToStderr, NULL);
-  fprintf(stderr, "\n");
-  // CHECK: Get attr 0 by name: 0 : index
+  // CHECK: Get attr "value": 0 : index
 
   // Get a non-existing attribute and assert that it is null (sanity).
   fprintf(stderr, "does_not_exist is null: %d\n",
-          mlirAttributeIsNull(mlirOperationGetAttributeByName(
+          mlirAttributeIsNull(mlirOperationGetDiscardableAttributeByName(
               operation, mlirStringRefCreateFromCString("does_not_exist"))));
   // CHECK: does_not_exist is null: 1
 
@@ -443,10 +442,10 @@ static void printFirstOfEach(MlirContext ctx, MlirOperation operation) {
   fprintf(stderr, "\n");
   // CHECK: Result 0 type: index
 
-  // Set a custom attribute.
-  mlirOperationSetAttributeByName(operation,
-                                  mlirStringRefCreateFromCString("custom_attr"),
-                                  mlirBoolAttrGet(ctx, 1));
+  // Set a discardable attribute.
+  mlirOperationSetDiscardableAttributeByName(
+      operation, mlirStringRefCreateFromCString("custom_attr"),
+      mlirBoolAttrGet(ctx, 1));
   fprintf(stderr, "Op with set attr: ");
   mlirOperationPrint(operation, printToStderr, NULL);
   fprintf(stderr, "\n");
@@ -454,13 +453,13 @@ static void printFirstOfEach(MlirContext ctx, MlirOperation operation) {
 
   // Remove the attribute.
   fprintf(stderr, "Remove attr: %d\n",
-          mlirOperationRemoveAttributeByName(
+          mlirOperationRemoveDiscardableAttributeByName(
               operation, mlirStringRefCreateFromCString("custom_attr")));
   fprintf(stderr, "Remove attr again: %d\n",
-          mlirOperationRemoveAttributeByName(
+          mlirOperationRemoveDiscardableAttributeByName(
               operation, mlirStringRefCreateFromCString("custom_attr")));
   fprintf(stderr, "Removed attr is null: %d\n",
-          mlirAttributeIsNull(mlirOperationGetAttributeByName(
+          mlirAttributeIsNull(mlirOperationGetDiscardableAttributeByName(
               operation, mlirStringRefCreateFromCString("custom_attr"))));
   // CHECK: Remove attr: 1
   // CHECK: Remove attr again: 0
@@ -469,7 +468,7 @@ static void printFirstOfEach(MlirContext ctx, MlirOperation operation) {
   // Add a large attribute to verify printing flags.
   int64_t eltsShape[] = {4};
   int32_t eltsData[] = {1, 2, 3, 4};
-  mlirOperationSetAttributeByName(
+  mlirOperationSetDiscardableAttributeByName(
       operation, mlirStringRefCreateFromCString("elts"),
       mlirDenseElementsAttrInt32Get(
           mlirRankedTensorTypeGet(1, eltsShape, mlirIntegerTypeGet(ctx, 32),


        


More information about the Mlir-commits mailing list