[Mlir-commits] [mlir] [MLIR] Introduce new C bindings to differentiate between discardable and inherent attributes (PR #66332)

Mehdi Amini llvmlistbot at llvm.org
Sat Sep 16 20:48:53 PDT 2023


https://github.com/joker-eph updated https://github.com/llvm/llvm-project/pull/66332

>From 721724acabe1b909cca35c9f30eb6f355ad97c86 Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Wed, 13 Sep 2023 23:41:02 -0700
Subject: [PATCH] [MLIR] Introduce new C bindings to differentiate between
 discardable and inherent attributes

This is part of the transition toward properly splitting the two groups.
This only introduces new C APIs, the Python bindings are unaffected.
No API is removed.
---
 mlir/include/mlir-c/IR.h         | 52 ++++++++++++++++++++++++++++++++
 mlir/include/mlir/IR/Operation.h | 17 +++++++++++
 mlir/lib/CAPI/IR/IR.cpp          | 47 +++++++++++++++++++++++++++++
 mlir/test/CAPI/ir.c              | 43 +++++++++++++-------------
 4 files changed, 137 insertions(+), 22 deletions(-)

diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index b5c6a3094bc67df..fb6ee2e8423c48f 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -552,25 +552,77 @@ MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumSuccessors(MlirOperation op);
 MLIR_CAPI_EXPORTED MlirBlock mlirOperationGetSuccessor(MlirOperation op,
                                                        intptr_t pos);
 
+/// Returns true if this operation defines an inherent attribute with this name.
+/// Note: the attribute can be optional, so
+/// `mlirOperationGetInherentAttributeByName` can still return a null attribute.
+MLIR_CAPI_EXPORTED bool
+mlirOperationHasInherentAttributeByName(MlirOperation op, MlirStringRef name);
+
+/// Returns an inherent attribute attached to the operation given its name.
+MLIR_CAPI_EXPORTED MlirAttribute
+mlirOperationGetInherentAttributeByName(MlirOperation op, MlirStringRef name);
+
+/// Sets an inherent attribute by name, replacing the existing if it exists.
+/// This has no effect if "name" does not match an inherent attribute.
+MLIR_CAPI_EXPORTED void
+mlirOperationSetInherentAttributeByName(MlirOperation op, MlirStringRef name,
+                                        MlirAttribute attr);
+
+/// Returns the number of discardable attributes attached to the operation.
+MLIR_CAPI_EXPORTED intptr_t
+mlirOperationGetNumDiscardableAttributes(MlirOperation op);
+
+/// Return `pos`-th discardable attribute of the operation.
+MLIR_CAPI_EXPORTED MlirNamedAttribute
+mlirOperationGetDiscardableAttribute(MlirOperation op, intptr_t pos);
+
+/// Returns a discardable attribute attached to the operation given its name.
+MLIR_CAPI_EXPORTED MlirAttribute mlirOperationGetDiscardableAttributeByName(
+    MlirOperation op, MlirStringRef name);
+
+/// Sets a discardable attribute by name, replacing the existing if it exists or
+/// adding a new one otherwise. The new `attr` Attribute is not allowed to be
+/// null, use `mlirOperationRemoveDiscardableAttributeByName` to remove an
+/// Attribute instead.
+MLIR_CAPI_EXPORTED void
+mlirOperationSetDiscardableAttributeByName(MlirOperation op, MlirStringRef name,
+                                           MlirAttribute attr);
+
+/// Removes a discardable attribute by name. Returns false if the attribute was
+/// not found and true if removed.
+MLIR_CAPI_EXPORTED bool
+mlirOperationRemoveDiscardableAttributeByName(MlirOperation op,
+                                              MlirStringRef name);
+
 /// Returns the number of attributes attached to the operation.
+/// Deprecated, please use `mlirOperationGetNumInherentAttributes` or
+/// `mlirOperationGetNumDiscardableAttributes`.
 MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumAttributes(MlirOperation op);
 
 /// Return `pos`-th attribute of the operation.
+/// Deprecated, please use `mlirOperationGetInherentAttribute` or
+/// `mlirOperationGetDiscardableAttribute`.
 MLIR_CAPI_EXPORTED MlirNamedAttribute
 mlirOperationGetAttribute(MlirOperation op, intptr_t pos);
 
 /// Returns an attribute attached to the operation given its name.
+/// Deprecated, please use `mlirOperationGetInherentAttributeByName` or
+/// `mlirOperationGetDiscardableAttributeByName`.
 MLIR_CAPI_EXPORTED MlirAttribute
 mlirOperationGetAttributeByName(MlirOperation op, MlirStringRef name);
 
 /// Sets an attribute by name, replacing the existing if it exists or
 /// adding a new one otherwise.
+/// Deprecated, please use `mlirOperationSetInherentAttributeByName` or
+/// `mlirOperationSetDiscardableAttributeByName`.
 MLIR_CAPI_EXPORTED void mlirOperationSetAttributeByName(MlirOperation op,
                                                         MlirStringRef name,
                                                         MlirAttribute attr);
 
 /// Removes an attribute by name. Returns false if the attribute was not found
 /// and true if removed.
+/// Deprecated, please use `mlirOperationRemoveInherentAttributeByName` or
+/// `mlirOperationRemoveDiscardableAttributeByName`.
 MLIR_CAPI_EXPORTED bool mlirOperationRemoveAttributeByName(MlirOperation op,
                                                            MlirStringRef name);
 
diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index b815eaf8899d6fc..35e9d31a6323173 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -457,6 +457,23 @@ class alignas(8) Operation final
     if (attributes.set(name, value) != value)
       attrs = attributes.getDictionary(getContext());
   }
+  void setDiscardableAttr(StringRef name, Attribute value) {
+    setDiscardableAttr(StringAttr::get(getContext(), name), value);
+  }
+
+  /// Remove the discardable attribute with the specified name if it exists.
+  /// Return the attribute that was erased, or nullptr if there was no attribute
+  /// with such name.
+  Attribute removeDiscardableAttr(StringAttr name) {
+    NamedAttrList attributes(attrs);
+    Attribute removedAttr = attributes.erase(name);
+    if (removedAttr)
+      attrs = attributes.getDictionary(getContext());
+    return removedAttr;
+  }
+  Attribute removeDiscardableAttr(StringRef name) {
+    return removeDiscardableAttr(StringAttr::get(getContext(), name));
+  }
 
   /// Return all of the discardable attributes on this operation.
   ArrayRef<NamedAttribute> getDiscardableAttrs() { return attrs.getValue(); }
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index ef234a912490eea..377eff45877c8bb 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -550,6 +550,53 @@ MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos) {
   return wrap(unwrap(op)->getSuccessor(static_cast<unsigned>(pos)));
 }
 
+MLIR_CAPI_EXPORTED bool
+mlirOperationHasInherentAttributeByName(MlirOperation op, MlirStringRef name) {
+  std::optional<Attribute> attr = unwrap(op)->getInherentAttr(unwrap(name));
+  return attr.has_value();
+}
+
+MlirAttribute mlirOperationGetInherentAttributeByName(MlirOperation op,
+                                                      MlirStringRef name) {
+  std::optional<Attribute> attr = unwrap(op)->getInherentAttr(unwrap(name));
+  if (attr.has_value())
+    return wrap(*attr);
+  return {};
+}
+
+void mlirOperationSetInherentAttributeByName(MlirOperation op,
+                                             MlirStringRef name,
+                                             MlirAttribute attr) {
+  unwrap(op)->setInherentAttr(
+      StringAttr::get(unwrap(op)->getContext(), unwrap(name)), unwrap(attr));
+}
+
+intptr_t mlirOperationGetNumDiscardableAttributes(MlirOperation op) {
+  return static_cast<intptr_t>(unwrap(op)->getDiscardableAttrs().size());
+}
+
+MlirNamedAttribute mlirOperationGetDiscardableAttribute(MlirOperation op,
+                                                        intptr_t pos) {
+  NamedAttribute attr = unwrap(op)->getDiscardableAttrs()[pos];
+  return MlirNamedAttribute{wrap(attr.getName()), wrap(attr.getValue())};
+}
+
+MlirAttribute mlirOperationGetDiscardableAttributeByName(MlirOperation op,
+                                                         MlirStringRef name) {
+  return wrap(unwrap(op)->getDiscardableAttr(unwrap(name)));
+}
+
+void mlirOperationSetDiscardableAttributeByName(MlirOperation op,
+                                                MlirStringRef name,
+                                                MlirAttribute attr) {
+  unwrap(op)->setDiscardableAttr(unwrap(name), unwrap(attr));
+}
+
+bool mlirOperationRemoveDiscardableAttributeByName(MlirOperation op,
+                                                   MlirStringRef name) {
+  return !!unwrap(op)->removeDiscardableAttr(unwrap(name));
+}
+
 intptr_t mlirOperationGetNumAttributes(MlirOperation op) {
   return static_cast<intptr_t>(unwrap(op)->getAttrs().size());
 }
diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index 5725d05a3e132f7..634f8436d1398fc 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -407,24 +407,23 @@ static void printFirstOfEach(MlirContext ctx, MlirOperation operation) {
   fprintf(stderr, "\n");
   // CHECK: Terminator: func.return
 
-  // Get the attribute by index.
-  MlirNamedAttribute namedAttr0 = mlirOperationGetAttribute(operation, 0);
-  fprintf(stderr, "Get attr 0: ");
-  mlirAttributePrint(namedAttr0.attribute, printToStderr, NULL);
+  // Get the attribute by name.
+  bool hasValueAttr = mlirOperationHasInherentAttributeByName(
+      operation, mlirStringRefCreateFromCString("value"));
+  if (hasValueAttr)
+    // CHECK: Has attr "value"
+    fprintf(stderr, "Has attr \"value\"");
+
+  MlirAttribute valueAttr0 = mlirOperationGetInherentAttributeByName(
+      operation, mlirStringRefCreateFromCString("value"));
+  fprintf(stderr, "Get attr \"value\": ");
+  mlirAttributePrint(valueAttr0, printToStderr, NULL);
   fprintf(stderr, "\n");
-  // CHECK: Get attr 0: 0 : index
-
-  // Now re-get the attribute by name.
-  MlirAttribute attr0ByName = mlirOperationGetAttributeByName(
-      operation, mlirIdentifierStr(namedAttr0.name));
-  fprintf(stderr, "Get attr 0 by name: ");
-  mlirAttributePrint(attr0ByName, printToStderr, NULL);
-  fprintf(stderr, "\n");
-  // CHECK: Get attr 0 by name: 0 : index
+  // CHECK: Get attr "value": 0 : index
 
   // Get a non-existing attribute and assert that it is null (sanity).
   fprintf(stderr, "does_not_exist is null: %d\n",
-          mlirAttributeIsNull(mlirOperationGetAttributeByName(
+          mlirAttributeIsNull(mlirOperationGetDiscardableAttributeByName(
               operation, mlirStringRefCreateFromCString("does_not_exist"))));
   // CHECK: does_not_exist is null: 1
 
@@ -443,10 +442,10 @@ static void printFirstOfEach(MlirContext ctx, MlirOperation operation) {
   fprintf(stderr, "\n");
   // CHECK: Result 0 type: index
 
-  // Set a custom attribute.
-  mlirOperationSetAttributeByName(operation,
-                                  mlirStringRefCreateFromCString("custom_attr"),
-                                  mlirBoolAttrGet(ctx, 1));
+  // Set a discardable attribute.
+  mlirOperationSetDiscardableAttributeByName(
+      operation, mlirStringRefCreateFromCString("custom_attr"),
+      mlirBoolAttrGet(ctx, 1));
   fprintf(stderr, "Op with set attr: ");
   mlirOperationPrint(operation, printToStderr, NULL);
   fprintf(stderr, "\n");
@@ -454,13 +453,13 @@ static void printFirstOfEach(MlirContext ctx, MlirOperation operation) {
 
   // Remove the attribute.
   fprintf(stderr, "Remove attr: %d\n",
-          mlirOperationRemoveAttributeByName(
+          mlirOperationRemoveDiscardableAttributeByName(
               operation, mlirStringRefCreateFromCString("custom_attr")));
   fprintf(stderr, "Remove attr again: %d\n",
-          mlirOperationRemoveAttributeByName(
+          mlirOperationRemoveDiscardableAttributeByName(
               operation, mlirStringRefCreateFromCString("custom_attr")));
   fprintf(stderr, "Removed attr is null: %d\n",
-          mlirAttributeIsNull(mlirOperationGetAttributeByName(
+          mlirAttributeIsNull(mlirOperationGetDiscardableAttributeByName(
               operation, mlirStringRefCreateFromCString("custom_attr"))));
   // CHECK: Remove attr: 1
   // CHECK: Remove attr again: 0
@@ -469,7 +468,7 @@ static void printFirstOfEach(MlirContext ctx, MlirOperation operation) {
   // Add a large attribute to verify printing flags.
   int64_t eltsShape[] = {4};
   int32_t eltsData[] = {1, 2, 3, 4};
-  mlirOperationSetAttributeByName(
+  mlirOperationSetDiscardableAttributeByName(
       operation, mlirStringRefCreateFromCString("elts"),
       mlirDenseElementsAttrInt32Get(
           mlirRankedTensorTypeGet(1, eltsShape, mlirIntegerTypeGet(ctx, 32),



More information about the Mlir-commits mailing list