[Mlir-commits] [mlir] 6771b98 - [mlir][CAPI] Add mlirAttributeGetType function.
Stella Laurenzo
llvmlistbot at llvm.org
Thu Oct 15 18:34:55 PDT 2020
Author: Stella Laurenzo
Date: 2020-10-15T18:33:50-07:00
New Revision: 6771b98c4e4d5c0bd0a78a876bd212a76ec80a24
URL: https://github.com/llvm/llvm-project/commit/6771b98c4e4d5c0bd0a78a876bd212a76ec80a24
DIFF: https://github.com/llvm/llvm-project/commit/6771b98c4e4d5c0bd0a78a876bd212a76ec80a24.diff
LOG: [mlir][CAPI] Add mlirAttributeGetType function.
* Also fixes the const-ness of the various DenseElementsAttr construction functions.
* Both issues identified when trying to use the DenseElementsAttr functions.
Differential Revision: https://reviews.llvm.org/D89517
Added:
Modified:
mlir/include/mlir-c/IR.h
mlir/include/mlir-c/StandardAttributes.h
mlir/lib/CAPI/IR/IR.cpp
mlir/lib/CAPI/IR/StandardAttributes.cpp
mlir/test/CAPI/ir.c
Removed:
################################################################################
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index b2a17869e2b3..a00b96119298 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -440,6 +440,9 @@ MlirAttribute mlirAttributeParseGet(MlirContext context, const char *attr);
/** Gets the context that an attribute was created with. */
MlirContext mlirAttributeGetContext(MlirAttribute attribute);
+/** Gets the type of this attribute. */
+MlirType mlirAttributeGetType(MlirAttribute attribute);
+
/** Checks whether an attribute is null. */
static inline int mlirAttributeIsNull(MlirAttribute attr) { return !attr.ptr; }
diff --git a/mlir/include/mlir-c/StandardAttributes.h b/mlir/include/mlir-c/StandardAttributes.h
index 2fc2ecc9ee1d..6227c8ae89ed 100644
--- a/mlir/include/mlir-c/StandardAttributes.h
+++ b/mlir/include/mlir-c/StandardAttributes.h
@@ -319,25 +319,26 @@ MlirAttribute mlirDenseElementsAttrDoubleSplatGet(MlirType shapedType,
* of a specific type. Expects the element type of the shaped type to match the
* data element type. */
MlirAttribute mlirDenseElementsAttrBoolGet(MlirType shapedType,
- intptr_t numElements, int *elements);
+ intptr_t numElements,
+ const int *elements);
MlirAttribute mlirDenseElementsAttrUInt32Get(MlirType shapedType,
intptr_t numElements,
- uint32_t *elements);
+ const uint32_t *elements);
MlirAttribute mlirDenseElementsAttrInt32Get(MlirType shapedType,
intptr_t numElements,
- int32_t *elements);
+ const int32_t *elements);
MlirAttribute mlirDenseElementsAttrUInt64Get(MlirType shapedType,
intptr_t numElements,
- uint64_t *elements);
+ const uint64_t *elements);
MlirAttribute mlirDenseElementsAttrInt64Get(MlirType shapedType,
intptr_t numElements,
- int64_t *elements);
+ const int64_t *elements);
MlirAttribute mlirDenseElementsAttrFloatGet(MlirType shapedType,
intptr_t numElements,
- float *elements);
+ const float *elements);
MlirAttribute mlirDenseElementsAttrDoubleGet(MlirType shapedType,
intptr_t numElements,
- double *elements);
+ const double *elements);
/** Creates a dense elements attribute with the given shaped type from string
* elements. The strings need not be null-terminated and their lengths are
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 45cd009bddc0..8226e4e552f7 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -457,6 +457,10 @@ MlirContext mlirAttributeGetContext(MlirAttribute attribute) {
return wrap(unwrap(attribute).getContext());
}
+MlirType mlirAttributeGetType(MlirAttribute attribute) {
+ return wrap(unwrap(attribute).getType());
+}
+
int mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2) {
return unwrap(a1) == unwrap(a2);
}
diff --git a/mlir/lib/CAPI/IR/StandardAttributes.cpp b/mlir/lib/CAPI/IR/StandardAttributes.cpp
index 1277d2b041ac..7443816b76aa 100644
--- a/mlir/lib/CAPI/IR/StandardAttributes.cpp
+++ b/mlir/lib/CAPI/IR/StandardAttributes.cpp
@@ -374,7 +374,7 @@ MlirAttribute mlirDenseElementsAttrDoubleSplatGet(MlirType shapedType,
MlirAttribute mlirDenseElementsAttrBoolGet(MlirType shapedType,
intptr_t numElements,
- int *elements) {
+ const int *elements) {
SmallVector<bool, 8> values(elements, elements + numElements);
return wrap(
DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), values));
@@ -383,7 +383,8 @@ MlirAttribute mlirDenseElementsAttrBoolGet(MlirType shapedType,
/// Creates a dense attribute with elements of the type deduced by templates.
template <typename T>
static MlirAttribute getDenseAttribute(MlirType shapedType,
- intptr_t numElements, T *elements) {
+ intptr_t numElements,
+ const T *elements) {
return wrap(
DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(),
llvm::makeArrayRef(elements, numElements)));
@@ -391,32 +392,32 @@ static MlirAttribute getDenseAttribute(MlirType shapedType,
MlirAttribute mlirDenseElementsAttrUInt32Get(MlirType shapedType,
intptr_t numElements,
- uint32_t *elements) {
+ const uint32_t *elements) {
return getDenseAttribute(shapedType, numElements, elements);
}
MlirAttribute mlirDenseElementsAttrInt32Get(MlirType shapedType,
intptr_t numElements,
- int32_t *elements) {
+ const int32_t *elements) {
return getDenseAttribute(shapedType, numElements, elements);
}
MlirAttribute mlirDenseElementsAttrUInt64Get(MlirType shapedType,
intptr_t numElements,
- uint64_t *elements) {
+ const uint64_t *elements) {
return getDenseAttribute(shapedType, numElements, elements);
}
MlirAttribute mlirDenseElementsAttrInt64Get(MlirType shapedType,
intptr_t numElements,
- int64_t *elements) {
+ const int64_t *elements) {
return getDenseAttribute(shapedType, numElements, elements);
}
MlirAttribute mlirDenseElementsAttrFloatGet(MlirType shapedType,
intptr_t numElements,
- float *elements) {
+ const float *elements) {
return getDenseAttribute(shapedType, numElements, elements);
}
MlirAttribute mlirDenseElementsAttrDoubleGet(MlirType shapedType,
intptr_t numElements,
- double *elements) {
+ const double *elements) {
return getDenseAttribute(shapedType, numElements, elements);
}
diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index 8eab4ebb3858..0c427c77bdb8 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -516,6 +516,10 @@ int printStandardAttributes(MlirContext ctx) {
return 1;
mlirAttributeDump(floating);
+ // Exercise mlirAttributeGetType() just for the first one.
+ MlirType floatingType = mlirAttributeGetType(floating);
+ mlirTypeDump(floatingType);
+
MlirAttribute integer = mlirIntegerAttrGet(mlirIntegerTypeGet(ctx, 32), 42);
if (!mlirAttributeIsAInteger(integer) ||
mlirIntegerAttrGetValueInt(integer) != 42)
@@ -990,6 +994,7 @@ int main() {
// clang-format off
// CHECK-LABEL: @attrs
// CHECK: 2.000000e+00 : f64
+ // CHECK: f64
// CHECK: 42 : i32
// CHECK: true
// CHECK: #std.abc
More information about the Mlir-commits
mailing list