[Mlir-commits] [mlir] f5c7c03 - [mlir] Add C API for IntegerSet
Alex Zinenko
llvmlistbot at llvm.org
Mon Jan 25 11:16:30 PST 2021
Author: Alex Zinenko
Date: 2021-01-25T20:16:22+01:00
New Revision: f5c7c031e2493168b3c2cfea3219e2131cc01483
URL: https://github.com/llvm/llvm-project/commit/f5c7c031e2493168b3c2cfea3219e2131cc01483
DIFF: https://github.com/llvm/llvm-project/commit/f5c7c031e2493168b3c2cfea3219e2131cc01483.diff
LOG: [mlir] Add C API for IntegerSet
Depends On D95357
Reviewed By: stellaraccident
Differential Revision: https://reviews.llvm.org/D95368
Added:
mlir/include/mlir-c/IntegerSet.h
mlir/include/mlir/CAPI/IntegerSet.h
mlir/lib/CAPI/IR/IntegerSet.cpp
Modified:
mlir/include/mlir/IR/IntegerSet.h
mlir/lib/CAPI/IR/CMakeLists.txt
mlir/test/CAPI/ir.c
Removed:
################################################################################
diff --git a/mlir/include/mlir-c/IntegerSet.h b/mlir/include/mlir-c/IntegerSet.h
new file mode 100644
index 000000000000..058be414d2ee
--- /dev/null
+++ b/mlir/include/mlir-c/IntegerSet.h
@@ -0,0 +1,131 @@
+//===-- mlir-c/IntegerSet.h - C API for MLIR Affine maps ----------*- C -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM
+// Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_C_INTEGERSET_H
+#define MLIR_C_INTEGERSET_H
+
+#include "mlir-c/AffineExpr.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+//===----------------------------------------------------------------------===//
+// Opaque type declarations.
+//
+// Types are exposed to C bindings as structs containing opaque pointers. They
+// are not supposed to be inspected from C. This allows the underlying
+// representation to change without affecting the API users. The use of structs
+// instead of typedefs enables some type safety as structs are not implicitly
+// convertible to each other.
+//
+// Instances of these types may or may not own the underlying object. The
+// ownership semantics is defined by how an instance of the type was obtained.
+//===----------------------------------------------------------------------===//
+
+#define DEFINE_C_API_STRUCT(name, storage) \
+ struct name { \
+ storage *ptr; \
+ }; \
+ typedef struct name name
+
+DEFINE_C_API_STRUCT(MlirIntegerSet, const void);
+
+#undef DEFINE_C_API_STRUCT
+
+/// Gets the context in which the given integer set lives.
+MLIR_CAPI_EXPORTED MlirContext mlirIntegerSetGetContext(MlirIntegerSet set);
+
+/// Checks whether an integer set is a null object.
+static inline bool mlirIntegerSetIsNull(MlirIntegerSet set) { return !set.ptr; }
+
+/// Checks if two integer set objects are equal. This is a "shallow" comparison
+/// of two objects. Only the sets with some small number of constraints are
+/// uniqued and compare equal here. Set objects that represent the same integer
+/// set with
diff erent constraints may be considered non-equal by this check.
+/// Set
diff erence followed by an (expensive) emptiness check should be used to
+/// check equivalence of the underlying integer sets.
+MLIR_CAPI_EXPORTED bool mlirIntegerSetEqual(MlirIntegerSet s1,
+ MlirIntegerSet s2);
+
+/// Prints an integer set by sending chunks of the string representation and
+/// forwarding `userData to `callback`. Note that the callback may be called
+/// several times with consecutive chunks of the string.
+MLIR_CAPI_EXPORTED void mlirIntegerSetPrint(MlirIntegerSet set,
+ MlirStringCallback callback,
+ void *userData);
+
+/// Prints an integer set to the standard error stream.
+MLIR_CAPI_EXPORTED void mlirIntegerSetDump(MlirIntegerSet set);
+
+/// Gets or creates a new canonically empty integer set with the give number of
+/// dimensions and symbols in the given context.
+MLIR_CAPI_EXPORTED MlirIntegerSet mlirIntegerSetEmptyGet(MlirContext context,
+ intptr_t numDims,
+ intptr_t numSymbols);
+
+/// Gets or creates a new integer set in the given context. The set is defined
+/// by a list of affine constraints, with the given number of input dimensions
+/// and symbols, which are treated as either equalities (eqFlags is 1) or
+/// inequalities (eqFlags is 0). Both `constraints` and `eqFlags` are expected
+/// to point to at least `numConstraint` consecutive values.
+MLIR_CAPI_EXPORTED MlirIntegerSet
+mlirIntegerSetGet(MlirContext context, intptr_t numDims, intptr_t numSymbols,
+ intptr_t numConstraints, const MlirAffineExpr *constraints,
+ const bool *eqFlags);
+
+/// Gets or creates a new integer set in which the values and dimensions of the
+/// given set are replaced with the given affine expressions. `dimReplacements`
+/// and `symbolReplacements` are expected to point to at least as many
+/// consecutive expressions as the given set has dimensions and symbols,
+/// respectively. The new set will have `numResultDims` and `numResultSymbols`
+/// dimensions and symbols, respectively.
+MLIR_CAPI_EXPORTED MlirIntegerSet mlirIntegerSetReplaceGet(
+ MlirIntegerSet set, const MlirAffineExpr *dimReplacements,
+ const MlirAffineExpr *symbolReplacements, intptr_t numResultDims,
+ intptr_t numResultSymbols);
+
+/// Checks whether the given set is a canonical empty set, e.g., the set
+/// returned by mlirIntegerSetEmptyGet.
+MLIR_CAPI_EXPORTED bool mlirIntegerSetIsCanonicalEmpty(MlirIntegerSet set);
+
+/// Returns the number of dimensions in the given set.
+MLIR_CAPI_EXPORTED intptr_t mlirIntegerSetGetNumDims(MlirIntegerSet set);
+
+/// Returns the number of symbols in the given set.
+MLIR_CAPI_EXPORTED intptr_t mlirIntegerSetGetNumSymbols(MlirIntegerSet set);
+
+/// Returns the number of inputs (dimensions + symbols) in the given set.
+MLIR_CAPI_EXPORTED intptr_t mlirIntegerSetGetNumInputs(MlirIntegerSet set);
+
+/// Returns the number of constraints (equalities + inequalities) in the given
+/// set.
+MLIR_CAPI_EXPORTED intptr_t mlirIntegerSetGetNumConstraints(MlirIntegerSet set);
+
+/// Returns the number of equalities in the given set.
+MLIR_CAPI_EXPORTED intptr_t mlirIntegerSetGetNumEqualities(MlirIntegerSet set);
+
+/// Returns the number of inequalities in the given set.
+MLIR_CAPI_EXPORTED intptr_t
+mlirIntegerSetGetNumInequalities(MlirIntegerSet set);
+
+/// Returns `pos`-th constraint of the set.
+MLIR_CAPI_EXPORTED MlirAffineExpr
+mlirIntegerSetGetConstraint(MlirIntegerSet set, intptr_t pos);
+
+/// Returns `true` of the `pos`-th constraint of the set is an equality
+/// constraint, `false` otherwise.
+MLIR_CAPI_EXPORTED bool mlirIntegerSetIsConstraintEq(MlirIntegerSet set,
+ intptr_t pos);
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif // MLIR_C_INTEGERSET_H
diff --git a/mlir/include/mlir/CAPI/IntegerSet.h b/mlir/include/mlir/CAPI/IntegerSet.h
new file mode 100644
index 000000000000..465b1f9cda00
--- /dev/null
+++ b/mlir/include/mlir/CAPI/IntegerSet.h
@@ -0,0 +1,24 @@
+//===- IntegerSet.h - C API Utils for Integer Sets --------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains declarations of implementation details of the C API for
+// MLIR IntegerSets. This file should not be included from C++ code other than C
+// API implementation nor from C code.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CAPI_INTEGERSET_H
+#define MLIR_CAPI_INTEGERSET_H
+
+#include "mlir-c/IntegerSet.h"
+#include "mlir/CAPI/Wrap.h"
+#include "mlir/IR/IntegerSet.h"
+
+DEFINE_C_API_METHODS(MlirIntegerSet, mlir::IntegerSet);
+
+#endif // MLIR_CAPI_INTEGERSET_H
diff --git a/mlir/include/mlir/IR/IntegerSet.h b/mlir/include/mlir/IR/IntegerSet.h
index 47ea7fc72007..c752822fc453 100644
--- a/mlir/include/mlir/IR/IntegerSet.h
+++ b/mlir/include/mlir/IR/IntegerSet.h
@@ -104,6 +104,15 @@ class IntegerSet {
friend ::llvm::hash_code hash_value(IntegerSet arg);
+ /// Methods supporting C API.
+ const void *getAsOpaquePointer() const {
+ return static_cast<const void *>(set);
+ }
+ static IntegerSet getFromOpaquePointer(const void *pointer) {
+ return IntegerSet(
+ reinterpret_cast<ImplType *>(const_cast<void *>(pointer)));
+ }
+
private:
ImplType *set;
/// Sets with constraints fewer than kUniquingThreshold are uniqued.
diff --git a/mlir/lib/CAPI/IR/CMakeLists.txt b/mlir/lib/CAPI/IR/CMakeLists.txt
index 411e0582bf4d..893ccb6721fb 100644
--- a/mlir/lib/CAPI/IR/CMakeLists.txt
+++ b/mlir/lib/CAPI/IR/CMakeLists.txt
@@ -5,6 +5,7 @@ add_mlir_public_c_api_library(MLIRCAPIIR
BuiltinAttributes.cpp
BuiltinTypes.cpp
Diagnostics.cpp
+ IntegerSet.cpp
IR.cpp
Pass.cpp
Support.cpp
diff --git a/mlir/lib/CAPI/IR/IntegerSet.cpp b/mlir/lib/CAPI/IR/IntegerSet.cpp
new file mode 100644
index 000000000000..701d70353614
--- /dev/null
+++ b/mlir/lib/CAPI/IR/IntegerSet.cpp
@@ -0,0 +1,103 @@
+//===- IntegerSet.cpp - C API for MLIR Integer Sets -----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir-c/IntegerSet.h"
+#include "mlir-c/AffineExpr.h"
+#include "mlir/CAPI/AffineExpr.h"
+#include "mlir/CAPI/IR.h"
+#include "mlir/CAPI/IntegerSet.h"
+#include "mlir/CAPI/Utils.h"
+#include "mlir/IR/IntegerSet.h"
+
+using namespace mlir;
+
+MlirContext mlirIntegerSetGetContext(MlirIntegerSet set) {
+ return wrap(unwrap(set).getContext());
+}
+
+bool mlirIntegerSetEqual(MlirIntegerSet s1, MlirIntegerSet s2) {
+ return unwrap(s1) == unwrap(s2);
+}
+
+void mlirIntegerSetPrint(MlirIntegerSet set, MlirStringCallback callback,
+ void *userData) {
+ mlir::detail::CallbackOstream stream(callback, userData);
+ unwrap(set).print(stream);
+}
+
+void mlirIntegerSetDump(MlirIntegerSet set) { unwrap(set).dump(); }
+
+MlirIntegerSet mlirIntegerSetEmptyGet(MlirContext context, intptr_t numDims,
+ intptr_t numSymbols) {
+ return wrap(IntegerSet::getEmptySet(static_cast<unsigned>(numDims),
+ static_cast<unsigned>(numSymbols),
+ unwrap(context)));
+}
+
+MlirIntegerSet mlirIntegerSetGet(MlirContext context, intptr_t numDims,
+ intptr_t numSymbols, intptr_t numConstraints,
+ const MlirAffineExpr *constraints,
+ const bool *eqFlags) {
+ SmallVector<AffineExpr> mlirConstraints;
+ (void)unwrapList(static_cast<size_t>(numConstraints), constraints,
+ mlirConstraints);
+ return wrap(IntegerSet::get(
+ static_cast<unsigned>(numDims), static_cast<unsigned>(numSymbols),
+ mlirConstraints,
+ llvm::makeArrayRef(eqFlags, static_cast<size_t>(numConstraints))));
+}
+
+MlirIntegerSet
+mlirIntegerSetReplaceGet(MlirIntegerSet set,
+ const MlirAffineExpr *dimReplacements,
+ const MlirAffineExpr *symbolReplacements,
+ intptr_t numResultDims, intptr_t numResultSymbols) {
+ SmallVector<AffineExpr> mlirDims, mlirSymbols;
+ (void)unwrapList(unwrap(set).getNumDims(), dimReplacements, mlirDims);
+ (void)unwrapList(unwrap(set).getNumSymbols(), symbolReplacements,
+ mlirSymbols);
+ return wrap(unwrap(set).replaceDimsAndSymbols(
+ mlirDims, mlirSymbols, static_cast<unsigned>(numResultDims),
+ static_cast<unsigned>(numResultSymbols)));
+}
+
+bool mlirIntegerSetIsCanonicalEmpty(MlirIntegerSet set) {
+ return unwrap(set).isEmptyIntegerSet();
+}
+
+intptr_t mlirIntegerSetGetNumDims(MlirIntegerSet set) {
+ return static_cast<intptr_t>(unwrap(set).getNumDims());
+}
+
+intptr_t mlirIntegerSetGetNumSymbols(MlirIntegerSet set) {
+ return static_cast<intptr_t>(unwrap(set).getNumSymbols());
+}
+
+intptr_t mlirIntegerSetGetNumInputs(MlirIntegerSet set) {
+ return static_cast<intptr_t>(unwrap(set).getNumInputs());
+}
+
+intptr_t mlirIntegerSetGetNumConstraints(MlirIntegerSet set) {
+ return static_cast<intptr_t>(unwrap(set).getNumConstraints());
+}
+
+intptr_t mlirIntegerSetGetNumEqualities(MlirIntegerSet set) {
+ return static_cast<intptr_t>(unwrap(set).getNumEqualities());
+}
+
+intptr_t mlirIntegerSetGetNumInequalities(MlirIntegerSet set) {
+ return static_cast<intptr_t>(unwrap(set).getNumInequalities());
+}
+
+MlirAffineExpr mlirIntegerSetGetConstraint(MlirIntegerSet set, intptr_t pos) {
+ return wrap(unwrap(set).getConstraint(static_cast<unsigned>(pos)));
+}
+
+bool mlirIntegerSetIsConstraintEq(MlirIntegerSet set, intptr_t pos) {
+ return unwrap(set).isEq(pos);
+}
diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index 015738384424..d19ab47971cf 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -17,6 +17,7 @@
#include "mlir-c/BuiltinTypes.h"
#include "mlir-c/Diagnostics.h"
#include "mlir-c/Dialect/Standard.h"
+#include "mlir-c/IntegerSet.h"
#include "mlir-c/Registration.h"
#include <assert.h>
@@ -1325,6 +1326,85 @@ int affineMapFromExprs(MlirContext ctx) {
return 0;
}
+int printIntegerSet(MlirContext ctx) {
+ MlirIntegerSet emptySet = mlirIntegerSetEmptyGet(ctx, 2, 1);
+
+ // CHECK-LABEL: @printIntegerSet
+ fprintf(stderr, "@printIntegerSet");
+
+ // CHECK: (d0, d1)[s0] : (1 == 0)
+ mlirIntegerSetDump(emptySet);
+
+ if (!mlirIntegerSetIsCanonicalEmpty(emptySet))
+ return 1;
+
+ MlirIntegerSet anotherEmptySet = mlirIntegerSetEmptyGet(ctx, 2, 1);
+ if (!mlirIntegerSetEqual(emptySet, anotherEmptySet))
+ return 2;
+
+ // Construct a set constrained by:
+ // d0 - s0 == 0,
+ // d1 - 42 >= 0.
+ MlirAffineExpr negOne = mlirAffineConstantExprGet(ctx, -1);
+ MlirAffineExpr negFortyTwo = mlirAffineConstantExprGet(ctx, -42);
+ MlirAffineExpr d0 = mlirAffineDimExprGet(ctx, 0);
+ MlirAffineExpr d1 = mlirAffineDimExprGet(ctx, 1);
+ MlirAffineExpr s0 = mlirAffineSymbolExprGet(ctx, 0);
+ MlirAffineExpr negS0 = mlirAffineMulExprGet(negOne, s0);
+ MlirAffineExpr d0minusS0 = mlirAffineAddExprGet(d0, negS0);
+ MlirAffineExpr d1minus42 = mlirAffineAddExprGet(d1, negFortyTwo);
+ MlirAffineExpr constraints[] = {d0minusS0, d1minus42};
+ bool flags[] = {true, false};
+
+ MlirIntegerSet set = mlirIntegerSetGet(ctx, 2, 1, 2, constraints, flags);
+ // CHECK: (d0, d1)[s0] : (
+ // CHECK-DAG: d0 - s0 == 0
+ // CHECK-DAG: d1 - 42 >= 0
+ mlirIntegerSetDump(set);
+
+ // Transform d1 into s0.
+ MlirAffineExpr s1 = mlirAffineSymbolExprGet(ctx, 1);
+ MlirAffineExpr repl[] = {d0, s1};
+ MlirIntegerSet replaced = mlirIntegerSetReplaceGet(set, repl, &s0, 1, 2);
+ // CHECK: (d0)[s0, s1] : (
+ // CHECK-DAG: d0 - s0 == 0
+ // CHECK-DAG: s1 - 42 >= 0
+ mlirIntegerSetDump(replaced);
+
+ if (mlirIntegerSetGetNumDims(set) != 2)
+ return 3;
+ if (mlirIntegerSetGetNumDims(replaced) != 1)
+ return 4;
+
+ if (mlirIntegerSetGetNumSymbols(set) != 1)
+ return 5;
+ if (mlirIntegerSetGetNumSymbols(replaced) != 2)
+ return 6;
+
+ if (mlirIntegerSetGetNumInputs(set) != 3)
+ return 7;
+
+ if (mlirIntegerSetGetNumConstraints(set) != 2)
+ return 8;
+
+ if (mlirIntegerSetGetNumEqualities(set) != 1)
+ return 9;
+
+ if (mlirIntegerSetGetNumInequalities(set) != 1)
+ return 10;
+
+ MlirAffineExpr cstr1 = mlirIntegerSetGetConstraint(set, 0);
+ MlirAffineExpr cstr2 = mlirIntegerSetGetConstraint(set, 1);
+ bool isEq1 = mlirIntegerSetIsConstraintEq(set, 0);
+ bool isEq2 = mlirIntegerSetIsConstraintEq(set, 1);
+ if (!mlirAffineExprEqual(cstr1, isEq1 ? d0minusS0 : d1minus42))
+ return 11;
+ if (!mlirAffineExprEqual(cstr2, isEq2 ? d0minusS0 : d1minus42))
+ return 12;
+
+ return 0;
+}
+
int registerOnlyStd() {
MlirContext ctx = mlirContextCreate();
// The built-in dialect is always loaded.
@@ -1429,8 +1509,10 @@ int main() {
return 6;
if (affineMapFromExprs(ctx))
return 7;
- if (registerOnlyStd())
+ if (printIntegerSet(ctx))
return 8;
+ if (registerOnlyStd())
+ return 9;
mlirContextDestroy(ctx);
More information about the Mlir-commits
mailing list