[Mlir-commits] [mlir] d8dc1c2 - [MLIR][Linalg] Add max named op to linalg
Renato Golin
llvmlistbot at llvm.org
Fri Jul 7 05:40:25 PDT 2023
Author: Renato Golin
Date: 2023-07-07T13:39:12+01:00
New Revision: d8dc1c22bf926cb8c87d7ff72bae6aafe076bbc2
URL: https://github.com/llvm/llvm-project/commit/d8dc1c22bf926cb8c87d7ff72bae6aafe076bbc2
DIFF: https://github.com/llvm/llvm-project/commit/d8dc1c22bf926cb8c87d7ff72bae6aafe076bbc2.diff
LOG: [MLIR][Linalg] Add max named op to linalg
I've been trying to come up with a simple and clean implementation for
ReLU. TOSA uses `clamp` which is probably the goal, but that means
table-gen to make it efficient (attributes, only lower `min` or `max`).
For now, `max` is a reasonable named op despite ReLU, so we can start
using it for tiling and fusion, and upon success, we create a more
complete op `clamp` that doesn't need a whole tensor filled with zeroes
or ones to implement the different activation functions.
As with other named ops, we start "requiring" type casts and broadcasts,
and zero filled constant tensors to a more complex pattern-matcher, and
can slowly simplify with attributes or structured matchers (ex. PDL) in
the future.
Differential Revision: https://reviews.llvm.org/D154703
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
mlir/test/Dialect/Linalg/generalize-named-ops.mlir
mlir/test/Dialect/Linalg/named-ops-fail.mlir
mlir/test/Dialect/Linalg/named-ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 11fa49ef34681d..d021376ff4cdcd 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -613,6 +613,55 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_arg: rhs
--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+ name: max
+ cpp_class_name: MaxOp
+ doc: |-
+ Takes the max (signed) between the input and a constant.
+
+ The shapes and element types must be identical. The appropriate casts,
+ broadcasts and reductions should be done previously to calling this op.
+
+ This means reduction/broadcast/element cast semantics is explicit. Further
+ passes can take that into account when lowering this code. For example,
+ a `linalg.broadcast` + `linalg.div` sequence can be lowered to a
+ `linalg.generic` with
diff erent affine maps for the two operands.
+structured_op: !LinalgStructuredOpConfig
+ args:
+ - !LinalgOperandDefConfig
+ name: lhs
+ kind: input_tensor
+ type_var: T
+ shape_map: affine_map<() -> ()>
+ - !LinalgOperandDefConfig
+ name: rhs
+ kind: input_tensor
+ type_var: T
+ shape_map: affine_map<() -> ()>
+ - !LinalgOperandDefConfig
+ name: out
+ kind: output_tensor
+ type_var: T
+ shape_map: affine_map<() -> ()>
+ indexing_maps: !LinalgIndexingMapsConfig
+ static_indexing_maps:
+ - affine_map<() -> ()>
+ - affine_map<() -> ()>
+ - affine_map<() -> ()>
+ iterator_types: []
+ assignments:
+ - !ScalarAssign
+ arg: out
+ value: !ScalarExpression
+ scalar_fn:
+ kind: binary
+ fn_name: max_signed
+ operands:
+ - !ScalarExpression
+ scalar_arg: lhs
+ - !ScalarExpression
+ scalar_arg: rhs
+--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: matmul
cpp_class_name: MatmulOp
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index 9cc252eb710234..e4512cd1e0573c 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -219,6 +219,25 @@ def div_unsigned(
O[None] = lhs[None] / rhs[None]
+ at linalg_structured_op
+def max(
+ lhs=TensorDef(T1),
+ rhs=TensorDef(T1),
+ O=TensorDef(T1, output=True),
+):
+ """Takes the max (signed) between two inputs, elementwise.
+
+ The shapes and element types must be identical. The appropriate casts,
+ broadcasts and reductions should be done previously to calling this op.
+
+ This means reduction/broadcast/element cast semantics is explicit. Further
+ passes can take that into account when lowering this code. For example,
+ a `linalg.broadcast` + `linalg.div` sequence can be lowered to a
+ `linalg.generic` with
diff erent affine maps for the two operands.
+ """
+ O[None] = BinaryFn.max_signed(lhs[None], rhs[None])
+
+
@linalg_structured_op
def matmul(
A=TensorDef(T1, S.M, S.K),
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index 7e96ad2b0b2412..af616a0a7bd8da 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -537,3 +537,28 @@ func.func @generalize_negf(%arg: memref<7x14x21xf32>, %out: memref<7x14x21xf32>)
// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32)
// CHECK-NEXT: %[[negf:.+]] = arith.negf %[[BBARG0]] : f32
// CHECK-NEXT: linalg.yield %[[negf]] : f32
+
+// -----
+
+func.func @generalize_max(%lhs: memref<7x14x21xf32>, %rhs: memref<7x14x21xf32>,
+ %out: memref<7x14x21xf32>) {
+ linalg.max ins(%lhs, %rhs : memref<7x14x21xf32>, memref<7x14x21xf32>)
+ outs(%out : memref<7x14x21xf32>)
+ return
+}
+
+// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+
+// CHECK: func @generalize_max
+// CHECK-SAME: (%[[LHS:.+]]: memref<7x14x21xf32>, %[[RHS:.+]]: memref<7x14x21xf32>,
+// CHECK-SAME: %[[OUT:.+]]: memref<7x14x21xf32>)
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
+// CHECK-SAME: ins(%[[LHS]], %[[RHS]] : memref<7x14x21xf32>, memref<7x14x21xf32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xf32>)
+
+// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32)
+// CHECK-NEXT: %[[max:.+]] = arith.maxf %[[BBARG0]], %[[BBARG1]] : f32
+// CHECK-NEXT: linalg.yield %[[max]] : f32
diff --git a/mlir/test/Dialect/Linalg/named-ops-fail.mlir b/mlir/test/Dialect/Linalg/named-ops-fail.mlir
index b5cd9b659dd181..c351e139a97e37 100644
--- a/mlir/test/Dialect/Linalg/named-ops-fail.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops-fail.mlir
@@ -173,3 +173,19 @@ func.func @negf_broadcast(%arg: memref<8x16xf32>, %out: memref<4x8x16xf32>) {
linalg.negf ins(%arg : memref<8x16xf32>) outs(%out: memref<4x8x16xf32>)
return
}
+
+// -----
+
+func.func @max_type_cast(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf16>, %arg2: memref<4x8x16xf32>) {
+ // CHECK: op requires the same type for all operands and results
+ linalg.max ins(%arg0, %arg1 : memref<4x8x16xf32>, memref<4x8x16xf16>) outs(%arg2: memref<4x8x16xf32>)
+ return
+}
+
+// -----
+
+func.func @max_broadcast(%arg0: memref<8x16xf32>, %arg1: memref<4x8x16xf32>, %arg2: memref<4x8x16xf32>) {
+ // CHECK: op expected operand rank (2) to match the result rank of indexing_map #0 (3)
+ linalg.max ins(%arg0, %arg1 : memref<8x16xf32>, memref<4x8x16xf32>) outs(%arg2: memref<4x8x16xf32>)
+ return
+}
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index ca4350e3056614..8f00d546553274 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1540,3 +1540,37 @@ func.func @negf_tensor(%arg0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> {
%1 = linalg.negf ins(%arg0 : tensor<4x8x16xf32>) outs(%0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
return %1 : tensor<4x8x16xf32>
}
+
+// -----
+
+// CHECK-LABEL: func @max_dynamic
+func.func @max_dynamic(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
+ // CHECK: linalg.max
+ // CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<?x?x?xf32>, memref<?x?x?xf32>)
+ // CHECK-SAME: outs(%{{.+}} : memref<?x?x?xf32>)
+ linalg.max ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2: memref<?x?x?xf32>)
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @max_static
+func.func @max_static(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf32>, %arg2: memref<4x8x16xf32>) {
+ // CHECK: linalg.max
+ // CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<4x8x16xf32>, memref<4x8x16xf32>)
+ // CHECK-SAME: outs(%{{.+}} : memref<4x8x16xf32>)
+ linalg.max ins(%arg0, %arg1 : memref<4x8x16xf32>, memref<4x8x16xf32>) outs(%arg2: memref<4x8x16xf32>)
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @max_tensor
+func.func @max_tensor(%arg0: tensor<4x8x16xf32>, %arg1: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> {
+ %0 = tensor.empty() : tensor<4x8x16xf32>
+ // CHECK: linalg.max
+ // CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<4x8x16xf32>, tensor<4x8x16xf32>)
+ // CHECK-SAME: outs(%{{.+}} : tensor<4x8x16xf32>)
+ %1 = linalg.max ins(%arg0, %arg1 : tensor<4x8x16xf32>, tensor<4x8x16xf32>) outs(%0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
+ return %1 : tensor<4x8x16xf32>
+}
More information about the Mlir-commits
mailing list