[Mlir-commits] [mlir] cf05668 - [mlir][OpDSL] Rename `PrimFn` to `ArithFn`.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jan 7 04:42:10 PST 2022
Author: gysit
Date: 2022-01-07T12:38:03Z
New Revision: cf05668c17681f34ff1a5a8f9ca806b978090592
URL: https://github.com/llvm/llvm-project/commit/cf05668c17681f34ff1a5a8f9ca806b978090592
DIFF: https://github.com/llvm/llvm-project/commit/cf05668c17681f34ff1a5a8f9ca806b978090592.diff
LOG: [mlir][OpDSL] Rename `PrimFn` to `ArithFn`.
The revision renames `PrimFn` to `ArithFn`. The name resembles the newly introduced arith dialect that implements most of the arithmetic functions. An exception are log/exp that are part of the math dialect.
Depends On D115239
Reviewed By: stellaraccident
Differential Revision: https://reviews.llvm.org/D115240
Added:
Modified:
mlir/docs/Dialects/Linalg/OpDSL.md
mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py
mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
mlir/test/python/dialects/linalg/opdsl/assignments.py
mlir/test/python/dialects/linalg/opdsl/emit_misc.py
mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
Removed:
################################################################################
diff --git a/mlir/docs/Dialects/Linalg/OpDSL.md b/mlir/docs/Dialects/Linalg/OpDSL.md
index 0d4fabe646445..4a08f0bc4d389 100644
--- a/mlir/docs/Dialects/Linalg/OpDSL.md
+++ b/mlir/docs/Dialects/Linalg/OpDSL.md
@@ -177,14 +177,20 @@ TODO: Introduce a directive to fix the dimension bindings.
Reduction dimensions are inferred to be any dimensions on the RHS that are not
on the LHS.
-A number of arithmetic primitive functions are supported:
+A number of arithmetic functions are supported:
+
+* `ArithFn.add(a, b)` (also via overloading the binary `+` operator)
+* `ArithFn.exp(a)`
+* `ArithFn.log(a)`
+* `ArithFn.mul(a, b)` (also via overloading the binary `*` operator)
+* `ArithFn.max(a, b)`
+* `ArithFn.min(a, b)`
+* `ArithFn.sub(a, b)` (also via overloading the binary `-` operator)
+* `ArithFn.max_unsigned(a, b)`
+* `ArithFn.min_unsigned(a, b)`
-* `PrimFn.add(a, b)` (also via overloading the binary `+` operator)
-* `PrimFn.exp(a)`
-* `PrimFn.log(a)`
-* `PrimFn.mul(a, b)` (also via overloading the binary `*` operator)
-* `PrimFn.max(a, b)`
-* `PrimFn.sub(a, b)` (also via overloading the binary `-` operator)
+As the integer types are signless, signedness is implement by
diff erent
+functions that treat integers as signed or unsigned values.
Reduction functions can appear as the outer-most function on the RHS:
@@ -233,6 +239,8 @@ The following examples illustrate the lowering of signed and unsigned functions:
* cast(F32 -> I32) -> `arith.FPToSIOp`
* cast_unsigned(I32 -> I64) -> `arith.ExtUIOp`
* cast_unsigned(F32 -> I32) -> `arith.FPToUIOp`
+* max -> `arith.MaxSIOp`
+* max_unsinged -> `arith.MaxUIOp`
Not all functions are applicable for all numeric types, and on mismatch, op
verification will fail.
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index c298549e714e5..0dcd86a0a5e9b 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -41,13 +41,13 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: C
value: !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: add
operands:
- !ScalarExpression
scalar_arg: C
- !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: mul
operands:
- !ScalarExpression
@@ -105,13 +105,13 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: C
value: !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: add
operands:
- !ScalarExpression
scalar_arg: C
- !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: mul
operands:
- !ScalarExpression
@@ -179,17 +179,17 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: C
value: !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: add
operands:
- !ScalarExpression
scalar_arg: C
- !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: mul
operands:
- !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: sub
operands:
- !ScalarExpression
@@ -207,7 +207,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_arg: AZp
- !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: sub
operands:
- !ScalarExpression
@@ -276,13 +276,13 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: accum
value: !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: add
operands:
- !ScalarExpression
scalar_arg: accum
- !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: mul
operands:
- !ScalarExpression
@@ -341,13 +341,13 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: C
value: !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: add
operands:
- !ScalarExpression
scalar_arg: C
- !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: mul
operands:
- !ScalarExpression
@@ -416,17 +416,17 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: C
value: !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: add
operands:
- !ScalarExpression
scalar_arg: C
- !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: mul
operands:
- !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: sub
operands:
- !ScalarExpression
@@ -444,7 +444,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_arg: AZp
- !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: sub
operands:
- !ScalarExpression
@@ -501,13 +501,13 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: x
value: !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: add
operands:
- !ScalarExpression
scalar_arg: x
- !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: mul
operands:
- !ScalarExpression
@@ -564,13 +564,13 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: x
value: !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: add
operands:
- !ScalarExpression
scalar_arg: x
- !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: mul
operands:
- !ScalarExpression
@@ -628,13 +628,13 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: C
value: !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: add
operands:
- !ScalarExpression
scalar_arg: C
- !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: mul
operands:
- !ScalarExpression
@@ -690,13 +690,13 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: C
value: !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: add
operands:
- !ScalarExpression
scalar_arg: C
- !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: mul
operands:
- !ScalarExpression
@@ -753,13 +753,13 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: add
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: mul
operands:
- !ScalarExpression
@@ -818,13 +818,13 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: add
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: mul
operands:
- !ScalarExpression
@@ -886,13 +886,13 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: add
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: mul
operands:
- !ScalarExpression
@@ -964,13 +964,13 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: add
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: mul
operands:
- !ScalarExpression
@@ -1054,13 +1054,13 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: add
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: mul
operands:
- !ScalarExpression
@@ -1157,17 +1157,17 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: add
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: mul
operands:
- !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: sub
operands:
- !ScalarExpression
@@ -1185,7 +1185,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_arg: IZp
- !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: sub
operands:
- !ScalarExpression
@@ -1269,13 +1269,13 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: add
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: mul
operands:
- !ScalarExpression
@@ -1359,13 +1359,13 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: add
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: mul
operands:
- !ScalarExpression
@@ -1436,13 +1436,13 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: add
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: mul
operands:
- !ScalarExpression
@@ -1519,13 +1519,13 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: add
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: mul
operands:
- !ScalarExpression
@@ -1613,17 +1613,17 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: add
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: mul
operands:
- !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: sub
operands:
- !ScalarExpression
@@ -1641,7 +1641,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_arg: IZp
- !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: sub
operands:
- !ScalarExpression
@@ -1721,13 +1721,13 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: add
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: mul
operands:
- !ScalarExpression
@@ -1819,17 +1819,17 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: add
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: mul
operands:
- !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: sub
operands:
- !ScalarExpression
@@ -1847,7 +1847,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_arg: IZp
- !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: sub
operands:
- !ScalarExpression
@@ -1923,7 +1923,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: add
operands:
- !ScalarExpression
@@ -1994,7 +1994,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: max
operands:
- !ScalarExpression
@@ -2065,7 +2065,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: max_unsigned
operands:
- !ScalarExpression
@@ -2136,7 +2136,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: max
operands:
- !ScalarExpression
@@ -2207,7 +2207,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: min
operands:
- !ScalarExpression
@@ -2278,7 +2278,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: min_unsigned
operands:
- !ScalarExpression
@@ -2355,7 +2355,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: add
operands:
- !ScalarExpression
@@ -2432,7 +2432,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: max
operands:
- !ScalarExpression
@@ -2509,7 +2509,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: min
operands:
- !ScalarExpression
@@ -2572,15 +2572,15 @@ structured_op: !LinalgStructuredOpConfig
type_var: T
operands:
- !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: add
operands:
- !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: mul
operands:
- !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: add
operands:
- !ScalarExpression
@@ -2596,15 +2596,15 @@ structured_op: !LinalgStructuredOpConfig
type_var: F64
operands:
- !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: add
operands:
- !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: mul
operands:
- !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: add
operands:
- !ScalarExpression
@@ -2615,15 +2615,15 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_index: 1
- !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: add
operands:
- !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: mul
operands:
- !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: add
operands:
- !ScalarExpression
@@ -2664,11 +2664,11 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_const: '12345 : i64'
- !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: mul
operands:
- !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: sub
operands:
- !ScalarExpression
@@ -2716,11 +2716,11 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: log
operands:
- !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: add
operands:
- !ScalarExpression
@@ -2731,7 +2731,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_const: '1.000000e+00 : f64'
- !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: exp
operands:
- !ScalarExpression
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 6d2803f3f7adb..7115f9414b76a 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -148,11 +148,11 @@ static LogicalResult foldMemRefCastInTiledLoopOp(TiledLoopOp op) {
// TODO: Move this to a utility library.
// The public methods on this class are referenced directly from generated code
// and bind by name to math and type conversion functions in the DSL as:
-// `applyfn__{fnName}`
+// `arithfn__{fnName}`
// `typefn__{fnName}`
// Examples:
-// `applyfn__add`
-// `applyfn__mul`
+// `arithfn__add`
+// `arithfn__mul`
// `typefn__cast`
// The naming convention is intentional in order to match snake-cased DSL names.
// See mlir-linalg-ods-yaml-gen.cpp for the code that mates to this class.
@@ -241,7 +241,7 @@ class RegionBuilderHelper {
}
// NOLINTNEXTLINE(*-identifier-naming): externally called.
- Value applyfn__add(Value lhs, Value rhs) {
+ Value arithfn__add(Value lhs, Value rhs) {
OpBuilder builder = getBuilder();
if (isFloatingPoint(lhs))
return builder.create<arith::AddFOp>(lhs.getLoc(), lhs, rhs);
@@ -251,7 +251,7 @@ class RegionBuilderHelper {
}
// NOLINTNEXTLINE(*-identifier-naming): externally called.
- Value applyfn__exp(Value x) {
+ Value arithfn__exp(Value x) {
OpBuilder builder = getBuilder();
if (isFloatingPoint(x))
return builder.create<math::ExpOp>(x.getLoc(), x);
@@ -259,7 +259,7 @@ class RegionBuilderHelper {
}
// NOLINTNEXTLINE(*-identifier-naming): externally called.
- Value applyfn__log(Value x) {
+ Value arithfn__log(Value x) {
OpBuilder builder = getBuilder();
if (isFloatingPoint(x))
return builder.create<math::LogOp>(x.getLoc(), x);
@@ -267,7 +267,7 @@ class RegionBuilderHelper {
}
// NOLINTNEXTLINE(*-identifier-naming): externally called.
- Value applyfn__sub(Value lhs, Value rhs) {
+ Value arithfn__sub(Value lhs, Value rhs) {
OpBuilder builder = getBuilder();
if (isFloatingPoint(lhs))
return builder.create<arith::SubFOp>(lhs.getLoc(), lhs, rhs);
@@ -277,7 +277,7 @@ class RegionBuilderHelper {
}
// NOLINTNEXTLINE(*-identifier-naming): externally called.
- Value applyfn__mul(Value lhs, Value rhs) {
+ Value arithfn__mul(Value lhs, Value rhs) {
OpBuilder builder = getBuilder();
if (isFloatingPoint(lhs))
return builder.create<arith::MulFOp>(lhs.getLoc(), lhs, rhs);
@@ -287,7 +287,7 @@ class RegionBuilderHelper {
}
// NOLINTNEXTLINE(*-identifier-naming): externally called.
- Value applyfn__max(Value lhs, Value rhs) {
+ Value arithfn__max(Value lhs, Value rhs) {
OpBuilder builder = getBuilder();
if (isFloatingPoint(lhs))
return builder.create<arith::MaxFOp>(lhs.getLoc(), lhs, rhs);
@@ -297,7 +297,7 @@ class RegionBuilderHelper {
}
// NOLINTNEXTLINE(*-identifier-naming): externally called.
- Value applyfn__max_unsigned(Value lhs, Value rhs) {
+ Value arithfn__max_unsigned(Value lhs, Value rhs) {
OpBuilder builder = getBuilder();
if (isFloatingPoint(lhs))
return builder.create<arith::MaxFOp>(lhs.getLoc(), lhs, rhs);
@@ -307,7 +307,7 @@ class RegionBuilderHelper {
}
// NOLINTNEXTLINE(*-identifier-naming): externally called.
- Value applyfn__min(Value lhs, Value rhs) {
+ Value arithfn__min(Value lhs, Value rhs) {
OpBuilder builder = getBuilder();
if (isFloatingPoint(lhs))
return builder.create<arith::MinFOp>(lhs.getLoc(), lhs, rhs);
@@ -317,7 +317,7 @@ class RegionBuilderHelper {
}
// NOLINTNEXTLINE(*-identifier-naming): externally called.
- Value applyfn__min_unsigned(Value lhs, Value rhs) {
+ Value arithfn__min_unsigned(Value lhs, Value rhs) {
OpBuilder builder = getBuilder();
if (isFloatingPoint(lhs))
return builder.create<arith::MinFOp>(lhs.getLoc(), lhs, rhs);
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
index be7fc02d04288..fd0fa72661909 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
@@ -77,13 +77,13 @@ def visit_scalar_def(expr):
self.visit_tensor_exprs(visit_scalar_def)
def __add__(self, rhs: "TensorExpression") -> "TensorExpression":
- return PrimFn.add(self, rhs)
+ return ArithFn.add(self, rhs)
def __mul__(self, rhs) -> "TensorExpression":
- return PrimFn.mul(self, rhs)
+ return ArithFn.mul(self, rhs)
def __sub__(self, rhs) -> "TensorExpression":
- return PrimFn.sub(self, rhs)
+ return ArithFn.sub(self, rhs)
def __hash__(self):
return hash(id(self))
@@ -347,42 +347,55 @@ class TypeFn:
cast_unsigned = TypeFnType("cast_unsigned")
-class PrimFnType:
- """Primitive operations."""
+class ArithFnType:
+ """Arithmetic function.
- def __init__(self, prim_name: str):
- self.prim_name = prim_name
+ An arithmetic function takes one ore more tensor expressions and returns the
+ function evaluation result.
+ """
- def __call__(self, *args):
- return PrimApply(self, args)
+ def __init__(self, fn_name: str):
+ self.fn_name = fn_name
+
+ def __call__(self, *args) -> "TensorArithFn":
+ return TensorArithFn(self, args)
def reduce(self, *reduce_dims: DimDef):
- """Shortcut to create a Reduce operation from this primitive."""
+ """Shortcut to create a Reduce operation from this function."""
return ReduceFnType(self, *reduce_dims)
def __repr__(self):
- return f"{self.prim_name}"
+ return f"{self.fn_name}"
-class PrimFn:
- add = PrimFnType("add")
- exp = PrimFnType("exp")
- log = PrimFnType("log")
- mul = PrimFnType("mul")
- max = PrimFnType("max")
- min = PrimFnType("min")
- sub = PrimFnType("sub")
- max_unsigned = PrimFnType("max_unsigned")
- min_unsigned = PrimFnType("min_unsigned")
+class ArithFn:
+ """Arithmetic function namespace.
+
+ As the integer types are signless, signedness is implement by
diff erent
+ functions that treat integers as signed or unsigned values.
+
+ Examples:
+ - max -> `arith.MaxSIOp`
+ - max_unsinged -> `arith.MaxUIOp`
+ """
+ add = ArithFnType("add")
+ exp = ArithFnType("exp")
+ log = ArithFnType("log")
+ mul = ArithFnType("mul")
+ max = ArithFnType("max")
+ min = ArithFnType("min")
+ sub = ArithFnType("sub")
+ max_unsigned = ArithFnType("max_unsigned")
+ min_unsigned = ArithFnType("min_unsigned")
class ReduceFnType:
"""A reduction operator that reduces into its LHS from its RHS."""
- def __init__(self, operator: PrimFnType, *reduce_dims: DimDef):
- """Initializes the ReduceFn with a primitive function and dims."""
- if not isinstance(operator, PrimFnType):
- raise ValueError(f"Reduce expected a Prim operator but got {operator}")
+ def __init__(self, operator: ArithFnType, *reduce_dims: DimDef):
+ """Initializes the ReduceFn with an airthmetic function and dims."""
+ if not isinstance(operator, ArithFnType):
+ raise ValueError(f"Reduce expected a ArithFnType but got {operator}")
self.operator = operator
self.reduce_dims = tuple(reduce_dims)
@@ -390,28 +403,28 @@ def __call__(self, *args: TensorExpression):
return ReduceApply(self, args)
def __repr__(self):
- return (f"reduce_{self.operator.prim_name}"
+ return (f"reduce_{self.operator.fn_name}"
f"({', '.join(repr(d) for d in self.reduce_dims)})")
class ReduceFn:
- add = PrimFn.add.reduce
- mul = PrimFn.mul.reduce
- max = PrimFn.max.reduce
- min = PrimFn.min.reduce
- max_unsigned = PrimFn.max_unsigned.reduce
- min_unsigned = PrimFn.min_unsigned.reduce
+ add = ArithFn.add.reduce
+ mul = ArithFn.mul.reduce
+ max = ArithFn.max.reduce
+ min = ArithFn.min.reduce
+ max_unsigned = ArithFn.max_unsigned.reduce
+ min_unsigned = ArithFn.min_unsigned.reduce
-class PrimApply(TensorExpression):
- """Application of a primitive."""
+class TensorArithFn(TensorExpression):
+ """Application of an arithmetic function."""
- def __init__(self, prim: PrimFnType, args: Sequence[TensorExpression]):
- self.prim = prim
+ def __init__(self, arith_fn: ArithFnType, args: Sequence[TensorExpression]):
+ self.arith_fn = arith_fn
self.args = tuple(args)
def to_scalar_expression(self) -> ScalarExpression:
- return ScalarApplyFn(self.prim.prim_name,
+ return ScalarArithFn(self.arith_fn.fn_name,
*[arg.to_scalar_expression() for arg in self.args
]).expr()
@@ -421,7 +434,7 @@ def visit_tensor_exprs(self, callback):
arg.visit_tensor_exprs(callback)
def __repr__(self):
- return f"{repr(self.prim)}({', '.join(repr(a) for a in self.args)})"
+ return f"{repr(self.arith_fn)}({', '.join(repr(a) for a in self.args)})"
class TensorTypeFn(TensorExpression):
@@ -503,7 +516,7 @@ def to_scalar_expression(self) -> ScalarExpression:
f"bound to its lhs: {self}")
full_args = [self.lhs.to_scalar_expression()
] + [arg.to_scalar_expression() for arg in self.args]
- return ScalarApplyFn(self.reduce.operator.prim_name, *full_args).expr()
+ return ScalarArithFn(self.reduce.operator.fn_name, *full_args).expr()
def visit_tensor_exprs(self, callback):
for arg in self.args:
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
index df91b9670a44d..22568c8b67487 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
@@ -221,10 +221,10 @@ def expression(self, expr: ScalarExpression) -> Value:
dim_attr = IntegerAttr.get(
IntegerType.get_signless(64), expr.scalar_index.dim)
return linalg.IndexOp(dim_attr).result
- elif expr.scalar_apply:
- fn = self._get_function(f"_eval_{expr.scalar_apply.fn_name}")
+ elif expr.arith_fn:
+ fn = self._get_function(f"_arithfn_{expr.arith_fn.fn_name}")
operand_values = [
- self.expression(operand) for operand in expr.scalar_apply.operands
+ self.expression(operand) for operand in expr.arith_fn.operands
]
return fn(*operand_values)
elif expr.type_fn:
@@ -310,59 +310,59 @@ def _typefn_cast(self, type_var_name: str, operand: Value) -> Value:
def _typefn_cast_unsigned(self, type_var_name: str, operand: Value) -> Value:
return self._cast(type_var_name, operand, True)
- def _eval_add(self, lhs: Value, rhs: Value) -> Value:
+ def _arithfn_add(self, lhs: Value, rhs: Value) -> Value:
if _is_floating_point_type(lhs.type):
return arith.AddFOp(lhs, rhs).result
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
return arith.AddIOp(lhs, rhs).result
raise NotImplementedError("Unsupported 'add' operand: {lhs}")
- def _eval_exp(self, x: Value) -> Value:
+ def _arithfn_exp(self, x: Value) -> Value:
if _is_floating_point_type(x.type):
return math.ExpOp(x).result
raise NotImplementedError("Unsupported 'exp' operand: {x}")
- def _eval_log(self, x: Value) -> Value:
+ def _arithfn_log(self, x: Value) -> Value:
if _is_floating_point_type(x.type):
return math.LogOp(x).result
raise NotImplementedError("Unsupported 'log' operand: {x}")
- def _eval_sub(self, lhs: Value, rhs: Value) -> Value:
+ def _arithfn_sub(self, lhs: Value, rhs: Value) -> Value:
if _is_floating_point_type(lhs.type):
return arith.SubFOp(lhs, rhs).result
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
return arith.SubIOp(lhs, rhs).result
raise NotImplementedError("Unsupported 'sub' operand: {lhs}")
- def _eval_mul(self, lhs: Value, rhs: Value) -> Value:
+ def _arithfn_mul(self, lhs: Value, rhs: Value) -> Value:
if _is_floating_point_type(lhs.type):
return arith.MulFOp(lhs, rhs).result
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
return arith.MulIOp(lhs, rhs).result
raise NotImplementedError("Unsupported 'mul' operand: {lhs}")
- def _eval_max(self, lhs: Value, rhs: Value) -> Value:
+ def _arithfn_max(self, lhs: Value, rhs: Value) -> Value:
if _is_floating_point_type(lhs.type):
return arith.MaxFOp(lhs, rhs).result
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
return arith.MaxSIOp(lhs, rhs).result
raise NotImplementedError("Unsupported 'max' operand: {lhs}")
- def _eval_max_unsigned(self, lhs: Value, rhs: Value) -> Value:
+ def _arithfn_max_unsigned(self, lhs: Value, rhs: Value) -> Value:
if _is_floating_point_type(lhs.type):
return arith.MaxFOp(lhs, rhs).result
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
return arith.MaxUIOp(lhs, rhs).result
raise NotImplementedError("Unsupported 'max_unsigned' operand: {lhs}")
- def _eval_min(self, lhs: Value, rhs: Value) -> Value:
+ def _arithfn_min(self, lhs: Value, rhs: Value) -> Value:
if _is_floating_point_type(lhs.type):
return arith.MinFOp(lhs, rhs).result
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
return arith.MinSIOp(lhs, rhs).result
raise NotImplementedError("Unsupported 'min' operand: {lhs}")
- def _eval_min_unsigned(self, lhs: Value, rhs: Value) -> Value:
+ def _arithfn_min_unsigned(self, lhs: Value, rhs: Value) -> Value:
if _is_floating_point_type(lhs.type):
return arith.MinFOp(lhs, rhs).result
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py
index c6b1b3885425f..2a30e6e78df97 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py
@@ -20,7 +20,7 @@
__all__ = [
"ScalarAssign",
- "ScalarApplyFn",
+ "ScalarArithFn",
"ScalarTypeFn",
"ScalarArg",
"ScalarConst",
@@ -29,18 +29,18 @@
]
-class ScalarApplyFn:
- """A type of ScalarExpression that applies a named function to operands."""
+class ScalarArithFn:
+ """A type of ScalarExpression that applies an arithmetic function."""
def __init__(self, fn_name: str, *operands: "ScalarExpression"):
self.fn_name = fn_name
self.operands = operands
def expr(self) -> "ScalarExpression":
- return ScalarExpression(scalar_apply=self)
+ return ScalarExpression(arith_fn=self)
def __repr__(self):
- return f"ScalarApplyFn<{self.fn_name}>({', '.join(self.operands)})"
+ return f"ScalarArithFn<{self.fn_name}>({', '.join(self.operands)})"
class ScalarTypeFn:
@@ -102,7 +102,7 @@ class ScalarExpression(YAMLObject):
"""An expression on scalar values.
Can be one of:
- - ScalarApplyFn
+ - ScalarArithFn
- ScalarTypeFn
- ScalarArg
- ScalarConst
@@ -112,27 +112,27 @@ class ScalarExpression(YAMLObject):
yaml_tag = "!ScalarExpression"
def __init__(self,
- scalar_apply: Optional[ScalarApplyFn] = None,
+ arith_fn: Optional[ScalarArithFn] = None,
type_fn: Optional[ScalarTypeFn] = None,
scalar_arg: Optional[ScalarArg] = None,
scalar_const: Optional[ScalarConst] = None,
scalar_index: Optional[ScalarIndex] = None):
- if (bool(scalar_apply) + bool(type_fn) + bool(scalar_arg) +
- bool(scalar_const) + bool(scalar_index)) != 1:
- raise ValueError("One of 'scalar_apply', 'type_fn', 'scalar_arg', "
+ if (bool(arith_fn) + bool(type_fn) + bool(scalar_arg) + bool(scalar_const) +
+ bool(scalar_index)) != 1:
+ raise ValueError("One of 'arith_fn', 'type_fn', 'scalar_arg', "
"'scalar_const', 'scalar_index', must be specified")
- self.scalar_apply = scalar_apply
+ self.arith_fn = arith_fn
self.type_fn = type_fn
self.scalar_arg = scalar_arg
self.scalar_const = scalar_const
self.scalar_index = scalar_index
def to_yaml_custom_dict(self):
- if self.scalar_apply:
+ if self.arith_fn:
return dict(
- scalar_apply=dict(
- fn_name=self.scalar_apply.fn_name,
- operands=list(self.scalar_apply.operands),
+ arith_fn=dict(
+ fn_name=self.arith_fn.fn_name,
+ operands=list(self.arith_fn.operands),
))
if self.type_fn:
# Note that even though operands must be arity 1, we write it the
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index 173af1a3fe401..afc078d509c54 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -665,4 +665,4 @@ def soft_plus_2d(
"""
domain(D.m, D.n)
O[D.m, D.n] = \
- PrimFn.log(TypeFn.cast(U, const(1.0)) + PrimFn.exp(TypeFn.cast(U, I[D.m, D.n])))
+ ArithFn.log(TypeFn.cast(U, const(1.0)) + ArithFn.exp(TypeFn.cast(U, I[D.m, D.n])))
diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
index a02b1da70a2d4..3634f4f83dd45 100644
--- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
+++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
@@ -34,7 +34,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
- scalar_apply:
+ arith_fn:
fn_name: add
operands:
- !ScalarExpression
@@ -89,7 +89,7 @@ structured_op: !LinalgStructuredOpConfig
# IMPL-DAG: Value [[VAL1:[a-z0-9]+]] = helper.typefn__cast(block.getArgument(0).getType(), [[VAL0]]);
# IMPL-DAG: Value [[VAL2:[a-z0-9]+]] = helper.index(1);
# IMPL-DAG: Value [[VAL3:[a-z0-9]+]] = helper.typefn__cast_unsigned(block.getArgument(0).getType(), [[VAL2]]);
-# IMPL-DAG: Value [[VAL4:[a-z0-9]+]] = helper.applyfn__add([[VAL1]], [[VAL3]]);
+# IMPL-DAG: Value [[VAL4:[a-z0-9]+]] = helper.arithfn__add([[VAL1]], [[VAL3]]);
# @linalg_structured_op
diff --git a/mlir/test/python/dialects/linalg/opdsl/assignments.py b/mlir/test/python/dialects/linalg/opdsl/assignments.py
index 8b235dfed38df..eb720b744fed5 100644
--- a/mlir/test/python/dialects/linalg/opdsl/assignments.py
+++ b/mlir/test/python/dialects/linalg/opdsl/assignments.py
@@ -9,10 +9,10 @@
# CHECK: -
# CHECK: arg: C
# CHECK: value:
-# CHECK: scalar_apply:
+# CHECK: arith_fn:
# CHECK: fn_name: add
# CHECK: operands:
-# CHECK: scalar_apply:
+# CHECK: arith_fn:
# CHECK: fn_name: mul
# CHECK: operands:
# CHECK: type_fn:
@@ -36,10 +36,10 @@ def matmul(
# CHECK: assignments:
# CHECK: -
# CHECK: arg: O
-# CHECK: scalar_apply:
+# CHECK: arith_fn:
# CHECK: fn_name: sub
# CHECK: operands:
-# CHECK: scalar_apply:
+# CHECK: arith_fn:
# CHECK: fn_name: add
# CHECK: operands:
# CHECK: type_fn:
@@ -67,7 +67,7 @@ def constants(O=TensorDef(T, S.M, S.K, output=True)):
# CHECK: assignments:
# CHECK: -
# CHECK: arg: O
-# CHECK: scalar_apply:
+# CHECK: arith_fn:
# CHECK: fn_name: add
# CHECK: operands:
# CHECK: scalar_index: 1
diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_misc.py b/mlir/test/python/dialects/linalg/opdsl/emit_misc.py
index 355d00a02ac89..9f4872e33e8d3 100644
--- a/mlir/test/python/dialects/linalg/opdsl/emit_misc.py
+++ b/mlir/test/python/dialects/linalg/opdsl/emit_misc.py
@@ -35,8 +35,8 @@ def fill_rng_poly(
@linalg_structured_op
def soft_plus_poly(
I=TensorDef(T, S.M, S.N), O=TensorDef(U, S.M, S.N, output=True)):
- O[D.m, D.n] = PrimFn.log(
- TypeFn.cast(U, const(1.0)) + TypeFn.cast(U, PrimFn.exp(I[D.m, D.n])))
+ O[D.m, D.n] = ArithFn.log(
+ TypeFn.cast(U, const(1.0)) + TypeFn.cast(U, ArithFn.exp(I[D.m, D.n])))
@linalg_structured_op(op_name="custom_op_name")
diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
index 38c1ea6b049ce..aeb469b12e30e 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
@@ -82,7 +82,7 @@ struct LinalgIndexingMapsConfig {
struct ScalarExpression;
-struct ScalarApply {
+struct ScalarArithFn {
std::string fnName;
// NOTE: Must be pure heap allocated container (not SmallVector)
// due to recursive data type.
@@ -101,7 +101,7 @@ struct ScalarExpression {
Optional<std::string> arg;
Optional<std::string> constant;
Optional<int64_t> index;
- Optional<ScalarApply> apply;
+ Optional<ScalarArithFn> arithFn;
Optional<ScalarTypeFn> typeFn;
};
@@ -245,9 +245,10 @@ struct MappingTraits<ScalarAssign> {
};
/// A scalar expression (RHS of an assignment). Must be one of:
-/// - `scalar_arg`: Name of an argument to the op.
-/// - `scalar_apply`: Result of evaluating a named function (see
-/// `ScalarApply`).
+/// - `scalar_arg`: An operation argument.
+/// - `scalar_const`: A constant definition.
+/// - `scalar_index`: An iteration index.
+/// - `arith_fn`: A named arithmetic function (see `ScalarArithFn`).
/// - `type_fn`: A named type conversion function (see `ScalarTypeFn`).
template <>
struct MappingTraits<ScalarExpression> {
@@ -255,7 +256,7 @@ struct MappingTraits<ScalarExpression> {
io.mapOptional("scalar_arg", info.arg);
io.mapOptional("scalar_const", info.constant);
io.mapOptional("scalar_index", info.index);
- io.mapOptional("scalar_apply", info.apply);
+ io.mapOptional("arith_fn", info.arithFn);
io.mapOptional("type_fn", info.typeFn);
}
};
@@ -266,8 +267,8 @@ struct MappingTraits<ScalarExpression> {
/// - `add(lhs, rhs)`
/// - `mul(lhs, rhs)`
template <>
-struct MappingTraits<ScalarApply> {
- static void mapping(IO &io, ScalarApply &info) {
+struct MappingTraits<ScalarArithFn> {
+ static void mapping(IO &io, ScalarArithFn &info) {
io.mapRequired("fn_name", info.fnName);
io.mapRequired("operands", info.operands);
}
@@ -944,11 +945,11 @@ void {0}::regionBuilder(ImplicitLocOpBuilder &b, Block &block) {{
cppIdent, *expression.index));
return cppIdent;
}
- if (expression.apply) {
+ if (expression.arithFn) {
// Apply function.
// Recursively generate operands.
SmallVector<std::string> operandCppValues;
- for (ScalarExpression &operand : expression.apply->operands) {
+ for (ScalarExpression &operand : expression.arithFn->operands) {
auto operandCppValue = generateExpression(operand);
if (!operandCppValue)
return None;
@@ -956,8 +957,8 @@ void {0}::regionBuilder(ImplicitLocOpBuilder &b, Block &block) {{
}
std::string cppIdent = llvm::formatv("value{0}", ++localCounter);
stmts.push_back(
- llvm::formatv("Value {0} = helper.applyfn__{1}({2});", cppIdent,
- expression.apply->fnName,
+ llvm::formatv("Value {0} = helper.arithfn__{1}({2});", cppIdent,
+ expression.arithFn->fnName,
interleaveToString(operandCppValues, ", ")));
return cppIdent;
}
More information about the Mlir-commits
mailing list