[Mlir-commits] [mlir] 3fe7fe4 - [mlir][linalg] Add unsigned min/max/cast function to OpDSL.

Tobias Gysi llvmlistbot at llvm.org
Wed Oct 6 23:59:55 PDT 2021


Author: Tobias Gysi
Date: 2021-10-07T06:27:20Z
New Revision: 3fe7fe44249b0c640031a09800f3485a06a61d2d

URL: https://github.com/llvm/llvm-project/commit/3fe7fe44249b0c640031a09800f3485a06a61d2d
DIFF: https://github.com/llvm/llvm-project/commit/3fe7fe44249b0c640031a09800f3485a06a61d2d.diff

LOG: [mlir][linalg] Add unsigned min/max/cast function to OpDSL.

Update OpDSL to support unsigned integers by adding unsigned min/max/cast signatures. Add tests in OpDSL and on the C++ side to verify the proper signed and unsigned operations are emitted.

The patch addresses an issue brought up in https://reviews.llvm.org/D111170.

Reviewed By: rsuderman

Differential Revision: https://reviews.llvm.org/D111230

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
    mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
    mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py
    mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
    mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
    mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
    mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py
    mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 1a6c5cdee2c8f..2d8b02bb8ee57 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -56,12 +56,78 @@ structured_op: !LinalgStructuredOpConfig
                 operands:
                 - !ScalarExpression
                   scalar_arg: A
+                is_unsigned_cast: false
             - !ScalarExpression
               symbolic_cast:
                 type_var: U
                 operands:
                 - !ScalarExpression
                   scalar_arg: B
+                is_unsigned_cast: false
+--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: matmul_unsigned
+  cpp_class_name: MatmulUnsignedOp
+  doc: |-
+    Performs a unsigned matrix multiplication of two 2D inputs.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+  implements:
+  - LinalgContractionOpInterface
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !LinalgOperandDefConfig
+    name: A
+    usage: InputOperand
+    type_var: T1
+    shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
+  - !LinalgOperandDefConfig
+    name: B
+    usage: InputOperand
+    type_var: T2
+    shape_map: affine_map<()[s0, s1, s2] -> (s1, s2)>
+  - !LinalgOperandDefConfig
+    name: C
+    usage: OutputOperand
+    type_var: U
+    shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)>
+    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2, d1)>
+    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)>
+  iterator_types:
+  - parallel
+  - parallel
+  - reduction
+  assignments:
+  - !ScalarAssign
+    arg: C
+    value: !ScalarExpression
+      scalar_apply:
+        fn_name: add
+        operands:
+        - !ScalarExpression
+          scalar_arg: C
+        - !ScalarExpression
+          scalar_apply:
+            fn_name: mul
+            operands:
+            - !ScalarExpression
+              symbolic_cast:
+                type_var: U
+                operands:
+                - !ScalarExpression
+                  scalar_arg: A
+                is_unsigned_cast: true
+            - !ScalarExpression
+              symbolic_cast:
+                type_var: U
+                operands:
+                - !ScalarExpression
+                  scalar_arg: B
+                is_unsigned_cast: true
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: quantized_matmul
@@ -132,12 +198,14 @@ structured_op: !LinalgStructuredOpConfig
                     operands:
                     - !ScalarExpression
                       scalar_arg: A
+                    is_unsigned_cast: false
                 - !ScalarExpression
                   symbolic_cast:
                     type_var: U
                     operands:
                     - !ScalarExpression
                       scalar_arg: AZp
+                    is_unsigned_cast: false
             - !ScalarExpression
               scalar_apply:
                 fn_name: sub
@@ -148,12 +216,14 @@ structured_op: !LinalgStructuredOpConfig
                     operands:
                     - !ScalarExpression
                       scalar_arg: B
+                    is_unsigned_cast: false
                 - !ScalarExpression
                   symbolic_cast:
                     type_var: U
                     operands:
                     - !ScalarExpression
                       scalar_arg: BZp
+                    is_unsigned_cast: false
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: mmt4d
@@ -221,12 +291,14 @@ structured_op: !LinalgStructuredOpConfig
                 operands:
                 - !ScalarExpression
                   scalar_arg: lhs
+                is_unsigned_cast: false
             - !ScalarExpression
               symbolic_cast:
                 type_var: AccumType
                 operands:
                 - !ScalarExpression
                   scalar_arg: rhs
+                is_unsigned_cast: false
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: batch_matmul
@@ -284,12 +356,14 @@ structured_op: !LinalgStructuredOpConfig
                 operands:
                 - !ScalarExpression
                   scalar_arg: A
+                is_unsigned_cast: false
             - !ScalarExpression
               symbolic_cast:
                 type_var: U
                 operands:
                 - !ScalarExpression
                   scalar_arg: B
+                is_unsigned_cast: false
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: quantized_batch_matmul
@@ -361,12 +435,14 @@ structured_op: !LinalgStructuredOpConfig
                     operands:
                     - !ScalarExpression
                       scalar_arg: A
+                    is_unsigned_cast: false
                 - !ScalarExpression
                   symbolic_cast:
                     type_var: U
                     operands:
                     - !ScalarExpression
                       scalar_arg: AZp
+                    is_unsigned_cast: false
             - !ScalarExpression
               scalar_apply:
                 fn_name: sub
@@ -377,12 +453,14 @@ structured_op: !LinalgStructuredOpConfig
                     operands:
                     - !ScalarExpression
                       scalar_arg: B
+                    is_unsigned_cast: false
                 - !ScalarExpression
                   symbolic_cast:
                     type_var: U
                     operands:
                     - !ScalarExpression
                       scalar_arg: BZp
+                    is_unsigned_cast: false
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: matvec
@@ -438,12 +516,14 @@ structured_op: !LinalgStructuredOpConfig
                 operands:
                 - !ScalarExpression
                   scalar_arg: A
+                is_unsigned_cast: false
             - !ScalarExpression
               symbolic_cast:
                 type_var: U
                 operands:
                 - !ScalarExpression
                   scalar_arg: y
+                is_unsigned_cast: false
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: vecmat
@@ -499,12 +579,14 @@ structured_op: !LinalgStructuredOpConfig
                 operands:
                 - !ScalarExpression
                   scalar_arg: y
+                is_unsigned_cast: false
             - !ScalarExpression
               symbolic_cast:
                 type_var: U
                 operands:
                 - !ScalarExpression
                   scalar_arg: A
+                is_unsigned_cast: false
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: batch_matvec
@@ -561,12 +643,14 @@ structured_op: !LinalgStructuredOpConfig
                 operands:
                 - !ScalarExpression
                   scalar_arg: A
+                is_unsigned_cast: false
             - !ScalarExpression
               symbolic_cast:
                 type_var: U
                 operands:
                 - !ScalarExpression
                   scalar_arg: B
+                is_unsigned_cast: false
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: dot
@@ -621,12 +705,14 @@ structured_op: !LinalgStructuredOpConfig
                 operands:
                 - !ScalarExpression
                   scalar_arg: A
+                is_unsigned_cast: false
             - !ScalarExpression
               symbolic_cast:
                 type_var: U
                 operands:
                 - !ScalarExpression
                   scalar_arg: B
+                is_unsigned_cast: false
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: conv_1d
@@ -682,12 +768,14 @@ structured_op: !LinalgStructuredOpConfig
                 operands:
                 - !ScalarExpression
                   scalar_arg: I
+                is_unsigned_cast: false
             - !ScalarExpression
               symbolic_cast:
                 type_var: U
                 operands:
                 - !ScalarExpression
                   scalar_arg: K
+                is_unsigned_cast: false
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: conv_2d
@@ -745,12 +833,14 @@ structured_op: !LinalgStructuredOpConfig
                 operands:
                 - !ScalarExpression
                   scalar_arg: I
+                is_unsigned_cast: false
             - !ScalarExpression
               symbolic_cast:
                 type_var: U
                 operands:
                 - !ScalarExpression
                   scalar_arg: K
+                is_unsigned_cast: false
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: conv_3d
@@ -811,12 +901,14 @@ structured_op: !LinalgStructuredOpConfig
                 operands:
                 - !ScalarExpression
                   scalar_arg: I
+                is_unsigned_cast: false
             - !ScalarExpression
               symbolic_cast:
                 type_var: U
                 operands:
                 - !ScalarExpression
                   scalar_arg: K
+                is_unsigned_cast: false
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: conv_1d_nwc_wcf
@@ -887,12 +979,14 @@ structured_op: !LinalgStructuredOpConfig
                 operands:
                 - !ScalarExpression
                   scalar_arg: I
+                is_unsigned_cast: false
             - !ScalarExpression
               symbolic_cast:
                 type_var: U
                 operands:
                 - !ScalarExpression
                   scalar_arg: K
+                is_unsigned_cast: false
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: conv_2d_nhwc_hwcf
@@ -975,12 +1069,14 @@ structured_op: !LinalgStructuredOpConfig
                 operands:
                 - !ScalarExpression
                   scalar_arg: I
+                is_unsigned_cast: false
             - !ScalarExpression
               symbolic_cast:
                 type_var: U
                 operands:
                 - !ScalarExpression
                   scalar_arg: K
+                is_unsigned_cast: false
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: conv_2d_nhwc_hwcf_q
@@ -1080,12 +1176,14 @@ structured_op: !LinalgStructuredOpConfig
                     operands:
                     - !ScalarExpression
                       scalar_arg: I
+                    is_unsigned_cast: false
                 - !ScalarExpression
                   symbolic_cast:
                     type_var: U
                     operands:
                     - !ScalarExpression
                       scalar_arg: IZp
+                    is_unsigned_cast: false
             - !ScalarExpression
               scalar_apply:
                 fn_name: sub
@@ -1096,12 +1194,14 @@ structured_op: !LinalgStructuredOpConfig
                     operands:
                     - !ScalarExpression
                       scalar_arg: K
+                    is_unsigned_cast: false
                 - !ScalarExpression
                   symbolic_cast:
                     type_var: U
                     operands:
                     - !ScalarExpression
                       scalar_arg: KZp
+                    is_unsigned_cast: false
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: conv_2d_nchw_fchw
@@ -1184,12 +1284,14 @@ structured_op: !LinalgStructuredOpConfig
                 operands:
                 - !ScalarExpression
                   scalar_arg: I
+                is_unsigned_cast: false
             - !ScalarExpression
               symbolic_cast:
                 type_var: U
                 operands:
                 - !ScalarExpression
                   scalar_arg: K
+                is_unsigned_cast: false
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: conv_3d_ndhwc_dhwcf
@@ -1272,12 +1374,14 @@ structured_op: !LinalgStructuredOpConfig
                 operands:
                 - !ScalarExpression
                   scalar_arg: I
+                is_unsigned_cast: false
             - !ScalarExpression
               symbolic_cast:
                 type_var: U
                 operands:
                 - !ScalarExpression
                   scalar_arg: K
+                is_unsigned_cast: false
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: depthwise_conv2D_nhw
@@ -1353,12 +1457,14 @@ structured_op: !LinalgStructuredOpConfig
                 operands:
                 - !ScalarExpression
                   scalar_arg: I
+                is_unsigned_cast: false
             - !ScalarExpression
               symbolic_cast:
                 type_var: U
                 operands:
                 - !ScalarExpression
                   scalar_arg: K
+                is_unsigned_cast: false
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: depthwise_conv2D_nhw_q
@@ -1449,12 +1555,14 @@ structured_op: !LinalgStructuredOpConfig
                     operands:
                     - !ScalarExpression
                       scalar_arg: I
+                    is_unsigned_cast: false
                 - !ScalarExpression
                   symbolic_cast:
                     type_var: U
                     operands:
                     - !ScalarExpression
                       scalar_arg: IZp
+                    is_unsigned_cast: false
             - !ScalarExpression
               scalar_apply:
                 fn_name: sub
@@ -1465,12 +1573,14 @@ structured_op: !LinalgStructuredOpConfig
                     operands:
                     - !ScalarExpression
                       scalar_arg: K
+                    is_unsigned_cast: false
                 - !ScalarExpression
                   symbolic_cast:
                     type_var: U
                     operands:
                     - !ScalarExpression
                       scalar_arg: KZp
+                    is_unsigned_cast: false
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: depthwise_conv2D_nhwc
@@ -1549,12 +1659,14 @@ structured_op: !LinalgStructuredOpConfig
                 operands:
                 - !ScalarExpression
                   scalar_arg: I
+                is_unsigned_cast: false
             - !ScalarExpression
               symbolic_cast:
                 type_var: U
                 operands:
                 - !ScalarExpression
                   scalar_arg: K
+                is_unsigned_cast: false
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: depthwise_conv2D_nhwc_q
@@ -1649,12 +1761,14 @@ structured_op: !LinalgStructuredOpConfig
                     operands:
                     - !ScalarExpression
                       scalar_arg: I
+                    is_unsigned_cast: false
                 - !ScalarExpression
                   symbolic_cast:
                     type_var: U
                     operands:
                     - !ScalarExpression
                       scalar_arg: IZp
+                    is_unsigned_cast: false
             - !ScalarExpression
               scalar_apply:
                 fn_name: sub
@@ -1665,12 +1779,14 @@ structured_op: !LinalgStructuredOpConfig
                     operands:
                     - !ScalarExpression
                       scalar_arg: K
+                    is_unsigned_cast: false
                 - !ScalarExpression
                   symbolic_cast:
                     type_var: U
                     operands:
                     - !ScalarExpression
                       scalar_arg: KZp
+                    is_unsigned_cast: false
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: pooling_nhwc_sum
@@ -1741,6 +1857,7 @@ structured_op: !LinalgStructuredOpConfig
             operands:
             - !ScalarExpression
               scalar_arg: I
+            is_unsigned_cast: false
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: pooling_nhwc_max
@@ -1811,6 +1928,78 @@ structured_op: !LinalgStructuredOpConfig
             operands:
             - !ScalarExpression
               scalar_arg: I
+            is_unsigned_cast: false
+--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: pooling_nhwc_max_unsigned
+  cpp_class_name: PoolingNhwcMaxUnsignedOp
+  doc: |-
+    Performs unsigned max pooling.
+
+    Numeric casting is performed on the input operand, promoting it to the same
+    data type as the accumulator/output.
+  implements:
+  - LinalgConvolutionOpInterface
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !LinalgOperandDefConfig
+    name: I
+    usage: InputOperand
+    type_var: T1
+    shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s1 *
+      s2 + s3 * s4, s5 * s6 + s7 * s8, s9)>
+  - !LinalgOperandDefConfig
+    name: K
+    usage: InputOperand
+    type_var: T2
+    shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s3, s7)>
+  - !LinalgOperandDefConfig
+    name: O
+    usage: OutputOperand
+    type_var: U
+    shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s1, s5,
+      s9)>
+  - !LinalgOperandDefConfig
+    name: strides
+    usage: IndexAttribute
+    type_var: I64
+    attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s2, s6)>
+  - !LinalgOperandDefConfig
+    name: dilations
+    usage: IndexAttribute
+    type_var: I64
+    attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s4, s8)>
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9]
+      -> (d0, d1 * s2 + d3 * s4, d2 * s6 + d4 * s8, d5)>
+    - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9]
+      -> (d3, d4)>
+    - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9]
+      -> (d0, d1, d2, d5)>
+  iterator_types:
+  - parallel
+  - parallel
+  - parallel
+  - reduction
+  - reduction
+  - parallel
+  assignments:
+  - !ScalarAssign
+    arg: O
+    value: !ScalarExpression
+      scalar_apply:
+        fn_name: max_unsigned
+        operands:
+        - !ScalarExpression
+          scalar_arg: O
+        - !ScalarExpression
+          symbolic_cast:
+            type_var: U
+            operands:
+            - !ScalarExpression
+              scalar_arg: I
+            is_unsigned_cast: true
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: pooling_nchw_max
@@ -1881,6 +2070,7 @@ structured_op: !LinalgStructuredOpConfig
             operands:
             - !ScalarExpression
               scalar_arg: I
+            is_unsigned_cast: false
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: pooling_nhwc_min
@@ -1951,6 +2141,78 @@ structured_op: !LinalgStructuredOpConfig
             operands:
             - !ScalarExpression
               scalar_arg: I
+            is_unsigned_cast: false
+--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: pooling_nhwc_min_unsigned
+  cpp_class_name: PoolingNhwcMinUnsignedOp
+  doc: |-
+    Performs unsigned min pooling.
+
+    Numeric casting is performed on the input operand, promoting it to the same
+    data type as the accumulator/output.
+  implements:
+  - LinalgConvolutionOpInterface
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !LinalgOperandDefConfig
+    name: I
+    usage: InputOperand
+    type_var: T1
+    shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s1 *
+      s2 + s3 * s4, s5 * s6 + s7 * s8, s9)>
+  - !LinalgOperandDefConfig
+    name: K
+    usage: InputOperand
+    type_var: T2
+    shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s3, s7)>
+  - !LinalgOperandDefConfig
+    name: O
+    usage: OutputOperand
+    type_var: U
+    shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s1, s5,
+      s9)>
+  - !LinalgOperandDefConfig
+    name: strides
+    usage: IndexAttribute
+    type_var: I64
+    attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s2, s6)>
+  - !LinalgOperandDefConfig
+    name: dilations
+    usage: IndexAttribute
+    type_var: I64
+    attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s4, s8)>
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9]
+      -> (d0, d1 * s2 + d3 * s4, d2 * s6 + d4 * s8, d5)>
+    - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9]
+      -> (d3, d4)>
+    - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9]
+      -> (d0, d1, d2, d5)>
+  iterator_types:
+  - parallel
+  - parallel
+  - parallel
+  - reduction
+  - reduction
+  - parallel
+  assignments:
+  - !ScalarAssign
+    arg: O
+    value: !ScalarExpression
+      scalar_apply:
+        fn_name: min_unsigned
+        operands:
+        - !ScalarExpression
+          scalar_arg: O
+        - !ScalarExpression
+          symbolic_cast:
+            type_var: U
+            operands:
+            - !ScalarExpression
+              scalar_arg: I
+            is_unsigned_cast: true
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: pooling_ndhwc_sum
@@ -2027,6 +2289,7 @@ structured_op: !LinalgStructuredOpConfig
             operands:
             - !ScalarExpression
               scalar_arg: I
+            is_unsigned_cast: false
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: pooling_ndhwc_max
@@ -2103,6 +2366,7 @@ structured_op: !LinalgStructuredOpConfig
             operands:
             - !ScalarExpression
               scalar_arg: I
+            is_unsigned_cast: false
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: pooling_ndhwc_min
@@ -2179,6 +2443,7 @@ structured_op: !LinalgStructuredOpConfig
             operands:
             - !ScalarExpression
               scalar_arg: I
+            is_unsigned_cast: false
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: fill_rng_2d
@@ -2246,6 +2511,7 @@ structured_op: !LinalgStructuredOpConfig
                         operands:
                         - !ScalarExpression
                           scalar_const: '2147483647 : i64'
+                        is_unsigned_cast: false
                     - !ScalarExpression
                       symbolic_cast:
                         type_var: F64
@@ -2268,6 +2534,7 @@ structured_op: !LinalgStructuredOpConfig
                                         operands:
                                         - !ScalarExpression
                                           scalar_index: 1
+                                        is_unsigned_cast: false
                                     - !ScalarExpression
                                       scalar_apply:
                                         fn_name: add
@@ -2286,6 +2553,7 @@ structured_op: !LinalgStructuredOpConfig
                                                     operands:
                                                     - !ScalarExpression
                                                       scalar_index: 0
+                                                    is_unsigned_cast: false
                                                 - !ScalarExpression
                                                   scalar_arg: seed
                                             - !ScalarExpression
@@ -2294,24 +2562,29 @@ structured_op: !LinalgStructuredOpConfig
                                                 operands:
                                                 - !ScalarExpression
                                                   scalar_const: '1103515245 : i64'
+                                                is_unsigned_cast: false
                                         - !ScalarExpression
                                           symbolic_cast:
                                             type_var: I32
                                             operands:
                                             - !ScalarExpression
                                               scalar_const: '12345 : i64'
+                                            is_unsigned_cast: false
                                 - !ScalarExpression
                                   symbolic_cast:
                                     type_var: I32
                                     operands:
                                     - !ScalarExpression
                                       scalar_const: '1103515245 : i64'
+                                    is_unsigned_cast: false
                             - !ScalarExpression
                               symbolic_cast:
                                 type_var: I32
                                 operands:
                                 - !ScalarExpression
                                   scalar_const: '12345 : i64'
+                                is_unsigned_cast: false
+                        is_unsigned_cast: false
                 - !ScalarExpression
                   scalar_apply:
                     fn_name: mul
@@ -2330,8 +2603,10 @@ structured_op: !LinalgStructuredOpConfig
                         operands:
                         - !ScalarExpression
                           scalar_const: '2.3283063999999999E-10 : f64'
+                        is_unsigned_cast: false
             - !ScalarExpression
               scalar_arg: min
+        is_unsigned_cast: false
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: soft_plus_2d
@@ -2377,6 +2652,7 @@ structured_op: !LinalgStructuredOpConfig
                 operands:
                 - !ScalarExpression
                   scalar_const: '1.000000e+00 : f64'
+                is_unsigned_cast: false
             - !ScalarExpression
               scalar_apply:
                 fn_name: exp
@@ -2387,3 +2663,4 @@ structured_op: !LinalgStructuredOpConfig
                     operands:
                     - !ScalarExpression
                       scalar_arg: I
+                    is_unsigned_cast: false

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 69cd9e25e5d94..2cba281acf8f0 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -196,7 +196,7 @@ class RegionBuilderHelper {
   // If the cast cannot be performed, a warning will be issued and the
   // operand returned as-is (which will presumably yield a verification
   // issue downstream).
-  Value cast(Type toType, Value operand) {
+  Value cast(Type toType, Value operand, bool isUnsignedCast) {
     OpBuilder builder = getBuilder();
     auto loc = operand.getLoc();
 
@@ -204,23 +204,32 @@ class RegionBuilderHelper {
       return operand;
     if (auto toIntType = toType.dyn_cast<IntegerType>()) {
       // If operand is floating point, cast directly to the int type.
-      if (operand.getType().isa<FloatType>())
+      if (operand.getType().isa<FloatType>()) {
+        if (isUnsignedCast)
+          return builder.create<FPToUIOp>(loc, toType, operand);
         return builder.create<FPToSIOp>(loc, toType, operand);
+      }
       // Cast index operands directly to the int type.
       if (operand.getType().isIndex())
         return builder.create<IndexCastOp>(loc, toType, operand);
       if (auto fromIntType = operand.getType().dyn_cast<IntegerType>()) {
-        // Either sign extend or truncate.
-        if (toIntType.getWidth() > fromIntType.getWidth())
+        // Either extend or truncate.
+        if (toIntType.getWidth() > fromIntType.getWidth()) {
+          if (isUnsignedCast)
+            return builder.create<ZeroExtendIOp>(loc, toType, operand);
           return builder.create<SignExtendIOp>(loc, toType, operand);
+        }
         if (toIntType.getWidth() < fromIntType.getWidth())
           return builder.create<TruncateIOp>(loc, toType, operand);
       }
     } else if (auto toFloatType = toType.dyn_cast<FloatType>()) {
       // If operand is integer, cast directly to the float type.
       // Note that it is unclear how to cast from BF16<->FP16.
-      if (operand.getType().isa<IntegerType>())
+      if (operand.getType().isa<IntegerType>()) {
+        if (isUnsignedCast)
+          return builder.create<UIToFPOp>(loc, toFloatType, operand);
         return builder.create<SIToFPOp>(loc, toFloatType, operand);
+      }
       if (auto fromFloatType = operand.getType().dyn_cast<FloatType>()) {
         if (toFloatType.getWidth() > fromFloatType.getWidth())
           return builder.create<FPExtOp>(loc, toFloatType, operand);
@@ -284,6 +293,15 @@ class RegionBuilderHelper {
     llvm_unreachable("unsupported non numeric type");
   }
 
+  Value applyfn__max_unsigned(Value lhs, Value rhs) {
+    OpBuilder builder = getBuilder();
+    if (isFloatingPoint(lhs))
+      return builder.create<MaxFOp>(lhs.getLoc(), lhs, rhs);
+    if (isInteger(lhs))
+      return builder.create<MaxUIOp>(lhs.getLoc(), lhs, rhs);
+    llvm_unreachable("unsupported non numeric type");
+  }
+
   Value applyfn__min(Value lhs, Value rhs) {
     OpBuilder builder = getBuilder();
     if (isFloatingPoint(lhs))
@@ -293,6 +311,15 @@ class RegionBuilderHelper {
     llvm_unreachable("unsupported non numeric type");
   }
 
+  Value applyfn__min_unsigned(Value lhs, Value rhs) {
+    OpBuilder builder = getBuilder();
+    if (isFloatingPoint(lhs))
+      return builder.create<MinFOp>(lhs.getLoc(), lhs, rhs);
+    if (isInteger(lhs))
+      return builder.create<MinUIOp>(lhs.getLoc(), lhs, rhs);
+    llvm_unreachable("unsupported non numeric type");
+  }
+
   void yieldOutputs(ValueRange values) {
     assert(!values.empty() && "linalg ops must yield outputs");
     if (values.empty())

diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
index c3894002914fa..732cacfff63ab 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
@@ -340,6 +340,8 @@ class PrimFn:
   max = PrimFnType("max")
   min = PrimFnType("min")
   sub = PrimFnType("sub")
+  max_unsigned = PrimFnType("max_unsigned")
+  min_unsigned = PrimFnType("min_unsigned")
 
 
 class ReduceFnType:
@@ -365,6 +367,8 @@ class ReduceFn:
   mul = PrimFn.mul.reduce
   max = PrimFn.max.reduce
   min = PrimFn.min.reduce
+  max_unsigned = PrimFn.max_unsigned.reduce
+  min_unsigned = PrimFn.min_unsigned.reduce
 
 
 class PrimApply(TensorExpression):
@@ -438,8 +442,8 @@ def __init__(self, to_type: TypeVar, operand: TensorExpression):
     self.operand = operand
 
   def to_scalar_expression(self) -> ScalarExpression:
-    return ScalarSymbolicCast(self.to_type,
-                              self.operand.to_scalar_expression()).expr()
+    return ScalarSymbolicCast(self.to_type, self.operand.to_scalar_expression(),
+                              False).expr()
 
   def visit_tensor_exprs(self, callback):
     super().visit_tensor_exprs(callback)
@@ -449,6 +453,17 @@ def __repr__(self):
     return f"cast({self.to_type}, {repr(self.operand)})"
 
 
+class cast_unsigned(cast):
+  """Casts the element type to an unsigned type (typically symbolic TypeVar)."""
+
+  def to_scalar_expression(self) -> ScalarExpression:
+    return ScalarSymbolicCast(self.to_type, self.operand.to_scalar_expression(),
+                              True).expr()
+
+  def __repr__(self):
+    return f"cast_unsigned({self.to_type}, {repr(self.operand)})"
+
+
 class ReduceApply(TensorExpression):
   """Application of a reduction.
 

diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
index 4a883e79037b5..7feea040aa77c 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
@@ -230,10 +230,12 @@ def expression(self, expr: ScalarExpression) -> Value:
       return fn(*operand_values)
     elif expr.symbolic_cast:
       operand_value = self.expression(expr.symbolic_cast.operand)
-      return self.cast(expr.symbolic_cast.to_type.name, operand_value)
+      return self.cast(expr.symbolic_cast.to_type.name, operand_value,
+                       expr.symbolic_cast.is_unsigned_cast)
     raise NotImplementedError(f"Unimplemented scalar body expression: {expr}")
 
-  def cast(self, type_var_name: str, operand: Value) -> Value:
+  def cast(self, type_var_name: str, operand: Value,
+           is_unsigned_cast: bool) -> Value:
     try:
       to_type = self.type_mapping[type_var_name]
     except KeyError:
@@ -242,29 +244,37 @@ def cast(self, type_var_name: str, operand: Value) -> Value:
     if operand.type == to_type:
       return operand
     if _is_integer_type(to_type):
-      return self._cast_to_integer(to_type, operand)
+      return self._cast_to_integer(to_type, operand, is_unsigned_cast)
     elif _is_floating_point_type(to_type):
-      return self._cast_to_floating_point(to_type, operand)
+      return self._cast_to_floating_point(to_type, operand, is_unsigned_cast)
 
-  def _cast_to_integer(self, to_type: Type, operand: Value) -> Value:
+  def _cast_to_integer(self, to_type: Type, operand: Value,
+                       is_unsigned_cast: bool) -> Value:
     to_width = IntegerType(to_type).width
     operand_type = operand.type
     if _is_floating_point_type(operand_type):
+      if is_unsigned_cast:
+        return std.FPToUIOp(to_type, operand).result
       return std.FPToSIOp(to_type, operand).result
     if _is_index_type(operand_type):
       return std.IndexCastOp(to_type, operand).result
     # Assume integer.
     from_width = IntegerType(operand_type).width
     if to_width > from_width:
+      if is_unsigned_cast:
+        return std.ZeroExtendIOp(to_type, operand).result
       return std.SignExtendIOp(to_type, operand).result
     elif to_width < from_width:
       return std.TruncateIOp(to_type, operand).result
     raise ValueError(f"Unable to cast body expression from {operand_type} to "
                      f"{to_type}")
 
-  def _cast_to_floating_point(self, to_type: Type, operand: Value) -> Value:
+  def _cast_to_floating_point(self, to_type: Type, operand: Value,
+                              is_unsigned_cast: bool) -> Value:
     operand_type = operand.type
     if _is_integer_type(operand_type):
+      if is_unsigned_cast:
+        return std.UIToFPOp(to_type, operand).result
       return std.SIToFPOp(to_type, operand).result
     # Assume FloatType.
     to_width = _get_floating_point_width(to_type)
@@ -324,6 +334,13 @@ def _eval_max(self, lhs: Value, rhs: Value) -> Value:
       return std.MaxSIOp(lhs.type, lhs, rhs).result
     raise NotImplementedError("Unsupported 'max' operand: {lhs}")
 
+  def _eval_max_unsigned(self, lhs: Value, rhs: Value) -> Value:
+    if _is_floating_point_type(lhs.type):
+      return std.MaxFOp(lhs.type, lhs, rhs).result
+    if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
+      return std.MaxUIOp(lhs.type, lhs, rhs).result
+    raise NotImplementedError("Unsupported 'max_unsigned' operand: {lhs}")
+
   def _eval_min(self, lhs: Value, rhs: Value) -> Value:
     if _is_floating_point_type(lhs.type):
       return std.MinFOp(lhs.type, lhs, rhs).result
@@ -331,6 +348,12 @@ def _eval_min(self, lhs: Value, rhs: Value) -> Value:
       return std.MinSIOp(lhs.type, lhs, rhs).result
     raise NotImplementedError("Unsupported 'min' operand: {lhs}")
 
+  def _eval_min_unsigned(self, lhs: Value, rhs: Value) -> Value:
+    if _is_floating_point_type(lhs.type):
+      return std.MinFOp(lhs.type, lhs, rhs).result
+    if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
+      return std.MinUIOp(lhs.type, lhs, rhs).result
+    raise NotImplementedError("Unsupported 'min_unsigned' operand: {lhs}")
 
 def _infer_structured_outs(op_config: LinalgStructuredOpConfig,
                            in_arg_defs: Sequence[OperandDefConfig],

diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py
index 48627bfab544c..6de3333fbf200 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py
@@ -85,15 +85,17 @@ def __repr__(self):
 class ScalarSymbolicCast:
   """A type of ScalarExpression that symbolically casts an operand to a TypeVar."""
 
-  def __init__(self, to_type: TypeVar, operand: "ScalarExpression"):
+  def __init__(self, to_type: TypeVar, operand: "ScalarExpression",
+               is_unsigned_cast: bool):
     self.to_type = to_type
     self.operand = operand
+    self.is_unsigned_cast = is_unsigned_cast
 
   def expr(self) -> "ScalarExpression":
     return ScalarExpression(symbolic_cast=self)
 
   def __repr__(self):
-    return f"ScalarSymbolicCast({self.to_type}, {self.operand})"
+    return f"ScalarSymbolicCast({self.to_type}, {self.operand}, {self.is_unsigned_cast})"
 
 
 class ScalarExpression(YAMLObject):
@@ -144,7 +146,8 @@ def to_yaml_custom_dict(self):
       return dict(
           symbolic_cast=dict(
               type_var=self.symbolic_cast.to_type.name,
-              operands=[self.symbolic_cast.operand]))
+              operands=[self.symbolic_cast.operand],
+              is_unsigned_cast=self.symbolic_cast.is_unsigned_cast))
     else:
       raise ValueError(f"Unexpected ScalarExpression type: {self}")
 

diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index b78a2179737f5..9f5b27ea000eb 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -20,6 +20,20 @@ def matmul(
   implements(ContractionOpInterface)
   C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
 
+ at linalg_structured_op
+def matmul_unsigned(
+    A=TensorDef(T1, S.M, S.K),
+    B=TensorDef(T2, S.K, S.N),
+    C=TensorDef(U, S.M, S.N, output=True)):
+  """Performs an unsigned matrix multiplication of two 2D inputs.
+
+  Numeric casting is performed on the operands to the inner multiply, promoting
+  them to the same data type as the accumulator/output.
+  """
+  domain(D.m, D.n, D.k)
+  implements(ContractionOpInterface)
+  C[D.m, D.n] += cast_unsigned(U, A[D.m, D.k]) * cast_unsigned(U, B[D.k, D.n])
+
 @linalg_structured_op
 def quantized_matmul(
     A=TensorDef(T1, S.M, S.K),
@@ -411,6 +425,24 @@ def pooling_nhwc_max(
       cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
                 D.c]))
 
+ at linalg_structured_op
+def pooling_nhwc_max_unsigned(
+    I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C),
+    K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
+    O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
+    strides=AttributeDef(S.SH, S.SW),
+    dilations=AttributeDef(S.DH, S.DW)):
+  """Performs unsigned max pooling.
+
+  Numeric casting is performed on the input operand, promoting it to the same
+  data type as the accumulator/output.
+  """
+  implements(ConvolutionOpInterface)
+  domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
+  O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_unsigned(D.kh, D.kw)(
+      cast_unsigned(
+          U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
+
 @linalg_structured_op
 def pooling_nchw_max(
     I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW),
@@ -447,6 +479,23 @@ def pooling_nhwc_min(
       cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
                 D.c]))
 
+ at linalg_structured_op
+def pooling_nhwc_min_unsigned(
+    I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C),
+    K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
+    O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
+    strides=AttributeDef(S.SH, S.SW),
+    dilations=AttributeDef(S.DH, S.DW)):
+  """Performs unsigned min pooling.
+
+  Numeric casting is performed on the input operand, promoting it to the same
+  data type as the accumulator/output.
+  """
+  implements(ConvolutionOpInterface)
+  domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
+  O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_unsigned(D.kh, D.kw)(
+      cast_unsigned(
+          U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
 
 @linalg_structured_op
 def pooling_ndhwc_sum(

diff  --git a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
index 89fd83e585eef..5d330a8e42721 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
@@ -1,35 +1,108 @@
 // RUN: mlir-opt %s -split-input-file -linalg-generalize-named-ops | FileCheck %s
 
-func @generalize_matmul_tensor_f32(%A : tensor<16x8xf32>, %B: tensor<8x32xf32>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
-  %0 = linalg.matmul ins(%A, %B: tensor<16x8xf32>, tensor<8x32xf32>)
+// Verifies that 
diff erent argument types is legal.
+func @generalize_matmul_tensor_f16f64f32(%A : tensor<16x8xf16>, %B: tensor<8x32xf64>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
+  %0 = linalg.matmul ins(%A, %B: tensor<16x8xf16>, tensor<8x32xf64>)
                           outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
   return %0: tensor<16x32xf32>
 }
 
-// CHECK-LABEL: @generalize_matmul_tensor_f32
-// CHECK:      ^{{.*}}(%[[A_ARG:.+]]: f32, %[[B_ARG:.+]]: f32, %[[C_ARG:.+]]: f32)
-// CHECK-NEXT:   %[[MUL:.+]] = mulf %[[A_ARG]], %[[B_ARG]] : f32
+// CHECK-LABEL: @generalize_matmul_tensor_f16f64f32
+// CHECK:      ^{{.*}}(%[[A_ARG:.+]]: f16, %[[B_ARG:.+]]: f64, %[[C_ARG:.+]]: f32)
+// Verify floating point extension and truncation.
+// CHECK-NEXT:   %[[A_CAST:.+]] = fpext %[[A_ARG]] : f16 to f32
+// CHECK-NEXT:   %[[B_CAST:.+]] = fptrunc %[[B_ARG]] : f64 to f32
+// CHECK-NEXT:   %[[MUL:.+]] = mulf %[[A_CAST]], %[[B_CAST]] : f32
 // CHECK-NEXT:   %[[ADD:.+]] = addf %[[C_ARG]], %[[MUL]] : f32
 // CHECK-NEXT:   linalg.yield %[[ADD]] : f32
 // CHECK-NEXT: -> tensor<16x32xf32>
 
 // -----
 
-func @generalize_matmul_tensor_i32(%A : tensor<16x8xi32>, %B: tensor<8x32xi32>, %C: tensor<16x32xi32>) -> tensor<16x32xi32> {
-  %0 = linalg.matmul ins(%A, %B: tensor<16x8xi32>, tensor<8x32xi32>)
+// Verifies that 
diff erent argument types is legal.
+func @generalize_matmul_tensor_i16i64i32(%A : tensor<16x8xi16>, %B: tensor<8x32xi64>, %C: tensor<16x32xi32>) -> tensor<16x32xi32> {
+  %0 = linalg.matmul ins(%A, %B: tensor<16x8xi16>, tensor<8x32xi64>)
                           outs(%C: tensor<16x32xi32>) -> tensor<16x32xi32>
   return %0: tensor<16x32xi32>
 }
 
-// CHECK-LABEL: @generalize_matmul_tensor_i32
-// CHECK:      ^{{.*}}(%[[A_ARG:.+]]: i32, %[[B_ARG:.+]]: i32, %[[C_ARG:.+]]: i32)
-// CHECK-NEXT:   %[[MUL:.+]] = muli %[[A_ARG]], %[[B_ARG]] : i32
+// CHECK-LABEL: @generalize_matmul_tensor_i16i64i32
+// CHECK:      ^{{.*}}(%[[A_ARG:.+]]: i16, %[[B_ARG:.+]]: i64, %[[C_ARG:.+]]: i32)
+// Verify signed integer extension and truncation.
+// CHECK-NEXT:   %[[A_CAST:.+]] = sexti %[[A_ARG]] : i16 to i32
+// CHECK-NEXT:   %[[B_CAST:.+]] = trunci %[[B_ARG]] : i64 to i32
+// CHECK-NEXT:   %[[MUL:.+]] = muli %[[A_CAST]], %[[B_CAST]] : i32
 // CHECK-NEXT:   %[[ADD:.+]] = addi %[[C_ARG]], %[[MUL]] : i32
 // CHECK-NEXT:   linalg.yield %[[ADD]] : i32
 // CHECK-NEXT: -> tensor<16x32xi32>
 
 // -----
 
+func @generalize_matmul_tensor_i16i64f32(%A : tensor<16x8xi16>, %B: tensor<8x32xi64>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
+  %0 = linalg.matmul ins(%A, %B: tensor<16x8xi16>, tensor<8x32xi64>)
+                     outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
+  return %0: tensor<16x32xf32>
+}
+
+// CHECK-LABEL: @generalize_matmul_tensor_i16i64f32
+// Verify signed integer to floating point cast.
+// CHECK:        = sitofp
+// CHECK:        = sitofp
+
+// -----
+
+func @generalize_matmul_tensor_f16f64i32(%A : tensor<16x8xf16>, %B: tensor<8x32xf64>, %C: tensor<16x32xi32>) -> tensor<16x32xi32> {
+  %0 = linalg.matmul ins(%A, %B: tensor<16x8xf16>, tensor<8x32xf64>)
+                              outs(%C: tensor<16x32xi32>) -> tensor<16x32xi32>
+  return %0: tensor<16x32xi32>
+}
+
+// CHECK-LABEL: @generalize_matmul_tensor_f16f64i32
+// Verify floating point to signed integer cast.
+// CHECK:        = fptosi
+// CHECK:        = fptosi
+
+// -----
+
+func @generalize_matmul_unsigned_tensor_i16i64i32(%A : tensor<16x8xi16>, %B: tensor<8x32xi64>, %C: tensor<16x32xi32>) -> tensor<16x32xi32> {
+  %0 = linalg.matmul_unsigned ins(%A, %B: tensor<16x8xi16>, tensor<8x32xi64>)
+                              outs(%C: tensor<16x32xi32>) -> tensor<16x32xi32>
+  return %0: tensor<16x32xi32>
+}
+
+// CHECK-LABEL: @generalize_matmul_unsigned_tensor_i16i64i32
+// Verify unsigned integer extension and truncation.
+// CHECK:        = zexti
+// CHECK:        = trunci
+
+// -----
+
+func @generalize_matmul_unsigned_tensor_i16i64f32(%A : tensor<16x8xi16>, %B: tensor<8x32xi64>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
+  %0 = linalg.matmul_unsigned ins(%A, %B: tensor<16x8xi16>, tensor<8x32xi64>)
+                              outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
+  return %0: tensor<16x32xf32>
+}
+
+// CHECK-LABEL: @generalize_matmul_unsigned_tensor_i16i64f32
+// Verify unsigned integer to floating point cast.
+// CHECK:        = uitofp
+// CHECK:        = uitofp
+
+// -----
+
+func @generalize_matmul_unsigned_tensor_f16f64i32(%A : tensor<16x8xf16>, %B: tensor<8x32xf64>, %C: tensor<16x32xi32>) -> tensor<16x32xi32> {
+  %0 = linalg.matmul_unsigned ins(%A, %B: tensor<16x8xf16>, tensor<8x32xf64>)
+                              outs(%C: tensor<16x32xi32>) -> tensor<16x32xi32>
+  return %0: tensor<16x32xi32>
+}
+
+// CHECK-LABEL: @generalize_matmul_unsigned_tensor_f16f64i32
+// Verify floating point to unsigend integer cast.
+// CHECK:        = fptoui
+// CHECK:        = fptoui
+
+// -----
+
 func @generalize_pooling_nhwc_max_f32(%input : tensor<1x4x16x1xf32>, %shape: tensor<2x2xf32>, %output: tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32> {
   %0 = linalg.pooling_nhwc_max {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>}
     ins(%input, %shape : tensor<1x4x16x1xf32>, tensor<2x2xf32>) outs(%output : tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32>
@@ -51,10 +124,20 @@ func @generalize_pooling_nhwc_max_i32(%input : tensor<1x4x16x1xi32>, %shape: ten
 }
 
 // CHECK-LABEL: @generalize_pooling_nhwc_max_i32
-// CHECK:      ^{{.*}}(%[[IN_ARG:.+]]: i32, %[[SHAPE_ARG:.+]]: i32, %[[OUT_ARG:.+]]: i32)
-// CHECK-NEXT:   %[[MAX:.+]] = maxsi %[[OUT_ARG]], %[[IN_ARG]] : i32
-// CHECK-NEXT:   linalg.yield %[[MAX]] : i32
-// CHECK-NEXT: -> tensor<1x2x4x1xi32>
+// Verify signed integer maximum.
+// CHECK:        = maxsi
+
+// -----
+
+func @generalize_pooling_nhwc_max_unsigned_i32(%input : tensor<1x4x16x1xi32>, %shape: tensor<2x2xi32>, %output: tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32> {
+  %0 = linalg.pooling_nhwc_max_unsigned {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>}
+    ins(%input, %shape : tensor<1x4x16x1xi32>, tensor<2x2xi32>) outs(%output : tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32>
+  return %0: tensor<1x2x4x1xi32>
+}
+
+// CHECK-LABEL: @generalize_pooling_nhwc_max_unsigned_i32
+// Verify unsigned integer minimum.
+// CHECK:        = maxui
 
 // -----
 
@@ -79,10 +162,20 @@ func @generalize_pooling_nhwc_min_i32(%input : tensor<1x4x16x1xi32>, %shape: ten
 }
 
 // CHECK-LABEL: @generalize_pooling_nhwc_min_i32
-// CHECK:      ^{{.*}}(%[[IN_ARG:.+]]: i32, %[[SHAPE_ARG:.+]]: i32, %[[OUT_ARG:.+]]: i32)
-// CHECK-NEXT:   %[[MIN:.+]] = minsi %[[OUT_ARG]], %[[IN_ARG]] : i32
-// CHECK-NEXT:   linalg.yield %[[MIN]] : i32
-// CHECK-NEXT: -> tensor<1x2x4x1xi32>
+// Verify signed integer minimum.
+// CHECK:        = minsi
+
+// -----
+
+func @generalize_pooling_nhwc_min_unsigned_i32(%input : tensor<1x4x16x1xi32>, %shape: tensor<2x2xi32>, %output: tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32> {
+  %0 = linalg.pooling_nhwc_min_unsigned {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>}
+    ins(%input, %shape : tensor<1x4x16x1xi32>, tensor<2x2xi32>) outs(%output : tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32>
+  return %0: tensor<1x2x4x1xi32>
+}
+
+// CHECK-LABEL: @generalize_pooling_nhwc_min_unsigned_i32
+// Verify unsigned integer minimum.
+// CHECK:        = minui
 
 // -----
 
@@ -169,122 +262,3 @@ func @generalize_soft_plus_2d_f32(%input: tensor<16x32xf32>, %output: tensor<16x
 // CHECK-NEXT:   %[[LOG:.+]] = math.log %[[SUM]] : f32
 // CHECK-NEXT:   linalg.yield %[[LOG]] : f32
 // CHECK-NEXT: -> tensor<16x32xf32>
-
-// -----
-// Verifies floating point to integer cast.
-func @generalize_matmul_tensor_f32_f32_i16(%A : tensor<16x8xf32>, %B: tensor<8x32xf32>, %C: tensor<16x32xi16>) -> tensor<16x32xi16> {
-  %0 = linalg.matmul ins(%A, %B: tensor<16x8xf32>, tensor<8x32xf32>)
-                          outs(%C: tensor<16x32xi16>) -> tensor<16x32xi16>
-  return %0: tensor<16x32xi16>
-}
-
-// CHECK-LABEL: @generalize_matmul_tensor_f32_f32_i16
-// CHECK:      ^{{.*}}(%[[A_ARG:.+]]: f32, %[[B_ARG:.+]]: f32, %[[C_ARG:.+]]: i16)
-// CHECK-NEXT:   %[[A_CAST:.+]] = fptosi %[[A_ARG]] : f32 to i16
-// CHECK-NEXT:   %[[B_CAST:.+]] = fptosi %[[B_ARG]] : f32 to i16
-// CHECK-NEXT:   %[[MUL:.+]] = muli %[[A_CAST]], %[[B_CAST]] : i16
-// CHECK-NEXT:   %[[ADD:.+]] = addi %[[C_ARG]], %[[MUL]] : i16
-// CHECK-NEXT:   linalg.yield %[[ADD]] : i16
-// CHECK-NEXT: -> tensor<16x32xi16>
-
-// -----
-// Verifies sign extension cast.
-func @generalize_matmul_tensor_i8_i8_i32(%A : tensor<16x8xi8>, %B: tensor<8x32xi8>, %C: tensor<16x32xi32>) -> tensor<16x32xi32> {
-  %0 = linalg.matmul ins(%A, %B: tensor<16x8xi8>, tensor<8x32xi8>)
-                          outs(%C: tensor<16x32xi32>) -> tensor<16x32xi32>
-  return %0: tensor<16x32xi32>
-}
-
-// CHECK-LABEL: @generalize_matmul_tensor_i8_i8_i32
-// CHECK:      ^{{.*}}(%[[A_ARG:.+]]: i8, %[[B_ARG:.+]]: i8, %[[C_ARG:.+]]: i32)
-// CHECK-NEXT:   %[[A_CAST:.+]] = sexti %[[A_ARG]] : i8 to i32
-// CHECK-NEXT:   %[[B_CAST:.+]] = sexti %[[B_ARG]] : i8 to i32
-// CHECK-NEXT:   %[[MUL:.+]] = muli %[[A_CAST]], %[[B_CAST]] : i32
-// CHECK-NEXT:   %[[ADD:.+]] = addi %[[C_ARG]], %[[MUL]] : i32
-// CHECK-NEXT:   linalg.yield %[[ADD]] : i32
-// CHECK-NEXT: -> tensor<16x32xi32>
-
-// -----
-// Verifies that 
diff erent argument types is legal.
-func @generalize_matmul_tensor_i8_i16_i32(%A : tensor<16x8xi8>, %B: tensor<8x32xi16>, %C: tensor<16x32xi32>) -> tensor<16x32xi32> {
-  %0 = linalg.matmul ins(%A, %B: tensor<16x8xi8>, tensor<8x32xi16>)
-                          outs(%C: tensor<16x32xi32>) -> tensor<16x32xi32>
-  return %0: tensor<16x32xi32>
-}
-
-// CHECK-LABEL: @generalize_matmul_tensor_i8_i16_i32
-// CHECK:      ^{{.*}}(%[[A_ARG:.+]]: i8, %[[B_ARG:.+]]: i16, %[[C_ARG:.+]]: i32)
-// CHECK-NEXT:   %[[A_CAST:.+]] = sexti %[[A_ARG]] : i8 to i32
-// CHECK-NEXT:   %[[B_CAST:.+]] = sexti %[[B_ARG]] : i16 to i32
-// CHECK-NEXT:   %[[MUL:.+]] = muli %[[A_CAST]], %[[B_CAST]] : i32
-// CHECK-NEXT:   %[[ADD:.+]] = addi %[[C_ARG]], %[[MUL]] : i32
-// CHECK-NEXT:   linalg.yield %[[ADD]] : i32
-// CHECK-NEXT: -> tensor<16x32xi32>
-
-// -----
-// Somewhat non-sensical but checks integer truncation cast.
-func @generalize_matmul_tensor_i32_i32_i16(%A : tensor<16x8xi32>, %B: tensor<8x32xi32>, %C: tensor<16x32xi16>) -> tensor<16x32xi16> {
-  %0 = linalg.matmul ins(%A, %B: tensor<16x8xi32>, tensor<8x32xi32>)
-                          outs(%C: tensor<16x32xi16>) -> tensor<16x32xi16>
-  return %0: tensor<16x32xi16>
-}
-
-// CHECK-LABEL: @generalize_matmul_tensor_i32_i32_i16
-// CHECK:      ^{{.*}}(%[[A_ARG:.+]]: i32, %[[B_ARG:.+]]: i32, %[[C_ARG:.+]]: i16)
-// CHECK-NEXT:   %[[A_CAST:.+]] = trunci %[[A_ARG]] : i32 to i16
-// CHECK-NEXT:   %[[B_CAST:.+]] = trunci %[[B_ARG]] : i32 to i16
-// CHECK-NEXT:   %[[MUL:.+]] = muli %[[A_CAST]], %[[B_CAST]] : i16
-// CHECK-NEXT:   %[[ADD:.+]] = addi %[[C_ARG]], %[[MUL]] : i16
-// CHECK-NEXT:   linalg.yield %[[ADD]] : i16
-// CHECK-NEXT: -> tensor<16x32xi16>
-
-// -----
-// Verifies integer to floating point cast.
-func @generalize_matmul_tensor_i8_i8_f32(%A : tensor<16x8xi8>, %B: tensor<8x32xi8>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
-  %0 = linalg.matmul ins(%A, %B: tensor<16x8xi8>, tensor<8x32xi8>)
-                          outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
-  return %0: tensor<16x32xf32>
-}
-
-// CHECK-LABEL: @generalize_matmul_tensor_i8_i8_f32
-// CHECK:      ^{{.*}}(%[[A_ARG:.+]]: i8, %[[B_ARG:.+]]: i8, %[[C_ARG:.+]]: f32)
-// CHECK-NEXT:   %[[A_CAST:.+]] = sitofp %[[A_ARG]] : i8 to f32
-// CHECK-NEXT:   %[[B_CAST:.+]] = sitofp %[[B_ARG]] : i8 to f32
-// CHECK-NEXT:   %[[MUL:.+]] = mulf %[[A_CAST]], %[[B_CAST]] : f32
-// CHECK-NEXT:   %[[ADD:.+]] = addf %[[C_ARG]], %[[MUL]] : f32
-// CHECK-NEXT:   linalg.yield %[[ADD]] : f32
-// CHECK-NEXT: -> tensor<16x32xf32>
-
-// -----
-// Verifies floating point extension cast.
-func @generalize_matmul_tensor_f16_f16_f32(%A : tensor<16x8xf16>, %B: tensor<8x32xf16>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
-  %0 = linalg.matmul ins(%A, %B: tensor<16x8xf16>, tensor<8x32xf16>)
-                          outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
-  return %0: tensor<16x32xf32>
-}
-
-// CHECK-LABEL: @generalize_matmul_tensor_f16_f16_f32
-// CHECK:      ^{{.*}}(%[[A_ARG:.+]]: f16, %[[B_ARG:.+]]: f16, %[[C_ARG:.+]]: f32)
-// CHECK-NEXT:   %[[A_CAST:.+]] = fpext %[[A_ARG]] : f16 to f32
-// CHECK-NEXT:   %[[B_CAST:.+]] = fpext %[[B_ARG]] : f16 to f32
-// CHECK-NEXT:   %[[MUL:.+]] = mulf %[[A_CAST]], %[[B_CAST]] : f32
-// CHECK-NEXT:   %[[ADD:.+]] = addf %[[C_ARG]], %[[MUL]] : f32
-// CHECK-NEXT:   linalg.yield %[[ADD]] : f32
-// CHECK-NEXT: -> tensor<16x32xf32>
-
-// -----
-// Verifies floating point truncation.
-func @generalize_matmul_tensor_f64_f64_f32(%A : tensor<16x8xf64>, %B: tensor<8x32xf64>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
-  %0 = linalg.matmul ins(%A, %B: tensor<16x8xf64>, tensor<8x32xf64>)
-                          outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
-  return %0: tensor<16x32xf32>
-}
-
-// CHECK-LABEL: @generalize_matmul_tensor_f64_f64_f32
-// CHECK:      ^{{.*}}(%[[A_ARG:.+]]: f64, %[[B_ARG:.+]]: f64, %[[C_ARG:.+]]: f32)
-// CHECK-NEXT:   %[[A_CAST:.+]] = fptrunc %[[A_ARG]] : f64 to f32
-// CHECK-NEXT:   %[[B_CAST:.+]] = fptrunc %[[B_ARG]] : f64 to f32
-// CHECK-NEXT:   %[[MUL:.+]] = mulf %[[A_CAST]], %[[B_CAST]] : f32
-// CHECK-NEXT:   %[[ADD:.+]] = addf %[[C_ARG]], %[[MUL]] : f32
-// CHECK-NEXT:   linalg.yield %[[ADD]] : f32
-// CHECK-NEXT: -> tensor<16x32xf32>

diff  --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
index 6613ab2a006f1..b1b8077016767 100644
--- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
+++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
@@ -43,12 +43,14 @@ structured_op: !LinalgStructuredOpConfig
             operands:
             - !ScalarExpression
               scalar_const: '42 : i64'
+            is_unsigned_cast: false
         - !ScalarExpression
           symbolic_cast:
             type_var: T
             operands:
             - !ScalarExpression
               scalar_index: 1
+            is_unsigned_cast: true
 
 # ODS-LABEL:  def Test1Op : LinalgStructuredBase_Op<"test1"
 
@@ -84,9 +86,9 @@ structured_op: !LinalgStructuredOpConfig
 # IMPL-LABEL:  void Test1Op::regionBuilder(
 #       IMPL:    ImplicitLocOpBuilder &b, Block &block)
 #       IMPL:  Value [[VAL0:[a-z0-9]+]] = helper.constant("42 : i64");
-#   IMPL-DAG:  Value [[VAL1:[a-z0-9]+]] = helper.cast(block.getArgument(0).getType(), [[VAL0]]);
+#   IMPL-DAG:  Value [[VAL1:[a-z0-9]+]] = helper.cast(block.getArgument(0).getType(), [[VAL0]], false);
 #   IMPL-DAG:  Value [[VAL2:[a-z0-9]+]] = helper.index(1);
-#   IMPL-DAG:  Value [[VAL3:[a-z0-9]+]] = helper.cast(block.getArgument(0).getType(), [[VAL2]]);
+#   IMPL-DAG:  Value [[VAL3:[a-z0-9]+]] = helper.cast(block.getArgument(0).getType(), [[VAL2]], true);
 #   IMPL-DAG:  Value [[VAL4:[a-z0-9]+]] = helper.applyfn__add([[VAL1]], [[VAL3]]);
 
 

diff  --git a/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py b/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py
index 16a82f63dbc83..971c0aaac2926 100644
--- a/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py
+++ b/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py
@@ -29,6 +29,15 @@ def matmul_poly(
   C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
 
 
+ at linalg_structured_op
+def matmul_unsigned_poly(
+    A=TensorDef(T1, S.M, S.K),
+    B=TensorDef(T2, S.K, S.N),
+    C=TensorDef(U, S.M, S.N, output=True)):
+  domain(D.m, D.n, D.k)
+  C[D.m, D.n] += cast_unsigned(U, A[D.m, D.k]) * cast_unsigned(U, B[D.k, D.n])
+
+
 @linalg_structured_op
 def conv_poly(
     I=TensorDef(T1, S.N, S.IH, S.IW, S.C),
@@ -54,6 +63,17 @@ def pooling_max_poly(
       cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
                 D.c]))
 
+ at linalg_structured_op
+def pooling_max_unsigned_poly(
+    I=TensorDef(T1, S.N, S.H, S.W, S.C),
+    K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
+    O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
+    strides=AttributeDef(S.SH, S.SW),
+    dilations=AttributeDef(S.DH, S.DW)):
+  domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
+  O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_unsigned(D.kh, D.kw)(
+      cast_unsigned(
+          U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
 
 @linalg_structured_op
 def pooling_min_poly(
@@ -67,6 +87,17 @@ def pooling_min_poly(
       cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
                 D.c]))
 
+ at linalg_structured_op
+def pooling_min_unsigned_poly(
+    I=TensorDef(T1, S.N, S.H, S.W, S.C),
+    K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
+    O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
+    strides=AttributeDef(S.SH, S.SW),
+    dilations=AttributeDef(S.DH, S.DW)):
+  domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
+  O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_unsigned(D.kh, D.kw)(
+      cast_unsigned(
+          U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
 
 @linalg_structured_op
 def fill_rng_poly(
@@ -147,6 +178,15 @@ def test_matmul_mono(lhs, rhs):
     def test_i8i8i32_matmul(lhs, rhs, init_result):
       return matmul_poly(lhs, rhs, outs=[init_result])
 
+    # CHECK-LABEL: @test_i8i8i32_matmul_unsigned
+    # CHECK:   = zexti
+    # CHECK:   = zexti
+    @builtin.FuncOp.from_py_func(
+        RankedTensorType.get((4, 16), i8), RankedTensorType.get((16, 8), i8),
+        RankedTensorType.get((4, 8), i32))
+    def test_i8i8i32_matmul_unsigned(lhs, rhs, init_result):
+      return matmul_unsigned_poly(lhs, rhs, outs=[init_result])
+
     # CHECK-LABEL: @test_i8i16i32_matmul
     # CHECK:      ^{{.*}}(%[[A_ARG:.+]]: i8, %[[B_ARG:.+]]: i16, %[[C_ARG:.+]]: i32)
     # CHECK-NEXT:   %[[A_CAST:.+]] = sexti %[[A_ARG]] : i8 to i32
@@ -189,6 +229,15 @@ def test_i32i32i16_matmul(lhs, rhs, init_result):
     def test_i8i8f32_matmul(lhs, rhs, init_result):
       return matmul_poly(lhs, rhs, outs=[init_result])
 
+    # CHECK-LABEL: @test_i8i8f32_matmul_unsigned
+    # CHECK:   = uitofp
+    # CHECK:   = uitofp
+    @builtin.FuncOp.from_py_func(
+        RankedTensorType.get((4, 16), i8), RankedTensorType.get((16, 8), i8),
+        RankedTensorType.get((4, 8), f32))
+    def test_i8i8f32_matmul_unsigned(lhs, rhs, init_result):
+      return matmul_unsigned_poly(lhs, rhs, outs=[init_result])
+
     # CHECK-LABEL: @test_f16f16f32_matmul
     # CHECK:      ^{{.*}}(%[[A_ARG:.+]]: f16, %[[B_ARG:.+]]: f16, %[[C_ARG:.+]]: f32)
     # CHECK-NEXT:   %[[A_CAST:.+]] = fpext %[[A_ARG]] : f16 to f32
@@ -252,6 +301,16 @@ def test_f32i32_max_pooling(input, shape, init_result):
       return pooling_max_poly(
           input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2])
 
+    # CHECK-LABEL: @test_f32i32_max_unsigned_pooling
+    # CHECK:   = fptoui
+    # CHECK:   = maxui
+    @builtin.FuncOp.from_py_func(
+        RankedTensorType.get((4, 16), f32), RankedTensorType.get((2, 2), f32),
+        RankedTensorType.get((2, 4), i32))
+    def test_f32i32_max_unsigned_pooling(input, shape, init_result):
+      return pooling_max_unsigned_poly(
+          input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2])
+
     # CHECK-LABEL: @test_f32f32_max_pooling
     # CHECK: linalg.generic
     # CHECK-SAME: indexing_maps = [#[[$CONV_MAP_I]], #[[$POOL_MAP_K]], #[[$CONV_MAP_O]]]
@@ -268,6 +327,7 @@ def test_f32f32_max_pooling(input, shape, init_result):
           input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2])
 
     # CHECK-LABEL: @test_f32i32_min_pooling
+    # CHECK:   = fptosi
     # CHECK:   = minsi
     @builtin.FuncOp.from_py_func(
         RankedTensorType.get((4, 16), f32), RankedTensorType.get((2, 2), f32),
@@ -276,6 +336,16 @@ def test_f32i32_min_pooling(input, shape, init_result):
       return pooling_min_poly(
           input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2])
 
+    # CHECK-LABEL: @test_f32i32_min_unsigned_pooling
+    # CHECK:   = fptoui
+    # CHECK:   = minui
+    @builtin.FuncOp.from_py_func(
+        RankedTensorType.get((4, 16), f32), RankedTensorType.get((2, 2), f32),
+        RankedTensorType.get((2, 4), i32))
+    def test_f32i32_min_unsigned_pooling(input, shape, init_result):
+      return pooling_min_unsigned_poly(
+          input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2])
+
     # CHECK-LABEL: @test_f32f32_min_pooling
     # CHECK:   = minf
     @builtin.FuncOp.from_py_func(

diff  --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
index 98e90b69d631d..44eb34a36b499 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
@@ -95,6 +95,7 @@ struct ScalarSymbolicCast {
   // NOTE: This must be of arity 1, but to break the self-referential cycle,
   // we use a heap allocated vector.
   std::vector<ScalarExpression> operands;
+  bool isUnsignedCast;
 };
 
 struct ScalarExpression {
@@ -278,6 +279,7 @@ struct MappingTraits<ScalarSymbolicCast> {
   static void mapping(IO &io, ScalarSymbolicCast &info) {
     io.mapRequired("type_var", info.typeVar);
     io.mapRequired("operands", info.operands);
+    io.mapRequired("is_unsigned_cast", info.isUnsignedCast);
   }
 };
 
@@ -986,9 +988,10 @@ void {0}::regionBuilder(ImplicitLocOpBuilder &b, Block &block) {{
             return None;
           }
           std::string cppIdent = llvm::formatv("value{0}", ++localCounter);
-          stmts.push_back(llvm::formatv("Value {0} = helper.cast({1}, {2});",
-                                        cppIdent, typeCppValue.getValue(),
-                                        *operandCppValue));
+          stmts.push_back(
+              llvm::formatv("Value {0} = helper.cast({1}, {2}, {3});", cppIdent,
+                            typeCppValue.getValue(), *operandCppValue,
+                            expression.symbolicCast->isUnsignedCast));
           return cppIdent;
         }
         emitError(genContext.getLoc()) << "unknown ScalarExpression type";


        


More information about the Mlir-commits mailing list