[Mlir-commits] [mlir] ec294eb - [mlir][linalg] Add an InitTensorOp python builder.

Stella Laurenzo llvmlistbot at llvm.org
Thu Mar 25 15:18:55 PDT 2021


Author: Stella Laurenzo
Date: 2021-03-25T15:17:48-07:00
New Revision: ec294eb87be24764aac15d4df046a841f77f4b48

URL: https://github.com/llvm/llvm-project/commit/ec294eb87be24764aac15d4df046a841f77f4b48
DIFF: https://github.com/llvm/llvm-project/commit/ec294eb87be24764aac15d4df046a841f77f4b48.diff

LOG: [mlir][linalg] Add an InitTensorOp python builder.

* This has the API I want but I am not thrilled with the implementation. There are various things that could be improved both about the way that Python builders are mapped and the way the Linalg ops are factored to increase code sharing between C++/Python.
* Landing this as-is since it at least makes the InitTensorOp usable with the right API. Will refactor underneath in follow-ons.

Differential Revision: https://reviews.llvm.org/D99000

Added: 
    

Modified: 
    mlir/lib/Bindings/Python/mlir/dialects/_linalg_ops_ext.py
    mlir/lib/Bindings/Python/mlir/dialects/_ods_common.py
    mlir/test/Bindings/Python/dialects/linalg/opdsl/emit_structured_generic.py
    mlir/test/Bindings/Python/dialects/linalg/ops.py

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Bindings/Python/mlir/dialects/_linalg_ops_ext.py b/mlir/lib/Bindings/Python/mlir/dialects/_linalg_ops_ext.py
index 74390d487a670..d35d10cc4b8eb 100644
--- a/mlir/lib/Bindings/Python/mlir/dialects/_linalg_ops_ext.py
+++ b/mlir/lib/Bindings/Python/mlir/dialects/_linalg_ops_ext.py
@@ -2,6 +2,48 @@
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
+from typing import Optional, Sequence, Union
+from ..ir import *
+from ._ods_common import get_default_loc_context
+
+
+class InitTensorOp:
+  """Extends the linalg.init_tensor op."""
+
+  def __init__(self,
+               sizes: Union[Sequence[int], Sequence[Value]],
+               element_type: Type,
+               *,
+               loc=None,
+               ip=None):
+    """Constructs an `init_tensor` with either static or dynamic sizes."""
+    context = get_default_loc_context(loc)
+    operands = []
+    attributes = {}
+    # TODO: Refactor the InitTensorOp to take an element type attribute and
+    # then use normal result type inference, unifying the Python and C++ side
+    # with a standard mechanism (versus stashing that in builders).
+    if sizes and isinstance(sizes[0], Value):
+      # Dynamic sizes.
+      operands.extend(sizes)
+      static_size_ints = [-1] * len(sizes)
+      result_type = RankedTensorType.get(static_size_ints, element_type)
+    else:
+      # Static sizes.
+      result_type = RankedTensorType.get(sizes, element_type)
+      static_size_ints = sizes
+
+    index_type = IndexType.get(context)
+    attributes["static_sizes"] = ArrayAttr.get(
+        [IntegerAttr.get(index_type, s) for s in static_size_ints],
+        context=context)
+    op = self.build_generic(results=[result_type],
+                            operands=operands,
+                            attributes=attributes,
+                            loc=loc,
+                            ip=ip)
+    OpView.__init__(self, op)
+
 
 class StructuredOpMixin:
   """All structured ops use the same mixin class."""

diff  --git a/mlir/lib/Bindings/Python/mlir/dialects/_ods_common.py b/mlir/lib/Bindings/Python/mlir/dialects/_ods_common.py
index 6d37700ecdc47..d030440887414 100644
--- a/mlir/lib/Bindings/Python/mlir/dialects/_ods_common.py
+++ b/mlir/lib/Bindings/Python/mlir/dialects/_ods_common.py
@@ -17,14 +17,15 @@ def extend_opview_class(ext_module):
   """Decorator to extend an OpView class from an extension module.
 
   Extension modules can expose various entry-points:
+    Stand-alone class with the same name as a parent OpView class (i.e.
+    "ReturnOp"). A name-based match is attempted first before falling back
+    to a below mechanism.
+
     def select_opview_mixin(parent_opview_cls):
       If defined, allows an appropriate mixin class to be selected dynamically
       based on the parent OpView class. Should return NotImplemented if a
       decision is not made.
 
-    Stand-alone class with the same name as a parent OpView class (i.e.
-    "ReturnOp").
-
   Args:
     ext_module: A module from which to locate extensions. Can be None if not
       available.
@@ -38,16 +39,18 @@ def class_decorator(parent_opview_cls: type):
     if ext_module is None:
       return parent_opview_cls
     mixin_cls = NotImplemented
+    # First try to resolve by name.
     try:
-      select_mixin = getattr(ext_module, "select_opview_mixin")
+      mixin_cls = getattr(ext_module, parent_opview_cls.__name__)
     except AttributeError:
-      # Try to default resolve it.
+      # Fall back to a select_opview_mixin hook.
       try:
-        mixin_cls = getattr(ext_module, parent_opview_cls.__name__)
+        select_mixin = getattr(ext_module, "select_opview_mixin")
       except AttributeError:
         pass
-    else:
-      mixin_cls = select_mixin(parent_opview_cls)
+      else:
+        mixin_cls = select_mixin(parent_opview_cls)
+
     if mixin_cls is NotImplemented or mixin_cls is None:
       return parent_opview_cls
 

diff  --git a/mlir/test/Bindings/Python/dialects/linalg/opdsl/emit_structured_generic.py b/mlir/test/Bindings/Python/dialects/linalg/opdsl/emit_structured_generic.py
index 397b32c93c22b..f27f79a4fb037 100644
--- a/mlir/test/Bindings/Python/dialects/linalg/opdsl/emit_structured_generic.py
+++ b/mlir/test/Bindings/Python/dialects/linalg/opdsl/emit_structured_generic.py
@@ -55,15 +55,7 @@ def matmul_poly(A=TensorDef(TV.T1, S.M, S.K),
     @builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), f32),
                                  RankedTensorType.get((16, 8), f32))
     def test_matmul_mono(lhs, rhs):
-      # TODO: Enable outs inference and add sugar for InitTensorOp
-      # construction.
-      init_result = linalg.InitTensorOp(result=RankedTensorType.get((4, 8),
-                                                                    f32),
-                                        static_sizes=ArrayAttr.get([
-                                            IntegerAttr.get(IndexType.get(), 4),
-                                            IntegerAttr.get(IndexType.get(), 8)
-                                        ]),
-                                        sizes=[])
+      init_result = linalg.InitTensorOp([4, 8], f32)
       return matmul_mono(lhs, rhs, outs=[init_result.result])
 
     # CHECK-LABEL: @test_i8i8i32_matmul

diff  --git a/mlir/test/Bindings/Python/dialects/linalg/ops.py b/mlir/test/Bindings/Python/dialects/linalg/ops.py
index 04a6ac8def843..22ed09e0716ce 100644
--- a/mlir/test/Bindings/Python/dialects/linalg/ops.py
+++ b/mlir/test/Bindings/Python/dialects/linalg/ops.py
@@ -9,9 +9,39 @@
 def run(f):
   print("\nTEST:", f.__name__)
   f()
+  return f
+
+
+# CHECK-LABEL: TEST: testInitTensor
+ at run
+def testInitTensor():
+  with Context() as ctx, Location.unknown():
+    module = Module.create()
+    f32 = F32Type.get()
+    with InsertionPoint(module.body):
+      # CHECK-LABEL: func @static_sizes
+      # CHECK: %0 = linalg.init_tensor [3, 4] : tensor<3x4xf32>
+      @builtin.FuncOp.from_py_func()
+      def static_sizes():
+        return linalg.InitTensorOp([3, 4], f32)
+
+      # CHECK-LABEL: func @dynamic_sizes
+      # CHECK: %0 = linalg.init_tensor [%arg0, %arg1] : tensor<?x?xf32>
+      @builtin.FuncOp.from_py_func(IndexType.get(), IndexType.get())
+      def dynamic_sizes(d0, d1):
+        return linalg.InitTensorOp([d0, d1], f32)
+
+      # CHECK-LABEL: func @zero_d
+      # CHECK: %0 = linalg.init_tensor [] : tensor<f32>
+      @builtin.FuncOp.from_py_func()
+      def zero_d():
+        return linalg.InitTensorOp([], f32)
+
+  print(module)
 
 
 # CHECK-LABEL: TEST: testStructuredOpOnTensors
+ at run
 def testStructuredOpOnTensors():
   with Context() as ctx, Location.unknown():
     module = Module.create()
@@ -31,10 +61,8 @@ def testStructuredOpOnTensors():
   print(module)
 
 
-run(testStructuredOpOnTensors)
-
-
 # CHECK-LABEL: TEST: testStructuredOpOnBuffers
+ at run
 def testStructuredOpOnBuffers():
   with Context() as ctx, Location.unknown():
     module = Module.create()
@@ -52,6 +80,3 @@ def testStructuredOpOnBuffers():
 
   # CHECK: linalg.matmul ins(%arg0, %arg1 : memref<2x3x4xf32>, memref<2x3x4xf32>) outs(%arg2 : memref<2x3x4xf32>)
   print(module)
-
-
-run(testStructuredOpOnBuffers)


        


More information about the Mlir-commits mailing list