[Mlir-commits] [mlir] c12cb0c - [mlir][python] fix value-builder generation for snake_case ops (#135302)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Apr 11 05:55:42 PDT 2025
Author: Maksim Levental
Date: 2025-04-11T08:55:38-04:00
New Revision: c12cb0ccbb408c5e65801a6aa5a8e53b8b36c646
URL: https://github.com/llvm/llvm-project/commit/c12cb0ccbb408c5e65801a6aa5a8e53b8b36c646
DIFF: https://github.com/llvm/llvm-project/commit/c12cb0ccbb408c5e65801a6aa5a8e53b8b36c646.diff
LOG: [mlir][python] fix value-builder generation for snake_case ops (#135302)
Ops that are already snake case (like [`ROCDL_wmma_*`
ops](https://github.com/makslevental/llvm-project/blob/66b0b0466bbd995146aadaf2cd18de5476c19941/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td#L411))
produce python "value-builders" that collide with the class names:
```python
class wmma_bf16_16x16x16_bf16(_ods_ir.OpView):
OPERATION_NAME = "rocdl.wmma.bf16.16x16x16.bf16"
...
def wmma_bf16_16x16x16_bf16(res, args, *, loc=None, ip=None) -> _ods_ir.Value:
return wmma_bf16_16x16x16_bf16(res=res, args=args, loc=loc, ip=ip).result
```
and thus cannot be emitted (because of recursive self-calls).
This PR fixes that by affixing `_` to the value builder names.
I would've preferred to just rename the ops but that would be a breaking
change :shrug:.
Added:
Modified:
mlir/test/mlir-tblgen/op-python-bindings.td
mlir/test/python/dialects/rocdl.py
mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
Removed:
################################################################################
diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index 72963cac64d54..c2bd86819666b 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -654,3 +654,9 @@ def WithSuccessorsOp : TestOp<"with_successors"> {
// CHECK: def with_successors(successor, successors, *, loc=None, ip=None)
// CHECK: return WithSuccessorsOp(successor=successor, successors=successors, loc=loc, ip=ip)
+
+// CHECK: class snake_case(_ods_ir.OpView):
+// CHECK-LABEL: OPERATION_NAME = "test.snake_case"
+def already_snake_case : TestOp<"snake_case"> {}
+// CHECK: def snake_case_(*, loc=None, ip=None)
+// CHECK: return snake_case(loc=loc, ip=ip)
diff --git a/mlir/test/python/dialects/rocdl.py b/mlir/test/python/dialects/rocdl.py
index a4eca2766899b..a4a50afa966c7 100644
--- a/mlir/test/python/dialects/rocdl.py
+++ b/mlir/test/python/dialects/rocdl.py
@@ -1,8 +1,10 @@
# RUN: %PYTHON %s | FileCheck %s
# This is just a smoke test that the dialect is functional.
+from array import array
from mlir.ir import *
-from mlir.dialects import rocdl
+from mlir.dialects import rocdl, arith
+from mlir.extras import types as T
def constructAndPrintInModule(f):
@@ -18,5 +20,22 @@ def constructAndPrintInModule(f):
# CHECK-LABEL: testSmoke
@constructAndPrintInModule
def testSmoke():
- # CHECK: rocdl.barrier
- rocdl.BarrierOp()
+ v_len = 16
+ f32 = F32Type.get()
+ # Note: this isn't actually the right type for the intrinsic (should be f16)
+ # but array doesn't support f16.
+ v16f32 = T.vector(v_len, f32)
+ f32_array = array("f", [0.0] * v_len)
+ a_frag = arith.constant(v16f32, f32_array)
+ b_frag = arith.constant(v16f32, f32_array)
+ c_frag = arith.constant(v16f32, f32_array)
+ false = arith.constant(T.bool(), False)
+
+ c_frag = rocdl.wmma_f16_16x16x16_f16(v16f32, [a_frag, b_frag, c_frag, false])
+ # CHECK: %{{.*}} = rocdl.wmma.f16.16x16x16.f16
+ print(c_frag)
+ assert isinstance(c_frag, OpView)
+ # CHECK: Value(%{{.*}} = rocdl.wmma.f16.16x16x16.f16
+ c_frag = rocdl.wmma_f16_16x16x16_f16_(v16f32, [a_frag, b_frag, c_frag, false])
+ print(c_frag)
+ assert isinstance(c_frag, Value)
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 604d2376052a8..d2e38e9d23198 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -1000,6 +1000,8 @@ static void emitValueBuilder(const Operator &op,
});
std::string nameWithoutDialect = sanitizeName(
op.getOperationName().substr(op.getOperationName().find('.') + 1));
+ if (nameWithoutDialect == op.getCppClassName())
+ nameWithoutDialect += "_";
std::string params = llvm::join(valueBuilderParams, ", ");
std::string args = llvm::join(opBuilderArgs, ", ");
const char *type =
More information about the Mlir-commits
mailing list