[Mlir-commits] [mlir] 255a690 - [mlir][python] Provide more convenient constructors for std.CallOp

Alex Zinenko llvmlistbot at llvm.org
Mon Oct 4 02:45:38 PDT 2021


Author: Alex Zinenko
Date: 2021-10-04T11:45:29+02:00
New Revision: 255a690971cb51c838623e8bb5b72b7415b454b5

URL: https://github.com/llvm/llvm-project/commit/255a690971cb51c838623e8bb5b72b7415b454b5
DIFF: https://github.com/llvm/llvm-project/commit/255a690971cb51c838623e8bb5b72b7415b454b5.diff

LOG: [mlir][python] Provide more convenient constructors for std.CallOp

The new constructor relies on type-based dynamic dispatch and allows one to
construct call operations given an object representing a FuncOp or its name as
a string, as opposed to requiring an explicitly constructed attribute.

Depends On D110947

Reviewed By: stellaraccident

Differential Revision: https://reviews.llvm.org/D110948

Added: 
    

Modified: 
    mlir/python/mlir/dialects/_builtin_ops_ext.py
    mlir/python/mlir/dialects/_std_ops_ext.py
    mlir/test/python/dialects/builtin.py
    mlir/test/python/dialects/std.py

Removed: 
    


################################################################################
diff  --git a/mlir/python/mlir/dialects/_builtin_ops_ext.py b/mlir/python/mlir/dialects/_builtin_ops_ext.py
index d464819f2dc8..462850d63156 100644
--- a/mlir/python/mlir/dialects/_builtin_ops_ext.py
+++ b/mlir/python/mlir/dialects/_builtin_ops_ext.py
@@ -3,7 +3,7 @@
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 try:
-  from typing import Optional, Sequence
+  from typing import Optional, Sequence, Union
 
   import inspect
 
@@ -82,8 +82,8 @@ def visibility(self):
     return self.attributes["sym_visibility"]
 
   @property
-  def name(self):
-    return self.attributes["sym_name"]
+  def name(self) -> StringAttr:
+    return StringAttr(self.attributes["sym_name"])
 
   @property
   def entry_block(self):
@@ -104,11 +104,15 @@ def add_entry_block(self):
 
   @property
   def arg_attrs(self):
-    return self.attributes[ARGUMENT_ATTRIBUTE_NAME]
+    return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME])
 
   @arg_attrs.setter
-  def arg_attrs(self, attribute: ArrayAttr):
-    self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute
+  def arg_attrs(self, attribute: Union[ArrayAttr, list]):
+    if isinstance(attribute, ArrayAttr):
+      self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute
+    else:
+      self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get(
+          attribute, context=self.context)
 
   @property
   def arguments(self):

diff  --git a/mlir/python/mlir/dialects/_std_ops_ext.py b/mlir/python/mlir/dialects/_std_ops_ext.py
index bb67fe44d3b5..39da1c83f449 100644
--- a/mlir/python/mlir/dialects/_std_ops_ext.py
+++ b/mlir/python/mlir/dialects/_std_ops_ext.py
@@ -69,3 +69,73 @@ def literal_value(self) -> Union[int, float]:
       return FloatAttr(self.value).value
     else:
       raise ValueError("only integer and float constants have literal values")
+
+
+class CallOp:
+  """Specialization for the call op class."""
+
+  def __init__(self,
+               calleeOrResults: Union[FuncOp, List[Type]],
+               argumentsOrCallee: Union[List, FlatSymbolRefAttr, str],
+               arguments: Optional[List] = None,
+               *,
+               loc=None,
+               ip=None):
+    """Creates an call operation.
+
+    The constructor accepts three 
diff erent forms:
+
+      1. A function op to be called followed by a list of arguments.
+      2. A list of result types, followed by the name of the function to be
+         called as string, following by a list of arguments.
+      3. A list of result types, followed by the name of the function to be
+         called as symbol reference attribute, followed by a list of arguments.
+
+    For example
+
+        f = builtin.FuncOp("foo", ...)
+        std.CallOp(f, [args])
+        std.CallOp([result_types], "foo", [args])
+
+    In all cases, the location and insertion point may be specified as keyword
+    arguments if not provided by the surrounding context managers.
+    """
+
+    # TODO: consider supporting constructor "overloads", e.g., through a custom
+    # or pybind-provided metaclass.
+    if isinstance(calleeOrResults, FuncOp):
+      if not isinstance(argumentsOrCallee, list):
+        raise ValueError(
+            "when constructing a call to a function, expected " +
+            "the second argument to be a list of call arguments, " +
+            f"got {type(argumentsOrCallee)}")
+      if arguments is not None:
+        raise ValueError("unexpected third argument when constructing a call" +
+                         "to a function")
+
+      super().__init__(
+          calleeOrResults.type.results,
+          FlatSymbolRefAttr.get(
+              calleeOrResults.name.value,
+              context=_get_default_loc_context(loc)),
+          argumentsOrCallee,
+          loc=loc,
+          ip=ip)
+      return
+
+    if isinstance(argumentsOrCallee, list):
+      raise ValueError("when constructing a call to a function by name, " +
+                       "expected the second argument to be a string or a " +
+                       f"FlatSymbolRefAttr, got {type(argumentsOrCallee)}")
+
+    if isinstance(argumentsOrCallee, FlatSymbolRefAttr):
+      super().__init__(
+          calleeOrResults, argumentsOrCallee, arguments, loc=loc, ip=ip)
+    elif isinstance(argumentsOrCallee, str):
+      super().__init__(
+          calleeOrResults,
+          FlatSymbolRefAttr.get(
+              argumentsOrCallee, context=_get_default_loc_context(loc)),
+          arguments,
+          loc=loc,
+          ip=ip)

diff  --git a/mlir/test/python/dialects/builtin.py b/mlir/test/python/dialects/builtin.py
index b87eabb72b94..73f2b5bc9cf9 100644
--- a/mlir/test/python/dialects/builtin.py
+++ b/mlir/test/python/dialects/builtin.py
@@ -171,7 +171,7 @@ def testFuncArgumentAccess():
     f32 = F32Type.get()
     f64 = F64Type.get()
     with InsertionPoint(module.body):
-      func = builtin.FuncOp("some_func", ([f32, f32], [f64, f64]))
+      func = builtin.FuncOp("some_func", ([f32, f32], [f32, f32]))
       with InsertionPoint(func.add_entry_block()):
         std.ReturnOp(func.arguments)
       func.arg_attrs = ArrayAttr.get([
@@ -186,6 +186,14 @@ def testFuncArgumentAccess():
           DictAttr.get({"res2": FloatAttr.get(f64, 256.0)})
       ])
 
+      other = builtin.FuncOp("other_func", ([f32, f32], []))
+      with InsertionPoint(other.add_entry_block()):
+        std.ReturnOp([])
+      other.arg_attrs = [
+          DictAttr.get({"foo": StringAttr.get("qux")}),
+          DictAttr.get()
+      ]
+
   # CHECK: [{baz, foo = "bar"}, {qux = []}]
   print(func.arg_attrs)
 
@@ -195,7 +203,11 @@ def testFuncArgumentAccess():
   # CHECK: func @some_func(
   # CHECK: %[[ARG0:.*]]: f32 {baz, foo = "bar"},
   # CHECK: %[[ARG1:.*]]: f32 {qux = []}) ->
-  # CHECK: f64 {res1 = 4.200000e+01 : f32},
-  # CHECK: f64 {res2 = 2.560000e+02 : f64})
+  # CHECK: f32 {res1 = 4.200000e+01 : f32},
+  # CHECK: f32 {res2 = 2.560000e+02 : f64})
   # CHECK: return %[[ARG0]], %[[ARG1]] : f32, f32
+  #
+  # CHECK: func @other_func(
+  # CHECK: %{{.*}}: f32 {foo = "qux"},
+  # CHECK: %{{.*}}: f32)
   print(module)

diff  --git a/mlir/test/python/dialects/std.py b/mlir/test/python/dialects/std.py
index ed507d664de6..2a4b269e9973 100644
--- a/mlir/test/python/dialects/std.py
+++ b/mlir/test/python/dialects/std.py
@@ -1,6 +1,7 @@
 # RUN: %PYTHON %s | FileCheck %s
 
 from mlir.ir import *
+from mlir.dialects import builtin
 from mlir.dialects import std
 
 
@@ -62,3 +63,27 @@ def testConstantIndexOp():
   print(c1.literal_value)
 
 # CHECK: = constant 10 : index
+
+# CHECK-LABEL: TEST: testFunctionCalls
+ at constructAndPrintInModule
+def testFunctionCalls():
+  foo = builtin.FuncOp("foo", ([], []))
+  bar = builtin.FuncOp("bar", ([], [IndexType.get()]))
+  qux = builtin.FuncOp("qux", ([], [F32Type.get()]))
+
+  with InsertionPoint(builtin.FuncOp("caller", ([], [])).add_entry_block()):
+    std.CallOp(foo, [])
+    std.CallOp([IndexType.get()], "bar", [])
+    std.CallOp([F32Type.get()], FlatSymbolRefAttr.get("qux"), [])
+    std.ReturnOp([])
+
+# CHECK: func @foo()
+# CHECK: func @bar() -> index
+# CHECK: func @qux() -> f32
+# CHECK: func @caller() {
+# CHECK:   call @foo() : () -> ()
+# CHECK:   %0 = call @bar() : () -> index
+# CHECK:   %1 = call @qux() : () -> f32
+# CHECK:   return
+# CHECK: }
+


        


More information about the Mlir-commits mailing list