[Mlir-commits] [mlir] [MLIR][Linalg] Ternary Op & Linalg select (PR #91461)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed May 8 04:36:44 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-core
Author: Petr Kurapov (kurapov-peter)
<details>
<summary>Changes</summary>
Following #<!-- -->90236, adding `select` to linalg as `arith.select`. No implicit type casting.
OpDSL doesn't expose a type restriction for bool, but I saw no reason in adding it (put a separate symbolic type and check the semantics in the builder).
---
Full diff: https://github.com/llvm/llvm-project/pull/91461.diff
11 Files Affected:
- (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td (+3)
- (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td (+6)
- (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml (+57)
- (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+19)
- (modified) mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py (+54-2)
- (modified) mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py (+7)
- (modified) mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py (+20)
- (modified) mlir/test/Dialect/Linalg/generalize-named-ops.mlir (+25)
- (modified) mlir/test/Dialect/Linalg/named-ops-fail.mlir (+16)
- (modified) mlir/test/Dialect/Linalg/named-ops.mlir (+22)
- (modified) mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp (+9-1)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
index e87e8b5600107..73f984dc072d3 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
@@ -68,6 +68,9 @@ def UnaryFnAttr : EnumAttr<Linalg_Dialect, UnaryFn, "unary_fn"> {
def BinaryFnAttr : EnumAttr<Linalg_Dialect, BinaryFn, "binary_fn"> {
let assemblyFormat = "`<` $value `>`";
}
+def TernaryFnAttr : EnumAttr<Linalg_Dialect, TernaryFn, "ternary_fn"> {
+ let assemblyFormat = "`<` $value `>`";
+}
def TypeFnAttr : EnumAttr<Linalg_Dialect, TypeFn, "type_fn"> {
let assemblyFormat = "`<` $value `>`";
}
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
index 6b4b073fc6724..e615876a95d05 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
@@ -49,6 +49,12 @@ def BinaryFn : I32EnumAttr<"BinaryFn", "", [
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::linalg";
}
+def TernaryFn : I32EnumAttr<"TernaryFn", "", [
+ I32EnumAttrCase<"select", 0>
+]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::linalg";
+}
def TypeFn : I32EnumAttr<"TypeFn", "", [
I32EnumAttrCase<"cast_signed", 0>,
I32EnumAttrCase<"cast_unsigned", 1>
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 584bfcd8b59dc..eb7dd37010a67 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -1008,6 +1008,63 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_arg: rhs
--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+ name: select
+ cpp_class_name: SelectOp
+ doc: |-
+ Chooses one value based on a binary condition supplied as its first operand.
+
+ The shapes and element types must be identical. The appropriate casts,
+ broadcasts and reductions should be done previously to calling this op.
+
+ This means reduction/broadcast/element cast semantics is explicit. Further
+ passes can take that into account when lowering this code. For example,
+ a `linalg.broadcast` + `linalg.select` sequence can be lowered to a
+ `linalg.generic` with different affine maps for the two operands.
+structured_op: !LinalgStructuredOpConfig
+ args:
+ - !LinalgOperandDefConfig
+ name: cond
+ kind: input_tensor
+ type_var: U
+ shape_map: affine_map<() -> ()>
+ - !LinalgOperandDefConfig
+ name: lhs
+ kind: input_tensor
+ type_var: T1
+ shape_map: affine_map<() -> ()>
+ - !LinalgOperandDefConfig
+ name: rhs
+ kind: input_tensor
+ type_var: T1
+ shape_map: affine_map<() -> ()>
+ - !LinalgOperandDefConfig
+ name: O
+ kind: output_tensor
+ type_var: T1
+ shape_map: affine_map<() -> ()>
+ indexing_maps: !LinalgIndexingMapsConfig
+ static_indexing_maps:
+ - affine_map<() -> ()>
+ - affine_map<() -> ()>
+ - affine_map<() -> ()>
+ - affine_map<() -> ()>
+ iterator_types: []
+ assignments:
+ - !ScalarAssign
+ arg: O
+ value: !ScalarExpression
+ scalar_fn:
+ kind: ternary
+ fn_name: select
+ operands:
+ - !ScalarExpression
+ scalar_arg: cond
+ - !ScalarExpression
+ scalar_arg: lhs
+ - !ScalarExpression
+ scalar_arg: rhs
+--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: matmul
cpp_class_name: MatmulOp
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index e5f83331baf81..6a5f25a7605f1 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -492,6 +492,25 @@ class RegionBuilderHelper {
llvm_unreachable("unsupported binary function");
}
+ // Build the ternary functions defined by OpDSL.
+ Value buildTernaryFn(TernaryFn ternaryFn, Value arg0, Value arg1,
+ Value arg2) {
+ bool headBool =
+ isInteger(arg0) && arg0.getType().getIntOrFloatBitWidth() == 1;
+ bool tailFloatingPoint =
+ isFloatingPoint(arg0) && isFloatingPoint(arg1) && isFloatingPoint(arg2);
+ bool tailInteger = isInteger(arg0) && isInteger(arg1) && isInteger(arg1);
+ OpBuilder::InsertionGuard g(builder);
+ builder.setInsertionPointToEnd(&block);
+ switch (ternaryFn) {
+ case TernaryFn::select:
+ if (!headBool && !(tailFloatingPoint || tailInteger))
+ llvm_unreachable("unsupported non numeric type");
+ return builder.create<arith::SelectOp>(arg0.getLoc(), arg0, arg1, arg2);
+ }
+ llvm_unreachable("unsupported ternary function");
+ }
+
// Build the type functions defined by OpDSL.
Value buildTypeFn(TypeFn typeFn, Type toType, Value operand) {
switch (typeFn) {
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
index bb43ebf2b6923..880dcb7250b96 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
@@ -262,7 +262,8 @@ def __repr__(self):
class FunctionKind(Enum):
UNARY = 0
BINARY = 1
- TYPE = 2
+ TERNARY = 2
+ TYPE = 3
class UnaryFnType:
@@ -339,6 +340,30 @@ class BinaryFn:
powf = BinaryFnType("powf")
+class TernaryFnType:
+ """Ternary function.
+
+ A bterary function takes three tensor expressions and returns the
+ function evaluation result.
+ """
+
+ def __init__(self, fn_name: str):
+ self.fn_name = fn_name
+
+ def __call__(self, arg0: TensorExpression, arg1: TensorExpression, arg2: TensorExpression) -> "TensorFn":
+ return TensorFn(FunctionKind.TERNARY, self.fn_name, None, None, [arg0, arg1, arg2])
+
+ def __repr__(self):
+ return f"{self.fn_name}"
+
+
+class TernaryFn:
+ """Ternary function namespace.
+ """
+
+ select = TernaryFnType("select")
+
+
class TypeFnType:
"""Type conversion function.
@@ -437,7 +462,8 @@ class OperandKind(Enum):
INDEX_ATTR = 3
UNARY_FN_ATTR = 4
BINARY_FN_ATTR = 5
- TYPE_FN_ATTR = 6
+ TERNARY_FN_ATTR = 6
+ TYPE_FN_ATTR = 7
class OperandDef:
@@ -489,6 +515,7 @@ def is_attribute(self) -> bool:
self.kind == OperandKind.INDEX_ATTR
or self.kind == OperandKind.UNARY_FN_ATTR
or self.kind == OperandKind.BINARY_FN_ATTR
+ or self.kind == OperandKind.TERNARY_FN_ATTR
or self.kind == OperandKind.TYPE_FN_ATTR
)
@@ -670,6 +697,31 @@ def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse:
return ReduceFnUse(None, self, *reduce_dims)
+class TernaryFnAttrDef:
+ """Ternary function attribute definition.
+
+ Ternary function attributes provide a way to make the arithmetic computation
+ parametrizable. Every attribute specifies a default Ternary function
+ that may be overwritten at operation instantiation time.
+ """
+
+ def __init__(self, default: "TernaryFnType"):
+ if not isinstance(default, TernaryFnType):
+ raise ValueError(
+ f"TernaryFnAttrDef requires default of type TernaryFnType "
+ f"but got {default}"
+ )
+ self.operand_def = OperandDef(
+ OperandKind.TERNARY_FN_ATTR, default_fn=default.fn_name
+ )
+
+ def __call__(self, arg0: TensorExpression, arg1: TensorExpression) -> TensorFn:
+ return TensorFn(FunctionKind.TERNARY, None, self.operand_def, None, [arg0, arg1])
+
+ def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse:
+ return ReduceFnUse(None, self, *reduce_dims)
+
+
class TypeFnAttrDef:
"""Type conversion function attribute definition.
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
index f91fc8b716008..845b533db52a9 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
@@ -60,6 +60,7 @@ def prepare_common_structured_op(
in [
OperandKind.UNARY_FN_ATTR,
OperandKind.BINARY_FN_ATTR,
+ OperandKind.TERNARY_FN_ATTR,
OperandKind.TYPE_FN_ATTR,
]
]
@@ -180,6 +181,12 @@ def prepare_common_structured_op(
f"Attribute {fn_attr.name} needs to be of type "
f"BinaryFnType but got {type(attr_val)}"
)
+ elif attr_kind == OperandKind.TERNARY_FN_ATTR:
+ if not isinstance(fn, TernaryFnType):
+ raise ValueError(
+ f"Attribute {fn_attr.name} needs to be of type "
+ f"TernaryFnType but got {type(attr_val)}"
+ )
else:
if not isinstance(fn, TypeFnType):
raise ValueError(
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index ca2bb0c5f7f8a..d73428a0f4df3 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -351,6 +351,26 @@ def powf(
O[None] = BinaryFn.powf(lhs[None], rhs[None])
+ at linalg_structured_op
+def select(
+ cond=TensorDef(U),
+ lhs=TensorDef(T1),
+ rhs=TensorDef(T1),
+ O=TensorDef(T1, output=True),
+):
+ """Chooses one value based on a binary condition supplied as its first operand.
+
+ The shapes and element types must be identical. The appropriate casts,
+ broadcasts and reductions should be done previously to calling this op.
+
+ This means reduction/broadcast/element cast semantics is explicit. Further
+ passes can take that into account when lowering this code. For example,
+ a `linalg.broadcast` + `linalg.select` sequence can be lowered to a
+ `linalg.generic` with different affine maps for the two operands.
+ """
+ O[None] = TernaryFn.select(cond[None], lhs[None], rhs[None])
+
+
@linalg_structured_op
def matmul(
A=TensorDef(T1, S.M, S.K),
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index 667ea3c18c8ad..4f43ec2c9e1ce 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -791,6 +791,31 @@ func.func @generalize_powf(%lhs: memref<7x14x21xf32>, %rhs: memref<7x14x21xf32>,
// -----
+func.func @generalize_select(%cond: memref<7x14x21xi1>, %lhs: memref<7x14x21xf32>, %rhs: memref<7x14x21xf32>,
+ %out: memref<7x14x21xf32>) {
+ linalg.select ins(%cond, %lhs, %rhs: memref<7x14x21xi1>, memref<7x14x21xf32>, memref<7x14x21xf32>)
+ outs(%out: memref<7x14x21xf32>)
+ return
+}
+
+// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+
+// CHECK: func @generalize_select
+// CHECK-SAME: (%[[COND:.+]]: memref<7x14x21xi1>, %[[LHS:.+]]: memref<7x14x21xf32>, %[[RHS:.+]]: memref<7x14x21xf32>,
+// CHECK-SAME: %[[OUT:.+]]: memref<7x14x21xf32>)
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]], #[[MAP]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
+// CHECK-SAME: ins(%[[COND]], %[[LHS]], %[[RHS]] : memref<7x14x21xi1>, memref<7x14x21xf32>, memref<7x14x21xf32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xf32>)
+
+// CHECK: ^{{.+}}(%[[BBARG0:.+]]: i1, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32, %[[BBARG3:.+]]: f32)
+// CHECK-NEXT: %[[select:.+]] = arith.select %[[BBARG0]], %[[BBARG1]], %[[BBARG2]] : f32
+// CHECK-NEXT: linalg.yield %[[select]] : f32
+
+
+// -----
// CHECK-LABEL: func @fill_tensor
func.func @fill_tensor(%f: f32, %v: vector<2x4xf32>) -> (tensor<f32>, tensor<vector<2x4xf32>>) {
diff --git a/mlir/test/Dialect/Linalg/named-ops-fail.mlir b/mlir/test/Dialect/Linalg/named-ops-fail.mlir
index e92a77aa7ad05..552a0abaa797c 100644
--- a/mlir/test/Dialect/Linalg/named-ops-fail.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops-fail.mlir
@@ -334,3 +334,19 @@ func.func @powf_broadcast(%arg0: memref<8x16xf32>, %arg1: memref<4x8x16xf32>, %a
return
}
+// -----
+
+func.func @select_type_cast(%arg0: memref<4x8x16xi1>, %arg1: memref<4x8x16xf16>, %arg2: memref<4x8x16xf32>, %arg3: memref<4x8x16xf32>) {
+ // CHECK: op failed to verify that all of {true_value, false_value, result} have same type
+ linalg.select ins(%arg0, %arg1, %arg2 : memref<4x8x16xi1>, memref<4x8x16xf16>, memref<4x8x16xf32>) outs(%arg3: memref<4x8x16xf32>)
+ return
+}
+
+// -----
+
+func.func @select_wrong_condition_type(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf32>, %arg2: memref<4x8x16xf32>, %arg3: memref<4x8x16xf32>) {
+ // CHECK: op operand #0 must be bool-like, but got 'f32'
+ linalg.select ins(%arg0, %arg1, %arg2 : memref<4x8x16xf32>, memref<4x8x16xf32>, memref<4x8x16xf32>) outs(%arg3: memref<4x8x16xf32>)
+ return
+}
+
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index fefe5578947f0..cecd0033b7765 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1924,3 +1924,25 @@ func.func @fill_tensor(%f: f32, %v: vector<2x4xf32>) -> (tensor<f32>, tensor<vec
%1 = linalg.fill ins(%v : vector<2x4xf32>) outs(%e1 : tensor<vector<2x4xf32>>) -> tensor<vector<2x4xf32>>
return %0, %1: tensor<f32>, tensor<vector<2x4xf32>>
}
+
+// -----
+
+// CHECK-LABEL: func @select_dynamic
+func.func @select_dynamic(%arg0: memref<?x?x?xi1>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>, %arg3: memref<?x?x?xf32>) {
+ // CHECK: linalg.select
+ // CHECK-SAME: ins(%{{.+}}, %{{.+}}, %{{.+}} : memref<?x?x?xi1>, memref<?x?x?xf32>, memref<?x?x?xf32>)
+ // CHECK-SAME: outs(%{{.+}} : memref<?x?x?xf32>)
+ linalg.select ins(%arg0, %arg1, %arg2 : memref<?x?x?xi1>, memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg3: memref<?x?x?xf32>)
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @select_static
+func.func @select_static(%arg0: memref<4x8x16xi1>, %arg1: memref<4x8x16xf32>, %arg2: memref<4x8x16xf32>, %arg3: memref<4x8x16xf32>) {
+ // CHECK: linalg.select
+ // CHECK-SAME: ins(%{{.+}}, %{{.+}}, %{{.+}} : memref<4x8x16xi1>, memref<4x8x16xf32>, memref<4x8x16xf32>)
+ // CHECK-SAME: outs(%{{.+}} : memref<4x8x16xf32>)
+ linalg.select ins(%arg0, %arg1, %arg2 : memref<4x8x16xi1>, memref<4x8x16xf32>, memref<4x8x16xf32>) outs(%arg3: memref<4x8x16xf32>)
+ return
+}
diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
index fe6ad15041126..37240164c377e 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
@@ -70,6 +70,7 @@ enum class LinalgOperandDefKind {
IndexAttr,
UnaryFnAttr,
BinaryFnAttr,
+ TernaryFnAttr,
TypeFnAttr
};
@@ -94,7 +95,7 @@ struct LinalgIndexingMapsConfig {
struct ScalarExpression;
-enum class ScalarFnKind { Unary, Binary, Type };
+enum class ScalarFnKind { Unary, Binary, Ternary, Type };
struct ScalarFn {
ScalarFnKind kind;
@@ -214,6 +215,7 @@ struct ScalarEnumerationTraits<LinalgOperandDefKind> {
io.enumCase(value, "index_attr", LinalgOperandDefKind::IndexAttr);
io.enumCase(value, "unary_fn_attr", LinalgOperandDefKind::UnaryFnAttr);
io.enumCase(value, "binary_fn_attr", LinalgOperandDefKind::BinaryFnAttr);
+ io.enumCase(value, "ternary_fn_attr", LinalgOperandDefKind::TernaryFnAttr);
io.enumCase(value, "type_fn_attr", LinalgOperandDefKind::TypeFnAttr);
}
};
@@ -284,6 +286,7 @@ struct ScalarEnumerationTraits<ScalarFnKind> {
static void enumeration(IO &io, ScalarFnKind &value) {
io.enumCase(value, "unary", ScalarFnKind::Unary);
io.enumCase(value, "binary", ScalarFnKind::Binary);
+ io.enumCase(value, "ternary", ScalarFnKind::Ternary);
io.enumCase(value, "type", ScalarFnKind::Type);
}
};
@@ -441,6 +444,7 @@ static ScalarAssign *findAssignment(StringRef name,
static bool isFunctionAttribute(LinalgOperandDefKind kind) {
return kind == LinalgOperandDefKind::UnaryFnAttr ||
kind == LinalgOperandDefKind::BinaryFnAttr ||
+ kind == LinalgOperandDefKind::TernaryFnAttr ||
kind == LinalgOperandDefKind::TypeFnAttr;
}
@@ -456,6 +460,8 @@ std::string convertOperandKindToEnumName(LinalgOperandDefKind kind) {
return std::string("UnaryFn");
case LinalgOperandDefKind::BinaryFnAttr:
return std::string("BinaryFn");
+ case LinalgOperandDefKind::TernaryFnAttr:
+ return std::string("TernaryFn");
case LinalgOperandDefKind::TypeFnAttr:
return std::string("TypeFn");
default:
@@ -471,6 +477,8 @@ std::string convertFunctionKindToEnumName(ScalarFnKind kind) {
return std::string("UnaryFn");
case ScalarFnKind::Binary:
return std::string("BinaryFn");
+ case ScalarFnKind::Ternary:
+ return std::string("TernaryFn");
case ScalarFnKind::Type:
return std::string("TypeFn");
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/91461
More information about the Mlir-commits
mailing list