[Mlir-commits] [mlir] 8c1b785 - [mlir][python] provide bindings for the SCF dialect

Alex Zinenko llvmlistbot at llvm.org
Thu Sep 30 00:38:24 PDT 2021


Author: Alex Zinenko
Date: 2021-09-30T09:38:15+02:00
New Revision: 8c1b785ce110b754c2112906021a929ddd32f587

URL: https://github.com/llvm/llvm-project/commit/8c1b785ce110b754c2112906021a929ddd32f587
DIFF: https://github.com/llvm/llvm-project/commit/8c1b785ce110b754c2112906021a929ddd32f587.diff

LOG: [mlir][python] provide bindings for the SCF dialect

This is an important core dialect that has not been exposed previously. Set up
the default bindings generation and provide a nicer wrapper for the `for` loop
with access to the loop configuration and body.

Depends On D110758

Reviewed By: stellaraccident

Differential Revision: https://reviews.llvm.org/D110759

Added: 
    mlir/python/mlir/dialects/SCFOps.td
    mlir/python/mlir/dialects/_scf_ops_ext.py
    mlir/python/mlir/dialects/scf.py
    mlir/test/python/dialects/scf.py

Modified: 
    mlir/python/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 506d8ead221dd..2ab3a9af12a93 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -109,6 +109,15 @@ declare_mlir_dialect_python_bindings(
   SOURCES dialects/python_test.py
   DIALECT_NAME python_test)
 
+declare_mlir_dialect_python_bindings(
+  ADD_TO_PARENT MLIRPythonSources.Dialects
+  ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
+  TD_FILE dialects/SCFOps.td
+  SOURCES
+    dialects/scf.py
+    dialects/_scf_ops_ext.py
+  DIALECT_NAME scf)
+
 declare_mlir_dialect_python_bindings(
   ADD_TO_PARENT MLIRPythonSources.Dialects
   ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"

diff  --git a/mlir/python/mlir/dialects/SCFOps.td b/mlir/python/mlir/dialects/SCFOps.td
new file mode 100644
index 0000000000000..855482d4a76a9
--- /dev/null
+++ b/mlir/python/mlir/dialects/SCFOps.td
@@ -0,0 +1,15 @@
+//===-- SCFOps.td - Entry point for SCF dialect bindings ---*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef PYTHON_BINDINGS_SCF_OPS
+#define PYTHON_BINDINGS_SCF_OPS
+
+include "mlir/Bindings/Python/Attributes.td"
+include "mlir/Dialect/SCF/SCFOps.td"
+
+#endif

diff  --git a/mlir/python/mlir/dialects/_scf_ops_ext.py b/mlir/python/mlir/dialects/_scf_ops_ext.py
new file mode 100644
index 0000000000000..c6532a75632b5
--- /dev/null
+++ b/mlir/python/mlir/dialects/_scf_ops_ext.py
@@ -0,0 +1,57 @@
+#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+#  See https://llvm.org/LICENSE.txt for license information.
+#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+try:
+  from ..ir import *
+except ImportError as e:
+  raise RuntimeError("Error loading imports from extension module") from e
+
+from typing import Any, Sequence
+
+
+class ForOp:
+  """Specialization for the SCF for op class."""
+
+  def __init__(self,
+               lower_bound,
+               upper_bound,
+               step,
+               iter_args: Sequence[Any] = [],
+               *,
+               loc=None,
+               ip=None):
+    """Creates an SCF `for` operation.
+
+    - `lower_bound` is the value to use as lower bound of the loop.
+    - `upper_bound` is the value to use as upper bound of the loop.
+    - `step` is the value to use as loop step.
+    - `iter_args` is a list of additional loop-carried arguments.
+    """
+    results = [arg.type for arg in iter_args]
+    super().__init__(
+        self.build_generic(
+            regions=1,
+            results=results,
+            operands=[lower_bound, upper_bound, step] + list(iter_args),
+            loc=loc,
+            ip=ip))
+    self.regions[0].blocks.append(IndexType.get(), *results)
+
+  @property
+  def body(self):
+    """Returns the body (block) of the loop."""
+    return self.regions[0].blocks[0]
+
+  @property
+  def induction_variable(self):
+    """Returns the induction variable of the loop."""
+    return self.body.arguments[0]
+
+  @property
+  def inner_iter_args(self):
+    """Returns the loop-carried arguments usable within the loop.
+
+    To obtain the loop-carried operands, use `iter_args`.
+    """
+    return self.body.arguments[1:]

diff  --git a/mlir/python/mlir/dialects/scf.py b/mlir/python/mlir/dialects/scf.py
new file mode 100644
index 0000000000000..302a49d56c211
--- /dev/null
+++ b/mlir/python/mlir/dialects/scf.py
@@ -0,0 +1,5 @@
+#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+#  See https://llvm.org/LICENSE.txt for license information.
+#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from ._scf_ops_gen import *

diff  --git a/mlir/test/python/dialects/scf.py b/mlir/test/python/dialects/scf.py
new file mode 100644
index 0000000000000..d604913b1c4cb
--- /dev/null
+++ b/mlir/test/python/dialects/scf.py
@@ -0,0 +1,54 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+from mlir.ir import *
+from mlir.dialects import scf
+from mlir.dialects import builtin
+
+
+def run(f):
+  print("\nTEST:", f.__name__)
+  f()
+  return f
+
+
+# CHECK-LABEL: TEST: testSimpleLoop
+ at run
+def testSimpleLoop():
+  with Context(), Location.unknown():
+    module = Module.create()
+    index_type = IndexType.get()
+    with InsertionPoint(module.body):
+
+      @builtin.FuncOp.from_py_func(index_type, index_type, index_type)
+      def simple_loop(lb, ub, step):
+        loop = scf.ForOp(lb, ub, step, [lb, lb])
+        with InsertionPoint(loop.body):
+          scf.YieldOp(loop.inner_iter_args)
+        return
+
+  # CHECK: func @simple_loop(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+  # CHECK: scf.for %{{.*}} = %[[ARG0]] to %[[ARG1]] step %[[ARG2]]
+  # CHECK: iter_args(%[[I1:.*]] = %[[ARG0]], %[[I2:.*]] = %[[ARG0]])
+  # CHECK: scf.yield %[[I1]], %[[I2]]
+  print(module)
+
+
+# CHECK-LABEL: TEST: testInductionVar
+ at run
+def testInductionVar():
+  with Context(), Location.unknown():
+    module = Module.create()
+    index_type = IndexType.get()
+    with InsertionPoint(module.body):
+
+      @builtin.FuncOp.from_py_func(index_type, index_type, index_type)
+      def induction_var(lb, ub, step):
+        loop = scf.ForOp(lb, ub, step, [lb])
+        with InsertionPoint(loop.body):
+          scf.YieldOp([loop.induction_variable])
+        return
+
+  # CHECK: func @induction_var(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+  # CHECK: scf.for %[[IV:.*]] = %[[ARG0]] to %[[ARG1]] step %[[ARG2]]
+  # CHECK: scf.yield %[[IV]]
+  print(module)


        


More information about the Mlir-commits mailing list