[Mlir-commits] [mlir] [MLIR][Linalg] Ternary Op & Linalg select (PR #91461)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed May 8 04:36:44 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-core

Author: Petr Kurapov (kurapov-peter)

<details>
<summary>Changes</summary>

Following #<!-- -->90236, adding `select` to linalg as `arith.select`. No implicit type casting.
OpDSL doesn't expose a type restriction for bool, but I saw no reason in adding it (put a separate symbolic type and check the semantics in the builder).

---
Full diff: https://github.com/llvm/llvm-project/pull/91461.diff


11 Files Affected:

- (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td (+3) 
- (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td (+6) 
- (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml (+57) 
- (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+19) 
- (modified) mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py (+54-2) 
- (modified) mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py (+7) 
- (modified) mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py (+20) 
- (modified) mlir/test/Dialect/Linalg/generalize-named-ops.mlir (+25) 
- (modified) mlir/test/Dialect/Linalg/named-ops-fail.mlir (+16) 
- (modified) mlir/test/Dialect/Linalg/named-ops.mlir (+22) 
- (modified) mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp (+9-1) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
index e87e8b5600107..73f984dc072d3 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
@@ -68,6 +68,9 @@ def UnaryFnAttr : EnumAttr<Linalg_Dialect, UnaryFn, "unary_fn"> {
 def BinaryFnAttr : EnumAttr<Linalg_Dialect, BinaryFn, "binary_fn"> {
   let assemblyFormat = "`<` $value `>`";
 }
+def TernaryFnAttr : EnumAttr<Linalg_Dialect, TernaryFn, "ternary_fn"> {
+  let assemblyFormat = "`<` $value `>`";
+}
 def TypeFnAttr : EnumAttr<Linalg_Dialect, TypeFn, "type_fn"> {
   let assemblyFormat = "`<` $value `>`";
 }
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
index 6b4b073fc6724..e615876a95d05 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
@@ -49,6 +49,12 @@ def BinaryFn : I32EnumAttr<"BinaryFn", "", [
   let genSpecializedAttr = 0;
   let cppNamespace = "::mlir::linalg";
 }
+def TernaryFn : I32EnumAttr<"TernaryFn", "", [
+  I32EnumAttrCase<"select", 0>
+]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "::mlir::linalg";
+}
 def TypeFn : I32EnumAttr<"TypeFn", "", [
   I32EnumAttrCase<"cast_signed", 0>,
   I32EnumAttrCase<"cast_unsigned", 1>
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 584bfcd8b59dc..eb7dd37010a67 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -1008,6 +1008,63 @@ structured_op: !LinalgStructuredOpConfig
         - !ScalarExpression
           scalar_arg: rhs
 --- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: select
+  cpp_class_name: SelectOp
+  doc: |-
+    Chooses one value based on a binary condition supplied as its first operand.
+
+    The shapes and element types must be identical. The appropriate casts,
+    broadcasts and reductions should be done previously to calling this op.
+
+    This means reduction/broadcast/element cast semantics is explicit. Further
+    passes can take that into account when lowering this code. For example,
+    a `linalg.broadcast` + `linalg.select` sequence can be lowered to a
+    `linalg.generic` with different affine maps for the two operands.
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !LinalgOperandDefConfig
+    name: cond
+    kind: input_tensor
+    type_var: U
+    shape_map: affine_map<() -> ()>
+  - !LinalgOperandDefConfig
+    name: lhs
+    kind: input_tensor
+    type_var: T1
+    shape_map: affine_map<() -> ()>
+  - !LinalgOperandDefConfig
+    name: rhs
+    kind: input_tensor
+    type_var: T1
+    shape_map: affine_map<() -> ()>
+  - !LinalgOperandDefConfig
+    name: O
+    kind: output_tensor
+    type_var: T1
+    shape_map: affine_map<() -> ()>
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<() -> ()>
+    - affine_map<() -> ()>
+    - affine_map<() -> ()>
+    - affine_map<() -> ()>
+  iterator_types: []
+  assignments:
+  - !ScalarAssign
+    arg: O
+    value: !ScalarExpression
+      scalar_fn:
+        kind: ternary
+        fn_name: select
+        operands:
+        - !ScalarExpression
+          scalar_arg: cond
+        - !ScalarExpression
+          scalar_arg: lhs
+        - !ScalarExpression
+          scalar_arg: rhs
+--- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: matmul
   cpp_class_name: MatmulOp
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index e5f83331baf81..6a5f25a7605f1 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -492,6 +492,25 @@ class RegionBuilderHelper {
     llvm_unreachable("unsupported binary function");
   }
 
+  // Build the ternary functions defined by OpDSL.
+  Value buildTernaryFn(TernaryFn ternaryFn, Value arg0, Value arg1,
+                       Value arg2) {
+    bool headBool =
+        isInteger(arg0) && arg0.getType().getIntOrFloatBitWidth() == 1;
+    bool tailFloatingPoint =
+        isFloatingPoint(arg0) && isFloatingPoint(arg1) && isFloatingPoint(arg2);
+    bool tailInteger = isInteger(arg0) && isInteger(arg1) && isInteger(arg1);
+    OpBuilder::InsertionGuard g(builder);
+    builder.setInsertionPointToEnd(&block);
+    switch (ternaryFn) {
+    case TernaryFn::select:
+      if (!headBool && !(tailFloatingPoint || tailInteger))
+        llvm_unreachable("unsupported non numeric type");
+      return builder.create<arith::SelectOp>(arg0.getLoc(), arg0, arg1, arg2);
+    }
+    llvm_unreachable("unsupported ternary function");
+  }
+
   // Build the type functions defined by OpDSL.
   Value buildTypeFn(TypeFn typeFn, Type toType, Value operand) {
     switch (typeFn) {
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
index bb43ebf2b6923..880dcb7250b96 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
@@ -262,7 +262,8 @@ def __repr__(self):
 class FunctionKind(Enum):
     UNARY = 0
     BINARY = 1
-    TYPE = 2
+    TERNARY = 2
+    TYPE = 3
 
 
 class UnaryFnType:
@@ -339,6 +340,30 @@ class BinaryFn:
     powf = BinaryFnType("powf")
 
 
+class TernaryFnType:
+    """Ternary function.
+
+    A bterary function takes three tensor expressions and returns the
+    function evaluation result.
+    """
+
+    def __init__(self, fn_name: str):
+        self.fn_name = fn_name
+
+    def __call__(self, arg0: TensorExpression, arg1: TensorExpression, arg2: TensorExpression) -> "TensorFn":
+        return TensorFn(FunctionKind.TERNARY, self.fn_name, None, None, [arg0, arg1, arg2])
+
+    def __repr__(self):
+        return f"{self.fn_name}"
+
+
+class TernaryFn:
+    """Ternary function namespace.
+    """
+
+    select = TernaryFnType("select")
+
+
 class TypeFnType:
     """Type conversion function.
 
@@ -437,7 +462,8 @@ class OperandKind(Enum):
     INDEX_ATTR = 3
     UNARY_FN_ATTR = 4
     BINARY_FN_ATTR = 5
-    TYPE_FN_ATTR = 6
+    TERNARY_FN_ATTR = 6
+    TYPE_FN_ATTR = 7
 
 
 class OperandDef:
@@ -489,6 +515,7 @@ def is_attribute(self) -> bool:
             self.kind == OperandKind.INDEX_ATTR
             or self.kind == OperandKind.UNARY_FN_ATTR
             or self.kind == OperandKind.BINARY_FN_ATTR
+            or self.kind == OperandKind.TERNARY_FN_ATTR
             or self.kind == OperandKind.TYPE_FN_ATTR
         )
 
@@ -670,6 +697,31 @@ def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse:
         return ReduceFnUse(None, self, *reduce_dims)
 
 
+class TernaryFnAttrDef:
+    """Ternary function attribute definition.
+
+    Ternary function attributes provide a way to make the arithmetic computation
+    parametrizable. Every attribute specifies a default Ternary function
+    that may be overwritten at operation instantiation time.
+    """
+
+    def __init__(self, default: "TernaryFnType"):
+        if not isinstance(default, TernaryFnType):
+            raise ValueError(
+                f"TernaryFnAttrDef requires default of type TernaryFnType "
+                f"but got {default}"
+            )
+        self.operand_def = OperandDef(
+            OperandKind.TERNARY_FN_ATTR, default_fn=default.fn_name
+        )
+
+    def __call__(self, arg0: TensorExpression, arg1: TensorExpression) -> TensorFn:
+        return TensorFn(FunctionKind.TERNARY, None, self.operand_def, None, [arg0, arg1])
+
+    def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse:
+        return ReduceFnUse(None, self, *reduce_dims)
+
+
 class TypeFnAttrDef:
     """Type conversion function attribute definition.
 
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
index f91fc8b716008..845b533db52a9 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
@@ -60,6 +60,7 @@ def prepare_common_structured_op(
         in [
             OperandKind.UNARY_FN_ATTR,
             OperandKind.BINARY_FN_ATTR,
+            OperandKind.TERNARY_FN_ATTR,
             OperandKind.TYPE_FN_ATTR,
         ]
     ]
@@ -180,6 +181,12 @@ def prepare_common_structured_op(
                         f"Attribute {fn_attr.name} needs to be of type "
                         f"BinaryFnType but got {type(attr_val)}"
                     )
+            elif attr_kind == OperandKind.TERNARY_FN_ATTR:
+                if not isinstance(fn, TernaryFnType):
+                    raise ValueError(
+                        f"Attribute {fn_attr.name} needs to be of type "
+                        f"TernaryFnType but got {type(attr_val)}"
+                    )
             else:
                 if not isinstance(fn, TypeFnType):
                     raise ValueError(
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index ca2bb0c5f7f8a..d73428a0f4df3 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -351,6 +351,26 @@ def powf(
     O[None] = BinaryFn.powf(lhs[None], rhs[None])
 
 
+ at linalg_structured_op
+def select(
+    cond=TensorDef(U),
+    lhs=TensorDef(T1),
+    rhs=TensorDef(T1),
+    O=TensorDef(T1, output=True),
+):
+    """Chooses one value based on a binary condition supplied as its first operand.
+
+    The shapes and element types must be identical. The appropriate casts,
+    broadcasts and reductions should be done previously to calling this op.
+
+    This means reduction/broadcast/element cast semantics is explicit. Further
+    passes can take that into account when lowering this code. For example,
+    a `linalg.broadcast` + `linalg.select` sequence can be lowered to a
+    `linalg.generic` with different affine maps for the two operands.
+    """
+    O[None] = TernaryFn.select(cond[None], lhs[None], rhs[None])
+
+
 @linalg_structured_op
 def matmul(
     A=TensorDef(T1, S.M, S.K),
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index 667ea3c18c8ad..4f43ec2c9e1ce 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -791,6 +791,31 @@ func.func @generalize_powf(%lhs: memref<7x14x21xf32>, %rhs: memref<7x14x21xf32>,
 
 // -----
 
+func.func @generalize_select(%cond: memref<7x14x21xi1>, %lhs: memref<7x14x21xf32>, %rhs: memref<7x14x21xf32>,
+                              %out: memref<7x14x21xf32>) {
+  linalg.select ins(%cond, %lhs, %rhs: memref<7x14x21xi1>, memref<7x14x21xf32>, memref<7x14x21xf32>)
+                outs(%out: memref<7x14x21xf32>)
+  return
+}
+
+// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+
+// CHECK: func @generalize_select
+// CHECK-SAME: (%[[COND:.+]]: memref<7x14x21xi1>, %[[LHS:.+]]: memref<7x14x21xf32>, %[[RHS:.+]]: memref<7x14x21xf32>,
+// CHECK-SAME:  %[[OUT:.+]]: memref<7x14x21xf32>)
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]], #[[MAP]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
+// CHECK-SAME:  ins(%[[COND]], %[[LHS]], %[[RHS]] : memref<7x14x21xi1>, memref<7x14x21xf32>, memref<7x14x21xf32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xf32>)
+
+// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: i1, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32, %[[BBARG3:.+]]: f32)
+// CHECK-NEXT:      %[[select:.+]] = arith.select %[[BBARG0]], %[[BBARG1]], %[[BBARG2]] : f32
+// CHECK-NEXT:      linalg.yield %[[select]] : f32
+
+
+// -----
 
 // CHECK-LABEL: func @fill_tensor
 func.func @fill_tensor(%f: f32, %v: vector<2x4xf32>) -> (tensor<f32>, tensor<vector<2x4xf32>>) {
diff --git a/mlir/test/Dialect/Linalg/named-ops-fail.mlir b/mlir/test/Dialect/Linalg/named-ops-fail.mlir
index e92a77aa7ad05..552a0abaa797c 100644
--- a/mlir/test/Dialect/Linalg/named-ops-fail.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops-fail.mlir
@@ -334,3 +334,19 @@ func.func @powf_broadcast(%arg0: memref<8x16xf32>, %arg1: memref<4x8x16xf32>, %a
   return
 }
 
+// -----
+
+func.func @select_type_cast(%arg0: memref<4x8x16xi1>, %arg1: memref<4x8x16xf16>, %arg2: memref<4x8x16xf32>, %arg3: memref<4x8x16xf32>) {
+  // CHECK: op failed to verify that all of {true_value, false_value, result} have same type
+  linalg.select ins(%arg0, %arg1, %arg2 : memref<4x8x16xi1>, memref<4x8x16xf16>, memref<4x8x16xf32>) outs(%arg3: memref<4x8x16xf32>)
+  return
+}
+
+// -----
+
+func.func @select_wrong_condition_type(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf32>, %arg2: memref<4x8x16xf32>, %arg3: memref<4x8x16xf32>) {
+  // CHECK: op operand #0 must be bool-like, but got 'f32'
+  linalg.select ins(%arg0, %arg1, %arg2 : memref<4x8x16xf32>, memref<4x8x16xf32>, memref<4x8x16xf32>) outs(%arg3: memref<4x8x16xf32>)
+  return
+}
+
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index fefe5578947f0..cecd0033b7765 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1924,3 +1924,25 @@ func.func @fill_tensor(%f: f32, %v: vector<2x4xf32>) -> (tensor<f32>, tensor<vec
   %1 = linalg.fill ins(%v : vector<2x4xf32>) outs(%e1 : tensor<vector<2x4xf32>>) -> tensor<vector<2x4xf32>>
   return %0, %1: tensor<f32>, tensor<vector<2x4xf32>>
 }
+
+// -----
+
+// CHECK-LABEL: func @select_dynamic
+func.func @select_dynamic(%arg0: memref<?x?x?xi1>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>, %arg3: memref<?x?x?xf32>) {
+  // CHECK: linalg.select
+  // CHECK-SAME: ins(%{{.+}}, %{{.+}}, %{{.+}} : memref<?x?x?xi1>, memref<?x?x?xf32>, memref<?x?x?xf32>)
+  // CHECK-SAME: outs(%{{.+}} : memref<?x?x?xf32>)
+  linalg.select ins(%arg0, %arg1, %arg2 : memref<?x?x?xi1>, memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg3: memref<?x?x?xf32>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @select_static
+func.func @select_static(%arg0: memref<4x8x16xi1>, %arg1: memref<4x8x16xf32>, %arg2: memref<4x8x16xf32>, %arg3: memref<4x8x16xf32>) {
+  // CHECK: linalg.select
+  // CHECK-SAME: ins(%{{.+}}, %{{.+}}, %{{.+}} : memref<4x8x16xi1>, memref<4x8x16xf32>, memref<4x8x16xf32>)
+  // CHECK-SAME: outs(%{{.+}} : memref<4x8x16xf32>)
+  linalg.select ins(%arg0, %arg1, %arg2 : memref<4x8x16xi1>, memref<4x8x16xf32>, memref<4x8x16xf32>) outs(%arg3: memref<4x8x16xf32>)
+  return
+}
diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
index fe6ad15041126..37240164c377e 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
@@ -70,6 +70,7 @@ enum class LinalgOperandDefKind {
   IndexAttr,
   UnaryFnAttr,
   BinaryFnAttr,
+  TernaryFnAttr,
   TypeFnAttr
 };
 
@@ -94,7 +95,7 @@ struct LinalgIndexingMapsConfig {
 
 struct ScalarExpression;
 
-enum class ScalarFnKind { Unary, Binary, Type };
+enum class ScalarFnKind { Unary, Binary, Ternary, Type };
 
 struct ScalarFn {
   ScalarFnKind kind;
@@ -214,6 +215,7 @@ struct ScalarEnumerationTraits<LinalgOperandDefKind> {
     io.enumCase(value, "index_attr", LinalgOperandDefKind::IndexAttr);
     io.enumCase(value, "unary_fn_attr", LinalgOperandDefKind::UnaryFnAttr);
     io.enumCase(value, "binary_fn_attr", LinalgOperandDefKind::BinaryFnAttr);
+    io.enumCase(value, "ternary_fn_attr", LinalgOperandDefKind::TernaryFnAttr);
     io.enumCase(value, "type_fn_attr", LinalgOperandDefKind::TypeFnAttr);
   }
 };
@@ -284,6 +286,7 @@ struct ScalarEnumerationTraits<ScalarFnKind> {
   static void enumeration(IO &io, ScalarFnKind &value) {
     io.enumCase(value, "unary", ScalarFnKind::Unary);
     io.enumCase(value, "binary", ScalarFnKind::Binary);
+    io.enumCase(value, "ternary", ScalarFnKind::Ternary);
     io.enumCase(value, "type", ScalarFnKind::Type);
   }
 };
@@ -441,6 +444,7 @@ static ScalarAssign *findAssignment(StringRef name,
 static bool isFunctionAttribute(LinalgOperandDefKind kind) {
   return kind == LinalgOperandDefKind::UnaryFnAttr ||
          kind == LinalgOperandDefKind::BinaryFnAttr ||
+         kind == LinalgOperandDefKind::TernaryFnAttr ||
          kind == LinalgOperandDefKind::TypeFnAttr;
 }
 
@@ -456,6 +460,8 @@ std::string convertOperandKindToEnumName(LinalgOperandDefKind kind) {
     return std::string("UnaryFn");
   case LinalgOperandDefKind::BinaryFnAttr:
     return std::string("BinaryFn");
+  case LinalgOperandDefKind::TernaryFnAttr:
+    return std::string("TernaryFn");
   case LinalgOperandDefKind::TypeFnAttr:
     return std::string("TypeFn");
   default:
@@ -471,6 +477,8 @@ std::string convertFunctionKindToEnumName(ScalarFnKind kind) {
     return std::string("UnaryFn");
   case ScalarFnKind::Binary:
     return std::string("BinaryFn");
+  case ScalarFnKind::Ternary:
+    return std::string("TernaryFn");
   case ScalarFnKind::Type:
     return std::string("TypeFn");
   }

``````````

</details>


https://github.com/llvm/llvm-project/pull/91461


More information about the Mlir-commits mailing list