[Mlir-commits] [mlir] [mlir][x86vector] Python bindings for x86vector dialect (PR #179958)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Feb 5 07:15:53 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-vector

Author: Adam Siemieniuk (adam-smnk)

<details>
<summary>Changes</summary>

Registers python bindings for x86vector dialect and transform ops.

---
Full diff: https://github.com/llvm/llvm-project/pull/179958.diff


8 Files Affected:

- (modified) mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td (-1) 
- (modified) mlir/python/CMakeLists.txt (+16) 
- (added) mlir/python/mlir/dialects/X86Vector.td (+14) 
- (added) mlir/python/mlir/dialects/X86VectorTransformOps.td (+14) 
- (added) mlir/python/mlir/dialects/transform/x86vector.py (+5) 
- (added) mlir/python/mlir/dialects/x86vector.py (+6) 
- (added) mlir/test/python/dialects/transform_x86vector_ext.py (+40) 
- (added) mlir/test/python/dialects/x86vector.py (+72) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
index 891829fca017f..d57ed1f1cd171 100644
--- a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
@@ -73,4 +73,3 @@ def ApplyShuffleVectorFMAOpsPatternsOp : Op<Transform_Dialect,
 
 
 #endif // X86VECTOR_TRANSFORM_OPS
-
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 8ab145ada85dd..50143f700f5a1 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -325,6 +325,15 @@ declare_mlir_dialect_extension_python_bindings(
     "../../include/mlir/Dialect/Vector/Transforms/VectorTransformsBase.td"
 )
 
+declare_mlir_dialect_extension_python_bindings(
+  ADD_TO_PARENT MLIRPythonSources.Dialects
+  ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
+  TD_FILE dialects/X86VectorTransformOps.td
+  SOURCES
+    dialects/transform/x86vector.py
+  DIALECT_NAME transform
+  EXTENSION_NAME x86vector_transform)
+
 declare_mlir_dialect_extension_python_bindings(
   ADD_TO_PARENT MLIRPythonSources.Dialects
   ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
@@ -510,6 +519,13 @@ declare_mlir_dialect_python_bindings(
   GEN_ENUM_BINDINGS_TD_FILE
     "dialects/VectorAttributes.td")
 
+declare_mlir_dialect_python_bindings(
+  ADD_TO_PARENT MLIRPythonSources.Dialects
+  ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
+  TD_FILE dialects/X86Vector.td
+  SOURCES dialects/x86vector.py
+  DIALECT_NAME x86vector)
+
 declare_mlir_dialect_python_bindings(
   ADD_TO_PARENT MLIRPythonSources.Dialects
   ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
diff --git a/mlir/python/mlir/dialects/X86Vector.td b/mlir/python/mlir/dialects/X86Vector.td
new file mode 100644
index 0000000000000..d8a846bf9e905
--- /dev/null
+++ b/mlir/python/mlir/dialects/X86Vector.td
@@ -0,0 +1,14 @@
+//===-- X86Vector.td - Entry point for x86vector bindings --*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef PYTHON_BINDINGS_X86VECTOR
+#define PYTHON_BINDINGS_X86VECTOR
+
+include "mlir/Dialect/X86Vector/X86Vector.td"
+
+#endif // PYTHON_BINDINGS_X86VECTOR
diff --git a/mlir/python/mlir/dialects/X86VectorTransformOps.td b/mlir/python/mlir/dialects/X86VectorTransformOps.td
new file mode 100644
index 0000000000000..ad6a693923703
--- /dev/null
+++ b/mlir/python/mlir/dialects/X86VectorTransformOps.td
@@ -0,0 +1,14 @@
+//===-- X86VectorTransformOps.td ---------------------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef PYTHON_BINDINGS_X86VECTORTRANSFORMOPS
+#define PYTHON_BINDINGS_X86VECTORTRANSFORMOPS
+
+include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td"
+
+#endif // PYTHON_BINDINGS_X86VECTORTRANSFORMOPS
diff --git a/mlir/python/mlir/dialects/transform/x86vector.py b/mlir/python/mlir/dialects/transform/x86vector.py
new file mode 100644
index 0000000000000..cccd300522797
--- /dev/null
+++ b/mlir/python/mlir/dialects/transform/x86vector.py
@@ -0,0 +1,5 @@
+#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+#  See https://llvm.org/LICENSE.txt for license information.
+#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from .._x86vector_transform_ops_gen import *
diff --git a/mlir/python/mlir/dialects/x86vector.py b/mlir/python/mlir/dialects/x86vector.py
new file mode 100644
index 0000000000000..eddc93dbe6460
--- /dev/null
+++ b/mlir/python/mlir/dialects/x86vector.py
@@ -0,0 +1,6 @@
+#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+#  See https://llvm.org/LICENSE.txt for license information.
+#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from ._x86vector_ops_gen import *
+from ._x86vector_ops_gen import _Dialect
diff --git a/mlir/test/python/dialects/transform_x86vector_ext.py b/mlir/test/python/dialects/transform_x86vector_ext.py
new file mode 100644
index 0000000000000..ad8dab8175ef2
--- /dev/null
+++ b/mlir/test/python/dialects/transform_x86vector_ext.py
@@ -0,0 +1,40 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+from mlir.ir import *
+from mlir.dialects import transform
+from mlir.dialects.transform import x86vector
+
+
+def run_apply_patterns(f):
+    with Context(), Location.unknown():
+        module = Module.create()
+        with InsertionPoint(module.body):
+            sequence = transform.SequenceOp(
+                transform.FailurePropagationMode.Propagate,
+                [],
+                transform.AnyOpType.get(),
+            )
+            with InsertionPoint(sequence.body):
+                apply = transform.ApplyPatternsOp(sequence.bodyTarget)
+                with InsertionPoint(apply.patterns):
+                    f()
+                transform.YieldOp()
+        print("\nTEST:", f.__name__)
+        print(module)
+    return f
+
+
+ at run_apply_patterns
+def non_configurable_patterns():
+    # CHECK-LABEL: TEST: non_configurable_patterns
+    # CHECK: apply_patterns
+    # CHECK: transform.apply_patterns.x86vector.vector_contract_to_fma
+    x86vector.ApplyVectorContractToFMAPatternsOp()
+    # CHECK: transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+    x86vector.ApplyVectorContractToPackedTypeDotProductPatternsOp()
+    # CHECK: transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
+    x86vector.ApplyVectorContractBF16ToFMAPatternsOp()
+    # CHECK: transform.apply_patterns.x86vector.sink_vector_producer_ops
+    x86vector.ApplySinkVectorProducerOpsPatternsOp()
+    # CHECK: transform.apply_patterns.x86vector.shuffle_vector_fma_ops
+    x86vector.ApplyShuffleVectorFMAOpsPatternsOp()
diff --git a/mlir/test/python/dialects/x86vector.py b/mlir/test/python/dialects/x86vector.py
new file mode 100644
index 0000000000000..c7d680792fb66
--- /dev/null
+++ b/mlir/test/python/dialects/x86vector.py
@@ -0,0 +1,72 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+from mlir.ir import *
+import mlir.dialects.builtin as builtin
+import mlir.dialects.func as func
+import mlir.dialects.x86vector as x86vector
+
+
+def run(f):
+    print("\nTEST:", f.__name__)
+    with Context(), Location.unknown():
+        f()
+    return f
+
+
+# CHECK-LABEL: TEST: testAvxOp
+ at run
+def testAvxOp():
+    module = Module.create()
+    with InsertionPoint(module.body):
+
+        @func.FuncOp.from_py_func(MemRefType.get((1,), BF16Type.get()))
+        def avx_op(arg):
+            return x86vector.BcstToPackedF32Op(a=arg, dst=VectorType.get((8,), F32Type.get()))
+
+    # CHECK-LABEL: func @avx_op(
+    # CHECK-SAME:      %[[ARG:.+]]: memref<1xbf16>) -> vector<8xf32> {
+    #       CHECK:   %[[VAL:.+]] = x86vector.avx.bcst_to_f32.packed %[[ARG]]
+    #       CHECK:   return %[[VAL]] : vector<8xf32>
+    #       CHECK: }
+    print(module)
+
+# CHECK-LABEL: TEST: testAvx512Op
+ at run
+def testAvx512Op():
+    module = Module.create()
+    with InsertionPoint(module.body):
+
+        @func.FuncOp.from_py_func(VectorType.get((8,), F32Type.get()))
+        def avx512_op(arg):
+            return x86vector.CvtPackedF32ToBF16Op(a=arg, dst=VectorType.get((8,), BF16Type.get()))
+
+    # CHECK-LABEL: func @avx512_op(
+    # CHECK-SAME:      %[[ARG:.+]]: vector<8xf32>) -> vector<8xbf16> {
+    #       CHECK:   %[[VAL:.+]] = x86vector.avx512.cvt.packed.f32_to_bf16 %[[ARG]]
+    #       CHECK:   return %[[VAL]] : vector<8xbf16>
+    #       CHECK: }
+    print(module)
+
+# CHECK-LABEL: TEST: testAvx10Op
+ at run
+def testAvx10Op():
+    module = Module.create()
+    with InsertionPoint(module.body):
+
+        @func.FuncOp.from_py_func(
+            VectorType.get((16,), IntegerType.get(32)),
+            VectorType.get((64,), IntegerType.get(8)),
+            VectorType.get((64,), IntegerType.get(8)),
+        )
+        def avx10_op(*args):
+            return x86vector.AVX10DotInt8Op(
+                w=args[0], a=args[1], b=args[2]
+            )
+
+    # CHECK-LABEL: func @avx10_op(
+    # CHECK-SAME:      %[[W:.+]]: vector<16xi32>, %[[A:.+]]: vector<64xi8>,
+    # CHECK-SAME:      %[[B:.+]]: vector<64xi8>) -> vector<16xi32> {
+    #       CHECK:   %[[VAL:.+]] = x86vector.avx10.dot.i8 %[[W]], %[[A]], %[[B]]
+    #       CHECK:   return %[[VAL]] : vector<16xi32>
+    #       CHECK: }
+    print(module)

``````````

</details>


https://github.com/llvm/llvm-project/pull/179958


More information about the Mlir-commits mailing list