[Mlir-commits] [mlir] [mlir][python] python binding for the affine.store op (PR #68816)
Amy Wang
llvmlistbot at llvm.org
Wed Oct 11 09:13:21 PDT 2023
https://github.com/kaitingwang created https://github.com/llvm/llvm-project/pull/68816
This PR creates the necessary files to support bindings for operations in the affine dialect.
This is the first of many PRs which will progressively introduce affine.load, affine.for, etc operations. I would like to
acknowledge the work by Nelli's author @makslevental : https://github.com/makslevental/nelli/blob/main/nelli/mlir/affine/affine.py which jump-starts the work.
>From 8b255a3db41e9d7e1afcf5680bc2eb7b940dcb0e Mon Sep 17 00:00:00 2001
From: Amy Wang <kai.ting.wang at huawei.com>
Date: Wed, 11 Oct 2023 12:01:20 -0400
Subject: [PATCH] [mlir][python] python binding for the affine.store op
This PR creates the necessary files to support bindings for
operations in the affine dialect.
---
mlir/python/CMakeLists.txt | 10 ++++
mlir/python/mlir/dialects/AffineOps.td | 6 +++
mlir/python/mlir/dialects/_affine_ops_ext.py | 54 ++++++++++++++++++++
mlir/python/mlir/dialects/affine.py | 1 +
mlir/test/python/dialects/affine.py | 44 ++++++++++++++++
5 files changed, 115 insertions(+)
create mode 100644 mlir/python/mlir/dialects/AffineOps.td
create mode 100644 mlir/python/mlir/dialects/_affine_ops_ext.py
create mode 100644 mlir/python/mlir/dialects/affine.py
create mode 100644 mlir/test/python/dialects/affine.py
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 088d9a765b97730..c7b3c283a6b6dc1 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -46,6 +46,16 @@ declare_mlir_python_sources(MLIRPythonCAPI.HeaderSources
# Dialect bindings
################################################################################
+declare_mlir_dialect_python_bindings(
+ ADD_TO_PARENT MLIRPythonSources.Dialects
+ ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
+ TD_FILE dialects/AffineOps.td
+ SOURCES
+ dialects/affine.py
+ dialects/_affine_ops_ext.py
+ DIALECT_NAME affine
+ GEN_ENUM_BINDINGS)
+
declare_mlir_dialect_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
diff --git a/mlir/python/mlir/dialects/AffineOps.td b/mlir/python/mlir/dialects/AffineOps.td
new file mode 100644
index 000000000000000..067823c9b0247cc
--- /dev/null
+++ b/mlir/python/mlir/dialects/AffineOps.td
@@ -0,0 +1,6 @@
+#ifndef PYTHON_BINDINGS_AFFINE_OPS
+#define PYTHON_BINDINGS_AFFINE_OPS
+
+include "mlir/Dialect/Affine/IR/AffineOps.td"
+
+#endif // PYTHON_BINDINGS_AFFINE_OPS
\ No newline at end of file
diff --git a/mlir/python/mlir/dialects/_affine_ops_ext.py b/mlir/python/mlir/dialects/_affine_ops_ext.py
new file mode 100644
index 000000000000000..db789654fbc4776
--- /dev/null
+++ b/mlir/python/mlir/dialects/_affine_ops_ext.py
@@ -0,0 +1,54 @@
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+try:
+ from ..ir import *
+ from ._ods_common import get_op_result_or_value as _get_op_result_or_value
+ from ._ods_common import get_op_results_or_values as _get_op_results_or_values
+except ImportError as e:
+ raise RuntimeError("Error loading imports from extension module") from e
+
+from typing import Optional, Sequence, Union
+
+
+class AffineStoreOp:
+ """Specialization for the Affine store operation."""
+
+ def __init__(
+ self,
+ value: Value,
+ memref: Union[Operation, OpView, Value],
+ map,
+ *,
+ map_operands=[],
+ loc=None,
+ ip=None
+ ):
+ """Creates an affine load operation.
+
+ - `value`: the value to store into the memref.
+ - `memref`: the buffer to store into.
+ - `map`: the affine map that maps the map_operands to the index of the
+ memref.
+ - `map_operands`: the list of arguments to substitute the dimensions,
+ then symbols in the affine map, in increasing order.
+ """
+ operands = [
+ _get_op_result_or_value(value),
+ _get_op_result_or_value(memref),
+ *[_get_op_result_or_value(op) for op in map_operands]
+ ]
+ results = []
+ attributes = {"map": AffineMapAttr.get(map)}
+ regions = None
+ _ods_successors = None
+ super().__init__(self.build_generic(
+ attributes=attributes,
+ results=results,
+ operands=operands,
+ successors=_ods_successors,
+ regions=regions,
+ loc=loc,
+ ip=ip
+ ))
diff --git a/mlir/python/mlir/dialects/affine.py b/mlir/python/mlir/dialects/affine.py
new file mode 100644
index 000000000000000..d9d28bdb3923949
--- /dev/null
+++ b/mlir/python/mlir/dialects/affine.py
@@ -0,0 +1 @@
+from ._affine_ops_gen import *
\ No newline at end of file
diff --git a/mlir/test/python/dialects/affine.py b/mlir/test/python/dialects/affine.py
new file mode 100644
index 000000000000000..d2e664d4653420f
--- /dev/null
+++ b/mlir/test/python/dialects/affine.py
@@ -0,0 +1,44 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+from mlir.ir import *
+import mlir.dialects.func as func
+import mlir.dialects.arith as arith
+import mlir.dialects.affine as affine
+import mlir.dialects.memref as memref
+
+
+def run(f):
+ print("\nTEST:", f.__name__)
+ f()
+ return f
+
+
+# CHECK-LABEL: TEST: testAffineStoreOp
+ at run
+def testAffineStoreOp():
+ with Context() as ctx, Location.unknown():
+ module = Module.create()
+ with InsertionPoint(module.body):
+ f32 = F32Type.get()
+ index_type = IndexType.get()
+ memref_type_out = MemRefType.get([12, 12], f32)
+
+ # CHECK: func.func @affine_store_test(%[[ARG0:.*]]: index) -> memref<12x12xf32> {
+ @func.FuncOp.from_py_func(index_type)
+ def affine_store_test(arg0):
+ # CHECK: %[[O_VAR:.*]] = memref.alloc() : memref<12x12xf32>
+ mem = memref.AllocOp(memref_type_out, [], []).result
+
+ d0 = AffineDimExpr.get(0)
+ s0 = AffineSymbolExpr.get(0)
+ map = AffineMap.get(1, 1, [s0 * 3, d0 + s0 + 1])
+
+ # CHECK: %[[A1:.*]] = arith.constant 2.100000e+00 : f32
+ a1 = arith.ConstantOp(f32, 2.1)
+
+ # CHECK: affine.store %[[A1]], %alloc[symbol(%[[ARG0]]) * 3, %[[ARG0]] + symbol(%[[ARG0]]) + 1] : memref<12x12xf32>
+ affine.AffineStoreOp(a1, mem, map, map_operands=[arg0, arg0])
+
+ return mem
+
+ print(module)
More information about the Mlir-commits
mailing list