[Mlir-commits] [mlir] [mlir][x86vector] Python bindings for x86vector dialect (PR #179958)
Adam Siemieniuk
llvmlistbot at llvm.org
Thu Feb 5 07:22:53 PST 2026
https://github.com/adam-smnk updated https://github.com/llvm/llvm-project/pull/179958
>From 924c93f1dff372360bb09e99c55e6695324a9ab5 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Thu, 5 Feb 2026 16:06:35 +0100
Subject: [PATCH 1/2] [mlir][x86vector] Python bindings for x86vector dialect
Registers python bindings for x86vector dialect and transform ops.
---
.../TransformOps/X86VectorTransformOps.td | 1 -
mlir/python/CMakeLists.txt | 16 +++++
mlir/python/mlir/dialects/X86Vector.td | 14 ++++
.../mlir/dialects/X86VectorTransformOps.td | 14 ++++
.../mlir/dialects/transform/x86vector.py | 5 ++
mlir/python/mlir/dialects/x86vector.py | 6 ++
.../dialects/transform_x86vector_ext.py | 40 +++++++++++
mlir/test/python/dialects/x86vector.py | 72 +++++++++++++++++++
8 files changed, 167 insertions(+), 1 deletion(-)
create mode 100644 mlir/python/mlir/dialects/X86Vector.td
create mode 100644 mlir/python/mlir/dialects/X86VectorTransformOps.td
create mode 100644 mlir/python/mlir/dialects/transform/x86vector.py
create mode 100644 mlir/python/mlir/dialects/x86vector.py
create mode 100644 mlir/test/python/dialects/transform_x86vector_ext.py
create mode 100644 mlir/test/python/dialects/x86vector.py
diff --git a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
index 891829fca017f..d57ed1f1cd171 100644
--- a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
@@ -73,4 +73,3 @@ def ApplyShuffleVectorFMAOpsPatternsOp : Op<Transform_Dialect,
#endif // X86VECTOR_TRANSFORM_OPS
-
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 8ab145ada85dd..50143f700f5a1 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -325,6 +325,15 @@ declare_mlir_dialect_extension_python_bindings(
"../../include/mlir/Dialect/Vector/Transforms/VectorTransformsBase.td"
)
+declare_mlir_dialect_extension_python_bindings(
+ ADD_TO_PARENT MLIRPythonSources.Dialects
+ ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
+ TD_FILE dialects/X86VectorTransformOps.td
+ SOURCES
+ dialects/transform/x86vector.py
+ DIALECT_NAME transform
+ EXTENSION_NAME x86vector_transform)
+
declare_mlir_dialect_extension_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
@@ -510,6 +519,13 @@ declare_mlir_dialect_python_bindings(
GEN_ENUM_BINDINGS_TD_FILE
"dialects/VectorAttributes.td")
+declare_mlir_dialect_python_bindings(
+ ADD_TO_PARENT MLIRPythonSources.Dialects
+ ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
+ TD_FILE dialects/X86Vector.td
+ SOURCES dialects/x86vector.py
+ DIALECT_NAME x86vector)
+
declare_mlir_dialect_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
diff --git a/mlir/python/mlir/dialects/X86Vector.td b/mlir/python/mlir/dialects/X86Vector.td
new file mode 100644
index 0000000000000..d8a846bf9e905
--- /dev/null
+++ b/mlir/python/mlir/dialects/X86Vector.td
@@ -0,0 +1,14 @@
+//===-- X86Vector.td - Entry point for x86vector bindings --*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef PYTHON_BINDINGS_X86VECTOR
+#define PYTHON_BINDINGS_X86VECTOR
+
+include "mlir/Dialect/X86Vector/X86Vector.td"
+
+#endif // PYTHON_BINDINGS_X86VECTOR
diff --git a/mlir/python/mlir/dialects/X86VectorTransformOps.td b/mlir/python/mlir/dialects/X86VectorTransformOps.td
new file mode 100644
index 0000000000000..ad6a693923703
--- /dev/null
+++ b/mlir/python/mlir/dialects/X86VectorTransformOps.td
@@ -0,0 +1,14 @@
+//===-- X86VectorTransformOps.td ---------------------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef PYTHON_BINDINGS_X86VECTORTRANSFORMOPS
+#define PYTHON_BINDINGS_X86VECTORTRANSFORMOPS
+
+include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td"
+
+#endif // PYTHON_BINDINGS_X86VECTORTRANSFORMOPS
diff --git a/mlir/python/mlir/dialects/transform/x86vector.py b/mlir/python/mlir/dialects/transform/x86vector.py
new file mode 100644
index 0000000000000..cccd300522797
--- /dev/null
+++ b/mlir/python/mlir/dialects/transform/x86vector.py
@@ -0,0 +1,5 @@
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from .._x86vector_transform_ops_gen import *
diff --git a/mlir/python/mlir/dialects/x86vector.py b/mlir/python/mlir/dialects/x86vector.py
new file mode 100644
index 0000000000000..eddc93dbe6460
--- /dev/null
+++ b/mlir/python/mlir/dialects/x86vector.py
@@ -0,0 +1,6 @@
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from ._x86vector_ops_gen import *
+from ._x86vector_ops_gen import _Dialect
diff --git a/mlir/test/python/dialects/transform_x86vector_ext.py b/mlir/test/python/dialects/transform_x86vector_ext.py
new file mode 100644
index 0000000000000..ad8dab8175ef2
--- /dev/null
+++ b/mlir/test/python/dialects/transform_x86vector_ext.py
@@ -0,0 +1,40 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+from mlir.ir import *
+from mlir.dialects import transform
+from mlir.dialects.transform import x86vector
+
+
+def run_apply_patterns(f):
+ with Context(), Location.unknown():
+ module = Module.create()
+ with InsertionPoint(module.body):
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate,
+ [],
+ transform.AnyOpType.get(),
+ )
+ with InsertionPoint(sequence.body):
+ apply = transform.ApplyPatternsOp(sequence.bodyTarget)
+ with InsertionPoint(apply.patterns):
+ f()
+ transform.YieldOp()
+ print("\nTEST:", f.__name__)
+ print(module)
+ return f
+
+
+ at run_apply_patterns
+def non_configurable_patterns():
+ # CHECK-LABEL: TEST: non_configurable_patterns
+ # CHECK: apply_patterns
+ # CHECK: transform.apply_patterns.x86vector.vector_contract_to_fma
+ x86vector.ApplyVectorContractToFMAPatternsOp()
+ # CHECK: transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+ x86vector.ApplyVectorContractToPackedTypeDotProductPatternsOp()
+ # CHECK: transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
+ x86vector.ApplyVectorContractBF16ToFMAPatternsOp()
+ # CHECK: transform.apply_patterns.x86vector.sink_vector_producer_ops
+ x86vector.ApplySinkVectorProducerOpsPatternsOp()
+ # CHECK: transform.apply_patterns.x86vector.shuffle_vector_fma_ops
+ x86vector.ApplyShuffleVectorFMAOpsPatternsOp()
diff --git a/mlir/test/python/dialects/x86vector.py b/mlir/test/python/dialects/x86vector.py
new file mode 100644
index 0000000000000..c7d680792fb66
--- /dev/null
+++ b/mlir/test/python/dialects/x86vector.py
@@ -0,0 +1,72 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+from mlir.ir import *
+import mlir.dialects.builtin as builtin
+import mlir.dialects.func as func
+import mlir.dialects.x86vector as x86vector
+
+
+def run(f):
+ print("\nTEST:", f.__name__)
+ with Context(), Location.unknown():
+ f()
+ return f
+
+
+# CHECK-LABEL: TEST: testAvxOp
+ at run
+def testAvxOp():
+ module = Module.create()
+ with InsertionPoint(module.body):
+
+ @func.FuncOp.from_py_func(MemRefType.get((1,), BF16Type.get()))
+ def avx_op(arg):
+ return x86vector.BcstToPackedF32Op(a=arg, dst=VectorType.get((8,), F32Type.get()))
+
+ # CHECK-LABEL: func @avx_op(
+ # CHECK-SAME: %[[ARG:.+]]: memref<1xbf16>) -> vector<8xf32> {
+ # CHECK: %[[VAL:.+]] = x86vector.avx.bcst_to_f32.packed %[[ARG]]
+ # CHECK: return %[[VAL]] : vector<8xf32>
+ # CHECK: }
+ print(module)
+
+# CHECK-LABEL: TEST: testAvx512Op
+ at run
+def testAvx512Op():
+ module = Module.create()
+ with InsertionPoint(module.body):
+
+ @func.FuncOp.from_py_func(VectorType.get((8,), F32Type.get()))
+ def avx512_op(arg):
+ return x86vector.CvtPackedF32ToBF16Op(a=arg, dst=VectorType.get((8,), BF16Type.get()))
+
+ # CHECK-LABEL: func @avx512_op(
+ # CHECK-SAME: %[[ARG:.+]]: vector<8xf32>) -> vector<8xbf16> {
+ # CHECK: %[[VAL:.+]] = x86vector.avx512.cvt.packed.f32_to_bf16 %[[ARG]]
+ # CHECK: return %[[VAL]] : vector<8xbf16>
+ # CHECK: }
+ print(module)
+
+# CHECK-LABEL: TEST: testAvx10Op
+ at run
+def testAvx10Op():
+ module = Module.create()
+ with InsertionPoint(module.body):
+
+ @func.FuncOp.from_py_func(
+ VectorType.get((16,), IntegerType.get(32)),
+ VectorType.get((64,), IntegerType.get(8)),
+ VectorType.get((64,), IntegerType.get(8)),
+ )
+ def avx10_op(*args):
+ return x86vector.AVX10DotInt8Op(
+ w=args[0], a=args[1], b=args[2]
+ )
+
+ # CHECK-LABEL: func @avx10_op(
+ # CHECK-SAME: %[[W:.+]]: vector<16xi32>, %[[A:.+]]: vector<64xi8>,
+ # CHECK-SAME: %[[B:.+]]: vector<64xi8>) -> vector<16xi32> {
+ # CHECK: %[[VAL:.+]] = x86vector.avx10.dot.i8 %[[W]], %[[A]], %[[B]]
+ # CHECK: return %[[VAL]] : vector<16xi32>
+ # CHECK: }
+ print(module)
>From 74b38bc5445a9cb3abf07b5d23bacb08b778dc69 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Thu, 5 Feb 2026 16:21:51 +0100
Subject: [PATCH 2/2] Formatting
---
mlir/test/python/dialects/x86vector.py | 14 +++++++++-----
1 file changed, 9 insertions(+), 5 deletions(-)
diff --git a/mlir/test/python/dialects/x86vector.py b/mlir/test/python/dialects/x86vector.py
index c7d680792fb66..c270727078a20 100644
--- a/mlir/test/python/dialects/x86vector.py
+++ b/mlir/test/python/dialects/x86vector.py
@@ -21,7 +21,9 @@ def testAvxOp():
@func.FuncOp.from_py_func(MemRefType.get((1,), BF16Type.get()))
def avx_op(arg):
- return x86vector.BcstToPackedF32Op(a=arg, dst=VectorType.get((8,), F32Type.get()))
+ return x86vector.BcstToPackedF32Op(
+ a=arg, dst=VectorType.get((8,), F32Type.get())
+ )
# CHECK-LABEL: func @avx_op(
# CHECK-SAME: %[[ARG:.+]]: memref<1xbf16>) -> vector<8xf32> {
@@ -30,6 +32,7 @@ def avx_op(arg):
# CHECK: }
print(module)
+
# CHECK-LABEL: TEST: testAvx512Op
@run
def testAvx512Op():
@@ -38,7 +41,9 @@ def testAvx512Op():
@func.FuncOp.from_py_func(VectorType.get((8,), F32Type.get()))
def avx512_op(arg):
- return x86vector.CvtPackedF32ToBF16Op(a=arg, dst=VectorType.get((8,), BF16Type.get()))
+ return x86vector.CvtPackedF32ToBF16Op(
+ a=arg, dst=VectorType.get((8,), BF16Type.get())
+ )
# CHECK-LABEL: func @avx512_op(
# CHECK-SAME: %[[ARG:.+]]: vector<8xf32>) -> vector<8xbf16> {
@@ -47,6 +52,7 @@ def avx512_op(arg):
# CHECK: }
print(module)
+
# CHECK-LABEL: TEST: testAvx10Op
@run
def testAvx10Op():
@@ -59,9 +65,7 @@ def testAvx10Op():
VectorType.get((64,), IntegerType.get(8)),
)
def avx10_op(*args):
- return x86vector.AVX10DotInt8Op(
- w=args[0], a=args[1], b=args[2]
- )
+ return x86vector.AVX10DotInt8Op(w=args[0], a=args[1], b=args[2])
# CHECK-LABEL: func @avx10_op(
# CHECK-SAME: %[[W:.+]]: vector<16xi32>, %[[A:.+]]: vector<64xi8>,
More information about the Mlir-commits
mailing list