[Mlir-commits] [mlir] [MLIR][Linalg] Ternary Op & Linalg select (PR #91461)
Petr Kurapov
llvmlistbot at llvm.org
Wed May 8 06:07:21 PDT 2024
https://github.com/kurapov-peter updated https://github.com/llvm/llvm-project/pull/91461
>From f3eb7a92f58074565833266598e8f13438f97c17 Mon Sep 17 00:00:00 2001
From: Renato Golin <rengolin at systemcall.eu>
Date: Mon, 29 Apr 2024 23:41:29 +0100
Subject: [PATCH 1/5] [MLIR][Linalg] Add ternary select named op (in progress)
---
.../mlir/Dialect/Linalg/IR/LinalgBase.td | 3 +
.../mlir/Dialect/Linalg/IR/LinalgEnums.td | 6 ++
.../Linalg/IR/LinalgNamedStructuredOps.yaml | 57 +++++++++++++++++++
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 16 ++++++
.../linalg/opdsl/lang/comprehension.py | 56 +++++++++++++++++-
.../dialects/linalg/opdsl/lang/emitter.py | 7 +++
.../linalg/opdsl/ops/core_named_ops.py | 20 +++++++
7 files changed, 163 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
index e87e8b5600107..73f984dc072d3 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
@@ -68,6 +68,9 @@ def UnaryFnAttr : EnumAttr<Linalg_Dialect, UnaryFn, "unary_fn"> {
def BinaryFnAttr : EnumAttr<Linalg_Dialect, BinaryFn, "binary_fn"> {
let assemblyFormat = "`<` $value `>`";
}
+def TernaryFnAttr : EnumAttr<Linalg_Dialect, TernaryFn, "ternary_fn"> {
+ let assemblyFormat = "`<` $value `>`";
+}
def TypeFnAttr : EnumAttr<Linalg_Dialect, TypeFn, "type_fn"> {
let assemblyFormat = "`<` $value `>`";
}
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
index 6b4b073fc6724..e615876a95d05 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
@@ -49,6 +49,12 @@ def BinaryFn : I32EnumAttr<"BinaryFn", "", [
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::linalg";
}
+def TernaryFn : I32EnumAttr<"TernaryFn", "", [
+ I32EnumAttrCase<"select", 0>
+]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::linalg";
+}
def TypeFn : I32EnumAttr<"TypeFn", "", [
I32EnumAttrCase<"cast_signed", 0>,
I32EnumAttrCase<"cast_unsigned", 1>
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 584bfcd8b59dc..7350271aa3829 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -1008,6 +1008,63 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_arg: rhs
--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+ name: select
+ cpp_class_name: SelectOp
+ doc: |-
+ Chooses one value based on a binary condition supplied as its first operand.
+
+ The shapes and element types must be identical. The appropriate casts,
+ broadcasts and reductions should be done previously to calling this op.
+
+ This means reduction/broadcast/element cast semantics is explicit. Further
+ passes can take that into account when lowering this code. For example,
+ a `linalg.broadcast` + `linalg.select` sequence can be lowered to a
+ `linalg.generic` with different affine maps for the two operands.
+structured_op: !LinalgStructuredOpConfig
+ args:
+ - !LinalgOperandDefConfig
+ name: cond
+ kind: input_tensor
+ type_var: T1
+ shape_map: affine_map<() -> ()>
+ - !LinalgOperandDefConfig
+ name: lhs
+ kind: input_tensor
+ type_var: T1
+ shape_map: affine_map<() -> ()>
+ - !LinalgOperandDefConfig
+ name: rhs
+ kind: input_tensor
+ type_var: T1
+ shape_map: affine_map<() -> ()>
+ - !LinalgOperandDefConfig
+ name: O
+ kind: output_tensor
+ type_var: T1
+ shape_map: affine_map<() -> ()>
+ indexing_maps: !LinalgIndexingMapsConfig
+ static_indexing_maps:
+ - affine_map<() -> ()>
+ - affine_map<() -> ()>
+ - affine_map<() -> ()>
+ - affine_map<() -> ()>
+ iterator_types: []
+ assignments:
+ - !ScalarAssign
+ arg: O
+ value: !ScalarExpression
+ scalar_fn:
+ kind: ternary
+ fn_name: select
+ operands:
+ - !ScalarExpression
+ scalar_arg: cond
+ - !ScalarExpression
+ scalar_arg: lhs
+ - !ScalarExpression
+ scalar_arg: rhs
+--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: matmul
cpp_class_name: MatmulOp
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index e5f83331baf81..4933d71ac7fb0 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -492,6 +492,22 @@ class RegionBuilderHelper {
llvm_unreachable("unsupported binary function");
}
+ // Build the ternary functions defined by OpDSL.
+ Value buildTernaryFn(TernaryFn ternaryFn, Value arg0, Value arg1, Value arg2) {
+ bool headBool = isInteger(arg0) && arg0.getType().getIntOrFloatBitWidth() == 1;
+ bool tailFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1) && isFloatingPoint(arg2);
+ bool tailnteger = isInteger(arg0) && isInteger(arg1) && isInteger(arg1);
+ OpBuilder::InsertionGuard g(builder);
+ builder.setInsertionPointToEnd(&block);
+ switch (ternaryFn) {
+ case TernaryFn::select:
+ if (!headBool && !(tailFloatingPoint || tailInteger))
+ llvm_unreachable("unsupported non numeric type");
+ return builder.create<arith::SelectOp>(arg0.getLoc(), arg0, arg1, arg2);
+ }
+ llvm_unreachable("unsupported ternary function");
+ }
+
// Build the type functions defined by OpDSL.
Value buildTypeFn(TypeFn typeFn, Type toType, Value operand) {
switch (typeFn) {
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
index bb43ebf2b6923..880dcb7250b96 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
@@ -262,7 +262,8 @@ def __repr__(self):
class FunctionKind(Enum):
UNARY = 0
BINARY = 1
- TYPE = 2
+ TERNARY = 2
+ TYPE = 3
class UnaryFnType:
@@ -339,6 +340,30 @@ class BinaryFn:
powf = BinaryFnType("powf")
+class TernaryFnType:
+ """Ternary function.
+
+ A bterary function takes three tensor expressions and returns the
+ function evaluation result.
+ """
+
+ def __init__(self, fn_name: str):
+ self.fn_name = fn_name
+
+ def __call__(self, arg0: TensorExpression, arg1: TensorExpression, arg2: TensorExpression) -> "TensorFn":
+ return TensorFn(FunctionKind.TERNARY, self.fn_name, None, None, [arg0, arg1, arg2])
+
+ def __repr__(self):
+ return f"{self.fn_name}"
+
+
+class TernaryFn:
+ """Ternary function namespace.
+ """
+
+ select = TernaryFnType("select")
+
+
class TypeFnType:
"""Type conversion function.
@@ -437,7 +462,8 @@ class OperandKind(Enum):
INDEX_ATTR = 3
UNARY_FN_ATTR = 4
BINARY_FN_ATTR = 5
- TYPE_FN_ATTR = 6
+ TERNARY_FN_ATTR = 6
+ TYPE_FN_ATTR = 7
class OperandDef:
@@ -489,6 +515,7 @@ def is_attribute(self) -> bool:
self.kind == OperandKind.INDEX_ATTR
or self.kind == OperandKind.UNARY_FN_ATTR
or self.kind == OperandKind.BINARY_FN_ATTR
+ or self.kind == OperandKind.TERNARY_FN_ATTR
or self.kind == OperandKind.TYPE_FN_ATTR
)
@@ -670,6 +697,31 @@ def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse:
return ReduceFnUse(None, self, *reduce_dims)
+class TernaryFnAttrDef:
+ """Ternary function attribute definition.
+
+ Ternary function attributes provide a way to make the arithmetic computation
+ parametrizable. Every attribute specifies a default Ternary function
+ that may be overwritten at operation instantiation time.
+ """
+
+ def __init__(self, default: "TernaryFnType"):
+ if not isinstance(default, TernaryFnType):
+ raise ValueError(
+ f"TernaryFnAttrDef requires default of type TernaryFnType "
+ f"but got {default}"
+ )
+ self.operand_def = OperandDef(
+ OperandKind.TERNARY_FN_ATTR, default_fn=default.fn_name
+ )
+
+ def __call__(self, arg0: TensorExpression, arg1: TensorExpression) -> TensorFn:
+ return TensorFn(FunctionKind.TERNARY, None, self.operand_def, None, [arg0, arg1])
+
+ def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse:
+ return ReduceFnUse(None, self, *reduce_dims)
+
+
class TypeFnAttrDef:
"""Type conversion function attribute definition.
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
index f91fc8b716008..845b533db52a9 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
@@ -60,6 +60,7 @@ def prepare_common_structured_op(
in [
OperandKind.UNARY_FN_ATTR,
OperandKind.BINARY_FN_ATTR,
+ OperandKind.TERNARY_FN_ATTR,
OperandKind.TYPE_FN_ATTR,
]
]
@@ -180,6 +181,12 @@ def prepare_common_structured_op(
f"Attribute {fn_attr.name} needs to be of type "
f"BinaryFnType but got {type(attr_val)}"
)
+ elif attr_kind == OperandKind.TERNARY_FN_ATTR:
+ if not isinstance(fn, TernaryFnType):
+ raise ValueError(
+ f"Attribute {fn_attr.name} needs to be of type "
+ f"TernaryFnType but got {type(attr_val)}"
+ )
else:
if not isinstance(fn, TypeFnType):
raise ValueError(
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index ca2bb0c5f7f8a..3cf3e9de0136f 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -351,6 +351,26 @@ def powf(
O[None] = BinaryFn.powf(lhs[None], rhs[None])
+ at linalg_structured_op
+def selectf(
+ cond=TensorDef(bool),
+ lhs=TensorDef(T1),
+ rhs=TensorDef(T1),
+ O=TensorDef(T1, output=True),
+):
+ """Chooses one value based on a binary condition supplied as its first operand.
+
+ The shapes and element types must be identical. The appropriate casts,
+ broadcasts and reductions should be done previously to calling this op.
+
+ This means reduction/broadcast/element cast semantics is explicit. Further
+ passes can take that into account when lowering this code. For example,
+ a `linalg.broadcast` + `linalg.select` sequence can be lowered to a
+ `linalg.generic` with different affine maps for the two operands.
+ """
+ O[None] = TernaryFn.select(cond[None], lhs[None], rhs[None])
+
+
@linalg_structured_op
def matmul(
A=TensorDef(T1, S.M, S.K),
>From 2c13925e0be5d3572aa1100d463fa80d691d8c29 Mon Sep 17 00:00:00 2001
From: Petr Kurapov <petr.a.kurapov at intel.com>
Date: Tue, 7 May 2024 18:15:06 +0000
Subject: [PATCH 2/5] [MLIR][Linalg] Fix the build & add ternary into
appropriate enums
---
.../Linalg/IR/LinalgNamedStructuredOps.yaml | 45 +++----------------
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 11 +++--
.../linalg/opdsl/ops/core_named_ops.py | 2 +-
.../mlir-linalg-ods-yaml-gen.cpp | 10 ++++-
4 files changed, 22 insertions(+), 46 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 7350271aa3829..2e4af321d1363 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -304,41 +304,6 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_arg: I
--- !LinalgOpConfig
-metadata: !LinalgOpMetadata
- name: reciprocal
- cpp_class_name: ReciprocalOp
- doc: |-
- Applies reciprocal(x) elementwise.
-
- No numeric casting is performed on the input operand.
-structured_op: !LinalgStructuredOpConfig
- args:
- - !LinalgOperandDefConfig
- name: I
- kind: input_tensor
- type_var: T1
- shape_map: affine_map<() -> ()>
- - !LinalgOperandDefConfig
- name: O
- kind: output_tensor
- type_var: T1
- shape_map: affine_map<() -> ()>
- indexing_maps: !LinalgIndexingMapsConfig
- static_indexing_maps:
- - affine_map<() -> ()>
- - affine_map<() -> ()>
- iterator_types: []
- assignments:
- - !ScalarAssign
- arg: O
- value: !ScalarExpression
- scalar_fn:
- kind: unary
- fn_name: reciprocal
- operands:
- - !ScalarExpression
- scalar_arg: I
---- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: round
cpp_class_name: RoundOp
@@ -516,7 +481,7 @@ structured_op: !LinalgStructuredOpConfig
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: erf
- cpp_class_name: erfOp
+ cpp_class_name: ErfOp
doc: |-
Applies erf(x) elementwise.
@@ -959,7 +924,7 @@ structured_op: !LinalgStructuredOpConfig
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: powf
- cpp_class_name: PowFOp
+ cpp_class_name: PowfOp
doc: |-
Takes the powf(lhs, rhs) between two inputs, elementwise. For powf(arg, 2) use `linalg.square`.
@@ -1009,8 +974,8 @@ structured_op: !LinalgStructuredOpConfig
scalar_arg: rhs
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
- name: select
- cpp_class_name: SelectOp
+ name: selectf
+ cpp_class_name: SelectfOp
doc: |-
Chooses one value based on a binary condition supplied as its first operand.
@@ -1026,7 +991,7 @@ structured_op: !LinalgStructuredOpConfig
- !LinalgOperandDefConfig
name: cond
kind: input_tensor
- type_var: T1
+ type_var: U
shape_map: affine_map<() -> ()>
- !LinalgOperandDefConfig
name: lhs
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 4933d71ac7fb0..6a5f25a7605f1 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -493,10 +493,13 @@ class RegionBuilderHelper {
}
// Build the ternary functions defined by OpDSL.
- Value buildTernaryFn(TernaryFn ternaryFn, Value arg0, Value arg1, Value arg2) {
- bool headBool = isInteger(arg0) && arg0.getType().getIntOrFloatBitWidth() == 1;
- bool tailFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1) && isFloatingPoint(arg2);
- bool tailnteger = isInteger(arg0) && isInteger(arg1) && isInteger(arg1);
+ Value buildTernaryFn(TernaryFn ternaryFn, Value arg0, Value arg1,
+ Value arg2) {
+ bool headBool =
+ isInteger(arg0) && arg0.getType().getIntOrFloatBitWidth() == 1;
+ bool tailFloatingPoint =
+ isFloatingPoint(arg0) && isFloatingPoint(arg1) && isFloatingPoint(arg2);
+ bool tailInteger = isInteger(arg0) && isInteger(arg1) && isInteger(arg1);
OpBuilder::InsertionGuard g(builder);
builder.setInsertionPointToEnd(&block);
switch (ternaryFn) {
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index 3cf3e9de0136f..04d10e7e9e7f5 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -353,7 +353,7 @@ def powf(
@linalg_structured_op
def selectf(
- cond=TensorDef(bool),
+ cond=TensorDef(U),
lhs=TensorDef(T1),
rhs=TensorDef(T1),
O=TensorDef(T1, output=True),
diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
index fe6ad15041126..37240164c377e 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
@@ -70,6 +70,7 @@ enum class LinalgOperandDefKind {
IndexAttr,
UnaryFnAttr,
BinaryFnAttr,
+ TernaryFnAttr,
TypeFnAttr
};
@@ -94,7 +95,7 @@ struct LinalgIndexingMapsConfig {
struct ScalarExpression;
-enum class ScalarFnKind { Unary, Binary, Type };
+enum class ScalarFnKind { Unary, Binary, Ternary, Type };
struct ScalarFn {
ScalarFnKind kind;
@@ -214,6 +215,7 @@ struct ScalarEnumerationTraits<LinalgOperandDefKind> {
io.enumCase(value, "index_attr", LinalgOperandDefKind::IndexAttr);
io.enumCase(value, "unary_fn_attr", LinalgOperandDefKind::UnaryFnAttr);
io.enumCase(value, "binary_fn_attr", LinalgOperandDefKind::BinaryFnAttr);
+ io.enumCase(value, "ternary_fn_attr", LinalgOperandDefKind::TernaryFnAttr);
io.enumCase(value, "type_fn_attr", LinalgOperandDefKind::TypeFnAttr);
}
};
@@ -284,6 +286,7 @@ struct ScalarEnumerationTraits<ScalarFnKind> {
static void enumeration(IO &io, ScalarFnKind &value) {
io.enumCase(value, "unary", ScalarFnKind::Unary);
io.enumCase(value, "binary", ScalarFnKind::Binary);
+ io.enumCase(value, "ternary", ScalarFnKind::Ternary);
io.enumCase(value, "type", ScalarFnKind::Type);
}
};
@@ -441,6 +444,7 @@ static ScalarAssign *findAssignment(StringRef name,
static bool isFunctionAttribute(LinalgOperandDefKind kind) {
return kind == LinalgOperandDefKind::UnaryFnAttr ||
kind == LinalgOperandDefKind::BinaryFnAttr ||
+ kind == LinalgOperandDefKind::TernaryFnAttr ||
kind == LinalgOperandDefKind::TypeFnAttr;
}
@@ -456,6 +460,8 @@ std::string convertOperandKindToEnumName(LinalgOperandDefKind kind) {
return std::string("UnaryFn");
case LinalgOperandDefKind::BinaryFnAttr:
return std::string("BinaryFn");
+ case LinalgOperandDefKind::TernaryFnAttr:
+ return std::string("TernaryFn");
case LinalgOperandDefKind::TypeFnAttr:
return std::string("TypeFn");
default:
@@ -471,6 +477,8 @@ std::string convertFunctionKindToEnumName(ScalarFnKind kind) {
return std::string("UnaryFn");
case ScalarFnKind::Binary:
return std::string("BinaryFn");
+ case ScalarFnKind::Ternary:
+ return std::string("TernaryFn");
case ScalarFnKind::Type:
return std::string("TypeFn");
}
>From 334e31622bea42341420bb59c6e0cbb2d71887af Mon Sep 17 00:00:00 2001
From: Petr Kurapov <petr.a.kurapov at intel.com>
Date: Tue, 7 May 2024 19:24:08 +0000
Subject: [PATCH 3/5] [MLIR][Linalg] add basic select test
---
.../Linalg/IR/LinalgNamedStructuredOps.yaml | 43 +++++++++++++++++--
.../linalg/opdsl/ops/core_named_ops.py | 2 +-
.../Dialect/Linalg/generalize-named-ops.mlir | 25 +++++++++++
mlir/test/Dialect/Linalg/named-ops-fail.mlir | 16 +++++++
mlir/test/Dialect/Linalg/named-ops.mlir | 22 ++++++++++
5 files changed, 103 insertions(+), 5 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 2e4af321d1363..eb7dd37010a67 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -304,6 +304,41 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_arg: I
--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+ name: reciprocal
+ cpp_class_name: ReciprocalOp
+ doc: |-
+ Applies reciprocal(x) elementwise.
+
+ No numeric casting is performed on the input operand.
+structured_op: !LinalgStructuredOpConfig
+ args:
+ - !LinalgOperandDefConfig
+ name: I
+ kind: input_tensor
+ type_var: T1
+ shape_map: affine_map<() -> ()>
+ - !LinalgOperandDefConfig
+ name: O
+ kind: output_tensor
+ type_var: T1
+ shape_map: affine_map<() -> ()>
+ indexing_maps: !LinalgIndexingMapsConfig
+ static_indexing_maps:
+ - affine_map<() -> ()>
+ - affine_map<() -> ()>
+ iterator_types: []
+ assignments:
+ - !ScalarAssign
+ arg: O
+ value: !ScalarExpression
+ scalar_fn:
+ kind: unary
+ fn_name: reciprocal
+ operands:
+ - !ScalarExpression
+ scalar_arg: I
+--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: round
cpp_class_name: RoundOp
@@ -481,7 +516,7 @@ structured_op: !LinalgStructuredOpConfig
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: erf
- cpp_class_name: ErfOp
+ cpp_class_name: erfOp
doc: |-
Applies erf(x) elementwise.
@@ -924,7 +959,7 @@ structured_op: !LinalgStructuredOpConfig
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: powf
- cpp_class_name: PowfOp
+ cpp_class_name: PowFOp
doc: |-
Takes the powf(lhs, rhs) between two inputs, elementwise. For powf(arg, 2) use `linalg.square`.
@@ -974,8 +1009,8 @@ structured_op: !LinalgStructuredOpConfig
scalar_arg: rhs
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
- name: selectf
- cpp_class_name: SelectfOp
+ name: select
+ cpp_class_name: SelectOp
doc: |-
Chooses one value based on a binary condition supplied as its first operand.
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index 04d10e7e9e7f5..d73428a0f4df3 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -352,7 +352,7 @@ def powf(
@linalg_structured_op
-def selectf(
+def select(
cond=TensorDef(U),
lhs=TensorDef(T1),
rhs=TensorDef(T1),
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index 667ea3c18c8ad..4f43ec2c9e1ce 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -791,6 +791,31 @@ func.func @generalize_powf(%lhs: memref<7x14x21xf32>, %rhs: memref<7x14x21xf32>,
// -----
+func.func @generalize_select(%cond: memref<7x14x21xi1>, %lhs: memref<7x14x21xf32>, %rhs: memref<7x14x21xf32>,
+ %out: memref<7x14x21xf32>) {
+ linalg.select ins(%cond, %lhs, %rhs: memref<7x14x21xi1>, memref<7x14x21xf32>, memref<7x14x21xf32>)
+ outs(%out: memref<7x14x21xf32>)
+ return
+}
+
+// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+
+// CHECK: func @generalize_select
+// CHECK-SAME: (%[[COND:.+]]: memref<7x14x21xi1>, %[[LHS:.+]]: memref<7x14x21xf32>, %[[RHS:.+]]: memref<7x14x21xf32>,
+// CHECK-SAME: %[[OUT:.+]]: memref<7x14x21xf32>)
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]], #[[MAP]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
+// CHECK-SAME: ins(%[[COND]], %[[LHS]], %[[RHS]] : memref<7x14x21xi1>, memref<7x14x21xf32>, memref<7x14x21xf32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xf32>)
+
+// CHECK: ^{{.+}}(%[[BBARG0:.+]]: i1, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32, %[[BBARG3:.+]]: f32)
+// CHECK-NEXT: %[[select:.+]] = arith.select %[[BBARG0]], %[[BBARG1]], %[[BBARG2]] : f32
+// CHECK-NEXT: linalg.yield %[[select]] : f32
+
+
+// -----
// CHECK-LABEL: func @fill_tensor
func.func @fill_tensor(%f: f32, %v: vector<2x4xf32>) -> (tensor<f32>, tensor<vector<2x4xf32>>) {
diff --git a/mlir/test/Dialect/Linalg/named-ops-fail.mlir b/mlir/test/Dialect/Linalg/named-ops-fail.mlir
index e92a77aa7ad05..552a0abaa797c 100644
--- a/mlir/test/Dialect/Linalg/named-ops-fail.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops-fail.mlir
@@ -334,3 +334,19 @@ func.func @powf_broadcast(%arg0: memref<8x16xf32>, %arg1: memref<4x8x16xf32>, %a
return
}
+// -----
+
+func.func @select_type_cast(%arg0: memref<4x8x16xi1>, %arg1: memref<4x8x16xf16>, %arg2: memref<4x8x16xf32>, %arg3: memref<4x8x16xf32>) {
+ // CHECK: op failed to verify that all of {true_value, false_value, result} have same type
+ linalg.select ins(%arg0, %arg1, %arg2 : memref<4x8x16xi1>, memref<4x8x16xf16>, memref<4x8x16xf32>) outs(%arg3: memref<4x8x16xf32>)
+ return
+}
+
+// -----
+
+func.func @select_wrong_condition_type(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf32>, %arg2: memref<4x8x16xf32>, %arg3: memref<4x8x16xf32>) {
+ // CHECK: op operand #0 must be bool-like, but got 'f32'
+ linalg.select ins(%arg0, %arg1, %arg2 : memref<4x8x16xf32>, memref<4x8x16xf32>, memref<4x8x16xf32>) outs(%arg3: memref<4x8x16xf32>)
+ return
+}
+
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index fefe5578947f0..cecd0033b7765 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1924,3 +1924,25 @@ func.func @fill_tensor(%f: f32, %v: vector<2x4xf32>) -> (tensor<f32>, tensor<vec
%1 = linalg.fill ins(%v : vector<2x4xf32>) outs(%e1 : tensor<vector<2x4xf32>>) -> tensor<vector<2x4xf32>>
return %0, %1: tensor<f32>, tensor<vector<2x4xf32>>
}
+
+// -----
+
+// CHECK-LABEL: func @select_dynamic
+func.func @select_dynamic(%arg0: memref<?x?x?xi1>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>, %arg3: memref<?x?x?xf32>) {
+ // CHECK: linalg.select
+ // CHECK-SAME: ins(%{{.+}}, %{{.+}}, %{{.+}} : memref<?x?x?xi1>, memref<?x?x?xf32>, memref<?x?x?xf32>)
+ // CHECK-SAME: outs(%{{.+}} : memref<?x?x?xf32>)
+ linalg.select ins(%arg0, %arg1, %arg2 : memref<?x?x?xi1>, memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg3: memref<?x?x?xf32>)
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @select_static
+func.func @select_static(%arg0: memref<4x8x16xi1>, %arg1: memref<4x8x16xf32>, %arg2: memref<4x8x16xf32>, %arg3: memref<4x8x16xf32>) {
+ // CHECK: linalg.select
+ // CHECK-SAME: ins(%{{.+}}, %{{.+}}, %{{.+}} : memref<4x8x16xi1>, memref<4x8x16xf32>, memref<4x8x16xf32>)
+ // CHECK-SAME: outs(%{{.+}} : memref<4x8x16xf32>)
+ linalg.select ins(%arg0, %arg1, %arg2 : memref<4x8x16xi1>, memref<4x8x16xf32>, memref<4x8x16xf32>) outs(%arg3: memref<4x8x16xf32>)
+ return
+}
>From 40d9d3d3fb31fc3441a3ceabc6d4fa49b6f64a83 Mon Sep 17 00:00:00 2001
From: Petr Kurapov <petr.kurapov at gmail.com>
Date: Wed, 8 May 2024 14:57:48 +0200
Subject: [PATCH 4/5] Update
mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
Co-authored-by: Maksim Levental <maksim.levental at gmail.com>
---
mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
index 880dcb7250b96..34eea0d0fc8a0 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
@@ -343,7 +343,7 @@ class BinaryFn:
class TernaryFnType:
"""Ternary function.
- A bterary function takes three tensor expressions and returns the
+ A ternary function takes three tensor expressions and returns the
function evaluation result.
"""
>From b56a72fd026c201c8a76e430e5f716041d7c19af Mon Sep 17 00:00:00 2001
From: Petr Kurapov <petr.a.kurapov at intel.com>
Date: Wed, 8 May 2024 13:07:01 +0000
Subject: [PATCH 5/5] Fix Python formatting
---
.../dialects/linalg/opdsl/lang/comprehension.py | 15 ++++++++++-----
1 file changed, 10 insertions(+), 5 deletions(-)
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
index 34eea0d0fc8a0..1a198fc5ec6f9 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
@@ -350,16 +350,19 @@ class TernaryFnType:
def __init__(self, fn_name: str):
self.fn_name = fn_name
- def __call__(self, arg0: TensorExpression, arg1: TensorExpression, arg2: TensorExpression) -> "TensorFn":
- return TensorFn(FunctionKind.TERNARY, self.fn_name, None, None, [arg0, arg1, arg2])
+ def __call__(
+ self, arg0: TensorExpression, arg1: TensorExpression, arg2: TensorExpression
+ ) -> "TensorFn":
+ return TensorFn(
+ FunctionKind.TERNARY, self.fn_name, None, None, [arg0, arg1, arg2]
+ )
def __repr__(self):
return f"{self.fn_name}"
class TernaryFn:
- """Ternary function namespace.
- """
+ """Ternary function namespace."""
select = TernaryFnType("select")
@@ -716,7 +719,9 @@ def __init__(self, default: "TernaryFnType"):
)
def __call__(self, arg0: TensorExpression, arg1: TensorExpression) -> TensorFn:
- return TensorFn(FunctionKind.TERNARY, None, self.operand_def, None, [arg0, arg1])
+ return TensorFn(
+ FunctionKind.TERNARY, None, self.operand_def, None, [arg0, arg1]
+ )
def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse:
return ReduceFnUse(None, self, *reduce_dims)
More information about the Mlir-commits
mailing list