[Mlir-commits] [mlir] 24357fe - [mlir][OpDSL] Add arithmetic function attributes.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Feb 28 23:46:15 PST 2022
Author: gysit
Date: 2022-03-01T07:45:47Z
New Revision: 24357fec8d706b1ef3b0a34432f7e10a8e6b2151
URL: https://github.com/llvm/llvm-project/commit/24357fec8d706b1ef3b0a34432f7e10a8e6b2151
DIFF: https://github.com/llvm/llvm-project/commit/24357fec8d706b1ef3b0a34432f7e10a8e6b2151.diff
LOG: [mlir][OpDSL] Add arithmetic function attributes.
The revision extends OpDSL with unary and binary function attributes. A function attribute, makes the operations used in the body of a structured operation configurable. For example, a pooling operation may take an aggregation function attribute that specifies if the op shall implement a min or a max pooling. The goal of this revision is to define less and more flexible operations.
We may thus for example define an element wise op:
```
linalg.elem(lhs, rhs, outs=[out], op=BinaryFn.mul)
```
If the op argument is not set the default operation is used.
Depends On D120109
Reviewed By: nicolasvasilache, aartbik
Differential Revision: https://reviews.llvm.org/D120110
Added:
Modified:
mlir/docs/Dialects/Linalg/OpDSL.md
mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
mlir/python/mlir/dialects/linalg/opdsl/lang/config.py
mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py
mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
mlir/test/python/dialects/linalg/opdsl/arguments.py
mlir/test/python/dialects/linalg/opdsl/assignments.py
mlir/test/python/dialects/linalg/opdsl/emit_pooling.py
mlir/test/python/dialects/linalg/ops.py
mlir/test/python/integration/dialects/linalg/opsrun.py
mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
Removed:
################################################################################
diff --git a/mlir/docs/Dialects/Linalg/OpDSL.md b/mlir/docs/Dialects/Linalg/OpDSL.md
index 116057ef4ad8a..dd068b1f400c1 100644
--- a/mlir/docs/Dialects/Linalg/OpDSL.md
+++ b/mlir/docs/Dialects/Linalg/OpDSL.md
@@ -107,12 +107,12 @@ copy_and_scale(val, in_tensor, outs=[out_tensor])
## Index Attributes
-Attributes are compile-time constant parameters only accessible in index
+Index attributes are compile-time constant parameters only accessible in index
expressions. They can be used to parameterize the access pattern of a structured
operation, for example, by setting its strides. They cannot take part in the
actual computation.
-The following example demonstrates the use of attributes:
+The following example demonstrates the use of index attributes:
```python
@linalg_structured_op
@@ -136,9 +136,9 @@ The `strides` vector elements substitute the symbols `S.SH` and `S.SW` in the
index expressions of the operation instance. If no strides are provided the
`default` vector elements are used instead.
-Attributes are currently limited to integer vectors and only accessible in index
-expressions. An operation may have multiple attributes all of them placed at the
-end of the parameter list after the output tensors.
+Index attributes are currently limited to integer vectors and only accessible in
+index expressions. An operation may have multiple attributes all of them placed
+at the end of the parameter list after the output tensors.
## Shape-Only Tensors
@@ -220,6 +220,43 @@ There are also special forms:
* `const(value)` returns a constant value.
* `index(dim)` returns the iteration index in the given dimension `dim`.
+## Function Attributes
+
+Function attributes are compile-time constant function parameters. They can be
+used to parameterize the computation performed by a structured operation, for
+example, to support signed and unsigned computations.
+
+The following example demonstrates the use of function attributes:
+
+```python
+ at linalg_structured_op
+def elemwise_binary(
+ lhs=TensorDef(T1),
+ rhs=TensorDef(T2),
+ O=TensorDef(U, output=True),
+ fun=BinaryFnAttrDef(default=BinaryFn.add),
+ cast=TypeFnAttrDef(default=TypeFn.cast)):
+ O[None] = fun(cast(U, lhs[None]), cast(U, rhs[None]))
+```
+
+The `fun` and `cast` function attributes by default are aliases for their
+default values `BinaryFn.add` and `TypeFn.cast`, respectively. When
+instantiating the operation, the function attributes may be set to other
+functions using optional named arguments:
+
+```python
+elemwise_binary(lhs, rhs, outs=[out_tensor],
+ fun=BinaryFn.mul, cast=TypeFn.cast_unsigned)
+```
+
+In the example, the `fun` and `cast` arguments adapt the body of the operation
+to implement multiplication and unsigned casts instead of addition and signed
+casts.
+
+OpDSL supports unary, binary, and type conversion function attributes. An
+operation can take multiple attributes of
diff erent kinds placed at the end of
+the parameter list.
+
## Types
All types in assignment expressions are late bound based on actual input and
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
index 6090c7055ef3e..c60e1b646e533 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
@@ -58,7 +58,26 @@ def Linalg_Dialect : Dialect {
}];
}
-// Define a TypeFn enum matching the OpDSL TypeFn class.
+// Define the function attribute enums matching the OpDSL functions.
+def UnaryFn : I32EnumAttr<"UnaryFn", "", [
+ I32EnumAttrCase<"exp", 0>,
+ I32EnumAttrCase<"log", 1>
+]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::linalg";
+}
+def BinaryFn : I32EnumAttr<"BinaryFn", "", [
+ I32EnumAttrCase<"add", 0>,
+ I32EnumAttrCase<"mul", 1>,
+ I32EnumAttrCase<"max", 2>,
+ I32EnumAttrCase<"min", 3>,
+ I32EnumAttrCase<"sub", 4>,
+ I32EnumAttrCase<"max_unsigned", 5>,
+ I32EnumAttrCase<"min_unsigned", 6>
+]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::linalg";
+}
def TypeFn : I32EnumAttr<"TypeFn", "", [
I32EnumAttrCase<"cast", 0>,
I32EnumAttrCase<"cast_unsigned", 1>
@@ -67,6 +86,12 @@ def TypeFn : I32EnumAttr<"TypeFn", "", [
let cppNamespace = "::mlir::linalg";
}
+def UnaryFnAttr : EnumAttr<Linalg_Dialect, UnaryFn, "unary_fn"> {
+ let assemblyFormat = "`<` $value `>`";
+}
+def BinaryFnAttr : EnumAttr<Linalg_Dialect, BinaryFn, "binary_fn"> {
+ let assemblyFormat = "`<` $value `>`";
+}
def TypeFnAttr : EnumAttr<Linalg_Dialect, TypeFn, "type_fn"> {
let assemblyFormat = "`<` $value `>`";
}
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 8789185b961fa..73fd114daf32e 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -1,6 +1,120 @@
### AUTOGENERATED from core_named_ops.py
### To regenerate, run: bin/update_core_linalg_named_ops.sh
--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+ name: elemwise_unary
+ cpp_class_name: ElemwiseUnaryOp
+ doc: |-
+ Applies the unary function fun elementwise.
+
+ Numeric casting is performed on the input operand, promoting it to the same
+ data type as the accumulator/output.
+structured_op: !LinalgStructuredOpConfig
+ args:
+ - !LinalgOperandDefConfig
+ name: I
+ kind: input_tensor
+ type_var: T1
+ shape_map: affine_map<() -> ()>
+ - !LinalgOperandDefConfig
+ name: O
+ kind: output_tensor
+ type_var: U
+ shape_map: affine_map<() -> ()>
+ - !LinalgOperandDefConfig
+ name: fun
+ kind: unary_fn_attr
+ default_fn: exp
+ - !LinalgOperandDefConfig
+ name: cast
+ kind: type_fn_attr
+ default_fn: cast
+ indexing_maps: !LinalgIndexingMapsConfig
+ static_indexing_maps:
+ - affine_map<() -> ()>
+ - affine_map<() -> ()>
+ iterator_types: []
+ assignments:
+ - !ScalarAssign
+ arg: O
+ value: !ScalarExpression
+ scalar_fn:
+ kind: unary
+ attr_name: fun
+ operands:
+ - !ScalarExpression
+ scalar_fn:
+ kind: type
+ attr_name: cast
+ type_var: U
+ operands:
+ - !ScalarExpression
+ scalar_arg: I
+--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+ name: elemwise_binary
+ cpp_class_name: ElemwiseBinaryOp
+ doc: |-
+ Applies the binary function fun elementwise.
+
+ Numeric casting is performed on the input operand, promoting it to the same
+ data type as the accumulator/output.
+structured_op: !LinalgStructuredOpConfig
+ args:
+ - !LinalgOperandDefConfig
+ name: lhs
+ kind: input_tensor
+ type_var: T1
+ shape_map: affine_map<() -> ()>
+ - !LinalgOperandDefConfig
+ name: rhs
+ kind: input_tensor
+ type_var: T2
+ shape_map: affine_map<() -> ()>
+ - !LinalgOperandDefConfig
+ name: O
+ kind: output_tensor
+ type_var: U
+ shape_map: affine_map<() -> ()>
+ - !LinalgOperandDefConfig
+ name: fun
+ kind: binary_fn_attr
+ default_fn: add
+ - !LinalgOperandDefConfig
+ name: cast
+ kind: type_fn_attr
+ default_fn: cast
+ indexing_maps: !LinalgIndexingMapsConfig
+ static_indexing_maps:
+ - affine_map<() -> ()>
+ - affine_map<() -> ()>
+ - affine_map<() -> ()>
+ iterator_types: []
+ assignments:
+ - !ScalarAssign
+ arg: O
+ value: !ScalarExpression
+ scalar_fn:
+ kind: binary
+ attr_name: fun
+ operands:
+ - !ScalarExpression
+ scalar_fn:
+ kind: type
+ attr_name: cast
+ type_var: U
+ operands:
+ - !ScalarExpression
+ scalar_arg: lhs
+ - !ScalarExpression
+ scalar_fn:
+ kind: type
+ attr_name: cast
+ type_var: U
+ operands:
+ - !ScalarExpression
+ scalar_arg: rhs
+--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: matmul
cpp_class_name: MatmulOp
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 24f492aed936a..dfc994368e6cf 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -108,17 +108,9 @@ static LogicalResult foldMemRefCast(Operation *op) {
//===----------------------------------------------------------------------===//
// Region builder helper.
// TODO: Move this to a utility library.
-// The public methods on this class are referenced directly from generated code
-// and bind by name to math functions in the DSL as:
-// `unary__{fnName}`
-// `binary__{fnName}`
-// Examples:
-// `binary__add`
-// `binary__mul`
-// `unary__exp`
-// `unary__log`
-// The naming convention is intentional in order to match snake-cased DSL names.
-// See mlir-linalg-ods-yaml-gen.cpp for the code that mates to this class.
+// The public methods on this class are referenced directly from generated code.
+// Helper build the unary, binary, and type conversion functions defined by the
+// DSL. See mlir-linalg-ods-yaml-gen.cpp for the code that uses this class.
//
// Implementations of the math functions must be polymorphic over numeric types,
// internally performing necessary casts. If the function application makes no
@@ -142,6 +134,98 @@ class RegionBuilderHelper {
RegionBuilderHelper(MLIRContext *context, Block &block)
: context(context), block(block) {}
+ // Build the unary functions defined by OpDSL.
+ Value buildUnaryFn(UnaryFn unaryFn, Value arg) {
+ if (!isFloatingPoint(arg))
+ llvm_unreachable("unsupported non numeric type");
+ OpBuilder builder = getBuilder();
+ switch (unaryFn) {
+ case UnaryFn::exp:
+ return builder.create<math::ExpOp>(arg.getLoc(), arg);
+ case UnaryFn::log:
+ return builder.create<math::LogOp>(arg.getLoc(), arg);
+ }
+ llvm_unreachable("unsupported unary function");
+ }
+
+ // Build the binary functions defined by OpDSL.
+ Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1) {
+ bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
+ bool allInteger = isInteger(arg0) && isInteger(arg1);
+ if (!allFloatingPoint && !allInteger)
+ llvm_unreachable("unsupported non numeric type");
+ OpBuilder builder = getBuilder();
+ switch (binaryFn) {
+ case BinaryFn::add:
+ if (allFloatingPoint)
+ return builder.create<arith::AddFOp>(arg0.getLoc(), arg0, arg1);
+ return builder.create<arith::AddIOp>(arg0.getLoc(), arg0, arg1);
+ case BinaryFn::mul:
+ if (allFloatingPoint)
+ return builder.create<arith::MulFOp>(arg0.getLoc(), arg0, arg1);
+ return builder.create<arith::MulIOp>(arg0.getLoc(), arg0, arg1);
+ case BinaryFn::max:
+ if (allFloatingPoint)
+ return builder.create<arith::MaxFOp>(arg0.getLoc(), arg0, arg1);
+ return builder.create<arith::MaxSIOp>(arg0.getLoc(), arg0, arg1);
+ case BinaryFn::min:
+ if (allFloatingPoint)
+ return builder.create<arith::MinFOp>(arg0.getLoc(), arg0, arg1);
+ return builder.create<arith::MinSIOp>(arg0.getLoc(), arg0, arg1);
+ case BinaryFn::sub:
+ if (allFloatingPoint)
+ return builder.create<arith::SubFOp>(arg0.getLoc(), arg0, arg1);
+ return builder.create<arith::SubIOp>(arg0.getLoc(), arg0, arg1);
+ case BinaryFn::max_unsigned:
+ if (allFloatingPoint)
+ return builder.create<arith::MaxFOp>(arg0.getLoc(), arg0, arg1);
+ return builder.create<arith::MaxUIOp>(arg0.getLoc(), arg0, arg1);
+ case BinaryFn::min_unsigned:
+ if (allFloatingPoint)
+ return builder.create<arith::MinFOp>(arg0.getLoc(), arg0, arg1);
+ return builder.create<arith::MinUIOp>(arg0.getLoc(), arg0, arg1);
+ }
+ llvm_unreachable("unsupported binary function");
+ }
+
+ // Build the type functions defined by OpDSL.
+ Value buildTypeFn(TypeFn typeFn, Type toType, Value operand) {
+ switch (typeFn) {
+ case TypeFn::cast:
+ return cast(toType, operand, false);
+ case TypeFn::cast_unsigned:
+ return cast(toType, operand, true);
+ }
+ llvm_unreachable("unsupported type conversion function");
+ }
+
+ void yieldOutputs(ValueRange values) {
+ OpBuilder builder = getBuilder();
+ Location loc = builder.getUnknownLoc();
+ builder.create<YieldOp>(loc, values);
+ }
+
+ Value constant(const std::string &value) {
+ OpBuilder builder = getBuilder();
+ Location loc = builder.getUnknownLoc();
+ Attribute valueAttr = parseAttribute(value, builder.getContext());
+ return builder.create<arith::ConstantOp>(loc, valueAttr.getType(),
+ valueAttr);
+ }
+
+ Value index(int64_t dim) {
+ OpBuilder builder = getBuilder();
+ return builder.create<IndexOp>(builder.getUnknownLoc(), dim);
+ }
+
+ Type getIntegerType(unsigned width) {
+ return IntegerType::get(context, width);
+ }
+
+ Type getFloat32Type() { return Float32Type::get(context); }
+ Type getFloat64Type() { return Float64Type::get(context); }
+
+private:
// Generates operations to cast the given operand to a specified type.
// If the cast cannot be performed, a warning will be issued and the
// operand returned as-is (which will presumably yield a verification
@@ -193,136 +277,6 @@ class RegionBuilderHelper {
return operand;
}
- Value buildTypeFn(TypeFn typeFn, Type toType, Value operand) {
- switch (typeFn) {
- case TypeFn::cast:
- return cast(toType, operand, false);
- case TypeFn::cast_unsigned:
- return cast(toType, operand, true);
- }
- llvm_unreachable("unsupported type conversion function");
- }
-
- // NOLINTNEXTLINE(*-identifier-naming): externally called.
- Value binary__add(Value lhs, Value rhs) {
- OpBuilder builder = getBuilder();
- if (isFloatingPoint(lhs))
- return builder.create<arith::AddFOp>(lhs.getLoc(), lhs, rhs);
- if (isInteger(lhs))
- return builder.create<arith::AddIOp>(lhs.getLoc(), lhs, rhs);
- llvm_unreachable("unsupported non numeric type");
- }
-
- // NOLINTNEXTLINE(*-identifier-naming): externally called.
- Value unary__exp(Value x) {
- OpBuilder builder = getBuilder();
- if (isFloatingPoint(x))
- return builder.create<math::ExpOp>(x.getLoc(), x);
- llvm_unreachable("unsupported non numeric type");
- }
-
- // NOLINTNEXTLINE(*-identifier-naming): externally called.
- Value unary__log(Value x) {
- OpBuilder builder = getBuilder();
- if (isFloatingPoint(x))
- return builder.create<math::LogOp>(x.getLoc(), x);
- llvm_unreachable("unsupported non numeric type");
- }
-
- // NOLINTNEXTLINE(*-identifier-naming): externally called.
- Value binary__sub(Value lhs, Value rhs) {
- OpBuilder builder = getBuilder();
- if (isFloatingPoint(lhs))
- return builder.create<arith::SubFOp>(lhs.getLoc(), lhs, rhs);
- if (isInteger(lhs))
- return builder.create<arith::SubIOp>(lhs.getLoc(), lhs, rhs);
- llvm_unreachable("unsupported non numeric type");
- }
-
- // NOLINTNEXTLINE(*-identifier-naming): externally called.
- Value binary__mul(Value lhs, Value rhs) {
- OpBuilder builder = getBuilder();
- if (isFloatingPoint(lhs))
- return builder.create<arith::MulFOp>(lhs.getLoc(), lhs, rhs);
- if (isInteger(lhs))
- return builder.create<arith::MulIOp>(lhs.getLoc(), lhs, rhs);
- llvm_unreachable("unsupported non numeric type");
- }
-
- // NOLINTNEXTLINE(*-identifier-naming): externally called.
- Value binary__max(Value lhs, Value rhs) {
- OpBuilder builder = getBuilder();
- if (isFloatingPoint(lhs))
- return builder.create<arith::MaxFOp>(lhs.getLoc(), lhs, rhs);
- if (isInteger(lhs))
- return builder.create<arith::MaxSIOp>(lhs.getLoc(), lhs, rhs);
- llvm_unreachable("unsupported non numeric type");
- }
-
- // NOLINTNEXTLINE(*-identifier-naming): externally called.
- Value binary__max_unsigned(Value lhs, Value rhs) {
- OpBuilder builder = getBuilder();
- if (isFloatingPoint(lhs))
- return builder.create<arith::MaxFOp>(lhs.getLoc(), lhs, rhs);
- if (isInteger(lhs))
- return builder.create<arith::MaxUIOp>(lhs.getLoc(), lhs, rhs);
- llvm_unreachable("unsupported non numeric type");
- }
-
- // NOLINTNEXTLINE(*-identifier-naming): externally called.
- Value binary__min(Value lhs, Value rhs) {
- OpBuilder builder = getBuilder();
- if (isFloatingPoint(lhs))
- return builder.create<arith::MinFOp>(lhs.getLoc(), lhs, rhs);
- if (isInteger(lhs))
- return builder.create<arith::MinSIOp>(lhs.getLoc(), lhs, rhs);
- llvm_unreachable("unsupported non numeric type");
- }
-
- // NOLINTNEXTLINE(*-identifier-naming): externally called.
- Value binary__min_unsigned(Value lhs, Value rhs) {
- OpBuilder builder = getBuilder();
- if (isFloatingPoint(lhs))
- return builder.create<arith::MinFOp>(lhs.getLoc(), lhs, rhs);
- if (isInteger(lhs))
- return builder.create<arith::MinUIOp>(lhs.getLoc(), lhs, rhs);
- llvm_unreachable("unsupported non numeric type");
- }
-
- void yieldOutputs(ValueRange values) {
- assert(!values.empty() && "linalg ops must yield outputs");
- if (values.empty())
- return;
- Value first = values.front();
- OpBuilder builder = getBuilder();
- builder.create<YieldOp>(first.getLoc(), values);
- }
-
- Value constant(const std::string &value) {
- OpBuilder builder = getBuilder();
- Location loc = builder.getUnknownLoc();
- Attribute valueAttr = parseAttribute(value, builder.getContext());
- return builder.create<arith::ConstantOp>(loc, valueAttr.getType(),
- valueAttr);
- }
-
- Value index(int64_t dim) {
- OpBuilder builder = getBuilder();
- return builder.create<IndexOp>(builder.getUnknownLoc(), dim);
- }
-
- Type getIntegerType(unsigned width) {
- return IntegerType::get(context, width);
- }
-
- Type getFloat32Type() { return Float32Type::get(context); }
-
- Type getFloat64Type() { return Float64Type::get(context); }
-
-private:
- MLIRContext *context;
- Block █
-
bool isFloatingPoint(Value value) { return value.getType().isa<FloatType>(); }
bool isInteger(Value value) { return value.getType().isa<IntegerType>(); }
@@ -331,6 +285,9 @@ class RegionBuilderHelper {
builder.setInsertionPointToEnd(&block);
return builder;
}
+
+ MLIRContext *context;
+ Block █
};
} // namespace
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
index ef2ef30378211..f6bf0ff9a50d0 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
@@ -126,7 +126,7 @@ def _compute_reduce_dims(self, rhs: TensorExpression) -> Set[DimDef]:
return rhs_dims - lhs_dims
def __iadd__(self, rhs: TensorExpression) -> "TensorReduceFn":
- return ReduceFnUse(BinaryFn.add, *self._compute_reduce_dims(rhs))(rhs)
+ return ReduceFnUse(BinaryFn.add, None, *self._compute_reduce_dims(rhs))(rhs)
def __repr__(self):
return (f"{self.operand_def.name}"
@@ -183,8 +183,14 @@ def to_scalar_expression(self) -> ScalarExpression:
f"bound to its lhs: {self}")
full_args = [self.lhs.to_scalar_expression()
] + [arg.to_scalar_expression() for arg in self.args]
- return ScalarFn(FunctionKind.BINARY, self.reduce_use.binary_fn.fn_name,
- None, None, full_args).expr()
+ fn_name = None
+ attr_name = None
+ if self.reduce_use.binary_fn:
+ fn_name = self.reduce_use.binary_fn.fn_name
+ if self.reduce_use.binary_attr:
+ attr_name = self.reduce_use.binary_attr.operand_def.name
+ return ScalarFn(FunctionKind.BINARY, fn_name, attr_name, None,
+ full_args).expr()
def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]):
for arg in self.args:
@@ -257,8 +263,8 @@ class UnaryFnType:
def __init__(self, fn_name: str):
self.fn_name = fn_name
- def __call__(self, exp: TensorExpression) -> "TensorFn":
- return TensorFn(FunctionKind.UNARY, self.fn_name, None, None, [exp])
+ def __call__(self, arg: TensorExpression) -> "TensorFn":
+ return TensorFn(FunctionKind.UNARY, self.fn_name, None, None, [arg])
def __repr__(self):
return f"{self.fn_name}"
@@ -345,16 +351,21 @@ class ReduceFnUse:
A reduction use specifies the reduction function and dimensions.
"""
- def __init__(self, binary_fn: BinaryFnType, *reduce_dims: DimDef):
+ def __init__(self, binary_fn: Optional[BinaryFnType],
+ binary_attr: Optional["BinaryFnAttrDef"], *reduce_dims: DimDef):
+ if bool(binary_fn) + bool(binary_attr) != 1:
+ raise ValueError("One of 'binary_fn', 'binary_attr' must be specified")
self.binary_fn = binary_fn
+ self.binary_attr = binary_attr
self.reduce_dims = reduce_dims
def __call__(self, *args: TensorExpression) -> "TensorReduceFn":
return TensorReduceFn(self, args)
def __repr__(self):
- return (f"reduce_{self.binary_fn.fn_name}"
- f"({', '.join(repr(d) for d in self.reduce_dims)})")
+ fn = self.binary_fn if self.binary_fn else self.binary_attr
+ return (
+ f"reduce_{repr(fn)}({', '.join(repr(d) for d in self.reduce_dims)})")
class ReduceFnType:
@@ -369,10 +380,10 @@ def __init__(self, binary_fn: BinaryFnType):
self.binary_fn = binary_fn
def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse:
- return ReduceFnUse(self.binary_fn, *reduce_dims)
+ return ReduceFnUse(self.binary_fn, None, *reduce_dims)
def __repr__(self):
- return (f"reduce_{self.binary_fn.fn_name}")
+ return f"reduce_{repr(self.binary_fn)}"
class ReduceFn:
@@ -394,7 +405,9 @@ class OperandKind(Enum):
SCALAR = 1
OUTPUT_TENSOR = 2
INDEX_ATTR = 3
- TYPE_FN_ATTR = 4
+ UNARY_FN_ATTR = 4
+ BINARY_FN_ATTR = 5
+ TYPE_FN_ATTR = 6
class OperandDef:
@@ -441,6 +454,8 @@ def is_tensor(self) -> bool:
def is_attribute(self) -> bool:
return (self.kind == OperandKind.INDEX_ATTR or
+ self.kind == OperandKind.UNARY_FN_ATTR or
+ self.kind == OperandKind.BINARY_FN_ATTR or
self.kind == OperandKind.TYPE_FN_ATTR)
def __hash__(self):
@@ -557,6 +572,49 @@ def __init__(self, *sizes: SymbolDef, default: Sequence[int]):
OperandKind.INDEX_ATTR, size_exprs=sizes, default_indices=default)
+class UnaryFnAttrDef:
+ """Unary function attribute definition.
+
+ Unary function attributes provide a way to make the arithmetic computation
+ parametrizable. Every attribute specifies a default unary function
+ that may be overwritten at operation instantiation time.
+ """
+
+ def __init__(self, default: "UnaryFnType"):
+ if not isinstance(default, UnaryFnType):
+ raise ValueError(f"UnaryFnAttrDef requires default of type UnaryFnType "
+ f"but got {default}")
+ self.operand_def = OperandDef(
+ OperandKind.UNARY_FN_ATTR, default_fn=default.fn_name)
+
+ def __call__(self, arg: TensorExpression) -> TensorFn:
+ return TensorFn(FunctionKind.UNARY, None, self.operand_def, None, [arg])
+
+
+class BinaryFnAttrDef:
+ """Binary function attribute definition.
+
+ Binary function attributes provide a way to make the arithmetic computation
+ parametrizable. Every attribute specifies a default binary function
+ that may be overwritten at operation instantiation time.
+ """
+
+ def __init__(self, default: "BinaryFnType"):
+ if not isinstance(default, BinaryFnType):
+ raise ValueError(f"BinaryFnAttrDef requires default of type BinaryFnType "
+ f"but got {default}")
+ self.operand_def = OperandDef(
+ OperandKind.BINARY_FN_ATTR, default_fn=default.fn_name)
+
+ def __call__(self, arg0: TensorExpression,
+ arg1: TensorExpression) -> TensorFn:
+ return TensorFn(FunctionKind.BINARY, None, self.operand_def, None,
+ [arg0, arg1])
+
+ def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse:
+ return ReduceFnUse(None, self, *reduce_dims)
+
+
class TypeFnAttrDef:
"""Type conversion function attribute definition.
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py
index 12b168de184aa..ed30b8e5fc9a0 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py
@@ -309,8 +309,8 @@ def get_type(symbolic_name, position):
def add_operand(self, operand_def: OperandDef):
if operand_def in self.operands:
return
- if (operand_def.kind == OperandKind.SCALAR or
- operand_def.kind == OperandKind.TYPE_FN_ATTR):
+ if not (operand_def.is_tensor() or
+ operand_def.kind == OperandKind.INDEX_ATTR):
self.operands[operand_def] = OperandDefConfig(operand_def)
return
with self.context:
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py
index 99ce713668d1d..bd9042ac0aacb 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py
@@ -130,7 +130,8 @@ def linalg_structured_op(dsl_func=None,
for param_name, param in sig.parameters.items():
param_default = param.default
if isinstance(param_default,
- (TensorDef, ScalarDef, IndexAttrDef, TypeFnAttrDef)):
+ (TensorDef, ScalarDef, IndexAttrDef, UnaryFnAttrDef,
+ BinaryFnAttrDef, TypeFnAttrDef)):
op_def.add_operand(param_name, param_default.operand_def)
else:
raise ValueError(
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
index df4ab2249d4d4..79fc3f5a21904 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
@@ -41,7 +41,7 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
all_arg_defs = op_config.ordered_operands
in_arg_defs = [
d for d in all_arg_defs
- if d.kind == OperandKind.SCALAR or d.kind == OperandKind.INPUT_TENSOR
+ if d.kind in [OperandKind.SCALAR, OperandKind.INPUT_TENSOR]
]
out_arg_defs = [
d for d in all_arg_defs if d.kind == OperandKind.OUTPUT_TENSOR
@@ -49,8 +49,11 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
index_attr_arg_defs = [
d for d in all_arg_defs if d.kind == OperandKind.INDEX_ATTR
]
- type_fn_attr_arg_defs = [
- d for d in all_arg_defs if d.kind == OperandKind.TYPE_FN_ATTR
+ fn_attr_arg_defs = [
+ d for d in all_arg_defs if d.kind in [
+ OperandKind.UNARY_FN_ATTR, OperandKind.BINARY_FN_ATTR,
+ OperandKind.TYPE_FN_ATTR
+ ]
]
# Verify outs is a sequence or a list of results.
@@ -135,28 +138,38 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
array = np.array(index_attr_vals, dtype=np.int64)
index_attrs[index_attr.name] = DenseElementsAttr.get(array)
- # Compute the type function attribute mapping.
- type_fn_attr_mapping = {}
- for type_fn_attr in type_fn_attr_arg_defs:
- attr_val = type_fn_attr.operand_def.default_fn
- if type_fn_attr.name in attrs:
- type_fn = attrs.get(type_fn_attr.name)
- if not isinstance(type_fn, TypeFnType):
- raise ValueError(f"Attribute {type_fn_attr.name} needs to be of type "
- f"TypeFnType but got {type(attr_val)}")
- attr_val = type_fn.fn_name
- assert attr_val, "Type function attribute has no value"
- type_fn_attr_mapping[type_fn_attr.name] = attr_val
+ # Compute the function attribute mapping.
+ fn_attr_mapping = {}
+ for fn_attr in fn_attr_arg_defs:
+ attr_val = fn_attr.operand_def.default_fn
+ attr_kind = fn_attr.kind
+ if fn_attr.name in attrs:
+ fn = attrs.get(fn_attr.name)
+ if attr_kind == OperandKind.UNARY_FN_ATTR:
+ if not isinstance(fn, UnaryFnType):
+ raise ValueError(f"Attribute {fn_attr.name} needs to be of type "
+ f"UnaryFnType but got {type(attr_val)}")
+ elif attr_kind == OperandKind.BINARY_FN_ATTR:
+ if not isinstance(fn, BinaryFnType):
+ raise ValueError(f"Attribute {fn_attr.name} needs to be of type "
+ f"BinaryFnType but got {type(attr_val)}")
+ else:
+ if not isinstance(fn, TypeFnType):
+ raise ValueError(f"Attribute {fn_attr.name} needs to be of type "
+ f"TypeFnType but got {type(attr_val)}")
+ attr_val = fn.fn_name
+ assert attr_val, "Function attribute has no value"
+ fn_attr_mapping[fn_attr.name] = (attr_val, attr_kind)
return (all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types,
type_mapping, indexing_maps_attr, iterator_types_attr, index_attrs,
- type_fn_attr_mapping, block_arg_types)
+ fn_attr_mapping, block_arg_types)
def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value,
outs: ValueList, **attrs: Sequence[int]):
all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \
- indexing_maps_attr, iterator_types_attr, index_attrs, type_fn_attr_mapping, \
+ indexing_maps_attr, iterator_types_attr, index_attrs, fn_attr_mapping, \
block_arg_types = \
prepare_common_structured_op(op_config, *ins, outs = outs, **attrs)
@@ -193,7 +206,7 @@ def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value,
block_arg_mapping = dict(zip(block_arg_names, block.arguments))
with InsertionPoint(block):
body_builder = _BodyBuilder(type_mapping, block_arg_mapping,
- type_fn_attr_mapping)
+ fn_attr_mapping)
for assignment in op_config.assignments:
body_builder.assign(assignment)
body_builder.yield_outputs(*_get_operand_def_names(*out_arg_defs))
@@ -208,7 +221,7 @@ def emit_named_structured_op(op_config: LinalgStructuredOpConfig, op_name: str,
op_class_name: str, *ins: Value, outs: ValueList,
**attrs: Sequence[int]):
all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \
- indexing_maps_attr, iterator_types_attr, index_attrs, type_fn_attr_mapping, \
+ indexing_maps_attr, iterator_types_attr, index_attrs, fn_attr_mapping, \
block_arg_types = \
prepare_common_structured_op(op_config, *ins, outs = outs, **attrs)
@@ -225,10 +238,12 @@ def emit_named_structured_op(op_config: LinalgStructuredOpConfig, op_name: str,
for name, value in index_attrs.items():
named_op.operation.attributes[name] = value
- # Set the type function attributes.
- for name, value in type_fn_attr_mapping.items():
+ # Compute the function attributes by combining operand kind and function name.
+ for name, (fn_name, kind) in fn_attr_mapping.items():
+ assert kind.name.lower().endswith("_attr")
+ enum_name = kind.name.lower()[:-5]
named_op.operation.attributes[name] = Attribute.parse(
- f"#linalg.type_fn<{value}>")
+ f"#linalg.{enum_name}<{fn_name}>")
linalg.fill_builtin_region(named_op.operation)
@@ -242,11 +257,11 @@ class _BodyBuilder:
"""Constructs a structured op body by evaluating assignments."""
def __init__(self, type_mapping: Dict[str, Type],
- block_arg_mapping: Dict[str, Value],
- type_fn_attr_mapping: Dict[str, str]):
+ block_arg_mapping: Dict[str, Value], fn_attr_mapping: Dict[str,
+ str]):
self.type_mapping = type_mapping
self.block_arg_mapping = block_arg_mapping
- self.type_fn_attr_mapping = type_fn_attr_mapping
+ self.fn_attr_mapping = fn_attr_mapping
self.yield_mapping = dict() # type: Dict[str, Value]
def assign(self, assignment: ScalarAssign):
@@ -270,21 +285,18 @@ def expression(self, expr: ScalarExpression) -> Value:
dim_attr = IntegerAttr.get(
IntegerType.get_signless(64), expr.scalar_index.dim)
return linalg.IndexOp(dim_attr).result
- elif expr.scalar_fn and expr.scalar_fn.kind is not FunctionKind.TYPE:
+ elif expr.scalar_fn:
kind = expr.scalar_fn.kind.name.lower()
- fn = self._get_function(f"_{kind}_{expr.scalar_fn.fn_name}")
+ fn_name = expr.scalar_fn.fn_name
+ if expr.scalar_fn.attr_name:
+ fn_name, _ = self.fn_attr_mapping[expr.scalar_fn.attr_name]
+ fn = self._get_function(f"_{kind}_{fn_name}")
operand_values = [
self.expression(operand) for operand in expr.scalar_fn.operands
]
+ if expr.scalar_fn.kind == FunctionKind.TYPE:
+ operand_values = [expr.scalar_fn.type_var.name] + operand_values
return fn(*operand_values)
- elif expr.scalar_fn and expr.scalar_fn.kind is FunctionKind.TYPE:
- kind = expr.scalar_fn.kind.name.lower()
- fn_name = expr.scalar_fn.fn_name
- if expr.scalar_fn.attr_name:
- fn_name = self.type_fn_attr_mapping[expr.scalar_fn.attr_name]
- fn = self._get_function(f"_{kind}_{fn_name}")
- operand_value = self.expression(expr.scalar_fn.operands[0])
- return fn(expr.scalar_fn.type_var.name, operand_value)
raise NotImplementedError(f"Unimplemented scalar body expression: {expr}")
def yield_outputs(self, *output_names: str):
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index 340f4db4471bb..b7a827bf649b2 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -6,6 +6,35 @@
Batch = S.Batch
+ at linalg_structured_op
+def elemwise_unary(
+ I=TensorDef(T1),
+ O=TensorDef(U, output=True),
+ fun=UnaryFnAttrDef(default=UnaryFn.exp),
+ cast=TypeFnAttrDef(default=TypeFn.cast)):
+ """Applies the unary function fun elementwise.
+
+ Numeric casting is performed on the input operand, promoting it to the same
+ data type as the accumulator/output.
+ """
+ O[None] = fun(cast(U, I[None]))
+
+
+ at linalg_structured_op
+def elemwise_binary(
+ lhs=TensorDef(T1),
+ rhs=TensorDef(T2),
+ O=TensorDef(U, output=True),
+ fun=BinaryFnAttrDef(default=BinaryFn.add),
+ cast=TypeFnAttrDef(default=TypeFn.cast)):
+ """Applies the binary function fun elementwise.
+
+ Numeric casting is performed on the input operand, promoting it to the same
+ data type as the accumulator/output.
+ """
+ O[None] = fun(cast(U, lhs[None]), cast(U, rhs[None]))
+
+
@linalg_structured_op
def matmul(
A=TensorDef(T1, S.M, S.K),
diff --git a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
index 3778315ff5a61..21f00268a9c12 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
@@ -292,16 +292,48 @@ func @generalize_fill_rng_2d_i32(%min: f64, %max: f64, %seed: i32, %O: tensor<16
// -----
-func @generalize_soft_plus_2d_f32(%input: tensor<16x32xf32>, %output: tensor<16x32xf32>) -> tensor<16x32xf32> {
- %0 = linalg.soft_plus_2d ins(%input: tensor<16x32xf32>) outs(%output: tensor<16x32xf32>) -> tensor<16x32xf32>
- return %0: tensor<16x32xf32>
+// Verifies the default value of the fun attribute is an exp op.
+func @generalize_elemwise_exp(%lhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> {
+ %0 = linalg.elemwise_unary ins(%lhs: tensor<4x8xf32>) outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32>
+ return %0: tensor<4x8xf32>
}
-// CHECK-LABEL: @generalize_soft_plus_2d_f32
-// CHECK: %[[C1:.+]] = arith.constant 1.000000e+00 : f32
-// CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32
-// CHECK-NEXT: %[[EXP:.+]] = math.exp %[[IN]] : f32
-// CHECK-NEXT: %[[SUM:.+]] = arith.addf %[[EXP]], %[[C1]] : f32
-// CHECK-NEXT: %[[LOG:.+]] = math.log %[[SUM]] : f32
-// CHECK-NEXT: linalg.yield %[[LOG]] : f32
-// CHECK-NEXT: -> tensor<16x32xf32>
+// CHECK-LABEL: @generalize_elemwise_exp
+// CHECK: = math.exp
+
+// -----
+
+// Verifies the fun attribute controls the unary function used.
+func @generalize_elemwise_log(%lhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> {
+ %0 = linalg.elemwise_unary {fun = #linalg.unary_fn<log>}
+ ins(%lhs: tensor<4x8xf32>) outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32>
+ return %0: tensor<4x8xf32>
+}
+
+// CHECK-LABEL: @generalize_elemwise_log
+// CHECK: = math.log
+
+// -----
+
+// Verifies the default value of the fun attribute is an add op.
+func @generalize_elemwise_add(%lhs : tensor<4x8xf32>, %rhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> {
+ %0 = linalg.elemwise_binary ins(%lhs, %rhs: tensor<4x8xf32>, tensor<4x8xf32>)
+ outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32>
+ return %0: tensor<4x8xf32>
+}
+
+// CHECK-LABEL: @generalize_elemwise_add
+// CHECK: = arith.addf
+
+// -----
+
+// Verifies the fun attribute controls the binary function used.
+func @generalize_elemwise_mul(%lhs : tensor<4x8xf32>, %rhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> {
+ %0 = linalg.elemwise_binary {fun = #linalg.binary_fn<mul>}
+ ins(%lhs, %rhs: tensor<4x8xf32>, tensor<4x8xf32>)
+ outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32>
+ return %0: tensor<4x8xf32>
+}
+
+// CHECK-LABEL: @generalize_elemwise_mul
+// CHECK: = arith.mulf
diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
index dc1e1809eb46b..2defebbba781f 100644
--- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
+++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
@@ -111,7 +111,7 @@ structured_op: !LinalgStructuredOpConfig
# IMPL-DAG: Value [[VAL1:[a-z0-9]+]] = helper.buildTypeFn(castVal, block.getArgument(0).getType(), [[VAL0]]);
# IMPL-DAG: Value [[VAL2:[a-z0-9]+]] = helper.index(1);
# IMPL-DAG: Value [[VAL3:[a-z0-9]+]] = helper.buildTypeFn(castVal, block.getArgument(0).getType(), [[VAL2]]);
-# IMPL-DAG: Value [[VAL4:[a-z0-9]+]] = helper.binary__add([[VAL1]], [[VAL3]]);
+# IMPL-DAG: Value [[VAL4:[a-z0-9]+]] = helper.buildBinaryFn(BinaryFn::add, [[VAL1]], [[VAL3]]);
# @linalg_structured_op
@@ -255,14 +255,15 @@ structured_op: !LinalgStructuredOpConfig
# IMPL-NEXT: AffineMap scalarMap = AffineMap::get(getNumParallelLoops(), 0, context);
# IMPL-NEXT: AffineMap tensorMap = AffineMap::getMultiDimIdentityMap(
-
# @linalg_structured_op
-# def test4(O=TensorDef(T, S.M, S.N, output=True)):
+# def test4(O=TensorDef(T, S.M, S.N, output=True),
+# unary_fun=UnaryFnAttrDef(default=UnaryFn.exp),
+# binary_fun=BinaryFnAttrDef(default=BinaryFn.add)):
# """Title.
# Detailed description.
# """
-# O[D.m, D.n] = BinaryFn.add(UnaryFn.exp(O[D.m, D.n]), O[D.m, D.n])
+# O[D.m, D.n] = binary_fun(unary_fun(O[D.m, D.n]), O[D.m, D.n])
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
@@ -279,6 +280,14 @@ structured_op: !LinalgStructuredOpConfig
kind: output_tensor
type_var: T
shape_map: affine_map<()[s0, s1] -> (s0, s1)>
+ - !LinalgOperandDefConfig
+ name: unary_fun
+ kind: unary_fn_attr
+ default_fn: exp
+ - !LinalgOperandDefConfig
+ name: binary_fun
+ kind: binary_fn_attr
+ default_fn: add
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<(d0, d1)[s0, s1] -> (d0, d1)>
@@ -291,21 +300,36 @@ structured_op: !LinalgStructuredOpConfig
value: !ScalarExpression
scalar_fn:
kind: binary
- fn_name: add
+ attr_name: binary_fun
operands:
- !ScalarExpression
scalar_fn:
kind: unary
- fn_name: exp
+ attr_name: unary_fun
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
scalar_arg: O
+# ODS-LABEL: def Test4Op : LinalgStructuredBase_Op<"test4"
+
+# ODS: let arguments =
+# ODS-NEXT: Variadic<AnyType>:$inputs,
+# ODS-NEXT: Variadic<AnyShaped>:$outputs,
+# ODS-NEXT: DefaultValuedAttr<UnaryFnAttr, "UnaryFn::exp">:$unary_fun,
+# ODS-NEXT: DefaultValuedAttr<BinaryFnAttr, "BinaryFn::add">:$binary_fun
+
+# ODS: "Attribute":$unary_fun, "Attribute":$binary_fun,
+
+# ODS: $_state.addAttribute("unary_fun", unary_fun)
+# ODS-NEXT: $_state.addAttribute("binary_fun", binary_fun)
+
# IMPL-LABEL: void Test4Op::regionBuilder(ImplicitLocOpBuilder &b,
# IMPL-NEXT: Block &block, ArrayRef<NamedAttribute> attrs)
+# IMPL: UnaryFn unary_funVal = UnaryFn::exp
+# IMPL: BinaryFn binary_funVal = BinaryFn::add
-# IMPL: Value [[VAL0:[a-z0-9]+]] = helper.unary__exp(block.getArgument(0))
-# IMPL-NEXT: Value [[VAL1:[a-z0-9]+]] = helper.binary__add([[VAL0]], block.getArgument(0))
+# IMPL: Value [[VAL0:[a-z0-9]+]] = helper.buildUnaryFn(unary_funVal, block.getArgument(0))
+# IMPL-NEXT: Value [[VAL1:[a-z0-9]+]] = helper.buildBinaryFn(binary_funVal, [[VAL0]], block.getArgument(0))
# IMPL-NEXT: yields.push_back([[VAL1]])
diff --git a/mlir/test/python/dialects/linalg/opdsl/arguments.py b/mlir/test/python/dialects/linalg/opdsl/arguments.py
index c4f3d91ff10ab..853627611987c 100644
--- a/mlir/test/python/dialects/linalg/opdsl/arguments.py
+++ b/mlir/test/python/dialects/linalg/opdsl/arguments.py
@@ -18,6 +18,12 @@
# CHECK: kind: output_tensor
# CHECK: type_var: U
# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
+# CHECK: name: bfn
+# CHECK: kind: binary_fn_attr
+# CHECK: default_fn: mul
+# CHECK: name: ufn
+# CHECK: kind: unary_fn_attr
+# CHECK: default_fn: exp
# CHECK: name: cast
# CHECK: kind: type_fn_attr
# CHECK: default_fn: cast
@@ -26,8 +32,10 @@ def matmul(
A=TensorDef(T, S.M, S.K),
B=TensorDef(T, S.K, S.N),
C=TensorDef(U, S.M, S.N, output=True),
+ bfn=BinaryFnAttrDef(default=BinaryFn.mul),
+ ufn=UnaryFnAttrDef(default=UnaryFn.exp),
cast=TypeFnAttrDef(default=TypeFn.cast)):
- C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
+ C[D.m, D.n] += bfn(cast(U, A[D.m, D.k]), cast(U, B[D.k, D.n]))
# CHECK: ---
diff --git a/mlir/test/python/dialects/linalg/opdsl/assignments.py b/mlir/test/python/dialects/linalg/opdsl/assignments.py
index f93e0704a1e36..d8ddc24454914 100644
--- a/mlir/test/python/dialects/linalg/opdsl/assignments.py
+++ b/mlir/test/python/dialects/linalg/opdsl/assignments.py
@@ -10,10 +10,12 @@
# CHECK: arg: C
# CHECK: value:
# CHECK: scalar_fn:
+# CHECK: kind: binary
# CHECK: fn_name: add
# CHECK: operands:
# CHECK: scalar_fn:
-# CHECK: fn_name: mul
+# CHECK: kind: binary
+# CHECK: attr_name: mul
# CHECK: operands:
# CHECK: scalar_fn:
# CHECK: kind: type
@@ -32,8 +34,9 @@ def matmul(
A=TensorDef(T, S.M, S.K),
B=TensorDef(T, S.K, S.N),
C=TensorDef(U, S.M, S.N, output=True),
+ mul=BinaryFnAttrDef(default=BinaryFn.mul),
cast=TypeFnAttrDef(default=TypeFn.cast)):
- C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
+ C[D.m, D.n] += mul(cast(U, A[D.m, D.k]), cast(U, B[D.k, D.n]))
# CHECK: ---
@@ -69,14 +72,21 @@ def matmul(
# CHECK: fn_name: cast
# CHECK: type_var: T
# CHECK: operands:
-# CHECK: scalar_const: '1.{{[0]*}}e+03 : f64'
+# CHECK: scalar_fn:
+# CHECK: kind: unary
+# CHECK: attr_name: exp
+# CHECK: operands:
+# CHECK: scalar_const: '1.{{[0]*}}e+03 : f64'
@linalg_structured_op
-def constants(O=TensorDef(T, S.M, S.K, output=True)):
+def constants(
+ O=TensorDef(T, S.M, S.K, output=True),
+ exp=UnaryFnAttrDef(default=UnaryFn.exp)):
pi = TypeFn.cast(T, const(3.1415926535897931))
cst42 = TypeFn.cast(T, const(42))
- cst1000 = TypeFn.cast(T, const(1e+3))
+ cst1000 = TypeFn.cast(T, exp(const(1e+3)))
O[D.m, D.n] = UnaryFn.exp(pi) + cst42 - cst1000
+
# CHECK: ---
# CHECK-LABEL: indices
# CHECK: assignments:
diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_pooling.py b/mlir/test/python/dialects/linalg/opdsl/emit_pooling.py
index 35ec8540cb4b2..ded97cd7b8220 100644
--- a/mlir/test/python/dialects/linalg/opdsl/emit_pooling.py
+++ b/mlir/test/python/dialects/linalg/opdsl/emit_pooling.py
@@ -12,55 +12,18 @@
@linalg_structured_op
-def pooling_max_poly(
+def pooling_poly(
I=TensorDef(T1, S.N, S.H, S.W, S.C),
K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
+ reduce=BinaryFnAttrDef(default=BinaryFn.max),
+ cast=TypeFnAttrDef(default=TypeFn.cast),
strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
- O[D.n, D.oh, D.ow, D.c] = ReduceFn.max[D.kh, D.kw](
- TypeFn.cast(
- U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
-
-
- at linalg_structured_op
-def pooling_max_unsigned_poly(
- I=TensorDef(T1, S.N, S.H, S.W, S.C),
- K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
- O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
- strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
- dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
- domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
- O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_unsigned[D.kh, D.kw](
- TypeFn.cast_unsigned(
- U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
-
-
- at linalg_structured_op
-def pooling_min_poly(
- I=TensorDef(T1, S.N, S.H, S.W, S.C),
- K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
- O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
- strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
- dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
- domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
- O[D.n, D.oh, D.ow, D.c] = ReduceFn.min[D.kh, D.kw](
- TypeFn.cast(
- U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
-
-
- at linalg_structured_op
-def pooling_min_unsigned_poly(
- I=TensorDef(T1, S.N, S.H, S.W, S.C),
- K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
- O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
- strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
- dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
- domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
- O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_unsigned[D.kh, D.kw](
- TypeFn.cast_unsigned(
- U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
+ O[D.n, D.oh, D.ow, D.c] = reduce[D.kh, D.kw](
+ cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
+ D.c]))
with Context() as ctx, Location.unknown():
@@ -88,7 +51,7 @@ def pooling_min_unsigned_poly(
RankedTensorType.get((2, 2), f32),
RankedTensorType.get((1, 2, 4, 1), i32))
def test_f32i32_max_pooling(input, shape, init_result):
- return pooling_max_poly(
+ return pooling_poly(
input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2])
# CHECK-LABEL: @test_f32i32_max_unsigned_pooling
@@ -99,8 +62,14 @@ def test_f32i32_max_pooling(input, shape, init_result):
RankedTensorType.get((2, 2), f32),
RankedTensorType.get((1, 2, 4, 1), i32))
def test_f32i32_max_unsigned_pooling(input, shape, init_result):
- return pooling_max_unsigned_poly(
- input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2])
+ return pooling_poly(
+ input,
+ shape,
+ outs=[init_result],
+ reduce=BinaryFn.max_unsigned,
+ cast=TypeFn.cast_unsigned,
+ strides=[2, 4],
+ dilations=[1, 2])
# CHECK-LABEL: @test_f32f32_max_pooling
# CHECK: linalg.generic
@@ -115,7 +84,7 @@ def test_f32i32_max_unsigned_pooling(input, shape, init_result):
RankedTensorType.get((2, 2), f32),
RankedTensorType.get((1, 2, 4, 1), f32))
def test_f32f32_max_pooling(input, shape, init_result):
- return pooling_max_poly(
+ return pooling_poly(
input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2])
# CHECK-LABEL: @test_f32i32_min_pooling
@@ -126,8 +95,13 @@ def test_f32f32_max_pooling(input, shape, init_result):
RankedTensorType.get((2, 2), f32),
RankedTensorType.get((1, 2, 4, 1), i32))
def test_f32i32_min_pooling(input, shape, init_result):
- return pooling_min_poly(
- input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2])
+ return pooling_poly(
+ input,
+ shape,
+ outs=[init_result],
+ reduce=BinaryFn.min,
+ strides=[2, 4],
+ dilations=[1, 2])
# CHECK-LABEL: @test_f32i32_min_unsigned_pooling
# CHECK: = arith.fptoui
@@ -137,8 +111,14 @@ def test_f32i32_min_pooling(input, shape, init_result):
RankedTensorType.get((2, 2), f32),
RankedTensorType.get((1, 2, 4, 1), i32))
def test_f32i32_min_unsigned_pooling(input, shape, init_result):
- return pooling_min_unsigned_poly(
- input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2])
+ return pooling_poly(
+ input,
+ shape,
+ outs=[init_result],
+ reduce=BinaryFn.min_unsigned,
+ cast=TypeFn.cast_unsigned,
+ strides=[2, 4],
+ dilations=[1, 2])
# CHECK-LABEL: @test_f32f32_min_pooling
# CHECK: = arith.minf
@@ -147,8 +127,13 @@ def test_f32i32_min_unsigned_pooling(input, shape, init_result):
RankedTensorType.get((2, 2), f32),
RankedTensorType.get((1, 2, 4, 1), f32))
def test_f32f32_min_pooling(input, shape, init_result):
- return pooling_min_poly(
- input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2])
+ return pooling_poly(
+ input,
+ shape,
+ outs=[init_result],
+ reduce=BinaryFn.min,
+ strides=[2, 4],
+ dilations=[1, 2])
print(module)
diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index 8b686e69cb138..00a8406373855 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -94,20 +94,27 @@ def testNamedStructuredOpCustomForm():
with InsertionPoint(module.body):
@builtin.FuncOp.from_py_func(
- RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8),
- f32))
+ RankedTensorType.get((4, 8), f32), RankedTensorType.get((4, 8), f32))
def named_form(lhs, rhs):
init_result = linalg.InitTensorOp([4, 8], f32)
- # First check the named form with custom format
- # CHECK: linalg.matmul
- # CHECK: cast = #linalg.type_fn<cast_unsigned>
- # CHECK-NOT: linalg.memoized_indexing_maps
- # CHECK-SAME: ins(%{{.*}} : tensor<4x16xf32>, tensor<16x8xf32>)
- # CHECK-SAME: outs(%{{.*}} : tensor<4x8xf32>)
- # CHECK-SAME: -> tensor<4x8xf32>
- # CHECK-NEXT: return
- return linalg.matmul(
- lhs, rhs, outs=[init_result.result], cast=TypeFn.cast_unsigned)
+ # Check for the named form with custom format
+ # CHECK: linalg.elemwise_unary
+ # CHECK-SAME: cast = #linalg.type_fn<cast>
+ # CHECK-SAME: fun = #linalg.unary_fn<exp>
+ # CHECK-SAME: ins(%{{.*}} : tensor<4x8xf32>) outs(%{{.*}} : tensor<4x8xf32>)
+ unary_result = linalg.elemwise_unary(lhs, outs=[init_result.result])
+ # CHECK: linalg.elemwise_binary
+ # CHECK-SAME: cast = #linalg.type_fn<cast_unsigned>
+ # CHECK-SAME: fun = #linalg.binary_fn<mul>
+ # CHECK-SAME: ins(%{{.*}}, %{{.*}} : tensor<4x8xf32>, tensor<4x8xf32>) outs(%{{.*}} : tensor<4x8xf32>)
+ # CHECK: return
+ binary_result = linalg.elemwise_binary(
+ lhs,
+ rhs,
+ outs=[init_result.result],
+ fun=BinaryFn.mul,
+ cast=TypeFn.cast_unsigned)
+ return unary_result, binary_result
print(module)
@@ -130,7 +137,8 @@ def named_form(lhs, rhs):
# CHECK-NEXT: arith.mulf{{.*}} (f32, f32) -> f32
# CHECK-NEXT: arith.addf{{.*}} (f32, f32) -> f32
# CHECK-NEXT: linalg.yield{{.*}} (f32) -> ()
- # CHECK-NEXT: operand_segment_sizes = dense<[2, 1]> : vector<2xi32>
+ # CHECK-NEXT: cast = #linalg.type_fn<cast>
+ # CHECK-SAME: operand_segment_sizes = dense<[2, 1]> : vector<2xi32>
# CHECK-SAME: (tensor<4x16xf32>, tensor<16x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32>
return linalg.matmul(lhs, rhs, outs=[init_result.result])
diff --git a/mlir/test/python/integration/dialects/linalg/opsrun.py b/mlir/test/python/integration/dialects/linalg/opsrun.py
index 1416aabc81133..c4e580f89c94d 100644
--- a/mlir/test/python/integration/dialects/linalg/opsrun.py
+++ b/mlir/test/python/integration/dialects/linalg/opsrun.py
@@ -19,6 +19,37 @@ def log(*args):
sys.stderr.flush()
+elemwise_boiler = """
+func @main() -> f32 attributes {llvm.emit_c_interface} {
+ %v0 = arith.constant 0.0 : f32
+ %v1 = arith.constant 1.0 : f32
+ %v2 = arith.constant 2.0 : f32
+
+ %lhs = memref.alloc() : memref<4x8xf32>
+ %rhs = memref.alloc() : memref<4x8xf32>
+ %O0 = memref.alloc() : memref<4x8xf32>
+ %O1 = memref.alloc() : memref<4x8xf32>
+ linalg.fill(%v1, %lhs) : f32, memref<4x8xf32>
+ linalg.fill(%v2, %rhs) : f32, memref<4x8xf32>
+ linalg.fill(%v0, %O0) : f32, memref<4x8xf32>
+ linalg.fill(%v0, %O1) : f32, memref<4x8xf32>
+
+ call @elemwise_exp_add_on_buffers(%lhs, %rhs, %O0) :
+ (memref<4x8xf32>, memref<4x8xf32>, memref<4x8xf32>) -> ()
+ call @elemwise_log_mul_on_buffers(%lhs, %rhs, %O1) :
+ (memref<4x8xf32>, memref<4x8xf32>, memref<4x8xf32>) -> ()
+
+ %c0 = arith.constant 0 : index
+ %res0 = memref.load %O0[%c0, %c0] : memref<4x8xf32>
+ %res1 = memref.load %O1[%c0, %c0] : memref<4x8xf32>
+
+ %0 = arith.addf %res0, %res1 : f32
+
+ // TODO: FFI-based solution to allow testing and printing with python code.
+ return %0 : f32
+}
+"""
+
matmul_boiler = """
func @main() -> f32 attributes {llvm.emit_c_interface} {
%v0 = arith.constant 0.0 : f32
@@ -166,13 +197,93 @@ def transform(module, boilerplate):
pm = PassManager.parse(
"builtin.func(convert-linalg-to-loops, lower-affine, " +
- "convert-scf-to-cf, arith-expand, memref-expand), convert-vector-to-llvm,"
- + "convert-memref-to-llvm, convert-std-to-llvm," +
+ "convert-math-to-llvm, convert-scf-to-cf, arith-expand, memref-expand), "
+ + "convert-vector-to-llvm, convert-memref-to-llvm, convert-std-to-llvm," +
"reconcile-unrealized-casts")
pm.run(mod)
return mod
+def test_elemwise_builtin():
+ with Context() as ctx, Location.unknown():
+ module = Module.create()
+ f32 = F32Type.get()
+ i8 = IntegerType.get_signless(8)
+ with InsertionPoint(module.body):
+
+ @builtin.FuncOp.from_py_func(
+ MemRefType.get((4, 8), f32), MemRefType.get((4, 8), f32),
+ MemRefType.get((4, 8), f32))
+ def elemwise_exp_add_on_buffers(lhs, rhs, out):
+ linalg.elemwise_unary(lhs, outs=[out])
+ linalg.elemwise_binary(out, rhs, outs=[out])
+
+ @builtin.FuncOp.from_py_func(
+ MemRefType.get((4, 8), f32), MemRefType.get((4, 8), f32),
+ MemRefType.get((4, 8), f32))
+ def elemwise_log_mul_on_buffers(lhs, rhs, out):
+ linalg.elemwise_unary(lhs, outs=[out], fun=UnaryFn.log)
+ linalg.elemwise_binary(out, rhs, outs=[out], fun=BinaryFn.mul)
+
+ execution_engine = ExecutionEngine(transform(module, elemwise_boiler))
+
+ # TODO: FFI-based solution to allow testing and printing with python code.
+ # Prepare arguments: one result f32.
+ # Arguments must be passed as pointers.
+ c_float_p = ctypes.c_float * 1
+ res = c_float_p(-1.)
+ execution_engine.invoke("main", res)
+
+ log("RESULT: ", res[0])
+ # elemwise_exp_add_on_buffers: exp(1.0) + 2.0 = 4.71828182846
+ # elemwise_log_mul_on_buffers: log(1.0) * 2.0 = 0.0
+ # CHECK: RESULT: 4.71828
+
+
+test_elemwise_builtin()
+
+
+def test_elemwise_generic():
+ with Context() as ctx, Location.unknown():
+ module = Module.create()
+ f32 = F32Type.get()
+ i8 = IntegerType.get_signless(8)
+ with InsertionPoint(module.body):
+
+ @builtin.FuncOp.from_py_func(
+ MemRefType.get((4, 8), f32), MemRefType.get((4, 8), f32),
+ MemRefType.get((4, 8), f32))
+ def elemwise_exp_add_on_buffers(lhs, rhs, out):
+ linalg.elemwise_unary(lhs, outs=[out], emit_generic=True)
+ linalg.elemwise_binary(out, rhs, outs=[out], emit_generic=True)
+
+ @builtin.FuncOp.from_py_func(
+ MemRefType.get((4, 8), f32), MemRefType.get((4, 8), f32),
+ MemRefType.get((4, 8), f32))
+ def elemwise_log_mul_on_buffers(lhs, rhs, out):
+ linalg.elemwise_unary(
+ lhs, outs=[out], fun=UnaryFn.log, emit_generic=True)
+ linalg.elemwise_binary(
+ out, rhs, outs=[out], fun=BinaryFn.mul, emit_generic=True)
+
+ execution_engine = ExecutionEngine(transform(module, elemwise_boiler))
+
+ # TODO: FFI-based solution to allow testing and printing with python code.
+ # Prepare arguments: one result f32.
+ # Arguments must be passed as pointers.
+ c_float_p = ctypes.c_float * 1
+ res = c_float_p(-1.)
+ execution_engine.invoke("main", res)
+
+ log("RESULT: ", res[0])
+ # elemwise_exp_add_on_buffers: exp(1.0) + 2.0 = 4.71828182846
+ # elemwise_log_mul_on_buffers: log(1.0) * 2.0 = 0.0
+ # CHECK: RESULT: 4.71828
+
+
+test_elemwise_generic()
+
+
def test_matmul_builtin():
with Context() as ctx, Location.unknown():
module = Module.create()
diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
index 7685d1a53e313..a535ad7654809 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
@@ -66,6 +66,8 @@ enum class LinalgOperandDefKind {
Scalar,
OutputTensor,
IndexAttr,
+ UnaryFnAttr,
+ BinaryFnAttr,
TypeFnAttr
};
@@ -208,6 +210,8 @@ struct ScalarEnumerationTraits<LinalgOperandDefKind> {
io.enumCase(value, "scalar", LinalgOperandDefKind::Scalar);
io.enumCase(value, "output_tensor", LinalgOperandDefKind::OutputTensor);
io.enumCase(value, "index_attr", LinalgOperandDefKind::IndexAttr);
+ io.enumCase(value, "unary_fn_attr", LinalgOperandDefKind::UnaryFnAttr);
+ io.enumCase(value, "binary_fn_attr", LinalgOperandDefKind::BinaryFnAttr);
io.enumCase(value, "type_fn_attr", LinalgOperandDefKind::TypeFnAttr);
}
};
@@ -430,6 +434,45 @@ static ScalarAssign *findAssignment(StringRef name,
return nullptr;
}
+// Return true if the operand is a function attribute.
+static bool isFunctionAttribute(LinalgOperandDefKind kind) {
+ return kind == LinalgOperandDefKind::UnaryFnAttr ||
+ kind == LinalgOperandDefKind::BinaryFnAttr ||
+ kind == LinalgOperandDefKind::TypeFnAttr;
+}
+
+// Return true if the operand is an attribute.
+static bool isAttribute(LinalgOperandDefKind kind) {
+ return kind == LinalgOperandDefKind::IndexAttr || isFunctionAttribute(kind);
+}
+
+// Get the enum name for the given operand kind.
+std::string convertOperandKindToEnumName(LinalgOperandDefKind kind) {
+ switch (kind) {
+ case LinalgOperandDefKind::UnaryFnAttr:
+ return std::string("UnaryFn");
+ case LinalgOperandDefKind::BinaryFnAttr:
+ return std::string("BinaryFn");
+ case LinalgOperandDefKind::TypeFnAttr:
+ return std::string("TypeFn");
+ default:
+ break;
+ }
+ llvm_unreachable("unsupported function attribute kind");
+}
+
+// Get the enum name for the given function kind.
+std::string convertFunctionKindToEnumName(ScalarFnKind kind) {
+ switch (kind) {
+ case ScalarFnKind::Unary:
+ return std::string("UnaryFn");
+ case ScalarFnKind::Binary:
+ return std::string("BinaryFn");
+ case ScalarFnKind::Type:
+ return std::string("TypeFn");
+ }
+}
+
//===----------------------------------------------------------------------===//
// Templates
//===----------------------------------------------------------------------===//
@@ -693,8 +736,7 @@ static LogicalResult generateNamedGenericOpOds(LinalgOpConfig &opConfig,
interfaceNameList = interleaveToString(opConfig.metadata->implements, ", ");
if (llvm::any_of(opConfig.structuredOp->args, [](LinalgOperandDef &arg) {
- return arg.kind == LinalgOperandDefKind::IndexAttr ||
- arg.kind == LinalgOperandDefKind::TypeFnAttr;
+ return isAttribute(arg.kind);
})) {
SmallVector<std::string> attrDefs;
SmallVector<std::string> attrParams;
@@ -703,13 +745,14 @@ static LogicalResult generateNamedGenericOpOds(LinalgOpConfig &opConfig,
static const char paramFmt[] = "\"Attribute\":${0}";
static const char stmtFmt[] = "$_state.addAttribute(\"{0}\", {0});";
// Add the type conversion attributes to the op definition and builders.
- if (arg.kind == LinalgOperandDefKind::TypeFnAttr) {
+ if (isFunctionAttribute(arg.kind)) {
assert(arg.defaultFn.hasValue());
- static const char typeFmt[] = "TypeFn::{0}";
+ std::string enumName = convertOperandKindToEnumName(arg.kind);
+ static const char typeFmt[] = "{0}::{1}";
static const char defFmt[] = "DefaultValuedAttr<{0}, \"{1}\">:${2}";
- attrDefs.push_back(llvm::formatv(defFmt, "TypeFnAttr",
- llvm::formatv(typeFmt, arg.defaultFn),
- arg.name));
+ attrDefs.push_back(llvm::formatv(
+ defFmt, llvm::formatv("{0}Attr", enumName),
+ llvm::formatv(typeFmt, enumName, arg.defaultFn), arg.name));
attrParams.push_back(llvm::formatv(paramFmt, arg.name));
attrStmts.push_back(llvm::formatv(stmtFmt, arg.name));
}
@@ -1000,21 +1043,24 @@ void {0}::regionBuilder(ImplicitLocOpBuilder &b,
SmallVector<std::string> attrs;
SmallVector<std::string> stmts;
for (LinalgOperandDef &arg : args) {
- if (arg.kind != LinalgOperandDefKind::TypeFnAttr)
+ if (!isFunctionAttribute(arg.kind))
continue;
// Obtain the type function attribute values. Parameters.
- // {0}: attribute name
- // {1}: default type function name
+ // {0}: enum name
+ // {1}: attribute name
+ // {2}: default type function name
static const char attrDef[] = R"FMT(
-TypeFn {0}Val = TypeFn::{1};
-auto {0}Iter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {{
- return attr.getName() == "{0}"; });
-if ({0}Iter != attrs.end()) {{
- if (auto attr = {0}Iter->getValue().dyn_cast<TypeFnAttr>())
- {0}Val = attr.getValue();
+{0} {1}Val = {0}::{2};
+auto {1}Iter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {{
+ return attr.getName() == "{1}"; });
+if ({1}Iter != attrs.end()) {{
+ if (auto attr = {1}Iter->getValue().dyn_cast<{0}Attr>())
+ {1}Val = attr.getValue();
}
)FMT";
- attrs.push_back(llvm::formatv(attrDef, arg.name, arg.defaultFn));
+ std::string enumName = convertOperandKindToEnumName(arg.kind);
+ attrs.push_back(
+ llvm::formatv(attrDef, enumName, arg.name, arg.defaultFn));
}
for (LinalgOperandDef &arg : args) {
if (arg.kind != LinalgOperandDefKind::OutputTensor)
@@ -1056,71 +1102,59 @@ if ({0}Iter != attrs.end()) {{
cppIdent, *expression.index));
return cppIdent;
}
- if (expression.scalarFn &&
- expression.scalarFn->kind != ScalarFnKind::Type) {
- // Apply function.
- // Recursively generate operands.
- SmallVector<std::string> operandCppValues;
- for (ScalarExpression &operand : expression.scalarFn->operands) {
- auto operandCppValue = generateExpression(operand);
- if (!operandCppValue)
- return None;
- operandCppValues.push_back(*operandCppValue);
- }
-
- std::string prefix = expression.scalarFn->kind == ScalarFnKind::Unary
- ? "unary"
- : "binary";
- std::string cppIdent = llvm::formatv("value{0}", ++localCounter);
- stmts.push_back(
- llvm::formatv("Value {0} = helper.{1}__{2}({3});", cppIdent,
- prefix, expression.scalarFn->fnName,
- interleaveToString(operandCppValues, ", ")));
- return cppIdent;
- }
- if (expression.scalarFn &&
- expression.scalarFn->kind == ScalarFnKind::Type) {
- // Symbolic cast.
- // Operands must be arity 1.
- if (expression.scalarFn->operands.size() != 1) {
- emitError(genContext.getLoc())
- << "type conversion operand arity must be 1";
- return None;
+ if (expression.scalarFn) {
+ std::string enumName =
+ convertFunctionKindToEnumName(expression.scalarFn->kind);
+
+ // Get the function or attribute name.
+ assert(expression.scalarFn->fnName || expression.scalarFn->attrName);
+ std::string funcType;
+ if (expression.scalarFn->fnName) {
+ funcType = llvm::formatv("{0}::{1}", enumName,
+ *expression.scalarFn->fnName);
}
- Optional<std::string> operandCppValue =
- generateExpression(expression.scalarFn->operands[0]);
- if (!operandCppValue)
- return None;
-
- assert(expression.scalarFn->typeVar.hasValue());
- Optional<std::string> typeCppValue =
- findTypeValue(expression.scalarFn->typeVar.getValue(), args);
- if (!typeCppValue) {
- emitError(genContext.getLoc())
- << "type variable " << expression.scalarFn->typeVar.getValue()
- << ", used in a type conversion, must map to a predefined or "
- << "an argument type but it does not";
- return None;
- }
-
- // Use the function name or the attribute to build the type function.
- std::string typeFunc = llvm::formatv(
- "TypeFn::{0}", expression.scalarFn->fnName.getValueOr(""));
if (expression.scalarFn->attrName) {
if (llvm::none_of(args, [&](LinalgOperandDef &arg) {
- return arg.kind == LinalgOperandDefKind::TypeFnAttr &&
+ return isFunctionAttribute(arg.kind) &&
arg.name == expression.scalarFn->attrName.getValue();
})) {
emitError(genContext.getLoc())
- << "missing type function attribute "
+ << "missing function attribute "
<< expression.scalarFn->attrName.getValue();
}
- typeFunc = llvm::formatv("{0}Val", *expression.scalarFn->attrName);
+ funcType = llvm::formatv("{0}Val", *expression.scalarFn->attrName);
+ }
+ assert(!funcType.empty());
+
+ // Add the optional type parameter to the operands.
+ SmallVector<std::string> operandCppValues;
+ if (expression.scalarFn->kind == ScalarFnKind::Type) {
+ assert(expression.scalarFn->typeVar.hasValue());
+ Optional<std::string> typeCppValue =
+ findTypeValue(expression.scalarFn->typeVar.getValue(), args);
+ if (!typeCppValue) {
+ emitError(genContext.getLoc())
+ << "type variable " << expression.scalarFn->typeVar.getValue()
+ << ", used in a type conversion, must map to a predefined or "
+ << "an argument type but it does not";
+ return None;
+ }
+ operandCppValues.push_back(typeCppValue.getValue());
}
+
+ // Collect the scalar operands.
+ for (ScalarExpression &operand : expression.scalarFn->operands) {
+ auto operandCppValue = generateExpression(operand);
+ if (!operandCppValue)
+ return None;
+ operandCppValues.push_back(*operandCppValue);
+ }
+
+ // Call the function builder.
std::string cppIdent = llvm::formatv("value{0}", ++localCounter);
stmts.push_back(llvm::formatv(
- "Value {0} = helper.buildTypeFn({1}, {2}, {3});", cppIdent,
- typeFunc, typeCppValue.getValue(), *operandCppValue));
+ "Value {0} = helper.build{1}({2}, {3});", cppIdent, enumName,
+ funcType, interleaveToString(operandCppValues, ", ")));
return cppIdent;
}
emitError(genContext.getLoc()) << "unknown ScalarExpression type";
More information about the Mlir-commits
mailing list