[Mlir-commits] [mlir] 4cec3b3 - [MLIR][Linalg] More Linalg named ops (#90236)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Apr 28 07:25:27 PDT 2024
Author: Renato Golin
Date: 2024-04-28T15:25:24+01:00
New Revision: 4cec3b36f6d6c858992530fa5592824622ada9c7
URL: https://github.com/llvm/llvm-project/commit/4cec3b36f6d6c858992530fa5592824622ada9c7
DIFF: https://github.com/llvm/llvm-project/commit/4cec3b36f6d6c858992530fa5592824622ada9c7.diff
LOG: [MLIR][Linalg] More Linalg named ops (#90236)
Adding `min` that was already implemented but not exposed.
Adding a few additional unary ops:
* Reciprocal as `arith.div(1,arg)`
* Round as `math.round(arg)`
* Sqrt as `math.sqrt(arg)`
* Rsqrt as `math.rsqrt(arg)`
* Square as `math.powf(arg, 2)`
* TanH as `math.tanh(arg)`
All with the agreed semantics at the round table: no implicit
broadcast/type cast.
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
mlir/test/Dialect/Linalg/generalize-named-ops.mlir
mlir/test/Dialect/Linalg/named-ops-fail.mlir
mlir/test/Dialect/Linalg/named-ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
index 59f909aed8f61a..7a350d2c014262 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
@@ -22,7 +22,13 @@ def UnaryFn : I32EnumAttr<"UnaryFn", "", [
I32EnumAttrCase<"abs", 2>,
I32EnumAttrCase<"ceil", 3>,
I32EnumAttrCase<"floor", 4>,
- I32EnumAttrCase<"negf", 5>
+ I32EnumAttrCase<"negf", 5>,
+ I32EnumAttrCase<"reciprocal", 6>,
+ I32EnumAttrCase<"round", 7>,
+ I32EnumAttrCase<"sqrt", 8>,
+ I32EnumAttrCase<"rsqrt", 9>,
+ I32EnumAttrCase<"square", 10>,
+ I32EnumAttrCase<"tanh", 11>
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::linalg";
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 1ff6c4086cf357..b7567577347587 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -304,6 +304,216 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_arg: I
--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+ name: reciprocal
+ cpp_class_name: ReciprocalOp
+ doc: |-
+ Applies reciprocal(x) elementwise.
+
+ No numeric casting is performed on the input operand.
+structured_op: !LinalgStructuredOpConfig
+ args:
+ - !LinalgOperandDefConfig
+ name: I
+ kind: input_tensor
+ type_var: T1
+ shape_map: affine_map<() -> ()>
+ - !LinalgOperandDefConfig
+ name: O
+ kind: output_tensor
+ type_var: T1
+ shape_map: affine_map<() -> ()>
+ indexing_maps: !LinalgIndexingMapsConfig
+ static_indexing_maps:
+ - affine_map<() -> ()>
+ - affine_map<() -> ()>
+ iterator_types: []
+ assignments:
+ - !ScalarAssign
+ arg: O
+ value: !ScalarExpression
+ scalar_fn:
+ kind: unary
+ fn_name: reciprocal
+ operands:
+ - !ScalarExpression
+ scalar_arg: I
+--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+ name: round
+ cpp_class_name: RoundOp
+ doc: |-
+ Applies round(x) elementwise.
+
+ No numeric casting is performed on the input operand.
+structured_op: !LinalgStructuredOpConfig
+ args:
+ - !LinalgOperandDefConfig
+ name: I
+ kind: input_tensor
+ type_var: T1
+ shape_map: affine_map<() -> ()>
+ - !LinalgOperandDefConfig
+ name: O
+ kind: output_tensor
+ type_var: T1
+ shape_map: affine_map<() -> ()>
+ indexing_maps: !LinalgIndexingMapsConfig
+ static_indexing_maps:
+ - affine_map<() -> ()>
+ - affine_map<() -> ()>
+ iterator_types: []
+ assignments:
+ - !ScalarAssign
+ arg: O
+ value: !ScalarExpression
+ scalar_fn:
+ kind: unary
+ fn_name: round
+ operands:
+ - !ScalarExpression
+ scalar_arg: I
+--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+ name: sqrt
+ cpp_class_name: SqrtOp
+ doc: |-
+ Applies sqrt(x) elementwise.
+
+ No numeric casting is performed on the input operand.
+structured_op: !LinalgStructuredOpConfig
+ args:
+ - !LinalgOperandDefConfig
+ name: I
+ kind: input_tensor
+ type_var: T1
+ shape_map: affine_map<() -> ()>
+ - !LinalgOperandDefConfig
+ name: O
+ kind: output_tensor
+ type_var: T1
+ shape_map: affine_map<() -> ()>
+ indexing_maps: !LinalgIndexingMapsConfig
+ static_indexing_maps:
+ - affine_map<() -> ()>
+ - affine_map<() -> ()>
+ iterator_types: []
+ assignments:
+ - !ScalarAssign
+ arg: O
+ value: !ScalarExpression
+ scalar_fn:
+ kind: unary
+ fn_name: sqrt
+ operands:
+ - !ScalarExpression
+ scalar_arg: I
+--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+ name: rsqrt
+ cpp_class_name: RsqrtOp
+ doc: |-
+ Applies rsqrt(x) elementwise.
+
+ No numeric casting is performed on the input operand.
+structured_op: !LinalgStructuredOpConfig
+ args:
+ - !LinalgOperandDefConfig
+ name: I
+ kind: input_tensor
+ type_var: T1
+ shape_map: affine_map<() -> ()>
+ - !LinalgOperandDefConfig
+ name: O
+ kind: output_tensor
+ type_var: T1
+ shape_map: affine_map<() -> ()>
+ indexing_maps: !LinalgIndexingMapsConfig
+ static_indexing_maps:
+ - affine_map<() -> ()>
+ - affine_map<() -> ()>
+ iterator_types: []
+ assignments:
+ - !ScalarAssign
+ arg: O
+ value: !ScalarExpression
+ scalar_fn:
+ kind: unary
+ fn_name: rsqrt
+ operands:
+ - !ScalarExpression
+ scalar_arg: I
+--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+ name: square
+ cpp_class_name: SquareOp
+ doc: |-
+ Applies square(x) elementwise.
+
+ No numeric casting is performed on the input operand.
+structured_op: !LinalgStructuredOpConfig
+ args:
+ - !LinalgOperandDefConfig
+ name: I
+ kind: input_tensor
+ type_var: T1
+ shape_map: affine_map<() -> ()>
+ - !LinalgOperandDefConfig
+ name: O
+ kind: output_tensor
+ type_var: T1
+ shape_map: affine_map<() -> ()>
+ indexing_maps: !LinalgIndexingMapsConfig
+ static_indexing_maps:
+ - affine_map<() -> ()>
+ - affine_map<() -> ()>
+ iterator_types: []
+ assignments:
+ - !ScalarAssign
+ arg: O
+ value: !ScalarExpression
+ scalar_fn:
+ kind: unary
+ fn_name: square
+ operands:
+ - !ScalarExpression
+ scalar_arg: I
+--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+ name: tanh
+ cpp_class_name: TanhOp
+ doc: |-
+ Applies tanh(x) elementwise.
+
+ No numeric casting is performed on the input operand.
+structured_op: !LinalgStructuredOpConfig
+ args:
+ - !LinalgOperandDefConfig
+ name: I
+ kind: input_tensor
+ type_var: T1
+ shape_map: affine_map<() -> ()>
+ - !LinalgOperandDefConfig
+ name: O
+ kind: output_tensor
+ type_var: T1
+ shape_map: affine_map<() -> ()>
+ indexing_maps: !LinalgIndexingMapsConfig
+ static_indexing_maps:
+ - affine_map<() -> ()>
+ - affine_map<() -> ()>
+ iterator_types: []
+ assignments:
+ - !ScalarAssign
+ arg: O
+ value: !ScalarExpression
+ scalar_fn:
+ kind: unary
+ fn_name: tanh
+ operands:
+ - !ScalarExpression
+ scalar_arg: I
+--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: elemwise_binary
cpp_class_name: ElemwiseBinaryOp
@@ -625,7 +835,7 @@ metadata: !LinalgOpMetadata
This means reduction/broadcast/element cast semantics is explicit. Further
passes can take that into account when lowering this code. For example,
- a `linalg.broadcast` + `linalg.div` sequence can be lowered to a
+ a `linalg.broadcast` + `linalg.max` sequence can be lowered to a
`linalg.generic` with
diff erent affine maps for the two operands.
structured_op: !LinalgStructuredOpConfig
args:
@@ -663,6 +873,55 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_arg: rhs
--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+ name: min
+ cpp_class_name: MinOp
+ doc: |-
+ Takes the min (signed) between two inputs, elementwise.
+
+ The shapes and element types must be identical. The appropriate casts,
+ broadcasts and reductions should be done previously to calling this op.
+
+ This means reduction/broadcast/element cast semantics is explicit. Further
+ passes can take that into account when lowering this code. For example,
+ a `linalg.broadcast` + `linalg.min` sequence can be lowered to a
+ `linalg.generic` with
diff erent affine maps for the two operands.
+structured_op: !LinalgStructuredOpConfig
+ args:
+ - !LinalgOperandDefConfig
+ name: lhs
+ kind: input_tensor
+ type_var: T1
+ shape_map: affine_map<() -> ()>
+ - !LinalgOperandDefConfig
+ name: rhs
+ kind: input_tensor
+ type_var: T1
+ shape_map: affine_map<() -> ()>
+ - !LinalgOperandDefConfig
+ name: O
+ kind: output_tensor
+ type_var: T1
+ shape_map: affine_map<() -> ()>
+ indexing_maps: !LinalgIndexingMapsConfig
+ static_indexing_maps:
+ - affine_map<() -> ()>
+ - affine_map<() -> ()>
+ - affine_map<() -> ()>
+ iterator_types: []
+ assignments:
+ - !ScalarAssign
+ arg: O
+ value: !ScalarExpression
+ scalar_fn:
+ kind: binary
+ fn_name: min_signed
+ operands:
+ - !ScalarExpression
+ scalar_arg: lhs
+ - !ScalarExpression
+ scalar_arg: rhs
+--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: matmul
cpp_class_name: MatmulOp
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 9c5c58fa1fabfb..5d10b59373ad03 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -395,6 +395,22 @@ class RegionBuilderHelper {
return builder.create<math::FloorOp>(arg.getLoc(), arg);
case UnaryFn::negf:
return builder.create<arith::NegFOp>(arg.getLoc(), arg);
+ case UnaryFn::reciprocal: {
+ Attribute oneAttr = builder.getOneAttr(arg.getType());
+ auto one = builder.create<arith::ConstantOp>(arg.getLoc(),
+ ::cast<TypedAttr>(oneAttr));
+ return builder.create<arith::DivFOp>(arg.getLoc(), one, arg);
+ }
+ case UnaryFn::round:
+ return builder.create<math::RoundOp>(arg.getLoc(), arg);
+ case UnaryFn::sqrt:
+ return builder.create<math::SqrtOp>(arg.getLoc(), arg);
+ case UnaryFn::rsqrt:
+ return builder.create<math::RsqrtOp>(arg.getLoc(), arg);
+ case UnaryFn::square:
+ return builder.create<arith::MulFOp>(arg.getLoc(), arg, arg);
+ case UnaryFn::tanh:
+ return builder.create<math::TanhOp>(arg.getLoc(), arg);
}
llvm_unreachable("unsupported unary function");
}
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
index 23d6d26b7e294c..f7bc81bd2f6833 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
@@ -291,6 +291,11 @@ class UnaryFn:
ceil = UnaryFnType("ceil")
floor = UnaryFnType("floor")
negf = UnaryFnType("negf")
+ round = UnaryFnType("round")
+ sqrt = UnaryFnType("sqrt")
+ rsqrt = UnaryFnType("rsqrt")
+ square = UnaryFnType("square")
+ tanh = UnaryFnType("tanh")
class BinaryFnType:
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index 5b05364f6d35f3..2c8864be1107fb 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -108,6 +108,66 @@ def negf(
O[None] = UnaryFn.negf(I[None])
+ at linalg_structured_op
+def round(
+ I=TensorDef(T1),
+ O=TensorDef(T1, output=True),
+):
+ """Applies round(x) elementwise.
+
+ No numeric casting is performed on the input operand.
+ """
+ O[None] = UnaryFn.round(I[None])
+
+
+ at linalg_structured_op
+def sqrt(
+ I=TensorDef(T1),
+ O=TensorDef(T1, output=True),
+):
+ """Applies sqrt(x) elementwise.
+
+ No numeric casting is performed on the input operand.
+ """
+ O[None] = UnaryFn.sqrt(I[None])
+
+
+ at linalg_structured_op
+def rsqrt(
+ I=TensorDef(T1),
+ O=TensorDef(T1, output=True),
+):
+ """Applies rsqrt(x) elementwise.
+
+ No numeric casting is performed on the input operand.
+ """
+ O[None] = UnaryFn.rsqrt(I[None])
+
+
+ at linalg_structured_op
+def square(
+ I=TensorDef(T1),
+ O=TensorDef(T1, output=True),
+):
+ """Applies square(x) elementwise.
+
+ No numeric casting is performed on the input operand.
+ """
+ O[None] = UnaryFn.square(I[None])
+
+
+ at linalg_structured_op
+def tanh(
+ I=TensorDef(T1),
+ O=TensorDef(T1, output=True),
+):
+ """Applies tanh(x) elementwise.
+
+ No numeric casting is performed on the input operand.
+ """
+ O[None] = UnaryFn.tanh(I[None])
+
+
@linalg_structured_op
def elemwise_binary(
lhs=TensorDef(T1),
@@ -233,12 +293,31 @@ def max(
This means reduction/broadcast/element cast semantics is explicit. Further
passes can take that into account when lowering this code. For example,
- a `linalg.broadcast` + `linalg.div` sequence can be lowered to a
+ a `linalg.broadcast` + `linalg.max` sequence can be lowered to a
`linalg.generic` with
diff erent affine maps for the two operands.
"""
O[None] = BinaryFn.max_signed(lhs[None], rhs[None])
+ at linalg_structured_op
+def min(
+ lhs=TensorDef(T1),
+ rhs=TensorDef(T1),
+ O=TensorDef(T1, output=True),
+):
+ """Takes the min (signed) between two inputs, elementwise.
+
+ The shapes and element types must be identical. The appropriate casts,
+ broadcasts and reductions should be done previously to calling this op.
+
+ This means reduction/broadcast/element cast semantics is explicit. Further
+ passes can take that into account when lowering this code. For example,
+ a `linalg.broadcast` + `linalg.min` sequence can be lowered to a
+ `linalg.generic` with
diff erent affine maps for the two operands.
+ """
+ O[None] = BinaryFn.min_signed(lhs[None], rhs[None])
+
+
@linalg_structured_op
def matmul(
A=TensorDef(T1, S.M, S.K),
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index e852824cdb7367..add34412b92f2b 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -565,6 +565,134 @@ func.func @generalize_negf(%arg: memref<7x14x21xf32>, %out: memref<7x14x21xf32>)
// -----
+func.func @generalize_reciprocal(%arg: memref<7x14x21xf32>, %out: memref<7x14x21xf32>) {
+ linalg.reciprocal ins(%arg : memref<7x14x21xf32>) outs(%out : memref<7x14x21xf32>)
+ return
+}
+
+// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+
+// CHECK: func @generalize_reciprocal
+// CHECK-SAME: (%[[ARG:.+]]: memref<7x14x21xf32>, %[[OUT:.+]]: memref<7x14x21xf32>)
+
+// CHECK: %[[one:.+]] = arith.constant 1.000000e+00 : f32
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
+// CHECK-SAME: ins(%[[LHS]] : memref<7x14x21xf32>) outs(%[[OUT]] : memref<7x14x21xf32>)
+
+// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32)
+// CHECK-NEXT: %[[reciprocal:.+]] = arith.divf %[[one]], %[[BBARG0]] : f32
+// CHECK-NEXT: linalg.yield %[[reciprocal]] : f32
+
+// -----
+
+func.func @generalize_round(%arg: memref<7x14x21xf32>, %out: memref<7x14x21xf32>) {
+ linalg.round ins(%arg : memref<7x14x21xf32>) outs(%out : memref<7x14x21xf32>)
+ return
+}
+
+// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+
+// CHECK: func @generalize_round
+// CHECK-SAME: (%[[ARG:.+]]: memref<7x14x21xf32>, %[[OUT:.+]]: memref<7x14x21xf32>)
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
+// CHECK-SAME: ins(%[[LHS]] : memref<7x14x21xf32>) outs(%[[OUT]] : memref<7x14x21xf32>)
+
+// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32)
+// CHECK-NEXT: %[[round:.+]] = math.round %[[BBARG0]] : f32
+// CHECK-NEXT: linalg.yield %[[round]] : f32
+
+// -----
+
+func.func @generalize_sqrt(%arg: memref<7x14x21xf32>, %out: memref<7x14x21xf32>) {
+ linalg.sqrt ins(%arg : memref<7x14x21xf32>) outs(%out : memref<7x14x21xf32>)
+ return
+}
+
+// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+
+// CHECK: func @generalize_sqrt
+// CHECK-SAME: (%[[ARG:.+]]: memref<7x14x21xf32>, %[[OUT:.+]]: memref<7x14x21xf32>)
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
+// CHECK-SAME: ins(%[[LHS]] : memref<7x14x21xf32>) outs(%[[OUT]] : memref<7x14x21xf32>)
+
+// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32)
+// CHECK-NEXT: %[[sqrt:.+]] = math.sqrt %[[BBARG0]] : f32
+// CHECK-NEXT: linalg.yield %[[sqrt]] : f32
+
+// -----
+
+func.func @generalize_rsqrt(%arg: memref<7x14x21xf32>, %out: memref<7x14x21xf32>) {
+ linalg.rsqrt ins(%arg : memref<7x14x21xf32>) outs(%out : memref<7x14x21xf32>)
+ return
+}
+
+// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+
+// CHECK: func @generalize_rsqrt
+// CHECK-SAME: (%[[ARG:.+]]: memref<7x14x21xf32>, %[[OUT:.+]]: memref<7x14x21xf32>)
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
+// CHECK-SAME: ins(%[[LHS]] : memref<7x14x21xf32>) outs(%[[OUT]] : memref<7x14x21xf32>)
+
+// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32)
+// CHECK-NEXT: %[[rsqrt:.+]] = math.rsqrt %[[BBARG0]] : f32
+// CHECK-NEXT: linalg.yield %[[rsqrt]] : f32
+
+// -----
+
+func.func @generalize_square(%arg: memref<7x14x21xf32>, %out: memref<7x14x21xf32>) {
+ linalg.square ins(%arg : memref<7x14x21xf32>) outs(%out : memref<7x14x21xf32>)
+ return
+}
+
+// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+
+// CHECK: func @generalize_square
+// CHECK-SAME: (%[[ARG:.+]]: memref<7x14x21xf32>, %[[OUT:.+]]: memref<7x14x21xf32>)
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
+// CHECK-SAME: ins(%[[LHS]] : memref<7x14x21xf32>) outs(%[[OUT]] : memref<7x14x21xf32>)
+
+// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32)
+// CHECK-NEXT: %[[square:.+]] = arith.mulf %[[BBARG0]], %[[BBARG0]] : f32
+// CHECK-NEXT: linalg.yield %[[square]] : f32
+
+// -----
+
+func.func @generalize_tanh(%arg: memref<7x14x21xf32>, %out: memref<7x14x21xf32>) {
+ linalg.tanh ins(%arg : memref<7x14x21xf32>) outs(%out : memref<7x14x21xf32>)
+ return
+}
+
+// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+
+// CHECK: func @generalize_tanh
+// CHECK-SAME: (%[[ARG:.+]]: memref<7x14x21xf32>, %[[OUT:.+]]: memref<7x14x21xf32>)
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
+// CHECK-SAME: ins(%[[LHS]] : memref<7x14x21xf32>) outs(%[[OUT]] : memref<7x14x21xf32>)
+
+// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32)
+// CHECK-NEXT: %[[tanh:.+]] = math.tanh %[[BBARG0]] : f32
+// CHECK-NEXT: linalg.yield %[[tanh]] : f32
+
+// -----
+
func.func @generalize_max(%lhs: memref<7x14x21xf32>, %rhs: memref<7x14x21xf32>,
%out: memref<7x14x21xf32>) {
linalg.max ins(%lhs, %rhs : memref<7x14x21xf32>, memref<7x14x21xf32>)
@@ -590,6 +718,31 @@ func.func @generalize_max(%lhs: memref<7x14x21xf32>, %rhs: memref<7x14x21xf32>,
// -----
+func.func @generalize_min(%lhs: memref<7x14x21xf32>, %rhs: memref<7x14x21xf32>,
+ %out: memref<7x14x21xf32>) {
+ linalg.min ins(%lhs, %rhs : memref<7x14x21xf32>, memref<7x14x21xf32>)
+ outs(%out : memref<7x14x21xf32>)
+ return
+}
+
+// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+
+// CHECK: func @generalize_min
+// CHECK-SAME: (%[[LHS:.+]]: memref<7x14x21xf32>, %[[RHS:.+]]: memref<7x14x21xf32>,
+// CHECK-SAME: %[[OUT:.+]]: memref<7x14x21xf32>)
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
+// CHECK-SAME: ins(%[[LHS]], %[[RHS]] : memref<7x14x21xf32>, memref<7x14x21xf32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xf32>)
+
+// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32)
+// CHECK-NEXT: %[[min:.+]] = arith.minimumf %[[BBARG0]], %[[BBARG1]] : f32
+// CHECK-NEXT: linalg.yield %[[min]] : f32
+
+// -----
+
// CHECK-LABEL: func @fill_tensor
func.func @fill_tensor(%f: f32, %v: vector<2x4xf32>) -> (tensor<f32>, tensor<vector<2x4xf32>>) {
diff --git a/mlir/test/Dialect/Linalg/named-ops-fail.mlir b/mlir/test/Dialect/Linalg/named-ops-fail.mlir
index c351e139a97e37..f66608e71ffc64 100644
--- a/mlir/test/Dialect/Linalg/named-ops-fail.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops-fail.mlir
@@ -176,6 +176,102 @@ func.func @negf_broadcast(%arg: memref<8x16xf32>, %out: memref<4x8x16xf32>) {
// -----
+func.func @reciprocal_type_cast(%arg: memref<4x8x16xf16>, %out: memref<4x8x16xf32>) {
+ // CHECK: operand 1 ('f16') doesn't match the element type of the enclosing linalg.generic op ('f32')
+ linalg.reciprocal ins(%arg : memref<4x8x16xf16>) outs(%out: memref<4x8x16xf32>)
+ return
+}
+
+// -----
+
+func.func @reciprocal_broadcast(%arg: memref<8x16xf32>, %out: memref<4x8x16xf32>) {
+ // CHECK: op expected operand rank (2) to match the result rank of indexing_map #0 (3)
+ linalg.reciprocal ins(%arg : memref<8x16xf32>) outs(%out: memref<4x8x16xf32>)
+ return
+}
+
+// -----
+
+func.func @round_type_cast(%arg: memref<4x8x16xf16>, %out: memref<4x8x16xf32>) {
+ // CHECK: operand 1 ('f16') doesn't match the element type of the enclosing linalg.generic op ('f32')
+ linalg.round ins(%arg : memref<4x8x16xf16>) outs(%out: memref<4x8x16xf32>)
+ return
+}
+
+// -----
+
+func.func @round_broadcast(%arg: memref<8x16xf32>, %out: memref<4x8x16xf32>) {
+ // CHECK: op expected operand rank (2) to match the result rank of indexing_map #0 (3)
+ linalg.round ins(%arg : memref<8x16xf32>) outs(%out: memref<4x8x16xf32>)
+ return
+}
+
+// -----
+
+func.func @sqrt_type_cast(%arg: memref<4x8x16xf16>, %out: memref<4x8x16xf32>) {
+ // CHECK: operand 1 ('f16') doesn't match the element type of the enclosing linalg.generic op ('f32')
+ linalg.sqrt ins(%arg : memref<4x8x16xf16>) outs(%out: memref<4x8x16xf32>)
+ return
+}
+
+// -----
+
+func.func @sqrt_broadcast(%arg: memref<8x16xf32>, %out: memref<4x8x16xf32>) {
+ // CHECK: op expected operand rank (2) to match the result rank of indexing_map #0 (3)
+ linalg.sqrt ins(%arg : memref<8x16xf32>) outs(%out: memref<4x8x16xf32>)
+ return
+}
+
+// -----
+
+func.func @rsqrt_type_cast(%arg: memref<4x8x16xf16>, %out: memref<4x8x16xf32>) {
+ // CHECK: operand 1 ('f16') doesn't match the element type of the enclosing linalg.generic op ('f32')
+ linalg.rsqrt ins(%arg : memref<4x8x16xf16>) outs(%out: memref<4x8x16xf32>)
+ return
+}
+
+// -----
+
+func.func @rsqrt_broadcast(%arg: memref<8x16xf32>, %out: memref<4x8x16xf32>) {
+ // CHECK: op expected operand rank (2) to match the result rank of indexing_map #0 (3)
+ linalg.rsqrt ins(%arg : memref<8x16xf32>) outs(%out: memref<4x8x16xf32>)
+ return
+}
+
+// -----
+
+func.func @square_type_cast(%arg: memref<4x8x16xf16>, %out: memref<4x8x16xf32>) {
+ // CHECK: operand 1 ('f16') doesn't match the element type of the enclosing linalg.generic op ('f32')
+ linalg.square ins(%arg : memref<4x8x16xf16>) outs(%out: memref<4x8x16xf32>)
+ return
+}
+
+// -----
+
+func.func @square_broadcast(%arg: memref<8x16xf32>, %out: memref<4x8x16xf32>) {
+ // CHECK: op expected operand rank (2) to match the result rank of indexing_map #0 (3)
+ linalg.square ins(%arg : memref<8x16xf32>) outs(%out: memref<4x8x16xf32>)
+ return
+}
+
+// -----
+
+func.func @tanh_type_cast(%arg: memref<4x8x16xf16>, %out: memref<4x8x16xf32>) {
+ // CHECK: operand 1 ('f16') doesn't match the element type of the enclosing linalg.generic op ('f32')
+ linalg.tanh ins(%arg : memref<4x8x16xf16>) outs(%out: memref<4x8x16xf32>)
+ return
+}
+
+// -----
+
+func.func @tanh_broadcast(%arg: memref<8x16xf32>, %out: memref<4x8x16xf32>) {
+ // CHECK: op expected operand rank (2) to match the result rank of indexing_map #0 (3)
+ linalg.tanh ins(%arg : memref<8x16xf32>) outs(%out: memref<4x8x16xf32>)
+ return
+}
+
+// -----
+
func.func @max_type_cast(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf16>, %arg2: memref<4x8x16xf32>) {
// CHECK: op requires the same type for all operands and results
linalg.max ins(%arg0, %arg1 : memref<4x8x16xf32>, memref<4x8x16xf16>) outs(%arg2: memref<4x8x16xf32>)
@@ -189,3 +285,19 @@ func.func @max_broadcast(%arg0: memref<8x16xf32>, %arg1: memref<4x8x16xf32>, %ar
linalg.max ins(%arg0, %arg1 : memref<8x16xf32>, memref<4x8x16xf32>) outs(%arg2: memref<4x8x16xf32>)
return
}
+
+// -----
+
+func.func @min_type_cast(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf16>, %arg2: memref<4x8x16xf32>) {
+ // CHECK: op requires the same type for all operands and results
+ linalg.min ins(%arg0, %arg1 : memref<4x8x16xf32>, memref<4x8x16xf16>) outs(%arg2: memref<4x8x16xf32>)
+ return
+}
+
+// -----
+
+func.func @min_broadcast(%arg0: memref<8x16xf32>, %arg1: memref<4x8x16xf32>, %arg2: memref<4x8x16xf32>) {
+ // CHECK: op expected operand rank (2) to match the result rank of indexing_map #0 (3)
+ linalg.min ins(%arg0, %arg1 : memref<8x16xf32>, memref<4x8x16xf32>) outs(%arg2: memref<4x8x16xf32>)
+ return
+}
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index 7064e1b3f9dc76..cf59f673610013 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1597,6 +1597,192 @@ func.func @negf_tensor(%arg0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> {
// -----
+// CHECK-LABEL: func @reciprocal_dynamic
+func.func @reciprocal_dynamic(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>) {
+ // CHECK: linalg.reciprocal
+ // CHECK-SAME: ins(%{{.+}} : memref<?x?x?xf32>) outs(%{{.+}} : memref<?x?x?xf32>)
+ linalg.reciprocal ins(%arg0 : memref<?x?x?xf32>) outs(%arg1: memref<?x?x?xf32>)
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @reciprocal_static
+func.func @reciprocal_static(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf32>) {
+ // CHECK: linalg.reciprocal
+ // CHECK-SAME: ins(%{{.+}} : memref<4x8x16xf32>) outs(%{{.+}} : memref<4x8x16xf32>)
+ linalg.reciprocal ins(%arg0 : memref<4x8x16xf32>) outs(%arg1: memref<4x8x16xf32>)
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @reciprocal_tensor
+func.func @reciprocal_tensor(%arg0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> {
+ %0 = tensor.empty() : tensor<4x8x16xf32>
+ // CHECK: linalg.reciprocal
+ // CHECK-SAME: ins(%{{.+}} : tensor<4x8x16xf32>) outs(%{{.+}} : tensor<4x8x16xf32>)
+ %1 = linalg.reciprocal ins(%arg0 : tensor<4x8x16xf32>) outs(%0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
+ return %1 : tensor<4x8x16xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @round_dynamic
+func.func @round_dynamic(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>) {
+ // CHECK: linalg.round
+ // CHECK-SAME: ins(%{{.+}} : memref<?x?x?xf32>) outs(%{{.+}} : memref<?x?x?xf32>)
+ linalg.round ins(%arg0 : memref<?x?x?xf32>) outs(%arg1: memref<?x?x?xf32>)
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @round_static
+func.func @round_static(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf32>) {
+ // CHECK: linalg.round
+ // CHECK-SAME: ins(%{{.+}} : memref<4x8x16xf32>) outs(%{{.+}} : memref<4x8x16xf32>)
+ linalg.round ins(%arg0 : memref<4x8x16xf32>) outs(%arg1: memref<4x8x16xf32>)
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @round_tensor
+func.func @round_tensor(%arg0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> {
+ %0 = tensor.empty() : tensor<4x8x16xf32>
+ // CHECK: linalg.round
+ // CHECK-SAME: ins(%{{.+}} : tensor<4x8x16xf32>) outs(%{{.+}} : tensor<4x8x16xf32>)
+ %1 = linalg.round ins(%arg0 : tensor<4x8x16xf32>) outs(%0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
+ return %1 : tensor<4x8x16xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @sqrt_dynamic
+func.func @sqrt_dynamic(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>) {
+ // CHECK: linalg.sqrt
+ // CHECK-SAME: ins(%{{.+}} : memref<?x?x?xf32>) outs(%{{.+}} : memref<?x?x?xf32>)
+ linalg.sqrt ins(%arg0 : memref<?x?x?xf32>) outs(%arg1: memref<?x?x?xf32>)
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @sqrt_static
+func.func @sqrt_static(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf32>) {
+ // CHECK: linalg.sqrt
+ // CHECK-SAME: ins(%{{.+}} : memref<4x8x16xf32>) outs(%{{.+}} : memref<4x8x16xf32>)
+ linalg.sqrt ins(%arg0 : memref<4x8x16xf32>) outs(%arg1: memref<4x8x16xf32>)
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @sqrt_tensor
+func.func @sqrt_tensor(%arg0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> {
+ %0 = tensor.empty() : tensor<4x8x16xf32>
+ // CHECK: linalg.sqrt
+ // CHECK-SAME: ins(%{{.+}} : tensor<4x8x16xf32>) outs(%{{.+}} : tensor<4x8x16xf32>)
+ %1 = linalg.sqrt ins(%arg0 : tensor<4x8x16xf32>) outs(%0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
+ return %1 : tensor<4x8x16xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @rsqrt_dynamic
+func.func @rsqrt_dynamic(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>) {
+ // CHECK: linalg.rsqrt
+ // CHECK-SAME: ins(%{{.+}} : memref<?x?x?xf32>) outs(%{{.+}} : memref<?x?x?xf32>)
+ linalg.rsqrt ins(%arg0 : memref<?x?x?xf32>) outs(%arg1: memref<?x?x?xf32>)
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @rsqrt_static
+func.func @rsqrt_static(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf32>) {
+ // CHECK: linalg.rsqrt
+ // CHECK-SAME: ins(%{{.+}} : memref<4x8x16xf32>) outs(%{{.+}} : memref<4x8x16xf32>)
+ linalg.rsqrt ins(%arg0 : memref<4x8x16xf32>) outs(%arg1: memref<4x8x16xf32>)
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @rsqrt_tensor
+func.func @rsqrt_tensor(%arg0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> {
+ %0 = tensor.empty() : tensor<4x8x16xf32>
+ // CHECK: linalg.rsqrt
+ // CHECK-SAME: ins(%{{.+}} : tensor<4x8x16xf32>) outs(%{{.+}} : tensor<4x8x16xf32>)
+ %1 = linalg.rsqrt ins(%arg0 : tensor<4x8x16xf32>) outs(%0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
+ return %1 : tensor<4x8x16xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @square_dynamic
+func.func @square_dynamic(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>) {
+ // CHECK: linalg.square
+ // CHECK-SAME: ins(%{{.+}} : memref<?x?x?xf32>) outs(%{{.+}} : memref<?x?x?xf32>)
+ linalg.square ins(%arg0 : memref<?x?x?xf32>) outs(%arg1: memref<?x?x?xf32>)
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @square_static
+func.func @square_static(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf32>) {
+ // CHECK: linalg.square
+ // CHECK-SAME: ins(%{{.+}} : memref<4x8x16xf32>) outs(%{{.+}} : memref<4x8x16xf32>)
+ linalg.square ins(%arg0 : memref<4x8x16xf32>) outs(%arg1: memref<4x8x16xf32>)
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @square_tensor
+func.func @square_tensor(%arg0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> {
+ %0 = tensor.empty() : tensor<4x8x16xf32>
+ // CHECK: linalg.square
+ // CHECK-SAME: ins(%{{.+}} : tensor<4x8x16xf32>) outs(%{{.+}} : tensor<4x8x16xf32>)
+ %1 = linalg.square ins(%arg0 : tensor<4x8x16xf32>) outs(%0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
+ return %1 : tensor<4x8x16xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @tanh_dynamic
+func.func @tanh_dynamic(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>) {
+ // CHECK: linalg.tanh
+ // CHECK-SAME: ins(%{{.+}} : memref<?x?x?xf32>) outs(%{{.+}} : memref<?x?x?xf32>)
+ linalg.tanh ins(%arg0 : memref<?x?x?xf32>) outs(%arg1: memref<?x?x?xf32>)
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @tanh_static
+func.func @tanh_static(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf32>) {
+ // CHECK: linalg.tanh
+ // CHECK-SAME: ins(%{{.+}} : memref<4x8x16xf32>) outs(%{{.+}} : memref<4x8x16xf32>)
+ linalg.tanh ins(%arg0 : memref<4x8x16xf32>) outs(%arg1: memref<4x8x16xf32>)
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @tanh_tensor
+func.func @tanh_tensor(%arg0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> {
+ %0 = tensor.empty() : tensor<4x8x16xf32>
+ // CHECK: linalg.tanh
+ // CHECK-SAME: ins(%{{.+}} : tensor<4x8x16xf32>) outs(%{{.+}} : tensor<4x8x16xf32>)
+ %1 = linalg.tanh ins(%arg0 : tensor<4x8x16xf32>) outs(%0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
+ return %1 : tensor<4x8x16xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @max_dynamic
func.func @max_dynamic(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
// CHECK: linalg.max
@@ -1631,6 +1817,40 @@ func.func @max_tensor(%arg0: tensor<4x8x16xf32>, %arg1: tensor<4x8x16xf32>) -> t
// -----
+// CHECK-LABEL: func @min_dynamic
+func.func @min_dynamic(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
+ // CHECK: linalg.min
+ // CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<?x?x?xf32>, memref<?x?x?xf32>)
+ // CHECK-SAME: outs(%{{.+}} : memref<?x?x?xf32>)
+ linalg.min ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2: memref<?x?x?xf32>)
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @min_static
+func.func @min_static(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf32>, %arg2: memref<4x8x16xf32>) {
+ // CHECK: linalg.min
+ // CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<4x8x16xf32>, memref<4x8x16xf32>)
+ // CHECK-SAME: outs(%{{.+}} : memref<4x8x16xf32>)
+ linalg.min ins(%arg0, %arg1 : memref<4x8x16xf32>, memref<4x8x16xf32>) outs(%arg2: memref<4x8x16xf32>)
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @min_tensor
+func.func @min_tensor(%arg0: tensor<4x8x16xf32>, %arg1: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> {
+ %0 = tensor.empty() : tensor<4x8x16xf32>
+ // CHECK: linalg.min
+ // CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<4x8x16xf32>, tensor<4x8x16xf32>)
+ // CHECK-SAME: outs(%{{.+}} : tensor<4x8x16xf32>)
+ %1 = linalg.min ins(%arg0, %arg1 : tensor<4x8x16xf32>, tensor<4x8x16xf32>) outs(%0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
+ return %1 : tensor<4x8x16xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @fill_tensor
func.func @fill_tensor(%f: f32, %v: vector<2x4xf32>) -> (tensor<f32>, tensor<vector<2x4xf32>>) {
%e0 = tensor.empty() : tensor<f32>
More information about the Mlir-commits
mailing list