[Mlir-commits] [mlir] [mlir][x86vector] Python bindings for x86vector dialect (PR #179958)

Adam Siemieniuk llvmlistbot at llvm.org
Thu Feb 5 07:15:15 PST 2026


https://github.com/adam-smnk created https://github.com/llvm/llvm-project/pull/179958

Registers python bindings for x86vector dialect and transform ops.

>From 924c93f1dff372360bb09e99c55e6695324a9ab5 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Thu, 5 Feb 2026 16:06:35 +0100
Subject: [PATCH] [mlir][x86vector] Python bindings for x86vector dialect

Registers python bindings for x86vector dialect and transform ops.
---
 .../TransformOps/X86VectorTransformOps.td     |  1 -
 mlir/python/CMakeLists.txt                    | 16 +++++
 mlir/python/mlir/dialects/X86Vector.td        | 14 ++++
 .../mlir/dialects/X86VectorTransformOps.td    | 14 ++++
 .../mlir/dialects/transform/x86vector.py      |  5 ++
 mlir/python/mlir/dialects/x86vector.py        |  6 ++
 .../dialects/transform_x86vector_ext.py       | 40 +++++++++++
 mlir/test/python/dialects/x86vector.py        | 72 +++++++++++++++++++
 8 files changed, 167 insertions(+), 1 deletion(-)
 create mode 100644 mlir/python/mlir/dialects/X86Vector.td
 create mode 100644 mlir/python/mlir/dialects/X86VectorTransformOps.td
 create mode 100644 mlir/python/mlir/dialects/transform/x86vector.py
 create mode 100644 mlir/python/mlir/dialects/x86vector.py
 create mode 100644 mlir/test/python/dialects/transform_x86vector_ext.py
 create mode 100644 mlir/test/python/dialects/x86vector.py

diff --git a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
index 891829fca017f..d57ed1f1cd171 100644
--- a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
@@ -73,4 +73,3 @@ def ApplyShuffleVectorFMAOpsPatternsOp : Op<Transform_Dialect,
 
 
 #endif // X86VECTOR_TRANSFORM_OPS
-
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 8ab145ada85dd..50143f700f5a1 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -325,6 +325,15 @@ declare_mlir_dialect_extension_python_bindings(
     "../../include/mlir/Dialect/Vector/Transforms/VectorTransformsBase.td"
 )
 
+declare_mlir_dialect_extension_python_bindings(
+  ADD_TO_PARENT MLIRPythonSources.Dialects
+  ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
+  TD_FILE dialects/X86VectorTransformOps.td
+  SOURCES
+    dialects/transform/x86vector.py
+  DIALECT_NAME transform
+  EXTENSION_NAME x86vector_transform)
+
 declare_mlir_dialect_extension_python_bindings(
   ADD_TO_PARENT MLIRPythonSources.Dialects
   ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
@@ -510,6 +519,13 @@ declare_mlir_dialect_python_bindings(
   GEN_ENUM_BINDINGS_TD_FILE
     "dialects/VectorAttributes.td")
 
+declare_mlir_dialect_python_bindings(
+  ADD_TO_PARENT MLIRPythonSources.Dialects
+  ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
+  TD_FILE dialects/X86Vector.td
+  SOURCES dialects/x86vector.py
+  DIALECT_NAME x86vector)
+
 declare_mlir_dialect_python_bindings(
   ADD_TO_PARENT MLIRPythonSources.Dialects
   ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
diff --git a/mlir/python/mlir/dialects/X86Vector.td b/mlir/python/mlir/dialects/X86Vector.td
new file mode 100644
index 0000000000000..d8a846bf9e905
--- /dev/null
+++ b/mlir/python/mlir/dialects/X86Vector.td
@@ -0,0 +1,14 @@
+//===-- X86Vector.td - Entry point for x86vector bindings --*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef PYTHON_BINDINGS_X86VECTOR
+#define PYTHON_BINDINGS_X86VECTOR
+
+include "mlir/Dialect/X86Vector/X86Vector.td"
+
+#endif // PYTHON_BINDINGS_X86VECTOR
diff --git a/mlir/python/mlir/dialects/X86VectorTransformOps.td b/mlir/python/mlir/dialects/X86VectorTransformOps.td
new file mode 100644
index 0000000000000..ad6a693923703
--- /dev/null
+++ b/mlir/python/mlir/dialects/X86VectorTransformOps.td
@@ -0,0 +1,14 @@
+//===-- X86VectorTransformOps.td ---------------------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef PYTHON_BINDINGS_X86VECTORTRANSFORMOPS
+#define PYTHON_BINDINGS_X86VECTORTRANSFORMOPS
+
+include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td"
+
+#endif // PYTHON_BINDINGS_X86VECTORTRANSFORMOPS
diff --git a/mlir/python/mlir/dialects/transform/x86vector.py b/mlir/python/mlir/dialects/transform/x86vector.py
new file mode 100644
index 0000000000000..cccd300522797
--- /dev/null
+++ b/mlir/python/mlir/dialects/transform/x86vector.py
@@ -0,0 +1,5 @@
+#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+#  See https://llvm.org/LICENSE.txt for license information.
+#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from .._x86vector_transform_ops_gen import *
diff --git a/mlir/python/mlir/dialects/x86vector.py b/mlir/python/mlir/dialects/x86vector.py
new file mode 100644
index 0000000000000..eddc93dbe6460
--- /dev/null
+++ b/mlir/python/mlir/dialects/x86vector.py
@@ -0,0 +1,6 @@
+#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+#  See https://llvm.org/LICENSE.txt for license information.
+#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from ._x86vector_ops_gen import *
+from ._x86vector_ops_gen import _Dialect
diff --git a/mlir/test/python/dialects/transform_x86vector_ext.py b/mlir/test/python/dialects/transform_x86vector_ext.py
new file mode 100644
index 0000000000000..ad8dab8175ef2
--- /dev/null
+++ b/mlir/test/python/dialects/transform_x86vector_ext.py
@@ -0,0 +1,40 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+from mlir.ir import *
+from mlir.dialects import transform
+from mlir.dialects.transform import x86vector
+
+
+def run_apply_patterns(f):
+    with Context(), Location.unknown():
+        module = Module.create()
+        with InsertionPoint(module.body):
+            sequence = transform.SequenceOp(
+                transform.FailurePropagationMode.Propagate,
+                [],
+                transform.AnyOpType.get(),
+            )
+            with InsertionPoint(sequence.body):
+                apply = transform.ApplyPatternsOp(sequence.bodyTarget)
+                with InsertionPoint(apply.patterns):
+                    f()
+                transform.YieldOp()
+        print("\nTEST:", f.__name__)
+        print(module)
+    return f
+
+
+ at run_apply_patterns
+def non_configurable_patterns():
+    # CHECK-LABEL: TEST: non_configurable_patterns
+    # CHECK: apply_patterns
+    # CHECK: transform.apply_patterns.x86vector.vector_contract_to_fma
+    x86vector.ApplyVectorContractToFMAPatternsOp()
+    # CHECK: transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+    x86vector.ApplyVectorContractToPackedTypeDotProductPatternsOp()
+    # CHECK: transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
+    x86vector.ApplyVectorContractBF16ToFMAPatternsOp()
+    # CHECK: transform.apply_patterns.x86vector.sink_vector_producer_ops
+    x86vector.ApplySinkVectorProducerOpsPatternsOp()
+    # CHECK: transform.apply_patterns.x86vector.shuffle_vector_fma_ops
+    x86vector.ApplyShuffleVectorFMAOpsPatternsOp()
diff --git a/mlir/test/python/dialects/x86vector.py b/mlir/test/python/dialects/x86vector.py
new file mode 100644
index 0000000000000..c7d680792fb66
--- /dev/null
+++ b/mlir/test/python/dialects/x86vector.py
@@ -0,0 +1,72 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+from mlir.ir import *
+import mlir.dialects.builtin as builtin
+import mlir.dialects.func as func
+import mlir.dialects.x86vector as x86vector
+
+
+def run(f):
+    print("\nTEST:", f.__name__)
+    with Context(), Location.unknown():
+        f()
+    return f
+
+
+# CHECK-LABEL: TEST: testAvxOp
+ at run
+def testAvxOp():
+    module = Module.create()
+    with InsertionPoint(module.body):
+
+        @func.FuncOp.from_py_func(MemRefType.get((1,), BF16Type.get()))
+        def avx_op(arg):
+            return x86vector.BcstToPackedF32Op(a=arg, dst=VectorType.get((8,), F32Type.get()))
+
+    # CHECK-LABEL: func @avx_op(
+    # CHECK-SAME:      %[[ARG:.+]]: memref<1xbf16>) -> vector<8xf32> {
+    #       CHECK:   %[[VAL:.+]] = x86vector.avx.bcst_to_f32.packed %[[ARG]]
+    #       CHECK:   return %[[VAL]] : vector<8xf32>
+    #       CHECK: }
+    print(module)
+
+# CHECK-LABEL: TEST: testAvx512Op
+ at run
+def testAvx512Op():
+    module = Module.create()
+    with InsertionPoint(module.body):
+
+        @func.FuncOp.from_py_func(VectorType.get((8,), F32Type.get()))
+        def avx512_op(arg):
+            return x86vector.CvtPackedF32ToBF16Op(a=arg, dst=VectorType.get((8,), BF16Type.get()))
+
+    # CHECK-LABEL: func @avx512_op(
+    # CHECK-SAME:      %[[ARG:.+]]: vector<8xf32>) -> vector<8xbf16> {
+    #       CHECK:   %[[VAL:.+]] = x86vector.avx512.cvt.packed.f32_to_bf16 %[[ARG]]
+    #       CHECK:   return %[[VAL]] : vector<8xbf16>
+    #       CHECK: }
+    print(module)
+
+# CHECK-LABEL: TEST: testAvx10Op
+ at run
+def testAvx10Op():
+    module = Module.create()
+    with InsertionPoint(module.body):
+
+        @func.FuncOp.from_py_func(
+            VectorType.get((16,), IntegerType.get(32)),
+            VectorType.get((64,), IntegerType.get(8)),
+            VectorType.get((64,), IntegerType.get(8)),
+        )
+        def avx10_op(*args):
+            return x86vector.AVX10DotInt8Op(
+                w=args[0], a=args[1], b=args[2]
+            )
+
+    # CHECK-LABEL: func @avx10_op(
+    # CHECK-SAME:      %[[W:.+]]: vector<16xi32>, %[[A:.+]]: vector<64xi8>,
+    # CHECK-SAME:      %[[B:.+]]: vector<64xi8>) -> vector<16xi32> {
+    #       CHECK:   %[[VAL:.+]] = x86vector.avx10.dot.i8 %[[W]], %[[A]], %[[B]]
+    #       CHECK:   return %[[VAL]] : vector<16xi32>
+    #       CHECK: }
+    print(module)



More information about the Mlir-commits mailing list