[llvm-branch-commits] [mlir] 76753a5 - Add FunctionType to MLIR C and Python bindings.
Stella Laurenzo via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Mon Sep 28 09:57:22 PDT 2020
Author: Stella Laurenzo
Date: 2020-09-28T09:56:48-07:00
New Revision: 76753a597b5d9bf4addf19399ae30c4b3870a4a6
URL: https://github.com/llvm/llvm-project/commit/76753a597b5d9bf4addf19399ae30c4b3870a4a6
DIFF: https://github.com/llvm/llvm-project/commit/76753a597b5d9bf4addf19399ae30c4b3870a4a6.diff
LOG: Add FunctionType to MLIR C and Python bindings.
Differential Revision: https://reviews.llvm.org/D88416
Added:
Modified:
mlir/include/mlir-c/StandardTypes.h
mlir/lib/Bindings/Python/IRModules.cpp
mlir/lib/CAPI/IR/StandardTypes.cpp
mlir/test/Bindings/Python/ir_types.py
mlir/test/CAPI/ir.c
Removed:
################################################################################
diff --git a/mlir/include/mlir-c/StandardTypes.h b/mlir/include/mlir-c/StandardTypes.h
index 4bbcb23b8002e..3b667c9e0a1b0 100644
--- a/mlir/include/mlir-c/StandardTypes.h
+++ b/mlir/include/mlir-c/StandardTypes.h
@@ -270,6 +270,30 @@ intptr_t mlirTupleTypeGetNumTypes(MlirType type);
/** Returns the pos-th type in the tuple type. */
MlirType mlirTupleTypeGetType(MlirType type, intptr_t pos);
+/*============================================================================*/
+/* Function type. */
+/*============================================================================*/
+
+/** Checks whether the given type is a function type. */
+int mlirTypeIsAFunction(MlirType type);
+
+/** Creates a function type, mapping a list of input types to result types. */
+MlirType mlirFunctionTypeGet(MlirContext ctx, intptr_t numInputs,
+ MlirType *inputs, intptr_t numResults,
+ MlirType *results);
+
+/** Returns the number of input types. */
+intptr_t mlirFunctionTypeGetNumInputs(MlirType type);
+
+/** Returns the number of result types. */
+intptr_t mlirFunctionTypeGetNumResults(MlirType type);
+
+/** Returns the pos-th input type. */
+MlirType mlirFunctionTypeGetInput(MlirType type, intptr_t pos);
+
+/** Returns the pos-th result type. */
+MlirType mlirFunctionTypeGetResult(MlirType type, intptr_t pos);
+
#ifdef __cplusplus
}
#endif
diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp
index 99b00b96b974c..f3bd96856d090 100644
--- a/mlir/lib/Bindings/Python/IRModules.cpp
+++ b/mlir/lib/Bindings/Python/IRModules.cpp
@@ -1278,6 +1278,56 @@ class PyTupleType : public PyConcreteType<PyTupleType> {
}
};
+/// Function type.
+class PyFunctionType : public PyConcreteType<PyFunctionType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFunction;
+ static constexpr const char *pyClassName = "FunctionType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](PyMlirContext &context, std::vector<PyType> inputs,
+ std::vector<PyType> results) {
+ SmallVector<MlirType, 4> inputsRaw(inputs.begin(), inputs.end());
+ SmallVector<MlirType, 4> resultsRaw(results.begin(), results.end());
+ MlirType t = mlirFunctionTypeGet(context.get(), inputsRaw.size(),
+ inputsRaw.data(), resultsRaw.size(),
+ resultsRaw.data());
+ return PyFunctionType(context.getRef(), t);
+ },
+ py::arg("context"), py::arg("inputs"), py::arg("results"),
+ "Gets a FunctionType from a list of input and result types");
+ c.def_property_readonly(
+ "inputs",
+ [](PyFunctionType &self) {
+ MlirType t = self.type;
+ auto contextRef = self.getContext();
+ py::list types;
+ for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self.type);
+ i < e; ++i) {
+ types.append(PyType(contextRef, mlirFunctionTypeGetInput(t, i)));
+ }
+ return types;
+ },
+ "Returns the list of input types in the FunctionType.");
+ c.def_property_readonly(
+ "results",
+ [](PyFunctionType &self) {
+ MlirType t = self.type;
+ auto contextRef = self.getContext();
+ py::list types;
+ for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self.type);
+ i < e; ++i) {
+ types.append(PyType(contextRef, mlirFunctionTypeGetResult(t, i)));
+ }
+ return types;
+ },
+ "Returns the list of result types in the FunctionType.");
+ }
+};
+
} // namespace
//------------------------------------------------------------------------------
@@ -1613,6 +1663,7 @@ void mlir::python::populateIRSubmodule(py::module &m) {
PyMemRefType::bind(m);
PyUnrankedMemRefType::bind(m);
PyTupleType::bind(m);
+ PyFunctionType::bind(m);
// Container bindings.
PyBlockIterator::bind(m);
diff --git a/mlir/lib/CAPI/IR/StandardTypes.cpp b/mlir/lib/CAPI/IR/StandardTypes.cpp
index ddd3a5e93147a..b4e37d38ace31 100644
--- a/mlir/lib/CAPI/IR/StandardTypes.cpp
+++ b/mlir/lib/CAPI/IR/StandardTypes.cpp
@@ -13,6 +13,7 @@
#include "mlir/CAPI/IR.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/StandardTypes.h"
+#include "mlir/IR/Types.h"
using namespace mlir;
@@ -297,3 +298,41 @@ intptr_t mlirTupleTypeGetNumTypes(MlirType type) {
MlirType mlirTupleTypeGetType(MlirType type, intptr_t pos) {
return wrap(unwrap(type).cast<TupleType>().getType(static_cast<size_t>(pos)));
}
+
+/*============================================================================*/
+/* Function type. */
+/*============================================================================*/
+
+int mlirTypeIsAFunction(MlirType type) {
+ return unwrap(type).isa<FunctionType>();
+}
+
+MlirType mlirFunctionTypeGet(MlirContext ctx, intptr_t numInputs,
+ MlirType *inputs, intptr_t numResults,
+ MlirType *results) {
+ SmallVector<Type, 4> inputsList;
+ SmallVector<Type, 4> resultsList;
+ (void)unwrapList(numInputs, inputs, inputsList);
+ (void)unwrapList(numResults, results, resultsList);
+ return wrap(FunctionType::get(inputsList, resultsList, unwrap(ctx)));
+}
+
+intptr_t mlirFunctionTypeGetNumInputs(MlirType type) {
+ return unwrap(type).cast<FunctionType>().getNumInputs();
+}
+
+intptr_t mlirFunctionTypeGetNumResults(MlirType type) {
+ return unwrap(type).cast<FunctionType>().getNumResults();
+}
+
+MlirType mlirFunctionTypeGetInput(MlirType type, intptr_t pos) {
+ assert(pos >= 0 && "pos in array must be positive");
+ return wrap(
+ unwrap(type).cast<FunctionType>().getInput(static_cast<unsigned>(pos)));
+}
+
+MlirType mlirFunctionTypeGetResult(MlirType type, intptr_t pos) {
+ assert(pos >= 0 && "pos in array must be positive");
+ return wrap(
+ unwrap(type).cast<FunctionType>().getResult(static_cast<unsigned>(pos)));
+}
diff --git a/mlir/test/Bindings/Python/ir_types.py b/mlir/test/Bindings/Python/ir_types.py
index b80cbebb10e24..d8ae77f1f092d 100644
--- a/mlir/test/Bindings/Python/ir_types.py
+++ b/mlir/test/Bindings/Python/ir_types.py
@@ -392,3 +392,19 @@ def testTupleType():
print("pos-th type in the tuple type:", tuple_type.get_type(1))
run(testTupleType)
+
+
+# CHECK-LABEL: TEST: testFunctionType
+def testFunctionType():
+ ctx = mlir.ir.Context()
+ input_types = [mlir.ir.IntegerType.get_signless(ctx, 32),
+ mlir.ir.IntegerType.get_signless(ctx, 16)]
+ result_types = [mlir.ir.IndexType(ctx)]
+ func = mlir.ir.FunctionType.get(ctx, input_types, result_types)
+ # CHECK: INPUTS: [Type(i32), Type(i16)]
+ print("INPUTS:", func.inputs)
+ # CHECK: RESULTS: [Type(index)]
+ print("RESULTS:", func.results)
+
+
+run(testFunctionType)
diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index 4849111986cda..909929647a84a 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -10,8 +10,8 @@
/* RUN: mlir-capi-ir-test 2>&1 | FileCheck %s
*/
-#include "mlir-c/AffineMap.h"
#include "mlir-c/IR.h"
+#include "mlir-c/AffineMap.h"
#include "mlir-c/Registration.h"
#include "mlir-c/StandardAttributes.h"
#include "mlir-c/StandardTypes.h"
@@ -443,6 +443,26 @@ static int printStandardTypes(MlirContext ctx) {
mlirTypeDump(tuple);
fprintf(stderr, "\n");
+ // Function type.
+ MlirType funcInputs[2] = {mlirIndexTypeGet(ctx), mlirIntegerTypeGet(ctx, 1)};
+ MlirType funcResults[3] = {mlirIntegerTypeGet(ctx, 16),
+ mlirIntegerTypeGet(ctx, 32),
+ mlirIntegerTypeGet(ctx, 64)};
+ MlirType funcType = mlirFunctionTypeGet(ctx, 2, funcInputs, 3, funcResults);
+ if (mlirFunctionTypeGetNumInputs(funcType) != 2)
+ return 21;
+ if (mlirFunctionTypeGetNumResults(funcType) != 3)
+ return 22;
+ if (!mlirTypeEqual(funcInputs[0], mlirFunctionTypeGetInput(funcType, 0)) ||
+ !mlirTypeEqual(funcInputs[1], mlirFunctionTypeGetInput(funcType, 1)))
+ return 23;
+ if (!mlirTypeEqual(funcResults[0], mlirFunctionTypeGetResult(funcType, 0)) ||
+ !mlirTypeEqual(funcResults[1], mlirFunctionTypeGetResult(funcType, 1)) ||
+ !mlirTypeEqual(funcResults[2], mlirFunctionTypeGetResult(funcType, 2)))
+ return 24;
+ mlirTypeDump(funcType);
+ fprintf(stderr, "\n");
+
return 0;
}
@@ -691,8 +711,7 @@ int printAffineMap(MlirContext ctx) {
return 2;
if (!mlirAffineMapIsEmpty(emptyAffineMap) ||
- mlirAffineMapIsEmpty(affineMap) ||
- mlirAffineMapIsEmpty(constAffineMap) ||
+ mlirAffineMapIsEmpty(affineMap) || mlirAffineMapIsEmpty(constAffineMap) ||
mlirAffineMapIsEmpty(multiDimIdentityAffineMap) ||
mlirAffineMapIsEmpty(minorIdentityAffineMap) ||
mlirAffineMapIsEmpty(permutationAffineMap))
@@ -859,6 +878,7 @@ int main() {
// CHECK: memref<2x3xf32, 2>
// CHECK: memref<*xf32, 4>
// CHECK: tuple<memref<*xf32, 4>, f32>
+ // CHECK: (index, i1) -> (i16, i32, i64)
// CHECK: 0
// clang-format on
fprintf(stderr, "@types\n");
More information about the llvm-branch-commits
mailing list