[Mlir-commits] [mlir] d9343e6 - [mlir][python] Function decorator for capturing a FuncOp from a python function.
Stella Laurenzo
llvmlistbot at llvm.org
Fri Mar 19 18:27:38 PDT 2021
Author: Stella Laurenzo
Date: 2021-03-19T18:27:21-07:00
New Revision: d9343e61534f54665b2be6dd8bc2e051220d3beb
URL: https://github.com/llvm/llvm-project/commit/d9343e61534f54665b2be6dd8bc2e051220d3beb
DIFF: https://github.com/llvm/llvm-project/commit/d9343e61534f54665b2be6dd8bc2e051220d3beb.diff
LOG: [mlir][python] Function decorator for capturing a FuncOp from a python function.
* Moves this out of a test case where it was being developed to good effect and generalizes it.
* Having tried a number of things like this, I think this balances concerns reasonably well.
Differential Revision: https://reviews.llvm.org/D98989
Added:
Modified:
mlir/lib/Bindings/Python/mlir/dialects/_builtin_ops_ext.py
mlir/test/Bindings/Python/dialects/builtin.py
mlir/test/Bindings/Python/dialects/linalg/opdsl/emit_structured_generic.py
Removed:
################################################################################
diff --git a/mlir/lib/Bindings/Python/mlir/dialects/_builtin_ops_ext.py b/mlir/lib/Bindings/Python/mlir/dialects/_builtin_ops_ext.py
index b0789299139d..dc1d37e766d0 100644
--- a/mlir/lib/Bindings/Python/mlir/dialects/_builtin_ops_ext.py
+++ b/mlir/lib/Bindings/Python/mlir/dialects/_builtin_ops_ext.py
@@ -1,6 +1,11 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from typing import Optional, Sequence
+
+import inspect
+
from ..ir import *
@@ -93,3 +98,99 @@ def add_entry_block(self):
raise IndexError('The function already has an entry block!')
self.body.blocks.append(*self.type.inputs)
return self.body.blocks[0]
+
+ @classmethod
+ def from_py_func(FuncOp,
+ *inputs: Type,
+ results: Optional[Sequence[Type]] = None,
+ name: Optional[str] = None):
+ """Decorator to define an MLIR FuncOp specified as a python function.
+
+ Requires that an `mlir.ir.InsertionPoint` and `mlir.ir.Location` are
+ active for the current thread (i.e. established in a `with` block).
+
+ When applied as a decorator to a Python function, an entry block will
+ be constructed for the FuncOp with types as specified in `*inputs`. The
+ block arguments will be passed positionally to the Python function. In
+ addition, if the Python function accepts keyword arguments generally or
+ has a corresponding keyword argument, the following will be passed:
+ * `func_op`: The `func` op being defined.
+
+ By default, the function name will be the Python function `__name__`. This
+ can be overriden by passing the `name` argument to the decorator.
+
+ If `results` is not specified, then the decorator will implicitly
+ insert a `ReturnOp` with the `Value`'s returned from the decorated
+ function. It will also set the `FuncOp` type with the actual return
+ value types. If `results` is specified, then the decorated function
+ must return `None` and no implicit `ReturnOp` is added (nor are the result
+ types updated). The implicit behavior is intended for simple, single-block
+ cases, and users should specify result types explicitly for any complicated
+ cases.
+
+ The decorated function can further be called from Python and will insert
+ a `CallOp` at the then-current insertion point, returning either None (
+ if no return values), a unary Value (for one result), or a list of Values).
+ This mechanism cannot be used to emit recursive calls (by construction).
+ """
+
+ def decorator(f):
+ from . import std
+ # Introspect the callable for optional features.
+ sig = inspect.signature(f)
+ has_arg_func_op = False
+ for param in sig.parameters.values():
+ if param.kind == param.VAR_KEYWORD:
+ has_arg_func_op = True
+ if param.name == "func_op" and (param.kind
+ == param.POSITIONAL_OR_KEYWORD or
+ param.kind == param.KEYWORD_ONLY):
+ has_arg_func_op = True
+
+ # Emit the FuncOp.
+ implicit_return = results is None
+ symbol_name = name or f.__name__
+ function_type = FunctionType.get(
+ inputs=inputs, results=[] if implicit_return else results)
+ func_op = FuncOp(name=symbol_name, type=function_type)
+ with InsertionPoint(func_op.add_entry_block()):
+ func_args = func_op.entry_block.arguments
+ func_kwargs = {}
+ if has_arg_func_op:
+ func_kwargs["func_op"] = func_op
+ return_values = f(*func_args, **func_kwargs)
+ if not implicit_return:
+ return_types = list(results)
+ assert return_values is None, (
+ "Capturing a python function with explicit `results=` "
+ "requires that the wrapped function returns None.")
+ else:
+ # Coerce return values, add ReturnOp and rewrite func type.
+ if return_values is None:
+ return_values = []
+ elif isinstance(return_values, Value):
+ return_values = [return_values]
+ else:
+ return_values = list(return_values)
+ std.ReturnOp(return_values)
+ # Recompute the function type.
+ return_types = [v.type for v in return_values]
+ function_type = FunctionType.get(inputs=inputs, results=return_types)
+ func_op.attributes["type"] = TypeAttr.get(function_type)
+
+ def emit_call_op(*call_args):
+ call_op = std.CallOp(return_types, FlatSymbolRefAttr.get(symbol_name),
+ call_args)
+ if return_types is None:
+ return None
+ elif len(return_types) == 1:
+ return call_op.result
+ else:
+ return call_op.results
+
+ wrapped = emit_call_op
+ wrapped.__name__ = f.__name__
+ wrapped.func_op = func_op
+ return wrapped
+
+ return decorator
diff --git a/mlir/test/Bindings/Python/dialects/builtin.py b/mlir/test/Bindings/Python/dialects/builtin.py
index 447a255f6021..80dea68bae36 100644
--- a/mlir/test/Bindings/Python/dialects/builtin.py
+++ b/mlir/test/Bindings/Python/dialects/builtin.py
@@ -8,9 +8,106 @@
def run(f):
print("\nTEST:", f.__name__)
f()
+ return f
+
+
+# CHECK-LABEL: TEST: testFromPyFunc
+ at run
+def testFromPyFunc():
+ with Context() as ctx, Location.unknown() as loc:
+ m = builtin.ModuleOp()
+ f32 = F32Type.get()
+ f64 = F64Type.get()
+ with InsertionPoint.at_block_terminator(m.body):
+ # CHECK-LABEL: func @unary_return(%arg0: f64) -> f64
+ # CHECK: return %arg0 : f64
+ @builtin.FuncOp.from_py_func(f64)
+ def unary_return(a):
+ return a
+
+ # CHECK-LABEL: func @binary_return(%arg0: f32, %arg1: f64) -> (f32, f64)
+ # CHECK: return %arg0, %arg1 : f32, f64
+ @builtin.FuncOp.from_py_func(f32, f64)
+ def binary_return(a, b):
+ return a, b
+
+ # CHECK-LABEL: func @none_return(%arg0: f32, %arg1: f64)
+ # CHECK: return
+ @builtin.FuncOp.from_py_func(f32, f64)
+ def none_return(a, b):
+ pass
+
+ # CHECK-LABEL: func @call_unary
+ # CHECK: %0 = call @unary_return(%arg0) : (f64) -> f64
+ # CHECK: return %0 : f64
+ @builtin.FuncOp.from_py_func(f64)
+ def call_unary(a):
+ return unary_return(a)
+
+ # CHECK-LABEL: func @call_binary
+ # CHECK: %0:2 = call @binary_return(%arg0, %arg1) : (f32, f64) -> (f32, f64)
+ # CHECK: return %0#0, %0#1 : f32, f64
+ @builtin.FuncOp.from_py_func(f32, f64)
+ def call_binary(a, b):
+ return binary_return(a, b)
+
+ # CHECK-LABEL: func @call_none
+ # CHECK: call @none_return(%arg0, %arg1) : (f32, f64) -> ()
+ # CHECK: return
+ @builtin.FuncOp.from_py_func(f32, f64)
+ def call_none(a, b):
+ return none_return(a, b)
+
+ ## Variants and optional feature tests.
+ # CHECK-LABEL: func @from_name_arg
+ @builtin.FuncOp.from_py_func(f32, f64, name="from_name_arg")
+ def explicit_name(a, b):
+ return b
+
+ @builtin.FuncOp.from_py_func(f32, f64)
+ def positional_func_op(a, b, func_op):
+ assert isinstance(func_op, builtin.FuncOp)
+ return b
+
+ @builtin.FuncOp.from_py_func(f32, f64)
+ def kw_func_op(a, b=None, func_op=None):
+ assert isinstance(func_op, builtin.FuncOp)
+ return b
+
+ @builtin.FuncOp.from_py_func(f32, f64)
+ def kwargs_func_op(a, b=None, **kwargs):
+ assert isinstance(kwargs["func_op"], builtin.FuncOp)
+ return b
+
+ # CHECK-LABEL: func @explicit_results(%arg0: f32, %arg1: f64) -> f64
+ # CHECK: return %arg1 : f64
+ @builtin.FuncOp.from_py_func(f32, f64, results=[f64])
+ def explicit_results(a, b):
+ std.ReturnOp([b])
+
+ print(m)
+
+
+# CHECK-LABEL: TEST: testFromPyFuncErrors
+ at run
+def testFromPyFuncErrors():
+ with Context() as ctx, Location.unknown() as loc:
+ m = builtin.ModuleOp()
+ f32 = F32Type.get()
+ f64 = F64Type.get()
+ with InsertionPoint.at_block_terminator(m.body):
+ try:
+
+ @builtin.FuncOp.from_py_func(f64, results=[f64])
+ def unary_return(a):
+ return a
+ except AssertionError as e:
+ # CHECK: Capturing a python function with explicit `results=` requires that the wrapped function returns None.
+ print(e)
# CHECK-LABEL: TEST: testBuildFuncOp
+ at run
def testBuildFuncOp():
ctx = Context()
with Location.unknown(ctx) as loc:
@@ -64,6 +161,3 @@ def testBuildFuncOp():
# CHECK: return %arg0 : tensor<2x3x4xf32>
# CHECK: }
print(m)
-
-
-run(testBuildFuncOp)
diff --git a/mlir/test/Bindings/Python/dialects/linalg/opdsl/emit_structured_generic.py b/mlir/test/Bindings/Python/dialects/linalg/opdsl/emit_structured_generic.py
index 7f8c11679457..573999c97525 100644
--- a/mlir/test/Bindings/Python/dialects/linalg/opdsl/emit_structured_generic.py
+++ b/mlir/test/Bindings/Python/dialects/linalg/opdsl/emit_structured_generic.py
@@ -10,46 +10,6 @@
from mlir.dialects.linalg.opdsl.lang import *
-# TODO: Find a home for this quality of life helper.
-def build_function(*inputs: Type, results: Optional[Sequence[Type]] = None):
- """Decorator that emits a function in a more pythonic way.
-
- If result types are not specified, they are inferred from the function
- returns. The `ReturnOp` is implicitly added upon the wrapped function return.
- """
-
- def decorator(f):
- return_types = results
- symbol_name = f.__name__
- function_type = FunctionType.get(inputs=inputs, results=results or [])
- func_op = builtin.FuncOp(name=symbol_name, type=function_type)
- with InsertionPoint(func_op.add_entry_block()):
- func_args = func_op.entry_block.arguments
- return_values = f(*func_args)
- if return_values is None:
- return_values = []
- elif isinstance(return_values, Value):
- return_values = [return_values]
- else:
- return_values = list(return_values)
- std.ReturnOp(return_values)
- if return_types is None:
- # Recompute the function type.
- return_types = [v.type for v in return_values]
- function_type = FunctionType.get(inputs=inputs, results=return_types)
- # TODO: Have an API or a setter for this.
- func_op.attributes["type"] = TypeAttr.get(function_type)
-
- # TODO: When turning this into a real facility, return a function that emits
- # a `call` to the function instead of doing nothing.
- wrapped = lambda: None
- wrapped.__name__ = symbol_name
- wrapped.func_op = func_op
- return wrapped
-
- return decorator
-
-
@linalg_structured_op
def matmul_mono(A=TensorDef(T, S.M, S.K),
B=TensorDef(T, S.K, S.N),
@@ -92,8 +52,8 @@ def matmul_poly(A=TensorDef(TV.T1, S.M, S.K),
# CHECK-SAME: ins(%[[A]], %[[B]]
# CHECK-SAME: outs(%[[INITC]]
- @build_function(RankedTensorType.get((4, 16), f32),
- RankedTensorType.get((16, 8), f32))
+ @builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), f32),
+ RankedTensorType.get((16, 8), f32))
def test_matmul_mono(lhs, rhs):
# TODO: Enable outs inference and add sugar for InitTensorOp
# construction.
@@ -114,9 +74,9 @@ def test_matmul_mono(lhs, rhs):
# CHECK-NEXT: %[[ADD:.+]] = addi %[[C_ARG]], %[[MUL]] : i32
# CHECK-NEXT: linalg.yield %[[ADD]] : i32
# CHECK-NEXT: -> tensor<4x8xi32>
- @build_function(RankedTensorType.get((4, 16), i8),
- RankedTensorType.get((16, 8), i8),
- RankedTensorType.get((4, 8), i32))
+ @builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), i8),
+ RankedTensorType.get((16, 8), i8),
+ RankedTensorType.get((4, 8), i32))
def test_i8i8i32_matmul(lhs, rhs, init_result):
return matmul_poly(lhs, rhs, outs=[init_result])
@@ -128,9 +88,9 @@ def test_i8i8i32_matmul(lhs, rhs, init_result):
# CHECK-NEXT: %[[ADD:.+]] = addi %[[C_ARG]], %[[MUL]] : i32
# CHECK-NEXT: linalg.yield %[[ADD]] : i32
# CHECK-NEXT: -> tensor<4x8xi32>
- @build_function(RankedTensorType.get((4, 16), i8),
- RankedTensorType.get((16, 8), i16),
- RankedTensorType.get((4, 8), i32))
+ @builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), i8),
+ RankedTensorType.get((16, 8), i16),
+ RankedTensorType.get((4, 8), i32))
def test_i8i16i32_matmul(lhs, rhs, init_result):
return matmul_poly(lhs, rhs, outs=[init_result])
@@ -142,9 +102,9 @@ def test_i8i16i32_matmul(lhs, rhs, init_result):
# CHECK-NEXT: %[[ADD:.+]] = addi %[[C_ARG]], %[[MUL]] : i16
# CHECK-NEXT: linalg.yield %[[ADD]] : i16
# CHECK-NEXT: -> tensor<4x8xi16>
- @build_function(RankedTensorType.get((4, 16), i32),
- RankedTensorType.get((16, 8), i32),
- RankedTensorType.get((4, 8), i16))
+ @builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), i32),
+ RankedTensorType.get((16, 8), i32),
+ RankedTensorType.get((4, 8), i16))
def test_i32i32i16_matmul(lhs, rhs, init_result):
return matmul_poly(lhs, rhs, outs=[init_result])
@@ -156,9 +116,9 @@ def test_i32i32i16_matmul(lhs, rhs, init_result):
# CHECK-NEXT: %[[ADD:.+]] = addf %[[C_ARG]], %[[MUL]] : f32
# CHECK-NEXT: linalg.yield %[[ADD]] : f32
# CHECK-NEXT: -> tensor<4x8xf32>
- @build_function(RankedTensorType.get((4, 16), i8),
- RankedTensorType.get((16, 8), i8),
- RankedTensorType.get((4, 8), f32))
+ @builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), i8),
+ RankedTensorType.get((16, 8), i8),
+ RankedTensorType.get((4, 8), f32))
def test_i8i8f32_matmul(lhs, rhs, init_result):
return matmul_poly(lhs, rhs, outs=[init_result])
@@ -170,9 +130,9 @@ def test_i8i8f32_matmul(lhs, rhs, init_result):
# CHECK-NEXT: %[[ADD:.+]] = addf %[[C_ARG]], %[[MUL]] : f32
# CHECK-NEXT: linalg.yield %[[ADD]] : f32
# CHECK-NEXT: -> tensor<4x8xf32>
- @build_function(RankedTensorType.get((4, 16), f16),
- RankedTensorType.get((16, 8), f16),
- RankedTensorType.get((4, 8), f32))
+ @builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), f16),
+ RankedTensorType.get((16, 8), f16),
+ RankedTensorType.get((4, 8), f32))
def test_f16f16f32_matmul(lhs, rhs, init_result):
return matmul_poly(lhs, rhs, outs=[init_result])
@@ -184,9 +144,9 @@ def test_f16f16f32_matmul(lhs, rhs, init_result):
# CHECK-NEXT: %[[ADD:.+]] = addf %[[C_ARG]], %[[MUL]] : f32
# CHECK-NEXT: linalg.yield %[[ADD]] : f32
# CHECK-NEXT: -> tensor<4x8xf32>
- @build_function(RankedTensorType.get((4, 16), f64),
- RankedTensorType.get((16, 8), f64),
- RankedTensorType.get((4, 8), f32))
+ @builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), f64),
+ RankedTensorType.get((16, 8), f64),
+ RankedTensorType.get((4, 8), f32))
def test_f64f64f32_matmul(lhs, rhs, init_result):
return matmul_poly(lhs, rhs, outs=[init_result])
More information about the Mlir-commits
mailing list