[Mlir-commits] [mlir] bca003d - [mlir] Fix wrong variable name in Linalg OpDSL

Alex Zinenko llvmlistbot at llvm.org
Wed Nov 17 13:56:06 PST 2021


Author: Alex Zinenko
Date: 2021-11-17T22:55:35+01:00
New Revision: bca003dea8df9d87ce3cf17defb4e89b3166462d

URL: https://github.com/llvm/llvm-project/commit/bca003dea8df9d87ce3cf17defb4e89b3166462d
DIFF: https://github.com/llvm/llvm-project/commit/bca003dea8df9d87ce3cf17defb4e89b3166462d.diff

LOG: [mlir] Fix wrong variable name in Linalg OpDSL

The name seems to have been left over from a renaming effort on an unexercised
codepaths that are difficult to catch in Python. Fix it and add a test that
exercises the codepath.

Reviewed By: gysit

Differential Revision: https://reviews.llvm.org/D114004

Added: 
    

Modified: 
    mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py
    mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py

Removed: 
    


################################################################################
diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py
index 1acae7a7a389..a65350ccd430 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py
@@ -112,7 +112,7 @@ def linalg_structured_op(dsl_func=None,
   if dsl_func is None:
     # Curry the keyword args in for delayed application.
     return functools.partial(
-        tc_def_op, op_name=op_name, op_class_name=op_class_name)
+        linalg_structured_op, op_name=op_name, op_class_name=op_class_name)
   # Determine default names by introspecting the function.
   if op_name is None:
     op_name = dsl_func.__name__
@@ -131,9 +131,10 @@ def linalg_structured_op(dsl_func=None,
     if isinstance(param_default, (TensorDef, ScalarDef, AttributeDef)):
       tc_model.add_operand(param_name, param_default.operand_def)
     else:
-      raise ValueError(f"@tc_def_op function parameters must be defaulted as "
-                       f"TensorDef(...), ScalarDef(...), or AttributeDef(...): "
-                       f"Found {param_name}: {param_default}")
+      raise ValueError(
+          f"@linalg_structured_op function parameters must be defaulted as "
+          f"TensorDef(...), ScalarDef(...), or AttributeDef(...): "
+          f"Found {param_name}: {param_default}")
     dsl_func_args.append(param_default)
 
   # Invoke the DSL func to finish populating the model.

diff  --git a/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py b/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py
index 71dc8a5474aa..d0c74270950e 100644
--- a/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py
+++ b/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py
@@ -126,6 +126,11 @@ def soft_plus_poly(
       PrimFn.log(cast(U, const(1.0)) + cast(U, PrimFn.exp(I[D.m, D.n])))
 
 
+ at linalg_structured_op(op_name="custom_op_name")
+def non_default_op_name(I=TensorDef(T, S.N), O=TensorDef(T, S.N, output=True)):
+  O[D.n] = I[D.n]
+
+
 with Context() as ctx, Location.unknown():
   module = Module.create()
   f16 = F16Type.get()
@@ -392,5 +397,12 @@ def test_i32_fill_rng(min, max, seed, init_result):
     def test_f32_soft_plus(input, init_result):
       return soft_plus_poly(input, outs=[init_result])
 
+    # Just check that we don't assert out on name mismatch.
+    # CHECK-LABEL: @test_non_default_op_name
+    @builtin.FuncOp.from_py_func(
+        RankedTensorType.get((42,), f32), RankedTensorType.get((42,), f32))
+    def test_non_default_op_name(input, init_result):
+      return non_default_op_name(input, outs=[init_result])
+
 
 print(module)


        


More information about the Mlir-commits mailing list