[Mlir-commits] [mlir] eb6a3c0 - [mlir][Linalg] Add a polymorphic linalg.copy operation

Nicolas Vasilache llvmlistbot at llvm.org
Tue Mar 8 09:52:56 PST 2022


Author: Nicolas Vasilache
Date: 2022-03-08T12:52:51-05:00
New Revision: eb6a3c0c0c71ab44141e71112ecd0a2ae2848037

URL: https://github.com/llvm/llvm-project/commit/eb6a3c0c0c71ab44141e71112ecd0a2ae2848037
DIFF: https://github.com/llvm/llvm-project/commit/eb6a3c0c0c71ab44141e71112ecd0a2ae2848037.diff

LOG: [mlir][Linalg] Add a polymorphic linalg.copy operation

With the recent improvements to OpDSL it is cheap to reintroduce a linalg.copy operation.

This operation is needed in at least 2 cases:
  1. for copies that may want to change the elemental type (e.g. cast, truncate, quantize, etc)
  2. to specify new tensors that should bufferize to a copy operation. The linalg.generic form
     always folds away which is not always the right call.

Differential Revision: https://reviews.llvm.org/D121230

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
    mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
    mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 2138822b7cfee..d249a8f3b9d32 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -1,6 +1,47 @@
 ### AUTOGENERATED from core_named_ops.py
 ### To regenerate, run: bin/update_core_linalg_named_ops.sh
 --- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: copy
+  cpp_class_name: CopyOp
+  doc: |-
+    Copies the tensor elementwise.
+
+    Numeric casting is performed on the input operand, promoting it to the same
+    data type as the accumulator/output.
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !LinalgOperandDefConfig
+    name: I
+    kind: input_tensor
+    type_var: T1
+    shape_map: affine_map<() -> ()>
+  - !LinalgOperandDefConfig
+    name: O
+    kind: output_tensor
+    type_var: U
+    shape_map: affine_map<() -> ()>
+  - !LinalgOperandDefConfig
+    name: cast
+    kind: type_fn_attr
+    default_fn: cast_signed
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<() -> ()>
+    - affine_map<() -> ()>
+  iterator_types: []
+  assignments:
+  - !ScalarAssign
+    arg: O
+    value: !ScalarExpression
+      scalar_fn:
+        kind: type
+        attr_name: cast
+        type_var: U
+        operands:
+        - !ScalarExpression
+          scalar_arg: I
+--- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: elemwise_unary
   cpp_class_name: ElemwiseUnaryOp

diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index 2e1424a932a58..5774cbc6c7677 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -7,11 +7,22 @@
 
 
 @linalg_structured_op
-def elemwise_unary(
-    I=TensorDef(T1),
-    O=TensorDef(U, output=True),
-    fun=UnaryFnAttrDef(default=UnaryFn.exp),
-    cast=TypeFnAttrDef(default=TypeFn.cast_signed)):
+def copy(I=TensorDef(T1),
+         O=TensorDef(U, output=True),
+         cast=TypeFnAttrDef(default=TypeFn.cast_signed)):
+  """Copies the tensor elementwise.
+
+  Numeric casting is performed on the input operand, promoting it to the same
+  data type as the accumulator/output.
+  """
+  O[None] = cast(U, I[None])
+
+
+ at linalg_structured_op
+def elemwise_unary(I=TensorDef(T1),
+                   O=TensorDef(U, output=True),
+                   fun=UnaryFnAttrDef(default=UnaryFn.exp),
+                   cast=TypeFnAttrDef(default=TypeFn.cast_signed)):
   """Applies the unary function fun elementwise.
 
   Numeric casting is performed on the input operand, promoting it to the same
@@ -21,12 +32,11 @@ def elemwise_unary(
 
 
 @linalg_structured_op
-def elemwise_binary(
-    lhs=TensorDef(T1),
-    rhs=TensorDef(T2),
-    O=TensorDef(U, output=True),
-    fun=BinaryFnAttrDef(default=BinaryFn.add),
-    cast=TypeFnAttrDef(default=TypeFn.cast_signed)):
+def elemwise_binary(lhs=TensorDef(T1),
+                    rhs=TensorDef(T2),
+                    O=TensorDef(U, output=True),
+                    fun=BinaryFnAttrDef(default=BinaryFn.add),
+                    cast=TypeFnAttrDef(default=TypeFn.cast_signed)):
   """Applies the binary function fun elementwise.
 
   Numeric casting is performed on the input operand, promoting it to the same
@@ -36,11 +46,10 @@ def elemwise_binary(
 
 
 @linalg_structured_op
-def matmul(
-    A=TensorDef(T1, S.M, S.K),
-    B=TensorDef(T2, S.K, S.N),
-    C=TensorDef(U, S.M, S.N, output=True),
-    cast=TypeFnAttrDef(default=TypeFn.cast_signed)):
+def matmul(A=TensorDef(T1, S.M, S.K),
+           B=TensorDef(T2, S.K, S.N),
+           C=TensorDef(U, S.M, S.N, output=True),
+           cast=TypeFnAttrDef(default=TypeFn.cast_signed)):
   """Performs a matrix multiplication of two 2D inputs.
 
   Numeric casting is performed on the operands to the inner multiply, promoting
@@ -52,10 +61,9 @@ def matmul(
 
 
 @linalg_structured_op
-def matmul_unsigned(
-    A=TensorDef(T1, S.M, S.K),
-    B=TensorDef(T2, S.K, S.N),
-    C=TensorDef(U, S.M, S.N, output=True)):
+def matmul_unsigned(A=TensorDef(T1, S.M, S.K),
+                    B=TensorDef(T2, S.K, S.N),
+                    C=TensorDef(U, S.M, S.N, output=True)):
   """Performs an unsigned matrix multiplication of two 2D inputs.
 
   Numeric casting is performed on the operands to the inner multiply, promoting
@@ -68,12 +76,11 @@ def matmul_unsigned(
 
 
 @linalg_structured_op
-def quantized_matmul(
-    A=TensorDef(T1, S.M, S.K),
-    B=TensorDef(T2, S.K, S.N),
-    AZp=ScalarDef(I32),
-    BZp=ScalarDef(I32),
-    C=TensorDef(U, S.M, S.N, output=True)):
+def quantized_matmul(A=TensorDef(T1, S.M, S.K),
+                     B=TensorDef(T2, S.K, S.N),
+                     AZp=ScalarDef(I32),
+                     BZp=ScalarDef(I32),
+                     C=TensorDef(U, S.M, S.N, output=True)):
   """Performs a matrix multiplication of two 2D inputs.
 
   Numeric casting is performed on the operands to the inner multiply, promoting
@@ -82,16 +89,16 @@ def quantized_matmul(
   matmul.
   """
   domain(D.m, D.n, D.k)
-  C[D.m, D.n] += (
-      TypeFn.cast_signed(U, A[D.m, D.k]) - TypeFn.cast_signed(U, AZp)) * (
-          TypeFn.cast_signed(U, B[D.k, D.n]) - TypeFn.cast_signed(U, BZp))
+  C[D.m,
+    D.n] += (TypeFn.cast_signed(U, A[D.m, D.k]) -
+             TypeFn.cast_signed(U, AZp)) * (TypeFn.cast_signed(U, B[D.k, D.n]) -
+                                            TypeFn.cast_signed(U, BZp))
 
 
 @linalg_structured_op
-def mmt4d(
-    lhs=TensorDef(TV.LhsType, S.M, S.K, S.M0, S.K0),
-    rhs=TensorDef(TV.RhsType, S.N, S.K, S.N0, S.K0),
-    accum=TensorDef(TV.AccumType, S.M, S.N, S.M0, S.N0, output=True)):
+def mmt4d(lhs=TensorDef(TV.LhsType, S.M, S.K, S.M0, S.K0),
+          rhs=TensorDef(TV.RhsType, S.N, S.K, S.N0, S.K0),
+          accum=TensorDef(TV.AccumType, S.M, S.N, S.M0, S.N0, output=True)):
   """Performs a matrix-matrix-transpose multiplication of two 4D inputs.
 
     Differences from linalg.matmul:
@@ -110,10 +117,9 @@ def mmt4d(
 
 
 @linalg_structured_op
-def batch_matmul(
-    A=TensorDef(T1, Batch, S.M, S.K),
-    B=TensorDef(T2, Batch, S.K, S.N),
-    C=TensorDef(U, Batch, S.M, S.N, output=True)):
+def batch_matmul(A=TensorDef(T1, Batch, S.M, S.K),
+                 B=TensorDef(T2, Batch, S.K, S.N),
+                 C=TensorDef(U, Batch, S.M, S.N, output=True)):
   """Performs a batched matrix multiplication of two 3D inputs.
 
   Numeric casting is performed on the operands to the inner multiply, promoting
@@ -127,12 +133,11 @@ def batch_matmul(
 
 
 @linalg_structured_op
-def quantized_batch_matmul(
-    A=TensorDef(T1, Batch, S.M, S.K),
-    B=TensorDef(T2, Batch, S.K, S.N),
-    AZp=ScalarDef(I32),
-    BZp=ScalarDef(I32),
-    C=TensorDef(U, Batch, S.M, S.N, output=True)):
+def quantized_batch_matmul(A=TensorDef(T1, Batch, S.M, S.K),
+                           B=TensorDef(T2, Batch, S.K, S.N),
+                           AZp=ScalarDef(I32),
+                           BZp=ScalarDef(I32),
+                           C=TensorDef(U, Batch, S.M, S.N, output=True)):
   """Performs a batched matrix multiplication of two 3D inputs.
 
   Numeric casting is performed on the operands to the inner multiply, promoting
@@ -141,16 +146,15 @@ def quantized_batch_matmul(
   matmul.
   """
   domain(D.b, D.m, D.n, D.k)
-  C[D.b, D.m, D.n] += (
-      TypeFn.cast_signed(U, A[D.b, D.m, D.k]) - TypeFn.cast_signed(U, AZp)) * (
-          TypeFn.cast_signed(U, B[D.b, D.k, D.n]) - TypeFn.cast_signed(U, BZp))
+  C[D.b, D.m, D.n] += (TypeFn.cast_signed(U, A[D.b, D.m, D.k]) -
+                       TypeFn.cast_signed(U, AZp)) * (TypeFn.cast_signed(
+                           U, B[D.b, D.k, D.n]) - TypeFn.cast_signed(U, BZp))
 
 
 @linalg_structured_op
-def matvec(
-    A=TensorDef(T1, S.M, S.N),
-    y=TensorDef(T2, S.N),
-    x=TensorDef(U, S.M, output=True)):
+def matvec(A=TensorDef(T1, S.M, S.N),
+           y=TensorDef(T2, S.N),
+           x=TensorDef(U, S.M, output=True)):
   """Performs a matrix-vector multiplication.
 
   Numeric casting is performed on the operands to the inner multiply, promoting
@@ -162,10 +166,9 @@ def matvec(
 
 
 @linalg_structured_op
-def vecmat(
-    y=TensorDef(T1, S.M),
-    A=TensorDef(T2, S.M, S.N),
-    x=TensorDef(U, S.N, output=True)):
+def vecmat(y=TensorDef(T1, S.M),
+           A=TensorDef(T2, S.M, S.N),
+           x=TensorDef(U, S.N, output=True)):
   """Performs a vector-matrix multiplication.
 
   Numeric casting is performed on the operands to the inner multiply, promoting
@@ -177,10 +180,9 @@ def vecmat(
 
 
 @linalg_structured_op
-def batch_matvec(
-    A=TensorDef(T1, Batch, S.M, S.K),
-    B=TensorDef(T2, Batch, S.K),
-    C=TensorDef(U, Batch, S.M, output=True)):
+def batch_matvec(A=TensorDef(T1, Batch, S.M, S.K),
+                 B=TensorDef(T2, Batch, S.K),
+                 C=TensorDef(U, Batch, S.M, output=True)):
   """Performs a batched matrix-vector multiplication.
 
   Numeric casting is performed on the operands to the inner multiply, promoting
@@ -193,8 +195,8 @@ def batch_matvec(
 
 
 @linalg_structured_op
-def dot(
-    A=TensorDef(T1, S.M), B=TensorDef(T2, S.M), C=TensorDef(U, output=True)):
+def dot(A=TensorDef(T1, S.M), B=TensorDef(T2, S.M), C=TensorDef(U,
+                                                                output=True)):
   """Performs a dot product of two vectors to a scalar result.
 
   Numeric casting is performed on the operands to the inner multiply, promoting
@@ -205,10 +207,9 @@ def dot(
 
 
 @linalg_structured_op
-def conv_1d(
-    I=TensorDef(T1, S.OW + S.KW),
-    K=TensorDef(T2, S.KW),
-    O=TensorDef(U, S.OW, output=True)):
+def conv_1d(I=TensorDef(T1, S.OW + S.KW),
+            K=TensorDef(T2, S.KW),
+            O=TensorDef(U, S.OW, output=True)):
   """Performs 1-D convolution with no channels.
 
   Numeric casting is performed on the operands to the inner multiply, promoting
@@ -221,10 +222,9 @@ def conv_1d(
 
 
 @linalg_structured_op
-def conv_2d(
-    I=TensorDef(T1, S.OH + S.KH, S.OW + S.KW),
-    K=TensorDef(T2, S.KH, S.KW),
-    O=TensorDef(U, S.OH, S.OW, output=True)):
+def conv_2d(I=TensorDef(T1, S.OH + S.KH, S.OW + S.KW),
+            K=TensorDef(T2, S.KH, S.KW),
+            O=TensorDef(U, S.OH, S.OW, output=True)):
   """Performs 2-D convolution with no channels.
 
   Numeric casting is performed on the operands to the inner multiply, promoting
@@ -237,10 +237,9 @@ def conv_2d(
 
 
 @linalg_structured_op
-def conv_3d(
-    I=TensorDef(T1, S.OD + S.KD, S.OH + S.KH, S.OW + S.KW),
-    K=TensorDef(T2, S.KD, S.KH, S.KW),
-    O=TensorDef(U, S.OD, S.OH, S.OW, output=True)):
+def conv_3d(I=TensorDef(T1, S.OD + S.KD, S.OH + S.KH, S.OW + S.KW),
+            K=TensorDef(T2, S.KD, S.KH, S.KW),
+            O=TensorDef(U, S.OD, S.OH, S.OW, output=True)):
   """Performs 3-D convolution with no channels.
 
   Numeric casting is performed on the operands to the inner multiply, promoting
@@ -254,12 +253,11 @@ def conv_3d(
 
 
 @linalg_structured_op
-def conv_1d_nwc_wcf(
-    I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C),
-    K=TensorDef(T2, S.KW, S.C, S.F),
-    O=TensorDef(U, S.N, S.OW, S.F, output=True),
-    strides=IndexAttrDef(S.SW, default=[1]),
-    dilations=IndexAttrDef(S.DW, default=[1])):
+def conv_1d_nwc_wcf(I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C),
+                    K=TensorDef(T2, S.KW, S.C, S.F),
+                    O=TensorDef(U, S.N, S.OW, S.F, output=True),
+                    strides=IndexAttrDef(S.SW, default=[1]),
+                    dilations=IndexAttrDef(S.DW, default=[1])):
   """Performs 1-D convolution.
 
   Numeric casting is performed on the operands to the inner multiply, promoting
@@ -267,19 +265,18 @@ def conv_1d_nwc_wcf(
   """
   implements(ConvolutionOpInterface)
   domain(D.n, D.ow, D.f, D.kw, D.c)
-  O[D.n, D.ow,
-    D.f] += TypeFn.cast_signed(U, I[D.n, D.ow * S.SW + D.kw * S.DW,
-                             D.c]) * TypeFn.cast_signed(U, K[D.kw, D.c, D.f])
+  O[D.n, D.ow, D.f] += TypeFn.cast_signed(
+      U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c]) * TypeFn.cast_signed(
+          U, K[D.kw, D.c, D.f])
 
 
 @linalg_structured_op
-def conv_2d_nhwc_hwcf(
-    I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW,
-                S.C),
-    K=TensorDef(T2, S.KH, S.KW, S.C, S.F),
-    O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True),
-    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
-    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
+def conv_2d_nhwc_hwcf(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH,
+                                  S.OW * S.SW + S.KW * S.DW, S.C),
+                      K=TensorDef(T2, S.KH, S.KW, S.C, S.F),
+                      O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True),
+                      strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+                      dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
   """Performs 2-D convolution.
 
   Layout:
@@ -297,15 +294,14 @@ def conv_2d_nhwc_hwcf(
 
 
 @linalg_structured_op
-def conv_2d_nhwc_hwcf_q(
-    I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW,
-                S.C),
-    K=TensorDef(T2, S.KH, S.KW, S.C, S.F),
-    IZp=ScalarDef(I32),
-    KZp=ScalarDef(I32),
-    O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True),
-    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
-    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
+def conv_2d_nhwc_hwcf_q(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH,
+                                    S.OW * S.SW + S.KW * S.DW, S.C),
+                        K=TensorDef(T2, S.KH, S.KW, S.C, S.F),
+                        IZp=ScalarDef(I32),
+                        KZp=ScalarDef(I32),
+                        O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True),
+                        strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+                        dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
   """Performs 2-D convolution with zero point offsets.
 
   Layout:
@@ -321,19 +317,17 @@ def conv_2d_nhwc_hwcf_q(
   O[D.n, D.oh, D.ow,
     D.f] += (TypeFn.cast_signed(
         U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]) -
-             TypeFn.cast_signed(U, IZp)) * (
-                 TypeFn.cast_signed(U, K[D.kh, D.kw, D.c, D.f]) -
-                 TypeFn.cast_signed(U, KZp))
+             TypeFn.cast_signed(U, IZp)) * (TypeFn.cast_signed(
+                 U, K[D.kh, D.kw, D.c, D.f]) - TypeFn.cast_signed(U, KZp))
 
 
 @linalg_structured_op
-def conv_2d_nchw_fchw(
-    I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH,
-                S.OW * S.SW + S.KW * S.DW),
-    K=TensorDef(T2, S.F, S.C, S.KH, S.KW),
-    O=TensorDef(U, S.N, S.F, S.OH, S.OW, output=True),
-    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
-    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
+def conv_2d_nchw_fchw(I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH,
+                                  S.OW * S.SW + S.KW * S.DW),
+                      K=TensorDef(T2, S.F, S.C, S.KH, S.KW),
+                      O=TensorDef(U, S.N, S.F, S.OH, S.OW, output=True),
+                      strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+                      dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
   """Performs 2-D convolution.
 
   Layout:
@@ -351,13 +345,19 @@ def conv_2d_nchw_fchw(
 
 
 @linalg_structured_op
-def conv_3d_ndhwc_dhwcf(
-    I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD, S.OH * S.SH + S.KH * S.DH,
-                S.OW * S.SW + S.KW * S.DW, S.C),
-    K=TensorDef(T2, S.KD, S.KH, S.KW, S.C, S.F),
-    O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.F, output=True),
-    strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]),
-    dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1])):
+def conv_3d_ndhwc_dhwcf(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD,
+                                    S.OH * S.SH + S.KH * S.DH,
+                                    S.OW * S.SW + S.KW * S.DW, S.C),
+                        K=TensorDef(T2, S.KD, S.KH, S.KW, S.C, S.F),
+                        O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.F, output=True),
+                        strides=IndexAttrDef(S.SD,
+                                             S.SH,
+                                             S.SW,
+                                             default=[1, 1, 1]),
+                        dilations=IndexAttrDef(S.DD,
+                                               S.DH,
+                                               S.DW,
+                                               default=[1, 1, 1])):
   """Performs 3-D convolution.
 
   Numeric casting is performed on the operands to the inner multiply, promoting
@@ -372,12 +372,12 @@ def conv_3d_ndhwc_dhwcf(
 
 
 @linalg_structured_op
-def depthwise_conv_1d_nwc_wc(
-    I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.IC),
-    K=TensorDef(T2, S.KW, S.IC),
-    O=TensorDef(U, S.N, S.OW, S.IC, output=True),
-    strides=IndexAttrDef(S.SW, default=[1]),
-    dilations=IndexAttrDef(S.DW, default=[1])):
+def depthwise_conv_1d_nwc_wc(I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW,
+                                         S.IC),
+                             K=TensorDef(T2, S.KW, S.IC),
+                             O=TensorDef(U, S.N, S.OW, S.IC, output=True),
+                             strides=IndexAttrDef(S.SW, default=[1]),
+                             dilations=IndexAttrDef(S.DW, default=[1])):
   """Performs depth-wise 1-D convolution.
 
   Numeric casting is performed on the operands to the inner multiply, promoting
@@ -392,13 +392,19 @@ def depthwise_conv_1d_nwc_wc(
 
 
 @linalg_structured_op
-def depthwise_conv_2d_nhwc_hwc(
-    I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW,
-                S.IC),
-    K=TensorDef(T2, S.KH, S.KW, S.IC),
-    O=TensorDef(U, S.N, S.OH, S.OW, S.IC, output=True),
-    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
-    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
+def depthwise_conv_2d_nhwc_hwc(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH,
+                                           S.OW * S.SW + S.KW * S.DW, S.IC),
+                               K=TensorDef(T2, S.KH, S.KW, S.IC),
+                               O=TensorDef(U,
+                                           S.N,
+                                           S.OH,
+                                           S.OW,
+                                           S.IC,
+                                           output=True),
+                               strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+                               dilations=IndexAttrDef(S.DH,
+                                                      S.DW,
+                                                      default=[1, 1])):
   """Performs depth-wise 2-D convolution.
 
   Numeric casting is performed on the operands to the inner multiply, promoting
@@ -413,15 +419,23 @@ def depthwise_conv_2d_nhwc_hwc(
 
 
 @linalg_structured_op
-def depthwise_conv_2d_nhwc_hwc_q(
-    I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW,
-                S.IC),
-    K=TensorDef(T2, S.KH, S.KW, S.IC),
-    IZp=ScalarDef(I32),
-    KZp=ScalarDef(I32),
-    O=TensorDef(U, S.N, S.OH, S.OW, S.IC, output=True),
-    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
-    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
+def depthwise_conv_2d_nhwc_hwc_q(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH,
+                                             S.OW * S.SW + S.KW * S.DW, S.IC),
+                                 K=TensorDef(T2, S.KH, S.KW, S.IC),
+                                 IZp=ScalarDef(I32),
+                                 KZp=ScalarDef(I32),
+                                 O=TensorDef(U,
+                                             S.N,
+                                             S.OH,
+                                             S.OW,
+                                             S.IC,
+                                             output=True),
+                                 strides=IndexAttrDef(S.SH,
+                                                      S.SW,
+                                                      default=[1, 1]),
+                                 dilations=IndexAttrDef(S.DH,
+                                                        S.DW,
+                                                        default=[1, 1])):
   """Performs depth-wise 2-D convolution.
 
   Numeric casting is performed on the operands to the inner multiply, promoting
@@ -437,13 +451,21 @@ def depthwise_conv_2d_nhwc_hwc_q(
 
 
 @linalg_structured_op
-def depthwise_conv_2d_nhwc_hwcm(
-    I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW,
-                S.IC),
-    K=TensorDef(T2, S.KH, S.KW, S.IC, S.CM),
-    O=TensorDef(U, S.N, S.OH, S.OW, S.IC, S.CM, output=True),
-    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
-    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
+def depthwise_conv_2d_nhwc_hwcm(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH,
+                                            S.OW * S.SW + S.KW * S.DW, S.IC),
+                                K=TensorDef(T2, S.KH, S.KW, S.IC, S.CM),
+                                O=TensorDef(U,
+                                            S.N,
+                                            S.OH,
+                                            S.OW,
+                                            S.IC,
+                                            S.CM,
+                                            output=True),
+                                strides=IndexAttrDef(S.SH, S.SW, default=[1,
+                                                                          1]),
+                                dilations=IndexAttrDef(S.DH,
+                                                       S.DW,
+                                                       default=[1, 1])):
   """Performs depth-wise 2-D convolution.
 
   Numeric casting is performed on the operands to the inner multiply, promoting
@@ -457,15 +479,25 @@ def depthwise_conv_2d_nhwc_hwcm(
 
 
 @linalg_structured_op
-def depthwise_conv_2d_nhwc_hwcm_q(
-    I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW,
-                S.IC),
-    K=TensorDef(T2, S.KH, S.KW, S.IC, S.CM),
-    IZp=ScalarDef(I32),
-    KZp=ScalarDef(I32),
-    O=TensorDef(U, S.N, S.OH, S.OW, S.IC, S.CM, output=True),
-    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
-    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
+def depthwise_conv_2d_nhwc_hwcm_q(I=TensorDef(T1, S.N,
+                                              S.OH * S.SH + S.KH * S.DH,
+                                              S.OW * S.SW + S.KW * S.DW, S.IC),
+                                  K=TensorDef(T2, S.KH, S.KW, S.IC, S.CM),
+                                  IZp=ScalarDef(I32),
+                                  KZp=ScalarDef(I32),
+                                  O=TensorDef(U,
+                                              S.N,
+                                              S.OH,
+                                              S.OW,
+                                              S.IC,
+                                              S.CM,
+                                              output=True),
+                                  strides=IndexAttrDef(S.SH,
+                                                       S.SW,
+                                                       default=[1, 1]),
+                                  dilations=IndexAttrDef(S.DH,
+                                                         S.DW,
+                                                         default=[1, 1])):
   """Performs depth-wise 2-D convolution.
 
   Numeric casting is performed on the operands to the inner multiply, promoting
@@ -482,13 +514,12 @@ def depthwise_conv_2d_nhwc_hwcm_q(
 
 
 @linalg_structured_op
-def pooling_nhwc_sum(
-    I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW,
-                S.C),
-    K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
-    O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
-    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
-    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
+def pooling_nhwc_sum(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH,
+                                 S.OW * S.SW + S.KW * S.DW, S.C),
+                     K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
+                     O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
+                     strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+                     dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
   """Performs sum pooling.
 
   Numeric casting is performed on the input operand, promoting it to the same
@@ -501,13 +532,12 @@ def pooling_nhwc_sum(
 
 
 @linalg_structured_op
-def pooling_nhwc_max(
-    I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW,
-                S.C),
-    K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
-    O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
-    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
-    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
+def pooling_nhwc_max(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH,
+                                 S.OW * S.SW + S.KW * S.DW, S.C),
+                     K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
+                     O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
+                     strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+                     dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
   """Performs max pooling.
 
   Numeric casting is performed on the input operand, promoting it to the same
@@ -515,19 +545,21 @@ def pooling_nhwc_max(
   """
   implements(ConvolutionOpInterface)
   domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw)
-  O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_signed[D.kh, D.kw](
-      TypeFn.cast_signed(
-          U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
+  O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_signed[D.kh, D.kw](TypeFn.cast_signed(
+      U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
 
 
 @linalg_structured_op
-def pooling_nhwc_max_unsigned(
-    I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW,
-                S.C),
-    K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
-    O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
-    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
-    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
+def pooling_nhwc_max_unsigned(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH,
+                                          S.OW * S.SW + S.KW * S.DW, S.C),
+                              K=TensorDef(T2,
+                                          S.KH,
+                                          S.KW,
+                                          index_dims=[D.kh, D.kw]),
+                              O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
+                              strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+                              dilations=IndexAttrDef(S.DH, S.DW, default=[1,
+                                                                          1])):
   """Performs unsigned max pooling.
 
   Numeric casting is performed on the input operand, promoting it to the same
@@ -535,19 +567,18 @@ def pooling_nhwc_max_unsigned(
   """
   implements(ConvolutionOpInterface)
   domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw)
-  O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_unsigned[D.kh, D.kw](
-      TypeFn.cast_unsigned(
-          U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
+  O[D.n, D.oh, D.ow,
+    D.c] = ReduceFn.max_unsigned[D.kh, D.kw](TypeFn.cast_unsigned(
+        U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
 
 
 @linalg_structured_op
-def pooling_nchw_max(
-    I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH,
-                S.OW * S.SW + S.KW * S.DW),
-    K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
-    O=TensorDef(U, S.N, S.C, S.OH, S.OW, output=True),
-    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
-    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
+def pooling_nchw_max(I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH,
+                                 S.OW * S.SW + S.KW * S.DW),
+                     K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
+                     O=TensorDef(U, S.N, S.C, S.OH, S.OW, output=True),
+                     strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+                     dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
   """Performs max pooling.
 
   Numeric casting is performed on the input operand, promoting it to the same
@@ -555,20 +586,17 @@ def pooling_nchw_max(
   """
   implements(ConvolutionOpInterface)
   domain(D.n, D.c, D.oh, D.ow, D.kh, D.kw)
-  O[D.n, D.c, D.oh, D.ow] = ReduceFn.max_signed[D.kh, D.kw](
-      TypeFn.cast_signed(
-          U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH,
-               D.ow * S.SW + D.kw * S.DW,]))
+  O[D.n, D.c, D.oh, D.ow] = ReduceFn.max_signed[D.kh, D.kw](TypeFn.cast_signed(
+      U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,]))
 
 
 @linalg_structured_op
-def pooling_nhwc_min(
-    I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW,
-                S.C),
-    K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
-    O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
-    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
-    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
+def pooling_nhwc_min(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH,
+                                 S.OW * S.SW + S.KW * S.DW, S.C),
+                     K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
+                     O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
+                     strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+                     dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
   """Performs min pooling.
 
   Numeric casting is performed on the input operand, promoting it to the same
@@ -576,19 +604,21 @@ def pooling_nhwc_min(
   """
   implements(ConvolutionOpInterface)
   domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw)
-  O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_signed[D.kh, D.kw](
-      TypeFn.cast_signed(
-          U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
+  O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_signed[D.kh, D.kw](TypeFn.cast_signed(
+      U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
 
 
 @linalg_structured_op
-def pooling_nhwc_min_unsigned(
-    I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW,
-                S.C),
-    K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
-    O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
-    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
-    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
+def pooling_nhwc_min_unsigned(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH,
+                                          S.OW * S.SW + S.KW * S.DW, S.C),
+                              K=TensorDef(T2,
+                                          S.KH,
+                                          S.KW,
+                                          index_dims=[D.kh, D.kw]),
+                              O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
+                              strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+                              dilations=IndexAttrDef(S.DH, S.DW, default=[1,
+                                                                          1])):
   """Performs unsigned min pooling.
 
   Numeric casting is performed on the input operand, promoting it to the same
@@ -596,19 +626,26 @@ def pooling_nhwc_min_unsigned(
   """
   implements(ConvolutionOpInterface)
   domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw)
-  O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_unsigned[D.kh, D.kw](
-      TypeFn.cast_unsigned(
-          U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
-
-
- at linalg_structured_op
-def pooling_ndhwc_sum(
-    I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD, S.OH * S.SH + S.KH * S.DH,
-                S.OW * S.SW + S.KW * S.DW, S.C),
-    K=TensorDef(T2, S.KD, S.KH, S.KW, index_dims=[D.kd, D.kh, D.kw]),
-    O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True),
-    strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]),
-    dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1])):
+  O[D.n, D.oh, D.ow,
+    D.c] = ReduceFn.min_unsigned[D.kh, D.kw](TypeFn.cast_unsigned(
+        U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
+
+
+ at linalg_structured_op
+def pooling_ndhwc_sum(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD,
+                                  S.OH * S.SH + S.KH * S.DH,
+                                  S.OW * S.SW + S.KW * S.DW, S.C),
+                      K=TensorDef(T2,
+                                  S.KD,
+                                  S.KH,
+                                  S.KW,
+                                  index_dims=[D.kd, D.kh, D.kw]),
+                      O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True),
+                      strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]),
+                      dilations=IndexAttrDef(S.DD,
+                                             S.DH,
+                                             S.DW,
+                                             default=[1, 1, 1])):
   """Performs 3D sum pooling.
 
   Numeric casting is performed on the input operand, promoting it to the same
@@ -622,13 +659,20 @@ def pooling_ndhwc_sum(
 
 
 @linalg_structured_op
-def pooling_ndhwc_max(
-    I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD, S.OH * S.SH + S.KH * S.DH,
-                S.OW * S.SW + S.KW * S.DW, S.C),
-    K=TensorDef(T2, S.KD, S.KH, S.KW, index_dims=[D.kd, D.kh, D.kw]),
-    O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True),
-    strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]),
-    dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1])):
+def pooling_ndhwc_max(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD,
+                                  S.OH * S.SH + S.KH * S.DH,
+                                  S.OW * S.SW + S.KW * S.DW, S.C),
+                      K=TensorDef(T2,
+                                  S.KD,
+                                  S.KH,
+                                  S.KW,
+                                  index_dims=[D.kd, D.kh, D.kw]),
+                      O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True),
+                      strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]),
+                      dilations=IndexAttrDef(S.DD,
+                                             S.DH,
+                                             S.DW,
+                                             default=[1, 1, 1])):
   """Performs 3D max pooling.
 
   Numeric casting is performed on the input operand, promoting it to the same
@@ -636,20 +680,27 @@ def pooling_ndhwc_max(
   """
   implements(ConvolutionOpInterface)
   domain(D.n, D.od, D.oh, D.ow, D.c, D.kd, D.kh, D.kw)
-  O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.max_signed[D.kd, D.kh, D.kw](
-      TypeFn.cast_signed(
-          U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH,
-               D.ow * S.SW + D.kw * S.DW, D.c]))
-
-
- at linalg_structured_op
-def pooling_ndhwc_min(
-    I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD, S.OH * S.SH + S.KH * S.DH,
-                S.OW * S.SW + S.KW * S.DW, S.C),
-    K=TensorDef(T2, S.KD, S.KH, S.KW, index_dims=[D.kd, D.kh, D.kw]),
-    O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True),
-    strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]),
-    dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1])):
+  O[D.n, D.od, D.oh, D.ow,
+    D.c] = ReduceFn.max_signed[D.kd, D.kh, D.kw](TypeFn.cast_signed(
+        U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH,
+             D.ow * S.SW + D.kw * S.DW, D.c]))
+
+
+ at linalg_structured_op
+def pooling_ndhwc_min(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD,
+                                  S.OH * S.SH + S.KH * S.DH,
+                                  S.OW * S.SW + S.KW * S.DW, S.C),
+                      K=TensorDef(T2,
+                                  S.KD,
+                                  S.KH,
+                                  S.KW,
+                                  index_dims=[D.kd, D.kh, D.kw]),
+                      O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True),
+                      strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]),
+                      dilations=IndexAttrDef(S.DD,
+                                             S.DH,
+                                             S.DW,
+                                             default=[1, 1, 1])):
   """Performs 3D min pooling.
 
   Numeric casting is performed on the input operand, promoting it to the same
@@ -657,10 +708,10 @@ def pooling_ndhwc_min(
   """
   implements(ConvolutionOpInterface)
   domain(D.n, D.od, D.oh, D.ow, D.c, D.kd, D.kh, D.kw)
-  O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.min_signed[D.kd, D.kh, D.kw](
-      TypeFn.cast_signed(
-          U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH,
-               D.ow * S.SW + D.kw * S.DW, D.c]))
+  O[D.n, D.od, D.oh, D.ow,
+    D.c] = ReduceFn.min_signed[D.kd, D.kh, D.kw](TypeFn.cast_signed(
+        U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH,
+             D.ow * S.SW + D.kw * S.DW, D.c]))
 
 
 @linalg_structured_op
@@ -677,11 +728,10 @@ def fill_tensor(value=ScalarDef(T1), O=TensorDef(U, output=True)):
 
 
 @linalg_structured_op
-def fill_rng_2d(
-    min=ScalarDef(F64),
-    max=ScalarDef(F64),
-    seed=ScalarDef(I32),
-    O=TensorDef(T, S.M, S.N, output=True)):
+def fill_rng_2d(min=ScalarDef(F64),
+                max=ScalarDef(F64),
+                seed=ScalarDef(I32),
+                O=TensorDef(T, S.M, S.N, output=True)):
   """Fills the output tensor with pseudo random numbers.
 
   The operation generations pseudo random numbers using a linear congruential

diff  --git a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
index 0c98629041c4e..ebb9a87696ad0 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
@@ -335,3 +335,16 @@ func @generalize_elemwise_rank_zero(%lhs : tensor<f32>, %rhs : tensor<f32>, %out
 // CHECK:       linalg.generic
 // CHECK-SAME:  iterator_types = ["parallel", "parallel"]
 // CHECK:        = arith.subf
+
+// -----
+
+// Verifies the fun attribute controls the binary function used.
+func @generalize_copy(%lhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> {
+  %0 = linalg.copy ins(%lhs: tensor<4x8xf32>) outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32>
+  return %0: tensor<4x8xf32>
+}
+
+// CHECK-LABEL: @generalize_copy
+//       CHECK:   linalg.generic
+//  CHECK-NEXT:   ^bb0(%[[I:[0-9a-zA-Z]*]]: f32
+//  CHECK-NEXT:   linalg.yield %[[I]]


        


More information about the Mlir-commits mailing list