[Mlir-commits] [mlir] [MLIR][Linalg] More Linalg named ops (PR #90236)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Apr 26 10:13:49 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-core

Author: Renato Golin (rengolin)

<details>
<summary>Changes</summary>

Adding `min` that was already implemented but not exposed.

Adding a few additional unary ops:
* Reciprocal as `arith.div(1,arg)`
* Round as `math.round(arg)`
* Sqrt as `math.sqrt(arg)`
* Rsqrt as `math.rsqrt(arg)`
* Square as `math.powf(arg, 2)`
* TanH as `math.tanh(arg)`

All with the agreed semantics at the round table: no implicit broadcast/type cast.

Small update on builder to get a numeric attribute (and update the zero/one functions to use it).

This is just the first step. Soon after this we'll add explicit type casts, select, clamp, square difference, generic pow.

After that, we'll discuss the semantics of `linalg.softmax` and how it should be lowered to existing named ops.

---

Patch is 35.95 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/90236.diff


10 Files Affected:

- (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td (+7-1) 
- (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml (+260-1) 
- (modified) mlir/include/mlir/IR/Builders.h (+5-1) 
- (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+20) 
- (modified) mlir/lib/IR/Builders.cpp (+9-19) 
- (modified) mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py (+5) 
- (modified) mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py (+79) 
- (modified) mlir/test/Dialect/Linalg/generalize-named-ops.mlir (+155) 
- (modified) mlir/test/Dialect/Linalg/named-ops-fail.mlir (+112) 
- (modified) mlir/test/Dialect/Linalg/named-ops.mlir (+220) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
index 59f909aed8f61a..7a350d2c014262 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
@@ -22,7 +22,13 @@ def UnaryFn : I32EnumAttr<"UnaryFn", "", [
   I32EnumAttrCase<"abs", 2>,
   I32EnumAttrCase<"ceil", 3>,
   I32EnumAttrCase<"floor", 4>,
-  I32EnumAttrCase<"negf", 5>
+  I32EnumAttrCase<"negf", 5>,
+  I32EnumAttrCase<"reciprocal", 6>,
+  I32EnumAttrCase<"round", 7>,
+  I32EnumAttrCase<"sqrt", 8>,
+  I32EnumAttrCase<"rsqrt", 9>,
+  I32EnumAttrCase<"square", 10>,
+  I32EnumAttrCase<"tanh", 11>
 ]> {
   let genSpecializedAttr = 0;
   let cppNamespace = "::mlir::linalg";
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 1ff6c4086cf357..b7567577347587 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -304,6 +304,216 @@ structured_op: !LinalgStructuredOpConfig
         - !ScalarExpression
           scalar_arg: I
 --- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: reciprocal
+  cpp_class_name: ReciprocalOp
+  doc: |-
+    Applies reciprocal(x) elementwise.
+
+    No numeric casting is performed on the input operand.
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !LinalgOperandDefConfig
+    name: I
+    kind: input_tensor
+    type_var: T1
+    shape_map: affine_map<() -> ()>
+  - !LinalgOperandDefConfig
+    name: O
+    kind: output_tensor
+    type_var: T1
+    shape_map: affine_map<() -> ()>
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<() -> ()>
+    - affine_map<() -> ()>
+  iterator_types: []
+  assignments:
+  - !ScalarAssign
+    arg: O
+    value: !ScalarExpression
+      scalar_fn:
+        kind: unary
+        fn_name: reciprocal
+        operands:
+        - !ScalarExpression
+          scalar_arg: I
+--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: round
+  cpp_class_name: RoundOp
+  doc: |-
+    Applies round(x) elementwise.
+
+    No numeric casting is performed on the input operand.
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !LinalgOperandDefConfig
+    name: I
+    kind: input_tensor
+    type_var: T1
+    shape_map: affine_map<() -> ()>
+  - !LinalgOperandDefConfig
+    name: O
+    kind: output_tensor
+    type_var: T1
+    shape_map: affine_map<() -> ()>
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<() -> ()>
+    - affine_map<() -> ()>
+  iterator_types: []
+  assignments:
+  - !ScalarAssign
+    arg: O
+    value: !ScalarExpression
+      scalar_fn:
+        kind: unary
+        fn_name: round
+        operands:
+        - !ScalarExpression
+          scalar_arg: I
+--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: sqrt
+  cpp_class_name: SqrtOp
+  doc: |-
+    Applies sqrt(x) elementwise.
+
+    No numeric casting is performed on the input operand.
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !LinalgOperandDefConfig
+    name: I
+    kind: input_tensor
+    type_var: T1
+    shape_map: affine_map<() -> ()>
+  - !LinalgOperandDefConfig
+    name: O
+    kind: output_tensor
+    type_var: T1
+    shape_map: affine_map<() -> ()>
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<() -> ()>
+    - affine_map<() -> ()>
+  iterator_types: []
+  assignments:
+  - !ScalarAssign
+    arg: O
+    value: !ScalarExpression
+      scalar_fn:
+        kind: unary
+        fn_name: sqrt
+        operands:
+        - !ScalarExpression
+          scalar_arg: I
+--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: rsqrt
+  cpp_class_name: RsqrtOp
+  doc: |-
+    Applies rsqrt(x) elementwise.
+
+    No numeric casting is performed on the input operand.
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !LinalgOperandDefConfig
+    name: I
+    kind: input_tensor
+    type_var: T1
+    shape_map: affine_map<() -> ()>
+  - !LinalgOperandDefConfig
+    name: O
+    kind: output_tensor
+    type_var: T1
+    shape_map: affine_map<() -> ()>
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<() -> ()>
+    - affine_map<() -> ()>
+  iterator_types: []
+  assignments:
+  - !ScalarAssign
+    arg: O
+    value: !ScalarExpression
+      scalar_fn:
+        kind: unary
+        fn_name: rsqrt
+        operands:
+        - !ScalarExpression
+          scalar_arg: I
+--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: square
+  cpp_class_name: SquareOp
+  doc: |-
+    Applies square(x) elementwise.
+
+    No numeric casting is performed on the input operand.
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !LinalgOperandDefConfig
+    name: I
+    kind: input_tensor
+    type_var: T1
+    shape_map: affine_map<() -> ()>
+  - !LinalgOperandDefConfig
+    name: O
+    kind: output_tensor
+    type_var: T1
+    shape_map: affine_map<() -> ()>
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<() -> ()>
+    - affine_map<() -> ()>
+  iterator_types: []
+  assignments:
+  - !ScalarAssign
+    arg: O
+    value: !ScalarExpression
+      scalar_fn:
+        kind: unary
+        fn_name: square
+        operands:
+        - !ScalarExpression
+          scalar_arg: I
+--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: tanh
+  cpp_class_name: TanhOp
+  doc: |-
+    Applies tanh(x) elementwise.
+
+    No numeric casting is performed on the input operand.
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !LinalgOperandDefConfig
+    name: I
+    kind: input_tensor
+    type_var: T1
+    shape_map: affine_map<() -> ()>
+  - !LinalgOperandDefConfig
+    name: O
+    kind: output_tensor
+    type_var: T1
+    shape_map: affine_map<() -> ()>
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<() -> ()>
+    - affine_map<() -> ()>
+  iterator_types: []
+  assignments:
+  - !ScalarAssign
+    arg: O
+    value: !ScalarExpression
+      scalar_fn:
+        kind: unary
+        fn_name: tanh
+        operands:
+        - !ScalarExpression
+          scalar_arg: I
+--- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: elemwise_binary
   cpp_class_name: ElemwiseBinaryOp
@@ -625,7 +835,7 @@ metadata: !LinalgOpMetadata
 
     This means reduction/broadcast/element cast semantics is explicit. Further
     passes can take that into account when lowering this code. For example,
-    a `linalg.broadcast` + `linalg.div` sequence can be lowered to a
+    a `linalg.broadcast` + `linalg.max` sequence can be lowered to a
     `linalg.generic` with different affine maps for the two operands.
 structured_op: !LinalgStructuredOpConfig
   args:
@@ -663,6 +873,55 @@ structured_op: !LinalgStructuredOpConfig
         - !ScalarExpression
           scalar_arg: rhs
 --- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: min
+  cpp_class_name: MinOp
+  doc: |-
+    Takes the min (signed) between two inputs, elementwise.
+
+    The shapes and element types must be identical. The appropriate casts,
+    broadcasts and reductions should be done previously to calling this op.
+
+    This means reduction/broadcast/element cast semantics is explicit. Further
+    passes can take that into account when lowering this code. For example,
+    a `linalg.broadcast` + `linalg.min` sequence can be lowered to a
+    `linalg.generic` with different affine maps for the two operands.
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !LinalgOperandDefConfig
+    name: lhs
+    kind: input_tensor
+    type_var: T1
+    shape_map: affine_map<() -> ()>
+  - !LinalgOperandDefConfig
+    name: rhs
+    kind: input_tensor
+    type_var: T1
+    shape_map: affine_map<() -> ()>
+  - !LinalgOperandDefConfig
+    name: O
+    kind: output_tensor
+    type_var: T1
+    shape_map: affine_map<() -> ()>
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<() -> ()>
+    - affine_map<() -> ()>
+    - affine_map<() -> ()>
+  iterator_types: []
+  assignments:
+  - !ScalarAssign
+    arg: O
+    value: !ScalarExpression
+      scalar_fn:
+        kind: binary
+        fn_name: min_signed
+        operands:
+        - !ScalarExpression
+          scalar_arg: lhs
+        - !ScalarExpression
+          scalar_arg: rhs
+--- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: matmul
   cpp_class_name: MatmulOp
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 0d5fa719d0dee2..308c2cd38196cb 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -115,12 +115,16 @@ class Builder {
   ArrayAttr getArrayAttr(ArrayRef<Attribute> value);
 
   // Returns a 0-valued attribute of the given `type`. This function only
-  // supports boolean, integer, and 16-/32-/64-bit float types, and vector or
+  // supports integer, and 16-/32-/64-bit float types, and vector or
   // ranked tensor of them. Returns null attribute otherwise.
   TypedAttr getZeroAttr(Type type);
   // Returns a 1-valued attribute of the given `type`.
   // Type constraints are the same as `getZeroAttr`.
   TypedAttr getOneAttr(Type type);
+  // Returns a numeric attribute of the given `type`.
+  // Type constraints are the same as `getZeroAttr`.
+  // Non float types are converted before returning the attribute.
+  TypedAttr getNumberAttr(double value, Type type);
 
   // Convenience methods for fixed types.
   FloatAttr getF16FloatAttr(float value);
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 9c5c58fa1fabfb..6ee5864656cf64 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -395,6 +395,26 @@ class RegionBuilderHelper {
       return builder.create<math::FloorOp>(arg.getLoc(), arg);
     case UnaryFn::negf:
       return builder.create<arith::NegFOp>(arg.getLoc(), arg);
+    case UnaryFn::reciprocal:
+    {
+      Attribute oneAttr = builder.getNumberAttr(1.0, arg.getType());
+      auto one = builder.create<arith::ConstantOp>(arg.getLoc(), ::cast<TypedAttr>(oneAttr));
+      return builder.create<arith::DivFOp>(arg.getLoc(), one, arg);
+    }
+    case UnaryFn::round:
+      return builder.create<math::RoundOp>(arg.getLoc(), arg);
+    case UnaryFn::sqrt:
+      return builder.create<math::SqrtOp>(arg.getLoc(), arg);
+    case UnaryFn::rsqrt:
+      return builder.create<math::RsqrtOp>(arg.getLoc(), arg);
+    case UnaryFn::square:
+    {
+      Attribute twoAttr = builder.getNumberAttr(2.0, arg.getType());
+      auto two = builder.create<arith::ConstantOp>(arg.getLoc(), ::cast<TypedAttr>(twoAttr));
+      return builder.create<math::PowFOp>(arg.getLoc(), arg, two);
+    }
+    case UnaryFn::tanh:
+      return builder.create<math::TanhOp>(arg.getLoc(), arg);
     }
     llvm_unreachable("unsupported unary function");
   }
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index d49f69a7b7ae6b..ae8348c8454e9e 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -329,34 +329,24 @@ ArrayAttr Builder::getAffineMapArrayAttr(ArrayRef<AffineMap> values) {
 }
 
 TypedAttr Builder::getZeroAttr(Type type) {
-  if (llvm::isa<FloatType>(type))
-    return getFloatAttr(type, 0.0);
-  if (llvm::isa<IndexType>(type))
-    return getIndexAttr(0);
-  if (llvm::dyn_cast<IntegerType>(type))
-    return getIntegerAttr(type,
-                          APInt(llvm::cast<IntegerType>(type).getWidth(), 0));
-  if (llvm::isa<RankedTensorType, VectorType>(type)) {
-    auto vtType = llvm::cast<ShapedType>(type);
-    auto element = getZeroAttr(vtType.getElementType());
-    if (!element)
-      return {};
-    return DenseElementsAttr::get(vtType, element);
-  }
-  return {};
+  return getNumberAttr(0.0, type);
 }
 
 TypedAttr Builder::getOneAttr(Type type) {
+  return getNumberAttr(1.0, type);
+}
+
+TypedAttr Builder::getNumberAttr(double value, Type type) {
   if (llvm::isa<FloatType>(type))
-    return getFloatAttr(type, 1.0);
+    return getFloatAttr(type, value);
   if (llvm::isa<IndexType>(type))
-    return getIndexAttr(1);
+    return getIndexAttr(static_cast<int64_t>(value));
   if (llvm::dyn_cast<IntegerType>(type))
     return getIntegerAttr(type,
-                          APInt(llvm::cast<IntegerType>(type).getWidth(), 1));
+                          APInt(llvm::cast<IntegerType>(type).getWidth(), static_cast<int64_t>(value)));
   if (llvm::isa<RankedTensorType, VectorType>(type)) {
     auto vtType = llvm::cast<ShapedType>(type);
-    auto element = getOneAttr(vtType.getElementType());
+    auto element = getNumberAttr(value, vtType.getElementType());
     if (!element)
       return {};
     return DenseElementsAttr::get(vtType, element);
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
index 23d6d26b7e294c..f7bc81bd2f6833 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
@@ -291,6 +291,11 @@ class UnaryFn:
     ceil = UnaryFnType("ceil")
     floor = UnaryFnType("floor")
     negf = UnaryFnType("negf")
+    round = UnaryFnType("round")
+    sqrt = UnaryFnType("sqrt")
+    rsqrt = UnaryFnType("rsqrt")
+    square = UnaryFnType("square")
+    tanh = UnaryFnType("tanh")
 
 
 class BinaryFnType:
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index 5b05364f6d35f3..c97ac448006505 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -108,6 +108,66 @@ def negf(
     O[None] = UnaryFn.negf(I[None])
 
 
+ at linalg_structured_op
+def round(
+    I=TensorDef(T1),
+    O=TensorDef(T1, output=True),
+):
+    """Applies round(x) elementwise.
+
+    No numeric casting is performed on the input operand.
+    """
+    O[None] = UnaryFn.round(I[None])
+
+
+ at linalg_structured_op
+def sqrt(
+    I=TensorDef(T1),
+    O=TensorDef(T1, output=True),
+):
+    """Applies sqrt(x) elementwise.
+
+    No numeric casting is performed on the input operand.
+    """
+    O[None] = UnaryFn.sqrt(I[None])
+
+
+ at linalg_structured_op
+def rsqrt(
+    I=TensorDef(T1),
+    O=TensorDef(T1, output=True),
+):
+    """Applies rsqrt(x) elementwise.
+
+    No numeric casting is performed on the input operand.
+    """
+    O[None] = UnaryFn.rsqrt(I[None])
+
+
+ at linalg_structured_op
+def square(
+    I=TensorDef(T1),
+    O=TensorDef(T1, output=True),
+):
+    """Applies square(x) elementwise.
+
+    No numeric casting is performed on the input operand.
+    """
+    O[None] = UnaryFn.square(I[None])
+
+
+ at linalg_structured_op
+def tanh(
+    I=TensorDef(T1),
+    O=TensorDef(T1, output=True),
+):
+    """Applies tanh(x) elementwise.
+
+    No numeric casting is performed on the input operand.
+    """
+    O[None] = UnaryFn.tanh(I[None])
+
+
 @linalg_structured_op
 def elemwise_binary(
     lhs=TensorDef(T1),
@@ -239,6 +299,25 @@ def max(
     O[None] = BinaryFn.max_signed(lhs[None], rhs[None])
 
 
+ at linalg_structured_op
+def min(
+    lhs=TensorDef(T1),
+    rhs=TensorDef(T1),
+    O=TensorDef(T1, output=True),
+):
+    """Takes the min (signed) between two inputs, elementwise.
+
+    The shapes and element types must be identical. The appropriate casts,
+    broadcasts and reductions should be done previously to calling this op.
+
+    This means reduction/broadcast/element cast semantics is explicit. Further
+    passes can take that into account when lowering this code. For example,
+    a `linalg.broadcast` + `linalg.div` sequence can be lowered to a
+    `linalg.generic` with different affine maps for the two operands.
+    """
+    O[None] = BinaryFn.min_signed(lhs[None], rhs[None])
+
+
 @linalg_structured_op
 def matmul(
     A=TensorDef(T1, S.M, S.K),
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index e852824cdb7367..160d1e9536bb5b 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -565,6 +565,136 @@ func.func @generalize_negf(%arg: memref<7x14x21xf32>, %out: memref<7x14x21xf32>)
 
 // -----
 
+func.func @generalize_reciprocal(%arg: memref<7x14x21xf32>, %out: memref<7x14x21xf32>) {
+  linalg.reciprocal ins(%arg : memref<7x14x21xf32>) outs(%out : memref<7x14x21xf32>)
+  return
+}
+
+// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+
+// CHECK: func @generalize_reciprocal
+// CHECK-SAME: (%[[ARG:.+]]: memref<7x14x21xf32>, %[[OUT:.+]]: memref<7x14x21xf32>)
+
+// CHECK: %[[one:.+]] = arith.constant 1.000000e+00 : f32
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
+// CHECK-SAME:  ins(%[[LHS]] : memref<7x14x21xf32>) outs(%[[OUT]] : memref<7x14x21xf32>)
+
+// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32)
+// CHECK-NEXT:      %[[reciprocal:.+]] = arith.divf %[[one]], %[[BBARG0]] : f32
+// CHECK-NEXT:      linalg.yield %[[reciprocal]] : f32
+
+// -----
+
+func.func @generalize_round(%arg: memref<7x14x21xf32>, %out: memref<7x14x21xf32>) {
+  linalg.round ins(%arg : memref<7x14x21xf32>) outs(%out : memref<7x14x21xf32>)
+  return
+}
+
+// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+
+// CHECK: func @generalize_round
+// CHECK-SAME: (%[[ARG:.+]]: memref<7x14x21xf32>, %[[OUT:.+]]: memref<7x14x21xf32>)
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
+// CHECK-SAME:  ins(%[[LHS]] : memref<7x14x21xf32>) outs(%[[OUT]] : memref<7x14x21xf32>)
+
+// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32)
+// CHECK-NEXT:      %[[round:.+]] = math.round %[[BBARG0]] : f32
+// CHECK-NEXT:      linalg.yield %[[round]] : f32
+
+// -----
+
+func.func @generalize_sqrt(%arg: memref<7x14x21xf32>, %out: memref<7x14x21xf32>) {
+  linalg.sqrt ins(%arg : memref<7x14x21xf32>) outs(%out : memref<7x14x21xf32>)
+  return
+}
+
+// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+
+// CHECK: func @generalize_sqrt
+// CHECK-SAME: (%[[ARG:.+]]: memref<7x14x21xf32>, %[[OUT:.+]]: memref<7x14x21xf32>)
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
+// CHECK-SAME:  ins(%[[LHS]] : memref<7x14x21xf32>) outs(%[[OUT]] : memref<7x14x21xf32>)
+
+// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32)
+// CHECK-NEXT:      %[[sqrt:.+]] = math.sqrt %[[BBARG0]] : f32
+// CHECK-NEXT:      linalg.yield %[[sqrt]] : f32
+
+// -----
+
+func.func @generalize_rsqrt(%arg: memref<7x14x21xf32>, %out: memref<7x14x21xf32>) {
+  linalg.rsqrt ins(%arg : memref<7x14x21xf32>) outs(%out : memref<7x14x21xf32>)
+  return
+}
+
+// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+
+// CHECK: func @generalize_rsqrt
+// CHECK-SAME: (%[[ARG:.+]]: memref<7x14x21xf32>, %[[OUT:.+]]: memref<7x14x21xf32>)
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
+// CHECK-SAME:  ins(%[[LHS]] : memref<7x14x21xf32>) outs(%[[OUT]] : memref<7x14x21xf32>)
+
+// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32)
+// CHECK-NEXT:      %[[rsqrt:.+]] = math.rsqrt %[[BBARG0]] : f32
+// CHECK-NEXT:      linalg.yield %[[rsqrt]] : f32
+
+// -----
+
+func.func @generalize_square(%arg: memref<7x14x21xf32>, %out: memref<7x14x21xf32>) {
+  linalg.square ins(%arg : memref<7x14x21xf32>) outs(%out : memref<7x14x21xf32>)
+  return
+}
+
+// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+
+// CHECK: func @generalize_square
+// CHECK-SAME: (%[[ARG:.+]]: memref<7x14x21xf32>, %[[OUT:.+]]: memref<7x14x21xf32>)
+
+// CHECK: %[[two:.+]] = arith.constant 2.000000e+00 : f32
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]]
+// CHECK-SAME: iterator_types = ["...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/90236


More information about the Mlir-commits mailing list