[Mlir-commits] [mlir] [MLIR][Linalg] More Linalg named ops (PR #90236)

Renato Golin llvmlistbot at llvm.org
Fri Apr 26 10:21:29 PDT 2024


https://github.com/rengolin updated https://github.com/llvm/llvm-project/pull/90236

>From 680124b093fc3139904a67d551a3f8c46a57d2df Mon Sep 17 00:00:00 2001
From: Renato Golin <rengolin at systemcall.eu>
Date: Fri, 26 Apr 2024 12:35:31 +0100
Subject: [PATCH 1/3] [MLIR][Linalg] Add (signed) min binary op

---
 .../Linalg/IR/LinalgNamedStructuredOps.yaml   | 51 ++++++++++++++++++-
 .../linalg/opdsl/ops/core_named_ops.py        | 19 +++++++
 .../Dialect/Linalg/generalize-named-ops.mlir  | 25 +++++++++
 mlir/test/Dialect/Linalg/named-ops-fail.mlir  | 16 ++++++
 mlir/test/Dialect/Linalg/named-ops.mlir       | 34 +++++++++++++
 5 files changed, 144 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 1ff6c4086cf357..32312d3671c726 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -625,7 +625,7 @@ metadata: !LinalgOpMetadata
 
     This means reduction/broadcast/element cast semantics is explicit. Further
     passes can take that into account when lowering this code. For example,
-    a `linalg.broadcast` + `linalg.div` sequence can be lowered to a
+    a `linalg.broadcast` + `linalg.max` sequence can be lowered to a
     `linalg.generic` with different affine maps for the two operands.
 structured_op: !LinalgStructuredOpConfig
   args:
@@ -663,6 +663,55 @@ structured_op: !LinalgStructuredOpConfig
         - !ScalarExpression
           scalar_arg: rhs
 --- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: min
+  cpp_class_name: MinOp
+  doc: |-
+    Takes the min (signed) between two inputs, elementwise.
+
+    The shapes and element types must be identical. The appropriate casts,
+    broadcasts and reductions should be done previously to calling this op.
+
+    This means reduction/broadcast/element cast semantics is explicit. Further
+    passes can take that into account when lowering this code. For example,
+    a `linalg.broadcast` + `linalg.min` sequence can be lowered to a
+    `linalg.generic` with different affine maps for the two operands.
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !LinalgOperandDefConfig
+    name: lhs
+    kind: input_tensor
+    type_var: T1
+    shape_map: affine_map<() -> ()>
+  - !LinalgOperandDefConfig
+    name: rhs
+    kind: input_tensor
+    type_var: T1
+    shape_map: affine_map<() -> ()>
+  - !LinalgOperandDefConfig
+    name: O
+    kind: output_tensor
+    type_var: T1
+    shape_map: affine_map<() -> ()>
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<() -> ()>
+    - affine_map<() -> ()>
+    - affine_map<() -> ()>
+  iterator_types: []
+  assignments:
+  - !ScalarAssign
+    arg: O
+    value: !ScalarExpression
+      scalar_fn:
+        kind: binary
+        fn_name: min_signed
+        operands:
+        - !ScalarExpression
+          scalar_arg: lhs
+        - !ScalarExpression
+          scalar_arg: rhs
+--- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: matmul
   cpp_class_name: MatmulOp
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index 5b05364f6d35f3..64f566ce5f8ba6 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -239,6 +239,25 @@ def max(
     O[None] = BinaryFn.max_signed(lhs[None], rhs[None])
 
 
+ at linalg_structured_op
+def min(
+    lhs=TensorDef(T1),
+    rhs=TensorDef(T1),
+    O=TensorDef(T1, output=True),
+):
+    """Takes the min (signed) between two inputs, elementwise.
+
+    The shapes and element types must be identical. The appropriate casts,
+    broadcasts and reductions should be done previously to calling this op.
+
+    This means reduction/broadcast/element cast semantics is explicit. Further
+    passes can take that into account when lowering this code. For example,
+    a `linalg.broadcast` + `linalg.div` sequence can be lowered to a
+    `linalg.generic` with different affine maps for the two operands.
+    """
+    O[None] = BinaryFn.min_signed(lhs[None], rhs[None])
+
+
 @linalg_structured_op
 def matmul(
     A=TensorDef(T1, S.M, S.K),
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index e852824cdb7367..17264ac54ab158 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -590,6 +590,31 @@ func.func @generalize_max(%lhs: memref<7x14x21xf32>, %rhs: memref<7x14x21xf32>,
 
 // -----
 
+func.func @generalize_min(%lhs: memref<7x14x21xf32>, %rhs: memref<7x14x21xf32>,
+                          %out: memref<7x14x21xf32>) {
+  linalg.min ins(%lhs, %rhs : memref<7x14x21xf32>, memref<7x14x21xf32>)
+             outs(%out : memref<7x14x21xf32>)
+  return
+}
+
+// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+
+// CHECK: func @generalize_min
+// CHECK-SAME: (%[[LHS:.+]]: memref<7x14x21xf32>, %[[RHS:.+]]: memref<7x14x21xf32>,
+// CHECK-SAME:  %[[OUT:.+]]: memref<7x14x21xf32>)
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
+// CHECK-SAME:  ins(%[[LHS]], %[[RHS]] : memref<7x14x21xf32>, memref<7x14x21xf32>)
+// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xf32>)
+
+// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32)
+// CHECK-NEXT:      %[[min:.+]] = arith.minimumf %[[BBARG0]], %[[BBARG1]] : f32
+// CHECK-NEXT:      linalg.yield %[[min]] : f32
+
+// -----
+
 
 // CHECK-LABEL: func @fill_tensor
 func.func @fill_tensor(%f: f32, %v: vector<2x4xf32>) -> (tensor<f32>, tensor<vector<2x4xf32>>) {
diff --git a/mlir/test/Dialect/Linalg/named-ops-fail.mlir b/mlir/test/Dialect/Linalg/named-ops-fail.mlir
index c351e139a97e37..8f4c3e90b303ac 100644
--- a/mlir/test/Dialect/Linalg/named-ops-fail.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops-fail.mlir
@@ -189,3 +189,19 @@ func.func @max_broadcast(%arg0: memref<8x16xf32>, %arg1: memref<4x8x16xf32>, %ar
   linalg.max ins(%arg0, %arg1 : memref<8x16xf32>, memref<4x8x16xf32>) outs(%arg2: memref<4x8x16xf32>)
   return
 }
+
+// -----
+
+func.func @min_type_cast(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf16>, %arg2: memref<4x8x16xf32>) {
+  // CHECK: op requires the same type for all operands and results
+  linalg.min ins(%arg0, %arg1 : memref<4x8x16xf32>, memref<4x8x16xf16>) outs(%arg2: memref<4x8x16xf32>)
+  return
+}
+
+// -----
+
+func.func @min_broadcast(%arg0: memref<8x16xf32>, %arg1: memref<4x8x16xf32>, %arg2: memref<4x8x16xf32>) {
+  // CHECK: op expected operand rank (2) to match the result rank of indexing_map #0 (3)
+  linalg.min ins(%arg0, %arg1 : memref<8x16xf32>, memref<4x8x16xf32>) outs(%arg2: memref<4x8x16xf32>)
+  return
+}
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index 7064e1b3f9dc76..c7149e1a00d779 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1631,6 +1631,40 @@ func.func @max_tensor(%arg0: tensor<4x8x16xf32>, %arg1: tensor<4x8x16xf32>) -> t
 
 // -----
 
+// CHECK-LABEL: func @min_dynamic
+func.func @min_dynamic(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
+  // CHECK: linalg.min
+  // CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<?x?x?xf32>, memref<?x?x?xf32>)
+  // CHECK-SAME: outs(%{{.+}} : memref<?x?x?xf32>)
+  linalg.min ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2: memref<?x?x?xf32>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @min_static
+func.func @min_static(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf32>, %arg2: memref<4x8x16xf32>) {
+  // CHECK: linalg.min
+  // CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<4x8x16xf32>, memref<4x8x16xf32>)
+  // CHECK-SAME: outs(%{{.+}} : memref<4x8x16xf32>)
+  linalg.min ins(%arg0, %arg1 : memref<4x8x16xf32>, memref<4x8x16xf32>) outs(%arg2: memref<4x8x16xf32>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @min_tensor
+func.func @min_tensor(%arg0: tensor<4x8x16xf32>, %arg1: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> {
+  %0 = tensor.empty() : tensor<4x8x16xf32>
+  // CHECK: linalg.min
+  // CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<4x8x16xf32>, tensor<4x8x16xf32>)
+  // CHECK-SAME: outs(%{{.+}} : tensor<4x8x16xf32>)
+  %1 = linalg.min ins(%arg0, %arg1 : tensor<4x8x16xf32>, tensor<4x8x16xf32>) outs(%0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
+  return %1 : tensor<4x8x16xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @fill_tensor
 func.func @fill_tensor(%f: f32, %v: vector<2x4xf32>) -> (tensor<f32>, tensor<vector<2x4xf32>>) {
   %e0 = tensor.empty() : tensor<f32>

>From aba65e39b0b0fa84ef34023a1b2a9d934762c1a7 Mon Sep 17 00:00:00 2001
From: Renato Golin <rengolin at systemcall.eu>
Date: Fri, 26 Apr 2024 13:47:03 +0100
Subject: [PATCH 2/3] [MLIR][Linalg] Adding more unary named ops

---
 .../mlir/Dialect/Linalg/IR/LinalgEnums.td     |   8 +-
 .../Linalg/IR/LinalgNamedStructuredOps.yaml   | 210 ++++++++++++++++++
 mlir/include/mlir/IR/Builders.h               |   6 +-
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      |  20 ++
 mlir/lib/IR/Builders.cpp                      |  28 +--
 .../linalg/opdsl/lang/comprehension.py        |   5 +
 .../linalg/opdsl/ops/core_named_ops.py        |  60 +++++
 .../Dialect/Linalg/generalize-named-ops.mlir  | 130 +++++++++++
 mlir/test/Dialect/Linalg/named-ops-fail.mlir  |  96 ++++++++
 mlir/test/Dialect/Linalg/named-ops.mlir       | 186 ++++++++++++++++
 10 files changed, 728 insertions(+), 21 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
index 59f909aed8f61a..7a350d2c014262 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
@@ -22,7 +22,13 @@ def UnaryFn : I32EnumAttr<"UnaryFn", "", [
   I32EnumAttrCase<"abs", 2>,
   I32EnumAttrCase<"ceil", 3>,
   I32EnumAttrCase<"floor", 4>,
-  I32EnumAttrCase<"negf", 5>
+  I32EnumAttrCase<"negf", 5>,
+  I32EnumAttrCase<"reciprocal", 6>,
+  I32EnumAttrCase<"round", 7>,
+  I32EnumAttrCase<"sqrt", 8>,
+  I32EnumAttrCase<"rsqrt", 9>,
+  I32EnumAttrCase<"square", 10>,
+  I32EnumAttrCase<"tanh", 11>
 ]> {
   let genSpecializedAttr = 0;
   let cppNamespace = "::mlir::linalg";
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 32312d3671c726..b7567577347587 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -304,6 +304,216 @@ structured_op: !LinalgStructuredOpConfig
         - !ScalarExpression
           scalar_arg: I
 --- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: reciprocal
+  cpp_class_name: ReciprocalOp
+  doc: |-
+    Applies reciprocal(x) elementwise.
+
+    No numeric casting is performed on the input operand.
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !LinalgOperandDefConfig
+    name: I
+    kind: input_tensor
+    type_var: T1
+    shape_map: affine_map<() -> ()>
+  - !LinalgOperandDefConfig
+    name: O
+    kind: output_tensor
+    type_var: T1
+    shape_map: affine_map<() -> ()>
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<() -> ()>
+    - affine_map<() -> ()>
+  iterator_types: []
+  assignments:
+  - !ScalarAssign
+    arg: O
+    value: !ScalarExpression
+      scalar_fn:
+        kind: unary
+        fn_name: reciprocal
+        operands:
+        - !ScalarExpression
+          scalar_arg: I
+--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: round
+  cpp_class_name: RoundOp
+  doc: |-
+    Applies round(x) elementwise.
+
+    No numeric casting is performed on the input operand.
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !LinalgOperandDefConfig
+    name: I
+    kind: input_tensor
+    type_var: T1
+    shape_map: affine_map<() -> ()>
+  - !LinalgOperandDefConfig
+    name: O
+    kind: output_tensor
+    type_var: T1
+    shape_map: affine_map<() -> ()>
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<() -> ()>
+    - affine_map<() -> ()>
+  iterator_types: []
+  assignments:
+  - !ScalarAssign
+    arg: O
+    value: !ScalarExpression
+      scalar_fn:
+        kind: unary
+        fn_name: round
+        operands:
+        - !ScalarExpression
+          scalar_arg: I
+--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: sqrt
+  cpp_class_name: SqrtOp
+  doc: |-
+    Applies sqrt(x) elementwise.
+
+    No numeric casting is performed on the input operand.
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !LinalgOperandDefConfig
+    name: I
+    kind: input_tensor
+    type_var: T1
+    shape_map: affine_map<() -> ()>
+  - !LinalgOperandDefConfig
+    name: O
+    kind: output_tensor
+    type_var: T1
+    shape_map: affine_map<() -> ()>
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<() -> ()>
+    - affine_map<() -> ()>
+  iterator_types: []
+  assignments:
+  - !ScalarAssign
+    arg: O
+    value: !ScalarExpression
+      scalar_fn:
+        kind: unary
+        fn_name: sqrt
+        operands:
+        - !ScalarExpression
+          scalar_arg: I
+--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: rsqrt
+  cpp_class_name: RsqrtOp
+  doc: |-
+    Applies rsqrt(x) elementwise.
+
+    No numeric casting is performed on the input operand.
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !LinalgOperandDefConfig
+    name: I
+    kind: input_tensor
+    type_var: T1
+    shape_map: affine_map<() -> ()>
+  - !LinalgOperandDefConfig
+    name: O
+    kind: output_tensor
+    type_var: T1
+    shape_map: affine_map<() -> ()>
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<() -> ()>
+    - affine_map<() -> ()>
+  iterator_types: []
+  assignments:
+  - !ScalarAssign
+    arg: O
+    value: !ScalarExpression
+      scalar_fn:
+        kind: unary
+        fn_name: rsqrt
+        operands:
+        - !ScalarExpression
+          scalar_arg: I
+--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: square
+  cpp_class_name: SquareOp
+  doc: |-
+    Applies square(x) elementwise.
+
+    No numeric casting is performed on the input operand.
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !LinalgOperandDefConfig
+    name: I
+    kind: input_tensor
+    type_var: T1
+    shape_map: affine_map<() -> ()>
+  - !LinalgOperandDefConfig
+    name: O
+    kind: output_tensor
+    type_var: T1
+    shape_map: affine_map<() -> ()>
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<() -> ()>
+    - affine_map<() -> ()>
+  iterator_types: []
+  assignments:
+  - !ScalarAssign
+    arg: O
+    value: !ScalarExpression
+      scalar_fn:
+        kind: unary
+        fn_name: square
+        operands:
+        - !ScalarExpression
+          scalar_arg: I
+--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: tanh
+  cpp_class_name: TanhOp
+  doc: |-
+    Applies tanh(x) elementwise.
+
+    No numeric casting is performed on the input operand.
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !LinalgOperandDefConfig
+    name: I
+    kind: input_tensor
+    type_var: T1
+    shape_map: affine_map<() -> ()>
+  - !LinalgOperandDefConfig
+    name: O
+    kind: output_tensor
+    type_var: T1
+    shape_map: affine_map<() -> ()>
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<() -> ()>
+    - affine_map<() -> ()>
+  iterator_types: []
+  assignments:
+  - !ScalarAssign
+    arg: O
+    value: !ScalarExpression
+      scalar_fn:
+        kind: unary
+        fn_name: tanh
+        operands:
+        - !ScalarExpression
+          scalar_arg: I
+--- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: elemwise_binary
   cpp_class_name: ElemwiseBinaryOp
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 0d5fa719d0dee2..308c2cd38196cb 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -115,12 +115,16 @@ class Builder {
   ArrayAttr getArrayAttr(ArrayRef<Attribute> value);
 
   // Returns a 0-valued attribute of the given `type`. This function only
-  // supports boolean, integer, and 16-/32-/64-bit float types, and vector or
+  // supports integer, and 16-/32-/64-bit float types, and vector or
   // ranked tensor of them. Returns null attribute otherwise.
   TypedAttr getZeroAttr(Type type);
   // Returns a 1-valued attribute of the given `type`.
   // Type constraints are the same as `getZeroAttr`.
   TypedAttr getOneAttr(Type type);
+  // Returns a numeric attribute of the given `type`.
+  // Type constraints are the same as `getZeroAttr`.
+  // Non float types are converted before returning the attribute.
+  TypedAttr getNumberAttr(double value, Type type);
 
   // Convenience methods for fixed types.
   FloatAttr getF16FloatAttr(float value);
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 9c5c58fa1fabfb..6ee5864656cf64 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -395,6 +395,26 @@ class RegionBuilderHelper {
       return builder.create<math::FloorOp>(arg.getLoc(), arg);
     case UnaryFn::negf:
       return builder.create<arith::NegFOp>(arg.getLoc(), arg);
+    case UnaryFn::reciprocal:
+    {
+      Attribute oneAttr = builder.getNumberAttr(1.0, arg.getType());
+      auto one = builder.create<arith::ConstantOp>(arg.getLoc(), ::cast<TypedAttr>(oneAttr));
+      return builder.create<arith::DivFOp>(arg.getLoc(), one, arg);
+    }
+    case UnaryFn::round:
+      return builder.create<math::RoundOp>(arg.getLoc(), arg);
+    case UnaryFn::sqrt:
+      return builder.create<math::SqrtOp>(arg.getLoc(), arg);
+    case UnaryFn::rsqrt:
+      return builder.create<math::RsqrtOp>(arg.getLoc(), arg);
+    case UnaryFn::square:
+    {
+      Attribute twoAttr = builder.getNumberAttr(2.0, arg.getType());
+      auto two = builder.create<arith::ConstantOp>(arg.getLoc(), ::cast<TypedAttr>(twoAttr));
+      return builder.create<math::PowFOp>(arg.getLoc(), arg, two);
+    }
+    case UnaryFn::tanh:
+      return builder.create<math::TanhOp>(arg.getLoc(), arg);
     }
     llvm_unreachable("unsupported unary function");
   }
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index d49f69a7b7ae6b..ae8348c8454e9e 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -329,34 +329,24 @@ ArrayAttr Builder::getAffineMapArrayAttr(ArrayRef<AffineMap> values) {
 }
 
 TypedAttr Builder::getZeroAttr(Type type) {
-  if (llvm::isa<FloatType>(type))
-    return getFloatAttr(type, 0.0);
-  if (llvm::isa<IndexType>(type))
-    return getIndexAttr(0);
-  if (llvm::dyn_cast<IntegerType>(type))
-    return getIntegerAttr(type,
-                          APInt(llvm::cast<IntegerType>(type).getWidth(), 0));
-  if (llvm::isa<RankedTensorType, VectorType>(type)) {
-    auto vtType = llvm::cast<ShapedType>(type);
-    auto element = getZeroAttr(vtType.getElementType());
-    if (!element)
-      return {};
-    return DenseElementsAttr::get(vtType, element);
-  }
-  return {};
+  return getNumberAttr(0.0, type);
 }
 
 TypedAttr Builder::getOneAttr(Type type) {
+  return getNumberAttr(1.0, type);
+}
+
+TypedAttr Builder::getNumberAttr(double value, Type type) {
   if (llvm::isa<FloatType>(type))
-    return getFloatAttr(type, 1.0);
+    return getFloatAttr(type, value);
   if (llvm::isa<IndexType>(type))
-    return getIndexAttr(1);
+    return getIndexAttr(static_cast<int64_t>(value));
   if (llvm::dyn_cast<IntegerType>(type))
     return getIntegerAttr(type,
-                          APInt(llvm::cast<IntegerType>(type).getWidth(), 1));
+                          APInt(llvm::cast<IntegerType>(type).getWidth(), static_cast<int64_t>(value)));
   if (llvm::isa<RankedTensorType, VectorType>(type)) {
     auto vtType = llvm::cast<ShapedType>(type);
-    auto element = getOneAttr(vtType.getElementType());
+    auto element = getNumberAttr(value, vtType.getElementType());
     if (!element)
       return {};
     return DenseElementsAttr::get(vtType, element);
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
index 23d6d26b7e294c..f7bc81bd2f6833 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
@@ -291,6 +291,11 @@ class UnaryFn:
     ceil = UnaryFnType("ceil")
     floor = UnaryFnType("floor")
     negf = UnaryFnType("negf")
+    round = UnaryFnType("round")
+    sqrt = UnaryFnType("sqrt")
+    rsqrt = UnaryFnType("rsqrt")
+    square = UnaryFnType("square")
+    tanh = UnaryFnType("tanh")
 
 
 class BinaryFnType:
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index 64f566ce5f8ba6..c97ac448006505 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -108,6 +108,66 @@ def negf(
     O[None] = UnaryFn.negf(I[None])
 
 
+ at linalg_structured_op
+def round(
+    I=TensorDef(T1),
+    O=TensorDef(T1, output=True),
+):
+    """Applies round(x) elementwise.
+
+    No numeric casting is performed on the input operand.
+    """
+    O[None] = UnaryFn.round(I[None])
+
+
+ at linalg_structured_op
+def sqrt(
+    I=TensorDef(T1),
+    O=TensorDef(T1, output=True),
+):
+    """Applies sqrt(x) elementwise.
+
+    No numeric casting is performed on the input operand.
+    """
+    O[None] = UnaryFn.sqrt(I[None])
+
+
+ at linalg_structured_op
+def rsqrt(
+    I=TensorDef(T1),
+    O=TensorDef(T1, output=True),
+):
+    """Applies rsqrt(x) elementwise.
+
+    No numeric casting is performed on the input operand.
+    """
+    O[None] = UnaryFn.rsqrt(I[None])
+
+
+ at linalg_structured_op
+def square(
+    I=TensorDef(T1),
+    O=TensorDef(T1, output=True),
+):
+    """Applies square(x) elementwise.
+
+    No numeric casting is performed on the input operand.
+    """
+    O[None] = UnaryFn.square(I[None])
+
+
+ at linalg_structured_op
+def tanh(
+    I=TensorDef(T1),
+    O=TensorDef(T1, output=True),
+):
+    """Applies tanh(x) elementwise.
+
+    No numeric casting is performed on the input operand.
+    """
+    O[None] = UnaryFn.tanh(I[None])
+
+
 @linalg_structured_op
 def elemwise_binary(
     lhs=TensorDef(T1),
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index 17264ac54ab158..160d1e9536bb5b 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -565,6 +565,136 @@ func.func @generalize_negf(%arg: memref<7x14x21xf32>, %out: memref<7x14x21xf32>)
 
 // -----
 
+func.func @generalize_reciprocal(%arg: memref<7x14x21xf32>, %out: memref<7x14x21xf32>) {
+  linalg.reciprocal ins(%arg : memref<7x14x21xf32>) outs(%out : memref<7x14x21xf32>)
+  return
+}
+
+// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+
+// CHECK: func @generalize_reciprocal
+// CHECK-SAME: (%[[ARG:.+]]: memref<7x14x21xf32>, %[[OUT:.+]]: memref<7x14x21xf32>)
+
+// CHECK: %[[one:.+]] = arith.constant 1.000000e+00 : f32
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
+// CHECK-SAME:  ins(%[[LHS]] : memref<7x14x21xf32>) outs(%[[OUT]] : memref<7x14x21xf32>)
+
+// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32)
+// CHECK-NEXT:      %[[reciprocal:.+]] = arith.divf %[[one]], %[[BBARG0]] : f32
+// CHECK-NEXT:      linalg.yield %[[reciprocal]] : f32
+
+// -----
+
+func.func @generalize_round(%arg: memref<7x14x21xf32>, %out: memref<7x14x21xf32>) {
+  linalg.round ins(%arg : memref<7x14x21xf32>) outs(%out : memref<7x14x21xf32>)
+  return
+}
+
+// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+
+// CHECK: func @generalize_round
+// CHECK-SAME: (%[[ARG:.+]]: memref<7x14x21xf32>, %[[OUT:.+]]: memref<7x14x21xf32>)
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
+// CHECK-SAME:  ins(%[[LHS]] : memref<7x14x21xf32>) outs(%[[OUT]] : memref<7x14x21xf32>)
+
+// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32)
+// CHECK-NEXT:      %[[round:.+]] = math.round %[[BBARG0]] : f32
+// CHECK-NEXT:      linalg.yield %[[round]] : f32
+
+// -----
+
+func.func @generalize_sqrt(%arg: memref<7x14x21xf32>, %out: memref<7x14x21xf32>) {
+  linalg.sqrt ins(%arg : memref<7x14x21xf32>) outs(%out : memref<7x14x21xf32>)
+  return
+}
+
+// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+
+// CHECK: func @generalize_sqrt
+// CHECK-SAME: (%[[ARG:.+]]: memref<7x14x21xf32>, %[[OUT:.+]]: memref<7x14x21xf32>)
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
+// CHECK-SAME:  ins(%[[LHS]] : memref<7x14x21xf32>) outs(%[[OUT]] : memref<7x14x21xf32>)
+
+// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32)
+// CHECK-NEXT:      %[[sqrt:.+]] = math.sqrt %[[BBARG0]] : f32
+// CHECK-NEXT:      linalg.yield %[[sqrt]] : f32
+
+// -----
+
+func.func @generalize_rsqrt(%arg: memref<7x14x21xf32>, %out: memref<7x14x21xf32>) {
+  linalg.rsqrt ins(%arg : memref<7x14x21xf32>) outs(%out : memref<7x14x21xf32>)
+  return
+}
+
+// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+
+// CHECK: func @generalize_rsqrt
+// CHECK-SAME: (%[[ARG:.+]]: memref<7x14x21xf32>, %[[OUT:.+]]: memref<7x14x21xf32>)
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
+// CHECK-SAME:  ins(%[[LHS]] : memref<7x14x21xf32>) outs(%[[OUT]] : memref<7x14x21xf32>)
+
+// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32)
+// CHECK-NEXT:      %[[rsqrt:.+]] = math.rsqrt %[[BBARG0]] : f32
+// CHECK-NEXT:      linalg.yield %[[rsqrt]] : f32
+
+// -----
+
+func.func @generalize_square(%arg: memref<7x14x21xf32>, %out: memref<7x14x21xf32>) {
+  linalg.square ins(%arg : memref<7x14x21xf32>) outs(%out : memref<7x14x21xf32>)
+  return
+}
+
+// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+
+// CHECK: func @generalize_square
+// CHECK-SAME: (%[[ARG:.+]]: memref<7x14x21xf32>, %[[OUT:.+]]: memref<7x14x21xf32>)
+
+// CHECK: %[[two:.+]] = arith.constant 2.000000e+00 : f32
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
+// CHECK-SAME:  ins(%[[LHS]] : memref<7x14x21xf32>) outs(%[[OUT]] : memref<7x14x21xf32>)
+
+// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32)
+// CHECK-NEXT:      %[[square:.+]] = math.powf %[[BBARG0]], %[[two]] : f32
+// CHECK-NEXT:      linalg.yield %[[square]] : f32
+
+// -----
+
+func.func @generalize_tanh(%arg: memref<7x14x21xf32>, %out: memref<7x14x21xf32>) {
+  linalg.tanh ins(%arg : memref<7x14x21xf32>) outs(%out : memref<7x14x21xf32>)
+  return
+}
+
+// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+
+// CHECK: func @generalize_tanh
+// CHECK-SAME: (%[[ARG:.+]]: memref<7x14x21xf32>, %[[OUT:.+]]: memref<7x14x21xf32>)
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
+// CHECK-SAME:  ins(%[[LHS]] : memref<7x14x21xf32>) outs(%[[OUT]] : memref<7x14x21xf32>)
+
+// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32)
+// CHECK-NEXT:      %[[tanh:.+]] = math.tanh %[[BBARG0]] : f32
+// CHECK-NEXT:      linalg.yield %[[tanh]] : f32
+
+// -----
+
 func.func @generalize_max(%lhs: memref<7x14x21xf32>, %rhs: memref<7x14x21xf32>,
                           %out: memref<7x14x21xf32>) {
   linalg.max ins(%lhs, %rhs : memref<7x14x21xf32>, memref<7x14x21xf32>)
diff --git a/mlir/test/Dialect/Linalg/named-ops-fail.mlir b/mlir/test/Dialect/Linalg/named-ops-fail.mlir
index 8f4c3e90b303ac..f66608e71ffc64 100644
--- a/mlir/test/Dialect/Linalg/named-ops-fail.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops-fail.mlir
@@ -176,6 +176,102 @@ func.func @negf_broadcast(%arg: memref<8x16xf32>, %out: memref<4x8x16xf32>) {
 
 // -----
 
+func.func @reciprocal_type_cast(%arg: memref<4x8x16xf16>, %out: memref<4x8x16xf32>) {
+  // CHECK: operand 1 ('f16') doesn't match the element type of the enclosing linalg.generic op ('f32')
+  linalg.reciprocal ins(%arg : memref<4x8x16xf16>) outs(%out: memref<4x8x16xf32>)
+  return
+}
+
+// -----
+
+func.func @reciprocal_broadcast(%arg: memref<8x16xf32>, %out: memref<4x8x16xf32>) {
+  // CHECK: op expected operand rank (2) to match the result rank of indexing_map #0 (3)
+  linalg.reciprocal ins(%arg : memref<8x16xf32>) outs(%out: memref<4x8x16xf32>)
+  return
+}
+
+// -----
+
+func.func @round_type_cast(%arg: memref<4x8x16xf16>, %out: memref<4x8x16xf32>) {
+  // CHECK: operand 1 ('f16') doesn't match the element type of the enclosing linalg.generic op ('f32')
+  linalg.round ins(%arg : memref<4x8x16xf16>) outs(%out: memref<4x8x16xf32>)
+  return
+}
+
+// -----
+
+func.func @round_broadcast(%arg: memref<8x16xf32>, %out: memref<4x8x16xf32>) {
+  // CHECK: op expected operand rank (2) to match the result rank of indexing_map #0 (3)
+  linalg.round ins(%arg : memref<8x16xf32>) outs(%out: memref<4x8x16xf32>)
+  return
+}
+
+// -----
+
+func.func @sqrt_type_cast(%arg: memref<4x8x16xf16>, %out: memref<4x8x16xf32>) {
+  // CHECK: operand 1 ('f16') doesn't match the element type of the enclosing linalg.generic op ('f32')
+  linalg.sqrt ins(%arg : memref<4x8x16xf16>) outs(%out: memref<4x8x16xf32>)
+  return
+}
+
+// -----
+
+func.func @sqrt_broadcast(%arg: memref<8x16xf32>, %out: memref<4x8x16xf32>) {
+  // CHECK: op expected operand rank (2) to match the result rank of indexing_map #0 (3)
+  linalg.sqrt ins(%arg : memref<8x16xf32>) outs(%out: memref<4x8x16xf32>)
+  return
+}
+
+// -----
+
+func.func @rsqrt_type_cast(%arg: memref<4x8x16xf16>, %out: memref<4x8x16xf32>) {
+  // CHECK: operand 1 ('f16') doesn't match the element type of the enclosing linalg.generic op ('f32')
+  linalg.rsqrt ins(%arg : memref<4x8x16xf16>) outs(%out: memref<4x8x16xf32>)
+  return
+}
+
+// -----
+
+func.func @rsqrt_broadcast(%arg: memref<8x16xf32>, %out: memref<4x8x16xf32>) {
+  // CHECK: op expected operand rank (2) to match the result rank of indexing_map #0 (3)
+  linalg.rsqrt ins(%arg : memref<8x16xf32>) outs(%out: memref<4x8x16xf32>)
+  return
+}
+
+// -----
+
+func.func @square_type_cast(%arg: memref<4x8x16xf16>, %out: memref<4x8x16xf32>) {
+  // CHECK: operand 1 ('f16') doesn't match the element type of the enclosing linalg.generic op ('f32')
+  linalg.square ins(%arg : memref<4x8x16xf16>) outs(%out: memref<4x8x16xf32>)
+  return
+}
+
+// -----
+
+func.func @square_broadcast(%arg: memref<8x16xf32>, %out: memref<4x8x16xf32>) {
+  // CHECK: op expected operand rank (2) to match the result rank of indexing_map #0 (3)
+  linalg.square ins(%arg : memref<8x16xf32>) outs(%out: memref<4x8x16xf32>)
+  return
+}
+
+// -----
+
+func.func @tanh_type_cast(%arg: memref<4x8x16xf16>, %out: memref<4x8x16xf32>) {
+  // CHECK: operand 1 ('f16') doesn't match the element type of the enclosing linalg.generic op ('f32')
+  linalg.tanh ins(%arg : memref<4x8x16xf16>) outs(%out: memref<4x8x16xf32>)
+  return
+}
+
+// -----
+
+func.func @tanh_broadcast(%arg: memref<8x16xf32>, %out: memref<4x8x16xf32>) {
+  // CHECK: op expected operand rank (2) to match the result rank of indexing_map #0 (3)
+  linalg.tanh ins(%arg : memref<8x16xf32>) outs(%out: memref<4x8x16xf32>)
+  return
+}
+
+// -----
+
 func.func @max_type_cast(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf16>, %arg2: memref<4x8x16xf32>) {
   // CHECK: op requires the same type for all operands and results
   linalg.max ins(%arg0, %arg1 : memref<4x8x16xf32>, memref<4x8x16xf16>) outs(%arg2: memref<4x8x16xf32>)
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index c7149e1a00d779..cf59f673610013 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1597,6 +1597,192 @@ func.func @negf_tensor(%arg0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> {
 
 // -----
 
+// CHECK-LABEL: func @reciprocal_dynamic
+func.func @reciprocal_dynamic(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>) {
+  // CHECK: linalg.reciprocal
+  // CHECK-SAME: ins(%{{.+}} : memref<?x?x?xf32>) outs(%{{.+}} : memref<?x?x?xf32>)
+  linalg.reciprocal ins(%arg0 : memref<?x?x?xf32>) outs(%arg1: memref<?x?x?xf32>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @reciprocal_static
+func.func @reciprocal_static(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf32>) {
+  // CHECK: linalg.reciprocal
+  // CHECK-SAME: ins(%{{.+}} : memref<4x8x16xf32>) outs(%{{.+}} : memref<4x8x16xf32>)
+  linalg.reciprocal ins(%arg0 : memref<4x8x16xf32>) outs(%arg1: memref<4x8x16xf32>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @reciprocal_tensor
+func.func @reciprocal_tensor(%arg0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> {
+  %0 = tensor.empty() : tensor<4x8x16xf32>
+  // CHECK: linalg.reciprocal
+  // CHECK-SAME: ins(%{{.+}} : tensor<4x8x16xf32>) outs(%{{.+}} : tensor<4x8x16xf32>)
+  %1 = linalg.reciprocal ins(%arg0 : tensor<4x8x16xf32>) outs(%0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
+  return %1 : tensor<4x8x16xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @round_dynamic
+func.func @round_dynamic(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>) {
+  // CHECK: linalg.round
+  // CHECK-SAME: ins(%{{.+}} : memref<?x?x?xf32>) outs(%{{.+}} : memref<?x?x?xf32>)
+  linalg.round ins(%arg0 : memref<?x?x?xf32>) outs(%arg1: memref<?x?x?xf32>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @round_static
+func.func @round_static(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf32>) {
+  // CHECK: linalg.round
+  // CHECK-SAME: ins(%{{.+}} : memref<4x8x16xf32>) outs(%{{.+}} : memref<4x8x16xf32>)
+  linalg.round ins(%arg0 : memref<4x8x16xf32>) outs(%arg1: memref<4x8x16xf32>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @round_tensor
+func.func @round_tensor(%arg0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> {
+  %0 = tensor.empty() : tensor<4x8x16xf32>
+  // CHECK: linalg.round
+  // CHECK-SAME: ins(%{{.+}} : tensor<4x8x16xf32>) outs(%{{.+}} : tensor<4x8x16xf32>)
+  %1 = linalg.round ins(%arg0 : tensor<4x8x16xf32>) outs(%0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
+  return %1 : tensor<4x8x16xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @sqrt_dynamic
+func.func @sqrt_dynamic(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>) {
+  // CHECK: linalg.sqrt
+  // CHECK-SAME: ins(%{{.+}} : memref<?x?x?xf32>) outs(%{{.+}} : memref<?x?x?xf32>)
+  linalg.sqrt ins(%arg0 : memref<?x?x?xf32>) outs(%arg1: memref<?x?x?xf32>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @sqrt_static
+func.func @sqrt_static(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf32>) {
+  // CHECK: linalg.sqrt
+  // CHECK-SAME: ins(%{{.+}} : memref<4x8x16xf32>) outs(%{{.+}} : memref<4x8x16xf32>)
+  linalg.sqrt ins(%arg0 : memref<4x8x16xf32>) outs(%arg1: memref<4x8x16xf32>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @sqrt_tensor
+func.func @sqrt_tensor(%arg0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> {
+  %0 = tensor.empty() : tensor<4x8x16xf32>
+  // CHECK: linalg.sqrt
+  // CHECK-SAME: ins(%{{.+}} : tensor<4x8x16xf32>) outs(%{{.+}} : tensor<4x8x16xf32>)
+  %1 = linalg.sqrt ins(%arg0 : tensor<4x8x16xf32>) outs(%0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
+  return %1 : tensor<4x8x16xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @rsqrt_dynamic
+func.func @rsqrt_dynamic(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>) {
+  // CHECK: linalg.rsqrt
+  // CHECK-SAME: ins(%{{.+}} : memref<?x?x?xf32>) outs(%{{.+}} : memref<?x?x?xf32>)
+  linalg.rsqrt ins(%arg0 : memref<?x?x?xf32>) outs(%arg1: memref<?x?x?xf32>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @rsqrt_static
+func.func @rsqrt_static(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf32>) {
+  // CHECK: linalg.rsqrt
+  // CHECK-SAME: ins(%{{.+}} : memref<4x8x16xf32>) outs(%{{.+}} : memref<4x8x16xf32>)
+  linalg.rsqrt ins(%arg0 : memref<4x8x16xf32>) outs(%arg1: memref<4x8x16xf32>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @rsqrt_tensor
+func.func @rsqrt_tensor(%arg0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> {
+  %0 = tensor.empty() : tensor<4x8x16xf32>
+  // CHECK: linalg.rsqrt
+  // CHECK-SAME: ins(%{{.+}} : tensor<4x8x16xf32>) outs(%{{.+}} : tensor<4x8x16xf32>)
+  %1 = linalg.rsqrt ins(%arg0 : tensor<4x8x16xf32>) outs(%0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
+  return %1 : tensor<4x8x16xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @square_dynamic
+func.func @square_dynamic(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>) {
+  // CHECK: linalg.square
+  // CHECK-SAME: ins(%{{.+}} : memref<?x?x?xf32>) outs(%{{.+}} : memref<?x?x?xf32>)
+  linalg.square ins(%arg0 : memref<?x?x?xf32>) outs(%arg1: memref<?x?x?xf32>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @square_static
+func.func @square_static(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf32>) {
+  // CHECK: linalg.square
+  // CHECK-SAME: ins(%{{.+}} : memref<4x8x16xf32>) outs(%{{.+}} : memref<4x8x16xf32>)
+  linalg.square ins(%arg0 : memref<4x8x16xf32>) outs(%arg1: memref<4x8x16xf32>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @square_tensor
+func.func @square_tensor(%arg0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> {
+  %0 = tensor.empty() : tensor<4x8x16xf32>
+  // CHECK: linalg.square
+  // CHECK-SAME: ins(%{{.+}} : tensor<4x8x16xf32>) outs(%{{.+}} : tensor<4x8x16xf32>)
+  %1 = linalg.square ins(%arg0 : tensor<4x8x16xf32>) outs(%0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
+  return %1 : tensor<4x8x16xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @tanh_dynamic
+func.func @tanh_dynamic(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>) {
+  // CHECK: linalg.tanh
+  // CHECK-SAME: ins(%{{.+}} : memref<?x?x?xf32>) outs(%{{.+}} : memref<?x?x?xf32>)
+  linalg.tanh ins(%arg0 : memref<?x?x?xf32>) outs(%arg1: memref<?x?x?xf32>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @tanh_static
+func.func @tanh_static(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf32>) {
+  // CHECK: linalg.tanh
+  // CHECK-SAME: ins(%{{.+}} : memref<4x8x16xf32>) outs(%{{.+}} : memref<4x8x16xf32>)
+  linalg.tanh ins(%arg0 : memref<4x8x16xf32>) outs(%arg1: memref<4x8x16xf32>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @tanh_tensor
+func.func @tanh_tensor(%arg0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> {
+  %0 = tensor.empty() : tensor<4x8x16xf32>
+  // CHECK: linalg.tanh
+  // CHECK-SAME: ins(%{{.+}} : tensor<4x8x16xf32>) outs(%{{.+}} : tensor<4x8x16xf32>)
+  %1 = linalg.tanh ins(%arg0 : tensor<4x8x16xf32>) outs(%0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
+  return %1 : tensor<4x8x16xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @max_dynamic
 func.func @max_dynamic(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
   // CHECK: linalg.max

>From 1b178d1e313b281733c5707d70a9248a37c80fbd Mon Sep 17 00:00:00 2001
From: Renato Golin <rengolin at systemcall.eu>
Date: Fri, 26 Apr 2024 18:21:03 +0100
Subject: [PATCH 3/3] clang-format

---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 12 ++++++------
 mlir/lib/IR/Builders.cpp                 | 12 ++++--------
 2 files changed, 10 insertions(+), 14 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 6ee5864656cf64..f648165727c86f 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -395,10 +395,10 @@ class RegionBuilderHelper {
       return builder.create<math::FloorOp>(arg.getLoc(), arg);
     case UnaryFn::negf:
       return builder.create<arith::NegFOp>(arg.getLoc(), arg);
-    case UnaryFn::reciprocal:
-    {
+    case UnaryFn::reciprocal: {
       Attribute oneAttr = builder.getNumberAttr(1.0, arg.getType());
-      auto one = builder.create<arith::ConstantOp>(arg.getLoc(), ::cast<TypedAttr>(oneAttr));
+      auto one = builder.create<arith::ConstantOp>(arg.getLoc(),
+                                                   ::cast<TypedAttr>(oneAttr));
       return builder.create<arith::DivFOp>(arg.getLoc(), one, arg);
     }
     case UnaryFn::round:
@@ -407,10 +407,10 @@ class RegionBuilderHelper {
       return builder.create<math::SqrtOp>(arg.getLoc(), arg);
     case UnaryFn::rsqrt:
       return builder.create<math::RsqrtOp>(arg.getLoc(), arg);
-    case UnaryFn::square:
-    {
+    case UnaryFn::square: {
       Attribute twoAttr = builder.getNumberAttr(2.0, arg.getType());
-      auto two = builder.create<arith::ConstantOp>(arg.getLoc(), ::cast<TypedAttr>(twoAttr));
+      auto two = builder.create<arith::ConstantOp>(arg.getLoc(),
+                                                   ::cast<TypedAttr>(twoAttr));
       return builder.create<math::PowFOp>(arg.getLoc(), arg, two);
     }
     case UnaryFn::tanh:
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index ae8348c8454e9e..8c31da284c97bf 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -328,13 +328,9 @@ ArrayAttr Builder::getAffineMapArrayAttr(ArrayRef<AffineMap> values) {
   return getArrayAttr(attrs);
 }
 
-TypedAttr Builder::getZeroAttr(Type type) {
-  return getNumberAttr(0.0, type);
-}
+TypedAttr Builder::getZeroAttr(Type type) { return getNumberAttr(0.0, type); }
 
-TypedAttr Builder::getOneAttr(Type type) {
-  return getNumberAttr(1.0, type);
-}
+TypedAttr Builder::getOneAttr(Type type) { return getNumberAttr(1.0, type); }
 
 TypedAttr Builder::getNumberAttr(double value, Type type) {
   if (llvm::isa<FloatType>(type))
@@ -342,8 +338,8 @@ TypedAttr Builder::getNumberAttr(double value, Type type) {
   if (llvm::isa<IndexType>(type))
     return getIndexAttr(static_cast<int64_t>(value));
   if (llvm::dyn_cast<IntegerType>(type))
-    return getIntegerAttr(type,
-                          APInt(llvm::cast<IntegerType>(type).getWidth(), static_cast<int64_t>(value)));
+    return getIntegerAttr(type, APInt(llvm::cast<IntegerType>(type).getWidth(),
+                                      static_cast<int64_t>(value)));
   if (llvm::isa<RankedTensorType, VectorType>(type)) {
     auto vtType = llvm::cast<ShapedType>(type);
     auto element = getNumberAttr(value, vtType.getElementType());



More information about the Mlir-commits mailing list