[Mlir-commits] [mlir] [MLIR] Introduce new C bindings to differentiate between discardable and inherent attributes (PR #66332)
Mehdi Amini
llvmlistbot at llvm.org
Sat Sep 16 20:48:53 PDT 2023
https://github.com/joker-eph updated https://github.com/llvm/llvm-project/pull/66332
>From 721724acabe1b909cca35c9f30eb6f355ad97c86 Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Wed, 13 Sep 2023 23:41:02 -0700
Subject: [PATCH] [MLIR] Introduce new C bindings to differentiate between
discardable and inherent attributes
This is part of the transition toward properly splitting the two groups.
This only introduces new C APIs, the Python bindings are unaffected.
No API is removed.
---
mlir/include/mlir-c/IR.h | 52 ++++++++++++++++++++++++++++++++
mlir/include/mlir/IR/Operation.h | 17 +++++++++++
mlir/lib/CAPI/IR/IR.cpp | 47 +++++++++++++++++++++++++++++
mlir/test/CAPI/ir.c | 43 +++++++++++++-------------
4 files changed, 137 insertions(+), 22 deletions(-)
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index b5c6a3094bc67df..fb6ee2e8423c48f 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -552,25 +552,77 @@ MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumSuccessors(MlirOperation op);
MLIR_CAPI_EXPORTED MlirBlock mlirOperationGetSuccessor(MlirOperation op,
intptr_t pos);
+/// Returns true if this operation defines an inherent attribute with this name.
+/// Note: the attribute can be optional, so
+/// `mlirOperationGetInherentAttributeByName` can still return a null attribute.
+MLIR_CAPI_EXPORTED bool
+mlirOperationHasInherentAttributeByName(MlirOperation op, MlirStringRef name);
+
+/// Returns an inherent attribute attached to the operation given its name.
+MLIR_CAPI_EXPORTED MlirAttribute
+mlirOperationGetInherentAttributeByName(MlirOperation op, MlirStringRef name);
+
+/// Sets an inherent attribute by name, replacing the existing if it exists.
+/// This has no effect if "name" does not match an inherent attribute.
+MLIR_CAPI_EXPORTED void
+mlirOperationSetInherentAttributeByName(MlirOperation op, MlirStringRef name,
+ MlirAttribute attr);
+
+/// Returns the number of discardable attributes attached to the operation.
+MLIR_CAPI_EXPORTED intptr_t
+mlirOperationGetNumDiscardableAttributes(MlirOperation op);
+
+/// Return `pos`-th discardable attribute of the operation.
+MLIR_CAPI_EXPORTED MlirNamedAttribute
+mlirOperationGetDiscardableAttribute(MlirOperation op, intptr_t pos);
+
+/// Returns a discardable attribute attached to the operation given its name.
+MLIR_CAPI_EXPORTED MlirAttribute mlirOperationGetDiscardableAttributeByName(
+ MlirOperation op, MlirStringRef name);
+
+/// Sets a discardable attribute by name, replacing the existing if it exists or
+/// adding a new one otherwise. The new `attr` Attribute is not allowed to be
+/// null, use `mlirOperationRemoveDiscardableAttributeByName` to remove an
+/// Attribute instead.
+MLIR_CAPI_EXPORTED void
+mlirOperationSetDiscardableAttributeByName(MlirOperation op, MlirStringRef name,
+ MlirAttribute attr);
+
+/// Removes a discardable attribute by name. Returns false if the attribute was
+/// not found and true if removed.
+MLIR_CAPI_EXPORTED bool
+mlirOperationRemoveDiscardableAttributeByName(MlirOperation op,
+ MlirStringRef name);
+
/// Returns the number of attributes attached to the operation.
+/// Deprecated, please use `mlirOperationGetNumInherentAttributes` or
+/// `mlirOperationGetNumDiscardableAttributes`.
MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumAttributes(MlirOperation op);
/// Return `pos`-th attribute of the operation.
+/// Deprecated, please use `mlirOperationGetInherentAttribute` or
+/// `mlirOperationGetDiscardableAttribute`.
MLIR_CAPI_EXPORTED MlirNamedAttribute
mlirOperationGetAttribute(MlirOperation op, intptr_t pos);
/// Returns an attribute attached to the operation given its name.
+/// Deprecated, please use `mlirOperationGetInherentAttributeByName` or
+/// `mlirOperationGetDiscardableAttributeByName`.
MLIR_CAPI_EXPORTED MlirAttribute
mlirOperationGetAttributeByName(MlirOperation op, MlirStringRef name);
/// Sets an attribute by name, replacing the existing if it exists or
/// adding a new one otherwise.
+/// Deprecated, please use `mlirOperationSetInherentAttributeByName` or
+/// `mlirOperationSetDiscardableAttributeByName`.
MLIR_CAPI_EXPORTED void mlirOperationSetAttributeByName(MlirOperation op,
MlirStringRef name,
MlirAttribute attr);
/// Removes an attribute by name. Returns false if the attribute was not found
/// and true if removed.
+/// Deprecated, please use `mlirOperationRemoveInherentAttributeByName` or
+/// `mlirOperationRemoveDiscardableAttributeByName`.
MLIR_CAPI_EXPORTED bool mlirOperationRemoveAttributeByName(MlirOperation op,
MlirStringRef name);
diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index b815eaf8899d6fc..35e9d31a6323173 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -457,6 +457,23 @@ class alignas(8) Operation final
if (attributes.set(name, value) != value)
attrs = attributes.getDictionary(getContext());
}
+ void setDiscardableAttr(StringRef name, Attribute value) {
+ setDiscardableAttr(StringAttr::get(getContext(), name), value);
+ }
+
+ /// Remove the discardable attribute with the specified name if it exists.
+ /// Return the attribute that was erased, or nullptr if there was no attribute
+ /// with such name.
+ Attribute removeDiscardableAttr(StringAttr name) {
+ NamedAttrList attributes(attrs);
+ Attribute removedAttr = attributes.erase(name);
+ if (removedAttr)
+ attrs = attributes.getDictionary(getContext());
+ return removedAttr;
+ }
+ Attribute removeDiscardableAttr(StringRef name) {
+ return removeDiscardableAttr(StringAttr::get(getContext(), name));
+ }
/// Return all of the discardable attributes on this operation.
ArrayRef<NamedAttribute> getDiscardableAttrs() { return attrs.getValue(); }
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index ef234a912490eea..377eff45877c8bb 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -550,6 +550,53 @@ MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos) {
return wrap(unwrap(op)->getSuccessor(static_cast<unsigned>(pos)));
}
+MLIR_CAPI_EXPORTED bool
+mlirOperationHasInherentAttributeByName(MlirOperation op, MlirStringRef name) {
+ std::optional<Attribute> attr = unwrap(op)->getInherentAttr(unwrap(name));
+ return attr.has_value();
+}
+
+MlirAttribute mlirOperationGetInherentAttributeByName(MlirOperation op,
+ MlirStringRef name) {
+ std::optional<Attribute> attr = unwrap(op)->getInherentAttr(unwrap(name));
+ if (attr.has_value())
+ return wrap(*attr);
+ return {};
+}
+
+void mlirOperationSetInherentAttributeByName(MlirOperation op,
+ MlirStringRef name,
+ MlirAttribute attr) {
+ unwrap(op)->setInherentAttr(
+ StringAttr::get(unwrap(op)->getContext(), unwrap(name)), unwrap(attr));
+}
+
+intptr_t mlirOperationGetNumDiscardableAttributes(MlirOperation op) {
+ return static_cast<intptr_t>(unwrap(op)->getDiscardableAttrs().size());
+}
+
+MlirNamedAttribute mlirOperationGetDiscardableAttribute(MlirOperation op,
+ intptr_t pos) {
+ NamedAttribute attr = unwrap(op)->getDiscardableAttrs()[pos];
+ return MlirNamedAttribute{wrap(attr.getName()), wrap(attr.getValue())};
+}
+
+MlirAttribute mlirOperationGetDiscardableAttributeByName(MlirOperation op,
+ MlirStringRef name) {
+ return wrap(unwrap(op)->getDiscardableAttr(unwrap(name)));
+}
+
+void mlirOperationSetDiscardableAttributeByName(MlirOperation op,
+ MlirStringRef name,
+ MlirAttribute attr) {
+ unwrap(op)->setDiscardableAttr(unwrap(name), unwrap(attr));
+}
+
+bool mlirOperationRemoveDiscardableAttributeByName(MlirOperation op,
+ MlirStringRef name) {
+ return !!unwrap(op)->removeDiscardableAttr(unwrap(name));
+}
+
intptr_t mlirOperationGetNumAttributes(MlirOperation op) {
return static_cast<intptr_t>(unwrap(op)->getAttrs().size());
}
diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index 5725d05a3e132f7..634f8436d1398fc 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -407,24 +407,23 @@ static void printFirstOfEach(MlirContext ctx, MlirOperation operation) {
fprintf(stderr, "\n");
// CHECK: Terminator: func.return
- // Get the attribute by index.
- MlirNamedAttribute namedAttr0 = mlirOperationGetAttribute(operation, 0);
- fprintf(stderr, "Get attr 0: ");
- mlirAttributePrint(namedAttr0.attribute, printToStderr, NULL);
+ // Get the attribute by name.
+ bool hasValueAttr = mlirOperationHasInherentAttributeByName(
+ operation, mlirStringRefCreateFromCString("value"));
+ if (hasValueAttr)
+ // CHECK: Has attr "value"
+ fprintf(stderr, "Has attr \"value\"");
+
+ MlirAttribute valueAttr0 = mlirOperationGetInherentAttributeByName(
+ operation, mlirStringRefCreateFromCString("value"));
+ fprintf(stderr, "Get attr \"value\": ");
+ mlirAttributePrint(valueAttr0, printToStderr, NULL);
fprintf(stderr, "\n");
- // CHECK: Get attr 0: 0 : index
-
- // Now re-get the attribute by name.
- MlirAttribute attr0ByName = mlirOperationGetAttributeByName(
- operation, mlirIdentifierStr(namedAttr0.name));
- fprintf(stderr, "Get attr 0 by name: ");
- mlirAttributePrint(attr0ByName, printToStderr, NULL);
- fprintf(stderr, "\n");
- // CHECK: Get attr 0 by name: 0 : index
+ // CHECK: Get attr "value": 0 : index
// Get a non-existing attribute and assert that it is null (sanity).
fprintf(stderr, "does_not_exist is null: %d\n",
- mlirAttributeIsNull(mlirOperationGetAttributeByName(
+ mlirAttributeIsNull(mlirOperationGetDiscardableAttributeByName(
operation, mlirStringRefCreateFromCString("does_not_exist"))));
// CHECK: does_not_exist is null: 1
@@ -443,10 +442,10 @@ static void printFirstOfEach(MlirContext ctx, MlirOperation operation) {
fprintf(stderr, "\n");
// CHECK: Result 0 type: index
- // Set a custom attribute.
- mlirOperationSetAttributeByName(operation,
- mlirStringRefCreateFromCString("custom_attr"),
- mlirBoolAttrGet(ctx, 1));
+ // Set a discardable attribute.
+ mlirOperationSetDiscardableAttributeByName(
+ operation, mlirStringRefCreateFromCString("custom_attr"),
+ mlirBoolAttrGet(ctx, 1));
fprintf(stderr, "Op with set attr: ");
mlirOperationPrint(operation, printToStderr, NULL);
fprintf(stderr, "\n");
@@ -454,13 +453,13 @@ static void printFirstOfEach(MlirContext ctx, MlirOperation operation) {
// Remove the attribute.
fprintf(stderr, "Remove attr: %d\n",
- mlirOperationRemoveAttributeByName(
+ mlirOperationRemoveDiscardableAttributeByName(
operation, mlirStringRefCreateFromCString("custom_attr")));
fprintf(stderr, "Remove attr again: %d\n",
- mlirOperationRemoveAttributeByName(
+ mlirOperationRemoveDiscardableAttributeByName(
operation, mlirStringRefCreateFromCString("custom_attr")));
fprintf(stderr, "Removed attr is null: %d\n",
- mlirAttributeIsNull(mlirOperationGetAttributeByName(
+ mlirAttributeIsNull(mlirOperationGetDiscardableAttributeByName(
operation, mlirStringRefCreateFromCString("custom_attr"))));
// CHECK: Remove attr: 1
// CHECK: Remove attr again: 0
@@ -469,7 +468,7 @@ static void printFirstOfEach(MlirContext ctx, MlirOperation operation) {
// Add a large attribute to verify printing flags.
int64_t eltsShape[] = {4};
int32_t eltsData[] = {1, 2, 3, 4};
- mlirOperationSetAttributeByName(
+ mlirOperationSetDiscardableAttributeByName(
operation, mlirStringRefCreateFromCString("elts"),
mlirDenseElementsAttrInt32Get(
mlirRankedTensorTypeGet(1, eltsShape, mlirIntegerTypeGet(ctx, 32),
More information about the Mlir-commits
mailing list