[Mlir-commits] [mlir] [MLIR][Linalg] Ternary Op & Linalg select (PR #91461)

Petr Kurapov llvmlistbot at llvm.org
Fri May 10 05:19:47 PDT 2024


https://github.com/kurapov-peter updated https://github.com/llvm/llvm-project/pull/91461

>From f3eb7a92f58074565833266598e8f13438f97c17 Mon Sep 17 00:00:00 2001
From: Renato Golin <rengolin at systemcall.eu>
Date: Mon, 29 Apr 2024 23:41:29 +0100
Subject: [PATCH 1/6] [MLIR][Linalg] Add ternary select named op (in progress)

---
 .../mlir/Dialect/Linalg/IR/LinalgBase.td      |  3 +
 .../mlir/Dialect/Linalg/IR/LinalgEnums.td     |  6 ++
 .../Linalg/IR/LinalgNamedStructuredOps.yaml   | 57 +++++++++++++++++++
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 16 ++++++
 .../linalg/opdsl/lang/comprehension.py        | 56 +++++++++++++++++-
 .../dialects/linalg/opdsl/lang/emitter.py     |  7 +++
 .../linalg/opdsl/ops/core_named_ops.py        | 20 +++++++
 7 files changed, 163 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
index e87e8b5600107..73f984dc072d3 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
@@ -68,6 +68,9 @@ def UnaryFnAttr : EnumAttr<Linalg_Dialect, UnaryFn, "unary_fn"> {
 def BinaryFnAttr : EnumAttr<Linalg_Dialect, BinaryFn, "binary_fn"> {
   let assemblyFormat = "`<` $value `>`";
 }
+def TernaryFnAttr : EnumAttr<Linalg_Dialect, TernaryFn, "ternary_fn"> {
+  let assemblyFormat = "`<` $value `>`";
+}
 def TypeFnAttr : EnumAttr<Linalg_Dialect, TypeFn, "type_fn"> {
   let assemblyFormat = "`<` $value `>`";
 }
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
index 6b4b073fc6724..e615876a95d05 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
@@ -49,6 +49,12 @@ def BinaryFn : I32EnumAttr<"BinaryFn", "", [
   let genSpecializedAttr = 0;
   let cppNamespace = "::mlir::linalg";
 }
+def TernaryFn : I32EnumAttr<"TernaryFn", "", [
+  I32EnumAttrCase<"select", 0>
+]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "::mlir::linalg";
+}
 def TypeFn : I32EnumAttr<"TypeFn", "", [
   I32EnumAttrCase<"cast_signed", 0>,
   I32EnumAttrCase<"cast_unsigned", 1>
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 584bfcd8b59dc..7350271aa3829 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -1008,6 +1008,63 @@ structured_op: !LinalgStructuredOpConfig
         - !ScalarExpression
           scalar_arg: rhs
 --- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: select
+  cpp_class_name: SelectOp
+  doc: |-
+    Chooses one value based on a binary condition supplied as its first operand.
+
+    The shapes and element types must be identical. The appropriate casts,
+    broadcasts and reductions should be done previously to calling this op.
+
+    This means reduction/broadcast/element cast semantics is explicit. Further
+    passes can take that into account when lowering this code. For example,
+    a `linalg.broadcast` + `linalg.select` sequence can be lowered to a
+    `linalg.generic` with different affine maps for the two operands.
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !LinalgOperandDefConfig
+    name: cond
+    kind: input_tensor
+    type_var: T1
+    shape_map: affine_map<() -> ()>
+  - !LinalgOperandDefConfig
+    name: lhs
+    kind: input_tensor
+    type_var: T1
+    shape_map: affine_map<() -> ()>
+  - !LinalgOperandDefConfig
+    name: rhs
+    kind: input_tensor
+    type_var: T1
+    shape_map: affine_map<() -> ()>
+  - !LinalgOperandDefConfig
+    name: O
+    kind: output_tensor
+    type_var: T1
+    shape_map: affine_map<() -> ()>
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<() -> ()>
+    - affine_map<() -> ()>
+    - affine_map<() -> ()>
+    - affine_map<() -> ()>
+  iterator_types: []
+  assignments:
+  - !ScalarAssign
+    arg: O
+    value: !ScalarExpression
+      scalar_fn:
+        kind: ternary
+        fn_name: select
+        operands:
+        - !ScalarExpression
+          scalar_arg: cond
+        - !ScalarExpression
+          scalar_arg: lhs
+        - !ScalarExpression
+          scalar_arg: rhs
+--- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: matmul
   cpp_class_name: MatmulOp
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index e5f83331baf81..4933d71ac7fb0 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -492,6 +492,22 @@ class RegionBuilderHelper {
     llvm_unreachable("unsupported binary function");
   }
 
+  // Build the ternary functions defined by OpDSL.
+  Value buildTernaryFn(TernaryFn ternaryFn, Value arg0, Value arg1, Value arg2) {
+    bool headBool = isInteger(arg0) && arg0.getType().getIntOrFloatBitWidth() == 1;
+    bool tailFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1) && isFloatingPoint(arg2);
+    bool tailnteger = isInteger(arg0) && isInteger(arg1) && isInteger(arg1);
+    OpBuilder::InsertionGuard g(builder);
+    builder.setInsertionPointToEnd(&block);
+    switch (ternaryFn) {
+    case TernaryFn::select:
+      if (!headBool && !(tailFloatingPoint || tailInteger))
+        llvm_unreachable("unsupported non numeric type");
+      return builder.create<arith::SelectOp>(arg0.getLoc(), arg0, arg1, arg2);
+    }
+    llvm_unreachable("unsupported ternary function");
+  }
+
   // Build the type functions defined by OpDSL.
   Value buildTypeFn(TypeFn typeFn, Type toType, Value operand) {
     switch (typeFn) {
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
index bb43ebf2b6923..880dcb7250b96 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
@@ -262,7 +262,8 @@ def __repr__(self):
 class FunctionKind(Enum):
     UNARY = 0
     BINARY = 1
-    TYPE = 2
+    TERNARY = 2
+    TYPE = 3
 
 
 class UnaryFnType:
@@ -339,6 +340,30 @@ class BinaryFn:
     powf = BinaryFnType("powf")
 
 
+class TernaryFnType:
+    """Ternary function.
+
+    A bterary function takes three tensor expressions and returns the
+    function evaluation result.
+    """
+
+    def __init__(self, fn_name: str):
+        self.fn_name = fn_name
+
+    def __call__(self, arg0: TensorExpression, arg1: TensorExpression, arg2: TensorExpression) -> "TensorFn":
+        return TensorFn(FunctionKind.TERNARY, self.fn_name, None, None, [arg0, arg1, arg2])
+
+    def __repr__(self):
+        return f"{self.fn_name}"
+
+
+class TernaryFn:
+    """Ternary function namespace.
+    """
+
+    select = TernaryFnType("select")
+
+
 class TypeFnType:
     """Type conversion function.
 
@@ -437,7 +462,8 @@ class OperandKind(Enum):
     INDEX_ATTR = 3
     UNARY_FN_ATTR = 4
     BINARY_FN_ATTR = 5
-    TYPE_FN_ATTR = 6
+    TERNARY_FN_ATTR = 6
+    TYPE_FN_ATTR = 7
 
 
 class OperandDef:
@@ -489,6 +515,7 @@ def is_attribute(self) -> bool:
             self.kind == OperandKind.INDEX_ATTR
             or self.kind == OperandKind.UNARY_FN_ATTR
             or self.kind == OperandKind.BINARY_FN_ATTR
+            or self.kind == OperandKind.TERNARY_FN_ATTR
             or self.kind == OperandKind.TYPE_FN_ATTR
         )
 
@@ -670,6 +697,31 @@ def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse:
         return ReduceFnUse(None, self, *reduce_dims)
 
 
+class TernaryFnAttrDef:
+    """Ternary function attribute definition.
+
+    Ternary function attributes provide a way to make the arithmetic computation
+    parametrizable. Every attribute specifies a default Ternary function
+    that may be overwritten at operation instantiation time.
+    """
+
+    def __init__(self, default: "TernaryFnType"):
+        if not isinstance(default, TernaryFnType):
+            raise ValueError(
+                f"TernaryFnAttrDef requires default of type TernaryFnType "
+                f"but got {default}"
+            )
+        self.operand_def = OperandDef(
+            OperandKind.TERNARY_FN_ATTR, default_fn=default.fn_name
+        )
+
+    def __call__(self, arg0: TensorExpression, arg1: TensorExpression) -> TensorFn:
+        return TensorFn(FunctionKind.TERNARY, None, self.operand_def, None, [arg0, arg1])
+
+    def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse:
+        return ReduceFnUse(None, self, *reduce_dims)
+
+
 class TypeFnAttrDef:
     """Type conversion function attribute definition.
 
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
index f91fc8b716008..845b533db52a9 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
@@ -60,6 +60,7 @@ def prepare_common_structured_op(
         in [
             OperandKind.UNARY_FN_ATTR,
             OperandKind.BINARY_FN_ATTR,
+            OperandKind.TERNARY_FN_ATTR,
             OperandKind.TYPE_FN_ATTR,
         ]
     ]
@@ -180,6 +181,12 @@ def prepare_common_structured_op(
                         f"Attribute {fn_attr.name} needs to be of type "
                         f"BinaryFnType but got {type(attr_val)}"
                     )
+            elif attr_kind == OperandKind.TERNARY_FN_ATTR:
+                if not isinstance(fn, TernaryFnType):
+                    raise ValueError(
+                        f"Attribute {fn_attr.name} needs to be of type "
+                        f"TernaryFnType but got {type(attr_val)}"
+                    )
             else:
                 if not isinstance(fn, TypeFnType):
                     raise ValueError(
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index ca2bb0c5f7f8a..3cf3e9de0136f 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -351,6 +351,26 @@ def powf(
     O[None] = BinaryFn.powf(lhs[None], rhs[None])
 
 
+ at linalg_structured_op
+def selectf(
+    cond=TensorDef(bool),
+    lhs=TensorDef(T1),
+    rhs=TensorDef(T1),
+    O=TensorDef(T1, output=True),
+):
+    """Chooses one value based on a binary condition supplied as its first operand.
+
+    The shapes and element types must be identical. The appropriate casts,
+    broadcasts and reductions should be done previously to calling this op.
+
+    This means reduction/broadcast/element cast semantics is explicit. Further
+    passes can take that into account when lowering this code. For example,
+    a `linalg.broadcast` + `linalg.select` sequence can be lowered to a
+    `linalg.generic` with different affine maps for the two operands.
+    """
+    O[None] = TernaryFn.select(cond[None], lhs[None], rhs[None])
+
+
 @linalg_structured_op
 def matmul(
     A=TensorDef(T1, S.M, S.K),

>From 2c13925e0be5d3572aa1100d463fa80d691d8c29 Mon Sep 17 00:00:00 2001
From: Petr Kurapov <petr.a.kurapov at intel.com>
Date: Tue, 7 May 2024 18:15:06 +0000
Subject: [PATCH 2/6] [MLIR][Linalg] Fix the build & add ternary into
 appropriate enums

---
 .../Linalg/IR/LinalgNamedStructuredOps.yaml   | 45 +++----------------
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 11 +++--
 .../linalg/opdsl/ops/core_named_ops.py        |  2 +-
 .../mlir-linalg-ods-yaml-gen.cpp              | 10 ++++-
 4 files changed, 22 insertions(+), 46 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 7350271aa3829..2e4af321d1363 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -304,41 +304,6 @@ structured_op: !LinalgStructuredOpConfig
         - !ScalarExpression
           scalar_arg: I
 --- !LinalgOpConfig
-metadata: !LinalgOpMetadata
-  name: reciprocal
-  cpp_class_name: ReciprocalOp
-  doc: |-
-    Applies reciprocal(x) elementwise.
-
-    No numeric casting is performed on the input operand.
-structured_op: !LinalgStructuredOpConfig
-  args:
-  - !LinalgOperandDefConfig
-    name: I
-    kind: input_tensor
-    type_var: T1
-    shape_map: affine_map<() -> ()>
-  - !LinalgOperandDefConfig
-    name: O
-    kind: output_tensor
-    type_var: T1
-    shape_map: affine_map<() -> ()>
-  indexing_maps: !LinalgIndexingMapsConfig
-    static_indexing_maps:
-    - affine_map<() -> ()>
-    - affine_map<() -> ()>
-  iterator_types: []
-  assignments:
-  - !ScalarAssign
-    arg: O
-    value: !ScalarExpression
-      scalar_fn:
-        kind: unary
-        fn_name: reciprocal
-        operands:
-        - !ScalarExpression
-          scalar_arg: I
---- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: round
   cpp_class_name: RoundOp
@@ -516,7 +481,7 @@ structured_op: !LinalgStructuredOpConfig
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: erf
-  cpp_class_name: erfOp
+  cpp_class_name: ErfOp
   doc: |-
     Applies erf(x) elementwise.
 
@@ -959,7 +924,7 @@ structured_op: !LinalgStructuredOpConfig
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: powf
-  cpp_class_name: PowFOp
+  cpp_class_name: PowfOp
   doc: |-
     Takes the powf(lhs, rhs) between two inputs, elementwise. For powf(arg, 2) use `linalg.square`.
 
@@ -1009,8 +974,8 @@ structured_op: !LinalgStructuredOpConfig
           scalar_arg: rhs
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
-  name: select
-  cpp_class_name: SelectOp
+  name: selectf
+  cpp_class_name: SelectfOp
   doc: |-
     Chooses one value based on a binary condition supplied as its first operand.
 
@@ -1026,7 +991,7 @@ structured_op: !LinalgStructuredOpConfig
   - !LinalgOperandDefConfig
     name: cond
     kind: input_tensor
-    type_var: T1
+    type_var: U
     shape_map: affine_map<() -> ()>
   - !LinalgOperandDefConfig
     name: lhs
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 4933d71ac7fb0..6a5f25a7605f1 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -493,10 +493,13 @@ class RegionBuilderHelper {
   }
 
   // Build the ternary functions defined by OpDSL.
-  Value buildTernaryFn(TernaryFn ternaryFn, Value arg0, Value arg1, Value arg2) {
-    bool headBool = isInteger(arg0) && arg0.getType().getIntOrFloatBitWidth() == 1;
-    bool tailFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1) && isFloatingPoint(arg2);
-    bool tailnteger = isInteger(arg0) && isInteger(arg1) && isInteger(arg1);
+  Value buildTernaryFn(TernaryFn ternaryFn, Value arg0, Value arg1,
+                       Value arg2) {
+    bool headBool =
+        isInteger(arg0) && arg0.getType().getIntOrFloatBitWidth() == 1;
+    bool tailFloatingPoint =
+        isFloatingPoint(arg0) && isFloatingPoint(arg1) && isFloatingPoint(arg2);
+    bool tailInteger = isInteger(arg0) && isInteger(arg1) && isInteger(arg1);
     OpBuilder::InsertionGuard g(builder);
     builder.setInsertionPointToEnd(&block);
     switch (ternaryFn) {
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index 3cf3e9de0136f..04d10e7e9e7f5 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -353,7 +353,7 @@ def powf(
 
 @linalg_structured_op
 def selectf(
-    cond=TensorDef(bool),
+    cond=TensorDef(U),
     lhs=TensorDef(T1),
     rhs=TensorDef(T1),
     O=TensorDef(T1, output=True),
diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
index fe6ad15041126..37240164c377e 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
@@ -70,6 +70,7 @@ enum class LinalgOperandDefKind {
   IndexAttr,
   UnaryFnAttr,
   BinaryFnAttr,
+  TernaryFnAttr,
   TypeFnAttr
 };
 
@@ -94,7 +95,7 @@ struct LinalgIndexingMapsConfig {
 
 struct ScalarExpression;
 
-enum class ScalarFnKind { Unary, Binary, Type };
+enum class ScalarFnKind { Unary, Binary, Ternary, Type };
 
 struct ScalarFn {
   ScalarFnKind kind;
@@ -214,6 +215,7 @@ struct ScalarEnumerationTraits<LinalgOperandDefKind> {
     io.enumCase(value, "index_attr", LinalgOperandDefKind::IndexAttr);
     io.enumCase(value, "unary_fn_attr", LinalgOperandDefKind::UnaryFnAttr);
     io.enumCase(value, "binary_fn_attr", LinalgOperandDefKind::BinaryFnAttr);
+    io.enumCase(value, "ternary_fn_attr", LinalgOperandDefKind::TernaryFnAttr);
     io.enumCase(value, "type_fn_attr", LinalgOperandDefKind::TypeFnAttr);
   }
 };
@@ -284,6 +286,7 @@ struct ScalarEnumerationTraits<ScalarFnKind> {
   static void enumeration(IO &io, ScalarFnKind &value) {
     io.enumCase(value, "unary", ScalarFnKind::Unary);
     io.enumCase(value, "binary", ScalarFnKind::Binary);
+    io.enumCase(value, "ternary", ScalarFnKind::Ternary);
     io.enumCase(value, "type", ScalarFnKind::Type);
   }
 };
@@ -441,6 +444,7 @@ static ScalarAssign *findAssignment(StringRef name,
 static bool isFunctionAttribute(LinalgOperandDefKind kind) {
   return kind == LinalgOperandDefKind::UnaryFnAttr ||
          kind == LinalgOperandDefKind::BinaryFnAttr ||
+         kind == LinalgOperandDefKind::TernaryFnAttr ||
          kind == LinalgOperandDefKind::TypeFnAttr;
 }
 
@@ -456,6 +460,8 @@ std::string convertOperandKindToEnumName(LinalgOperandDefKind kind) {
     return std::string("UnaryFn");
   case LinalgOperandDefKind::BinaryFnAttr:
     return std::string("BinaryFn");
+  case LinalgOperandDefKind::TernaryFnAttr:
+    return std::string("TernaryFn");
   case LinalgOperandDefKind::TypeFnAttr:
     return std::string("TypeFn");
   default:
@@ -471,6 +477,8 @@ std::string convertFunctionKindToEnumName(ScalarFnKind kind) {
     return std::string("UnaryFn");
   case ScalarFnKind::Binary:
     return std::string("BinaryFn");
+  case ScalarFnKind::Ternary:
+    return std::string("TernaryFn");
   case ScalarFnKind::Type:
     return std::string("TypeFn");
   }

>From 334e31622bea42341420bb59c6e0cbb2d71887af Mon Sep 17 00:00:00 2001
From: Petr Kurapov <petr.a.kurapov at intel.com>
Date: Tue, 7 May 2024 19:24:08 +0000
Subject: [PATCH 3/6] [MLIR][Linalg] add basic select test

---
 .../Linalg/IR/LinalgNamedStructuredOps.yaml   | 43 +++++++++++++++++--
 .../linalg/opdsl/ops/core_named_ops.py        |  2 +-
 .../Dialect/Linalg/generalize-named-ops.mlir  | 25 +++++++++++
 mlir/test/Dialect/Linalg/named-ops-fail.mlir  | 16 +++++++
 mlir/test/Dialect/Linalg/named-ops.mlir       | 22 ++++++++++
 5 files changed, 103 insertions(+), 5 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 2e4af321d1363..eb7dd37010a67 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -304,6 +304,41 @@ structured_op: !LinalgStructuredOpConfig
         - !ScalarExpression
           scalar_arg: I
 --- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: reciprocal
+  cpp_class_name: ReciprocalOp
+  doc: |-
+    Applies reciprocal(x) elementwise.
+
+    No numeric casting is performed on the input operand.
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !LinalgOperandDefConfig
+    name: I
+    kind: input_tensor
+    type_var: T1
+    shape_map: affine_map<() -> ()>
+  - !LinalgOperandDefConfig
+    name: O
+    kind: output_tensor
+    type_var: T1
+    shape_map: affine_map<() -> ()>
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<() -> ()>
+    - affine_map<() -> ()>
+  iterator_types: []
+  assignments:
+  - !ScalarAssign
+    arg: O
+    value: !ScalarExpression
+      scalar_fn:
+        kind: unary
+        fn_name: reciprocal
+        operands:
+        - !ScalarExpression
+          scalar_arg: I
+--- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: round
   cpp_class_name: RoundOp
@@ -481,7 +516,7 @@ structured_op: !LinalgStructuredOpConfig
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: erf
-  cpp_class_name: ErfOp
+  cpp_class_name: erfOp
   doc: |-
     Applies erf(x) elementwise.
 
@@ -924,7 +959,7 @@ structured_op: !LinalgStructuredOpConfig
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: powf
-  cpp_class_name: PowfOp
+  cpp_class_name: PowFOp
   doc: |-
     Takes the powf(lhs, rhs) between two inputs, elementwise. For powf(arg, 2) use `linalg.square`.
 
@@ -974,8 +1009,8 @@ structured_op: !LinalgStructuredOpConfig
           scalar_arg: rhs
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
-  name: selectf
-  cpp_class_name: SelectfOp
+  name: select
+  cpp_class_name: SelectOp
   doc: |-
     Chooses one value based on a binary condition supplied as its first operand.
 
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index 04d10e7e9e7f5..d73428a0f4df3 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -352,7 +352,7 @@ def powf(
 
 
 @linalg_structured_op
-def selectf(
+def select(
     cond=TensorDef(U),
     lhs=TensorDef(T1),
     rhs=TensorDef(T1),
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index 667ea3c18c8ad..4f43ec2c9e1ce 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -791,6 +791,31 @@ func.func @generalize_powf(%lhs: memref<7x14x21xf32>, %rhs: memref<7x14x21xf32>,
 
 // -----
 
+func.func @generalize_select(%cond: memref<7x14x21xi1>, %lhs: memref<7x14x21xf32>, %rhs: memref<7x14x21xf32>,
+                              %out: memref<7x14x21xf32>) {
+  linalg.select ins(%cond, %lhs, %rhs: memref<7x14x21xi1>, memref<7x14x21xf32>, memref<7x14x21xf32>)
+                outs(%out: memref<7x14x21xf32>)
+  return
+}
+
+// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+
+// CHECK: func @generalize_select
+// CHECK-SAME: (%[[COND:.+]]: memref<7x14x21xi1>, %[[LHS:.+]]: memref<7x14x21xf32>, %[[RHS:.+]]: memref<7x14x21xf32>,
+// CHECK-SAME:  %[[OUT:.+]]: memref<7x14x21xf32>)
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]], #[[MAP]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
+// CHECK-SAME:  ins(%[[COND]], %[[LHS]], %[[RHS]] : memref<7x14x21xi1>, memref<7x14x21xf32>, memref<7x14x21xf32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xf32>)
+
+// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: i1, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32, %[[BBARG3:.+]]: f32)
+// CHECK-NEXT:      %[[select:.+]] = arith.select %[[BBARG0]], %[[BBARG1]], %[[BBARG2]] : f32
+// CHECK-NEXT:      linalg.yield %[[select]] : f32
+
+
+// -----
 
 // CHECK-LABEL: func @fill_tensor
 func.func @fill_tensor(%f: f32, %v: vector<2x4xf32>) -> (tensor<f32>, tensor<vector<2x4xf32>>) {
diff --git a/mlir/test/Dialect/Linalg/named-ops-fail.mlir b/mlir/test/Dialect/Linalg/named-ops-fail.mlir
index e92a77aa7ad05..552a0abaa797c 100644
--- a/mlir/test/Dialect/Linalg/named-ops-fail.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops-fail.mlir
@@ -334,3 +334,19 @@ func.func @powf_broadcast(%arg0: memref<8x16xf32>, %arg1: memref<4x8x16xf32>, %a
   return
 }
 
+// -----
+
+func.func @select_type_cast(%arg0: memref<4x8x16xi1>, %arg1: memref<4x8x16xf16>, %arg2: memref<4x8x16xf32>, %arg3: memref<4x8x16xf32>) {
+  // CHECK: op failed to verify that all of {true_value, false_value, result} have same type
+  linalg.select ins(%arg0, %arg1, %arg2 : memref<4x8x16xi1>, memref<4x8x16xf16>, memref<4x8x16xf32>) outs(%arg3: memref<4x8x16xf32>)
+  return
+}
+
+// -----
+
+func.func @select_wrong_condition_type(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf32>, %arg2: memref<4x8x16xf32>, %arg3: memref<4x8x16xf32>) {
+  // CHECK: op operand #0 must be bool-like, but got 'f32'
+  linalg.select ins(%arg0, %arg1, %arg2 : memref<4x8x16xf32>, memref<4x8x16xf32>, memref<4x8x16xf32>) outs(%arg3: memref<4x8x16xf32>)
+  return
+}
+
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index fefe5578947f0..cecd0033b7765 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1924,3 +1924,25 @@ func.func @fill_tensor(%f: f32, %v: vector<2x4xf32>) -> (tensor<f32>, tensor<vec
   %1 = linalg.fill ins(%v : vector<2x4xf32>) outs(%e1 : tensor<vector<2x4xf32>>) -> tensor<vector<2x4xf32>>
   return %0, %1: tensor<f32>, tensor<vector<2x4xf32>>
 }
+
+// -----
+
+// CHECK-LABEL: func @select_dynamic
+func.func @select_dynamic(%arg0: memref<?x?x?xi1>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>, %arg3: memref<?x?x?xf32>) {
+  // CHECK: linalg.select
+  // CHECK-SAME: ins(%{{.+}}, %{{.+}}, %{{.+}} : memref<?x?x?xi1>, memref<?x?x?xf32>, memref<?x?x?xf32>)
+  // CHECK-SAME: outs(%{{.+}} : memref<?x?x?xf32>)
+  linalg.select ins(%arg0, %arg1, %arg2 : memref<?x?x?xi1>, memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg3: memref<?x?x?xf32>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @select_static
+func.func @select_static(%arg0: memref<4x8x16xi1>, %arg1: memref<4x8x16xf32>, %arg2: memref<4x8x16xf32>, %arg3: memref<4x8x16xf32>) {
+  // CHECK: linalg.select
+  // CHECK-SAME: ins(%{{.+}}, %{{.+}}, %{{.+}} : memref<4x8x16xi1>, memref<4x8x16xf32>, memref<4x8x16xf32>)
+  // CHECK-SAME: outs(%{{.+}} : memref<4x8x16xf32>)
+  linalg.select ins(%arg0, %arg1, %arg2 : memref<4x8x16xi1>, memref<4x8x16xf32>, memref<4x8x16xf32>) outs(%arg3: memref<4x8x16xf32>)
+  return
+}

>From 40d9d3d3fb31fc3441a3ceabc6d4fa49b6f64a83 Mon Sep 17 00:00:00 2001
From: Petr Kurapov <petr.kurapov at gmail.com>
Date: Wed, 8 May 2024 14:57:48 +0200
Subject: [PATCH 4/6] Update
 mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py

Co-authored-by: Maksim Levental <maksim.levental at gmail.com>
---
 mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
index 880dcb7250b96..34eea0d0fc8a0 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
@@ -343,7 +343,7 @@ class BinaryFn:
 class TernaryFnType:
     """Ternary function.
 
-    A bterary function takes three tensor expressions and returns the
+    A ternary function takes three tensor expressions and returns the
     function evaluation result.
     """
 

>From b56a72fd026c201c8a76e430e5f716041d7c19af Mon Sep 17 00:00:00 2001
From: Petr Kurapov <petr.a.kurapov at intel.com>
Date: Wed, 8 May 2024 13:07:01 +0000
Subject: [PATCH 5/6] Fix Python formatting

---
 .../dialects/linalg/opdsl/lang/comprehension.py   | 15 ++++++++++-----
 1 file changed, 10 insertions(+), 5 deletions(-)

diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
index 34eea0d0fc8a0..1a198fc5ec6f9 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
@@ -350,16 +350,19 @@ class TernaryFnType:
     def __init__(self, fn_name: str):
         self.fn_name = fn_name
 
-    def __call__(self, arg0: TensorExpression, arg1: TensorExpression, arg2: TensorExpression) -> "TensorFn":
-        return TensorFn(FunctionKind.TERNARY, self.fn_name, None, None, [arg0, arg1, arg2])
+    def __call__(
+        self, arg0: TensorExpression, arg1: TensorExpression, arg2: TensorExpression
+    ) -> "TensorFn":
+        return TensorFn(
+            FunctionKind.TERNARY, self.fn_name, None, None, [arg0, arg1, arg2]
+        )
 
     def __repr__(self):
         return f"{self.fn_name}"
 
 
 class TernaryFn:
-    """Ternary function namespace.
-    """
+    """Ternary function namespace."""
 
     select = TernaryFnType("select")
 
@@ -716,7 +719,9 @@ def __init__(self, default: "TernaryFnType"):
         )
 
     def __call__(self, arg0: TensorExpression, arg1: TensorExpression) -> TensorFn:
-        return TensorFn(FunctionKind.TERNARY, None, self.operand_def, None, [arg0, arg1])
+        return TensorFn(
+            FunctionKind.TERNARY, None, self.operand_def, None, [arg0, arg1]
+        )
 
     def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse:
         return ReduceFnUse(None, self, *reduce_dims)

>From dfb0a5fd1f198e7a25d889b460a2a6ce911f91fd Mon Sep 17 00:00:00 2001
From: Petr Kurapov <petr.a.kurapov at intel.com>
Date: Fri, 10 May 2024 12:19:30 +0000
Subject: [PATCH 6/6] Add select on tensor test

---
 mlir/test/Dialect/Linalg/named-ops.mlir | 12 ++++++++++++
 1 file changed, 12 insertions(+)

diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index cecd0033b7765..051054e67edf0 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1946,3 +1946,15 @@ func.func @select_static(%arg0: memref<4x8x16xi1>, %arg1: memref<4x8x16xf32>, %a
   linalg.select ins(%arg0, %arg1, %arg2 : memref<4x8x16xi1>, memref<4x8x16xf32>, memref<4x8x16xf32>) outs(%arg3: memref<4x8x16xf32>)
   return
 }
+
+// -----
+
+// CHECK-LABEL: func @select_tensor
+func.func @select_tensor(%arg0: tensor<4x8x16xi1>, %arg1: tensor<4x8x16xf32>, %arg2: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> {
+  %0 = tensor.empty() : tensor<4x8x16xf32>
+  // CHECK: linalg.select
+  // CHECK-SAME: ins(%{{.+}}, %{{.+}}, %{{.+}} : tensor<4x8x16xi1>, tensor<4x8x16xf32>, tensor<4x8x16xf32>)
+  // CHECK-SAME: outs(%{{.+}} : tensor<4x8x16xf32>)
+  %1 = linalg.select ins(%arg0, %arg1, %arg2 : tensor<4x8x16xi1>, tensor<4x8x16xf32>, tensor<4x8x16xf32>) outs(%0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
+  return %1 : tensor<4x8x16xf32>
+}



More information about the Mlir-commits mailing list