[Mlir-commits] [mlir] d8dc1c2 - [MLIR][Linalg] Add max named op to linalg

Renato Golin llvmlistbot at llvm.org
Fri Jul 7 05:40:25 PDT 2023

Author: Renato Golin
Date: 2023-07-07T13:39:12+01:00
New Revision: d8dc1c22bf926cb8c87d7ff72bae6aafe076bbc2

URL: https://github.com/llvm/llvm-project/commit/d8dc1c22bf926cb8c87d7ff72bae6aafe076bbc2
DIFF: https://github.com/llvm/llvm-project/commit/d8dc1c22bf926cb8c87d7ff72bae6aafe076bbc2.diff

LOG: [MLIR][Linalg] Add max named op to linalg

I've been trying to come up with a simple and clean implementation for
ReLU. TOSA uses `clamp` which is probably the goal, but that means
table-gen to make it efficient (attributes, only lower `min` or `max`).

For now, `max` is a reasonable named op despite ReLU, so we can start
using it for tiling and fusion, and upon success, we create a more
complete op `clamp` that doesn't need a whole tensor filled with zeroes
or ones to implement the different activation functions.

As with other named ops, we start "requiring" type casts and broadcasts,
and zero filled constant tensors to a more complex pattern-matcher, and
can slowly simplify with attributes or structured matchers (ex. PDL) in
the future.

Differential Revision: https://reviews.llvm.org/D154703




diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 11fa49ef34681d..d021376ff4cdcd 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -613,6 +613,55 @@ structured_op: !LinalgStructuredOpConfig
         - !ScalarExpression
           scalar_arg: rhs
 --- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: max
+  cpp_class_name: MaxOp
+  doc: |-
+    Takes the max (signed) between the input and a constant.
+    The shapes and element types must be identical. The appropriate casts,
+    broadcasts and reductions should be done previously to calling this op.
+    This means reduction/broadcast/element cast semantics is explicit. Further
+    passes can take that into account when lowering this code. For example,
+    a `linalg.broadcast` + `linalg.div` sequence can be lowered to a
+    `linalg.generic` with 
diff erent affine maps for the two operands.
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !LinalgOperandDefConfig
+    name: lhs
+    kind: input_tensor
+    type_var: T
+    shape_map: affine_map<() -> ()>
+  - !LinalgOperandDefConfig
+    name: rhs
+    kind: input_tensor
+    type_var: T
+    shape_map: affine_map<() -> ()>
+  - !LinalgOperandDefConfig
+    name: out
+    kind: output_tensor
+    type_var: T
+    shape_map: affine_map<() -> ()>
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<() -> ()>
+    - affine_map<() -> ()>
+    - affine_map<() -> ()>
+  iterator_types: []
+  assignments:
+  - !ScalarAssign
+    arg: out
+    value: !ScalarExpression
+      scalar_fn:
+        kind: binary
+        fn_name: max_signed
+        operands:
+        - !ScalarExpression
+          scalar_arg: lhs
+        - !ScalarExpression
+          scalar_arg: rhs
+--- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: matmul
   cpp_class_name: MatmulOp

diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index 9cc252eb710234..e4512cd1e0573c 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -219,6 +219,25 @@ def div_unsigned(
     O[None] = lhs[None] / rhs[None]
+ at linalg_structured_op
+def max(
+    lhs=TensorDef(T1),
+    rhs=TensorDef(T1),
+    O=TensorDef(T1, output=True),
+    """Takes the max (signed) between two inputs, elementwise.
+    The shapes and element types must be identical. The appropriate casts,
+    broadcasts and reductions should be done previously to calling this op.
+    This means reduction/broadcast/element cast semantics is explicit. Further
+    passes can take that into account when lowering this code. For example,
+    a `linalg.broadcast` + `linalg.div` sequence can be lowered to a
+    `linalg.generic` with 
diff erent affine maps for the two operands.
+    """
+    O[None] = BinaryFn.max_signed(lhs[None], rhs[None])
 def matmul(
     A=TensorDef(T1, S.M, S.K),

diff  --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index 7e96ad2b0b2412..af616a0a7bd8da 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -537,3 +537,28 @@ func.func @generalize_negf(%arg: memref<7x14x21xf32>, %out: memref<7x14x21xf32>)
 // CHECK:         ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32)
 // CHECK-NEXT:      %[[negf:.+]] = arith.negf %[[BBARG0]] : f32
 // CHECK-NEXT:      linalg.yield %[[negf]] : f32
+// -----
+func.func @generalize_max(%lhs: memref<7x14x21xf32>, %rhs: memref<7x14x21xf32>,
+                          %out: memref<7x14x21xf32>) {
+  linalg.max ins(%lhs, %rhs : memref<7x14x21xf32>, memref<7x14x21xf32>)
+             outs(%out : memref<7x14x21xf32>)
+  return
+// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK: func @generalize_max
+// CHECK-SAME: (%[[LHS:.+]]: memref<7x14x21xf32>, %[[RHS:.+]]: memref<7x14x21xf32>,
+// CHECK-SAME:  %[[OUT:.+]]: memref<7x14x21xf32>)
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
+// CHECK-SAME:  ins(%[[LHS]], %[[RHS]] : memref<7x14x21xf32>, memref<7x14x21xf32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xf32>)
+// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32)
+// CHECK-NEXT:      %[[max:.+]] = arith.maxf %[[BBARG0]], %[[BBARG1]] : f32
+// CHECK-NEXT:      linalg.yield %[[max]] : f32

diff  --git a/mlir/test/Dialect/Linalg/named-ops-fail.mlir b/mlir/test/Dialect/Linalg/named-ops-fail.mlir
index b5cd9b659dd181..c351e139a97e37 100644
--- a/mlir/test/Dialect/Linalg/named-ops-fail.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops-fail.mlir
@@ -173,3 +173,19 @@ func.func @negf_broadcast(%arg: memref<8x16xf32>, %out: memref<4x8x16xf32>) {
   linalg.negf ins(%arg : memref<8x16xf32>) outs(%out: memref<4x8x16xf32>)
+// -----
+func.func @max_type_cast(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf16>, %arg2: memref<4x8x16xf32>) {
+  // CHECK: op requires the same type for all operands and results
+  linalg.max ins(%arg0, %arg1 : memref<4x8x16xf32>, memref<4x8x16xf16>) outs(%arg2: memref<4x8x16xf32>)
+  return
+// -----
+func.func @max_broadcast(%arg0: memref<8x16xf32>, %arg1: memref<4x8x16xf32>, %arg2: memref<4x8x16xf32>) {
+  // CHECK: op expected operand rank (2) to match the result rank of indexing_map #0 (3)
+  linalg.max ins(%arg0, %arg1 : memref<8x16xf32>, memref<4x8x16xf32>) outs(%arg2: memref<4x8x16xf32>)
+  return

diff  --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index ca4350e3056614..8f00d546553274 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1540,3 +1540,37 @@ func.func @negf_tensor(%arg0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> {
   %1 = linalg.negf ins(%arg0 : tensor<4x8x16xf32>) outs(%0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
   return %1 : tensor<4x8x16xf32>
+// -----
+// CHECK-LABEL: func @max_dynamic
+func.func @max_dynamic(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
+  // CHECK: linalg.max
+  // CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<?x?x?xf32>, memref<?x?x?xf32>)
+  // CHECK-SAME: outs(%{{.+}} : memref<?x?x?xf32>)
+  linalg.max ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2: memref<?x?x?xf32>)
+  return
+// -----
+// CHECK-LABEL: func @max_static
+func.func @max_static(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf32>, %arg2: memref<4x8x16xf32>) {
+  // CHECK: linalg.max
+  // CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<4x8x16xf32>, memref<4x8x16xf32>)
+  // CHECK-SAME: outs(%{{.+}} : memref<4x8x16xf32>)
+  linalg.max ins(%arg0, %arg1 : memref<4x8x16xf32>, memref<4x8x16xf32>) outs(%arg2: memref<4x8x16xf32>)
+  return
+// -----
+// CHECK-LABEL: func @max_tensor
+func.func @max_tensor(%arg0: tensor<4x8x16xf32>, %arg1: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> {
+  %0 = tensor.empty() : tensor<4x8x16xf32>
+  // CHECK: linalg.max
+  // CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<4x8x16xf32>, tensor<4x8x16xf32>)
+  // CHECK-SAME: outs(%{{.+}} : tensor<4x8x16xf32>)
+  %1 = linalg.max ins(%arg0, %arg1 : tensor<4x8x16xf32>, tensor<4x8x16xf32>) outs(%0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
+  return %1 : tensor<4x8x16xf32>


More information about the Mlir-commits mailing list