[Mlir-commits] [mlir] [mlir][python] fix value-builder generation for snake_case ops (PR #135302)

Maksim Levental llvmlistbot at llvm.org
Thu Apr 10 20:59:37 PDT 2025


https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/135302

>From 5b250acb82415901877db2232f198fef35702af8 Mon Sep 17 00:00:00 2001
From: Maksim Levental <maksim.levental at gmail.com>
Date: Thu, 10 Apr 2025 22:46:39 -0400
Subject: [PATCH] [mlir][python] fix value-builder generation for snake_case
 ops

---
 mlir/test/mlir-tblgen/op-python-bindings.td   |  6 +++++
 mlir/test/python/dialects/rocdl.py            | 25 ++++++++++++++++---
 mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp |  2 ++
 3 files changed, 30 insertions(+), 3 deletions(-)

diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index 72963cac64d54..c2bd86819666b 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -654,3 +654,9 @@ def WithSuccessorsOp : TestOp<"with_successors"> {
 
 // CHECK: def with_successors(successor, successors, *, loc=None, ip=None)
 // CHECK:   return WithSuccessorsOp(successor=successor, successors=successors, loc=loc, ip=ip)
+
+// CHECK: class snake_case(_ods_ir.OpView):
+// CHECK-LABEL: OPERATION_NAME = "test.snake_case"
+def already_snake_case : TestOp<"snake_case"> {}
+// CHECK: def snake_case_(*, loc=None, ip=None)
+// CHECK:   return snake_case(loc=loc, ip=ip)
diff --git a/mlir/test/python/dialects/rocdl.py b/mlir/test/python/dialects/rocdl.py
index a4eca2766899b..a4a50afa966c7 100644
--- a/mlir/test/python/dialects/rocdl.py
+++ b/mlir/test/python/dialects/rocdl.py
@@ -1,8 +1,10 @@
 # RUN: %PYTHON %s | FileCheck %s
 # This is just a smoke test that the dialect is functional.
+from array import array
 
 from mlir.ir import *
-from mlir.dialects import rocdl
+from mlir.dialects import rocdl, arith
+from mlir.extras import types as T
 
 
 def constructAndPrintInModule(f):
@@ -18,5 +20,22 @@ def constructAndPrintInModule(f):
 # CHECK-LABEL: testSmoke
 @constructAndPrintInModule
 def testSmoke():
-    # CHECK: rocdl.barrier
-    rocdl.BarrierOp()
+    v_len = 16
+    f32 = F32Type.get()
+    # Note: this isn't actually the right type for the intrinsic (should be f16)
+    # but array doesn't support f16.
+    v16f32 = T.vector(v_len, f32)
+    f32_array = array("f", [0.0] * v_len)
+    a_frag = arith.constant(v16f32, f32_array)
+    b_frag = arith.constant(v16f32, f32_array)
+    c_frag = arith.constant(v16f32, f32_array)
+    false = arith.constant(T.bool(), False)
+
+    c_frag = rocdl.wmma_f16_16x16x16_f16(v16f32, [a_frag, b_frag, c_frag, false])
+    # CHECK: %{{.*}} = rocdl.wmma.f16.16x16x16.f16
+    print(c_frag)
+    assert isinstance(c_frag, OpView)
+    # CHECK: Value(%{{.*}} = rocdl.wmma.f16.16x16x16.f16
+    c_frag = rocdl.wmma_f16_16x16x16_f16_(v16f32, [a_frag, b_frag, c_frag, false])
+    print(c_frag)
+    assert isinstance(c_frag, Value)
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 604d2376052a8..d2e38e9d23198 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -1000,6 +1000,8 @@ static void emitValueBuilder(const Operator &op,
       });
   std::string nameWithoutDialect = sanitizeName(
       op.getOperationName().substr(op.getOperationName().find('.') + 1));
+  if (nameWithoutDialect == op.getCppClassName())
+    nameWithoutDialect += "_";
   std::string params = llvm::join(valueBuilderParams, ", ");
   std::string args = llvm::join(opBuilderArgs, ", ");
   const char *type =



More information about the Mlir-commits mailing list