[Mlir-commits] [mlir] [MLIR][Linalg] Left over Linalg named ops from previous PR (PR #90405)
Renato Golin
llvmlistbot at llvm.org
Sun Apr 28 10:41:14 PDT 2024
https://github.com/rengolin created https://github.com/llvm/llvm-project/pull/90405
Adding `erf` as unary and `powf` as binary.
Same as `max(arg, 0.0)` for `ReLU`, `powf(arg, const)` can be either a generic (with broadcast) or a pair (`linalg.broadcast + linalg.powf`) and then lowered "correctly". Either way, the lower dialects need to know what kind of broadcast anyway, so no materialization of the constant tensors should remain.
I want to flush the easy ones before we start working on type cast & softmax.
>From f68ffc043fb9c4105da9cd287282d2e9829dd8ee Mon Sep 17 00:00:00 2001
From: Renato Golin <rengolin at systemcall.eu>
Date: Sun, 28 Apr 2024 17:49:53 +0100
Subject: [PATCH 1/2] [MLIR][Linalg] Add powf operation
---
.../mlir/Dialect/Linalg/IR/LinalgEnums.td | 3 +-
.../Linalg/IR/LinalgNamedStructuredOps.yaml | 51 +++++++++++++++++++
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 3 ++
.../linalg/opdsl/lang/comprehension.py | 1 +
.../linalg/opdsl/ops/core_named_ops.py | 21 ++++++++
.../Dialect/Linalg/generalize-named-ops.mlir | 27 ++++++++++
mlir/test/Dialect/Linalg/named-ops-fail.mlir | 17 +++++++
mlir/test/Dialect/Linalg/named-ops.mlir | 34 +++++++++++++
8 files changed, 156 insertions(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
index 7a350d2c014262..6586a0f4dd7b36 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
@@ -42,7 +42,8 @@ def BinaryFn : I32EnumAttr<"BinaryFn", "", [
I32EnumAttrCase<"max_signed", 5>,
I32EnumAttrCase<"min_signed", 6>,
I32EnumAttrCase<"max_unsigned", 7>,
- I32EnumAttrCase<"min_unsigned", 8>
+ I32EnumAttrCase<"min_unsigned", 8>,
+ I32EnumAttrCase<"powf", 9>
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::linalg";
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index b7567577347587..9d60be773a0a50 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -922,6 +922,57 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_arg: rhs
--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+ name: powf
+ cpp_class_name: PowFOp
+ doc: |-
+ Takes the powf(lhs, rhs) between two inputs, elementwise. For powf(arg, 2) use `linalg.square`.
+
+ Only applies to floating point values.
+
+ The shapes and element types must be identical. The appropriate casts,
+ broadcasts and reductions should be done previously to calling this op.
+
+ This means reduction/broadcast/element cast semantics is explicit. Further
+ passes can take that into account when lowering this code. For example,
+ a `linalg.broadcast` + `linalg.powf` sequence can be lowered to a
+ `linalg.generic` with different affine maps for the two operands.
+structured_op: !LinalgStructuredOpConfig
+ args:
+ - !LinalgOperandDefConfig
+ name: lhs
+ kind: input_tensor
+ type_var: T1
+ shape_map: affine_map<() -> ()>
+ - !LinalgOperandDefConfig
+ name: rhs
+ kind: input_tensor
+ type_var: T1
+ shape_map: affine_map<() -> ()>
+ - !LinalgOperandDefConfig
+ name: O
+ kind: output_tensor
+ type_var: T1
+ shape_map: affine_map<() -> ()>
+ indexing_maps: !LinalgIndexingMapsConfig
+ static_indexing_maps:
+ - affine_map<() -> ()>
+ - affine_map<() -> ()>
+ - affine_map<() -> ()>
+ iterator_types: []
+ assignments:
+ - !ScalarAssign
+ arg: O
+ value: !ScalarExpression
+ scalar_fn:
+ kind: binary
+ fn_name: powf
+ operands:
+ - !ScalarExpression
+ scalar_arg: lhs
+ - !ScalarExpression
+ scalar_arg: rhs
+--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: matmul
cpp_class_name: MatmulOp
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 5d10b59373ad03..e935175798c2f3 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -483,6 +483,9 @@ class RegionBuilderHelper {
if (allFloatingPoint)
return builder.create<arith::MinimumFOp>(arg0.getLoc(), arg0, arg1);
return builder.create<arith::MinUIOp>(arg0.getLoc(), arg0, arg1);
+ case BinaryFn::powf:
+ assert(allFloatingPoint);
+ return builder.create<math::PowFOp>(arg0.getLoc(), arg0, arg1);
}
llvm_unreachable("unsupported binary function");
}
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
index f7bc81bd2f6833..75a0cf0c8f661c 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
@@ -335,6 +335,7 @@ class BinaryFn:
min_signed = BinaryFnType("min_signed")
max_unsigned = BinaryFnType("max_unsigned")
min_unsigned = BinaryFnType("min_unsigned")
+ powf = BinaryFnType("powf")
class TypeFnType:
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index 2c8864be1107fb..89b205ca266e13 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -318,6 +318,27 @@ def min(
O[None] = BinaryFn.min_signed(lhs[None], rhs[None])
+ at linalg_structured_op
+def powf(
+ lhs=TensorDef(T1),
+ rhs=TensorDef(T1),
+ O=TensorDef(T1, output=True),
+):
+ """Takes the powf(lhs, rhs) between two inputs, elementwise. For powf(arg, 2) use `linalg.square`.
+
+ Only applies to floating point values.
+
+ The shapes and element types must be identical. The appropriate casts,
+ broadcasts and reductions should be done previously to calling this op.
+
+ This means reduction/broadcast/element cast semantics is explicit. Further
+ passes can take that into account when lowering this code. For example,
+ a `linalg.broadcast` + `linalg.powf` sequence can be lowered to a
+ `linalg.generic` with different affine maps for the two operands.
+ """
+ O[None] = BinaryFn.powf(lhs[None], rhs[None])
+
+
@linalg_structured_op
def matmul(
A=TensorDef(T1, S.M, S.K),
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index add34412b92f2b..86e3cb38f788d1 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -741,6 +741,33 @@ func.func @generalize_min(%lhs: memref<7x14x21xf32>, %rhs: memref<7x14x21xf32>,
// CHECK-NEXT: %[[min:.+]] = arith.minimumf %[[BBARG0]], %[[BBARG1]] : f32
// CHECK-NEXT: linalg.yield %[[min]] : f32
+
+// -----
+
+func.func @generalize_powf(%lhs: memref<7x14x21xf32>, %rhs: memref<7x14x21xf32>,
+ %out: memref<7x14x21xf32>) {
+ linalg.powf ins(%lhs, %rhs : memref<7x14x21xf32>, memref<7x14x21xf32>)
+ outs(%out : memref<7x14x21xf32>)
+ return
+}
+
+// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+
+// CHECK: func @generalize_powf
+// CHECK-SAME: (%[[LHS:.+]]: memref<7x14x21xf32>, %[[RHS:.+]]: memref<7x14x21xf32>,
+// CHECK-SAME: %[[OUT:.+]]: memref<7x14x21xf32>)
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
+// CHECK-SAME: ins(%[[LHS]], %[[RHS]] : memref<7x14x21xf32>, memref<7x14x21xf32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xf32>)
+
+// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32)
+// CHECK-NEXT: %[[powf:.+]] = math.powf %[[BBARG0]], %[[BBARG1]] : f32
+// CHECK-NEXT: linalg.yield %[[powf]] : f32
+
+
// -----
diff --git a/mlir/test/Dialect/Linalg/named-ops-fail.mlir b/mlir/test/Dialect/Linalg/named-ops-fail.mlir
index f66608e71ffc64..3106f54b8d9760 100644
--- a/mlir/test/Dialect/Linalg/named-ops-fail.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops-fail.mlir
@@ -301,3 +301,20 @@ func.func @min_broadcast(%arg0: memref<8x16xf32>, %arg1: memref<4x8x16xf32>, %ar
linalg.min ins(%arg0, %arg1 : memref<8x16xf32>, memref<4x8x16xf32>) outs(%arg2: memref<4x8x16xf32>)
return
}
+
+// -----
+
+func.func @powf_type_cast(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf16>, %arg2: memref<4x8x16xf32>) {
+ // CHECK: op requires the same type for all operands and results
+ linalg.powf ins(%arg0, %arg1 : memref<4x8x16xf32>, memref<4x8x16xf16>) outs(%arg2: memref<4x8x16xf32>)
+ return
+}
+
+// -----
+
+func.func @powf_broadcast(%arg0: memref<8x16xf32>, %arg1: memref<4x8x16xf32>, %arg2: memref<4x8x16xf32>) {
+ // CHECK: op expected operand rank (2) to match the result rank of indexing_map #0 (3)
+ linalg.powf ins(%arg0, %arg1 : memref<8x16xf32>, memref<4x8x16xf32>) outs(%arg2: memref<4x8x16xf32>)
+ return
+}
+
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index cf59f673610013..d211286228fcff 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1851,6 +1851,40 @@ func.func @min_tensor(%arg0: tensor<4x8x16xf32>, %arg1: tensor<4x8x16xf32>) -> t
// -----
+// CHECK-LABEL: func @powf_dynamic
+func.func @powf_dynamic(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
+ // CHECK: linalg.powf
+ // CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<?x?x?xf32>, memref<?x?x?xf32>)
+ // CHECK-SAME: outs(%{{.+}} : memref<?x?x?xf32>)
+ linalg.powf ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2: memref<?x?x?xf32>)
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @powf_static
+func.func @powf_static(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf32>, %arg2: memref<4x8x16xf32>) {
+ // CHECK: linalg.powf
+ // CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<4x8x16xf32>, memref<4x8x16xf32>)
+ // CHECK-SAME: outs(%{{.+}} : memref<4x8x16xf32>)
+ linalg.powf ins(%arg0, %arg1 : memref<4x8x16xf32>, memref<4x8x16xf32>) outs(%arg2: memref<4x8x16xf32>)
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @powf_tensor
+func.func @powf_tensor(%arg0: tensor<4x8x16xf32>, %arg1: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> {
+ %0 = tensor.empty() : tensor<4x8x16xf32>
+ // CHECK: linalg.powf
+ // CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<4x8x16xf32>, tensor<4x8x16xf32>)
+ // CHECK-SAME: outs(%{{.+}} : tensor<4x8x16xf32>)
+ %1 = linalg.powf ins(%arg0, %arg1 : tensor<4x8x16xf32>, tensor<4x8x16xf32>) outs(%0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
+ return %1 : tensor<4x8x16xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @fill_tensor
func.func @fill_tensor(%f: f32, %v: vector<2x4xf32>) -> (tensor<f32>, tensor<vector<2x4xf32>>) {
%e0 = tensor.empty() : tensor<f32>
>From 587b4cb56672faed720249a40ec47d5197c98a03 Mon Sep 17 00:00:00 2001
From: Renato Golin <rengolin at systemcall.eu>
Date: Sun, 28 Apr 2024 18:30:17 +0100
Subject: [PATCH 2/2] [MLIR][Linalg] Add erf operation
---
.../mlir/Dialect/Linalg/IR/LinalgEnums.td | 3 +-
.../Linalg/IR/LinalgNamedStructuredOps.yaml | 35 +++++++++++++++++++
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 2 ++
.../linalg/opdsl/lang/comprehension.py | 1 +
.../linalg/opdsl/ops/core_named_ops.py | 12 +++++++
.../Dialect/Linalg/generalize-named-ops.mlir | 21 +++++++++++
mlir/test/Dialect/Linalg/named-ops-fail.mlir | 16 +++++++++
mlir/test/Dialect/Linalg/named-ops.mlir | 31 ++++++++++++++++
8 files changed, 120 insertions(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
index 6586a0f4dd7b36..6b4b073fc67246 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
@@ -28,7 +28,8 @@ def UnaryFn : I32EnumAttr<"UnaryFn", "", [
I32EnumAttrCase<"sqrt", 8>,
I32EnumAttrCase<"rsqrt", 9>,
I32EnumAttrCase<"square", 10>,
- I32EnumAttrCase<"tanh", 11>
+ I32EnumAttrCase<"tanh", 11>,
+ I32EnumAttrCase<"erf", 12>
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::linalg";
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 9d60be773a0a50..584bfcd8b59dc3 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -514,6 +514,41 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_arg: I
--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+ name: erf
+ cpp_class_name: erfOp
+ doc: |-
+ Applies erf(x) elementwise.
+
+ No numeric casting is performed on the input operand.
+structured_op: !LinalgStructuredOpConfig
+ args:
+ - !LinalgOperandDefConfig
+ name: I
+ kind: input_tensor
+ type_var: T1
+ shape_map: affine_map<() -> ()>
+ - !LinalgOperandDefConfig
+ name: O
+ kind: output_tensor
+ type_var: T1
+ shape_map: affine_map<() -> ()>
+ indexing_maps: !LinalgIndexingMapsConfig
+ static_indexing_maps:
+ - affine_map<() -> ()>
+ - affine_map<() -> ()>
+ iterator_types: []
+ assignments:
+ - !ScalarAssign
+ arg: O
+ value: !ScalarExpression
+ scalar_fn:
+ kind: unary
+ fn_name: erf
+ operands:
+ - !ScalarExpression
+ scalar_arg: I
+--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: elemwise_binary
cpp_class_name: ElemwiseBinaryOp
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index e935175798c2f3..036005ce9d9251 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -411,6 +411,8 @@ class RegionBuilderHelper {
return builder.create<arith::MulFOp>(arg.getLoc(), arg, arg);
case UnaryFn::tanh:
return builder.create<math::TanhOp>(arg.getLoc(), arg);
+ case UnaryFn::erf:
+ return builder.create<math::ErfOp>(arg.getLoc(), arg);
}
llvm_unreachable("unsupported unary function");
}
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
index 75a0cf0c8f661c..bb43ebf2b6923a 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
@@ -296,6 +296,7 @@ class UnaryFn:
rsqrt = UnaryFnType("rsqrt")
square = UnaryFnType("square")
tanh = UnaryFnType("tanh")
+ erf = UnaryFnType("erf")
class BinaryFnType:
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index 89b205ca266e13..ca2bb0c5f7f8a9 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -168,6 +168,18 @@ def tanh(
O[None] = UnaryFn.tanh(I[None])
+ at linalg_structured_op
+def erf(
+ I=TensorDef(T1),
+ O=TensorDef(T1, output=True),
+):
+ """Applies erf(x) elementwise.
+
+ No numeric casting is performed on the input operand.
+ """
+ O[None] = UnaryFn.erf(I[None])
+
+
@linalg_structured_op
def elemwise_binary(
lhs=TensorDef(T1),
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index 86e3cb38f788d1..667ea3c18c8ad3 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -693,6 +693,27 @@ func.func @generalize_tanh(%arg: memref<7x14x21xf32>, %out: memref<7x14x21xf32>)
// -----
+func.func @generalize_erf(%arg: memref<7x14x21xf32>, %out: memref<7x14x21xf32>) {
+ linalg.erf ins(%arg : memref<7x14x21xf32>) outs(%out : memref<7x14x21xf32>)
+ return
+}
+
+// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+
+// CHECK: func @generalize_erf
+// CHECK-SAME: (%[[ARG:.+]]: memref<7x14x21xf32>, %[[OUT:.+]]: memref<7x14x21xf32>)
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
+// CHECK-SAME: ins(%[[LHS]] : memref<7x14x21xf32>) outs(%[[OUT]] : memref<7x14x21xf32>)
+
+// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32)
+// CHECK-NEXT: %[[erf:.+]] = math.erf %[[BBARG0]] : f32
+// CHECK-NEXT: linalg.yield %[[erf]] : f32
+
+// -----
+
func.func @generalize_max(%lhs: memref<7x14x21xf32>, %rhs: memref<7x14x21xf32>,
%out: memref<7x14x21xf32>) {
linalg.max ins(%lhs, %rhs : memref<7x14x21xf32>, memref<7x14x21xf32>)
diff --git a/mlir/test/Dialect/Linalg/named-ops-fail.mlir b/mlir/test/Dialect/Linalg/named-ops-fail.mlir
index 3106f54b8d9760..e92a77aa7ad059 100644
--- a/mlir/test/Dialect/Linalg/named-ops-fail.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops-fail.mlir
@@ -272,6 +272,22 @@ func.func @tanh_broadcast(%arg: memref<8x16xf32>, %out: memref<4x8x16xf32>) {
// -----
+func.func @erf_type_cast(%arg: memref<4x8x16xf16>, %out: memref<4x8x16xf32>) {
+ // CHECK: operand 1 ('f16') doesn't match the element type of the enclosing linalg.generic op ('f32')
+ linalg.erf ins(%arg : memref<4x8x16xf16>) outs(%out: memref<4x8x16xf32>)
+ return
+}
+
+// -----
+
+func.func @erf_broadcast(%arg: memref<8x16xf32>, %out: memref<4x8x16xf32>) {
+ // CHECK: op expected operand rank (2) to match the result rank of indexing_map #0 (3)
+ linalg.erf ins(%arg : memref<8x16xf32>) outs(%out: memref<4x8x16xf32>)
+ return
+}
+
+// -----
+
func.func @max_type_cast(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf16>, %arg2: memref<4x8x16xf32>) {
// CHECK: op requires the same type for all operands and results
linalg.max ins(%arg0, %arg1 : memref<4x8x16xf32>, memref<4x8x16xf16>) outs(%arg2: memref<4x8x16xf32>)
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index d211286228fcff..fefe5578947f00 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1783,6 +1783,37 @@ func.func @tanh_tensor(%arg0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> {
// -----
+// CHECK-LABEL: func @erf_dynamic
+func.func @erf_dynamic(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>) {
+ // CHECK: linalg.erf
+ // CHECK-SAME: ins(%{{.+}} : memref<?x?x?xf32>) outs(%{{.+}} : memref<?x?x?xf32>)
+ linalg.erf ins(%arg0 : memref<?x?x?xf32>) outs(%arg1: memref<?x?x?xf32>)
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @erf_static
+func.func @erf_static(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf32>) {
+ // CHECK: linalg.erf
+ // CHECK-SAME: ins(%{{.+}} : memref<4x8x16xf32>) outs(%{{.+}} : memref<4x8x16xf32>)
+ linalg.erf ins(%arg0 : memref<4x8x16xf32>) outs(%arg1: memref<4x8x16xf32>)
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @erf_tensor
+func.func @erf_tensor(%arg0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> {
+ %0 = tensor.empty() : tensor<4x8x16xf32>
+ // CHECK: linalg.erf
+ // CHECK-SAME: ins(%{{.+}} : tensor<4x8x16xf32>) outs(%{{.+}} : tensor<4x8x16xf32>)
+ %1 = linalg.erf ins(%arg0 : tensor<4x8x16xf32>) outs(%0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
+ return %1 : tensor<4x8x16xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @max_dynamic
func.func @max_dynamic(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
// CHECK: linalg.max
More information about the Mlir-commits
mailing list