[Mlir-commits] [mlir] 76753a5 - Add FunctionType to MLIR C and Python bindings.

Stella Laurenzo llvmlistbot at llvm.org
Mon Sep 28 09:57:32 PDT 2020


Author: Stella Laurenzo
Date: 2020-09-28T09:56:48-07:00
New Revision: 76753a597b5d9bf4addf19399ae30c4b3870a4a6

URL: https://github.com/llvm/llvm-project/commit/76753a597b5d9bf4addf19399ae30c4b3870a4a6
DIFF: https://github.com/llvm/llvm-project/commit/76753a597b5d9bf4addf19399ae30c4b3870a4a6.diff

LOG: Add FunctionType to MLIR C and Python bindings.

Differential Revision: https://reviews.llvm.org/D88416

Added: 
    

Modified: 
    mlir/include/mlir-c/StandardTypes.h
    mlir/lib/Bindings/Python/IRModules.cpp
    mlir/lib/CAPI/IR/StandardTypes.cpp
    mlir/test/Bindings/Python/ir_types.py
    mlir/test/CAPI/ir.c

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/StandardTypes.h b/mlir/include/mlir-c/StandardTypes.h
index 4bbcb23b8002e..3b667c9e0a1b0 100644
--- a/mlir/include/mlir-c/StandardTypes.h
+++ b/mlir/include/mlir-c/StandardTypes.h
@@ -270,6 +270,30 @@ intptr_t mlirTupleTypeGetNumTypes(MlirType type);
 /** Returns the pos-th type in the tuple type. */
 MlirType mlirTupleTypeGetType(MlirType type, intptr_t pos);
 
+/*============================================================================*/
+/* Function type.                                                             */
+/*============================================================================*/
+
+/** Checks whether the given type is a function type. */
+int mlirTypeIsAFunction(MlirType type);
+
+/** Creates a function type, mapping a list of input types to result types. */
+MlirType mlirFunctionTypeGet(MlirContext ctx, intptr_t numInputs,
+                             MlirType *inputs, intptr_t numResults,
+                             MlirType *results);
+
+/** Returns the number of input types. */
+intptr_t mlirFunctionTypeGetNumInputs(MlirType type);
+
+/** Returns the number of result types. */
+intptr_t mlirFunctionTypeGetNumResults(MlirType type);
+
+/** Returns the pos-th input type. */
+MlirType mlirFunctionTypeGetInput(MlirType type, intptr_t pos);
+
+/** Returns the pos-th result type. */
+MlirType mlirFunctionTypeGetResult(MlirType type, intptr_t pos);
+
 #ifdef __cplusplus
 }
 #endif

diff  --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp
index 99b00b96b974c..f3bd96856d090 100644
--- a/mlir/lib/Bindings/Python/IRModules.cpp
+++ b/mlir/lib/Bindings/Python/IRModules.cpp
@@ -1278,6 +1278,56 @@ class PyTupleType : public PyConcreteType<PyTupleType> {
   }
 };
 
+/// Function type.
+class PyFunctionType : public PyConcreteType<PyFunctionType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFunction;
+  static constexpr const char *pyClassName = "FunctionType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](PyMlirContext &context, std::vector<PyType> inputs,
+           std::vector<PyType> results) {
+          SmallVector<MlirType, 4> inputsRaw(inputs.begin(), inputs.end());
+          SmallVector<MlirType, 4> resultsRaw(results.begin(), results.end());
+          MlirType t = mlirFunctionTypeGet(context.get(), inputsRaw.size(),
+                                           inputsRaw.data(), resultsRaw.size(),
+                                           resultsRaw.data());
+          return PyFunctionType(context.getRef(), t);
+        },
+        py::arg("context"), py::arg("inputs"), py::arg("results"),
+        "Gets a FunctionType from a list of input and result types");
+    c.def_property_readonly(
+        "inputs",
+        [](PyFunctionType &self) {
+          MlirType t = self.type;
+          auto contextRef = self.getContext();
+          py::list types;
+          for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self.type);
+               i < e; ++i) {
+            types.append(PyType(contextRef, mlirFunctionTypeGetInput(t, i)));
+          }
+          return types;
+        },
+        "Returns the list of input types in the FunctionType.");
+    c.def_property_readonly(
+        "results",
+        [](PyFunctionType &self) {
+          MlirType t = self.type;
+          auto contextRef = self.getContext();
+          py::list types;
+          for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self.type);
+               i < e; ++i) {
+            types.append(PyType(contextRef, mlirFunctionTypeGetResult(t, i)));
+          }
+          return types;
+        },
+        "Returns the list of result types in the FunctionType.");
+  }
+};
+
 } // namespace
 
 //------------------------------------------------------------------------------
@@ -1613,6 +1663,7 @@ void mlir::python::populateIRSubmodule(py::module &m) {
   PyMemRefType::bind(m);
   PyUnrankedMemRefType::bind(m);
   PyTupleType::bind(m);
+  PyFunctionType::bind(m);
 
   // Container bindings.
   PyBlockIterator::bind(m);

diff  --git a/mlir/lib/CAPI/IR/StandardTypes.cpp b/mlir/lib/CAPI/IR/StandardTypes.cpp
index ddd3a5e93147a..b4e37d38ace31 100644
--- a/mlir/lib/CAPI/IR/StandardTypes.cpp
+++ b/mlir/lib/CAPI/IR/StandardTypes.cpp
@@ -13,6 +13,7 @@
 #include "mlir/CAPI/IR.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/StandardTypes.h"
+#include "mlir/IR/Types.h"
 
 using namespace mlir;
 
@@ -297,3 +298,41 @@ intptr_t mlirTupleTypeGetNumTypes(MlirType type) {
 MlirType mlirTupleTypeGetType(MlirType type, intptr_t pos) {
   return wrap(unwrap(type).cast<TupleType>().getType(static_cast<size_t>(pos)));
 }
+
+/*============================================================================*/
+/* Function type.                                                             */
+/*============================================================================*/
+
+int mlirTypeIsAFunction(MlirType type) {
+  return unwrap(type).isa<FunctionType>();
+}
+
+MlirType mlirFunctionTypeGet(MlirContext ctx, intptr_t numInputs,
+                             MlirType *inputs, intptr_t numResults,
+                             MlirType *results) {
+  SmallVector<Type, 4> inputsList;
+  SmallVector<Type, 4> resultsList;
+  (void)unwrapList(numInputs, inputs, inputsList);
+  (void)unwrapList(numResults, results, resultsList);
+  return wrap(FunctionType::get(inputsList, resultsList, unwrap(ctx)));
+}
+
+intptr_t mlirFunctionTypeGetNumInputs(MlirType type) {
+  return unwrap(type).cast<FunctionType>().getNumInputs();
+}
+
+intptr_t mlirFunctionTypeGetNumResults(MlirType type) {
+  return unwrap(type).cast<FunctionType>().getNumResults();
+}
+
+MlirType mlirFunctionTypeGetInput(MlirType type, intptr_t pos) {
+  assert(pos >= 0 && "pos in array must be positive");
+  return wrap(
+      unwrap(type).cast<FunctionType>().getInput(static_cast<unsigned>(pos)));
+}
+
+MlirType mlirFunctionTypeGetResult(MlirType type, intptr_t pos) {
+  assert(pos >= 0 && "pos in array must be positive");
+  return wrap(
+      unwrap(type).cast<FunctionType>().getResult(static_cast<unsigned>(pos)));
+}

diff  --git a/mlir/test/Bindings/Python/ir_types.py b/mlir/test/Bindings/Python/ir_types.py
index b80cbebb10e24..d8ae77f1f092d 100644
--- a/mlir/test/Bindings/Python/ir_types.py
+++ b/mlir/test/Bindings/Python/ir_types.py
@@ -392,3 +392,19 @@ def testTupleType():
   print("pos-th type in the tuple type:", tuple_type.get_type(1))
 
 run(testTupleType)
+
+
+# CHECK-LABEL: TEST: testFunctionType
+def testFunctionType():
+  ctx = mlir.ir.Context()
+  input_types = [mlir.ir.IntegerType.get_signless(ctx, 32),
+                 mlir.ir.IntegerType.get_signless(ctx, 16)]
+  result_types = [mlir.ir.IndexType(ctx)]
+  func = mlir.ir.FunctionType.get(ctx, input_types, result_types)
+  # CHECK: INPUTS: [Type(i32), Type(i16)]
+  print("INPUTS:", func.inputs)
+  # CHECK: RESULTS: [Type(index)]
+  print("RESULTS:", func.results)
+
+
+run(testFunctionType)

diff  --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index 4849111986cda..909929647a84a 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -10,8 +10,8 @@
 /* RUN: mlir-capi-ir-test 2>&1 | FileCheck %s
  */
 
-#include "mlir-c/AffineMap.h"
 #include "mlir-c/IR.h"
+#include "mlir-c/AffineMap.h"
 #include "mlir-c/Registration.h"
 #include "mlir-c/StandardAttributes.h"
 #include "mlir-c/StandardTypes.h"
@@ -443,6 +443,26 @@ static int printStandardTypes(MlirContext ctx) {
   mlirTypeDump(tuple);
   fprintf(stderr, "\n");
 
+  // Function type.
+  MlirType funcInputs[2] = {mlirIndexTypeGet(ctx), mlirIntegerTypeGet(ctx, 1)};
+  MlirType funcResults[3] = {mlirIntegerTypeGet(ctx, 16),
+                             mlirIntegerTypeGet(ctx, 32),
+                             mlirIntegerTypeGet(ctx, 64)};
+  MlirType funcType = mlirFunctionTypeGet(ctx, 2, funcInputs, 3, funcResults);
+  if (mlirFunctionTypeGetNumInputs(funcType) != 2)
+    return 21;
+  if (mlirFunctionTypeGetNumResults(funcType) != 3)
+    return 22;
+  if (!mlirTypeEqual(funcInputs[0], mlirFunctionTypeGetInput(funcType, 0)) ||
+      !mlirTypeEqual(funcInputs[1], mlirFunctionTypeGetInput(funcType, 1)))
+    return 23;
+  if (!mlirTypeEqual(funcResults[0], mlirFunctionTypeGetResult(funcType, 0)) ||
+      !mlirTypeEqual(funcResults[1], mlirFunctionTypeGetResult(funcType, 1)) ||
+      !mlirTypeEqual(funcResults[2], mlirFunctionTypeGetResult(funcType, 2)))
+    return 24;
+  mlirTypeDump(funcType);
+  fprintf(stderr, "\n");
+
   return 0;
 }
 
@@ -691,8 +711,7 @@ int printAffineMap(MlirContext ctx) {
     return 2;
 
   if (!mlirAffineMapIsEmpty(emptyAffineMap) ||
-      mlirAffineMapIsEmpty(affineMap) ||
-      mlirAffineMapIsEmpty(constAffineMap) ||
+      mlirAffineMapIsEmpty(affineMap) || mlirAffineMapIsEmpty(constAffineMap) ||
       mlirAffineMapIsEmpty(multiDimIdentityAffineMap) ||
       mlirAffineMapIsEmpty(minorIdentityAffineMap) ||
       mlirAffineMapIsEmpty(permutationAffineMap))
@@ -859,6 +878,7 @@ int main() {
   // CHECK: memref<2x3xf32, 2>
   // CHECK: memref<*xf32, 4>
   // CHECK: tuple<memref<*xf32, 4>, f32>
+  // CHECK: (index, i1) -> (i16, i32, i64)
   // CHECK: 0
   // clang-format on
   fprintf(stderr, "@types\n");


        


More information about the Mlir-commits mailing list