[llvm-branch-commits] [mlir] f5c7c03 - [mlir] Add C API for IntegerSet

Alex Zinenko via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Mon Jan 25 11:21:13 PST 2021


Author: Alex Zinenko
Date: 2021-01-25T20:16:22+01:00
New Revision: f5c7c031e2493168b3c2cfea3219e2131cc01483

URL: https://github.com/llvm/llvm-project/commit/f5c7c031e2493168b3c2cfea3219e2131cc01483
DIFF: https://github.com/llvm/llvm-project/commit/f5c7c031e2493168b3c2cfea3219e2131cc01483.diff

LOG: [mlir] Add C API for IntegerSet

Depends On D95357

Reviewed By: stellaraccident

Differential Revision: https://reviews.llvm.org/D95368

Added: 
    mlir/include/mlir-c/IntegerSet.h
    mlir/include/mlir/CAPI/IntegerSet.h
    mlir/lib/CAPI/IR/IntegerSet.cpp

Modified: 
    mlir/include/mlir/IR/IntegerSet.h
    mlir/lib/CAPI/IR/CMakeLists.txt
    mlir/test/CAPI/ir.c

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/IntegerSet.h b/mlir/include/mlir-c/IntegerSet.h
new file mode 100644
index 000000000000..058be414d2ee
--- /dev/null
+++ b/mlir/include/mlir-c/IntegerSet.h
@@ -0,0 +1,131 @@
+//===-- mlir-c/IntegerSet.h - C API for MLIR Affine maps ----------*- C -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM
+// Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_C_INTEGERSET_H
+#define MLIR_C_INTEGERSET_H
+
+#include "mlir-c/AffineExpr.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+//===----------------------------------------------------------------------===//
+// Opaque type declarations.
+//
+// Types are exposed to C bindings as structs containing opaque pointers. They
+// are not supposed to be inspected from C. This allows the underlying
+// representation to change without affecting the API users. The use of structs
+// instead of typedefs enables some type safety as structs are not implicitly
+// convertible to each other.
+//
+// Instances of these types may or may not own the underlying object. The
+// ownership semantics is defined by how an instance of the type was obtained.
+//===----------------------------------------------------------------------===//
+
+#define DEFINE_C_API_STRUCT(name, storage)                                     \
+  struct name {                                                                \
+    storage *ptr;                                                              \
+  };                                                                           \
+  typedef struct name name
+
+DEFINE_C_API_STRUCT(MlirIntegerSet, const void);
+
+#undef DEFINE_C_API_STRUCT
+
+/// Gets the context in which the given integer set lives.
+MLIR_CAPI_EXPORTED MlirContext mlirIntegerSetGetContext(MlirIntegerSet set);
+
+/// Checks whether an integer set is a null object.
+static inline bool mlirIntegerSetIsNull(MlirIntegerSet set) { return !set.ptr; }
+
+/// Checks if two integer set objects are equal. This is a "shallow" comparison
+/// of two objects. Only the sets with some small number of constraints are
+/// uniqued and compare equal here. Set objects that represent the same integer
+/// set with 
diff erent constraints may be considered non-equal by this check.
+/// Set 
diff erence followed by an (expensive) emptiness check should be used to
+/// check equivalence of the underlying integer sets.
+MLIR_CAPI_EXPORTED bool mlirIntegerSetEqual(MlirIntegerSet s1,
+                                            MlirIntegerSet s2);
+
+/// Prints an integer set by sending chunks of the string representation and
+/// forwarding `userData to `callback`. Note that the callback may be called
+/// several times with consecutive chunks of the string.
+MLIR_CAPI_EXPORTED void mlirIntegerSetPrint(MlirIntegerSet set,
+                                            MlirStringCallback callback,
+                                            void *userData);
+
+/// Prints an integer set to the standard error stream.
+MLIR_CAPI_EXPORTED void mlirIntegerSetDump(MlirIntegerSet set);
+
+/// Gets or creates a new canonically empty integer set with the give number of
+/// dimensions and symbols in the given context.
+MLIR_CAPI_EXPORTED MlirIntegerSet mlirIntegerSetEmptyGet(MlirContext context,
+                                                         intptr_t numDims,
+                                                         intptr_t numSymbols);
+
+/// Gets or creates a new integer set in the given context. The set is defined
+/// by a list of affine constraints, with the given number of input dimensions
+/// and symbols, which are treated as either equalities (eqFlags is 1) or
+/// inequalities (eqFlags is 0). Both `constraints` and `eqFlags` are expected
+/// to point to at least `numConstraint` consecutive values.
+MLIR_CAPI_EXPORTED MlirIntegerSet
+mlirIntegerSetGet(MlirContext context, intptr_t numDims, intptr_t numSymbols,
+                  intptr_t numConstraints, const MlirAffineExpr *constraints,
+                  const bool *eqFlags);
+
+/// Gets or creates a new integer set in which the values and dimensions of the
+/// given set are replaced with the given affine expressions. `dimReplacements`
+/// and `symbolReplacements` are expected to point to at least as many
+/// consecutive expressions as the given set has dimensions and symbols,
+/// respectively. The new set will have `numResultDims` and `numResultSymbols`
+/// dimensions and symbols, respectively.
+MLIR_CAPI_EXPORTED MlirIntegerSet mlirIntegerSetReplaceGet(
+    MlirIntegerSet set, const MlirAffineExpr *dimReplacements,
+    const MlirAffineExpr *symbolReplacements, intptr_t numResultDims,
+    intptr_t numResultSymbols);
+
+/// Checks whether the given set is a canonical empty set, e.g., the set
+/// returned by mlirIntegerSetEmptyGet.
+MLIR_CAPI_EXPORTED bool mlirIntegerSetIsCanonicalEmpty(MlirIntegerSet set);
+
+/// Returns the number of dimensions in the given set.
+MLIR_CAPI_EXPORTED intptr_t mlirIntegerSetGetNumDims(MlirIntegerSet set);
+
+/// Returns the number of symbols in the given set.
+MLIR_CAPI_EXPORTED intptr_t mlirIntegerSetGetNumSymbols(MlirIntegerSet set);
+
+/// Returns the number of inputs (dimensions + symbols) in the given set.
+MLIR_CAPI_EXPORTED intptr_t mlirIntegerSetGetNumInputs(MlirIntegerSet set);
+
+/// Returns the number of constraints (equalities + inequalities) in the given
+/// set.
+MLIR_CAPI_EXPORTED intptr_t mlirIntegerSetGetNumConstraints(MlirIntegerSet set);
+
+/// Returns the number of equalities in the given set.
+MLIR_CAPI_EXPORTED intptr_t mlirIntegerSetGetNumEqualities(MlirIntegerSet set);
+
+/// Returns the number of inequalities in the given set.
+MLIR_CAPI_EXPORTED intptr_t
+mlirIntegerSetGetNumInequalities(MlirIntegerSet set);
+
+/// Returns `pos`-th constraint of the set.
+MLIR_CAPI_EXPORTED MlirAffineExpr
+mlirIntegerSetGetConstraint(MlirIntegerSet set, intptr_t pos);
+
+/// Returns `true` of the `pos`-th constraint of the set is an equality
+/// constraint, `false` otherwise.
+MLIR_CAPI_EXPORTED bool mlirIntegerSetIsConstraintEq(MlirIntegerSet set,
+                                                     intptr_t pos);
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif // MLIR_C_INTEGERSET_H

diff  --git a/mlir/include/mlir/CAPI/IntegerSet.h b/mlir/include/mlir/CAPI/IntegerSet.h
new file mode 100644
index 000000000000..465b1f9cda00
--- /dev/null
+++ b/mlir/include/mlir/CAPI/IntegerSet.h
@@ -0,0 +1,24 @@
+//===- IntegerSet.h - C API Utils for Integer Sets --------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains declarations of implementation details of the C API for
+// MLIR IntegerSets. This file should not be included from C++ code other than C
+// API implementation nor from C code.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CAPI_INTEGERSET_H
+#define MLIR_CAPI_INTEGERSET_H
+
+#include "mlir-c/IntegerSet.h"
+#include "mlir/CAPI/Wrap.h"
+#include "mlir/IR/IntegerSet.h"
+
+DEFINE_C_API_METHODS(MlirIntegerSet, mlir::IntegerSet);
+
+#endif // MLIR_CAPI_INTEGERSET_H

diff  --git a/mlir/include/mlir/IR/IntegerSet.h b/mlir/include/mlir/IR/IntegerSet.h
index 47ea7fc72007..c752822fc453 100644
--- a/mlir/include/mlir/IR/IntegerSet.h
+++ b/mlir/include/mlir/IR/IntegerSet.h
@@ -104,6 +104,15 @@ class IntegerSet {
 
   friend ::llvm::hash_code hash_value(IntegerSet arg);
 
+  /// Methods supporting C API.
+  const void *getAsOpaquePointer() const {
+    return static_cast<const void *>(set);
+  }
+  static IntegerSet getFromOpaquePointer(const void *pointer) {
+    return IntegerSet(
+        reinterpret_cast<ImplType *>(const_cast<void *>(pointer)));
+  }
+
 private:
   ImplType *set;
   /// Sets with constraints fewer than kUniquingThreshold are uniqued.

diff  --git a/mlir/lib/CAPI/IR/CMakeLists.txt b/mlir/lib/CAPI/IR/CMakeLists.txt
index 411e0582bf4d..893ccb6721fb 100644
--- a/mlir/lib/CAPI/IR/CMakeLists.txt
+++ b/mlir/lib/CAPI/IR/CMakeLists.txt
@@ -5,6 +5,7 @@ add_mlir_public_c_api_library(MLIRCAPIIR
   BuiltinAttributes.cpp
   BuiltinTypes.cpp
   Diagnostics.cpp
+  IntegerSet.cpp
   IR.cpp
   Pass.cpp
   Support.cpp

diff  --git a/mlir/lib/CAPI/IR/IntegerSet.cpp b/mlir/lib/CAPI/IR/IntegerSet.cpp
new file mode 100644
index 000000000000..701d70353614
--- /dev/null
+++ b/mlir/lib/CAPI/IR/IntegerSet.cpp
@@ -0,0 +1,103 @@
+//===- IntegerSet.cpp - C API for MLIR Integer Sets -----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir-c/IntegerSet.h"
+#include "mlir-c/AffineExpr.h"
+#include "mlir/CAPI/AffineExpr.h"
+#include "mlir/CAPI/IR.h"
+#include "mlir/CAPI/IntegerSet.h"
+#include "mlir/CAPI/Utils.h"
+#include "mlir/IR/IntegerSet.h"
+
+using namespace mlir;
+
+MlirContext mlirIntegerSetGetContext(MlirIntegerSet set) {
+  return wrap(unwrap(set).getContext());
+}
+
+bool mlirIntegerSetEqual(MlirIntegerSet s1, MlirIntegerSet s2) {
+  return unwrap(s1) == unwrap(s2);
+}
+
+void mlirIntegerSetPrint(MlirIntegerSet set, MlirStringCallback callback,
+                         void *userData) {
+  mlir::detail::CallbackOstream stream(callback, userData);
+  unwrap(set).print(stream);
+}
+
+void mlirIntegerSetDump(MlirIntegerSet set) { unwrap(set).dump(); }
+
+MlirIntegerSet mlirIntegerSetEmptyGet(MlirContext context, intptr_t numDims,
+                                      intptr_t numSymbols) {
+  return wrap(IntegerSet::getEmptySet(static_cast<unsigned>(numDims),
+                                      static_cast<unsigned>(numSymbols),
+                                      unwrap(context)));
+}
+
+MlirIntegerSet mlirIntegerSetGet(MlirContext context, intptr_t numDims,
+                                 intptr_t numSymbols, intptr_t numConstraints,
+                                 const MlirAffineExpr *constraints,
+                                 const bool *eqFlags) {
+  SmallVector<AffineExpr> mlirConstraints;
+  (void)unwrapList(static_cast<size_t>(numConstraints), constraints,
+                   mlirConstraints);
+  return wrap(IntegerSet::get(
+      static_cast<unsigned>(numDims), static_cast<unsigned>(numSymbols),
+      mlirConstraints,
+      llvm::makeArrayRef(eqFlags, static_cast<size_t>(numConstraints))));
+}
+
+MlirIntegerSet
+mlirIntegerSetReplaceGet(MlirIntegerSet set,
+                         const MlirAffineExpr *dimReplacements,
+                         const MlirAffineExpr *symbolReplacements,
+                         intptr_t numResultDims, intptr_t numResultSymbols) {
+  SmallVector<AffineExpr> mlirDims, mlirSymbols;
+  (void)unwrapList(unwrap(set).getNumDims(), dimReplacements, mlirDims);
+  (void)unwrapList(unwrap(set).getNumSymbols(), symbolReplacements,
+                   mlirSymbols);
+  return wrap(unwrap(set).replaceDimsAndSymbols(
+      mlirDims, mlirSymbols, static_cast<unsigned>(numResultDims),
+      static_cast<unsigned>(numResultSymbols)));
+}
+
+bool mlirIntegerSetIsCanonicalEmpty(MlirIntegerSet set) {
+  return unwrap(set).isEmptyIntegerSet();
+}
+
+intptr_t mlirIntegerSetGetNumDims(MlirIntegerSet set) {
+  return static_cast<intptr_t>(unwrap(set).getNumDims());
+}
+
+intptr_t mlirIntegerSetGetNumSymbols(MlirIntegerSet set) {
+  return static_cast<intptr_t>(unwrap(set).getNumSymbols());
+}
+
+intptr_t mlirIntegerSetGetNumInputs(MlirIntegerSet set) {
+  return static_cast<intptr_t>(unwrap(set).getNumInputs());
+}
+
+intptr_t mlirIntegerSetGetNumConstraints(MlirIntegerSet set) {
+  return static_cast<intptr_t>(unwrap(set).getNumConstraints());
+}
+
+intptr_t mlirIntegerSetGetNumEqualities(MlirIntegerSet set) {
+  return static_cast<intptr_t>(unwrap(set).getNumEqualities());
+}
+
+intptr_t mlirIntegerSetGetNumInequalities(MlirIntegerSet set) {
+  return static_cast<intptr_t>(unwrap(set).getNumInequalities());
+}
+
+MlirAffineExpr mlirIntegerSetGetConstraint(MlirIntegerSet set, intptr_t pos) {
+  return wrap(unwrap(set).getConstraint(static_cast<unsigned>(pos)));
+}
+
+bool mlirIntegerSetIsConstraintEq(MlirIntegerSet set, intptr_t pos) {
+  return unwrap(set).isEq(pos);
+}

diff  --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index 015738384424..d19ab47971cf 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -17,6 +17,7 @@
 #include "mlir-c/BuiltinTypes.h"
 #include "mlir-c/Diagnostics.h"
 #include "mlir-c/Dialect/Standard.h"
+#include "mlir-c/IntegerSet.h"
 #include "mlir-c/Registration.h"
 
 #include <assert.h>
@@ -1325,6 +1326,85 @@ int affineMapFromExprs(MlirContext ctx) {
   return 0;
 }
 
+int printIntegerSet(MlirContext ctx) {
+  MlirIntegerSet emptySet = mlirIntegerSetEmptyGet(ctx, 2, 1);
+
+  // CHECK-LABEL: @printIntegerSet
+  fprintf(stderr, "@printIntegerSet");
+
+  // CHECK: (d0, d1)[s0] : (1 == 0)
+  mlirIntegerSetDump(emptySet);
+
+  if (!mlirIntegerSetIsCanonicalEmpty(emptySet))
+    return 1;
+
+  MlirIntegerSet anotherEmptySet = mlirIntegerSetEmptyGet(ctx, 2, 1);
+  if (!mlirIntegerSetEqual(emptySet, anotherEmptySet))
+    return 2;
+
+  // Construct a set constrained by:
+  //   d0 - s0 == 0,
+  //   d1 - 42 >= 0.
+  MlirAffineExpr negOne = mlirAffineConstantExprGet(ctx, -1);
+  MlirAffineExpr negFortyTwo = mlirAffineConstantExprGet(ctx, -42);
+  MlirAffineExpr d0 = mlirAffineDimExprGet(ctx, 0);
+  MlirAffineExpr d1 = mlirAffineDimExprGet(ctx, 1);
+  MlirAffineExpr s0 = mlirAffineSymbolExprGet(ctx, 0);
+  MlirAffineExpr negS0 = mlirAffineMulExprGet(negOne, s0);
+  MlirAffineExpr d0minusS0 = mlirAffineAddExprGet(d0, negS0);
+  MlirAffineExpr d1minus42 = mlirAffineAddExprGet(d1, negFortyTwo);
+  MlirAffineExpr constraints[] = {d0minusS0, d1minus42};
+  bool flags[] = {true, false};
+
+  MlirIntegerSet set = mlirIntegerSetGet(ctx, 2, 1, 2, constraints, flags);
+  // CHECK: (d0, d1)[s0] : (
+  // CHECK-DAG: d0 - s0 == 0
+  // CHECK-DAG: d1 - 42 >= 0
+  mlirIntegerSetDump(set);
+
+  // Transform d1 into s0.
+  MlirAffineExpr s1 = mlirAffineSymbolExprGet(ctx, 1);
+  MlirAffineExpr repl[] = {d0, s1};
+  MlirIntegerSet replaced = mlirIntegerSetReplaceGet(set, repl, &s0, 1, 2);
+  // CHECK: (d0)[s0, s1] : (
+  // CHECK-DAG: d0 - s0 == 0
+  // CHECK-DAG: s1 - 42 >= 0
+  mlirIntegerSetDump(replaced);
+
+  if (mlirIntegerSetGetNumDims(set) != 2)
+    return 3;
+  if (mlirIntegerSetGetNumDims(replaced) != 1)
+    return 4;
+
+  if (mlirIntegerSetGetNumSymbols(set) != 1)
+    return 5;
+  if (mlirIntegerSetGetNumSymbols(replaced) != 2)
+    return 6;
+
+  if (mlirIntegerSetGetNumInputs(set) != 3)
+    return 7;
+
+  if (mlirIntegerSetGetNumConstraints(set) != 2)
+    return 8;
+
+  if (mlirIntegerSetGetNumEqualities(set) != 1)
+    return 9;
+
+  if (mlirIntegerSetGetNumInequalities(set) != 1)
+    return 10;
+
+  MlirAffineExpr cstr1 = mlirIntegerSetGetConstraint(set, 0);
+  MlirAffineExpr cstr2 = mlirIntegerSetGetConstraint(set, 1);
+  bool isEq1 = mlirIntegerSetIsConstraintEq(set, 0);
+  bool isEq2 = mlirIntegerSetIsConstraintEq(set, 1);
+  if (!mlirAffineExprEqual(cstr1, isEq1 ? d0minusS0 : d1minus42))
+    return 11;
+  if (!mlirAffineExprEqual(cstr2, isEq2 ? d0minusS0 : d1minus42))
+    return 12;
+
+  return 0;
+}
+
 int registerOnlyStd() {
   MlirContext ctx = mlirContextCreate();
   // The built-in dialect is always loaded.
@@ -1429,8 +1509,10 @@ int main() {
     return 6;
   if (affineMapFromExprs(ctx))
     return 7;
-  if (registerOnlyStd())
+  if (printIntegerSet(ctx))
     return 8;
+  if (registerOnlyStd())
+    return 9;
 
   mlirContextDestroy(ctx);
 


        


More information about the llvm-branch-commits mailing list