[Mlir-commits] [mlir] cd2776b - [mlir][OpDSL] Split arithmetic functions.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Feb 25 07:28:59 PST 2022


Author: gysit
Date: 2022-02-25T15:27:42Z
New Revision: cd2776b0d5d77eeb166acdf5fce68db90309b403

URL: https://github.com/llvm/llvm-project/commit/cd2776b0d5d77eeb166acdf5fce68db90309b403
DIFF: https://github.com/llvm/llvm-project/commit/cd2776b0d5d77eeb166acdf5fce68db90309b403.diff

LOG: [mlir][OpDSL] Split arithmetic functions.

Split arithmetic function into unary and binary functions. The revision prepares the introduction of unary and binary function attributes that work similar to type function attributes.

Depends On D120108

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D120109

Added: 
    

Modified: 
    mlir/docs/Dialects/Linalg/OpDSL.md
    mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
    mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
    mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
    mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py
    mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
    mlir/test/python/dialects/linalg/opdsl/assignments.py
    mlir/test/python/dialects/linalg/opdsl/emit_misc.py
    mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/Dialects/Linalg/OpDSL.md b/mlir/docs/Dialects/Linalg/OpDSL.md
index 4f703b7504bc1..116057ef4ad8a 100644
--- a/mlir/docs/Dialects/Linalg/OpDSL.md
+++ b/mlir/docs/Dialects/Linalg/OpDSL.md
@@ -178,17 +178,17 @@ TODO: Introduce a directive to fix the dimension bindings.
 Reduction dimensions are inferred to be any dimensions on the RHS that are not
 on the LHS.
 
-A number of arithmetic functions are supported:
-
-*   `ArithFn.add(a, b)` (also via overloading the binary `+` operator)
-*   `ArithFn.exp(a)`
-*   `ArithFn.log(a)`
-*   `ArithFn.mul(a, b)` (also via overloading the binary `*` operator)
-*   `ArithFn.max(a, b)`
-*   `ArithFn.min(a, b)`
-*   `ArithFn.sub(a, b)` (also via overloading the binary `-` operator)
-*   `ArithFn.max_unsigned(a, b)`
-*   `ArithFn.min_unsigned(a, b)`
+A number of unary and binary arithmetic functions are supported:
+
+*   `BinaryFn.add(a, b)` (also via overloading the binary `+` operator)
+*   `BinaryFn.mul(a, b)` (also via overloading the binary `*` operator)
+*   `BinaryFn.max(a, b)`
+*   `BinaryFn.min(a, b)`
+*   `BinaryFn.sub(a, b)` (also via overloading the binary `-` operator)
+*   `BinaryFn.max_unsigned(a, b)`
+*   `BinaryFn.min_unsigned(a, b)`
+*   `UnaryFn.exp(a)`
+*   `UnaryFn.log(a)`
 
 As the integer types are signless, signedness is implement by 
diff erent
 functions that treat integers as signed or unsigned values.

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index fed9d39ed2f3a..8789185b961fa 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -46,14 +46,14 @@ structured_op: !LinalgStructuredOpConfig
     arg: C
     value: !ScalarExpression
       scalar_fn:
-        kind: arith
+        kind: binary
         fn_name: add
         operands:
         - !ScalarExpression
           scalar_arg: C
         - !ScalarExpression
           scalar_fn:
-            kind: arith
+            kind: binary
             fn_name: mul
             operands:
             - !ScalarExpression
@@ -114,14 +114,14 @@ structured_op: !LinalgStructuredOpConfig
     arg: C
     value: !ScalarExpression
       scalar_fn:
-        kind: arith
+        kind: binary
         fn_name: add
         operands:
         - !ScalarExpression
           scalar_arg: C
         - !ScalarExpression
           scalar_fn:
-            kind: arith
+            kind: binary
             fn_name: mul
             operands:
             - !ScalarExpression
@@ -192,19 +192,19 @@ structured_op: !LinalgStructuredOpConfig
     arg: C
     value: !ScalarExpression
       scalar_fn:
-        kind: arith
+        kind: binary
         fn_name: add
         operands:
         - !ScalarExpression
           scalar_arg: C
         - !ScalarExpression
           scalar_fn:
-            kind: arith
+            kind: binary
             fn_name: mul
             operands:
             - !ScalarExpression
               scalar_fn:
-                kind: arith
+                kind: binary
                 fn_name: sub
                 operands:
                 - !ScalarExpression
@@ -225,7 +225,7 @@ structured_op: !LinalgStructuredOpConfig
                       scalar_arg: AZp
             - !ScalarExpression
               scalar_fn:
-                kind: arith
+                kind: binary
                 fn_name: sub
                 operands:
                 - !ScalarExpression
@@ -297,14 +297,14 @@ structured_op: !LinalgStructuredOpConfig
     arg: accum
     value: !ScalarExpression
       scalar_fn:
-        kind: arith
+        kind: binary
         fn_name: add
         operands:
         - !ScalarExpression
           scalar_arg: accum
         - !ScalarExpression
           scalar_fn:
-            kind: arith
+            kind: binary
             fn_name: mul
             operands:
             - !ScalarExpression
@@ -366,14 +366,14 @@ structured_op: !LinalgStructuredOpConfig
     arg: C
     value: !ScalarExpression
       scalar_fn:
-        kind: arith
+        kind: binary
         fn_name: add
         operands:
         - !ScalarExpression
           scalar_arg: C
         - !ScalarExpression
           scalar_fn:
-            kind: arith
+            kind: binary
             fn_name: mul
             operands:
             - !ScalarExpression
@@ -445,19 +445,19 @@ structured_op: !LinalgStructuredOpConfig
     arg: C
     value: !ScalarExpression
       scalar_fn:
-        kind: arith
+        kind: binary
         fn_name: add
         operands:
         - !ScalarExpression
           scalar_arg: C
         - !ScalarExpression
           scalar_fn:
-            kind: arith
+            kind: binary
             fn_name: mul
             operands:
             - !ScalarExpression
               scalar_fn:
-                kind: arith
+                kind: binary
                 fn_name: sub
                 operands:
                 - !ScalarExpression
@@ -478,7 +478,7 @@ structured_op: !LinalgStructuredOpConfig
                       scalar_arg: AZp
             - !ScalarExpression
               scalar_fn:
-                kind: arith
+                kind: binary
                 fn_name: sub
                 operands:
                 - !ScalarExpression
@@ -538,14 +538,14 @@ structured_op: !LinalgStructuredOpConfig
     arg: x
     value: !ScalarExpression
       scalar_fn:
-        kind: arith
+        kind: binary
         fn_name: add
         operands:
         - !ScalarExpression
           scalar_arg: x
         - !ScalarExpression
           scalar_fn:
-            kind: arith
+            kind: binary
             fn_name: mul
             operands:
             - !ScalarExpression
@@ -605,14 +605,14 @@ structured_op: !LinalgStructuredOpConfig
     arg: x
     value: !ScalarExpression
       scalar_fn:
-        kind: arith
+        kind: binary
         fn_name: add
         operands:
         - !ScalarExpression
           scalar_arg: x
         - !ScalarExpression
           scalar_fn:
-            kind: arith
+            kind: binary
             fn_name: mul
             operands:
             - !ScalarExpression
@@ -673,14 +673,14 @@ structured_op: !LinalgStructuredOpConfig
     arg: C
     value: !ScalarExpression
       scalar_fn:
-        kind: arith
+        kind: binary
         fn_name: add
         operands:
         - !ScalarExpression
           scalar_arg: C
         - !ScalarExpression
           scalar_fn:
-            kind: arith
+            kind: binary
             fn_name: mul
             operands:
             - !ScalarExpression
@@ -739,14 +739,14 @@ structured_op: !LinalgStructuredOpConfig
     arg: C
     value: !ScalarExpression
       scalar_fn:
-        kind: arith
+        kind: binary
         fn_name: add
         operands:
         - !ScalarExpression
           scalar_arg: C
         - !ScalarExpression
           scalar_fn:
-            kind: arith
+            kind: binary
             fn_name: mul
             operands:
             - !ScalarExpression
@@ -806,14 +806,14 @@ structured_op: !LinalgStructuredOpConfig
     arg: O
     value: !ScalarExpression
       scalar_fn:
-        kind: arith
+        kind: binary
         fn_name: add
         operands:
         - !ScalarExpression
           scalar_arg: O
         - !ScalarExpression
           scalar_fn:
-            kind: arith
+            kind: binary
             fn_name: mul
             operands:
             - !ScalarExpression
@@ -875,14 +875,14 @@ structured_op: !LinalgStructuredOpConfig
     arg: O
     value: !ScalarExpression
       scalar_fn:
-        kind: arith
+        kind: binary
         fn_name: add
         operands:
         - !ScalarExpression
           scalar_arg: O
         - !ScalarExpression
           scalar_fn:
-            kind: arith
+            kind: binary
             fn_name: mul
             operands:
             - !ScalarExpression
@@ -947,14 +947,14 @@ structured_op: !LinalgStructuredOpConfig
     arg: O
     value: !ScalarExpression
       scalar_fn:
-        kind: arith
+        kind: binary
         fn_name: add
         operands:
         - !ScalarExpression
           scalar_arg: O
         - !ScalarExpression
           scalar_fn:
-            kind: arith
+            kind: binary
             fn_name: mul
             operands:
             - !ScalarExpression
@@ -1031,14 +1031,14 @@ structured_op: !LinalgStructuredOpConfig
     arg: O
     value: !ScalarExpression
       scalar_fn:
-        kind: arith
+        kind: binary
         fn_name: add
         operands:
         - !ScalarExpression
           scalar_arg: O
         - !ScalarExpression
           scalar_fn:
-            kind: arith
+            kind: binary
             fn_name: mul
             operands:
             - !ScalarExpression
@@ -1129,14 +1129,14 @@ structured_op: !LinalgStructuredOpConfig
     arg: O
     value: !ScalarExpression
       scalar_fn:
-        kind: arith
+        kind: binary
         fn_name: add
         operands:
         - !ScalarExpression
           scalar_arg: O
         - !ScalarExpression
           scalar_fn:
-            kind: arith
+            kind: binary
             fn_name: mul
             operands:
             - !ScalarExpression
@@ -1240,19 +1240,19 @@ structured_op: !LinalgStructuredOpConfig
     arg: O
     value: !ScalarExpression
       scalar_fn:
-        kind: arith
+        kind: binary
         fn_name: add
         operands:
         - !ScalarExpression
           scalar_arg: O
         - !ScalarExpression
           scalar_fn:
-            kind: arith
+            kind: binary
             fn_name: mul
             operands:
             - !ScalarExpression
               scalar_fn:
-                kind: arith
+                kind: binary
                 fn_name: sub
                 operands:
                 - !ScalarExpression
@@ -1273,7 +1273,7 @@ structured_op: !LinalgStructuredOpConfig
                       scalar_arg: IZp
             - !ScalarExpression
               scalar_fn:
-                kind: arith
+                kind: binary
                 fn_name: sub
                 operands:
                 - !ScalarExpression
@@ -1364,14 +1364,14 @@ structured_op: !LinalgStructuredOpConfig
     arg: O
     value: !ScalarExpression
       scalar_fn:
-        kind: arith
+        kind: binary
         fn_name: add
         operands:
         - !ScalarExpression
           scalar_arg: O
         - !ScalarExpression
           scalar_fn:
-            kind: arith
+            kind: binary
             fn_name: mul
             operands:
             - !ScalarExpression
@@ -1464,14 +1464,14 @@ structured_op: !LinalgStructuredOpConfig
     arg: O
     value: !ScalarExpression
       scalar_fn:
-        kind: arith
+        kind: binary
         fn_name: add
         operands:
         - !ScalarExpression
           scalar_arg: O
         - !ScalarExpression
           scalar_fn:
-            kind: arith
+            kind: binary
             fn_name: mul
             operands:
             - !ScalarExpression
@@ -1547,14 +1547,14 @@ structured_op: !LinalgStructuredOpConfig
     arg: O
     value: !ScalarExpression
       scalar_fn:
-        kind: arith
+        kind: binary
         fn_name: add
         operands:
         - !ScalarExpression
           scalar_arg: O
         - !ScalarExpression
           scalar_fn:
-            kind: arith
+            kind: binary
             fn_name: mul
             operands:
             - !ScalarExpression
@@ -1640,14 +1640,14 @@ structured_op: !LinalgStructuredOpConfig
     arg: O
     value: !ScalarExpression
       scalar_fn:
-        kind: arith
+        kind: binary
         fn_name: add
         operands:
         - !ScalarExpression
           scalar_arg: O
         - !ScalarExpression
           scalar_fn:
-            kind: arith
+            kind: binary
             fn_name: mul
             operands:
             - !ScalarExpression
@@ -1744,19 +1744,19 @@ structured_op: !LinalgStructuredOpConfig
     arg: O
     value: !ScalarExpression
       scalar_fn:
-        kind: arith
+        kind: binary
         fn_name: add
         operands:
         - !ScalarExpression
           scalar_arg: O
         - !ScalarExpression
           scalar_fn:
-            kind: arith
+            kind: binary
             fn_name: mul
             operands:
             - !ScalarExpression
               scalar_fn:
-                kind: arith
+                kind: binary
                 fn_name: sub
                 operands:
                 - !ScalarExpression
@@ -1777,7 +1777,7 @@ structured_op: !LinalgStructuredOpConfig
                       scalar_arg: IZp
             - !ScalarExpression
               scalar_fn:
-                kind: arith
+                kind: binary
                 fn_name: sub
                 operands:
                 - !ScalarExpression
@@ -1864,14 +1864,14 @@ structured_op: !LinalgStructuredOpConfig
     arg: O
     value: !ScalarExpression
       scalar_fn:
-        kind: arith
+        kind: binary
         fn_name: add
         operands:
         - !ScalarExpression
           scalar_arg: O
         - !ScalarExpression
           scalar_fn:
-            kind: arith
+            kind: binary
             fn_name: mul
             operands:
             - !ScalarExpression
@@ -1970,19 +1970,19 @@ structured_op: !LinalgStructuredOpConfig
     arg: O
     value: !ScalarExpression
       scalar_fn:
-        kind: arith
+        kind: binary
         fn_name: add
         operands:
         - !ScalarExpression
           scalar_arg: O
         - !ScalarExpression
           scalar_fn:
-            kind: arith
+            kind: binary
             fn_name: mul
             operands:
             - !ScalarExpression
               scalar_fn:
-                kind: arith
+                kind: binary
                 fn_name: sub
                 operands:
                 - !ScalarExpression
@@ -2003,7 +2003,7 @@ structured_op: !LinalgStructuredOpConfig
                       scalar_arg: IZp
             - !ScalarExpression
               scalar_fn:
-                kind: arith
+                kind: binary
                 fn_name: sub
                 operands:
                 - !ScalarExpression
@@ -2088,7 +2088,7 @@ structured_op: !LinalgStructuredOpConfig
     arg: O
     value: !ScalarExpression
       scalar_fn:
-        kind: arith
+        kind: binary
         fn_name: add
         operands:
         - !ScalarExpression
@@ -2167,7 +2167,7 @@ structured_op: !LinalgStructuredOpConfig
     arg: O
     value: !ScalarExpression
       scalar_fn:
-        kind: arith
+        kind: binary
         fn_name: max
         operands:
         - !ScalarExpression
@@ -2246,7 +2246,7 @@ structured_op: !LinalgStructuredOpConfig
     arg: O
     value: !ScalarExpression
       scalar_fn:
-        kind: arith
+        kind: binary
         fn_name: max_unsigned
         operands:
         - !ScalarExpression
@@ -2325,7 +2325,7 @@ structured_op: !LinalgStructuredOpConfig
     arg: O
     value: !ScalarExpression
       scalar_fn:
-        kind: arith
+        kind: binary
         fn_name: max
         operands:
         - !ScalarExpression
@@ -2404,7 +2404,7 @@ structured_op: !LinalgStructuredOpConfig
     arg: O
     value: !ScalarExpression
       scalar_fn:
-        kind: arith
+        kind: binary
         fn_name: min
         operands:
         - !ScalarExpression
@@ -2483,7 +2483,7 @@ structured_op: !LinalgStructuredOpConfig
     arg: O
     value: !ScalarExpression
       scalar_fn:
-        kind: arith
+        kind: binary
         fn_name: min_unsigned
         operands:
         - !ScalarExpression
@@ -2568,7 +2568,7 @@ structured_op: !LinalgStructuredOpConfig
     arg: O
     value: !ScalarExpression
       scalar_fn:
-        kind: arith
+        kind: binary
         fn_name: add
         operands:
         - !ScalarExpression
@@ -2653,7 +2653,7 @@ structured_op: !LinalgStructuredOpConfig
     arg: O
     value: !ScalarExpression
       scalar_fn:
-        kind: arith
+        kind: binary
         fn_name: max
         operands:
         - !ScalarExpression
@@ -2738,7 +2738,7 @@ structured_op: !LinalgStructuredOpConfig
     arg: O
     value: !ScalarExpression
       scalar_fn:
-        kind: arith
+        kind: binary
         fn_name: min
         operands:
         - !ScalarExpression
@@ -2841,17 +2841,17 @@ structured_op: !LinalgStructuredOpConfig
         operands:
         - !ScalarExpression
           scalar_fn:
-            kind: arith
+            kind: binary
             fn_name: add
             operands:
             - !ScalarExpression
               scalar_fn:
-                kind: arith
+                kind: binary
                 fn_name: mul
                 operands:
                 - !ScalarExpression
                   scalar_fn:
-                    kind: arith
+                    kind: binary
                     fn_name: add
                     operands:
                     - !ScalarExpression
@@ -2870,17 +2870,17 @@ structured_op: !LinalgStructuredOpConfig
                         operands:
                         - !ScalarExpression
                           scalar_fn:
-                            kind: arith
+                            kind: binary
                             fn_name: add
                             operands:
                             - !ScalarExpression
                               scalar_fn:
-                                kind: arith
+                                kind: binary
                                 fn_name: mul
                                 operands:
                                 - !ScalarExpression
                                   scalar_fn:
-                                    kind: arith
+                                    kind: binary
                                     fn_name: add
                                     operands:
                                     - !ScalarExpression
@@ -2893,17 +2893,17 @@ structured_op: !LinalgStructuredOpConfig
                                           scalar_index: 1
                                     - !ScalarExpression
                                       scalar_fn:
-                                        kind: arith
+                                        kind: binary
                                         fn_name: add
                                         operands:
                                         - !ScalarExpression
                                           scalar_fn:
-                                            kind: arith
+                                            kind: binary
                                             fn_name: mul
                                             operands:
                                             - !ScalarExpression
                                               scalar_fn:
-                                                kind: arith
+                                                kind: binary
                                                 fn_name: add
                                                 operands:
                                                 - !ScalarExpression
@@ -2950,12 +2950,12 @@ structured_op: !LinalgStructuredOpConfig
                                   scalar_const: '12345 : i64'
                 - !ScalarExpression
                   scalar_fn:
-                    kind: arith
+                    kind: binary
                     fn_name: mul
                     operands:
                     - !ScalarExpression
                       scalar_fn:
-                        kind: arith
+                        kind: binary
                         fn_name: sub
                         operands:
                         - !ScalarExpression
@@ -3005,12 +3005,12 @@ structured_op: !LinalgStructuredOpConfig
     arg: O
     value: !ScalarExpression
       scalar_fn:
-        kind: arith
+        kind: unary
         fn_name: log
         operands:
         - !ScalarExpression
           scalar_fn:
-            kind: arith
+            kind: binary
             fn_name: add
             operands:
             - !ScalarExpression
@@ -3023,7 +3023,7 @@ structured_op: !LinalgStructuredOpConfig
                   scalar_const: '1.000000e+00 : f64'
             - !ScalarExpression
               scalar_fn:
-                kind: arith
+                kind: unary
                 fn_name: exp
                 operands:
                 - !ScalarExpression

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index d41d9646da4d0..44179ebe60757 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -147,13 +147,14 @@ static LogicalResult foldMemRefCastInTiledLoopOp(TiledLoopOp op) {
 // Region builder helper.
 // TODO: Move this to a utility library.
 // The public methods on this class are referenced directly from generated code
-// and bind by name to math and type conversion functions in the DSL as:
-//   `arithfn__{fnName}`
-//   `typefn__{fnName}`
+// and bind by name to math functions in the DSL as:
+//   `unary__{fnName}`
+//   `binary__{fnName}`
 // Examples:
-//   `arithfn__add`
-//   `arithfn__mul`
-//   `typefn__cast`
+//   `binary__add`
+//   `binary__mul`
+//   `unary__exp`
+//   `unary__log`
 // The naming convention is intentional in order to match snake-cased DSL names.
 // See mlir-linalg-ods-yaml-gen.cpp for the code that mates to this class.
 //
@@ -241,7 +242,7 @@ class RegionBuilderHelper {
   }
 
   // NOLINTNEXTLINE(*-identifier-naming): externally called.
-  Value arithfn__add(Value lhs, Value rhs) {
+  Value binary__add(Value lhs, Value rhs) {
     OpBuilder builder = getBuilder();
     if (isFloatingPoint(lhs))
       return builder.create<arith::AddFOp>(lhs.getLoc(), lhs, rhs);
@@ -251,7 +252,7 @@ class RegionBuilderHelper {
   }
 
   // NOLINTNEXTLINE(*-identifier-naming): externally called.
-  Value arithfn__exp(Value x) {
+  Value unary__exp(Value x) {
     OpBuilder builder = getBuilder();
     if (isFloatingPoint(x))
       return builder.create<math::ExpOp>(x.getLoc(), x);
@@ -259,7 +260,7 @@ class RegionBuilderHelper {
   }
 
   // NOLINTNEXTLINE(*-identifier-naming): externally called.
-  Value arithfn__log(Value x) {
+  Value unary__log(Value x) {
     OpBuilder builder = getBuilder();
     if (isFloatingPoint(x))
       return builder.create<math::LogOp>(x.getLoc(), x);
@@ -267,7 +268,7 @@ class RegionBuilderHelper {
   }
 
   // NOLINTNEXTLINE(*-identifier-naming): externally called.
-  Value arithfn__sub(Value lhs, Value rhs) {
+  Value binary__sub(Value lhs, Value rhs) {
     OpBuilder builder = getBuilder();
     if (isFloatingPoint(lhs))
       return builder.create<arith::SubFOp>(lhs.getLoc(), lhs, rhs);
@@ -277,7 +278,7 @@ class RegionBuilderHelper {
   }
 
   // NOLINTNEXTLINE(*-identifier-naming): externally called.
-  Value arithfn__mul(Value lhs, Value rhs) {
+  Value binary__mul(Value lhs, Value rhs) {
     OpBuilder builder = getBuilder();
     if (isFloatingPoint(lhs))
       return builder.create<arith::MulFOp>(lhs.getLoc(), lhs, rhs);
@@ -287,7 +288,7 @@ class RegionBuilderHelper {
   }
 
   // NOLINTNEXTLINE(*-identifier-naming): externally called.
-  Value arithfn__max(Value lhs, Value rhs) {
+  Value binary__max(Value lhs, Value rhs) {
     OpBuilder builder = getBuilder();
     if (isFloatingPoint(lhs))
       return builder.create<arith::MaxFOp>(lhs.getLoc(), lhs, rhs);
@@ -297,7 +298,7 @@ class RegionBuilderHelper {
   }
 
   // NOLINTNEXTLINE(*-identifier-naming): externally called.
-  Value arithfn__max_unsigned(Value lhs, Value rhs) {
+  Value binary__max_unsigned(Value lhs, Value rhs) {
     OpBuilder builder = getBuilder();
     if (isFloatingPoint(lhs))
       return builder.create<arith::MaxFOp>(lhs.getLoc(), lhs, rhs);
@@ -307,7 +308,7 @@ class RegionBuilderHelper {
   }
 
   // NOLINTNEXTLINE(*-identifier-naming): externally called.
-  Value arithfn__min(Value lhs, Value rhs) {
+  Value binary__min(Value lhs, Value rhs) {
     OpBuilder builder = getBuilder();
     if (isFloatingPoint(lhs))
       return builder.create<arith::MinFOp>(lhs.getLoc(), lhs, rhs);
@@ -317,7 +318,7 @@ class RegionBuilderHelper {
   }
 
   // NOLINTNEXTLINE(*-identifier-naming): externally called.
-  Value arithfn__min_unsigned(Value lhs, Value rhs) {
+  Value binary__min_unsigned(Value lhs, Value rhs) {
     OpBuilder builder = getBuilder();
     if (isFloatingPoint(lhs))
       return builder.create<arith::MinFOp>(lhs.getLoc(), lhs, rhs);

diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
index d26aa077096c7..ef2ef30378211 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
@@ -77,13 +77,13 @@ def visit_scalar_def(expr: "TensorExpression"):
     self.visit_tensor_exprs(visit_scalar_def)
 
   def __add__(self, rhs: "TensorExpression") -> "TensorExpression":
-    return ArithFn.add(self, rhs)
+    return BinaryFn.add(self, rhs)
 
   def __mul__(self, rhs) -> "TensorExpression":
-    return ArithFn.mul(self, rhs)
+    return BinaryFn.mul(self, rhs)
 
   def __sub__(self, rhs) -> "TensorExpression":
-    return ArithFn.sub(self, rhs)
+    return BinaryFn.sub(self, rhs)
 
   def __hash__(self):
     return hash(id(self))
@@ -126,7 +126,7 @@ def _compute_reduce_dims(self, rhs: TensorExpression) -> Set[DimDef]:
     return rhs_dims - lhs_dims
 
   def __iadd__(self, rhs: TensorExpression) -> "TensorReduceFn":
-    return ReduceFnUse(ArithFn.add, *self._compute_reduce_dims(rhs))(rhs)
+    return ReduceFnUse(BinaryFn.add, *self._compute_reduce_dims(rhs))(rhs)
 
   def __repr__(self):
     return (f"{self.operand_def.name}"
@@ -183,8 +183,8 @@ def to_scalar_expression(self) -> ScalarExpression:
                        f"bound to its lhs: {self}")
     full_args = [self.lhs.to_scalar_expression()
                 ] + [arg.to_scalar_expression() for arg in self.args]
-    return ScalarFn(FunctionKind.ARITH, self.reduce_use.arith_fn.fn_name, None,
-                    None, full_args).expr()
+    return ScalarFn(FunctionKind.BINARY, self.reduce_use.binary_fn.fn_name,
+                    None, None, full_args).expr()
 
   def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]):
     for arg in self.args:
@@ -242,61 +242,54 @@ def __repr__(self):
 
 
 class FunctionKind(Enum):
-  ARITH = 0
-  TYPE = 1
+  UNARY = 0
+  BINARY = 1
+  TYPE = 2
 
 
-class TypeFnType:
-  """Type conversion function.
+class UnaryFnType:
+  """Unary function.
 
-  A type conversion function takes a target type and a tensor expression and
-  returns the casted tensor expression.
+  A unary function takes one tensor expression and returns the
+  function evaluation result.
   """
 
   def __init__(self, fn_name: str):
     self.fn_name = fn_name
 
-  def __call__(self, type_var: TypeVar, arg: TensorExpression) -> "TensorFn":
-    return TensorFn(FunctionKind.TYPE, self.fn_name, None, type_var, [arg])
+  def __call__(self, exp: TensorExpression) -> "TensorFn":
+    return TensorFn(FunctionKind.UNARY, self.fn_name, None, None, [exp])
 
   def __repr__(self):
     return f"{self.fn_name}"
 
 
-class TypeFn:
-  """Type conversion function namespace.
-
-  As the integer types are signless, signedness is implement by 
diff erent cast
-  functions that treat integers as signed (`cast`) or unsigned
-  (`cast_unsigned`) values.
-
-  Examples:
-  - cast(I32 -> I64) -> `arith.ExtSIOp`
-  - cast_unsigned(I32 -> I64) -> `arith.ExtUIOp`
-  """
-  cast = TypeFnType("cast")
-  cast_unsigned = TypeFnType("cast_unsigned")
+class UnaryFn:
+  """Unary function namespace."""
+  exp = UnaryFnType("exp")
+  log = UnaryFnType("log")
 
 
-class ArithFnType:
-  """Arithmetic function.
+class BinaryFnType:
+  """Binary function.
 
-  An arithmetic function takes one ore more tensor expressions and returns the
+  A binary function takes two tensor expressions and returns the
   function evaluation result.
   """
 
   def __init__(self, fn_name: str):
     self.fn_name = fn_name
 
-  def __call__(self, *args) -> "TensorFn":
-    return TensorFn(FunctionKind.ARITH, self.fn_name, None, None, args)
+  def __call__(self, arg0: TensorExpression,
+               arg1: TensorExpression) -> "TensorFn":
+    return TensorFn(FunctionKind.BINARY, self.fn_name, None, None, [arg0, arg1])
 
   def __repr__(self):
     return f"{self.fn_name}"
 
 
-class ArithFn:
-  """Arithmetic function namespace.
+class BinaryFn:
+  """Binary function namespace.
 
   As the integer types are signless, signedness is implement by 
diff erent
   functions that treat integers as signed or unsigned values.
@@ -305,15 +298,45 @@ class ArithFn:
   - max -> `arith.MaxSIOp`
   - max_unsinged -> `arith.MaxUIOp`
   """
-  add = ArithFnType("add")
-  exp = ArithFnType("exp")
-  log = ArithFnType("log")
-  mul = ArithFnType("mul")
-  max = ArithFnType("max")
-  min = ArithFnType("min")
-  sub = ArithFnType("sub")
-  max_unsigned = ArithFnType("max_unsigned")
-  min_unsigned = ArithFnType("min_unsigned")
+  add = BinaryFnType("add")
+  mul = BinaryFnType("mul")
+  max = BinaryFnType("max")
+  min = BinaryFnType("min")
+  sub = BinaryFnType("sub")
+  max_unsigned = BinaryFnType("max_unsigned")
+  min_unsigned = BinaryFnType("min_unsigned")
+
+
+class TypeFnType:
+  """Type conversion function.
+
+  A type conversion function takes a target type and a tensor expression and
+  returns the casted tensor expression.
+  """
+
+  def __init__(self, fn_name: str):
+    self.fn_name = fn_name
+
+  def __call__(self, type_var: TypeVar, arg: TensorExpression) -> "TensorFn":
+    return TensorFn(FunctionKind.TYPE, self.fn_name, None, type_var, [arg])
+
+  def __repr__(self):
+    return f"{self.fn_name}"
+
+
+class TypeFn:
+  """Type conversion function namespace.
+
+  As the integer types are signless, signedness is implement by 
diff erent cast
+  functions that treat integers as signed (`cast`) or unsigned
+  (`cast_unsigned`) values.
+
+  Examples:
+  - cast(I32 -> I64) -> `arith.ExtSIOp`
+  - cast_unsigned(I32 -> I64) -> `arith.ExtUIOp`
+  """
+  cast = TypeFnType("cast")
+  cast_unsigned = TypeFnType("cast_unsigned")
 
 
 class ReduceFnUse:
@@ -322,43 +345,43 @@ class ReduceFnUse:
   A reduction use specifies the reduction function and dimensions.
   """
 
-  def __init__(self, arith_fn: ArithFnType, *reduce_dims: DimDef):
-    self.arith_fn = arith_fn
+  def __init__(self, binary_fn: BinaryFnType, *reduce_dims: DimDef):
+    self.binary_fn = binary_fn
     self.reduce_dims = reduce_dims
 
-  def __call__(self, *args: TensorExpression):
+  def __call__(self, *args: TensorExpression) -> "TensorReduceFn":
     return TensorReduceFn(self, args)
 
   def __repr__(self):
-    return (f"reduce_{self.arith_fn.fn_name}"
+    return (f"reduce_{self.binary_fn.fn_name}"
             f"({', '.join(repr(d) for d in self.reduce_dims)})")
 
 
 class ReduceFnType:
   """Reduction function.
 
-  An arithmetic function that reduces its RHS into its LHS.
+  A binary function that reduces its RHS into its LHS.
   """
 
-  def __init__(self, arith_fn: ArithFnType):
-    if not isinstance(arith_fn, ArithFnType):
-      raise ValueError(f"Reduce expected a ArithFnType but got {arith_fn}")
-    self.arith_fn = arith_fn
+  def __init__(self, binary_fn: BinaryFnType):
+    if not isinstance(binary_fn, BinaryFnType):
+      raise ValueError(f"Reduce expected a BinaryFnType but got {binary_fn}")
+    self.binary_fn = binary_fn
 
   def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse:
-    return ReduceFnUse(self.arith_fn, *reduce_dims)
+    return ReduceFnUse(self.binary_fn, *reduce_dims)
 
   def __repr__(self):
-    return (f"reduce_{self.arith_fn.fn_name}")
+    return (f"reduce_{self.binary_fn.fn_name}")
 
 
 class ReduceFn:
-  add = ReduceFnType(ArithFn.add)
-  mul = ReduceFnType(ArithFn.mul)
-  max = ReduceFnType(ArithFn.max)
-  min = ReduceFnType(ArithFn.min)
-  max_unsigned = ReduceFnType(ArithFn.max_unsigned)
-  min_unsigned = ReduceFnType(ArithFn.min_unsigned)
+  add = ReduceFnType(BinaryFn.add)
+  mul = ReduceFnType(BinaryFn.mul)
+  max = ReduceFnType(BinaryFn.max)
+  min = ReduceFnType(BinaryFn.min)
+  max_unsigned = ReduceFnType(BinaryFn.max_unsigned)
+  min_unsigned = ReduceFnType(BinaryFn.min_unsigned)
 
 
 ###############################################################################

diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
index 07050f56fa640..df4ab2249d4d4 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
@@ -270,17 +270,19 @@ def expression(self, expr: ScalarExpression) -> Value:
       dim_attr = IntegerAttr.get(
           IntegerType.get_signless(64), expr.scalar_index.dim)
       return linalg.IndexOp(dim_attr).result
-    elif expr.scalar_fn and expr.scalar_fn.kind == FunctionKind.ARITH:
-      fn = self._get_function(f"_arithfn_{expr.scalar_fn.fn_name}")
+    elif expr.scalar_fn and expr.scalar_fn.kind is not FunctionKind.TYPE:
+      kind = expr.scalar_fn.kind.name.lower()
+      fn = self._get_function(f"_{kind}_{expr.scalar_fn.fn_name}")
       operand_values = [
           self.expression(operand) for operand in expr.scalar_fn.operands
       ]
       return fn(*operand_values)
-    elif expr.scalar_fn and expr.scalar_fn.kind == FunctionKind.TYPE:
+    elif expr.scalar_fn and expr.scalar_fn.kind is FunctionKind.TYPE:
+      kind = expr.scalar_fn.kind.name.lower()
       fn_name = expr.scalar_fn.fn_name
       if expr.scalar_fn.attr_name:
         fn_name = self.type_fn_attr_mapping[expr.scalar_fn.attr_name]
-      fn = self._get_function(f"_typefn_{fn_name}")
+      fn = self._get_function(f"_{kind}_{fn_name}")
       operand_value = self.expression(expr.scalar_fn.operands[0])
       return fn(expr.scalar_fn.type_var.name, operand_value)
     raise NotImplementedError(f"Unimplemented scalar body expression: {expr}")
@@ -356,70 +358,72 @@ def _cast_to_floating_point(self, to_type: Type, operand: Value,
     raise ValueError(f"Unable to cast body expression from {operand_type} to "
                      f"{to_type}")
 
-  def _typefn_cast(self, type_var_name: str, operand: Value) -> Value:
+  def _type_cast(self, type_var_name: str, operand: Value) -> Value:
     return self._cast(type_var_name, operand, False)
 
-  def _typefn_cast_unsigned(self, type_var_name: str, operand: Value) -> Value:
+  def _type_cast_unsigned(self, type_var_name: str, operand: Value) -> Value:
     return self._cast(type_var_name, operand, True)
 
-  def _arithfn_add(self, lhs: Value, rhs: Value) -> Value:
-    if _is_floating_point_type(lhs.type):
-      return arith.AddFOp(lhs, rhs).result
-    if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
-      return arith.AddIOp(lhs, rhs).result
-    raise NotImplementedError("Unsupported 'add' operand: {lhs}")
-
-  def _arithfn_exp(self, x: Value) -> Value:
+  def _unary_exp(self, x: Value) -> Value:
     if _is_floating_point_type(x.type):
       return math.ExpOp(x).result
     raise NotImplementedError("Unsupported 'exp' operand: {x}")
 
-  def _arithfn_log(self, x: Value) -> Value:
+  def _unary_log(self, x: Value) -> Value:
     if _is_floating_point_type(x.type):
       return math.LogOp(x).result
     raise NotImplementedError("Unsupported 'log' operand: {x}")
 
-  def _arithfn_sub(self, lhs: Value, rhs: Value) -> Value:
+  def _binary_add(self, lhs: Value, rhs: Value) -> Value:
+    if _is_floating_point_type(lhs.type):
+      return arith.AddFOp(lhs, rhs).result
+    if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
+      return arith.AddIOp(lhs, rhs).result
+    raise NotImplementedError("Unsupported 'add' operands: {lhs}, {rhs}")
+
+  def _binary_sub(self, lhs: Value, rhs: Value) -> Value:
     if _is_floating_point_type(lhs.type):
       return arith.SubFOp(lhs, rhs).result
     if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
       return arith.SubIOp(lhs, rhs).result
-    raise NotImplementedError("Unsupported 'sub' operand: {lhs}")
+    raise NotImplementedError("Unsupported 'sub' operands: {lhs}, {rhs}")
 
-  def _arithfn_mul(self, lhs: Value, rhs: Value) -> Value:
+  def _binary_mul(self, lhs: Value, rhs: Value) -> Value:
     if _is_floating_point_type(lhs.type):
       return arith.MulFOp(lhs, rhs).result
     if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
       return arith.MulIOp(lhs, rhs).result
-    raise NotImplementedError("Unsupported 'mul' operand: {lhs}")
+    raise NotImplementedError("Unsupported 'mul' operands: {lhs}, {rhs}")
 
-  def _arithfn_max(self, lhs: Value, rhs: Value) -> Value:
+  def _binary_max(self, lhs: Value, rhs: Value) -> Value:
     if _is_floating_point_type(lhs.type):
       return arith.MaxFOp(lhs, rhs).result
     if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
       return arith.MaxSIOp(lhs, rhs).result
-    raise NotImplementedError("Unsupported 'max' operand: {lhs}")
+    raise NotImplementedError("Unsupported 'max' operands: {lhs}, {rhs}")
 
-  def _arithfn_max_unsigned(self, lhs: Value, rhs: Value) -> Value:
+  def _binary_max_unsigned(self, lhs: Value, rhs: Value) -> Value:
     if _is_floating_point_type(lhs.type):
       return arith.MaxFOp(lhs, rhs).result
     if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
       return arith.MaxUIOp(lhs, rhs).result
-    raise NotImplementedError("Unsupported 'max_unsigned' operand: {lhs}")
+    raise NotImplementedError(
+        "Unsupported 'max_unsigned' operands: {lhs}, {rhs}")
 
-  def _arithfn_min(self, lhs: Value, rhs: Value) -> Value:
+  def _binary_min(self, lhs: Value, rhs: Value) -> Value:
     if _is_floating_point_type(lhs.type):
       return arith.MinFOp(lhs, rhs).result
     if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
       return arith.MinSIOp(lhs, rhs).result
-    raise NotImplementedError("Unsupported 'min' operand: {lhs}")
+    raise NotImplementedError("Unsupported 'min' operands: {lhs}, {rhs}")
 
-  def _arithfn_min_unsigned(self, lhs: Value, rhs: Value) -> Value:
+  def _binary_min_unsigned(self, lhs: Value, rhs: Value) -> Value:
     if _is_floating_point_type(lhs.type):
       return arith.MinFOp(lhs, rhs).result
     if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
       return arith.MinUIOp(lhs, rhs).result
-    raise NotImplementedError("Unsupported 'min_unsigned' operand: {lhs}")
+    raise NotImplementedError(
+        "Unsupported 'min_unsigned' operands: {lhs}, {rhs}")
 
 
 def _infer_structured_outs(

diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index db63a0705c62b..340f4db4471bb 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -677,4 +677,4 @@ def soft_plus_2d(
   """
   domain(D.m, D.n)
   O[D.m, D.n] = \
-      ArithFn.log(TypeFn.cast(U, const(1.0)) + ArithFn.exp(TypeFn.cast(U, I[D.m, D.n])))
+      UnaryFn.log(TypeFn.cast(U, const(1.0)) + UnaryFn.exp(TypeFn.cast(U, I[D.m, D.n])))

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py
index ed73ab13dd8d0..7b07821791186 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py
@@ -1539,12 +1539,12 @@ def _gather_input_accesses_index_vars(
     input_accesses.append(expr)
 
 
-def _op_to_callable(op: _BinaryOp) -> lang.ArithFnType:
+def _op_to_callable(op: _BinaryOp) -> lang.BinaryFnType:
   """Returns the linalg dialect function object for the given operation."""
   op_to_callable = {
-      operator.add: lang.ArithFn.add,
-      operator.sub: lang.ArithFn.sub,
-      operator.mul: lang.ArithFn.mul,
+      operator.add: lang.BinaryFn.add,
+      operator.sub: lang.BinaryFn.sub,
+      operator.mul: lang.BinaryFn.mul,
   }
   return op_to_callable[op]
 

diff  --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
index 660637e669a67..dc1e1809eb46b 100644
--- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
+++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
@@ -40,7 +40,7 @@ structured_op: !LinalgStructuredOpConfig
     arg: O
     value: !ScalarExpression
       scalar_fn:
-        kind: arith
+        kind: binary
         fn_name: add
         operands:
         - !ScalarExpression
@@ -111,7 +111,7 @@ structured_op: !LinalgStructuredOpConfig
 #   IMPL-DAG:  Value [[VAL1:[a-z0-9]+]] = helper.buildTypeFn(castVal, block.getArgument(0).getType(), [[VAL0]]);
 #   IMPL-DAG:  Value [[VAL2:[a-z0-9]+]] = helper.index(1);
 #   IMPL-DAG:  Value [[VAL3:[a-z0-9]+]] = helper.buildTypeFn(castVal, block.getArgument(0).getType(), [[VAL2]]);
-#   IMPL-DAG:  Value [[VAL4:[a-z0-9]+]] = helper.arithfn__add([[VAL1]], [[VAL3]]);
+#   IMPL-DAG:  Value [[VAL4:[a-z0-9]+]] = helper.binary__add([[VAL1]], [[VAL3]]);
 
 
 # @linalg_structured_op
@@ -254,3 +254,58 @@ structured_op: !LinalgStructuredOpConfig
 #  IMPL-NEXT:    MLIRContext *context = getContext();
 #  IMPL-NEXT:    AffineMap scalarMap = AffineMap::get(getNumParallelLoops(), 0, context);
 #  IMPL-NEXT:    AffineMap tensorMap = AffineMap::getMultiDimIdentityMap(
+
+
+# @linalg_structured_op
+# def test4(O=TensorDef(T, S.M, S.N, output=True)):
+#   """Title.
+
+#   Detailed description.
+#   """
+#   O[D.m, D.n] = BinaryFn.add(UnaryFn.exp(O[D.m, D.n]), O[D.m, D.n])
+
+--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: test4
+  cpp_class_name: Test4Op
+  doc: |-
+    Title.
+
+    Detailed description.
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !LinalgOperandDefConfig
+    name: O
+    kind: output_tensor
+    type_var: T
+    shape_map: affine_map<()[s0, s1] -> (s0, s1)>
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<(d0, d1)[s0, s1] -> (d0, d1)>
+  iterator_types:
+  - parallel
+  - parallel
+  assignments:
+  - !ScalarAssign
+    arg: O
+    value: !ScalarExpression
+      scalar_fn:
+        kind: binary
+        fn_name: add
+        operands:
+        - !ScalarExpression
+          scalar_fn:
+            kind: unary
+            fn_name: exp
+            operands:
+            - !ScalarExpression
+              scalar_arg: O
+        - !ScalarExpression
+          scalar_arg: O
+
+# IMPL-LABEL:  void Test4Op::regionBuilder(ImplicitLocOpBuilder &b,
+#  IMPL-NEXT:    Block &block, ArrayRef<NamedAttribute> attrs)
+
+#       IMPL:  Value [[VAL0:[a-z0-9]+]] = helper.unary__exp(block.getArgument(0))
+#  IMPL-NEXT:  Value [[VAL1:[a-z0-9]+]] = helper.binary__add([[VAL0]], block.getArgument(0))
+#  IMPL-NEXT:  yields.push_back([[VAL1]])

diff  --git a/mlir/test/python/dialects/linalg/opdsl/assignments.py b/mlir/test/python/dialects/linalg/opdsl/assignments.py
index 5b87216ca4372..f93e0704a1e36 100644
--- a/mlir/test/python/dialects/linalg/opdsl/assignments.py
+++ b/mlir/test/python/dialects/linalg/opdsl/assignments.py
@@ -42,18 +42,22 @@ def matmul(
 # CHECK:  -
 # CHECK:    arg: O
 # CHECK:      scalar_fn:
-# CHECK:        kind: arith
+# CHECK:        kind: binary
 # CHECK:        fn_name: sub
 # CHECK:        operands:
 # CHECK:          scalar_fn:
-# CHECK:            kind: arith
+# CHECK:            kind: binary
 # CHECK:            fn_name: add
 # CHECK:            operands:
 # CHECK:              scalar_fn:
-# CHECK:                kind: type
-# CHECK:                type_var: T
+# CHECK:                kind: unary
+# CHECK:                fn_name: exp
 # CHECK:                operands:
-# CHECK:                  scalar_const: '3.1415926535897931 : f64'
+# CHECK:                  scalar_fn:
+# CHECK:                    kind: type
+# CHECK:                    type_var: T
+# CHECK:                    operands:
+# CHECK:                      scalar_const: '3.1415926535897931 : f64'
 # CHECK:              scalar_fn:
 # CHECK:                kind: type
 # CHECK:                fn_name: cast
@@ -71,8 +75,7 @@ def constants(O=TensorDef(T, S.M, S.K, output=True)):
   pi = TypeFn.cast(T, const(3.1415926535897931))
   cst42 = TypeFn.cast(T, const(42))
   cst1000 = TypeFn.cast(T, const(1e+3))
-  O[D.m, D.n] = pi + cst42 - cst1000
-
+  O[D.m, D.n] = UnaryFn.exp(pi) + cst42 - cst1000
 
 # CHECK: ---
 # CHECK-LABEL: indices
@@ -80,7 +83,7 @@ def constants(O=TensorDef(T, S.M, S.K, output=True)):
 # CHECK:  -
 # CHECK:    arg: O
 # CHECK:      scalar_fn:
-# CHECK:        kind: arith
+# CHECK:        kind: binary
 # CHECK:        fn_name: add
 # CHECK:        operands:
 # CHECK:          scalar_index: 1

diff  --git a/mlir/test/python/dialects/linalg/opdsl/emit_misc.py b/mlir/test/python/dialects/linalg/opdsl/emit_misc.py
index 9f4872e33e8d3..9eab9b2a8e6ac 100644
--- a/mlir/test/python/dialects/linalg/opdsl/emit_misc.py
+++ b/mlir/test/python/dialects/linalg/opdsl/emit_misc.py
@@ -35,8 +35,8 @@ def fill_rng_poly(
 @linalg_structured_op
 def soft_plus_poly(
     I=TensorDef(T, S.M, S.N), O=TensorDef(U, S.M, S.N, output=True)):
-  O[D.m, D.n] = ArithFn.log(
-      TypeFn.cast(U, const(1.0)) + TypeFn.cast(U, ArithFn.exp(I[D.m, D.n])))
+  O[D.m, D.n] = UnaryFn.log(
+      TypeFn.cast(U, const(1.0)) + TypeFn.cast(U, UnaryFn.exp(I[D.m, D.n])))
 
 
 @linalg_structured_op(op_name="custom_op_name")

diff  --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
index d1fc9ac944942..7685d1a53e313 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
@@ -90,7 +90,7 @@ struct LinalgIndexingMapsConfig {
 
 struct ScalarExpression;
 
-enum class ScalarFnKind { Arith, Type };
+enum class ScalarFnKind { Unary, Binary, Type };
 
 struct ScalarFn {
   ScalarFnKind kind;
@@ -275,7 +275,8 @@ struct MappingTraits<ScalarExpression> {
 template <>
 struct ScalarEnumerationTraits<ScalarFnKind> {
   static void enumeration(IO &io, ScalarFnKind &value) {
-    io.enumCase(value, "arith", ScalarFnKind::Arith);
+    io.enumCase(value, "unary", ScalarFnKind::Unary);
+    io.enumCase(value, "binary", ScalarFnKind::Binary);
     io.enumCase(value, "type", ScalarFnKind::Type);
   }
 };
@@ -1056,7 +1057,7 @@ if ({0}Iter != attrs.end()) {{
           return cppIdent;
         }
         if (expression.scalarFn &&
-            expression.scalarFn->kind == ScalarFnKind::Arith) {
+            expression.scalarFn->kind != ScalarFnKind::Type) {
           // Apply function.
           // Recursively generate operands.
           SmallVector<std::string> operandCppValues;
@@ -1066,10 +1067,14 @@ if ({0}Iter != attrs.end()) {{
               return None;
             operandCppValues.push_back(*operandCppValue);
           }
+
+          std::string prefix = expression.scalarFn->kind == ScalarFnKind::Unary
+                                   ? "unary"
+                                   : "binary";
           std::string cppIdent = llvm::formatv("value{0}", ++localCounter);
           stmts.push_back(
-              llvm::formatv("Value {0} = helper.arithfn__{1}({2});", cppIdent,
-                            expression.scalarFn->fnName,
+              llvm::formatv("Value {0} = helper.{1}__{2}({3});", cppIdent,
+                            prefix, expression.scalarFn->fnName,
                             interleaveToString(operandCppValues, ", ")));
           return cppIdent;
         }


        


More information about the Mlir-commits mailing list