[Mlir-commits] [mlir] 8c1b785 - [mlir][python] provide bindings for the SCF dialect
Alex Zinenko
llvmlistbot at llvm.org
Thu Sep 30 00:38:24 PDT 2021
Author: Alex Zinenko
Date: 2021-09-30T09:38:15+02:00
New Revision: 8c1b785ce110b754c2112906021a929ddd32f587
URL: https://github.com/llvm/llvm-project/commit/8c1b785ce110b754c2112906021a929ddd32f587
DIFF: https://github.com/llvm/llvm-project/commit/8c1b785ce110b754c2112906021a929ddd32f587.diff
LOG: [mlir][python] provide bindings for the SCF dialect
This is an important core dialect that has not been exposed previously. Set up
the default bindings generation and provide a nicer wrapper for the `for` loop
with access to the loop configuration and body.
Depends On D110758
Reviewed By: stellaraccident
Differential Revision: https://reviews.llvm.org/D110759
Added:
mlir/python/mlir/dialects/SCFOps.td
mlir/python/mlir/dialects/_scf_ops_ext.py
mlir/python/mlir/dialects/scf.py
mlir/test/python/dialects/scf.py
Modified:
mlir/python/CMakeLists.txt
Removed:
################################################################################
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 506d8ead221dd..2ab3a9af12a93 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -109,6 +109,15 @@ declare_mlir_dialect_python_bindings(
SOURCES dialects/python_test.py
DIALECT_NAME python_test)
+declare_mlir_dialect_python_bindings(
+ ADD_TO_PARENT MLIRPythonSources.Dialects
+ ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
+ TD_FILE dialects/SCFOps.td
+ SOURCES
+ dialects/scf.py
+ dialects/_scf_ops_ext.py
+ DIALECT_NAME scf)
+
declare_mlir_dialect_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
diff --git a/mlir/python/mlir/dialects/SCFOps.td b/mlir/python/mlir/dialects/SCFOps.td
new file mode 100644
index 0000000000000..855482d4a76a9
--- /dev/null
+++ b/mlir/python/mlir/dialects/SCFOps.td
@@ -0,0 +1,15 @@
+//===-- SCFOps.td - Entry point for SCF dialect bindings ---*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef PYTHON_BINDINGS_SCF_OPS
+#define PYTHON_BINDINGS_SCF_OPS
+
+include "mlir/Bindings/Python/Attributes.td"
+include "mlir/Dialect/SCF/SCFOps.td"
+
+#endif
diff --git a/mlir/python/mlir/dialects/_scf_ops_ext.py b/mlir/python/mlir/dialects/_scf_ops_ext.py
new file mode 100644
index 0000000000000..c6532a75632b5
--- /dev/null
+++ b/mlir/python/mlir/dialects/_scf_ops_ext.py
@@ -0,0 +1,57 @@
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+try:
+ from ..ir import *
+except ImportError as e:
+ raise RuntimeError("Error loading imports from extension module") from e
+
+from typing import Any, Sequence
+
+
+class ForOp:
+ """Specialization for the SCF for op class."""
+
+ def __init__(self,
+ lower_bound,
+ upper_bound,
+ step,
+ iter_args: Sequence[Any] = [],
+ *,
+ loc=None,
+ ip=None):
+ """Creates an SCF `for` operation.
+
+ - `lower_bound` is the value to use as lower bound of the loop.
+ - `upper_bound` is the value to use as upper bound of the loop.
+ - `step` is the value to use as loop step.
+ - `iter_args` is a list of additional loop-carried arguments.
+ """
+ results = [arg.type for arg in iter_args]
+ super().__init__(
+ self.build_generic(
+ regions=1,
+ results=results,
+ operands=[lower_bound, upper_bound, step] + list(iter_args),
+ loc=loc,
+ ip=ip))
+ self.regions[0].blocks.append(IndexType.get(), *results)
+
+ @property
+ def body(self):
+ """Returns the body (block) of the loop."""
+ return self.regions[0].blocks[0]
+
+ @property
+ def induction_variable(self):
+ """Returns the induction variable of the loop."""
+ return self.body.arguments[0]
+
+ @property
+ def inner_iter_args(self):
+ """Returns the loop-carried arguments usable within the loop.
+
+ To obtain the loop-carried operands, use `iter_args`.
+ """
+ return self.body.arguments[1:]
diff --git a/mlir/python/mlir/dialects/scf.py b/mlir/python/mlir/dialects/scf.py
new file mode 100644
index 0000000000000..302a49d56c211
--- /dev/null
+++ b/mlir/python/mlir/dialects/scf.py
@@ -0,0 +1,5 @@
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from ._scf_ops_gen import *
diff --git a/mlir/test/python/dialects/scf.py b/mlir/test/python/dialects/scf.py
new file mode 100644
index 0000000000000..d604913b1c4cb
--- /dev/null
+++ b/mlir/test/python/dialects/scf.py
@@ -0,0 +1,54 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+from mlir.ir import *
+from mlir.dialects import scf
+from mlir.dialects import builtin
+
+
+def run(f):
+ print("\nTEST:", f.__name__)
+ f()
+ return f
+
+
+# CHECK-LABEL: TEST: testSimpleLoop
+ at run
+def testSimpleLoop():
+ with Context(), Location.unknown():
+ module = Module.create()
+ index_type = IndexType.get()
+ with InsertionPoint(module.body):
+
+ @builtin.FuncOp.from_py_func(index_type, index_type, index_type)
+ def simple_loop(lb, ub, step):
+ loop = scf.ForOp(lb, ub, step, [lb, lb])
+ with InsertionPoint(loop.body):
+ scf.YieldOp(loop.inner_iter_args)
+ return
+
+ # CHECK: func @simple_loop(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+ # CHECK: scf.for %{{.*}} = %[[ARG0]] to %[[ARG1]] step %[[ARG2]]
+ # CHECK: iter_args(%[[I1:.*]] = %[[ARG0]], %[[I2:.*]] = %[[ARG0]])
+ # CHECK: scf.yield %[[I1]], %[[I2]]
+ print(module)
+
+
+# CHECK-LABEL: TEST: testInductionVar
+ at run
+def testInductionVar():
+ with Context(), Location.unknown():
+ module = Module.create()
+ index_type = IndexType.get()
+ with InsertionPoint(module.body):
+
+ @builtin.FuncOp.from_py_func(index_type, index_type, index_type)
+ def induction_var(lb, ub, step):
+ loop = scf.ForOp(lb, ub, step, [lb])
+ with InsertionPoint(loop.body):
+ scf.YieldOp([loop.induction_variable])
+ return
+
+ # CHECK: func @induction_var(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+ # CHECK: scf.for %[[IV:.*]] = %[[ARG0]] to %[[ARG1]] step %[[ARG2]]
+ # CHECK: scf.yield %[[IV]]
+ print(module)
More information about the Mlir-commits
mailing list