[Mlir-commits] [mlir] d9343e6 - [mlir][python] Function decorator for capturing a FuncOp from a python function.

Stella Laurenzo llvmlistbot at llvm.org
Fri Mar 19 18:27:38 PDT 2021


Author: Stella Laurenzo
Date: 2021-03-19T18:27:21-07:00
New Revision: d9343e61534f54665b2be6dd8bc2e051220d3beb

URL: https://github.com/llvm/llvm-project/commit/d9343e61534f54665b2be6dd8bc2e051220d3beb
DIFF: https://github.com/llvm/llvm-project/commit/d9343e61534f54665b2be6dd8bc2e051220d3beb.diff

LOG: [mlir][python] Function decorator for capturing a FuncOp from a python function.

* Moves this out of a test case where it was being developed to good effect and generalizes it.
* Having tried a number of things like this, I think this balances concerns reasonably well.

Differential Revision: https://reviews.llvm.org/D98989

Added: 
    

Modified: 
    mlir/lib/Bindings/Python/mlir/dialects/_builtin_ops_ext.py
    mlir/test/Bindings/Python/dialects/builtin.py
    mlir/test/Bindings/Python/dialects/linalg/opdsl/emit_structured_generic.py

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Bindings/Python/mlir/dialects/_builtin_ops_ext.py b/mlir/lib/Bindings/Python/mlir/dialects/_builtin_ops_ext.py
index b0789299139d..dc1d37e766d0 100644
--- a/mlir/lib/Bindings/Python/mlir/dialects/_builtin_ops_ext.py
+++ b/mlir/lib/Bindings/Python/mlir/dialects/_builtin_ops_ext.py
@@ -1,6 +1,11 @@
 #  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from typing import Optional, Sequence
+
+import inspect
+
 from ..ir import *
 
 
@@ -93,3 +98,99 @@ def add_entry_block(self):
       raise IndexError('The function already has an entry block!')
     self.body.blocks.append(*self.type.inputs)
     return self.body.blocks[0]
+
+  @classmethod
+  def from_py_func(FuncOp,
+                   *inputs: Type,
+                   results: Optional[Sequence[Type]] = None,
+                   name: Optional[str] = None):
+    """Decorator to define an MLIR FuncOp specified as a python function.
+
+    Requires that an `mlir.ir.InsertionPoint` and `mlir.ir.Location` are
+    active for the current thread (i.e. established in a `with` block).
+
+    When applied as a decorator to a Python function, an entry block will
+    be constructed for the FuncOp with types as specified in `*inputs`. The
+    block arguments will be passed positionally to the Python function. In
+    addition, if the Python function accepts keyword arguments generally or
+    has a corresponding keyword argument, the following will be passed:
+      * `func_op`: The `func` op being defined.
+
+    By default, the function name will be the Python function `__name__`. This
+    can be overriden by passing the `name` argument to the decorator.
+
+    If `results` is not specified, then the decorator will implicitly
+    insert a `ReturnOp` with the `Value`'s returned from the decorated
+    function. It will also set the `FuncOp` type with the actual return
+    value types. If `results` is specified, then the decorated function
+    must return `None` and no implicit `ReturnOp` is added (nor are the result
+    types updated). The implicit behavior is intended for simple, single-block
+    cases, and users should specify result types explicitly for any complicated
+    cases.
+
+    The decorated function can further be called from Python and will insert
+    a `CallOp` at the then-current insertion point, returning either None (
+    if no return values), a unary Value (for one result), or a list of Values).
+    This mechanism cannot be used to emit recursive calls (by construction).
+    """
+
+    def decorator(f):
+      from . import std
+      # Introspect the callable for optional features.
+      sig = inspect.signature(f)
+      has_arg_func_op = False
+      for param in sig.parameters.values():
+        if param.kind == param.VAR_KEYWORD:
+          has_arg_func_op = True
+        if param.name == "func_op" and (param.kind
+                                        == param.POSITIONAL_OR_KEYWORD or
+                                        param.kind == param.KEYWORD_ONLY):
+          has_arg_func_op = True
+
+      # Emit the FuncOp.
+      implicit_return = results is None
+      symbol_name = name or f.__name__
+      function_type = FunctionType.get(
+          inputs=inputs, results=[] if implicit_return else results)
+      func_op = FuncOp(name=symbol_name, type=function_type)
+      with InsertionPoint(func_op.add_entry_block()):
+        func_args = func_op.entry_block.arguments
+        func_kwargs = {}
+        if has_arg_func_op:
+          func_kwargs["func_op"] = func_op
+        return_values = f(*func_args, **func_kwargs)
+        if not implicit_return:
+          return_types = list(results)
+          assert return_values is None, (
+              "Capturing a python function with explicit `results=` "
+              "requires that the wrapped function returns None.")
+        else:
+          # Coerce return values, add ReturnOp and rewrite func type.
+          if return_values is None:
+            return_values = []
+          elif isinstance(return_values, Value):
+            return_values = [return_values]
+          else:
+            return_values = list(return_values)
+          std.ReturnOp(return_values)
+          # Recompute the function type.
+          return_types = [v.type for v in return_values]
+          function_type = FunctionType.get(inputs=inputs, results=return_types)
+          func_op.attributes["type"] = TypeAttr.get(function_type)
+
+      def emit_call_op(*call_args):
+        call_op = std.CallOp(return_types, FlatSymbolRefAttr.get(symbol_name),
+                             call_args)
+        if return_types is None:
+          return None
+        elif len(return_types) == 1:
+          return call_op.result
+        else:
+          return call_op.results
+
+      wrapped = emit_call_op
+      wrapped.__name__ = f.__name__
+      wrapped.func_op = func_op
+      return wrapped
+
+    return decorator

diff  --git a/mlir/test/Bindings/Python/dialects/builtin.py b/mlir/test/Bindings/Python/dialects/builtin.py
index 447a255f6021..80dea68bae36 100644
--- a/mlir/test/Bindings/Python/dialects/builtin.py
+++ b/mlir/test/Bindings/Python/dialects/builtin.py
@@ -8,9 +8,106 @@
 def run(f):
   print("\nTEST:", f.__name__)
   f()
+  return f
+
+
+# CHECK-LABEL: TEST: testFromPyFunc
+ at run
+def testFromPyFunc():
+  with Context() as ctx, Location.unknown() as loc:
+    m = builtin.ModuleOp()
+    f32 = F32Type.get()
+    f64 = F64Type.get()
+    with InsertionPoint.at_block_terminator(m.body):
+      # CHECK-LABEL: func @unary_return(%arg0: f64) -> f64
+      # CHECK: return %arg0 : f64
+      @builtin.FuncOp.from_py_func(f64)
+      def unary_return(a):
+        return a
+
+      # CHECK-LABEL: func @binary_return(%arg0: f32, %arg1: f64) -> (f32, f64)
+      # CHECK: return %arg0, %arg1 : f32, f64
+      @builtin.FuncOp.from_py_func(f32, f64)
+      def binary_return(a, b):
+        return a, b
+
+      # CHECK-LABEL: func @none_return(%arg0: f32, %arg1: f64)
+      # CHECK: return
+      @builtin.FuncOp.from_py_func(f32, f64)
+      def none_return(a, b):
+        pass
+
+      # CHECK-LABEL: func @call_unary
+      # CHECK: %0 = call @unary_return(%arg0) : (f64) -> f64
+      # CHECK: return %0 : f64
+      @builtin.FuncOp.from_py_func(f64)
+      def call_unary(a):
+        return unary_return(a)
+
+      # CHECK-LABEL: func @call_binary
+      # CHECK: %0:2 = call @binary_return(%arg0, %arg1) : (f32, f64) -> (f32, f64)
+      # CHECK: return %0#0, %0#1 : f32, f64
+      @builtin.FuncOp.from_py_func(f32, f64)
+      def call_binary(a, b):
+        return binary_return(a, b)
+
+      # CHECK-LABEL: func @call_none
+      # CHECK: call @none_return(%arg0, %arg1) : (f32, f64) -> ()
+      # CHECK: return
+      @builtin.FuncOp.from_py_func(f32, f64)
+      def call_none(a, b):
+        return none_return(a, b)
+
+      ## Variants and optional feature tests.
+      # CHECK-LABEL: func @from_name_arg
+      @builtin.FuncOp.from_py_func(f32, f64, name="from_name_arg")
+      def explicit_name(a, b):
+        return b
+
+      @builtin.FuncOp.from_py_func(f32, f64)
+      def positional_func_op(a, b, func_op):
+        assert isinstance(func_op, builtin.FuncOp)
+        return b
+
+      @builtin.FuncOp.from_py_func(f32, f64)
+      def kw_func_op(a, b=None, func_op=None):
+        assert isinstance(func_op, builtin.FuncOp)
+        return b
+
+      @builtin.FuncOp.from_py_func(f32, f64)
+      def kwargs_func_op(a, b=None, **kwargs):
+        assert isinstance(kwargs["func_op"], builtin.FuncOp)
+        return b
+
+      # CHECK-LABEL: func @explicit_results(%arg0: f32, %arg1: f64) -> f64
+      # CHECK: return %arg1 : f64
+      @builtin.FuncOp.from_py_func(f32, f64, results=[f64])
+      def explicit_results(a, b):
+        std.ReturnOp([b])
+
+  print(m)
+
+
+# CHECK-LABEL: TEST: testFromPyFuncErrors
+ at run
+def testFromPyFuncErrors():
+  with Context() as ctx, Location.unknown() as loc:
+    m = builtin.ModuleOp()
+    f32 = F32Type.get()
+    f64 = F64Type.get()
+    with InsertionPoint.at_block_terminator(m.body):
+      try:
+
+        @builtin.FuncOp.from_py_func(f64, results=[f64])
+        def unary_return(a):
+          return a
+      except AssertionError as e:
+        # CHECK: Capturing a python function with explicit `results=` requires that the wrapped function returns None.
+        print(e)
 
 
 # CHECK-LABEL: TEST: testBuildFuncOp
+ at run
 def testBuildFuncOp():
   ctx = Context()
   with Location.unknown(ctx) as loc:
@@ -64,6 +161,3 @@ def testBuildFuncOp():
   # CHECK:   return %arg0 : tensor<2x3x4xf32>
   # CHECK:  }
   print(m)
-
-
-run(testBuildFuncOp)

diff  --git a/mlir/test/Bindings/Python/dialects/linalg/opdsl/emit_structured_generic.py b/mlir/test/Bindings/Python/dialects/linalg/opdsl/emit_structured_generic.py
index 7f8c11679457..573999c97525 100644
--- a/mlir/test/Bindings/Python/dialects/linalg/opdsl/emit_structured_generic.py
+++ b/mlir/test/Bindings/Python/dialects/linalg/opdsl/emit_structured_generic.py
@@ -10,46 +10,6 @@
 from mlir.dialects.linalg.opdsl.lang import *
 
 
-# TODO: Find a home for this quality of life helper.
-def build_function(*inputs: Type, results: Optional[Sequence[Type]] = None):
-  """Decorator that emits a function in a more pythonic way.
-
-  If result types are not specified, they are inferred from the function
-  returns. The `ReturnOp` is implicitly added upon the wrapped function return.
-  """
-
-  def decorator(f):
-    return_types = results
-    symbol_name = f.__name__
-    function_type = FunctionType.get(inputs=inputs, results=results or [])
-    func_op = builtin.FuncOp(name=symbol_name, type=function_type)
-    with InsertionPoint(func_op.add_entry_block()):
-      func_args = func_op.entry_block.arguments
-      return_values = f(*func_args)
-      if return_values is None:
-        return_values = []
-      elif isinstance(return_values, Value):
-        return_values = [return_values]
-      else:
-        return_values = list(return_values)
-      std.ReturnOp(return_values)
-      if return_types is None:
-        # Recompute the function type.
-        return_types = [v.type for v in return_values]
-        function_type = FunctionType.get(inputs=inputs, results=return_types)
-        # TODO: Have an API or a setter for this.
-        func_op.attributes["type"] = TypeAttr.get(function_type)
-
-    # TODO: When turning this into a real facility, return a function that emits
-    # a `call` to the function instead of doing nothing.
-    wrapped = lambda: None
-    wrapped.__name__ = symbol_name
-    wrapped.func_op = func_op
-    return wrapped
-
-  return decorator
-
-
 @linalg_structured_op
 def matmul_mono(A=TensorDef(T, S.M, S.K),
                 B=TensorDef(T, S.K, S.N),
@@ -92,8 +52,8 @@ def matmul_poly(A=TensorDef(TV.T1, S.M, S.K),
     # CHECK-SAME: ins(%[[A]], %[[B]]
     # CHECK-SAME: outs(%[[INITC]]
 
-    @build_function(RankedTensorType.get((4, 16), f32),
-                    RankedTensorType.get((16, 8), f32))
+    @builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), f32),
+                                 RankedTensorType.get((16, 8), f32))
     def test_matmul_mono(lhs, rhs):
       # TODO: Enable outs inference and add sugar for InitTensorOp
       # construction.
@@ -114,9 +74,9 @@ def test_matmul_mono(lhs, rhs):
     # CHECK-NEXT:   %[[ADD:.+]] = addi %[[C_ARG]], %[[MUL]] : i32
     # CHECK-NEXT:   linalg.yield %[[ADD]] : i32
     # CHECK-NEXT: -> tensor<4x8xi32>
-    @build_function(RankedTensorType.get((4, 16), i8),
-                    RankedTensorType.get((16, 8), i8),
-                    RankedTensorType.get((4, 8), i32))
+    @builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), i8),
+                                 RankedTensorType.get((16, 8), i8),
+                                 RankedTensorType.get((4, 8), i32))
     def test_i8i8i32_matmul(lhs, rhs, init_result):
       return matmul_poly(lhs, rhs, outs=[init_result])
 
@@ -128,9 +88,9 @@ def test_i8i8i32_matmul(lhs, rhs, init_result):
     # CHECK-NEXT:   %[[ADD:.+]] = addi %[[C_ARG]], %[[MUL]] : i32
     # CHECK-NEXT:   linalg.yield %[[ADD]] : i32
     # CHECK-NEXT: -> tensor<4x8xi32>
-    @build_function(RankedTensorType.get((4, 16), i8),
-                    RankedTensorType.get((16, 8), i16),
-                    RankedTensorType.get((4, 8), i32))
+    @builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), i8),
+                                 RankedTensorType.get((16, 8), i16),
+                                 RankedTensorType.get((4, 8), i32))
     def test_i8i16i32_matmul(lhs, rhs, init_result):
       return matmul_poly(lhs, rhs, outs=[init_result])
 
@@ -142,9 +102,9 @@ def test_i8i16i32_matmul(lhs, rhs, init_result):
     # CHECK-NEXT:   %[[ADD:.+]] = addi %[[C_ARG]], %[[MUL]] : i16
     # CHECK-NEXT:   linalg.yield %[[ADD]] : i16
     # CHECK-NEXT: -> tensor<4x8xi16>
-    @build_function(RankedTensorType.get((4, 16), i32),
-                    RankedTensorType.get((16, 8), i32),
-                    RankedTensorType.get((4, 8), i16))
+    @builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), i32),
+                                 RankedTensorType.get((16, 8), i32),
+                                 RankedTensorType.get((4, 8), i16))
     def test_i32i32i16_matmul(lhs, rhs, init_result):
       return matmul_poly(lhs, rhs, outs=[init_result])
 
@@ -156,9 +116,9 @@ def test_i32i32i16_matmul(lhs, rhs, init_result):
     # CHECK-NEXT:   %[[ADD:.+]] = addf %[[C_ARG]], %[[MUL]] : f32
     # CHECK-NEXT:   linalg.yield %[[ADD]] : f32
     # CHECK-NEXT: -> tensor<4x8xf32>
-    @build_function(RankedTensorType.get((4, 16), i8),
-                    RankedTensorType.get((16, 8), i8),
-                    RankedTensorType.get((4, 8), f32))
+    @builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), i8),
+                                 RankedTensorType.get((16, 8), i8),
+                                 RankedTensorType.get((4, 8), f32))
     def test_i8i8f32_matmul(lhs, rhs, init_result):
       return matmul_poly(lhs, rhs, outs=[init_result])
 
@@ -170,9 +130,9 @@ def test_i8i8f32_matmul(lhs, rhs, init_result):
     # CHECK-NEXT:   %[[ADD:.+]] = addf %[[C_ARG]], %[[MUL]] : f32
     # CHECK-NEXT:   linalg.yield %[[ADD]] : f32
     # CHECK-NEXT: -> tensor<4x8xf32>
-    @build_function(RankedTensorType.get((4, 16), f16),
-                    RankedTensorType.get((16, 8), f16),
-                    RankedTensorType.get((4, 8), f32))
+    @builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), f16),
+                                 RankedTensorType.get((16, 8), f16),
+                                 RankedTensorType.get((4, 8), f32))
     def test_f16f16f32_matmul(lhs, rhs, init_result):
       return matmul_poly(lhs, rhs, outs=[init_result])
 
@@ -184,9 +144,9 @@ def test_f16f16f32_matmul(lhs, rhs, init_result):
     # CHECK-NEXT:   %[[ADD:.+]] = addf %[[C_ARG]], %[[MUL]] : f32
     # CHECK-NEXT:   linalg.yield %[[ADD]] : f32
     # CHECK-NEXT: -> tensor<4x8xf32>
-    @build_function(RankedTensorType.get((4, 16), f64),
-                    RankedTensorType.get((16, 8), f64),
-                    RankedTensorType.get((4, 8), f32))
+    @builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), f64),
+                                 RankedTensorType.get((16, 8), f64),
+                                 RankedTensorType.get((4, 8), f32))
     def test_f64f64f32_matmul(lhs, rhs, init_result):
       return matmul_poly(lhs, rhs, outs=[init_result])
 


        


More information about the Mlir-commits mailing list