[Mlir-commits] [mlir] 24357fe - [mlir][OpDSL] Add arithmetic function attributes.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Feb 28 23:46:15 PST 2022


Author: gysit
Date: 2022-03-01T07:45:47Z
New Revision: 24357fec8d706b1ef3b0a34432f7e10a8e6b2151

URL: https://github.com/llvm/llvm-project/commit/24357fec8d706b1ef3b0a34432f7e10a8e6b2151
DIFF: https://github.com/llvm/llvm-project/commit/24357fec8d706b1ef3b0a34432f7e10a8e6b2151.diff

LOG: [mlir][OpDSL] Add arithmetic function attributes.

The revision extends OpDSL with unary and binary function attributes. A function attribute, makes the operations used in the body of a structured operation configurable. For example, a pooling operation may take an aggregation function attribute that specifies if the op shall implement a min or a max pooling. The goal of this revision is to define less and more flexible operations.

We may thus for example define an element wise op:
```
linalg.elem(lhs, rhs, outs=[out], op=BinaryFn.mul)
```
If the op argument is not set the default operation is used.

Depends On D120109

Reviewed By: nicolasvasilache, aartbik

Differential Revision: https://reviews.llvm.org/D120110

Added: 
    

Modified: 
    mlir/docs/Dialects/Linalg/OpDSL.md
    mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
    mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
    mlir/python/mlir/dialects/linalg/opdsl/lang/config.py
    mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py
    mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
    mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
    mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
    mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
    mlir/test/python/dialects/linalg/opdsl/arguments.py
    mlir/test/python/dialects/linalg/opdsl/assignments.py
    mlir/test/python/dialects/linalg/opdsl/emit_pooling.py
    mlir/test/python/dialects/linalg/ops.py
    mlir/test/python/integration/dialects/linalg/opsrun.py
    mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/Dialects/Linalg/OpDSL.md b/mlir/docs/Dialects/Linalg/OpDSL.md
index 116057ef4ad8a..dd068b1f400c1 100644
--- a/mlir/docs/Dialects/Linalg/OpDSL.md
+++ b/mlir/docs/Dialects/Linalg/OpDSL.md
@@ -107,12 +107,12 @@ copy_and_scale(val, in_tensor, outs=[out_tensor])
 
 ## Index Attributes
 
-Attributes are compile-time constant parameters only accessible in index
+Index attributes are compile-time constant parameters only accessible in index
 expressions. They can be used to parameterize the access pattern of a structured
 operation, for example, by setting its strides. They cannot take part in the
 actual computation.
 
-The following example demonstrates the use of attributes:
+The following example demonstrates the use of index attributes:
 
 ```python
 @linalg_structured_op
@@ -136,9 +136,9 @@ The `strides` vector elements substitute the symbols `S.SH` and `S.SW` in the
 index expressions of the operation instance. If no strides are provided the
 `default` vector elements are used instead.
 
-Attributes are currently limited to integer vectors and only accessible in index
-expressions. An operation may have multiple attributes all of them placed at the
-end of the parameter list after the output tensors.
+Index attributes are currently limited to integer vectors and only accessible in
+index expressions. An operation may have multiple attributes all of them placed
+at the end of the parameter list after the output tensors.
 
 ## Shape-Only Tensors
 
@@ -220,6 +220,43 @@ There are also special forms:
 *   `const(value)` returns a constant value.
 *   `index(dim)` returns the iteration index in the given dimension `dim`.
 
+## Function Attributes
+
+Function attributes are compile-time constant function parameters. They can be
+used to parameterize the computation performed by a structured operation, for
+example, to support signed and unsigned computations.
+
+The following example demonstrates the use of function attributes:
+
+```python
+ at linalg_structured_op
+def elemwise_binary(
+    lhs=TensorDef(T1),
+    rhs=TensorDef(T2),
+    O=TensorDef(U, output=True),
+    fun=BinaryFnAttrDef(default=BinaryFn.add),
+    cast=TypeFnAttrDef(default=TypeFn.cast)):
+  O[None] = fun(cast(U, lhs[None]), cast(U, rhs[None]))
+```
+
+The `fun` and `cast` function attributes by default are aliases for their
+default values `BinaryFn.add` and `TypeFn.cast`, respectively. When
+instantiating the operation, the function attributes may be set to other
+functions using optional named arguments:
+
+```python
+elemwise_binary(lhs, rhs, outs=[out_tensor],
+                fun=BinaryFn.mul, cast=TypeFn.cast_unsigned)
+```
+
+In the example, the `fun` and `cast` arguments adapt the body of the operation
+to implement multiplication and unsigned casts instead of addition and signed
+casts.
+
+OpDSL supports unary, binary, and type conversion function attributes. An
+operation can take multiple attributes of 
diff erent kinds placed at the end of
+the parameter list.
+
 ## Types
 
 All types in assignment expressions are late bound based on actual input and

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
index 6090c7055ef3e..c60e1b646e533 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
@@ -58,7 +58,26 @@ def Linalg_Dialect : Dialect {
   }];
 }
 
-// Define a TypeFn enum matching the OpDSL TypeFn class.
+// Define the function attribute enums matching the OpDSL functions.
+def UnaryFn : I32EnumAttr<"UnaryFn", "", [
+  I32EnumAttrCase<"exp", 0>,
+  I32EnumAttrCase<"log", 1>
+]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "::mlir::linalg";
+}
+def BinaryFn : I32EnumAttr<"BinaryFn", "", [
+  I32EnumAttrCase<"add", 0>,
+  I32EnumAttrCase<"mul", 1>,
+  I32EnumAttrCase<"max", 2>,
+  I32EnumAttrCase<"min", 3>,
+  I32EnumAttrCase<"sub", 4>,
+  I32EnumAttrCase<"max_unsigned", 5>,
+  I32EnumAttrCase<"min_unsigned", 6>
+]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "::mlir::linalg";
+}
 def TypeFn : I32EnumAttr<"TypeFn", "", [
   I32EnumAttrCase<"cast", 0>,
   I32EnumAttrCase<"cast_unsigned", 1>
@@ -67,6 +86,12 @@ def TypeFn : I32EnumAttr<"TypeFn", "", [
   let cppNamespace = "::mlir::linalg";
 }
 
+def UnaryFnAttr : EnumAttr<Linalg_Dialect, UnaryFn, "unary_fn"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+def BinaryFnAttr : EnumAttr<Linalg_Dialect, BinaryFn, "binary_fn"> {
+  let assemblyFormat = "`<` $value `>`";
+}
 def TypeFnAttr : EnumAttr<Linalg_Dialect, TypeFn, "type_fn"> {
   let assemblyFormat = "`<` $value `>`";
 }

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 8789185b961fa..73fd114daf32e 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -1,6 +1,120 @@
 ### AUTOGENERATED from core_named_ops.py
 ### To regenerate, run: bin/update_core_linalg_named_ops.sh
 --- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: elemwise_unary
+  cpp_class_name: ElemwiseUnaryOp
+  doc: |-
+    Applies the unary function fun elementwise.
+
+    Numeric casting is performed on the input operand, promoting it to the same
+    data type as the accumulator/output.
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !LinalgOperandDefConfig
+    name: I
+    kind: input_tensor
+    type_var: T1
+    shape_map: affine_map<() -> ()>
+  - !LinalgOperandDefConfig
+    name: O
+    kind: output_tensor
+    type_var: U
+    shape_map: affine_map<() -> ()>
+  - !LinalgOperandDefConfig
+    name: fun
+    kind: unary_fn_attr
+    default_fn: exp
+  - !LinalgOperandDefConfig
+    name: cast
+    kind: type_fn_attr
+    default_fn: cast
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<() -> ()>
+    - affine_map<() -> ()>
+  iterator_types: []
+  assignments:
+  - !ScalarAssign
+    arg: O
+    value: !ScalarExpression
+      scalar_fn:
+        kind: unary
+        attr_name: fun
+        operands:
+        - !ScalarExpression
+          scalar_fn:
+            kind: type
+            attr_name: cast
+            type_var: U
+            operands:
+            - !ScalarExpression
+              scalar_arg: I
+--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: elemwise_binary
+  cpp_class_name: ElemwiseBinaryOp
+  doc: |-
+    Applies the binary function fun elementwise.
+
+    Numeric casting is performed on the input operand, promoting it to the same
+    data type as the accumulator/output.
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !LinalgOperandDefConfig
+    name: lhs
+    kind: input_tensor
+    type_var: T1
+    shape_map: affine_map<() -> ()>
+  - !LinalgOperandDefConfig
+    name: rhs
+    kind: input_tensor
+    type_var: T2
+    shape_map: affine_map<() -> ()>
+  - !LinalgOperandDefConfig
+    name: O
+    kind: output_tensor
+    type_var: U
+    shape_map: affine_map<() -> ()>
+  - !LinalgOperandDefConfig
+    name: fun
+    kind: binary_fn_attr
+    default_fn: add
+  - !LinalgOperandDefConfig
+    name: cast
+    kind: type_fn_attr
+    default_fn: cast
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<() -> ()>
+    - affine_map<() -> ()>
+    - affine_map<() -> ()>
+  iterator_types: []
+  assignments:
+  - !ScalarAssign
+    arg: O
+    value: !ScalarExpression
+      scalar_fn:
+        kind: binary
+        attr_name: fun
+        operands:
+        - !ScalarExpression
+          scalar_fn:
+            kind: type
+            attr_name: cast
+            type_var: U
+            operands:
+            - !ScalarExpression
+              scalar_arg: lhs
+        - !ScalarExpression
+          scalar_fn:
+            kind: type
+            attr_name: cast
+            type_var: U
+            operands:
+            - !ScalarExpression
+              scalar_arg: rhs
+--- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: matmul
   cpp_class_name: MatmulOp

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 24f492aed936a..dfc994368e6cf 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -108,17 +108,9 @@ static LogicalResult foldMemRefCast(Operation *op) {
 //===----------------------------------------------------------------------===//
 // Region builder helper.
 // TODO: Move this to a utility library.
-// The public methods on this class are referenced directly from generated code
-// and bind by name to math functions in the DSL as:
-//   `unary__{fnName}`
-//   `binary__{fnName}`
-// Examples:
-//   `binary__add`
-//   `binary__mul`
-//   `unary__exp`
-//   `unary__log`
-// The naming convention is intentional in order to match snake-cased DSL names.
-// See mlir-linalg-ods-yaml-gen.cpp for the code that mates to this class.
+// The public methods on this class are referenced directly from generated code.
+// Helper build the unary, binary, and type conversion functions defined by the
+// DSL. See mlir-linalg-ods-yaml-gen.cpp for the code that uses this class.
 //
 // Implementations of the math functions must be polymorphic over numeric types,
 // internally performing necessary casts. If the function application makes no
@@ -142,6 +134,98 @@ class RegionBuilderHelper {
   RegionBuilderHelper(MLIRContext *context, Block &block)
       : context(context), block(block) {}
 
+  // Build the unary functions defined by OpDSL.
+  Value buildUnaryFn(UnaryFn unaryFn, Value arg) {
+    if (!isFloatingPoint(arg))
+      llvm_unreachable("unsupported non numeric type");
+    OpBuilder builder = getBuilder();
+    switch (unaryFn) {
+    case UnaryFn::exp:
+      return builder.create<math::ExpOp>(arg.getLoc(), arg);
+    case UnaryFn::log:
+      return builder.create<math::LogOp>(arg.getLoc(), arg);
+    }
+    llvm_unreachable("unsupported unary function");
+  }
+
+  // Build the binary functions defined by OpDSL.
+  Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1) {
+    bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
+    bool allInteger = isInteger(arg0) && isInteger(arg1);
+    if (!allFloatingPoint && !allInteger)
+      llvm_unreachable("unsupported non numeric type");
+    OpBuilder builder = getBuilder();
+    switch (binaryFn) {
+    case BinaryFn::add:
+      if (allFloatingPoint)
+        return builder.create<arith::AddFOp>(arg0.getLoc(), arg0, arg1);
+      return builder.create<arith::AddIOp>(arg0.getLoc(), arg0, arg1);
+    case BinaryFn::mul:
+      if (allFloatingPoint)
+        return builder.create<arith::MulFOp>(arg0.getLoc(), arg0, arg1);
+      return builder.create<arith::MulIOp>(arg0.getLoc(), arg0, arg1);
+    case BinaryFn::max:
+      if (allFloatingPoint)
+        return builder.create<arith::MaxFOp>(arg0.getLoc(), arg0, arg1);
+      return builder.create<arith::MaxSIOp>(arg0.getLoc(), arg0, arg1);
+    case BinaryFn::min:
+      if (allFloatingPoint)
+        return builder.create<arith::MinFOp>(arg0.getLoc(), arg0, arg1);
+      return builder.create<arith::MinSIOp>(arg0.getLoc(), arg0, arg1);
+    case BinaryFn::sub:
+      if (allFloatingPoint)
+        return builder.create<arith::SubFOp>(arg0.getLoc(), arg0, arg1);
+      return builder.create<arith::SubIOp>(arg0.getLoc(), arg0, arg1);
+    case BinaryFn::max_unsigned:
+      if (allFloatingPoint)
+        return builder.create<arith::MaxFOp>(arg0.getLoc(), arg0, arg1);
+      return builder.create<arith::MaxUIOp>(arg0.getLoc(), arg0, arg1);
+    case BinaryFn::min_unsigned:
+      if (allFloatingPoint)
+        return builder.create<arith::MinFOp>(arg0.getLoc(), arg0, arg1);
+      return builder.create<arith::MinUIOp>(arg0.getLoc(), arg0, arg1);
+    }
+    llvm_unreachable("unsupported binary function");
+  }
+
+  // Build the type functions defined by OpDSL.
+  Value buildTypeFn(TypeFn typeFn, Type toType, Value operand) {
+    switch (typeFn) {
+    case TypeFn::cast:
+      return cast(toType, operand, false);
+    case TypeFn::cast_unsigned:
+      return cast(toType, operand, true);
+    }
+    llvm_unreachable("unsupported type conversion function");
+  }
+
+  void yieldOutputs(ValueRange values) {
+    OpBuilder builder = getBuilder();
+    Location loc = builder.getUnknownLoc();
+    builder.create<YieldOp>(loc, values);
+  }
+
+  Value constant(const std::string &value) {
+    OpBuilder builder = getBuilder();
+    Location loc = builder.getUnknownLoc();
+    Attribute valueAttr = parseAttribute(value, builder.getContext());
+    return builder.create<arith::ConstantOp>(loc, valueAttr.getType(),
+                                             valueAttr);
+  }
+
+  Value index(int64_t dim) {
+    OpBuilder builder = getBuilder();
+    return builder.create<IndexOp>(builder.getUnknownLoc(), dim);
+  }
+
+  Type getIntegerType(unsigned width) {
+    return IntegerType::get(context, width);
+  }
+
+  Type getFloat32Type() { return Float32Type::get(context); }
+  Type getFloat64Type() { return Float64Type::get(context); }
+
+private:
   // Generates operations to cast the given operand to a specified type.
   // If the cast cannot be performed, a warning will be issued and the
   // operand returned as-is (which will presumably yield a verification
@@ -193,136 +277,6 @@ class RegionBuilderHelper {
     return operand;
   }
 
-  Value buildTypeFn(TypeFn typeFn, Type toType, Value operand) {
-    switch (typeFn) {
-    case TypeFn::cast:
-      return cast(toType, operand, false);
-    case TypeFn::cast_unsigned:
-      return cast(toType, operand, true);
-    }
-    llvm_unreachable("unsupported type conversion function");
-  }
-
-  // NOLINTNEXTLINE(*-identifier-naming): externally called.
-  Value binary__add(Value lhs, Value rhs) {
-    OpBuilder builder = getBuilder();
-    if (isFloatingPoint(lhs))
-      return builder.create<arith::AddFOp>(lhs.getLoc(), lhs, rhs);
-    if (isInteger(lhs))
-      return builder.create<arith::AddIOp>(lhs.getLoc(), lhs, rhs);
-    llvm_unreachable("unsupported non numeric type");
-  }
-
-  // NOLINTNEXTLINE(*-identifier-naming): externally called.
-  Value unary__exp(Value x) {
-    OpBuilder builder = getBuilder();
-    if (isFloatingPoint(x))
-      return builder.create<math::ExpOp>(x.getLoc(), x);
-    llvm_unreachable("unsupported non numeric type");
-  }
-
-  // NOLINTNEXTLINE(*-identifier-naming): externally called.
-  Value unary__log(Value x) {
-    OpBuilder builder = getBuilder();
-    if (isFloatingPoint(x))
-      return builder.create<math::LogOp>(x.getLoc(), x);
-    llvm_unreachable("unsupported non numeric type");
-  }
-
-  // NOLINTNEXTLINE(*-identifier-naming): externally called.
-  Value binary__sub(Value lhs, Value rhs) {
-    OpBuilder builder = getBuilder();
-    if (isFloatingPoint(lhs))
-      return builder.create<arith::SubFOp>(lhs.getLoc(), lhs, rhs);
-    if (isInteger(lhs))
-      return builder.create<arith::SubIOp>(lhs.getLoc(), lhs, rhs);
-    llvm_unreachable("unsupported non numeric type");
-  }
-
-  // NOLINTNEXTLINE(*-identifier-naming): externally called.
-  Value binary__mul(Value lhs, Value rhs) {
-    OpBuilder builder = getBuilder();
-    if (isFloatingPoint(lhs))
-      return builder.create<arith::MulFOp>(lhs.getLoc(), lhs, rhs);
-    if (isInteger(lhs))
-      return builder.create<arith::MulIOp>(lhs.getLoc(), lhs, rhs);
-    llvm_unreachable("unsupported non numeric type");
-  }
-
-  // NOLINTNEXTLINE(*-identifier-naming): externally called.
-  Value binary__max(Value lhs, Value rhs) {
-    OpBuilder builder = getBuilder();
-    if (isFloatingPoint(lhs))
-      return builder.create<arith::MaxFOp>(lhs.getLoc(), lhs, rhs);
-    if (isInteger(lhs))
-      return builder.create<arith::MaxSIOp>(lhs.getLoc(), lhs, rhs);
-    llvm_unreachable("unsupported non numeric type");
-  }
-
-  // NOLINTNEXTLINE(*-identifier-naming): externally called.
-  Value binary__max_unsigned(Value lhs, Value rhs) {
-    OpBuilder builder = getBuilder();
-    if (isFloatingPoint(lhs))
-      return builder.create<arith::MaxFOp>(lhs.getLoc(), lhs, rhs);
-    if (isInteger(lhs))
-      return builder.create<arith::MaxUIOp>(lhs.getLoc(), lhs, rhs);
-    llvm_unreachable("unsupported non numeric type");
-  }
-
-  // NOLINTNEXTLINE(*-identifier-naming): externally called.
-  Value binary__min(Value lhs, Value rhs) {
-    OpBuilder builder = getBuilder();
-    if (isFloatingPoint(lhs))
-      return builder.create<arith::MinFOp>(lhs.getLoc(), lhs, rhs);
-    if (isInteger(lhs))
-      return builder.create<arith::MinSIOp>(lhs.getLoc(), lhs, rhs);
-    llvm_unreachable("unsupported non numeric type");
-  }
-
-  // NOLINTNEXTLINE(*-identifier-naming): externally called.
-  Value binary__min_unsigned(Value lhs, Value rhs) {
-    OpBuilder builder = getBuilder();
-    if (isFloatingPoint(lhs))
-      return builder.create<arith::MinFOp>(lhs.getLoc(), lhs, rhs);
-    if (isInteger(lhs))
-      return builder.create<arith::MinUIOp>(lhs.getLoc(), lhs, rhs);
-    llvm_unreachable("unsupported non numeric type");
-  }
-
-  void yieldOutputs(ValueRange values) {
-    assert(!values.empty() && "linalg ops must yield outputs");
-    if (values.empty())
-      return;
-    Value first = values.front();
-    OpBuilder builder = getBuilder();
-    builder.create<YieldOp>(first.getLoc(), values);
-  }
-
-  Value constant(const std::string &value) {
-    OpBuilder builder = getBuilder();
-    Location loc = builder.getUnknownLoc();
-    Attribute valueAttr = parseAttribute(value, builder.getContext());
-    return builder.create<arith::ConstantOp>(loc, valueAttr.getType(),
-                                             valueAttr);
-  }
-
-  Value index(int64_t dim) {
-    OpBuilder builder = getBuilder();
-    return builder.create<IndexOp>(builder.getUnknownLoc(), dim);
-  }
-
-  Type getIntegerType(unsigned width) {
-    return IntegerType::get(context, width);
-  }
-
-  Type getFloat32Type() { return Float32Type::get(context); }
-
-  Type getFloat64Type() { return Float64Type::get(context); }
-
-private:
-  MLIRContext *context;
-  Block █
-
   bool isFloatingPoint(Value value) { return value.getType().isa<FloatType>(); }
   bool isInteger(Value value) { return value.getType().isa<IntegerType>(); }
 
@@ -331,6 +285,9 @@ class RegionBuilderHelper {
     builder.setInsertionPointToEnd(&block);
     return builder;
   }
+
+  MLIRContext *context;
+  Block █
 };
 
 } // namespace

diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
index ef2ef30378211..f6bf0ff9a50d0 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
@@ -126,7 +126,7 @@ def _compute_reduce_dims(self, rhs: TensorExpression) -> Set[DimDef]:
     return rhs_dims - lhs_dims
 
   def __iadd__(self, rhs: TensorExpression) -> "TensorReduceFn":
-    return ReduceFnUse(BinaryFn.add, *self._compute_reduce_dims(rhs))(rhs)
+    return ReduceFnUse(BinaryFn.add, None, *self._compute_reduce_dims(rhs))(rhs)
 
   def __repr__(self):
     return (f"{self.operand_def.name}"
@@ -183,8 +183,14 @@ def to_scalar_expression(self) -> ScalarExpression:
                        f"bound to its lhs: {self}")
     full_args = [self.lhs.to_scalar_expression()
                 ] + [arg.to_scalar_expression() for arg in self.args]
-    return ScalarFn(FunctionKind.BINARY, self.reduce_use.binary_fn.fn_name,
-                    None, None, full_args).expr()
+    fn_name = None
+    attr_name = None
+    if self.reduce_use.binary_fn:
+      fn_name = self.reduce_use.binary_fn.fn_name
+    if self.reduce_use.binary_attr:
+      attr_name = self.reduce_use.binary_attr.operand_def.name
+    return ScalarFn(FunctionKind.BINARY, fn_name, attr_name, None,
+                    full_args).expr()
 
   def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]):
     for arg in self.args:
@@ -257,8 +263,8 @@ class UnaryFnType:
   def __init__(self, fn_name: str):
     self.fn_name = fn_name
 
-  def __call__(self, exp: TensorExpression) -> "TensorFn":
-    return TensorFn(FunctionKind.UNARY, self.fn_name, None, None, [exp])
+  def __call__(self, arg: TensorExpression) -> "TensorFn":
+    return TensorFn(FunctionKind.UNARY, self.fn_name, None, None, [arg])
 
   def __repr__(self):
     return f"{self.fn_name}"
@@ -345,16 +351,21 @@ class ReduceFnUse:
   A reduction use specifies the reduction function and dimensions.
   """
 
-  def __init__(self, binary_fn: BinaryFnType, *reduce_dims: DimDef):
+  def __init__(self, binary_fn: Optional[BinaryFnType],
+               binary_attr: Optional["BinaryFnAttrDef"], *reduce_dims: DimDef):
+    if bool(binary_fn) + bool(binary_attr) != 1:
+      raise ValueError("One of 'binary_fn', 'binary_attr' must be specified")
     self.binary_fn = binary_fn
+    self.binary_attr = binary_attr
     self.reduce_dims = reduce_dims
 
   def __call__(self, *args: TensorExpression) -> "TensorReduceFn":
     return TensorReduceFn(self, args)
 
   def __repr__(self):
-    return (f"reduce_{self.binary_fn.fn_name}"
-            f"({', '.join(repr(d) for d in self.reduce_dims)})")
+    fn = self.binary_fn if self.binary_fn else self.binary_attr
+    return (
+        f"reduce_{repr(fn)}({', '.join(repr(d) for d in self.reduce_dims)})")
 
 
 class ReduceFnType:
@@ -369,10 +380,10 @@ def __init__(self, binary_fn: BinaryFnType):
     self.binary_fn = binary_fn
 
   def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse:
-    return ReduceFnUse(self.binary_fn, *reduce_dims)
+    return ReduceFnUse(self.binary_fn, None, *reduce_dims)
 
   def __repr__(self):
-    return (f"reduce_{self.binary_fn.fn_name}")
+    return f"reduce_{repr(self.binary_fn)}"
 
 
 class ReduceFn:
@@ -394,7 +405,9 @@ class OperandKind(Enum):
   SCALAR = 1
   OUTPUT_TENSOR = 2
   INDEX_ATTR = 3
-  TYPE_FN_ATTR = 4
+  UNARY_FN_ATTR = 4
+  BINARY_FN_ATTR = 5
+  TYPE_FN_ATTR = 6
 
 
 class OperandDef:
@@ -441,6 +454,8 @@ def is_tensor(self) -> bool:
 
   def is_attribute(self) -> bool:
     return (self.kind == OperandKind.INDEX_ATTR or
+            self.kind == OperandKind.UNARY_FN_ATTR or
+            self.kind == OperandKind.BINARY_FN_ATTR or
             self.kind == OperandKind.TYPE_FN_ATTR)
 
   def __hash__(self):
@@ -557,6 +572,49 @@ def __init__(self, *sizes: SymbolDef, default: Sequence[int]):
         OperandKind.INDEX_ATTR, size_exprs=sizes, default_indices=default)
 
 
+class UnaryFnAttrDef:
+  """Unary function attribute definition.
+
+  Unary function attributes provide a way to make the arithmetic computation
+  parametrizable. Every attribute specifies a default unary function
+  that may be overwritten at operation instantiation time.
+  """
+
+  def __init__(self, default: "UnaryFnType"):
+    if not isinstance(default, UnaryFnType):
+      raise ValueError(f"UnaryFnAttrDef requires default of type UnaryFnType "
+                       f"but got {default}")
+    self.operand_def = OperandDef(
+        OperandKind.UNARY_FN_ATTR, default_fn=default.fn_name)
+
+  def __call__(self, arg: TensorExpression) -> TensorFn:
+    return TensorFn(FunctionKind.UNARY, None, self.operand_def, None, [arg])
+
+
+class BinaryFnAttrDef:
+  """Binary function attribute definition.
+
+  Binary function attributes provide a way to make the arithmetic computation
+  parametrizable. Every attribute specifies a default binary function
+  that may be overwritten at operation instantiation time.
+  """
+
+  def __init__(self, default: "BinaryFnType"):
+    if not isinstance(default, BinaryFnType):
+      raise ValueError(f"BinaryFnAttrDef requires default of type BinaryFnType "
+                       f"but got {default}")
+    self.operand_def = OperandDef(
+        OperandKind.BINARY_FN_ATTR, default_fn=default.fn_name)
+
+  def __call__(self, arg0: TensorExpression,
+               arg1: TensorExpression) -> TensorFn:
+    return TensorFn(FunctionKind.BINARY, None, self.operand_def, None,
+                    [arg0, arg1])
+
+  def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse:
+    return ReduceFnUse(None, self, *reduce_dims)
+
+
 class TypeFnAttrDef:
   """Type conversion function attribute definition.
 

diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py
index 12b168de184aa..ed30b8e5fc9a0 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py
@@ -309,8 +309,8 @@ def get_type(symbolic_name, position):
   def add_operand(self, operand_def: OperandDef):
     if operand_def in self.operands:
       return
-    if (operand_def.kind == OperandKind.SCALAR or
-        operand_def.kind == OperandKind.TYPE_FN_ATTR):
+    if not (operand_def.is_tensor() or
+            operand_def.kind == OperandKind.INDEX_ATTR):
       self.operands[operand_def] = OperandDefConfig(operand_def)
       return
     with self.context:

diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py
index 99ce713668d1d..bd9042ac0aacb 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py
@@ -130,7 +130,8 @@ def linalg_structured_op(dsl_func=None,
   for param_name, param in sig.parameters.items():
     param_default = param.default
     if isinstance(param_default,
-                  (TensorDef, ScalarDef, IndexAttrDef, TypeFnAttrDef)):
+                  (TensorDef, ScalarDef, IndexAttrDef, UnaryFnAttrDef,
+                   BinaryFnAttrDef, TypeFnAttrDef)):
       op_def.add_operand(param_name, param_default.operand_def)
     else:
       raise ValueError(

diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
index df4ab2249d4d4..79fc3f5a21904 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
@@ -41,7 +41,7 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
   all_arg_defs = op_config.ordered_operands
   in_arg_defs = [
       d for d in all_arg_defs
-      if d.kind == OperandKind.SCALAR or d.kind == OperandKind.INPUT_TENSOR
+      if d.kind in [OperandKind.SCALAR, OperandKind.INPUT_TENSOR]
   ]
   out_arg_defs = [
       d for d in all_arg_defs if d.kind == OperandKind.OUTPUT_TENSOR
@@ -49,8 +49,11 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
   index_attr_arg_defs = [
       d for d in all_arg_defs if d.kind == OperandKind.INDEX_ATTR
   ]
-  type_fn_attr_arg_defs = [
-      d for d in all_arg_defs if d.kind == OperandKind.TYPE_FN_ATTR
+  fn_attr_arg_defs = [
+      d for d in all_arg_defs if d.kind in [
+          OperandKind.UNARY_FN_ATTR, OperandKind.BINARY_FN_ATTR,
+          OperandKind.TYPE_FN_ATTR
+      ]
   ]
 
   # Verify outs is a sequence or a list of results.
@@ -135,28 +138,38 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
       array = np.array(index_attr_vals, dtype=np.int64)
       index_attrs[index_attr.name] = DenseElementsAttr.get(array)
 
-  # Compute the type function attribute mapping.
-  type_fn_attr_mapping = {}
-  for type_fn_attr in type_fn_attr_arg_defs:
-    attr_val = type_fn_attr.operand_def.default_fn
-    if type_fn_attr.name in attrs:
-      type_fn = attrs.get(type_fn_attr.name)
-      if not isinstance(type_fn, TypeFnType):
-        raise ValueError(f"Attribute {type_fn_attr.name} needs to be of type "
-                         f"TypeFnType but got {type(attr_val)}")
-      attr_val = type_fn.fn_name
-    assert attr_val, "Type function attribute has no value"
-    type_fn_attr_mapping[type_fn_attr.name] = attr_val
+  # Compute the function attribute mapping.
+  fn_attr_mapping = {}
+  for fn_attr in fn_attr_arg_defs:
+    attr_val = fn_attr.operand_def.default_fn
+    attr_kind = fn_attr.kind
+    if fn_attr.name in attrs:
+      fn = attrs.get(fn_attr.name)
+      if attr_kind == OperandKind.UNARY_FN_ATTR:
+        if not isinstance(fn, UnaryFnType):
+          raise ValueError(f"Attribute {fn_attr.name} needs to be of type "
+                           f"UnaryFnType but got {type(attr_val)}")
+      elif attr_kind == OperandKind.BINARY_FN_ATTR:
+        if not isinstance(fn, BinaryFnType):
+          raise ValueError(f"Attribute {fn_attr.name} needs to be of type "
+                           f"BinaryFnType but got {type(attr_val)}")
+      else:
+        if not isinstance(fn, TypeFnType):
+          raise ValueError(f"Attribute {fn_attr.name} needs to be of type "
+                           f"TypeFnType but got {type(attr_val)}")
+      attr_val = fn.fn_name
+    assert attr_val, "Function attribute has no value"
+    fn_attr_mapping[fn_attr.name] = (attr_val, attr_kind)
 
   return (all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types,
           type_mapping, indexing_maps_attr, iterator_types_attr, index_attrs,
-          type_fn_attr_mapping, block_arg_types)
+          fn_attr_mapping, block_arg_types)
 
 
 def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value,
                                outs: ValueList, **attrs: Sequence[int]):
   all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \
-  indexing_maps_attr, iterator_types_attr, index_attrs, type_fn_attr_mapping, \
+  indexing_maps_attr, iterator_types_attr, index_attrs, fn_attr_mapping, \
   block_arg_types = \
      prepare_common_structured_op(op_config, *ins, outs = outs, **attrs)
 
@@ -193,7 +206,7 @@ def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value,
   block_arg_mapping = dict(zip(block_arg_names, block.arguments))
   with InsertionPoint(block):
     body_builder = _BodyBuilder(type_mapping, block_arg_mapping,
-                                type_fn_attr_mapping)
+                                fn_attr_mapping)
     for assignment in op_config.assignments:
       body_builder.assign(assignment)
     body_builder.yield_outputs(*_get_operand_def_names(*out_arg_defs))
@@ -208,7 +221,7 @@ def emit_named_structured_op(op_config: LinalgStructuredOpConfig, op_name: str,
                              op_class_name: str, *ins: Value, outs: ValueList,
                              **attrs: Sequence[int]):
   all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \
-  indexing_maps_attr, iterator_types_attr, index_attrs, type_fn_attr_mapping, \
+  indexing_maps_attr, iterator_types_attr, index_attrs, fn_attr_mapping, \
   block_arg_types = \
      prepare_common_structured_op(op_config, *ins, outs = outs, **attrs)
 
@@ -225,10 +238,12 @@ def emit_named_structured_op(op_config: LinalgStructuredOpConfig, op_name: str,
   for name, value in index_attrs.items():
     named_op.operation.attributes[name] = value
 
-  # Set the type function attributes.
-  for name, value in type_fn_attr_mapping.items():
+  # Compute the function attributes by combining operand kind and function name.
+  for name, (fn_name, kind) in fn_attr_mapping.items():
+    assert kind.name.lower().endswith("_attr")
+    enum_name = kind.name.lower()[:-5]
     named_op.operation.attributes[name] = Attribute.parse(
-        f"#linalg.type_fn<{value}>")
+        f"#linalg.{enum_name}<{fn_name}>")
 
   linalg.fill_builtin_region(named_op.operation)
 
@@ -242,11 +257,11 @@ class _BodyBuilder:
   """Constructs a structured op body by evaluating assignments."""
 
   def __init__(self, type_mapping: Dict[str, Type],
-               block_arg_mapping: Dict[str, Value],
-               type_fn_attr_mapping: Dict[str, str]):
+               block_arg_mapping: Dict[str, Value], fn_attr_mapping: Dict[str,
+                                                                          str]):
     self.type_mapping = type_mapping
     self.block_arg_mapping = block_arg_mapping
-    self.type_fn_attr_mapping = type_fn_attr_mapping
+    self.fn_attr_mapping = fn_attr_mapping
     self.yield_mapping = dict()  # type: Dict[str, Value]
 
   def assign(self, assignment: ScalarAssign):
@@ -270,21 +285,18 @@ def expression(self, expr: ScalarExpression) -> Value:
       dim_attr = IntegerAttr.get(
           IntegerType.get_signless(64), expr.scalar_index.dim)
       return linalg.IndexOp(dim_attr).result
-    elif expr.scalar_fn and expr.scalar_fn.kind is not FunctionKind.TYPE:
+    elif expr.scalar_fn:
       kind = expr.scalar_fn.kind.name.lower()
-      fn = self._get_function(f"_{kind}_{expr.scalar_fn.fn_name}")
+      fn_name = expr.scalar_fn.fn_name
+      if expr.scalar_fn.attr_name:
+        fn_name, _ = self.fn_attr_mapping[expr.scalar_fn.attr_name]
+      fn = self._get_function(f"_{kind}_{fn_name}")
       operand_values = [
           self.expression(operand) for operand in expr.scalar_fn.operands
       ]
+      if expr.scalar_fn.kind == FunctionKind.TYPE:
+        operand_values = [expr.scalar_fn.type_var.name] + operand_values
       return fn(*operand_values)
-    elif expr.scalar_fn and expr.scalar_fn.kind is FunctionKind.TYPE:
-      kind = expr.scalar_fn.kind.name.lower()
-      fn_name = expr.scalar_fn.fn_name
-      if expr.scalar_fn.attr_name:
-        fn_name = self.type_fn_attr_mapping[expr.scalar_fn.attr_name]
-      fn = self._get_function(f"_{kind}_{fn_name}")
-      operand_value = self.expression(expr.scalar_fn.operands[0])
-      return fn(expr.scalar_fn.type_var.name, operand_value)
     raise NotImplementedError(f"Unimplemented scalar body expression: {expr}")
 
   def yield_outputs(self, *output_names: str):

diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index 340f4db4471bb..b7a827bf649b2 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -6,6 +6,35 @@
 Batch = S.Batch
 
 
+ at linalg_structured_op
+def elemwise_unary(
+    I=TensorDef(T1),
+    O=TensorDef(U, output=True),
+    fun=UnaryFnAttrDef(default=UnaryFn.exp),
+    cast=TypeFnAttrDef(default=TypeFn.cast)):
+  """Applies the unary function fun elementwise.
+
+  Numeric casting is performed on the input operand, promoting it to the same
+  data type as the accumulator/output.
+  """
+  O[None] = fun(cast(U, I[None]))
+
+
+ at linalg_structured_op
+def elemwise_binary(
+    lhs=TensorDef(T1),
+    rhs=TensorDef(T2),
+    O=TensorDef(U, output=True),
+    fun=BinaryFnAttrDef(default=BinaryFn.add),
+    cast=TypeFnAttrDef(default=TypeFn.cast)):
+  """Applies the binary function fun elementwise.
+
+  Numeric casting is performed on the input operand, promoting it to the same
+  data type as the accumulator/output.
+  """
+  O[None] = fun(cast(U, lhs[None]), cast(U, rhs[None]))
+
+
 @linalg_structured_op
 def matmul(
     A=TensorDef(T1, S.M, S.K),

diff  --git a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
index 3778315ff5a61..21f00268a9c12 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
@@ -292,16 +292,48 @@ func @generalize_fill_rng_2d_i32(%min: f64, %max: f64, %seed: i32, %O: tensor<16
 
 // -----
 
-func @generalize_soft_plus_2d_f32(%input: tensor<16x32xf32>, %output: tensor<16x32xf32>) -> tensor<16x32xf32> {
-  %0 = linalg.soft_plus_2d ins(%input: tensor<16x32xf32>) outs(%output: tensor<16x32xf32>) -> tensor<16x32xf32>
-  return %0: tensor<16x32xf32>
+// Verifies the default value of the fun attribute is an exp op.
+func @generalize_elemwise_exp(%lhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> {
+  %0 = linalg.elemwise_unary ins(%lhs: tensor<4x8xf32>) outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32>
+  return %0: tensor<4x8xf32>
 }
 
-// CHECK-LABEL: @generalize_soft_plus_2d_f32
-//      CHECK: %[[C1:.+]] = arith.constant 1.000000e+00 : f32
-//      CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32
-// CHECK-NEXT:   %[[EXP:.+]] = math.exp %[[IN]] : f32
-// CHECK-NEXT:   %[[SUM:.+]] = arith.addf %[[EXP]], %[[C1]] : f32
-// CHECK-NEXT:   %[[LOG:.+]] = math.log %[[SUM]] : f32
-// CHECK-NEXT:   linalg.yield %[[LOG]] : f32
-// CHECK-NEXT: -> tensor<16x32xf32>
+// CHECK-LABEL: @generalize_elemwise_exp
+// CHECK:        = math.exp
+
+// -----
+
+// Verifies the fun attribute controls the unary function used.
+func @generalize_elemwise_log(%lhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> {
+  %0 = linalg.elemwise_unary {fun = #linalg.unary_fn<log>}
+                              ins(%lhs: tensor<4x8xf32>) outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32>
+  return %0: tensor<4x8xf32>
+}
+
+// CHECK-LABEL: @generalize_elemwise_log
+// CHECK:        = math.log
+
+// -----
+
+// Verifies the default value of the fun attribute is an add op.
+func @generalize_elemwise_add(%lhs : tensor<4x8xf32>, %rhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> {
+  %0 = linalg.elemwise_binary ins(%lhs, %rhs: tensor<4x8xf32>, tensor<4x8xf32>)
+                              outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32>
+  return %0: tensor<4x8xf32>
+}
+
+// CHECK-LABEL: @generalize_elemwise_add
+// CHECK:        = arith.addf
+
+// -----
+
+// Verifies the fun attribute controls the binary function used.
+func @generalize_elemwise_mul(%lhs : tensor<4x8xf32>, %rhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> {
+  %0 = linalg.elemwise_binary {fun = #linalg.binary_fn<mul>}
+                              ins(%lhs, %rhs: tensor<4x8xf32>, tensor<4x8xf32>)
+                              outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32>
+  return %0: tensor<4x8xf32>
+}
+
+// CHECK-LABEL: @generalize_elemwise_mul
+// CHECK:        = arith.mulf

diff  --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
index dc1e1809eb46b..2defebbba781f 100644
--- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
+++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
@@ -111,7 +111,7 @@ structured_op: !LinalgStructuredOpConfig
 #   IMPL-DAG:  Value [[VAL1:[a-z0-9]+]] = helper.buildTypeFn(castVal, block.getArgument(0).getType(), [[VAL0]]);
 #   IMPL-DAG:  Value [[VAL2:[a-z0-9]+]] = helper.index(1);
 #   IMPL-DAG:  Value [[VAL3:[a-z0-9]+]] = helper.buildTypeFn(castVal, block.getArgument(0).getType(), [[VAL2]]);
-#   IMPL-DAG:  Value [[VAL4:[a-z0-9]+]] = helper.binary__add([[VAL1]], [[VAL3]]);
+#   IMPL-DAG:  Value [[VAL4:[a-z0-9]+]] = helper.buildBinaryFn(BinaryFn::add, [[VAL1]], [[VAL3]]);
 
 
 # @linalg_structured_op
@@ -255,14 +255,15 @@ structured_op: !LinalgStructuredOpConfig
 #  IMPL-NEXT:    AffineMap scalarMap = AffineMap::get(getNumParallelLoops(), 0, context);
 #  IMPL-NEXT:    AffineMap tensorMap = AffineMap::getMultiDimIdentityMap(
 
-
 # @linalg_structured_op
-# def test4(O=TensorDef(T, S.M, S.N, output=True)):
+# def test4(O=TensorDef(T, S.M, S.N, output=True),
+#           unary_fun=UnaryFnAttrDef(default=UnaryFn.exp),
+#           binary_fun=BinaryFnAttrDef(default=BinaryFn.add)):
 #   """Title.
 
 #   Detailed description.
 #   """
-#   O[D.m, D.n] = BinaryFn.add(UnaryFn.exp(O[D.m, D.n]), O[D.m, D.n])
+#   O[D.m, D.n] = binary_fun(unary_fun(O[D.m, D.n]), O[D.m, D.n])
 
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
@@ -279,6 +280,14 @@ structured_op: !LinalgStructuredOpConfig
     kind: output_tensor
     type_var: T
     shape_map: affine_map<()[s0, s1] -> (s0, s1)>
+  - !LinalgOperandDefConfig
+    name: unary_fun
+    kind: unary_fn_attr
+    default_fn: exp
+  - !LinalgOperandDefConfig
+    name: binary_fun
+    kind: binary_fn_attr
+    default_fn: add
   indexing_maps: !LinalgIndexingMapsConfig
     static_indexing_maps:
     - affine_map<(d0, d1)[s0, s1] -> (d0, d1)>
@@ -291,21 +300,36 @@ structured_op: !LinalgStructuredOpConfig
     value: !ScalarExpression
       scalar_fn:
         kind: binary
-        fn_name: add
+        attr_name: binary_fun
         operands:
         - !ScalarExpression
           scalar_fn:
             kind: unary
-            fn_name: exp
+            attr_name: unary_fun
             operands:
             - !ScalarExpression
               scalar_arg: O
         - !ScalarExpression
           scalar_arg: O
 
+# ODS-LABEL:  def Test4Op : LinalgStructuredBase_Op<"test4"
+
+#       ODS:  let arguments =
+#  ODS-NEXT:    Variadic<AnyType>:$inputs,
+#  ODS-NEXT:    Variadic<AnyShaped>:$outputs,
+#  ODS-NEXT:    DefaultValuedAttr<UnaryFnAttr, "UnaryFn::exp">:$unary_fun,
+#  ODS-NEXT:    DefaultValuedAttr<BinaryFnAttr, "BinaryFn::add">:$binary_fun
+
+#       ODS:    "Attribute":$unary_fun, "Attribute":$binary_fun,
+
+#       ODS:    $_state.addAttribute("unary_fun", unary_fun)
+#  ODS-NEXT:    $_state.addAttribute("binary_fun", binary_fun)
+
 # IMPL-LABEL:  void Test4Op::regionBuilder(ImplicitLocOpBuilder &b,
 #  IMPL-NEXT:    Block &block, ArrayRef<NamedAttribute> attrs)
+#       IMPL:  UnaryFn unary_funVal = UnaryFn::exp
+#       IMPL:  BinaryFn binary_funVal = BinaryFn::add
 
-#       IMPL:  Value [[VAL0:[a-z0-9]+]] = helper.unary__exp(block.getArgument(0))
-#  IMPL-NEXT:  Value [[VAL1:[a-z0-9]+]] = helper.binary__add([[VAL0]], block.getArgument(0))
+#       IMPL:  Value [[VAL0:[a-z0-9]+]] = helper.buildUnaryFn(unary_funVal, block.getArgument(0))
+#  IMPL-NEXT:  Value [[VAL1:[a-z0-9]+]] = helper.buildBinaryFn(binary_funVal, [[VAL0]], block.getArgument(0))
 #  IMPL-NEXT:  yields.push_back([[VAL1]])

diff  --git a/mlir/test/python/dialects/linalg/opdsl/arguments.py b/mlir/test/python/dialects/linalg/opdsl/arguments.py
index c4f3d91ff10ab..853627611987c 100644
--- a/mlir/test/python/dialects/linalg/opdsl/arguments.py
+++ b/mlir/test/python/dialects/linalg/opdsl/arguments.py
@@ -18,6 +18,12 @@
 # CHECK:     kind: output_tensor
 # CHECK:     type_var: U
 # CHECK:     shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
+# CHECK:     name: bfn
+# CHECK:     kind: binary_fn_attr
+# CHECK:     default_fn: mul
+# CHECK:     name: ufn
+# CHECK:     kind: unary_fn_attr
+# CHECK:     default_fn: exp
 # CHECK:     name: cast
 # CHECK:     kind: type_fn_attr
 # CHECK:     default_fn: cast
@@ -26,8 +32,10 @@ def matmul(
     A=TensorDef(T, S.M, S.K),
     B=TensorDef(T, S.K, S.N),
     C=TensorDef(U, S.M, S.N, output=True),
+    bfn=BinaryFnAttrDef(default=BinaryFn.mul),
+    ufn=UnaryFnAttrDef(default=UnaryFn.exp),
     cast=TypeFnAttrDef(default=TypeFn.cast)):
-  C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
+  C[D.m, D.n] += bfn(cast(U, A[D.m, D.k]), cast(U, B[D.k, D.n]))
 
 
 # CHECK: ---

diff  --git a/mlir/test/python/dialects/linalg/opdsl/assignments.py b/mlir/test/python/dialects/linalg/opdsl/assignments.py
index f93e0704a1e36..d8ddc24454914 100644
--- a/mlir/test/python/dialects/linalg/opdsl/assignments.py
+++ b/mlir/test/python/dialects/linalg/opdsl/assignments.py
@@ -10,10 +10,12 @@
 # CHECK:    arg: C
 # CHECK:    value:
 # CHECK:      scalar_fn:
+# CHECK:        kind: binary
 # CHECK:        fn_name: add
 # CHECK:        operands:
 # CHECK:          scalar_fn:
-# CHECK:            fn_name: mul
+# CHECK:            kind: binary
+# CHECK:            attr_name: mul
 # CHECK:            operands:
 # CHECK:              scalar_fn:
 # CHECK:                kind: type
@@ -32,8 +34,9 @@ def matmul(
     A=TensorDef(T, S.M, S.K),
     B=TensorDef(T, S.K, S.N),
     C=TensorDef(U, S.M, S.N, output=True),
+    mul=BinaryFnAttrDef(default=BinaryFn.mul),
     cast=TypeFnAttrDef(default=TypeFn.cast)):
-  C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
+  C[D.m, D.n] += mul(cast(U, A[D.m, D.k]), cast(U, B[D.k, D.n]))
 
 
 # CHECK: ---
@@ -69,14 +72,21 @@ def matmul(
 # CHECK:            fn_name: cast
 # CHECK:            type_var: T
 # CHECK:            operands:
-# CHECK:              scalar_const: '1.{{[0]*}}e+03 : f64'
+# CHECK:              scalar_fn:
+# CHECK:                kind: unary
+# CHECK:                attr_name: exp
+# CHECK:                operands:
+# CHECK:                  scalar_const: '1.{{[0]*}}e+03 : f64'
 @linalg_structured_op
-def constants(O=TensorDef(T, S.M, S.K, output=True)):
+def constants(
+    O=TensorDef(T, S.M, S.K, output=True),
+    exp=UnaryFnAttrDef(default=UnaryFn.exp)):
   pi = TypeFn.cast(T, const(3.1415926535897931))
   cst42 = TypeFn.cast(T, const(42))
-  cst1000 = TypeFn.cast(T, const(1e+3))
+  cst1000 = TypeFn.cast(T, exp(const(1e+3)))
   O[D.m, D.n] = UnaryFn.exp(pi) + cst42 - cst1000
 
+
 # CHECK: ---
 # CHECK-LABEL: indices
 # CHECK: assignments:

diff  --git a/mlir/test/python/dialects/linalg/opdsl/emit_pooling.py b/mlir/test/python/dialects/linalg/opdsl/emit_pooling.py
index 35ec8540cb4b2..ded97cd7b8220 100644
--- a/mlir/test/python/dialects/linalg/opdsl/emit_pooling.py
+++ b/mlir/test/python/dialects/linalg/opdsl/emit_pooling.py
@@ -12,55 +12,18 @@
 
 
 @linalg_structured_op
-def pooling_max_poly(
+def pooling_poly(
     I=TensorDef(T1, S.N, S.H, S.W, S.C),
     K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
     O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
+    reduce=BinaryFnAttrDef(default=BinaryFn.max),
+    cast=TypeFnAttrDef(default=TypeFn.cast),
     strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
     dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
   domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
-  O[D.n, D.oh, D.ow, D.c] = ReduceFn.max[D.kh, D.kw](
-      TypeFn.cast(
-          U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
-
-
- at linalg_structured_op
-def pooling_max_unsigned_poly(
-    I=TensorDef(T1, S.N, S.H, S.W, S.C),
-    K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
-    O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
-    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
-    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
-  domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
-  O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_unsigned[D.kh, D.kw](
-      TypeFn.cast_unsigned(
-          U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
-
-
- at linalg_structured_op
-def pooling_min_poly(
-    I=TensorDef(T1, S.N, S.H, S.W, S.C),
-    K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
-    O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
-    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
-    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
-  domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
-  O[D.n, D.oh, D.ow, D.c] = ReduceFn.min[D.kh, D.kw](
-      TypeFn.cast(
-          U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
-
-
- at linalg_structured_op
-def pooling_min_unsigned_poly(
-    I=TensorDef(T1, S.N, S.H, S.W, S.C),
-    K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
-    O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
-    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
-    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
-  domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
-  O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_unsigned[D.kh, D.kw](
-      TypeFn.cast_unsigned(
-          U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
+  O[D.n, D.oh, D.ow, D.c] = reduce[D.kh, D.kw](
+      cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
+                D.c]))
 
 
 with Context() as ctx, Location.unknown():
@@ -88,7 +51,7 @@ def pooling_min_unsigned_poly(
         RankedTensorType.get((2, 2), f32),
         RankedTensorType.get((1, 2, 4, 1), i32))
     def test_f32i32_max_pooling(input, shape, init_result):
-      return pooling_max_poly(
+      return pooling_poly(
           input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2])
 
     # CHECK-LABEL: @test_f32i32_max_unsigned_pooling
@@ -99,8 +62,14 @@ def test_f32i32_max_pooling(input, shape, init_result):
         RankedTensorType.get((2, 2), f32),
         RankedTensorType.get((1, 2, 4, 1), i32))
     def test_f32i32_max_unsigned_pooling(input, shape, init_result):
-      return pooling_max_unsigned_poly(
-          input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2])
+      return pooling_poly(
+          input,
+          shape,
+          outs=[init_result],
+          reduce=BinaryFn.max_unsigned,
+          cast=TypeFn.cast_unsigned,
+          strides=[2, 4],
+          dilations=[1, 2])
 
     # CHECK-LABEL: @test_f32f32_max_pooling
     # CHECK: linalg.generic
@@ -115,7 +84,7 @@ def test_f32i32_max_unsigned_pooling(input, shape, init_result):
         RankedTensorType.get((2, 2), f32),
         RankedTensorType.get((1, 2, 4, 1), f32))
     def test_f32f32_max_pooling(input, shape, init_result):
-      return pooling_max_poly(
+      return pooling_poly(
           input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2])
 
     # CHECK-LABEL: @test_f32i32_min_pooling
@@ -126,8 +95,13 @@ def test_f32f32_max_pooling(input, shape, init_result):
         RankedTensorType.get((2, 2), f32),
         RankedTensorType.get((1, 2, 4, 1), i32))
     def test_f32i32_min_pooling(input, shape, init_result):
-      return pooling_min_poly(
-          input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2])
+      return pooling_poly(
+          input,
+          shape,
+          outs=[init_result],
+          reduce=BinaryFn.min,
+          strides=[2, 4],
+          dilations=[1, 2])
 
     # CHECK-LABEL: @test_f32i32_min_unsigned_pooling
     # CHECK:   = arith.fptoui
@@ -137,8 +111,14 @@ def test_f32i32_min_pooling(input, shape, init_result):
         RankedTensorType.get((2, 2), f32),
         RankedTensorType.get((1, 2, 4, 1), i32))
     def test_f32i32_min_unsigned_pooling(input, shape, init_result):
-      return pooling_min_unsigned_poly(
-          input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2])
+      return pooling_poly(
+          input,
+          shape,
+          outs=[init_result],
+          reduce=BinaryFn.min_unsigned,
+          cast=TypeFn.cast_unsigned,
+          strides=[2, 4],
+          dilations=[1, 2])
 
     # CHECK-LABEL: @test_f32f32_min_pooling
     # CHECK:   = arith.minf
@@ -147,8 +127,13 @@ def test_f32i32_min_unsigned_pooling(input, shape, init_result):
         RankedTensorType.get((2, 2), f32),
         RankedTensorType.get((1, 2, 4, 1), f32))
     def test_f32f32_min_pooling(input, shape, init_result):
-      return pooling_min_poly(
-          input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2])
+      return pooling_poly(
+          input,
+          shape,
+          outs=[init_result],
+          reduce=BinaryFn.min,
+          strides=[2, 4],
+          dilations=[1, 2])
 
 
 print(module)

diff  --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index 8b686e69cb138..00a8406373855 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -94,20 +94,27 @@ def testNamedStructuredOpCustomForm():
     with InsertionPoint(module.body):
 
       @builtin.FuncOp.from_py_func(
-          RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8),
-                                                                   f32))
+          RankedTensorType.get((4, 8), f32), RankedTensorType.get((4, 8), f32))
       def named_form(lhs, rhs):
         init_result = linalg.InitTensorOp([4, 8], f32)
-        # First check the named form with custom format
-        #      CHECK: linalg.matmul
-        #      CHECK: cast = #linalg.type_fn<cast_unsigned>
-        #  CHECK-NOT: linalg.memoized_indexing_maps
-        # CHECK-SAME:    ins(%{{.*}} : tensor<4x16xf32>, tensor<16x8xf32>)
-        # CHECK-SAME:   outs(%{{.*}} : tensor<4x8xf32>)
-        # CHECK-SAME:   -> tensor<4x8xf32>
-        # CHECK-NEXT: return
-        return linalg.matmul(
-            lhs, rhs, outs=[init_result.result], cast=TypeFn.cast_unsigned)
+        # Check for the named form with custom format
+        #      CHECK: linalg.elemwise_unary
+        # CHECK-SAME:    cast = #linalg.type_fn<cast>
+        # CHECK-SAME:    fun = #linalg.unary_fn<exp>
+        # CHECK-SAME:    ins(%{{.*}} : tensor<4x8xf32>) outs(%{{.*}} : tensor<4x8xf32>)
+        unary_result = linalg.elemwise_unary(lhs, outs=[init_result.result])
+        #      CHECK: linalg.elemwise_binary
+        # CHECK-SAME:    cast = #linalg.type_fn<cast_unsigned>
+        # CHECK-SAME:    fun = #linalg.binary_fn<mul>
+        # CHECK-SAME:    ins(%{{.*}}, %{{.*}} : tensor<4x8xf32>, tensor<4x8xf32>) outs(%{{.*}} : tensor<4x8xf32>)
+        #      CHECK: return
+        binary_result = linalg.elemwise_binary(
+            lhs,
+            rhs,
+            outs=[init_result.result],
+            fun=BinaryFn.mul,
+            cast=TypeFn.cast_unsigned)
+        return unary_result, binary_result
 
   print(module)
 
@@ -130,7 +137,8 @@ def named_form(lhs, rhs):
         # CHECK-NEXT:    arith.mulf{{.*}} (f32, f32) -> f32
         # CHECK-NEXT:    arith.addf{{.*}} (f32, f32) -> f32
         # CHECK-NEXT:    linalg.yield{{.*}} (f32) -> ()
-        # CHECK-NEXT:    operand_segment_sizes = dense<[2, 1]> : vector<2xi32>
+        # CHECK-NEXT:    cast = #linalg.type_fn<cast>
+        # CHECK-SAME:    operand_segment_sizes = dense<[2, 1]> : vector<2xi32>
         # CHECK-SAME: (tensor<4x16xf32>, tensor<16x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32>
         return linalg.matmul(lhs, rhs, outs=[init_result.result])
 

diff  --git a/mlir/test/python/integration/dialects/linalg/opsrun.py b/mlir/test/python/integration/dialects/linalg/opsrun.py
index 1416aabc81133..c4e580f89c94d 100644
--- a/mlir/test/python/integration/dialects/linalg/opsrun.py
+++ b/mlir/test/python/integration/dialects/linalg/opsrun.py
@@ -19,6 +19,37 @@ def log(*args):
   sys.stderr.flush()
 
 
+elemwise_boiler = """
+func @main() -> f32 attributes {llvm.emit_c_interface} {
+  %v0 = arith.constant 0.0 : f32
+  %v1 = arith.constant 1.0 : f32
+  %v2 = arith.constant 2.0 : f32
+
+  %lhs = memref.alloc() : memref<4x8xf32>
+  %rhs = memref.alloc() : memref<4x8xf32>
+  %O0 = memref.alloc() : memref<4x8xf32>
+  %O1 = memref.alloc() : memref<4x8xf32>
+  linalg.fill(%v1, %lhs) : f32, memref<4x8xf32>
+  linalg.fill(%v2, %rhs) : f32, memref<4x8xf32>
+  linalg.fill(%v0, %O0) : f32, memref<4x8xf32>
+  linalg.fill(%v0, %O1) : f32, memref<4x8xf32>
+
+  call @elemwise_exp_add_on_buffers(%lhs, %rhs, %O0) :
+    (memref<4x8xf32>, memref<4x8xf32>, memref<4x8xf32>) -> ()
+  call @elemwise_log_mul_on_buffers(%lhs, %rhs, %O1) :
+    (memref<4x8xf32>, memref<4x8xf32>, memref<4x8xf32>) -> ()
+
+  %c0 = arith.constant 0 : index
+  %res0 = memref.load %O0[%c0, %c0] : memref<4x8xf32>
+  %res1 = memref.load %O1[%c0, %c0] : memref<4x8xf32>
+
+  %0 = arith.addf %res0, %res1 : f32
+
+  // TODO: FFI-based solution to allow testing and printing with python code.
+  return %0 : f32
+}
+"""
+
 matmul_boiler = """
 func @main() -> f32 attributes {llvm.emit_c_interface} {
   %v0 = arith.constant 0.0 : f32
@@ -166,13 +197,93 @@ def transform(module, boilerplate):
 
   pm = PassManager.parse(
       "builtin.func(convert-linalg-to-loops, lower-affine, " +
-      "convert-scf-to-cf, arith-expand, memref-expand), convert-vector-to-llvm,"
-      + "convert-memref-to-llvm, convert-std-to-llvm," +
+      "convert-math-to-llvm, convert-scf-to-cf, arith-expand, memref-expand), "
+      + "convert-vector-to-llvm, convert-memref-to-llvm, convert-std-to-llvm," +
       "reconcile-unrealized-casts")
   pm.run(mod)
   return mod
 
 
+def test_elemwise_builtin():
+  with Context() as ctx, Location.unknown():
+    module = Module.create()
+    f32 = F32Type.get()
+    i8 = IntegerType.get_signless(8)
+    with InsertionPoint(module.body):
+
+      @builtin.FuncOp.from_py_func(
+          MemRefType.get((4, 8), f32), MemRefType.get((4, 8), f32),
+          MemRefType.get((4, 8), f32))
+      def elemwise_exp_add_on_buffers(lhs, rhs, out):
+        linalg.elemwise_unary(lhs, outs=[out])
+        linalg.elemwise_binary(out, rhs, outs=[out])
+
+      @builtin.FuncOp.from_py_func(
+          MemRefType.get((4, 8), f32), MemRefType.get((4, 8), f32),
+          MemRefType.get((4, 8), f32))
+      def elemwise_log_mul_on_buffers(lhs, rhs, out):
+        linalg.elemwise_unary(lhs, outs=[out], fun=UnaryFn.log)
+        linalg.elemwise_binary(out, rhs, outs=[out], fun=BinaryFn.mul)
+
+    execution_engine = ExecutionEngine(transform(module, elemwise_boiler))
+
+    # TODO: FFI-based solution to allow testing and printing with python code.
+    # Prepare arguments: one result f32.
+    # Arguments must be passed as pointers.
+    c_float_p = ctypes.c_float * 1
+    res = c_float_p(-1.)
+    execution_engine.invoke("main", res)
+
+    log("RESULT: ", res[0])
+    # elemwise_exp_add_on_buffers: exp(1.0) + 2.0 = 4.71828182846
+    # elemwise_log_mul_on_buffers: log(1.0) * 2.0 = 0.0
+    # CHECK: RESULT: 4.71828
+
+
+test_elemwise_builtin()
+
+
+def test_elemwise_generic():
+  with Context() as ctx, Location.unknown():
+    module = Module.create()
+    f32 = F32Type.get()
+    i8 = IntegerType.get_signless(8)
+    with InsertionPoint(module.body):
+
+      @builtin.FuncOp.from_py_func(
+          MemRefType.get((4, 8), f32), MemRefType.get((4, 8), f32),
+          MemRefType.get((4, 8), f32))
+      def elemwise_exp_add_on_buffers(lhs, rhs, out):
+        linalg.elemwise_unary(lhs, outs=[out], emit_generic=True)
+        linalg.elemwise_binary(out, rhs, outs=[out], emit_generic=True)
+
+      @builtin.FuncOp.from_py_func(
+          MemRefType.get((4, 8), f32), MemRefType.get((4, 8), f32),
+          MemRefType.get((4, 8), f32))
+      def elemwise_log_mul_on_buffers(lhs, rhs, out):
+        linalg.elemwise_unary(
+            lhs, outs=[out], fun=UnaryFn.log, emit_generic=True)
+        linalg.elemwise_binary(
+            out, rhs, outs=[out], fun=BinaryFn.mul, emit_generic=True)
+
+    execution_engine = ExecutionEngine(transform(module, elemwise_boiler))
+
+    # TODO: FFI-based solution to allow testing and printing with python code.
+    # Prepare arguments: one result f32.
+    # Arguments must be passed as pointers.
+    c_float_p = ctypes.c_float * 1
+    res = c_float_p(-1.)
+    execution_engine.invoke("main", res)
+
+    log("RESULT: ", res[0])
+    # elemwise_exp_add_on_buffers: exp(1.0) + 2.0 = 4.71828182846
+    # elemwise_log_mul_on_buffers: log(1.0) * 2.0 = 0.0
+    # CHECK: RESULT: 4.71828
+
+
+test_elemwise_generic()
+
+
 def test_matmul_builtin():
   with Context() as ctx, Location.unknown():
     module = Module.create()

diff  --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
index 7685d1a53e313..a535ad7654809 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
@@ -66,6 +66,8 @@ enum class LinalgOperandDefKind {
   Scalar,
   OutputTensor,
   IndexAttr,
+  UnaryFnAttr,
+  BinaryFnAttr,
   TypeFnAttr
 };
 
@@ -208,6 +210,8 @@ struct ScalarEnumerationTraits<LinalgOperandDefKind> {
     io.enumCase(value, "scalar", LinalgOperandDefKind::Scalar);
     io.enumCase(value, "output_tensor", LinalgOperandDefKind::OutputTensor);
     io.enumCase(value, "index_attr", LinalgOperandDefKind::IndexAttr);
+    io.enumCase(value, "unary_fn_attr", LinalgOperandDefKind::UnaryFnAttr);
+    io.enumCase(value, "binary_fn_attr", LinalgOperandDefKind::BinaryFnAttr);
     io.enumCase(value, "type_fn_attr", LinalgOperandDefKind::TypeFnAttr);
   }
 };
@@ -430,6 +434,45 @@ static ScalarAssign *findAssignment(StringRef name,
   return nullptr;
 }
 
+// Return true if the operand is a function attribute.
+static bool isFunctionAttribute(LinalgOperandDefKind kind) {
+  return kind == LinalgOperandDefKind::UnaryFnAttr ||
+         kind == LinalgOperandDefKind::BinaryFnAttr ||
+         kind == LinalgOperandDefKind::TypeFnAttr;
+}
+
+// Return true if the operand is an attribute.
+static bool isAttribute(LinalgOperandDefKind kind) {
+  return kind == LinalgOperandDefKind::IndexAttr || isFunctionAttribute(kind);
+}
+
+// Get the enum name for the given operand kind.
+std::string convertOperandKindToEnumName(LinalgOperandDefKind kind) {
+  switch (kind) {
+  case LinalgOperandDefKind::UnaryFnAttr:
+    return std::string("UnaryFn");
+  case LinalgOperandDefKind::BinaryFnAttr:
+    return std::string("BinaryFn");
+  case LinalgOperandDefKind::TypeFnAttr:
+    return std::string("TypeFn");
+  default:
+    break;
+  }
+  llvm_unreachable("unsupported function attribute kind");
+}
+
+// Get the enum name for the given function kind.
+std::string convertFunctionKindToEnumName(ScalarFnKind kind) {
+  switch (kind) {
+  case ScalarFnKind::Unary:
+    return std::string("UnaryFn");
+  case ScalarFnKind::Binary:
+    return std::string("BinaryFn");
+  case ScalarFnKind::Type:
+    return std::string("TypeFn");
+  }
+}
+
 //===----------------------------------------------------------------------===//
 // Templates
 //===----------------------------------------------------------------------===//
@@ -693,8 +736,7 @@ static LogicalResult generateNamedGenericOpOds(LinalgOpConfig &opConfig,
   interfaceNameList = interleaveToString(opConfig.metadata->implements, ", ");
 
   if (llvm::any_of(opConfig.structuredOp->args, [](LinalgOperandDef &arg) {
-        return arg.kind == LinalgOperandDefKind::IndexAttr ||
-               arg.kind == LinalgOperandDefKind::TypeFnAttr;
+        return isAttribute(arg.kind);
       })) {
     SmallVector<std::string> attrDefs;
     SmallVector<std::string> attrParams;
@@ -703,13 +745,14 @@ static LogicalResult generateNamedGenericOpOds(LinalgOpConfig &opConfig,
       static const char paramFmt[] = "\"Attribute\":${0}";
       static const char stmtFmt[] = "$_state.addAttribute(\"{0}\", {0});";
       // Add the type conversion attributes to the op definition and builders.
-      if (arg.kind == LinalgOperandDefKind::TypeFnAttr) {
+      if (isFunctionAttribute(arg.kind)) {
         assert(arg.defaultFn.hasValue());
-        static const char typeFmt[] = "TypeFn::{0}";
+        std::string enumName = convertOperandKindToEnumName(arg.kind);
+        static const char typeFmt[] = "{0}::{1}";
         static const char defFmt[] = "DefaultValuedAttr<{0}, \"{1}\">:${2}";
-        attrDefs.push_back(llvm::formatv(defFmt, "TypeFnAttr",
-                                         llvm::formatv(typeFmt, arg.defaultFn),
-                                         arg.name));
+        attrDefs.push_back(llvm::formatv(
+            defFmt, llvm::formatv("{0}Attr", enumName),
+            llvm::formatv(typeFmt, enumName, arg.defaultFn), arg.name));
         attrParams.push_back(llvm::formatv(paramFmt, arg.name));
         attrStmts.push_back(llvm::formatv(stmtFmt, arg.name));
       }
@@ -1000,21 +1043,24 @@ void {0}::regionBuilder(ImplicitLocOpBuilder &b,
     SmallVector<std::string> attrs;
     SmallVector<std::string> stmts;
     for (LinalgOperandDef &arg : args) {
-      if (arg.kind != LinalgOperandDefKind::TypeFnAttr)
+      if (!isFunctionAttribute(arg.kind))
         continue;
       // Obtain the type function attribute values. Parameters.
-      // {0}: attribute name
-      // {1}: default type function name
+      // {0}: enum name
+      // {1}: attribute name
+      // {2}: default type function name
       static const char attrDef[] = R"FMT(
-TypeFn {0}Val = TypeFn::{1};
-auto {0}Iter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {{
-                              return attr.getName() == "{0}"; });
-if ({0}Iter != attrs.end()) {{
-  if (auto attr = {0}Iter->getValue().dyn_cast<TypeFnAttr>())
-    {0}Val = attr.getValue();
+{0} {1}Val = {0}::{2};
+auto {1}Iter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {{
+                              return attr.getName() == "{1}"; });
+if ({1}Iter != attrs.end()) {{
+  if (auto attr = {1}Iter->getValue().dyn_cast<{0}Attr>())
+    {1}Val = attr.getValue();
 }
 )FMT";
-      attrs.push_back(llvm::formatv(attrDef, arg.name, arg.defaultFn));
+      std::string enumName = convertOperandKindToEnumName(arg.kind);
+      attrs.push_back(
+          llvm::formatv(attrDef, enumName, arg.name, arg.defaultFn));
     }
     for (LinalgOperandDef &arg : args) {
       if (arg.kind != LinalgOperandDefKind::OutputTensor)
@@ -1056,71 +1102,59 @@ if ({0}Iter != attrs.end()) {{
                                         cppIdent, *expression.index));
           return cppIdent;
         }
-        if (expression.scalarFn &&
-            expression.scalarFn->kind != ScalarFnKind::Type) {
-          // Apply function.
-          // Recursively generate operands.
-          SmallVector<std::string> operandCppValues;
-          for (ScalarExpression &operand : expression.scalarFn->operands) {
-            auto operandCppValue = generateExpression(operand);
-            if (!operandCppValue)
-              return None;
-            operandCppValues.push_back(*operandCppValue);
-          }
-
-          std::string prefix = expression.scalarFn->kind == ScalarFnKind::Unary
-                                   ? "unary"
-                                   : "binary";
-          std::string cppIdent = llvm::formatv("value{0}", ++localCounter);
-          stmts.push_back(
-              llvm::formatv("Value {0} = helper.{1}__{2}({3});", cppIdent,
-                            prefix, expression.scalarFn->fnName,
-                            interleaveToString(operandCppValues, ", ")));
-          return cppIdent;
-        }
-        if (expression.scalarFn &&
-            expression.scalarFn->kind == ScalarFnKind::Type) {
-          // Symbolic cast.
-          // Operands must be arity 1.
-          if (expression.scalarFn->operands.size() != 1) {
-            emitError(genContext.getLoc())
-                << "type conversion operand arity must be 1";
-            return None;
+        if (expression.scalarFn) {
+          std::string enumName =
+              convertFunctionKindToEnumName(expression.scalarFn->kind);
+
+          // Get the function or attribute name.
+          assert(expression.scalarFn->fnName || expression.scalarFn->attrName);
+          std::string funcType;
+          if (expression.scalarFn->fnName) {
+            funcType = llvm::formatv("{0}::{1}", enumName,
+                                     *expression.scalarFn->fnName);
           }
-          Optional<std::string> operandCppValue =
-              generateExpression(expression.scalarFn->operands[0]);
-          if (!operandCppValue)
-            return None;
-
-          assert(expression.scalarFn->typeVar.hasValue());
-          Optional<std::string> typeCppValue =
-              findTypeValue(expression.scalarFn->typeVar.getValue(), args);
-          if (!typeCppValue) {
-            emitError(genContext.getLoc())
-                << "type variable " << expression.scalarFn->typeVar.getValue()
-                << ", used in a type conversion, must map to a predefined or "
-                << "an argument type but it does not";
-            return None;
-          }
-
-          // Use the function name or the attribute to build the type function.
-          std::string typeFunc = llvm::formatv(
-              "TypeFn::{0}", expression.scalarFn->fnName.getValueOr(""));
           if (expression.scalarFn->attrName) {
             if (llvm::none_of(args, [&](LinalgOperandDef &arg) {
-                  return arg.kind == LinalgOperandDefKind::TypeFnAttr &&
+                  return isFunctionAttribute(arg.kind) &&
                          arg.name == expression.scalarFn->attrName.getValue();
                 })) {
               emitError(genContext.getLoc())
-                  << "missing type function attribute "
+                  << "missing function attribute "
                   << expression.scalarFn->attrName.getValue();
             }
-            typeFunc = llvm::formatv("{0}Val", *expression.scalarFn->attrName);
+            funcType = llvm::formatv("{0}Val", *expression.scalarFn->attrName);
+          }
+          assert(!funcType.empty());
+
+          // Add the optional type parameter to the operands.
+          SmallVector<std::string> operandCppValues;
+          if (expression.scalarFn->kind == ScalarFnKind::Type) {
+            assert(expression.scalarFn->typeVar.hasValue());
+            Optional<std::string> typeCppValue =
+                findTypeValue(expression.scalarFn->typeVar.getValue(), args);
+            if (!typeCppValue) {
+              emitError(genContext.getLoc())
+                  << "type variable " << expression.scalarFn->typeVar.getValue()
+                  << ", used in a type conversion, must map to a predefined or "
+                  << "an argument type but it does not";
+              return None;
+            }
+            operandCppValues.push_back(typeCppValue.getValue());
           }
+
+          // Collect the scalar operands.
+          for (ScalarExpression &operand : expression.scalarFn->operands) {
+            auto operandCppValue = generateExpression(operand);
+            if (!operandCppValue)
+              return None;
+            operandCppValues.push_back(*operandCppValue);
+          }
+
+          // Call the function builder.
           std::string cppIdent = llvm::formatv("value{0}", ++localCounter);
           stmts.push_back(llvm::formatv(
-              "Value {0} = helper.buildTypeFn({1}, {2}, {3});", cppIdent,
-              typeFunc, typeCppValue.getValue(), *operandCppValue));
+              "Value {0} = helper.build{1}({2}, {3});", cppIdent, enumName,
+              funcType, interleaveToString(operandCppValues, ", ")));
           return cppIdent;
         }
         emitError(genContext.getLoc()) << "unknown ScalarExpression type";


        


More information about the Mlir-commits mailing list