[Mlir-commits] [mlir] [MLIR][Linalg] More Linalg named ops (PR #90236)

Renato Golin llvmlistbot at llvm.org
Fri Apr 26 10:13:19 PDT 2024


https://github.com/rengolin created https://github.com/llvm/llvm-project/pull/90236

Adding `min` that was already implemented but not exposed.

Adding a few additional unary ops:
* Reciprocal as `arith.div(1,arg)`
* Round as `math.round(arg)`
* Sqrt as `math.sqrt(arg)`
* Rsqrt as `math.rsqrt(arg)`
* Square as `math.powf(arg, 2)`
* TanH as `math.tanh(arg)`

All with the agreed semantics at the round table: no implicit broadcast/type cast.

Small update on builder to get a numeric attribute (and update the zero/one functions to use it).

This is just the first step. Soon after this we'll add explicit type casts, select, clamp, square difference, generic pow.

After that, we'll discuss the semantics of `linalg.softmax` and how it should be lowered to existing named ops.

>From 680124b093fc3139904a67d551a3f8c46a57d2df Mon Sep 17 00:00:00 2001
From: Renato Golin <rengolin at systemcall.eu>
Date: Fri, 26 Apr 2024 12:35:31 +0100
Subject: [PATCH 1/2] [MLIR][Linalg] Add (signed) min binary op

---
 .../Linalg/IR/LinalgNamedStructuredOps.yaml   | 51 ++++++++++++++++++-
 .../linalg/opdsl/ops/core_named_ops.py        | 19 +++++++
 .../Dialect/Linalg/generalize-named-ops.mlir  | 25 +++++++++
 mlir/test/Dialect/Linalg/named-ops-fail.mlir  | 16 ++++++
 mlir/test/Dialect/Linalg/named-ops.mlir       | 34 +++++++++++++
 5 files changed, 144 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 1ff6c4086cf357..32312d3671c726 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -625,7 +625,7 @@ metadata: !LinalgOpMetadata
 
     This means reduction/broadcast/element cast semantics is explicit. Further
     passes can take that into account when lowering this code. For example,
-    a `linalg.broadcast` + `linalg.div` sequence can be lowered to a
+    a `linalg.broadcast` + `linalg.max` sequence can be lowered to a
     `linalg.generic` with different affine maps for the two operands.
 structured_op: !LinalgStructuredOpConfig
   args:
@@ -663,6 +663,55 @@ structured_op: !LinalgStructuredOpConfig
         - !ScalarExpression
           scalar_arg: rhs
 --- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: min
+  cpp_class_name: MinOp
+  doc: |-
+    Takes the min (signed) between two inputs, elementwise.
+
+    The shapes and element types must be identical. The appropriate casts,
+    broadcasts and reductions should be done previously to calling this op.
+
+    This means reduction/broadcast/element cast semantics is explicit. Further
+    passes can take that into account when lowering this code. For example,
+    a `linalg.broadcast` + `linalg.min` sequence can be lowered to a
+    `linalg.generic` with different affine maps for the two operands.
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !LinalgOperandDefConfig
+    name: lhs
+    kind: input_tensor
+    type_var: T1
+    shape_map: affine_map<() -> ()>
+  - !LinalgOperandDefConfig
+    name: rhs
+    kind: input_tensor
+    type_var: T1
+    shape_map: affine_map<() -> ()>
+  - !LinalgOperandDefConfig
+    name: O
+    kind: output_tensor
+    type_var: T1
+    shape_map: affine_map<() -> ()>
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<() -> ()>
+    - affine_map<() -> ()>
+    - affine_map<() -> ()>
+  iterator_types: []
+  assignments:
+  - !ScalarAssign
+    arg: O
+    value: !ScalarExpression
+      scalar_fn:
+        kind: binary
+        fn_name: min_signed
+        operands:
+        - !ScalarExpression
+          scalar_arg: lhs
+        - !ScalarExpression
+          scalar_arg: rhs
+--- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: matmul
   cpp_class_name: MatmulOp
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index 5b05364f6d35f3..64f566ce5f8ba6 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -239,6 +239,25 @@ def max(
     O[None] = BinaryFn.max_signed(lhs[None], rhs[None])
 
 
+ at linalg_structured_op
+def min(
+    lhs=TensorDef(T1),
+    rhs=TensorDef(T1),
+    O=TensorDef(T1, output=True),
+):
+    """Takes the min (signed) between two inputs, elementwise.
+
+    The shapes and element types must be identical. The appropriate casts,
+    broadcasts and reductions should be done previously to calling this op.
+
+    This means reduction/broadcast/element cast semantics is explicit. Further
+    passes can take that into account when lowering this code. For example,
+    a `linalg.broadcast` + `linalg.div` sequence can be lowered to a
+    `linalg.generic` with different affine maps for the two operands.
+    """
+    O[None] = BinaryFn.min_signed(lhs[None], rhs[None])
+
+
 @linalg_structured_op
 def matmul(
     A=TensorDef(T1, S.M, S.K),
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index e852824cdb7367..17264ac54ab158 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -590,6 +590,31 @@ func.func @generalize_max(%lhs: memref<7x14x21xf32>, %rhs: memref<7x14x21xf32>,
 
 // -----
 
+func.func @generalize_min(%lhs: memref<7x14x21xf32>, %rhs: memref<7x14x21xf32>,
+                          %out: memref<7x14x21xf32>) {
+  linalg.min ins(%lhs, %rhs : memref<7x14x21xf32>, memref<7x14x21xf32>)
+             outs(%out : memref<7x14x21xf32>)
+  return
+}
+
+// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+
+// CHECK: func @generalize_min
+// CHECK-SAME: (%[[LHS:.+]]: memref<7x14x21xf32>, %[[RHS:.+]]: memref<7x14x21xf32>,
+// CHECK-SAME:  %[[OUT:.+]]: memref<7x14x21xf32>)
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
+// CHECK-SAME:  ins(%[[LHS]], %[[RHS]] : memref<7x14x21xf32>, memref<7x14x21xf32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xf32>)
+
+// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32)
+// CHECK-NEXT:      %[[min:.+]] = arith.minimumf %[[BBARG0]], %[[BBARG1]] : f32
+// CHECK-NEXT:      linalg.yield %[[min]] : f32
+
+// -----
+
 
 // CHECK-LABEL: func @fill_tensor
 func.func @fill_tensor(%f: f32, %v: vector<2x4xf32>) -> (tensor<f32>, tensor<vector<2x4xf32>>) {
diff --git a/mlir/test/Dialect/Linalg/named-ops-fail.mlir b/mlir/test/Dialect/Linalg/named-ops-fail.mlir
index c351e139a97e37..8f4c3e90b303ac 100644
--- a/mlir/test/Dialect/Linalg/named-ops-fail.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops-fail.mlir
@@ -189,3 +189,19 @@ func.func @max_broadcast(%arg0: memref<8x16xf32>, %arg1: memref<4x8x16xf32>, %ar
   linalg.max ins(%arg0, %arg1 : memref<8x16xf32>, memref<4x8x16xf32>) outs(%arg2: memref<4x8x16xf32>)
   return
 }
+
+// -----
+
+func.func @min_type_cast(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf16>, %arg2: memref<4x8x16xf32>) {
+  // CHECK: op requires the same type for all operands and results
+  linalg.min ins(%arg0, %arg1 : memref<4x8x16xf32>, memref<4x8x16xf16>) outs(%arg2: memref<4x8x16xf32>)
+  return
+}
+
+// -----
+
+func.func @min_broadcast(%arg0: memref<8x16xf32>, %arg1: memref<4x8x16xf32>, %arg2: memref<4x8x16xf32>) {
+  // CHECK: op expected operand rank (2) to match the result rank of indexing_map #0 (3)
+  linalg.min ins(%arg0, %arg1 : memref<8x16xf32>, memref<4x8x16xf32>) outs(%arg2: memref<4x8x16xf32>)
+  return
+}
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index 7064e1b3f9dc76..c7149e1a00d779 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1631,6 +1631,40 @@ func.func @max_tensor(%arg0: tensor<4x8x16xf32>, %arg1: tensor<4x8x16xf32>) -> t
 
 // -----
 
+// CHECK-LABEL: func @min_dynamic
+func.func @min_dynamic(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
+  // CHECK: linalg.min
+  // CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<?x?x?xf32>, memref<?x?x?xf32>)
+  // CHECK-SAME: outs(%{{.+}} : memref<?x?x?xf32>)
+  linalg.min ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2: memref<?x?x?xf32>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @min_static
+func.func @min_static(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf32>, %arg2: memref<4x8x16xf32>) {
+  // CHECK: linalg.min
+  // CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<4x8x16xf32>, memref<4x8x16xf32>)
+  // CHECK-SAME: outs(%{{.+}} : memref<4x8x16xf32>)
+  linalg.min ins(%arg0, %arg1 : memref<4x8x16xf32>, memref<4x8x16xf32>) outs(%arg2: memref<4x8x16xf32>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @min_tensor
+func.func @min_tensor(%arg0: tensor<4x8x16xf32>, %arg1: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> {
+  %0 = tensor.empty() : tensor<4x8x16xf32>
+  // CHECK: linalg.min
+  // CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<4x8x16xf32>, tensor<4x8x16xf32>)
+  // CHECK-SAME: outs(%{{.+}} : tensor<4x8x16xf32>)
+  %1 = linalg.min ins(%arg0, %arg1 : tensor<4x8x16xf32>, tensor<4x8x16xf32>) outs(%0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
+  return %1 : tensor<4x8x16xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @fill_tensor
 func.func @fill_tensor(%f: f32, %v: vector<2x4xf32>) -> (tensor<f32>, tensor<vector<2x4xf32>>) {
   %e0 = tensor.empty() : tensor<f32>

>From aba65e39b0b0fa84ef34023a1b2a9d934762c1a7 Mon Sep 17 00:00:00 2001
From: Renato Golin <rengolin at systemcall.eu>
Date: Fri, 26 Apr 2024 13:47:03 +0100
Subject: [PATCH 2/2] [MLIR][Linalg] Adding more unary named ops

---
 .../mlir/Dialect/Linalg/IR/LinalgEnums.td     |   8 +-
 .../Linalg/IR/LinalgNamedStructuredOps.yaml   | 210 ++++++++++++++++++
 mlir/include/mlir/IR/Builders.h               |   6 +-
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      |  20 ++
 mlir/lib/IR/Builders.cpp                      |  28 +--
 .../linalg/opdsl/lang/comprehension.py        |   5 +
 .../linalg/opdsl/ops/core_named_ops.py        |  60 +++++
 .../Dialect/Linalg/generalize-named-ops.mlir  | 130 +++++++++++
 mlir/test/Dialect/Linalg/named-ops-fail.mlir  |  96 ++++++++
 mlir/test/Dialect/Linalg/named-ops.mlir       | 186 ++++++++++++++++
 10 files changed, 728 insertions(+), 21 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
index 59f909aed8f61a..7a350d2c014262 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
@@ -22,7 +22,13 @@ def UnaryFn : I32EnumAttr<"UnaryFn", "", [
   I32EnumAttrCase<"abs", 2>,
   I32EnumAttrCase<"ceil", 3>,
   I32EnumAttrCase<"floor", 4>,
-  I32EnumAttrCase<"negf", 5>
+  I32EnumAttrCase<"negf", 5>,
+  I32EnumAttrCase<"reciprocal", 6>,
+  I32EnumAttrCase<"round", 7>,
+  I32EnumAttrCase<"sqrt", 8>,
+  I32EnumAttrCase<"rsqrt", 9>,
+  I32EnumAttrCase<"square", 10>,
+  I32EnumAttrCase<"tanh", 11>
 ]> {
   let genSpecializedAttr = 0;
   let cppNamespace = "::mlir::linalg";
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 32312d3671c726..b7567577347587 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -304,6 +304,216 @@ structured_op: !LinalgStructuredOpConfig
         - !ScalarExpression
           scalar_arg: I
 --- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: reciprocal
+  cpp_class_name: ReciprocalOp
+  doc: |-
+    Applies reciprocal(x) elementwise.
+
+    No numeric casting is performed on the input operand.
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !LinalgOperandDefConfig
+    name: I
+    kind: input_tensor
+    type_var: T1
+    shape_map: affine_map<() -> ()>
+  - !LinalgOperandDefConfig
+    name: O
+    kind: output_tensor
+    type_var: T1
+    shape_map: affine_map<() -> ()>
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<() -> ()>
+    - affine_map<() -> ()>
+  iterator_types: []
+  assignments:
+  - !ScalarAssign
+    arg: O
+    value: !ScalarExpression
+      scalar_fn:
+        kind: unary
+        fn_name: reciprocal
+        operands:
+        - !ScalarExpression
+          scalar_arg: I
+--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: round
+  cpp_class_name: RoundOp
+  doc: |-
+    Applies round(x) elementwise.
+
+    No numeric casting is performed on the input operand.
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !LinalgOperandDefConfig
+    name: I
+    kind: input_tensor
+    type_var: T1
+    shape_map: affine_map<() -> ()>
+  - !LinalgOperandDefConfig
+    name: O
+    kind: output_tensor
+    type_var: T1
+    shape_map: affine_map<() -> ()>
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<() -> ()>
+    - affine_map<() -> ()>
+  iterator_types: []
+  assignments:
+  - !ScalarAssign
+    arg: O
+    value: !ScalarExpression
+      scalar_fn:
+        kind: unary
+        fn_name: round
+        operands:
+        - !ScalarExpression
+          scalar_arg: I
+--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: sqrt
+  cpp_class_name: SqrtOp
+  doc: |-
+    Applies sqrt(x) elementwise.
+
+    No numeric casting is performed on the input operand.
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !LinalgOperandDefConfig
+    name: I
+    kind: input_tensor
+    type_var: T1
+    shape_map: affine_map<() -> ()>
+  - !LinalgOperandDefConfig
+    name: O
+    kind: output_tensor
+    type_var: T1
+    shape_map: affine_map<() -> ()>
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<() -> ()>
+    - affine_map<() -> ()>
+  iterator_types: []
+  assignments:
+  - !ScalarAssign
+    arg: O
+    value: !ScalarExpression
+      scalar_fn:
+        kind: unary
+        fn_name: sqrt
+        operands:
+        - !ScalarExpression
+          scalar_arg: I
+--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: rsqrt
+  cpp_class_name: RsqrtOp
+  doc: |-
+    Applies rsqrt(x) elementwise.
+
+    No numeric casting is performed on the input operand.
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !LinalgOperandDefConfig
+    name: I
+    kind: input_tensor
+    type_var: T1
+    shape_map: affine_map<() -> ()>
+  - !LinalgOperandDefConfig
+    name: O
+    kind: output_tensor
+    type_var: T1
+    shape_map: affine_map<() -> ()>
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<() -> ()>
+    - affine_map<() -> ()>
+  iterator_types: []
+  assignments:
+  - !ScalarAssign
+    arg: O
+    value: !ScalarExpression
+      scalar_fn:
+        kind: unary
+        fn_name: rsqrt
+        operands:
+        - !ScalarExpression
+          scalar_arg: I
+--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: square
+  cpp_class_name: SquareOp
+  doc: |-
+    Applies square(x) elementwise.
+
+    No numeric casting is performed on the input operand.
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !LinalgOperandDefConfig
+    name: I
+    kind: input_tensor
+    type_var: T1
+    shape_map: affine_map<() -> ()>
+  - !LinalgOperandDefConfig
+    name: O
+    kind: output_tensor
+    type_var: T1
+    shape_map: affine_map<() -> ()>
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<() -> ()>
+    - affine_map<() -> ()>
+  iterator_types: []
+  assignments:
+  - !ScalarAssign
+    arg: O
+    value: !ScalarExpression
+      scalar_fn:
+        kind: unary
+        fn_name: square
+        operands:
+        - !ScalarExpression
+          scalar_arg: I
+--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: tanh
+  cpp_class_name: TanhOp
+  doc: |-
+    Applies tanh(x) elementwise.
+
+    No numeric casting is performed on the input operand.
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !LinalgOperandDefConfig
+    name: I
+    kind: input_tensor
+    type_var: T1
+    shape_map: affine_map<() -> ()>
+  - !LinalgOperandDefConfig
+    name: O
+    kind: output_tensor
+    type_var: T1
+    shape_map: affine_map<() -> ()>
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<() -> ()>
+    - affine_map<() -> ()>
+  iterator_types: []
+  assignments:
+  - !ScalarAssign
+    arg: O
+    value: !ScalarExpression
+      scalar_fn:
+        kind: unary
+        fn_name: tanh
+        operands:
+        - !ScalarExpression
+          scalar_arg: I
+--- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: elemwise_binary
   cpp_class_name: ElemwiseBinaryOp
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 0d5fa719d0dee2..308c2cd38196cb 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -115,12 +115,16 @@ class Builder {
   ArrayAttr getArrayAttr(ArrayRef<Attribute> value);
 
   // Returns a 0-valued attribute of the given `type`. This function only
-  // supports boolean, integer, and 16-/32-/64-bit float types, and vector or
+  // supports integer, and 16-/32-/64-bit float types, and vector or
   // ranked tensor of them. Returns null attribute otherwise.
   TypedAttr getZeroAttr(Type type);
   // Returns a 1-valued attribute of the given `type`.
   // Type constraints are the same as `getZeroAttr`.
   TypedAttr getOneAttr(Type type);
+  // Returns a numeric attribute of the given `type`.
+  // Type constraints are the same as `getZeroAttr`.
+  // Non float types are converted before returning the attribute.
+  TypedAttr getNumberAttr(double value, Type type);
 
   // Convenience methods for fixed types.
   FloatAttr getF16FloatAttr(float value);
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 9c5c58fa1fabfb..6ee5864656cf64 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -395,6 +395,26 @@ class RegionBuilderHelper {
       return builder.create<math::FloorOp>(arg.getLoc(), arg);
     case UnaryFn::negf:
       return builder.create<arith::NegFOp>(arg.getLoc(), arg);
+    case UnaryFn::reciprocal:
+    {
+      Attribute oneAttr = builder.getNumberAttr(1.0, arg.getType());
+      auto one = builder.create<arith::ConstantOp>(arg.getLoc(), ::cast<TypedAttr>(oneAttr));
+      return builder.create<arith::DivFOp>(arg.getLoc(), one, arg);
+    }
+    case UnaryFn::round:
+      return builder.create<math::RoundOp>(arg.getLoc(), arg);
+    case UnaryFn::sqrt:
+      return builder.create<math::SqrtOp>(arg.getLoc(), arg);
+    case UnaryFn::rsqrt:
+      return builder.create<math::RsqrtOp>(arg.getLoc(), arg);
+    case UnaryFn::square:
+    {
+      Attribute twoAttr = builder.getNumberAttr(2.0, arg.getType());
+      auto two = builder.create<arith::ConstantOp>(arg.getLoc(), ::cast<TypedAttr>(twoAttr));
+      return builder.create<math::PowFOp>(arg.getLoc(), arg, two);
+    }
+    case UnaryFn::tanh:
+      return builder.create<math::TanhOp>(arg.getLoc(), arg);
     }
     llvm_unreachable("unsupported unary function");
   }
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index d49f69a7b7ae6b..ae8348c8454e9e 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -329,34 +329,24 @@ ArrayAttr Builder::getAffineMapArrayAttr(ArrayRef<AffineMap> values) {
 }
 
 TypedAttr Builder::getZeroAttr(Type type) {
-  if (llvm::isa<FloatType>(type))
-    return getFloatAttr(type, 0.0);
-  if (llvm::isa<IndexType>(type))
-    return getIndexAttr(0);
-  if (llvm::dyn_cast<IntegerType>(type))
-    return getIntegerAttr(type,
-                          APInt(llvm::cast<IntegerType>(type).getWidth(), 0));
-  if (llvm::isa<RankedTensorType, VectorType>(type)) {
-    auto vtType = llvm::cast<ShapedType>(type);
-    auto element = getZeroAttr(vtType.getElementType());
-    if (!element)
-      return {};
-    return DenseElementsAttr::get(vtType, element);
-  }
-  return {};
+  return getNumberAttr(0.0, type);
 }
 
 TypedAttr Builder::getOneAttr(Type type) {
+  return getNumberAttr(1.0, type);
+}
+
+TypedAttr Builder::getNumberAttr(double value, Type type) {
   if (llvm::isa<FloatType>(type))
-    return getFloatAttr(type, 1.0);
+    return getFloatAttr(type, value);
   if (llvm::isa<IndexType>(type))
-    return getIndexAttr(1);
+    return getIndexAttr(static_cast<int64_t>(value));
   if (llvm::dyn_cast<IntegerType>(type))
     return getIntegerAttr(type,
-                          APInt(llvm::cast<IntegerType>(type).getWidth(), 1));
+                          APInt(llvm::cast<IntegerType>(type).getWidth(), static_cast<int64_t>(value)));
   if (llvm::isa<RankedTensorType, VectorType>(type)) {
     auto vtType = llvm::cast<ShapedType>(type);
-    auto element = getOneAttr(vtType.getElementType());
+    auto element = getNumberAttr(value, vtType.getElementType());
     if (!element)
       return {};
     return DenseElementsAttr::get(vtType, element);
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
index 23d6d26b7e294c..f7bc81bd2f6833 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
@@ -291,6 +291,11 @@ class UnaryFn:
     ceil = UnaryFnType("ceil")
     floor = UnaryFnType("floor")
     negf = UnaryFnType("negf")
+    round = UnaryFnType("round")
+    sqrt = UnaryFnType("sqrt")
+    rsqrt = UnaryFnType("rsqrt")
+    square = UnaryFnType("square")
+    tanh = UnaryFnType("tanh")
 
 
 class BinaryFnType:
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index 64f566ce5f8ba6..c97ac448006505 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -108,6 +108,66 @@ def negf(
     O[None] = UnaryFn.negf(I[None])
 
 
+ at linalg_structured_op
+def round(
+    I=TensorDef(T1),
+    O=TensorDef(T1, output=True),
+):
+    """Applies round(x) elementwise.
+
+    No numeric casting is performed on the input operand.
+    """
+    O[None] = UnaryFn.round(I[None])
+
+
+ at linalg_structured_op
+def sqrt(
+    I=TensorDef(T1),
+    O=TensorDef(T1, output=True),
+):
+    """Applies sqrt(x) elementwise.
+
+    No numeric casting is performed on the input operand.
+    """
+    O[None] = UnaryFn.sqrt(I[None])
+
+
+ at linalg_structured_op
+def rsqrt(
+    I=TensorDef(T1),
+    O=TensorDef(T1, output=True),
+):
+    """Applies rsqrt(x) elementwise.
+
+    No numeric casting is performed on the input operand.
+    """
+    O[None] = UnaryFn.rsqrt(I[None])
+
+
+ at linalg_structured_op
+def square(
+    I=TensorDef(T1),
+    O=TensorDef(T1, output=True),
+):
+    """Applies square(x) elementwise.
+
+    No numeric casting is performed on the input operand.
+    """
+    O[None] = UnaryFn.square(I[None])
+
+
+ at linalg_structured_op
+def tanh(
+    I=TensorDef(T1),
+    O=TensorDef(T1, output=True),
+):
+    """Applies tanh(x) elementwise.
+
+    No numeric casting is performed on the input operand.
+    """
+    O[None] = UnaryFn.tanh(I[None])
+
+
 @linalg_structured_op
 def elemwise_binary(
     lhs=TensorDef(T1),
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index 17264ac54ab158..160d1e9536bb5b 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -565,6 +565,136 @@ func.func @generalize_negf(%arg: memref<7x14x21xf32>, %out: memref<7x14x21xf32>)
 
 // -----
 
+func.func @generalize_reciprocal(%arg: memref<7x14x21xf32>, %out: memref<7x14x21xf32>) {
+  linalg.reciprocal ins(%arg : memref<7x14x21xf32>) outs(%out : memref<7x14x21xf32>)
+  return
+}
+
+// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+
+// CHECK: func @generalize_reciprocal
+// CHECK-SAME: (%[[ARG:.+]]: memref<7x14x21xf32>, %[[OUT:.+]]: memref<7x14x21xf32>)
+
+// CHECK: %[[one:.+]] = arith.constant 1.000000e+00 : f32
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
+// CHECK-SAME:  ins(%[[LHS]] : memref<7x14x21xf32>) outs(%[[OUT]] : memref<7x14x21xf32>)
+
+// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32)
+// CHECK-NEXT:      %[[reciprocal:.+]] = arith.divf %[[one]], %[[BBARG0]] : f32
+// CHECK-NEXT:      linalg.yield %[[reciprocal]] : f32
+
+// -----
+
+func.func @generalize_round(%arg: memref<7x14x21xf32>, %out: memref<7x14x21xf32>) {
+  linalg.round ins(%arg : memref<7x14x21xf32>) outs(%out : memref<7x14x21xf32>)
+  return
+}
+
+// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+
+// CHECK: func @generalize_round
+// CHECK-SAME: (%[[ARG:.+]]: memref<7x14x21xf32>, %[[OUT:.+]]: memref<7x14x21xf32>)
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
+// CHECK-SAME:  ins(%[[LHS]] : memref<7x14x21xf32>) outs(%[[OUT]] : memref<7x14x21xf32>)
+
+// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32)
+// CHECK-NEXT:      %[[round:.+]] = math.round %[[BBARG0]] : f32
+// CHECK-NEXT:      linalg.yield %[[round]] : f32
+
+// -----
+
+func.func @generalize_sqrt(%arg: memref<7x14x21xf32>, %out: memref<7x14x21xf32>) {
+  linalg.sqrt ins(%arg : memref<7x14x21xf32>) outs(%out : memref<7x14x21xf32>)
+  return
+}
+
+// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+
+// CHECK: func @generalize_sqrt
+// CHECK-SAME: (%[[ARG:.+]]: memref<7x14x21xf32>, %[[OUT:.+]]: memref<7x14x21xf32>)
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
+// CHECK-SAME:  ins(%[[LHS]] : memref<7x14x21xf32>) outs(%[[OUT]] : memref<7x14x21xf32>)
+
+// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32)
+// CHECK-NEXT:      %[[sqrt:.+]] = math.sqrt %[[BBARG0]] : f32
+// CHECK-NEXT:      linalg.yield %[[sqrt]] : f32
+
+// -----
+
+func.func @generalize_rsqrt(%arg: memref<7x14x21xf32>, %out: memref<7x14x21xf32>) {
+  linalg.rsqrt ins(%arg : memref<7x14x21xf32>) outs(%out : memref<7x14x21xf32>)
+  return
+}
+
+// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+
+// CHECK: func @generalize_rsqrt
+// CHECK-SAME: (%[[ARG:.+]]: memref<7x14x21xf32>, %[[OUT:.+]]: memref<7x14x21xf32>)
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
+// CHECK-SAME:  ins(%[[LHS]] : memref<7x14x21xf32>) outs(%[[OUT]] : memref<7x14x21xf32>)
+
+// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32)
+// CHECK-NEXT:      %[[rsqrt:.+]] = math.rsqrt %[[BBARG0]] : f32
+// CHECK-NEXT:      linalg.yield %[[rsqrt]] : f32
+
+// -----
+
+func.func @generalize_square(%arg: memref<7x14x21xf32>, %out: memref<7x14x21xf32>) {
+  linalg.square ins(%arg : memref<7x14x21xf32>) outs(%out : memref<7x14x21xf32>)
+  return
+}
+
+// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+
+// CHECK: func @generalize_square
+// CHECK-SAME: (%[[ARG:.+]]: memref<7x14x21xf32>, %[[OUT:.+]]: memref<7x14x21xf32>)
+
+// CHECK: %[[two:.+]] = arith.constant 2.000000e+00 : f32
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
+// CHECK-SAME:  ins(%[[LHS]] : memref<7x14x21xf32>) outs(%[[OUT]] : memref<7x14x21xf32>)
+
+// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32)
+// CHECK-NEXT:      %[[square:.+]] = math.powf %[[BBARG0]], %[[two]] : f32
+// CHECK-NEXT:      linalg.yield %[[square]] : f32
+
+// -----
+
+func.func @generalize_tanh(%arg: memref<7x14x21xf32>, %out: memref<7x14x21xf32>) {
+  linalg.tanh ins(%arg : memref<7x14x21xf32>) outs(%out : memref<7x14x21xf32>)
+  return
+}
+
+// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+
+// CHECK: func @generalize_tanh
+// CHECK-SAME: (%[[ARG:.+]]: memref<7x14x21xf32>, %[[OUT:.+]]: memref<7x14x21xf32>)
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
+// CHECK-SAME:  ins(%[[LHS]] : memref<7x14x21xf32>) outs(%[[OUT]] : memref<7x14x21xf32>)
+
+// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32)
+// CHECK-NEXT:      %[[tanh:.+]] = math.tanh %[[BBARG0]] : f32
+// CHECK-NEXT:      linalg.yield %[[tanh]] : f32
+
+// -----
+
 func.func @generalize_max(%lhs: memref<7x14x21xf32>, %rhs: memref<7x14x21xf32>,
                           %out: memref<7x14x21xf32>) {
   linalg.max ins(%lhs, %rhs : memref<7x14x21xf32>, memref<7x14x21xf32>)
diff --git a/mlir/test/Dialect/Linalg/named-ops-fail.mlir b/mlir/test/Dialect/Linalg/named-ops-fail.mlir
index 8f4c3e90b303ac..f66608e71ffc64 100644
--- a/mlir/test/Dialect/Linalg/named-ops-fail.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops-fail.mlir
@@ -176,6 +176,102 @@ func.func @negf_broadcast(%arg: memref<8x16xf32>, %out: memref<4x8x16xf32>) {
 
 // -----
 
+func.func @reciprocal_type_cast(%arg: memref<4x8x16xf16>, %out: memref<4x8x16xf32>) {
+  // CHECK: operand 1 ('f16') doesn't match the element type of the enclosing linalg.generic op ('f32')
+  linalg.reciprocal ins(%arg : memref<4x8x16xf16>) outs(%out: memref<4x8x16xf32>)
+  return
+}
+
+// -----
+
+func.func @reciprocal_broadcast(%arg: memref<8x16xf32>, %out: memref<4x8x16xf32>) {
+  // CHECK: op expected operand rank (2) to match the result rank of indexing_map #0 (3)
+  linalg.reciprocal ins(%arg : memref<8x16xf32>) outs(%out: memref<4x8x16xf32>)
+  return
+}
+
+// -----
+
+func.func @round_type_cast(%arg: memref<4x8x16xf16>, %out: memref<4x8x16xf32>) {
+  // CHECK: operand 1 ('f16') doesn't match the element type of the enclosing linalg.generic op ('f32')
+  linalg.round ins(%arg : memref<4x8x16xf16>) outs(%out: memref<4x8x16xf32>)
+  return
+}
+
+// -----
+
+func.func @round_broadcast(%arg: memref<8x16xf32>, %out: memref<4x8x16xf32>) {
+  // CHECK: op expected operand rank (2) to match the result rank of indexing_map #0 (3)
+  linalg.round ins(%arg : memref<8x16xf32>) outs(%out: memref<4x8x16xf32>)
+  return
+}
+
+// -----
+
+func.func @sqrt_type_cast(%arg: memref<4x8x16xf16>, %out: memref<4x8x16xf32>) {
+  // CHECK: operand 1 ('f16') doesn't match the element type of the enclosing linalg.generic op ('f32')
+  linalg.sqrt ins(%arg : memref<4x8x16xf16>) outs(%out: memref<4x8x16xf32>)
+  return
+}
+
+// -----
+
+func.func @sqrt_broadcast(%arg: memref<8x16xf32>, %out: memref<4x8x16xf32>) {
+  // CHECK: op expected operand rank (2) to match the result rank of indexing_map #0 (3)
+  linalg.sqrt ins(%arg : memref<8x16xf32>) outs(%out: memref<4x8x16xf32>)
+  return
+}
+
+// -----
+
+func.func @rsqrt_type_cast(%arg: memref<4x8x16xf16>, %out: memref<4x8x16xf32>) {
+  // CHECK: operand 1 ('f16') doesn't match the element type of the enclosing linalg.generic op ('f32')
+  linalg.rsqrt ins(%arg : memref<4x8x16xf16>) outs(%out: memref<4x8x16xf32>)
+  return
+}
+
+// -----
+
+func.func @rsqrt_broadcast(%arg: memref<8x16xf32>, %out: memref<4x8x16xf32>) {
+  // CHECK: op expected operand rank (2) to match the result rank of indexing_map #0 (3)
+  linalg.rsqrt ins(%arg : memref<8x16xf32>) outs(%out: memref<4x8x16xf32>)
+  return
+}
+
+// -----
+
+func.func @square_type_cast(%arg: memref<4x8x16xf16>, %out: memref<4x8x16xf32>) {
+  // CHECK: operand 1 ('f16') doesn't match the element type of the enclosing linalg.generic op ('f32')
+  linalg.square ins(%arg : memref<4x8x16xf16>) outs(%out: memref<4x8x16xf32>)
+  return
+}
+
+// -----
+
+func.func @square_broadcast(%arg: memref<8x16xf32>, %out: memref<4x8x16xf32>) {
+  // CHECK: op expected operand rank (2) to match the result rank of indexing_map #0 (3)
+  linalg.square ins(%arg : memref<8x16xf32>) outs(%out: memref<4x8x16xf32>)
+  return
+}
+
+// -----
+
+func.func @tanh_type_cast(%arg: memref<4x8x16xf16>, %out: memref<4x8x16xf32>) {
+  // CHECK: operand 1 ('f16') doesn't match the element type of the enclosing linalg.generic op ('f32')
+  linalg.tanh ins(%arg : memref<4x8x16xf16>) outs(%out: memref<4x8x16xf32>)
+  return
+}
+
+// -----
+
+func.func @tanh_broadcast(%arg: memref<8x16xf32>, %out: memref<4x8x16xf32>) {
+  // CHECK: op expected operand rank (2) to match the result rank of indexing_map #0 (3)
+  linalg.tanh ins(%arg : memref<8x16xf32>) outs(%out: memref<4x8x16xf32>)
+  return
+}
+
+// -----
+
 func.func @max_type_cast(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf16>, %arg2: memref<4x8x16xf32>) {
   // CHECK: op requires the same type for all operands and results
   linalg.max ins(%arg0, %arg1 : memref<4x8x16xf32>, memref<4x8x16xf16>) outs(%arg2: memref<4x8x16xf32>)
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index c7149e1a00d779..cf59f673610013 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1597,6 +1597,192 @@ func.func @negf_tensor(%arg0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> {
 
 // -----
 
+// CHECK-LABEL: func @reciprocal_dynamic
+func.func @reciprocal_dynamic(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>) {
+  // CHECK: linalg.reciprocal
+  // CHECK-SAME: ins(%{{.+}} : memref<?x?x?xf32>) outs(%{{.+}} : memref<?x?x?xf32>)
+  linalg.reciprocal ins(%arg0 : memref<?x?x?xf32>) outs(%arg1: memref<?x?x?xf32>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @reciprocal_static
+func.func @reciprocal_static(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf32>) {
+  // CHECK: linalg.reciprocal
+  // CHECK-SAME: ins(%{{.+}} : memref<4x8x16xf32>) outs(%{{.+}} : memref<4x8x16xf32>)
+  linalg.reciprocal ins(%arg0 : memref<4x8x16xf32>) outs(%arg1: memref<4x8x16xf32>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @reciprocal_tensor
+func.func @reciprocal_tensor(%arg0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> {
+  %0 = tensor.empty() : tensor<4x8x16xf32>
+  // CHECK: linalg.reciprocal
+  // CHECK-SAME: ins(%{{.+}} : tensor<4x8x16xf32>) outs(%{{.+}} : tensor<4x8x16xf32>)
+  %1 = linalg.reciprocal ins(%arg0 : tensor<4x8x16xf32>) outs(%0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
+  return %1 : tensor<4x8x16xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @round_dynamic
+func.func @round_dynamic(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>) {
+  // CHECK: linalg.round
+  // CHECK-SAME: ins(%{{.+}} : memref<?x?x?xf32>) outs(%{{.+}} : memref<?x?x?xf32>)
+  linalg.round ins(%arg0 : memref<?x?x?xf32>) outs(%arg1: memref<?x?x?xf32>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @round_static
+func.func @round_static(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf32>) {
+  // CHECK: linalg.round
+  // CHECK-SAME: ins(%{{.+}} : memref<4x8x16xf32>) outs(%{{.+}} : memref<4x8x16xf32>)
+  linalg.round ins(%arg0 : memref<4x8x16xf32>) outs(%arg1: memref<4x8x16xf32>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @round_tensor
+func.func @round_tensor(%arg0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> {
+  %0 = tensor.empty() : tensor<4x8x16xf32>
+  // CHECK: linalg.round
+  // CHECK-SAME: ins(%{{.+}} : tensor<4x8x16xf32>) outs(%{{.+}} : tensor<4x8x16xf32>)
+  %1 = linalg.round ins(%arg0 : tensor<4x8x16xf32>) outs(%0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
+  return %1 : tensor<4x8x16xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @sqrt_dynamic
+func.func @sqrt_dynamic(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>) {
+  // CHECK: linalg.sqrt
+  // CHECK-SAME: ins(%{{.+}} : memref<?x?x?xf32>) outs(%{{.+}} : memref<?x?x?xf32>)
+  linalg.sqrt ins(%arg0 : memref<?x?x?xf32>) outs(%arg1: memref<?x?x?xf32>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @sqrt_static
+func.func @sqrt_static(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf32>) {
+  // CHECK: linalg.sqrt
+  // CHECK-SAME: ins(%{{.+}} : memref<4x8x16xf32>) outs(%{{.+}} : memref<4x8x16xf32>)
+  linalg.sqrt ins(%arg0 : memref<4x8x16xf32>) outs(%arg1: memref<4x8x16xf32>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @sqrt_tensor
+func.func @sqrt_tensor(%arg0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> {
+  %0 = tensor.empty() : tensor<4x8x16xf32>
+  // CHECK: linalg.sqrt
+  // CHECK-SAME: ins(%{{.+}} : tensor<4x8x16xf32>) outs(%{{.+}} : tensor<4x8x16xf32>)
+  %1 = linalg.sqrt ins(%arg0 : tensor<4x8x16xf32>) outs(%0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
+  return %1 : tensor<4x8x16xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @rsqrt_dynamic
+func.func @rsqrt_dynamic(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>) {
+  // CHECK: linalg.rsqrt
+  // CHECK-SAME: ins(%{{.+}} : memref<?x?x?xf32>) outs(%{{.+}} : memref<?x?x?xf32>)
+  linalg.rsqrt ins(%arg0 : memref<?x?x?xf32>) outs(%arg1: memref<?x?x?xf32>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @rsqrt_static
+func.func @rsqrt_static(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf32>) {
+  // CHECK: linalg.rsqrt
+  // CHECK-SAME: ins(%{{.+}} : memref<4x8x16xf32>) outs(%{{.+}} : memref<4x8x16xf32>)
+  linalg.rsqrt ins(%arg0 : memref<4x8x16xf32>) outs(%arg1: memref<4x8x16xf32>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @rsqrt_tensor
+func.func @rsqrt_tensor(%arg0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> {
+  %0 = tensor.empty() : tensor<4x8x16xf32>
+  // CHECK: linalg.rsqrt
+  // CHECK-SAME: ins(%{{.+}} : tensor<4x8x16xf32>) outs(%{{.+}} : tensor<4x8x16xf32>)
+  %1 = linalg.rsqrt ins(%arg0 : tensor<4x8x16xf32>) outs(%0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
+  return %1 : tensor<4x8x16xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @square_dynamic
+func.func @square_dynamic(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>) {
+  // CHECK: linalg.square
+  // CHECK-SAME: ins(%{{.+}} : memref<?x?x?xf32>) outs(%{{.+}} : memref<?x?x?xf32>)
+  linalg.square ins(%arg0 : memref<?x?x?xf32>) outs(%arg1: memref<?x?x?xf32>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @square_static
+func.func @square_static(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf32>) {
+  // CHECK: linalg.square
+  // CHECK-SAME: ins(%{{.+}} : memref<4x8x16xf32>) outs(%{{.+}} : memref<4x8x16xf32>)
+  linalg.square ins(%arg0 : memref<4x8x16xf32>) outs(%arg1: memref<4x8x16xf32>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @square_tensor
+func.func @square_tensor(%arg0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> {
+  %0 = tensor.empty() : tensor<4x8x16xf32>
+  // CHECK: linalg.square
+  // CHECK-SAME: ins(%{{.+}} : tensor<4x8x16xf32>) outs(%{{.+}} : tensor<4x8x16xf32>)
+  %1 = linalg.square ins(%arg0 : tensor<4x8x16xf32>) outs(%0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
+  return %1 : tensor<4x8x16xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @tanh_dynamic
+func.func @tanh_dynamic(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>) {
+  // CHECK: linalg.tanh
+  // CHECK-SAME: ins(%{{.+}} : memref<?x?x?xf32>) outs(%{{.+}} : memref<?x?x?xf32>)
+  linalg.tanh ins(%arg0 : memref<?x?x?xf32>) outs(%arg1: memref<?x?x?xf32>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @tanh_static
+func.func @tanh_static(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf32>) {
+  // CHECK: linalg.tanh
+  // CHECK-SAME: ins(%{{.+}} : memref<4x8x16xf32>) outs(%{{.+}} : memref<4x8x16xf32>)
+  linalg.tanh ins(%arg0 : memref<4x8x16xf32>) outs(%arg1: memref<4x8x16xf32>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @tanh_tensor
+func.func @tanh_tensor(%arg0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> {
+  %0 = tensor.empty() : tensor<4x8x16xf32>
+  // CHECK: linalg.tanh
+  // CHECK-SAME: ins(%{{.+}} : tensor<4x8x16xf32>) outs(%{{.+}} : tensor<4x8x16xf32>)
+  %1 = linalg.tanh ins(%arg0 : tensor<4x8x16xf32>) outs(%0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
+  return %1 : tensor<4x8x16xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @max_dynamic
 func.func @max_dynamic(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
   // CHECK: linalg.max



More information about the Mlir-commits mailing list