[llvm] [mlir][sparse] Generates python bindings for SparseTensorTransformOps. (PR #66937)

Peiming Liu via llvm-commits llvm-commits at lists.llvm.org
Wed Sep 20 12:44:24 PDT 2023


https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/66937

>From ae9dab38a7873456a89ab3d01714a91558048c9f Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Wed, 20 Sep 2023 18:20:06 +0000
Subject: [PATCH 1/2] [mlir][sparse] Generates python bindings for
 SparseTensorTransformOps.

---
 mlir/python/CMakeLists.txt                    |  9 +++++++++
 .../mlir/dialects/SparseTensorTransformOps.td | 14 +++++++++++++
 .../mlir/dialects/transform/sparse_tensor.py  |  5 +++++
 .../mlir/python/BUILD.bazel                   | 20 +++++++++++++++++++
 4 files changed, 48 insertions(+)
 create mode 100644 mlir/python/mlir/dialects/SparseTensorTransformOps.td
 create mode 100644 mlir/python/mlir/dialects/transform/sparse_tensor.py

diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 5d2f233caa85d83..25be18fced0f7ac 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -213,6 +213,15 @@ declare_mlir_dialect_extension_python_bindings(
     "../../include/mlir/Dialect/Linalg/TransformOps/LinalgTransformEnums.td"
 )
 
+declare_mlir_dialect_extension_python_bindings(
+  ADD_TO_PARENT MLIRPythonSources.Dialects
+  ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
+  TD_FILE dialects/SparseTensorTransformOps.td
+  SOURCES
+    dialects/transform/sparse_tensor.py
+  DIALECT_NAME transform
+  EXTENSION_NAME sparse_tensor_transform)
+
 declare_mlir_dialect_extension_python_bindings(
   ADD_TO_PARENT MLIRPythonSources.Dialects
   ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
diff --git a/mlir/python/mlir/dialects/SparseTensorTransformOps.td b/mlir/python/mlir/dialects/SparseTensorTransformOps.td
new file mode 100644
index 000000000000000..8b95ae01d1d517c
--- /dev/null
+++ b/mlir/python/mlir/dialects/SparseTensorTransformOps.td
@@ -0,0 +1,14 @@
+//===-- SparseTensorTransfromOps.td ------------------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef PYTHON_BINDINGS_SPARSE_TENSOR_TRANSFORM_OPS
+#define PYTHON_BINDINGS_SPARSE_TENSOR_TRANSFORM_OPS
+
+include "mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.td"
+
+#endif
diff --git a/mlir/python/mlir/dialects/transform/sparse_tensor.py b/mlir/python/mlir/dialects/transform/sparse_tensor.py
new file mode 100644
index 000000000000000..8b33270dc23a119
--- /dev/null
+++ b/mlir/python/mlir/dialects/transform/sparse_tensor.py
@@ -0,0 +1,5 @@
+#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+#  See https://llvm.org/LICENSE.txt for license information.
+#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from .._sparse_tensor_transform_ops_gen import *
diff --git a/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel
index 1ea71cac2445e40..2e4e7892648ef35 100644
--- a/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel
@@ -1232,6 +1232,25 @@ gentbl_filegroup(
     ],
 )
 
+gentbl_filegroup(
+    name = "SparseTensorTransformOpsPyGen",
+    tbl_outs = [
+        (
+            [
+                "-gen-python-op-bindings",
+                "-bind-dialect=transform",
+                "-dialect-extension=sparse_tensor_transform",
+            ],
+            "mlir/dialects/_sparse_tensor_transform_ops_gen.py",
+        ),
+    ],
+    tblgen = "//mlir:mlir-tblgen",
+    td_file = "mlir/dialects/SparseTensorTransformOps.td",
+    deps = [
+        "//mlir:SparseTensorTransformOpsTdFiles",
+    ],
+)
+
 gentbl_filegroup(
     name = "TensorTransformOpsPyGen",
     tbl_outs = [
@@ -1309,6 +1328,7 @@ filegroup(
         ":LoopTransformOpsPyGen",
         ":MemRefTransformOpsPyGen",
         ":PDLTransformOpsPyGen",
+        ":SparseTensorTransformOpsPyGen",
         ":StructureTransformEnumPyGen",
         ":StructuredTransformOpsPyGen",
         ":TensorTransformOpsPyGen",

>From 01e929bdd59fadb91707d3b0356366668ca40590 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Wed, 20 Sep 2023 19:37:59 +0000
Subject: [PATCH 2/2] add check test

---
 .../dialects/transform_sparse_tensor_ext.py   | 31 +++++++++++++++++++
 1 file changed, 31 insertions(+)
 create mode 100644 mlir/test/python/dialects/transform_sparse_tensor_ext.py

diff --git a/mlir/test/python/dialects/transform_sparse_tensor_ext.py b/mlir/test/python/dialects/transform_sparse_tensor_ext.py
new file mode 100644
index 000000000000000..e11cc6bf1e07426
--- /dev/null
+++ b/mlir/test/python/dialects/transform_sparse_tensor_ext.py
@@ -0,0 +1,31 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+from mlir.ir import *
+from mlir.dialects import transform
+from mlir.dialects.transform import sparse_tensor
+
+
+def run(f):
+    with Context(), Location.unknown():
+        module = Module.create()
+        with InsertionPoint(module.body):
+            sequence = transform.SequenceOp(
+                transform.FailurePropagationMode.Propagate,
+                [],
+                transform.AnyOpType.get(),
+            )
+            with InsertionPoint(sequence.body):
+                f(sequence.bodyTarget)
+                transform.YieldOp()
+        print("\nTEST:", f.__name__)
+        print(module)
+    return f
+
+
+ at run
+def testMatchSparseInOut(target):
+    sparse_tensor.MatchSparseInOut(transform.AnyOpType.get(), target)
+    # CHECK-LABEL: TEST: testMatchSparseInOut
+    # CHECK:       transform.sequence
+    # CHECK-NEXT:  ^{{.*}}(%[[ARG0:.*]]: !transform.any_op):
+    # CHECK-NEXT:    transform.sparse_tensor.match.sparse_inout %[[ARG0]]



More information about the llvm-commits mailing list