[Mlir-commits] [mlir] e9085d0 - [mlir][OpDSL] Rename function to make signedness explicit (NFC).

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Mar 1 00:19:46 PST 2022


Author: gysit
Date: 2022-03-01T08:15:53Z
New Revision: e9085d0d2558aebb735d23c47ff8595171a2ce93

URL: https://github.com/llvm/llvm-project/commit/e9085d0d2558aebb735d23c47ff8595171a2ce93
DIFF: https://github.com/llvm/llvm-project/commit/e9085d0d2558aebb735d23c47ff8595171a2ce93.diff

LOG: [mlir][OpDSL] Rename function to make signedness explicit (NFC).

The revision renames the following OpDSL functions:
```
TypeFn.cast -> TypeFn.cast_signed
BinaryFn.min -> BinaryFn.min_signed
BinaryFn.max -> BinaryFn.max_signed
```
The corresponding enum values on the C++ side are renamed accordingly:
```
#linalg.type_fn<cast> -> #linalg.type_fn<cast_signed>
#linalg.binary_fn<min> -> #linalg.binary_fn<min_signed>
#linalg.binary_fn<max> -> #linalg.binary_fn<max_signed>
```

Depends On D120110

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D120562

Added: 
    

Modified: 
    mlir/docs/Dialects/Linalg/OpDSL.md
    mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
    mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
    mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
    mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
    mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
    mlir/test/python/dialects/linalg/opdsl/arguments.py
    mlir/test/python/dialects/linalg/opdsl/assignments.py
    mlir/test/python/dialects/linalg/opdsl/emit_convolution.py
    mlir/test/python/dialects/linalg/opdsl/emit_fill.py
    mlir/test/python/dialects/linalg/opdsl/emit_matmul.py
    mlir/test/python/dialects/linalg/opdsl/emit_misc.py
    mlir/test/python/dialects/linalg/opdsl/emit_pooling.py
    mlir/test/python/dialects/linalg/opdsl/interfaces.py
    mlir/test/python/dialects/linalg/opdsl/shape_maps_iteration.py
    mlir/test/python/dialects/linalg/ops.py

Removed: 
    


################################################################################
diff  --git a/mlir/docs/Dialects/Linalg/OpDSL.md b/mlir/docs/Dialects/Linalg/OpDSL.md
index dd068b1f400c1..d7526bf9f3bab 100644
--- a/mlir/docs/Dialects/Linalg/OpDSL.md
+++ b/mlir/docs/Dialects/Linalg/OpDSL.md
@@ -56,7 +56,8 @@ def matmul(A=TensorDef(T1, S.M, S.K),
   """
   domain(D.m, D.n, D.k)
   implements(ContractionOpInterface)
-  C[D.m, D.n] += TypeFn.cast(U, A[D.m, D.k]) * TypeFn.cast(U, B[D.k, D.n])
+  C[D.m, D.n] += TypeFn.cast_signed(
+      U, A[D.m, D.k]) * TypeFn.cast_signed(U, B[D.k, D.n])
 ```
 
 Here we have a simple type polymorphic contraction that takes arguments `A` and
@@ -160,7 +161,7 @@ def pooling_poly(
     O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
     strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
     dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
-  O[D.n, D.oh, D.ow, D.c] += TypeFn.cast(U,
+  O[D.n, D.oh, D.ow, D.c] += TypeFn.cast_signed(U,
           I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])
 ```
 
@@ -182,8 +183,8 @@ A number of unary and binary arithmetic functions are supported:
 
 *   `BinaryFn.add(a, b)` (also via overloading the binary `+` operator)
 *   `BinaryFn.mul(a, b)` (also via overloading the binary `*` operator)
-*   `BinaryFn.max(a, b)`
-*   `BinaryFn.min(a, b)`
+*   `BinaryFn.max_signed(a, b)`
+*   `BinaryFn.min_signed(a, b)`
 *   `BinaryFn.sub(a, b)` (also via overloading the binary `-` operator)
 *   `BinaryFn.max_unsigned(a, b)`
 *   `BinaryFn.min_unsigned(a, b)`
@@ -198,8 +199,8 @@ reduction functions can appear as the outermost function on the RHS:
 
 *   `ReduceFn.add` (also overloading the inplace `+=` on a LHS)
 *   `ReduceFn.mul`
-*   `ReduceFn.max`
-*   `ReduceFn.min`
+*   `ReduceFn.max_signed`
+*   `ReduceFn.min_signed`
 *   `ReduceFn.max_unsigned`
 *   `ReduceFn.min_unsigned`
 
@@ -208,11 +209,11 @@ functions that treat integers as signed or unsigned values.
 
 Additionally, type conversion functions cast an operand to a target type:
 
-*   `TypeFn.cast(TypeVar, operand)`
+*   `TypeFn.cast_signed(TypeVar, operand)`
 *   `TypeFn.cast_unsigned(TypeVar, operand)`
 
 As the integer types are signless, signedness is implement by 
diff erent
-functions that treat integers as signed (`TypeFn.cast`) or unsigned
+functions that treat integers as signed (`TypeFn.cast_signed`) or unsigned
 (`TypeFn.cast_unsigned`) values.
 
 There are also special forms:
@@ -235,12 +236,12 @@ def elemwise_binary(
     rhs=TensorDef(T2),
     O=TensorDef(U, output=True),
     fun=BinaryFnAttrDef(default=BinaryFn.add),
-    cast=TypeFnAttrDef(default=TypeFn.cast)):
+    cast=TypeFnAttrDef(default=TypeFn.cast_signed)):
   O[None] = fun(cast(U, lhs[None]), cast(U, rhs[None]))
 ```
 
 The `fun` and `cast` function attributes by default are aliases for their
-default values `BinaryFn.add` and `TypeFn.cast`, respectively. When
+default values `BinaryFn.add` and `TypeFn.cast_signed`, respectively. When
 instantiating the operation, the function attributes may be set to other
 functions using optional named arguments:
 
@@ -265,26 +266,27 @@ output types of constructed ops. An exception are predefined types such as
 computations with a type that is independent of the input and output types. For
 example, parts of floating point computation may require double precision
 arithmetic despite all inputs and outputs being single precision values.
-Assignment expressions with no `TypeFn.cast` calls will generally require
+Assignment expressions with no `TypeFn.cast_signed` calls will generally require
 uniform types throughout and will fail to verify if violated. The presence of a
-`TypeFn.cast` or `TypeFn.cast_unsigned` allows for a limited form of numeric
-type conversion between element types that can be derived from inputs and
-outputs (and in the future, attributes). `TypeFn.cast` calls with a `TypeVar`
-first argument are emitted as `type_fn` primitives in the YAML definition.
+`TypeFn.cast_signed` or `TypeFn.cast_unsigned` allows for a limited form of
+numeric type conversion between element types that can be derived from inputs
+and outputs (and in the future, attributes). `TypeFn.cast_signed` calls with a
+`TypeVar` first argument are emitted as `type_fn` primitives in the YAML
+definition.
 
 Casting will perform `int<->float` and `index->int` type conversions and will
 perform any necessary extension or truncation within the type family. The
 integer types themselves are signless and signedness is implemented by
-functions/operations. The `TypeFn.cast` function treats all integers as signed,
-while `TypeFn.cast_unsigned` treats them as unsigned.
+functions/operations. The `TypeFn.cast_signed` function treats all integers as
+signed, while `TypeFn.cast_unsigned` treats them as unsigned.
 
 The following examples illustrate the lowering of signed and unsigned functions:
 
-*   cast(I32 -> I64) -> `arith.ExtSIOp`
-*   cast(F32 -> I32) -> `arith.FPToSIOp`
+*   cast_signed(I32 -> I64) -> `arith.ExtSIOp`
+*   cast_signed(F32 -> I32) -> `arith.FPToSIOp`
 *   cast_unsigned(I32 -> I64) -> `arith.ExtUIOp`
 *   cast_unsigned(F32 -> I32) -> `arith.FPToUIOp`
-*   max -> `arith.MaxSIOp`
+*   max_signed -> `arith.MaxSIOp`
 *   max_unsinged -> `arith.MaxUIOp`
 
 Not all functions are applicable for all numeric types, and on mismatch, op
@@ -302,7 +304,7 @@ An example for a rank polymorphic operation is `fill`:
 @linalg_structured_op
 def fill(value=ScalarDef(T1),
          O=TensorDef(U, output=True)):
-  O[None] = TypeFn.cast(U, value)
+  O[None] = TypeFn.cast_signed(U, value)
 ```
 
 The operation sets the elements of the output tensor `O` to `value`. All

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
index c60e1b646e533..f962eb6b2a869 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
@@ -68,10 +68,10 @@ def UnaryFn : I32EnumAttr<"UnaryFn", "", [
 }
 def BinaryFn : I32EnumAttr<"BinaryFn", "", [
   I32EnumAttrCase<"add", 0>,
-  I32EnumAttrCase<"mul", 1>,
-  I32EnumAttrCase<"max", 2>,
-  I32EnumAttrCase<"min", 3>,
-  I32EnumAttrCase<"sub", 4>,
+  I32EnumAttrCase<"sub", 1>,
+  I32EnumAttrCase<"mul", 2>,
+  I32EnumAttrCase<"max_signed", 3>,
+  I32EnumAttrCase<"min_signed", 4>,
   I32EnumAttrCase<"max_unsigned", 5>,
   I32EnumAttrCase<"min_unsigned", 6>
 ]> {
@@ -79,7 +79,7 @@ def BinaryFn : I32EnumAttr<"BinaryFn", "", [
   let cppNamespace = "::mlir::linalg";
 }
 def TypeFn : I32EnumAttr<"TypeFn", "", [
-  I32EnumAttrCase<"cast", 0>,
+  I32EnumAttrCase<"cast_signed", 0>,
   I32EnumAttrCase<"cast_unsigned", 1>
 ]> {
   let genSpecializedAttr = 0;

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 73fd114daf32e..e296004603673 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -28,7 +28,7 @@ structured_op: !LinalgStructuredOpConfig
   - !LinalgOperandDefConfig
     name: cast
     kind: type_fn_attr
-    default_fn: cast
+    default_fn: cast_signed
   indexing_maps: !LinalgIndexingMapsConfig
     static_indexing_maps:
     - affine_map<() -> ()>
@@ -83,7 +83,7 @@ structured_op: !LinalgStructuredOpConfig
   - !LinalgOperandDefConfig
     name: cast
     kind: type_fn_attr
-    default_fn: cast
+    default_fn: cast_signed
   indexing_maps: !LinalgIndexingMapsConfig
     static_indexing_maps:
     - affine_map<() -> ()>
@@ -145,7 +145,7 @@ structured_op: !LinalgStructuredOpConfig
   - !LinalgOperandDefConfig
     name: cast
     kind: type_fn_attr
-    default_fn: cast
+    default_fn: cast_signed
   indexing_maps: !LinalgIndexingMapsConfig
     static_indexing_maps:
     - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)>
@@ -324,7 +324,7 @@ structured_op: !LinalgStructuredOpConfig
                 - !ScalarExpression
                   scalar_fn:
                     kind: type
-                    fn_name: cast
+                    fn_name: cast_signed
                     type_var: U
                     operands:
                     - !ScalarExpression
@@ -332,7 +332,7 @@ structured_op: !LinalgStructuredOpConfig
                 - !ScalarExpression
                   scalar_fn:
                     kind: type
-                    fn_name: cast
+                    fn_name: cast_signed
                     type_var: U
                     operands:
                     - !ScalarExpression
@@ -345,7 +345,7 @@ structured_op: !LinalgStructuredOpConfig
                 - !ScalarExpression
                   scalar_fn:
                     kind: type
-                    fn_name: cast
+                    fn_name: cast_signed
                     type_var: U
                     operands:
                     - !ScalarExpression
@@ -353,7 +353,7 @@ structured_op: !LinalgStructuredOpConfig
                 - !ScalarExpression
                   scalar_fn:
                     kind: type
-                    fn_name: cast
+                    fn_name: cast_signed
                     type_var: U
                     operands:
                     - !ScalarExpression
@@ -424,7 +424,7 @@ structured_op: !LinalgStructuredOpConfig
             - !ScalarExpression
               scalar_fn:
                 kind: type
-                fn_name: cast
+                fn_name: cast_signed
                 type_var: AccumType
                 operands:
                 - !ScalarExpression
@@ -432,7 +432,7 @@ structured_op: !LinalgStructuredOpConfig
             - !ScalarExpression
               scalar_fn:
                 kind: type
-                fn_name: cast
+                fn_name: cast_signed
                 type_var: AccumType
                 operands:
                 - !ScalarExpression
@@ -493,7 +493,7 @@ structured_op: !LinalgStructuredOpConfig
             - !ScalarExpression
               scalar_fn:
                 kind: type
-                fn_name: cast
+                fn_name: cast_signed
                 type_var: U
                 operands:
                 - !ScalarExpression
@@ -501,7 +501,7 @@ structured_op: !LinalgStructuredOpConfig
             - !ScalarExpression
               scalar_fn:
                 kind: type
-                fn_name: cast
+                fn_name: cast_signed
                 type_var: U
                 operands:
                 - !ScalarExpression
@@ -577,7 +577,7 @@ structured_op: !LinalgStructuredOpConfig
                 - !ScalarExpression
                   scalar_fn:
                     kind: type
-                    fn_name: cast
+                    fn_name: cast_signed
                     type_var: U
                     operands:
                     - !ScalarExpression
@@ -585,7 +585,7 @@ structured_op: !LinalgStructuredOpConfig
                 - !ScalarExpression
                   scalar_fn:
                     kind: type
-                    fn_name: cast
+                    fn_name: cast_signed
                     type_var: U
                     operands:
                     - !ScalarExpression
@@ -598,7 +598,7 @@ structured_op: !LinalgStructuredOpConfig
                 - !ScalarExpression
                   scalar_fn:
                     kind: type
-                    fn_name: cast
+                    fn_name: cast_signed
                     type_var: U
                     operands:
                     - !ScalarExpression
@@ -606,7 +606,7 @@ structured_op: !LinalgStructuredOpConfig
                 - !ScalarExpression
                   scalar_fn:
                     kind: type
-                    fn_name: cast
+                    fn_name: cast_signed
                     type_var: U
                     operands:
                     - !ScalarExpression
@@ -665,7 +665,7 @@ structured_op: !LinalgStructuredOpConfig
             - !ScalarExpression
               scalar_fn:
                 kind: type
-                fn_name: cast
+                fn_name: cast_signed
                 type_var: U
                 operands:
                 - !ScalarExpression
@@ -673,7 +673,7 @@ structured_op: !LinalgStructuredOpConfig
             - !ScalarExpression
               scalar_fn:
                 kind: type
-                fn_name: cast
+                fn_name: cast_signed
                 type_var: U
                 operands:
                 - !ScalarExpression
@@ -732,7 +732,7 @@ structured_op: !LinalgStructuredOpConfig
             - !ScalarExpression
               scalar_fn:
                 kind: type
-                fn_name: cast
+                fn_name: cast_signed
                 type_var: U
                 operands:
                 - !ScalarExpression
@@ -740,7 +740,7 @@ structured_op: !LinalgStructuredOpConfig
             - !ScalarExpression
               scalar_fn:
                 kind: type
-                fn_name: cast
+                fn_name: cast_signed
                 type_var: U
                 operands:
                 - !ScalarExpression
@@ -800,7 +800,7 @@ structured_op: !LinalgStructuredOpConfig
             - !ScalarExpression
               scalar_fn:
                 kind: type
-                fn_name: cast
+                fn_name: cast_signed
                 type_var: U
                 operands:
                 - !ScalarExpression
@@ -808,7 +808,7 @@ structured_op: !LinalgStructuredOpConfig
             - !ScalarExpression
               scalar_fn:
                 kind: type
-                fn_name: cast
+                fn_name: cast_signed
                 type_var: U
                 operands:
                 - !ScalarExpression
@@ -866,7 +866,7 @@ structured_op: !LinalgStructuredOpConfig
             - !ScalarExpression
               scalar_fn:
                 kind: type
-                fn_name: cast
+                fn_name: cast_signed
                 type_var: U
                 operands:
                 - !ScalarExpression
@@ -874,7 +874,7 @@ structured_op: !LinalgStructuredOpConfig
             - !ScalarExpression
               scalar_fn:
                 kind: type
-                fn_name: cast
+                fn_name: cast_signed
                 type_var: U
                 operands:
                 - !ScalarExpression
@@ -933,7 +933,7 @@ structured_op: !LinalgStructuredOpConfig
             - !ScalarExpression
               scalar_fn:
                 kind: type
-                fn_name: cast
+                fn_name: cast_signed
                 type_var: U
                 operands:
                 - !ScalarExpression
@@ -941,7 +941,7 @@ structured_op: !LinalgStructuredOpConfig
             - !ScalarExpression
               scalar_fn:
                 kind: type
-                fn_name: cast
+                fn_name: cast_signed
                 type_var: U
                 operands:
                 - !ScalarExpression
@@ -1002,7 +1002,7 @@ structured_op: !LinalgStructuredOpConfig
             - !ScalarExpression
               scalar_fn:
                 kind: type
-                fn_name: cast
+                fn_name: cast_signed
                 type_var: U
                 operands:
                 - !ScalarExpression
@@ -1010,7 +1010,7 @@ structured_op: !LinalgStructuredOpConfig
             - !ScalarExpression
               scalar_fn:
                 kind: type
-                fn_name: cast
+                fn_name: cast_signed
                 type_var: U
                 operands:
                 - !ScalarExpression
@@ -1074,7 +1074,7 @@ structured_op: !LinalgStructuredOpConfig
             - !ScalarExpression
               scalar_fn:
                 kind: type
-                fn_name: cast
+                fn_name: cast_signed
                 type_var: U
                 operands:
                 - !ScalarExpression
@@ -1082,7 +1082,7 @@ structured_op: !LinalgStructuredOpConfig
             - !ScalarExpression
               scalar_fn:
                 kind: type
-                fn_name: cast
+                fn_name: cast_signed
                 type_var: U
                 operands:
                 - !ScalarExpression
@@ -1158,7 +1158,7 @@ structured_op: !LinalgStructuredOpConfig
             - !ScalarExpression
               scalar_fn:
                 kind: type
-                fn_name: cast
+                fn_name: cast_signed
                 type_var: U
                 operands:
                 - !ScalarExpression
@@ -1166,7 +1166,7 @@ structured_op: !LinalgStructuredOpConfig
             - !ScalarExpression
               scalar_fn:
                 kind: type
-                fn_name: cast
+                fn_name: cast_signed
                 type_var: U
                 operands:
                 - !ScalarExpression
@@ -1256,7 +1256,7 @@ structured_op: !LinalgStructuredOpConfig
             - !ScalarExpression
               scalar_fn:
                 kind: type
-                fn_name: cast
+                fn_name: cast_signed
                 type_var: U
                 operands:
                 - !ScalarExpression
@@ -1264,7 +1264,7 @@ structured_op: !LinalgStructuredOpConfig
             - !ScalarExpression
               scalar_fn:
                 kind: type
-                fn_name: cast
+                fn_name: cast_signed
                 type_var: U
                 operands:
                 - !ScalarExpression
@@ -1372,7 +1372,7 @@ structured_op: !LinalgStructuredOpConfig
                 - !ScalarExpression
                   scalar_fn:
                     kind: type
-                    fn_name: cast
+                    fn_name: cast_signed
                     type_var: U
                     operands:
                     - !ScalarExpression
@@ -1380,7 +1380,7 @@ structured_op: !LinalgStructuredOpConfig
                 - !ScalarExpression
                   scalar_fn:
                     kind: type
-                    fn_name: cast
+                    fn_name: cast_signed
                     type_var: U
                     operands:
                     - !ScalarExpression
@@ -1393,7 +1393,7 @@ structured_op: !LinalgStructuredOpConfig
                 - !ScalarExpression
                   scalar_fn:
                     kind: type
-                    fn_name: cast
+                    fn_name: cast_signed
                     type_var: U
                     operands:
                     - !ScalarExpression
@@ -1401,7 +1401,7 @@ structured_op: !LinalgStructuredOpConfig
                 - !ScalarExpression
                   scalar_fn:
                     kind: type
-                    fn_name: cast
+                    fn_name: cast_signed
                     type_var: U
                     operands:
                     - !ScalarExpression
@@ -1491,7 +1491,7 @@ structured_op: !LinalgStructuredOpConfig
             - !ScalarExpression
               scalar_fn:
                 kind: type
-                fn_name: cast
+                fn_name: cast_signed
                 type_var: U
                 operands:
                 - !ScalarExpression
@@ -1499,7 +1499,7 @@ structured_op: !LinalgStructuredOpConfig
             - !ScalarExpression
               scalar_fn:
                 kind: type
-                fn_name: cast
+                fn_name: cast_signed
                 type_var: U
                 operands:
                 - !ScalarExpression
@@ -1591,7 +1591,7 @@ structured_op: !LinalgStructuredOpConfig
             - !ScalarExpression
               scalar_fn:
                 kind: type
-                fn_name: cast
+                fn_name: cast_signed
                 type_var: U
                 operands:
                 - !ScalarExpression
@@ -1599,7 +1599,7 @@ structured_op: !LinalgStructuredOpConfig
             - !ScalarExpression
               scalar_fn:
                 kind: type
-                fn_name: cast
+                fn_name: cast_signed
                 type_var: U
                 operands:
                 - !ScalarExpression
@@ -1674,7 +1674,7 @@ structured_op: !LinalgStructuredOpConfig
             - !ScalarExpression
               scalar_fn:
                 kind: type
-                fn_name: cast
+                fn_name: cast_signed
                 type_var: U
                 operands:
                 - !ScalarExpression
@@ -1682,7 +1682,7 @@ structured_op: !LinalgStructuredOpConfig
             - !ScalarExpression
               scalar_fn:
                 kind: type
-                fn_name: cast
+                fn_name: cast_signed
                 type_var: U
                 operands:
                 - !ScalarExpression
@@ -1767,7 +1767,7 @@ structured_op: !LinalgStructuredOpConfig
             - !ScalarExpression
               scalar_fn:
                 kind: type
-                fn_name: cast
+                fn_name: cast_signed
                 type_var: U
                 operands:
                 - !ScalarExpression
@@ -1775,7 +1775,7 @@ structured_op: !LinalgStructuredOpConfig
             - !ScalarExpression
               scalar_fn:
                 kind: type
-                fn_name: cast
+                fn_name: cast_signed
                 type_var: U
                 operands:
                 - !ScalarExpression
@@ -1876,7 +1876,7 @@ structured_op: !LinalgStructuredOpConfig
                 - !ScalarExpression
                   scalar_fn:
                     kind: type
-                    fn_name: cast
+                    fn_name: cast_signed
                     type_var: U
                     operands:
                     - !ScalarExpression
@@ -1884,7 +1884,7 @@ structured_op: !LinalgStructuredOpConfig
                 - !ScalarExpression
                   scalar_fn:
                     kind: type
-                    fn_name: cast
+                    fn_name: cast_signed
                     type_var: U
                     operands:
                     - !ScalarExpression
@@ -1897,7 +1897,7 @@ structured_op: !LinalgStructuredOpConfig
                 - !ScalarExpression
                   scalar_fn:
                     kind: type
-                    fn_name: cast
+                    fn_name: cast_signed
                     type_var: U
                     operands:
                     - !ScalarExpression
@@ -1905,7 +1905,7 @@ structured_op: !LinalgStructuredOpConfig
                 - !ScalarExpression
                   scalar_fn:
                     kind: type
-                    fn_name: cast
+                    fn_name: cast_signed
                     type_var: U
                     operands:
                     - !ScalarExpression
@@ -1991,7 +1991,7 @@ structured_op: !LinalgStructuredOpConfig
             - !ScalarExpression
               scalar_fn:
                 kind: type
-                fn_name: cast
+                fn_name: cast_signed
                 type_var: U
                 operands:
                 - !ScalarExpression
@@ -1999,7 +1999,7 @@ structured_op: !LinalgStructuredOpConfig
             - !ScalarExpression
               scalar_fn:
                 kind: type
-                fn_name: cast
+                fn_name: cast_signed
                 type_var: U
                 operands:
                 - !ScalarExpression
@@ -2102,7 +2102,7 @@ structured_op: !LinalgStructuredOpConfig
                 - !ScalarExpression
                   scalar_fn:
                     kind: type
-                    fn_name: cast
+                    fn_name: cast_signed
                     type_var: U
                     operands:
                     - !ScalarExpression
@@ -2110,7 +2110,7 @@ structured_op: !LinalgStructuredOpConfig
                 - !ScalarExpression
                   scalar_fn:
                     kind: type
-                    fn_name: cast
+                    fn_name: cast_signed
                     type_var: U
                     operands:
                     - !ScalarExpression
@@ -2123,7 +2123,7 @@ structured_op: !LinalgStructuredOpConfig
                 - !ScalarExpression
                   scalar_fn:
                     kind: type
-                    fn_name: cast
+                    fn_name: cast_signed
                     type_var: U
                     operands:
                     - !ScalarExpression
@@ -2131,7 +2131,7 @@ structured_op: !LinalgStructuredOpConfig
                 - !ScalarExpression
                   scalar_fn:
                     kind: type
-                    fn_name: cast
+                    fn_name: cast_signed
                     type_var: U
                     operands:
                     - !ScalarExpression
@@ -2210,7 +2210,7 @@ structured_op: !LinalgStructuredOpConfig
         - !ScalarExpression
           scalar_fn:
             kind: type
-            fn_name: cast
+            fn_name: cast_signed
             type_var: U
             operands:
             - !ScalarExpression
@@ -2282,14 +2282,14 @@ structured_op: !LinalgStructuredOpConfig
     value: !ScalarExpression
       scalar_fn:
         kind: binary
-        fn_name: max
+        fn_name: max_signed
         operands:
         - !ScalarExpression
           scalar_arg: O
         - !ScalarExpression
           scalar_fn:
             kind: type
-            fn_name: cast
+            fn_name: cast_signed
             type_var: U
             operands:
             - !ScalarExpression
@@ -2440,14 +2440,14 @@ structured_op: !LinalgStructuredOpConfig
     value: !ScalarExpression
       scalar_fn:
         kind: binary
-        fn_name: max
+        fn_name: max_signed
         operands:
         - !ScalarExpression
           scalar_arg: O
         - !ScalarExpression
           scalar_fn:
             kind: type
-            fn_name: cast
+            fn_name: cast_signed
             type_var: U
             operands:
             - !ScalarExpression
@@ -2519,14 +2519,14 @@ structured_op: !LinalgStructuredOpConfig
     value: !ScalarExpression
       scalar_fn:
         kind: binary
-        fn_name: min
+        fn_name: min_signed
         operands:
         - !ScalarExpression
           scalar_arg: O
         - !ScalarExpression
           scalar_fn:
             kind: type
-            fn_name: cast
+            fn_name: cast_signed
             type_var: U
             operands:
             - !ScalarExpression
@@ -2690,7 +2690,7 @@ structured_op: !LinalgStructuredOpConfig
         - !ScalarExpression
           scalar_fn:
             kind: type
-            fn_name: cast
+            fn_name: cast_signed
             type_var: U
             operands:
             - !ScalarExpression
@@ -2768,14 +2768,14 @@ structured_op: !LinalgStructuredOpConfig
     value: !ScalarExpression
       scalar_fn:
         kind: binary
-        fn_name: max
+        fn_name: max_signed
         operands:
         - !ScalarExpression
           scalar_arg: O
         - !ScalarExpression
           scalar_fn:
             kind: type
-            fn_name: cast
+            fn_name: cast_signed
             type_var: U
             operands:
             - !ScalarExpression
@@ -2853,14 +2853,14 @@ structured_op: !LinalgStructuredOpConfig
     value: !ScalarExpression
       scalar_fn:
         kind: binary
-        fn_name: min
+        fn_name: min_signed
         operands:
         - !ScalarExpression
           scalar_arg: O
         - !ScalarExpression
           scalar_fn:
             kind: type
-            fn_name: cast
+            fn_name: cast_signed
             type_var: U
             operands:
             - !ScalarExpression
@@ -2897,7 +2897,7 @@ structured_op: !LinalgStructuredOpConfig
     value: !ScalarExpression
       scalar_fn:
         kind: type
-        fn_name: cast
+        fn_name: cast_signed
         type_var: U
         operands:
         - !ScalarExpression
@@ -2950,7 +2950,7 @@ structured_op: !LinalgStructuredOpConfig
     value: !ScalarExpression
       scalar_fn:
         kind: type
-        fn_name: cast
+        fn_name: cast_signed
         type_var: T
         operands:
         - !ScalarExpression
@@ -2971,7 +2971,7 @@ structured_op: !LinalgStructuredOpConfig
                     - !ScalarExpression
                       scalar_fn:
                         kind: type
-                        fn_name: cast
+                        fn_name: cast_signed
                         type_var: F64
                         operands:
                         - !ScalarExpression
@@ -2979,7 +2979,7 @@ structured_op: !LinalgStructuredOpConfig
                     - !ScalarExpression
                       scalar_fn:
                         kind: type
-                        fn_name: cast
+                        fn_name: cast_signed
                         type_var: F64
                         operands:
                         - !ScalarExpression
@@ -3000,7 +3000,7 @@ structured_op: !LinalgStructuredOpConfig
                                     - !ScalarExpression
                                       scalar_fn:
                                         kind: type
-                                        fn_name: cast
+                                        fn_name: cast_signed
                                         type_var: I32
                                         operands:
                                         - !ScalarExpression
@@ -3023,7 +3023,7 @@ structured_op: !LinalgStructuredOpConfig
                                                 - !ScalarExpression
                                                   scalar_fn:
                                                     kind: type
-                                                    fn_name: cast
+                                                    fn_name: cast_signed
                                                     type_var: I32
                                                     operands:
                                                     - !ScalarExpression
@@ -3033,7 +3033,7 @@ structured_op: !LinalgStructuredOpConfig
                                             - !ScalarExpression
                                               scalar_fn:
                                                 kind: type
-                                                fn_name: cast
+                                                fn_name: cast_signed
                                                 type_var: I32
                                                 operands:
                                                 - !ScalarExpression
@@ -3041,7 +3041,7 @@ structured_op: !LinalgStructuredOpConfig
                                         - !ScalarExpression
                                           scalar_fn:
                                             kind: type
-                                            fn_name: cast
+                                            fn_name: cast_signed
                                             type_var: I32
                                             operands:
                                             - !ScalarExpression
@@ -3049,7 +3049,7 @@ structured_op: !LinalgStructuredOpConfig
                                 - !ScalarExpression
                                   scalar_fn:
                                     kind: type
-                                    fn_name: cast
+                                    fn_name: cast_signed
                                     type_var: I32
                                     operands:
                                     - !ScalarExpression
@@ -3057,7 +3057,7 @@ structured_op: !LinalgStructuredOpConfig
                             - !ScalarExpression
                               scalar_fn:
                                 kind: type
-                                fn_name: cast
+                                fn_name: cast_signed
                                 type_var: I32
                                 operands:
                                 - !ScalarExpression
@@ -3079,7 +3079,7 @@ structured_op: !LinalgStructuredOpConfig
                     - !ScalarExpression
                       scalar_fn:
                         kind: type
-                        fn_name: cast
+                        fn_name: cast_signed
                         type_var: F64
                         operands:
                         - !ScalarExpression
@@ -3130,7 +3130,7 @@ structured_op: !LinalgStructuredOpConfig
             - !ScalarExpression
               scalar_fn:
                 kind: type
-                fn_name: cast
+                fn_name: cast_signed
                 type_var: U
                 operands:
                 - !ScalarExpression
@@ -3143,7 +3143,7 @@ structured_op: !LinalgStructuredOpConfig
                 - !ScalarExpression
                   scalar_fn:
                     kind: type
-                    fn_name: cast
+                    fn_name: cast_signed
                     type_var: U
                     operands:
                     - !ScalarExpression

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index dfc994368e6cf..6a12c42f9bd0f 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -160,22 +160,22 @@ class RegionBuilderHelper {
       if (allFloatingPoint)
         return builder.create<arith::AddFOp>(arg0.getLoc(), arg0, arg1);
       return builder.create<arith::AddIOp>(arg0.getLoc(), arg0, arg1);
+    case BinaryFn::sub:
+      if (allFloatingPoint)
+        return builder.create<arith::SubFOp>(arg0.getLoc(), arg0, arg1);
+      return builder.create<arith::SubIOp>(arg0.getLoc(), arg0, arg1);
     case BinaryFn::mul:
       if (allFloatingPoint)
         return builder.create<arith::MulFOp>(arg0.getLoc(), arg0, arg1);
       return builder.create<arith::MulIOp>(arg0.getLoc(), arg0, arg1);
-    case BinaryFn::max:
+    case BinaryFn::max_signed:
       if (allFloatingPoint)
         return builder.create<arith::MaxFOp>(arg0.getLoc(), arg0, arg1);
       return builder.create<arith::MaxSIOp>(arg0.getLoc(), arg0, arg1);
-    case BinaryFn::min:
+    case BinaryFn::min_signed:
       if (allFloatingPoint)
         return builder.create<arith::MinFOp>(arg0.getLoc(), arg0, arg1);
       return builder.create<arith::MinSIOp>(arg0.getLoc(), arg0, arg1);
-    case BinaryFn::sub:
-      if (allFloatingPoint)
-        return builder.create<arith::SubFOp>(arg0.getLoc(), arg0, arg1);
-      return builder.create<arith::SubIOp>(arg0.getLoc(), arg0, arg1);
     case BinaryFn::max_unsigned:
       if (allFloatingPoint)
         return builder.create<arith::MaxFOp>(arg0.getLoc(), arg0, arg1);
@@ -191,7 +191,7 @@ class RegionBuilderHelper {
   // Build the type functions defined by OpDSL.
   Value buildTypeFn(TypeFn typeFn, Type toType, Value operand) {
     switch (typeFn) {
-    case TypeFn::cast:
+    case TypeFn::cast_signed:
       return cast(toType, operand, false);
     case TypeFn::cast_unsigned:
       return cast(toType, operand, true);

diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
index f6bf0ff9a50d0..7de0a76e87b7c 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
@@ -305,10 +305,10 @@ class BinaryFn:
   - max_unsinged -> `arith.MaxUIOp`
   """
   add = BinaryFnType("add")
-  mul = BinaryFnType("mul")
-  max = BinaryFnType("max")
-  min = BinaryFnType("min")
   sub = BinaryFnType("sub")
+  mul = BinaryFnType("mul")
+  max_signed = BinaryFnType("max_signed")
+  min_signed = BinaryFnType("min_signed")
   max_unsigned = BinaryFnType("max_unsigned")
   min_unsigned = BinaryFnType("min_unsigned")
 
@@ -334,14 +334,14 @@ class TypeFn:
   """Type conversion function namespace.
 
   As the integer types are signless, signedness is implement by 
diff erent cast
-  functions that treat integers as signed (`cast`) or unsigned
+  functions that treat integers as signed (`cast_signed`) or unsigned
   (`cast_unsigned`) values.
 
   Examples:
-  - cast(I32 -> I64) -> `arith.ExtSIOp`
+  - cast_signed(I32 -> I64) -> `arith.ExtSIOp`
   - cast_unsigned(I32 -> I64) -> `arith.ExtUIOp`
   """
-  cast = TypeFnType("cast")
+  cast_signed = TypeFnType("cast_signed")
   cast_unsigned = TypeFnType("cast_unsigned")
 
 
@@ -389,8 +389,8 @@ def __repr__(self):
 class ReduceFn:
   add = ReduceFnType(BinaryFn.add)
   mul = ReduceFnType(BinaryFn.mul)
-  max = ReduceFnType(BinaryFn.max)
-  min = ReduceFnType(BinaryFn.min)
+  max_signed = ReduceFnType(BinaryFn.max_signed)
+  min_signed = ReduceFnType(BinaryFn.min_signed)
   max_unsigned = ReduceFnType(BinaryFn.max_unsigned)
   min_unsigned = ReduceFnType(BinaryFn.min_unsigned)
 

diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
index 79fc3f5a21904..453a3e80cd9ad 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
@@ -370,7 +370,7 @@ def _cast_to_floating_point(self, to_type: Type, operand: Value,
     raise ValueError(f"Unable to cast body expression from {operand_type} to "
                      f"{to_type}")
 
-  def _type_cast(self, type_var_name: str, operand: Value) -> Value:
+  def _type_cast_signed(self, type_var_name: str, operand: Value) -> Value:
     return self._cast(type_var_name, operand, False)
 
   def _type_cast_unsigned(self, type_var_name: str, operand: Value) -> Value:
@@ -407,7 +407,7 @@ def _binary_mul(self, lhs: Value, rhs: Value) -> Value:
       return arith.MulIOp(lhs, rhs).result
     raise NotImplementedError("Unsupported 'mul' operands: {lhs}, {rhs}")
 
-  def _binary_max(self, lhs: Value, rhs: Value) -> Value:
+  def _binary_max_signed(self, lhs: Value, rhs: Value) -> Value:
     if _is_floating_point_type(lhs.type):
       return arith.MaxFOp(lhs, rhs).result
     if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
@@ -422,7 +422,7 @@ def _binary_max_unsigned(self, lhs: Value, rhs: Value) -> Value:
     raise NotImplementedError(
         "Unsupported 'max_unsigned' operands: {lhs}, {rhs}")
 
-  def _binary_min(self, lhs: Value, rhs: Value) -> Value:
+  def _binary_min_signed(self, lhs: Value, rhs: Value) -> Value:
     if _is_floating_point_type(lhs.type):
       return arith.MinFOp(lhs, rhs).result
     if _is_integer_type(lhs.type) or _is_index_type(lhs.type):

diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index b7a827bf649b2..0ef40613a7ba9 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -11,7 +11,7 @@ def elemwise_unary(
     I=TensorDef(T1),
     O=TensorDef(U, output=True),
     fun=UnaryFnAttrDef(default=UnaryFn.exp),
-    cast=TypeFnAttrDef(default=TypeFn.cast)):
+    cast=TypeFnAttrDef(default=TypeFn.cast_signed)):
   """Applies the unary function fun elementwise.
 
   Numeric casting is performed on the input operand, promoting it to the same
@@ -26,7 +26,7 @@ def elemwise_binary(
     rhs=TensorDef(T2),
     O=TensorDef(U, output=True),
     fun=BinaryFnAttrDef(default=BinaryFn.add),
-    cast=TypeFnAttrDef(default=TypeFn.cast)):
+    cast=TypeFnAttrDef(default=TypeFn.cast_signed)):
   """Applies the binary function fun elementwise.
 
   Numeric casting is performed on the input operand, promoting it to the same
@@ -40,7 +40,7 @@ def matmul(
     A=TensorDef(T1, S.M, S.K),
     B=TensorDef(T2, S.K, S.N),
     C=TensorDef(U, S.M, S.N, output=True),
-    cast=TypeFnAttrDef(default=TypeFn.cast)):
+    cast=TypeFnAttrDef(default=TypeFn.cast_signed)):
   """Performs a matrix multiplication of two 2D inputs.
 
   Numeric casting is performed on the operands to the inner multiply, promoting
@@ -82,8 +82,9 @@ def quantized_matmul(
   matmul.
   """
   domain(D.m, D.n, D.k)
-  C[D.m, D.n] += (TypeFn.cast(U, A[D.m, D.k]) - TypeFn.cast(U, AZp)) * (
-      TypeFn.cast(U, B[D.k, D.n]) - TypeFn.cast(U, BZp))
+  C[D.m, D.n] += (
+      TypeFn.cast_signed(U, A[D.m, D.k]) - TypeFn.cast_signed(U, AZp)) * (
+          TypeFn.cast_signed(U, B[D.k, D.n]) - TypeFn.cast_signed(U, BZp))
 
 
 @linalg_structured_op
@@ -103,8 +104,8 @@ def mmt4d(
   """
   domain(D.m, D.n, D.k, D.m0, D.n0, D.k0)
   implements(ContractionOpInterface)
-  accum[D.m, D.n, D.m0, D.n0] += TypeFn.cast(
-      TV.AccumType, lhs[D.m, D.k, D.m0, D.k0]) * TypeFn.cast(
+  accum[D.m, D.n, D.m0, D.n0] += TypeFn.cast_signed(
+      TV.AccumType, lhs[D.m, D.k, D.m0, D.k0]) * TypeFn.cast_signed(
           TV.AccumType, rhs[D.n, D.k, D.n0, D.k0])
 
 
@@ -121,7 +122,8 @@ def batch_matmul(
   domain(D.b, D.m, D.n, D.k)
   implements(ContractionOpInterface)
   C[D.b, D.m,
-    D.n] += TypeFn.cast(U, A[D.b, D.m, D.k]) * TypeFn.cast(U, B[D.b, D.k, D.n])
+    D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed(
+        U, B[D.b, D.k, D.n])
 
 
 @linalg_structured_op
@@ -139,9 +141,9 @@ def quantized_batch_matmul(
   matmul.
   """
   domain(D.b, D.m, D.n, D.k)
-  C[D.b, D.m,
-    D.n] += (TypeFn.cast(U, A[D.b, D.m, D.k]) - TypeFn.cast(U, AZp)) * (
-        TypeFn.cast(U, B[D.b, D.k, D.n]) - TypeFn.cast(U, BZp))
+  C[D.b, D.m, D.n] += (
+      TypeFn.cast_signed(U, A[D.b, D.m, D.k]) - TypeFn.cast_signed(U, AZp)) * (
+          TypeFn.cast_signed(U, B[D.b, D.k, D.n]) - TypeFn.cast_signed(U, BZp))
 
 
 @linalg_structured_op
@@ -156,7 +158,7 @@ def matvec(
   """
   domain(D.m, D.n)
   implements(ContractionOpInterface)
-  x[D.m] += TypeFn.cast(U, A[D.m, D.n]) * TypeFn.cast(U, y[D.n])
+  x[D.m] += TypeFn.cast_signed(U, A[D.m, D.n]) * TypeFn.cast_signed(U, y[D.n])
 
 
 @linalg_structured_op
@@ -171,7 +173,7 @@ def vecmat(
   """
   domain(D.n, D.m)
   implements(ContractionOpInterface)
-  x[D.n] += TypeFn.cast(U, y[D.m]) * TypeFn.cast(U, A[D.m, D.n])
+  x[D.n] += TypeFn.cast_signed(U, y[D.m]) * TypeFn.cast_signed(U, A[D.m, D.n])
 
 
 @linalg_structured_op
@@ -186,7 +188,8 @@ def batch_matvec(
   """
   domain(D.b, D.m, D.k)
   implements(ContractionOpInterface)
-  C[D.b, D.m] += TypeFn.cast(U, A[D.b, D.m, D.k]) * TypeFn.cast(U, B[D.b, D.k])
+  C[D.b, D.m] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed(
+      U, B[D.b, D.k])
 
 
 @linalg_structured_op
@@ -198,7 +201,7 @@ def dot(
   them to the same data type as the accumulator/output.
   """
   implements(ContractionOpInterface)
-  C[None] += TypeFn.cast(U, A[D.m]) * TypeFn.cast(U, B[D.m])
+  C[None] += TypeFn.cast_signed(U, A[D.m]) * TypeFn.cast_signed(U, B[D.m])
 
 
 @linalg_structured_op
@@ -213,7 +216,8 @@ def conv_1d(
   """
   implements(ConvolutionOpInterface)
   domain(D.ow, D.kw)
-  O[D.ow] += TypeFn.cast(U, I[D.ow + D.kw]) * TypeFn.cast(U, K[D.kw])
+  O[D.ow] += TypeFn.cast_signed(U, I[D.ow + D.kw]) * TypeFn.cast_signed(
+      U, K[D.kw])
 
 
 @linalg_structured_op
@@ -228,8 +232,8 @@ def conv_2d(
   """
   implements(ConvolutionOpInterface)
   domain(D.oh, D.ow, D.kh, D.kw)
-  O[D.oh, D.ow] += TypeFn.cast(U, I[D.oh + D.kh, D.ow + D.kw]) * TypeFn.cast(
-      U, K[D.kh, D.kw])
+  O[D.oh, D.ow] += TypeFn.cast_signed(
+      U, I[D.oh + D.kh, D.ow + D.kw]) * TypeFn.cast_signed(U, K[D.kh, D.kw])
 
 
 @linalg_structured_op
@@ -244,9 +248,9 @@ def conv_3d(
   """
   implements(ConvolutionOpInterface)
   domain(D.od, D.oh, D.ow, D.kd, D.kh, D.kw)
-  O[D.od, D.oh,
-    D.ow] += TypeFn.cast(U, I[D.od + D.kd, D.oh + D.kh, D.ow +
-                              D.kw]) * TypeFn.cast(U, K[D.kd, D.kh, D.kw])
+  O[D.od, D.oh, D.ow] += TypeFn.cast_signed(
+      U, I[D.od + D.kd, D.oh + D.kh, D.ow + D.kw]) * TypeFn.cast_signed(
+          U, K[D.kd, D.kh, D.kw])
 
 
 @linalg_structured_op
@@ -264,8 +268,8 @@ def conv_1d_nwc_wcf(
   implements(ConvolutionOpInterface)
   domain(D.n, D.ow, D.f, D.kw, D.c)
   O[D.n, D.ow,
-    D.f] += TypeFn.cast(U, I[D.n, D.ow * S.SW + D.kw * S.DW,
-                             D.c]) * TypeFn.cast(U, K[D.kw, D.c, D.f])
+    D.f] += TypeFn.cast_signed(U, I[D.n, D.ow * S.SW + D.kw * S.DW,
+                             D.c]) * TypeFn.cast_signed(U, K[D.kw, D.c, D.f])
 
 
 @linalg_structured_op
@@ -287,9 +291,9 @@ def conv_2d_nhwc_hwcf(
   """
   implements(ConvolutionOpInterface)
   domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c)
-  O[D.n, D.oh, D.ow, D.f] += TypeFn.cast(
+  O[D.n, D.oh, D.ow, D.f] += TypeFn.cast_signed(
       U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
-           D.c]) * TypeFn.cast(U, K[D.kh, D.kw, D.c, D.f])
+           D.c]) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.c, D.f])
 
 
 @linalg_structured_op
@@ -315,10 +319,11 @@ def conv_2d_nhwc_hwcf_q(
   implements(ConvolutionOpInterface)
   domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c)
   O[D.n, D.oh, D.ow,
-    D.f] += (TypeFn.cast(
+    D.f] += (TypeFn.cast_signed(
         U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]) -
-             TypeFn.cast(U, IZp)) * (
-                 TypeFn.cast(U, K[D.kh, D.kw, D.c, D.f]) - TypeFn.cast(U, KZp))
+             TypeFn.cast_signed(U, IZp)) * (
+                 TypeFn.cast_signed(U, K[D.kh, D.kw, D.c, D.f]) -
+                 TypeFn.cast_signed(U, KZp))
 
 
 @linalg_structured_op
@@ -340,9 +345,9 @@ def conv_2d_nchw_fchw(
   """
   implements(ConvolutionOpInterface)
   domain(D.n, D.f, D.oh, D.ow, D.c, D.kh, D.kw)
-  O[D.n, D.f, D.oh, D.ow] += TypeFn.cast(
-      U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH,
-           D.ow * S.SW + D.kw * S.DW]) * TypeFn.cast(U, K[D.f, D.c, D.kh, D.kw])
+  O[D.n, D.f, D.oh, D.ow] += TypeFn.cast_signed(
+      U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW +
+           D.kw * S.DW]) * TypeFn.cast_signed(U, K[D.f, D.c, D.kh, D.kw])
 
 
 @linalg_structured_op
@@ -360,9 +365,9 @@ def conv_3d_ndhwc_dhwcf(
   """
   implements(ConvolutionOpInterface)
   domain(D.n, D.od, D.oh, D.ow, D.f, D.kd, D.kh, D.kw, D.c)
-  O[D.n, D.od, D.oh, D.ow, D.f] += TypeFn.cast(
+  O[D.n, D.od, D.oh, D.ow, D.f] += TypeFn.cast_signed(
       U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH,
-           D.ow * S.SW + D.kw * S.DW, D.c]) * TypeFn.cast(
+           D.ow * S.SW + D.kw * S.DW, D.c]) * TypeFn.cast_signed(
                U, K[D.kd, D.kh, D.kw, D.c, D.f])
 
 
@@ -382,8 +387,8 @@ def depthwise_conv_1d_nwc_wc(
   implements(ConvolutionOpInterface)
   domain(D.n, D.ow, D.ic, D.kw)
   O[D.n, D.ow, D.ic] += \
-      TypeFn.cast(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.ic]) * \
-      TypeFn.cast(U, K[D.kw, D.ic])
+      TypeFn.cast_signed(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.ic]) * \
+      TypeFn.cast_signed(U, K[D.kw, D.ic])
 
 
 @linalg_structured_op
@@ -402,9 +407,9 @@ def depthwise_conv_2d_nhwc_hwc(
   """
   implements(ConvolutionOpInterface)
   domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw)
-  O[D.n, D.oh, D.ow, D.ic] += TypeFn.cast(
+  O[D.n, D.oh, D.ow, D.ic] += TypeFn.cast_signed(
       U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
-           D.ic]) * TypeFn.cast(U, K[D.kh, D.kw, D.ic])
+           D.ic]) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic])
 
 
 @linalg_structured_op
@@ -424,11 +429,11 @@ def depthwise_conv_2d_nhwc_hwc_q(
   """
   implements(ConvolutionOpInterface)
   domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw)
-  O[D.n, D.oh, D.ow,
-    D.ic] += ((TypeFn.cast(
-        U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic]) -
-               TypeFn.cast(U, IZp)) *
-              (TypeFn.cast(U, K[D.kh, D.kw, D.ic]) - TypeFn.cast(U, KZp)))
+  O[D.n, D.oh, D.ow, D.ic] += ((TypeFn.cast_signed(
+      U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic]) -
+                                TypeFn.cast_signed(U, IZp)) *
+                               (TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic]) -
+                                TypeFn.cast_signed(U, KZp)))
 
 
 @linalg_structured_op
@@ -446,9 +451,9 @@ def depthwise_conv_2d_nhwc_hwcm(
   """
   implements(ConvolutionOpInterface)
   domain(D.n, D.oh, D.ow, D.ic, D.cm, D.kh, D.kw)
-  O[D.n, D.oh, D.ow, D.ic, D.cm] += TypeFn.cast(
+  O[D.n, D.oh, D.ow, D.ic, D.cm] += TypeFn.cast_signed(
       U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
-           D.ic]) * TypeFn.cast(U, K[D.kh, D.kw, D.ic, D.cm])
+           D.ic]) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic, D.cm])
 
 
 @linalg_structured_op
@@ -469,10 +474,11 @@ def depthwise_conv_2d_nhwc_hwcm_q(
   implements(ConvolutionOpInterface)
   domain(D.n, D.oh, D.ow, D.ic, D.cm, D.kh, D.kw)
   O[D.n, D.oh, D.ow, D.ic,
-    D.cm] += ((TypeFn.cast(
+    D.cm] += ((TypeFn.cast_signed(
         U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic]) -
-               TypeFn.cast(U, IZp)) *
-              (TypeFn.cast(U, K[D.kh, D.kw, D.ic, D.cm]) - TypeFn.cast(U, KZp)))
+               TypeFn.cast_signed(U, IZp)) *
+              (TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic, D.cm]) -
+               TypeFn.cast_signed(U, KZp)))
 
 
 @linalg_structured_op
@@ -490,7 +496,7 @@ def pooling_nhwc_sum(
   """
   implements(ConvolutionOpInterface)
   domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw)
-  O[D.n, D.oh, D.ow, D.c] += TypeFn.cast(
+  O[D.n, D.oh, D.ow, D.c] += TypeFn.cast_signed(
       U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])
 
 
@@ -509,8 +515,8 @@ def pooling_nhwc_max(
   """
   implements(ConvolutionOpInterface)
   domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw)
-  O[D.n, D.oh, D.ow, D.c] = ReduceFn.max[D.kh, D.kw](
-      TypeFn.cast(
+  O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_signed[D.kh, D.kw](
+      TypeFn.cast_signed(
           U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
 
 
@@ -549,8 +555,8 @@ def pooling_nchw_max(
   """
   implements(ConvolutionOpInterface)
   domain(D.n, D.c, D.oh, D.ow, D.kh, D.kw)
-  O[D.n, D.c, D.oh, D.ow] = ReduceFn.max[D.kh, D.kw](
-      TypeFn.cast(
+  O[D.n, D.c, D.oh, D.ow] = ReduceFn.max_signed[D.kh, D.kw](
+      TypeFn.cast_signed(
           U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH,
                D.ow * S.SW + D.kw * S.DW,]))
 
@@ -570,8 +576,8 @@ def pooling_nhwc_min(
   """
   implements(ConvolutionOpInterface)
   domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw)
-  O[D.n, D.oh, D.ow, D.c] = ReduceFn.min[D.kh, D.kw](
-      TypeFn.cast(
+  O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_signed[D.kh, D.kw](
+      TypeFn.cast_signed(
           U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
 
 
@@ -610,7 +616,7 @@ def pooling_ndhwc_sum(
   """
   implements(ConvolutionOpInterface)
   domain(D.n, D.od, D.oh, D.ow, D.c, D.kd, D.kh, D.kw)
-  O[D.n, D.od, D.oh, D.ow, D.c] += TypeFn.cast(
+  O[D.n, D.od, D.oh, D.ow, D.c] += TypeFn.cast_signed(
       U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH,
            D.ow * S.SW + D.kw * S.DW, D.c])
 
@@ -630,8 +636,8 @@ def pooling_ndhwc_max(
   """
   implements(ConvolutionOpInterface)
   domain(D.n, D.od, D.oh, D.ow, D.c, D.kd, D.kh, D.kw)
-  O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.max[D.kd, D.kh, D.kw](
-      TypeFn.cast(
+  O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.max_signed[D.kd, D.kh, D.kw](
+      TypeFn.cast_signed(
           U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH,
                D.ow * S.SW + D.kw * S.DW, D.c]))
 
@@ -651,8 +657,8 @@ def pooling_ndhwc_min(
   """
   implements(ConvolutionOpInterface)
   domain(D.n, D.od, D.oh, D.ow, D.c, D.kd, D.kh, D.kw)
-  O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.min[D.kd, D.kh, D.kw](
-      TypeFn.cast(
+  O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.min_signed[D.kd, D.kh, D.kw](
+      TypeFn.cast_signed(
           U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH,
                D.ow * S.SW + D.kw * S.DW, D.c]))
 
@@ -665,7 +671,7 @@ def fill_tensor(value=ScalarDef(T1), O=TensorDef(U, output=True)):
   accesses only and is thus rank polymorphic. Numeric casting is performed on
   the value operand, promoting it to the same data type as the output.
   """
-  O[None] = TypeFn.cast(U, value)
+  O[None] = TypeFn.cast_signed(U, value)
 
 
 @linalg_structured_op
@@ -685,15 +691,15 @@ def fill_rng_2d(
   the range of the generated random numbers.
   """
   domain(D.m, D.n)
-  multiplier = TypeFn.cast(I32, const(1103515245))
-  increment = TypeFn.cast(I32, const(12345))
-  rand1 = (TypeFn.cast(I32, index(D.m)) + seed) * multiplier + increment
-  rand2 = (TypeFn.cast(I32, index(D.n)) + rand1) * multiplier + increment
-  inv_range = TypeFn.cast(F64, const(2.3283064e-10))
-  offset = TypeFn.cast(F64, const(2147483647))
+  multiplier = TypeFn.cast_signed(I32, const(1103515245))
+  increment = TypeFn.cast_signed(I32, const(12345))
+  rand1 = (TypeFn.cast_signed(I32, index(D.m)) + seed) * multiplier + increment
+  rand2 = (TypeFn.cast_signed(I32, index(D.n)) + rand1) * multiplier + increment
+  inv_range = TypeFn.cast_signed(F64, const(2.3283064e-10))
+  offset = TypeFn.cast_signed(F64, const(2147483647))
   scaling = (max - min) * inv_range
-  O[D.m, D.n] = TypeFn.cast(T,
-                            (offset + TypeFn.cast(F64, rand2)) * scaling + min)
+  O[D.m, D.n] = TypeFn.cast_signed(
+      T, (offset + TypeFn.cast_signed(F64, rand2)) * scaling + min)
 
 
 @linalg_structured_op
@@ -706,4 +712,4 @@ def soft_plus_2d(
   """
   domain(D.m, D.n)
   O[D.m, D.n] = \
-      UnaryFn.log(TypeFn.cast(U, const(1.0)) + UnaryFn.exp(TypeFn.cast(U, I[D.m, D.n])))
+      UnaryFn.log(TypeFn.cast_signed(U, const(1.0)) + UnaryFn.exp(TypeFn.cast_signed(U, I[D.m, D.n])))

diff  --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
index 2defebbba781f..3f6c763470146 100644
--- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
+++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
@@ -3,7 +3,7 @@
 
 # @linalg_structured_op
 # def test1(O=TensorDef(T, S.M, S.N, output=True),
-#           cast=TypeFnAttrDef(default=TypeFn.cast)):
+#           cast=TypeFnAttrDef(default=TypeFn.cast_signed)):
 #   """Title.
 
 #   Detailed description.
@@ -28,7 +28,7 @@ structured_op: !LinalgStructuredOpConfig
   - !LinalgOperandDefConfig
     name: cast
     kind: type_fn_attr
-    default_fn: cast
+    default_fn: cast_signed
   indexing_maps: !LinalgIndexingMapsConfig
     static_indexing_maps:
     - affine_map<(d0, d1)[s0, s1] -> (d0, d1)>
@@ -70,7 +70,7 @@ structured_op: !LinalgStructuredOpConfig
 #       ODS:  let arguments =
 #  ODS-NEXT:    Variadic<AnyType>:$inputs,
 #  ODS-NEXT:    Variadic<AnyShaped>:$outputs,
-#  ODS-NEXT:    DefaultValuedAttr<TypeFnAttr, "TypeFn::cast">:$cast
+#  ODS-NEXT:    DefaultValuedAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
 
 #       ODS:  let builders =
 #       ODS:  (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
@@ -99,7 +99,7 @@ structured_op: !LinalgStructuredOpConfig
 
 # IMPL-LABEL:  void Test1Op::regionBuilder(ImplicitLocOpBuilder &b,
 #  IMPL-NEXT:    Block &block, ArrayRef<NamedAttribute> attrs)
-#       IMPL:  TypeFn castVal = TypeFn::cast;
+#       IMPL:  TypeFn castVal = TypeFn::cast_signed;
 #  IMPL-NEXT:  auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
 #  IMPL-NEXT:                                return attr.getName() == "cast"; });
 #  IMPL-NEXT:  if (castIter != attrs.end()) {
@@ -209,7 +209,7 @@ structured_op: !LinalgStructuredOpConfig
 
 #   Detailed description.
 #   """
-#   O[None] = TypeFn.cast(U, value)
+#   O[None] = TypeFn.cast_signed(U, value)
 
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
@@ -241,7 +241,7 @@ structured_op: !LinalgStructuredOpConfig
     value: !ScalarExpression
       scalar_fn:
         kind: type
-        fn_name: cast
+        fn_name: cast_signed
         type_var: U
         operands:
         - !ScalarExpression

diff  --git a/mlir/test/python/dialects/linalg/opdsl/arguments.py b/mlir/test/python/dialects/linalg/opdsl/arguments.py
index 853627611987c..d787c5f49c441 100644
--- a/mlir/test/python/dialects/linalg/opdsl/arguments.py
+++ b/mlir/test/python/dialects/linalg/opdsl/arguments.py
@@ -26,7 +26,7 @@
 # CHECK:     default_fn: exp
 # CHECK:     name: cast
 # CHECK:     kind: type_fn_attr
-# CHECK:     default_fn: cast
+# CHECK:     default_fn: cast_signed
 @linalg_structured_op
 def matmul(
     A=TensorDef(T, S.M, S.K),
@@ -34,7 +34,7 @@ def matmul(
     C=TensorDef(U, S.M, S.N, output=True),
     bfn=BinaryFnAttrDef(default=BinaryFn.mul),
     ufn=UnaryFnAttrDef(default=UnaryFn.exp),
-    cast=TypeFnAttrDef(default=TypeFn.cast)):
+    cast=TypeFnAttrDef(default=TypeFn.cast_signed)):
   C[D.m, D.n] += bfn(cast(U, A[D.m, D.k]), cast(U, B[D.k, D.n]))
 
 

diff  --git a/mlir/test/python/dialects/linalg/opdsl/assignments.py b/mlir/test/python/dialects/linalg/opdsl/assignments.py
index d8ddc24454914..eacf43547b110 100644
--- a/mlir/test/python/dialects/linalg/opdsl/assignments.py
+++ b/mlir/test/python/dialects/linalg/opdsl/assignments.py
@@ -35,7 +35,7 @@ def matmul(
     B=TensorDef(T, S.K, S.N),
     C=TensorDef(U, S.M, S.N, output=True),
     mul=BinaryFnAttrDef(default=BinaryFn.mul),
-    cast=TypeFnAttrDef(default=TypeFn.cast)):
+    cast=TypeFnAttrDef(default=TypeFn.cast_signed)):
   C[D.m, D.n] += mul(cast(U, A[D.m, D.k]), cast(U, B[D.k, D.n]))
 
 
@@ -63,13 +63,13 @@ def matmul(
 # CHECK:                      scalar_const: '3.1415926535897931 : f64'
 # CHECK:              scalar_fn:
 # CHECK:                kind: type
-# CHECK:                fn_name: cast
+# CHECK:                fn_name: cast_signed
 # CHECK:                type_var: T
 # CHECK:                operands:
 # CHECK:                  scalar_const: '42 : i64'
 # CHECK:          scalar_fn:
 # CHECK:            kind: type
-# CHECK:            fn_name: cast
+# CHECK:            fn_name: cast_signed
 # CHECK:            type_var: T
 # CHECK:            operands:
 # CHECK:              scalar_fn:
@@ -81,9 +81,9 @@ def matmul(
 def constants(
     O=TensorDef(T, S.M, S.K, output=True),
     exp=UnaryFnAttrDef(default=UnaryFn.exp)):
-  pi = TypeFn.cast(T, const(3.1415926535897931))
-  cst42 = TypeFn.cast(T, const(42))
-  cst1000 = TypeFn.cast(T, exp(const(1e+3)))
+  pi = TypeFn.cast_signed(T, const(3.1415926535897931))
+  cst42 = TypeFn.cast_signed(T, const(42))
+  cst1000 = TypeFn.cast_signed(T, exp(const(1e+3)))
   O[D.m, D.n] = UnaryFn.exp(pi) + cst42 - cst1000
 
 

diff  --git a/mlir/test/python/dialects/linalg/opdsl/emit_convolution.py b/mlir/test/python/dialects/linalg/opdsl/emit_convolution.py
index e15424cfac4e7..25a3de05f0493 100644
--- a/mlir/test/python/dialects/linalg/opdsl/emit_convolution.py
+++ b/mlir/test/python/dialects/linalg/opdsl/emit_convolution.py
@@ -19,9 +19,9 @@ def conv_poly(
     strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
     dilations=IndexAttrDef(S.DH, S.DW, default=[1, 2])):
   domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
-  O[D.n, D.oh, D.ow, D.c] += TypeFn.cast(
+  O[D.n, D.oh, D.ow, D.c] += TypeFn.cast_signed(
       U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
-           D.c]) * TypeFn.cast(U, K[D.kh, D.kw, D.c])
+           D.c]) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.c])
 
 
 with Context() as ctx, Location.unknown():

diff  --git a/mlir/test/python/dialects/linalg/opdsl/emit_fill.py b/mlir/test/python/dialects/linalg/opdsl/emit_fill.py
index 75524691a4875..1a0d08a0dce79 100644
--- a/mlir/test/python/dialects/linalg/opdsl/emit_fill.py
+++ b/mlir/test/python/dialects/linalg/opdsl/emit_fill.py
@@ -13,7 +13,7 @@
 
 @linalg_structured_op
 def fill_poly(value=ScalarDef(T1), O=TensorDef(U, output=True)):
-  O[None] = TypeFn.cast(U, value)
+  O[None] = TypeFn.cast_signed(U, value)
 
 
 with Context() as ctx, Location.unknown():

diff  --git a/mlir/test/python/dialects/linalg/opdsl/emit_matmul.py b/mlir/test/python/dialects/linalg/opdsl/emit_matmul.py
index 1954d6d89ac5c..1b31ed040ad2e 100644
--- a/mlir/test/python/dialects/linalg/opdsl/emit_matmul.py
+++ b/mlir/test/python/dialects/linalg/opdsl/emit_matmul.py
@@ -25,7 +25,7 @@ def matmul_poly(
     A=TensorDef(T1, S.M, S.K),
     B=TensorDef(T2, S.K, S.N),
     C=TensorDef(U, S.M, S.N, output=True),
-    cast=TypeFnAttrDef(default=TypeFn.cast)):
+    cast=TypeFnAttrDef(default=TypeFn.cast_signed)):
   domain(D.m, D.n, D.k)
   C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
 

diff  --git a/mlir/test/python/dialects/linalg/opdsl/emit_misc.py b/mlir/test/python/dialects/linalg/opdsl/emit_misc.py
index 9eab9b2a8e6ac..85699493dd4db 100644
--- a/mlir/test/python/dialects/linalg/opdsl/emit_misc.py
+++ b/mlir/test/python/dialects/linalg/opdsl/emit_misc.py
@@ -21,22 +21,23 @@ def fill_rng_poly(
     max=ScalarDef(F64),
     seed=ScalarDef(I32),
     O=TensorDef(T, S.M, S.N, output=True)):
-  multiplier = TypeFn.cast(I32, const(1103515245))
-  increment = TypeFn.cast(I32, const(12345))
-  rand1 = (TypeFn.cast(I32, index(D.m)) + seed) * multiplier + increment
-  rand2 = (TypeFn.cast(I32, index(D.n)) + rand1) * multiplier + increment
-  inv_range = TypeFn.cast(F64, const(2.3283064e-10))
-  offset = TypeFn.cast(F64, const(2147483647))
+  multiplier = TypeFn.cast_signed(I32, const(1103515245))
+  increment = TypeFn.cast_signed(I32, const(12345))
+  rand1 = (TypeFn.cast_signed(I32, index(D.m)) + seed) * multiplier + increment
+  rand2 = (TypeFn.cast_signed(I32, index(D.n)) + rand1) * multiplier + increment
+  inv_range = TypeFn.cast_signed(F64, const(2.3283064e-10))
+  offset = TypeFn.cast_signed(F64, const(2147483647))
   scaling = (max - min) * inv_range
-  O[D.m, D.n] = TypeFn.cast(T,
-                            (offset + TypeFn.cast(F64, rand2)) * scaling + min)
+  O[D.m, D.n] = TypeFn.cast_signed(
+      T, (offset + TypeFn.cast_signed(F64, rand2)) * scaling + min)
 
 
 @linalg_structured_op
 def soft_plus_poly(
     I=TensorDef(T, S.M, S.N), O=TensorDef(U, S.M, S.N, output=True)):
   O[D.m, D.n] = UnaryFn.log(
-      TypeFn.cast(U, const(1.0)) + TypeFn.cast(U, UnaryFn.exp(I[D.m, D.n])))
+      TypeFn.cast_signed(U, const(1.0)) +
+      TypeFn.cast_signed(U, UnaryFn.exp(I[D.m, D.n])))
 
 
 @linalg_structured_op(op_name="custom_op_name")

diff  --git a/mlir/test/python/dialects/linalg/opdsl/emit_pooling.py b/mlir/test/python/dialects/linalg/opdsl/emit_pooling.py
index ded97cd7b8220..d68b43de7535b 100644
--- a/mlir/test/python/dialects/linalg/opdsl/emit_pooling.py
+++ b/mlir/test/python/dialects/linalg/opdsl/emit_pooling.py
@@ -16,8 +16,8 @@ def pooling_poly(
     I=TensorDef(T1, S.N, S.H, S.W, S.C),
     K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
     O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
-    reduce=BinaryFnAttrDef(default=BinaryFn.max),
-    cast=TypeFnAttrDef(default=TypeFn.cast),
+    reduce=BinaryFnAttrDef(default=BinaryFn.max_signed),
+    cast=TypeFnAttrDef(default=TypeFn.cast_signed),
     strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
     dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
   domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
@@ -99,7 +99,7 @@ def test_f32i32_min_pooling(input, shape, init_result):
           input,
           shape,
           outs=[init_result],
-          reduce=BinaryFn.min,
+          reduce=BinaryFn.min_signed,
           strides=[2, 4],
           dilations=[1, 2])
 
@@ -131,7 +131,7 @@ def test_f32f32_min_pooling(input, shape, init_result):
           input,
           shape,
           outs=[init_result],
-          reduce=BinaryFn.min,
+          reduce=BinaryFn.min_signed,
           strides=[2, 4],
           dilations=[1, 2])
 

diff  --git a/mlir/test/python/dialects/linalg/opdsl/interfaces.py b/mlir/test/python/dialects/linalg/opdsl/interfaces.py
index 81256e314e140..ca9bd04cd9671 100644
--- a/mlir/test/python/dialects/linalg/opdsl/interfaces.py
+++ b/mlir/test/python/dialects/linalg/opdsl/interfaces.py
@@ -13,4 +13,5 @@ def matmul(
     B=TensorDef(T, S.K, S.N),
     C=TensorDef(U, S.M, S.N, output=True)):
   implements(ContractionOpInterface)
-  C[D.m, D.n] += TypeFn.cast(U, A[D.m, D.k]) * TypeFn.cast(U, B[D.k, D.n])
+  C[D.m, D.n] += TypeFn.cast_signed(U, A[D.m, D.k]) * TypeFn.cast_signed(
+      U, B[D.k, D.n])

diff  --git a/mlir/test/python/dialects/linalg/opdsl/shape_maps_iteration.py b/mlir/test/python/dialects/linalg/opdsl/shape_maps_iteration.py
index 5817ab6d2c305..871341c835a5d 100644
--- a/mlir/test/python/dialects/linalg/opdsl/shape_maps_iteration.py
+++ b/mlir/test/python/dialects/linalg/opdsl/shape_maps_iteration.py
@@ -24,7 +24,8 @@ def matmul(
     B=TensorDef(T, S.K, S.N),
     C=TensorDef(U, S.M, S.N, output=True)):
   domain(D.m, D.n, D.k)
-  C[D.m, D.n] += TypeFn.cast(U, A[D.m, D.k]) * TypeFn.cast(U, B[D.k, D.n])
+  C[D.m, D.n] += TypeFn.cast_signed(U, A[D.m, D.k]) * TypeFn.cast_signed(
+      U, B[D.k, D.n])
 
 
 # Verifies that assignment to a scalar (represented as [None]) is represented
@@ -42,7 +43,7 @@ def matmul(
 # CHECK-NEXT: - reduction
 @linalg_structured_op
 def dot(A=TensorDef(T, S.M), B=TensorDef(T, S.M), C=TensorDef(U, output=True)):
-  C[None] += TypeFn.cast(U, A[D.m]) * TypeFn.cast(U, B[D.m])
+  C[None] += TypeFn.cast_signed(U, A[D.m]) * TypeFn.cast_signed(U, B[D.m])
 
 
 # Verifies that the index_dims of shape-only operands translate to correct
@@ -65,4 +66,4 @@ def pool(
     K=TensorDef(T, S.K, index_dims=[D.k]),
     O=TensorDef(U, S.O, output=True)):
   domain(D.o, D.k)
-  O[D.o] += TypeFn.cast(U, I[D.o * 2 + D.k])
+  O[D.o] += TypeFn.cast_signed(U, I[D.o * 2 + D.k])

diff  --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index 00a8406373855..08be2ccada5f7 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -99,7 +99,7 @@ def named_form(lhs, rhs):
         init_result = linalg.InitTensorOp([4, 8], f32)
         # Check for the named form with custom format
         #      CHECK: linalg.elemwise_unary
-        # CHECK-SAME:    cast = #linalg.type_fn<cast>
+        # CHECK-SAME:    cast = #linalg.type_fn<cast_signed>
         # CHECK-SAME:    fun = #linalg.unary_fn<exp>
         # CHECK-SAME:    ins(%{{.*}} : tensor<4x8xf32>) outs(%{{.*}} : tensor<4x8xf32>)
         unary_result = linalg.elemwise_unary(lhs, outs=[init_result.result])
@@ -137,7 +137,7 @@ def named_form(lhs, rhs):
         # CHECK-NEXT:    arith.mulf{{.*}} (f32, f32) -> f32
         # CHECK-NEXT:    arith.addf{{.*}} (f32, f32) -> f32
         # CHECK-NEXT:    linalg.yield{{.*}} (f32) -> ()
-        # CHECK-NEXT:    cast = #linalg.type_fn<cast>
+        # CHECK-NEXT:    cast = #linalg.type_fn<cast_signed>
         # CHECK-SAME:    operand_segment_sizes = dense<[2, 1]> : vector<2xi32>
         # CHECK-SAME: (tensor<4x16xf32>, tensor<16x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32>
         return linalg.matmul(lhs, rhs, outs=[init_result.result])


        


More information about the Mlir-commits mailing list