[Mlir-commits] [mlir] [TOSA] Use attributes for unsigned rescale (PR #118075)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Nov 29 01:36:33 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-tosa

Author: Thomas Preud'homme (RoboTux)

<details>
<summary>Changes</summary>

Unsigned integer types are uncommon enough in MLIR that there is no
operation to cast a scalar from signless to unsigned and vice versa.
Currently tosa.rescale uses builtin.unrealized_conversion_cast which
does not lower. Instead, this commit introduces optional attributes to
indicate unsigned input or output, named similarly to those in the TOSA
specification. This is more in line with the rest of MLIR where specific
operations rather than values are signed/unsigned.


---
Full diff: https://github.com/llvm/llvm-project/pull/118075.diff


3 Files Affected:

- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+38-17) 
- (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+2-16) 
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+23-15) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 9f57efff5d1fd2..54fe31bbc84e05 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1869,22 +1869,23 @@ def Tosa_RescaleOp: Tosa_Op<"rescale", [Pure,
   let description = [{
     Rescale quantized values into a new domain. Supported rescalings are:
 
-    | Mode                   | Input | Output |
-    |------------------------|-------|--------|
-    | signed 8 to 8          | int8  | int8   |
-    | signed 8 to 16         | int8  | int16  |
-    | signed 8 to 32         | int8  | int32  |
-    | signed 16 to 8         | int16 | int8   |
-    | signed 16 to 16        | int16 | int16  |
-    | signed 16 to 32        | int16 | int32  |
-    | signed 32 to 8         | int32 | int8   |
-    | signed 32 to 16        | int32 | int16  |
-    | signed 32 to 32        | int32 | int32  |
-    | signed 48 to 8         | int48 | int8   |
-    | signed 48 to 16        | int48 | int16  |
-    | signed 48 to 32        | int48 | int32  |
-    | unsigned 8 to signed 8 | uint8 | int8   |
-    | signed 8 to unsigned 8 | int8  | uint8  |
+    | Mode                   | Input | Output | Unsigned | Unsigned |
+    |                        |       |        |  input   |  output  |
+    |------------------------|-------|--------|----------|----------|
+    | signed 8 to 8          | int8  | int8   |  false   |  false   |
+    | signed 8 to 16         | int8  | int16  |  false   |  false   |
+    | signed 8 to 32         | int8  | int32  |  false   |  false   |
+    | signed 16 to 8         | int16 | int8   |  false   |  false   |
+    | signed 16 to 16        | int16 | int16  |  false   |  false   |
+    | signed 16 to 32        | int16 | int32  |  false   |  false   |
+    | signed 32 to 8         | int32 | int8   |  false   |  false   |
+    | signed 32 to 16        | int32 | int16  |  false   |  false   |
+    | signed 32 to 32        | int32 | int32  |  false   |  false   |
+    | signed 48 to 8         | int48 | int8   |  false   |  false   |
+    | signed 48 to 16        | int48 | int16  |  false   |  false   |
+    | signed 48 to 32        | int48 | int32  |  false   |  false   |
+    | unsigned 8 to signed 8 | uint8 | int8   |  true    |  false   |
+    | signed 8 to unsigned 8 | int8  | uint8  |  false   |  true    |
   }];
 
   let arguments = (ins
@@ -1895,13 +1896,33 @@ def Tosa_RescaleOp: Tosa_Op<"rescale", [Pure,
     DenseI8ArrayAttr:$shift,
     BoolAttr:$scale32,
     BoolAttr:$double_round,
-    BoolAttr:$per_channel
+    BoolAttr:$per_channel,
+    DefaultValuedOptionalAttr<BoolAttr, "false">:$input_unsigned,
+    DefaultValuedOptionalAttr<BoolAttr, "false">:$output_unsigned
   );
 
   let results = (outs
     Tosa_Tensor:$output
   );
 
+  // Custom builder that does not require optional input_unsigned and
+  // output_unsigned.
+  let builders = [
+    OpBuilder<(ins "::mlir::Type":$output,
+                   "::mlir::Value":$input,
+                   "::mlir::IntegerAttr":$input_zp,
+                   "::mlir::IntegerAttr":$output_zp,
+                   "::mlir::DenseI32ArrayAttr":$multiplier,
+                   "::mlir::DenseI8ArrayAttr":$shift,
+                   "::mlir::BoolAttr":$scale32,
+                   "::mlir::BoolAttr":$double_round,
+                   "::mlir::BoolAttr":$per_channel), [{
+      auto FalseAttr = BoolAttr::get($_builder.getContext(), false);
+      build($_builder, $_state, output, input, input_zp, output_zp, multiplier,
+            shift, scale32, double_round, per_channel, FalseAttr, FalseAttr);
+    }]>
+  ];
+
   let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
 }
 
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 5291f95d371442..2e5ffe673ac1ef 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1261,14 +1261,7 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
           Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
 
           if (valueTy.getIntOrFloatBitWidth() < 32) {
-            if (valueTy.isUnsignedInteger()) {
-              value = nestedBuilder
-                          .create<UnrealizedConversionCastOp>(
-                              nestedLoc,
-                              nestedBuilder.getIntegerType(
-                                  valueTy.getIntOrFloatBitWidth()),
-                              value)
-                          .getResult(0);
+            if (op.getInputUnsigned()) {
               value = nestedBuilder.create<arith::ExtUIOp>(
                   nestedLoc, nestedBuilder.getI32Type(), value);
             } else {
@@ -1297,7 +1290,7 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
           int32_t intMax = APInt::getSignedMaxValue(outBitWidth).getSExtValue();
 
           // Unsigned integers have a difference output value.
-          if (outIntType.isUnsignedInteger()) {
+          if (op.getOutputUnsigned()) {
             intMin = 0;
             intMax = APInt::getMaxValue(outBitWidth).getZExtValue();
           }
@@ -1314,13 +1307,6 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
             value = nestedBuilder.create<arith::TruncIOp>(
                 nestedLoc, rewriter.getIntegerType(outIntType.getWidth()),
                 value);
-
-            if (outIntType.isUnsignedInteger()) {
-              value = nestedBuilder
-                          .create<UnrealizedConversionCastOp>(nestedLoc,
-                                                              outIntType, value)
-                          .getResult(0);
-            }
           }
 
           nestedBuilder.create<linalg::YieldOp>(loc, value);
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 1a29d3f9f3507c..265a75986c6c8d 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1132,11 +1132,21 @@ func.func @rescale_i8(%arg0 : tensor<2xi8>) -> () {
   // CHECK-DAG: linalg.yield [[TRUNC]]
   %0 = tosa.rescale %arg0 {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = array<i32: 19689>, shift = array<i8: 15>, scale32 = false, double_round = false, per_channel = false} : (tensor<2xi8>) -> tensor<2xi8>
 
+  // CHECK: return
+  return
+}
+
+// -----
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
+
+// CHECK-LABEL: @rescale_i8_unsigned_output
+// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
+func.func @rescale_i8_unsigned_output(%arg0 : tensor<2xi8>) -> () {
   // CHECK: [[C0:%.+]] = arith.constant 19689
   // CHECK: [[C1:%.+]] = arith.constant 15
   // CHECK: [[INIT:%.+]] = tensor.empty()
-  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xi8>) outs([[INIT]] : tensor<2xui8>)
-  // CHECK: ^bb0([[IN:%.+]]: i8, [[UNUSED:%.+]]: ui8):
+  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xi8>) outs([[INIT]] : tensor<2xi8>)
+  // CHECK: ^bb0([[IN:%.+]]: i8, [[UNUSED:%.+]]: i8):
   // CHECK: [[C17:%.+]] = arith.constant 17
   // CHECK: [[C22:%.+]] = arith.constant 22
   // CHECK-DAG: [[IN32:%.+]] = arith.extsi [[IN]]
@@ -1148,9 +1158,8 @@ func.func @rescale_i8(%arg0 : tensor<2xi8>) -> () {
   // CHECK-DAG: [[LOWER:%.+]] = arith.maxsi [[CMIN]], [[SCALED_ZEROED]]
   // CHECK-DAG: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
   // CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
-  // CHECK-DAG: [[CAST:%.+]] = builtin.unrealized_conversion_cast [[TRUNC]] : i8 to ui8
-  // CHECK: linalg.yield [[CAST]]
-  %1 = tosa.rescale %arg0 {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = array<i32: 19689>, shift = array<i8: 15>, scale32 = false, double_round = false, per_channel = false} : (tensor<2xi8>) -> tensor<2xui8>
+  // CHECK: linalg.yield [[TRUNC]]
+  %1 = tosa.rescale %arg0 {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = array<i32: 19689>, shift = array<i8: 15>, scale32 = false, double_round = false, per_channel = false, output_unsigned = true} : (tensor<2xi8>) -> tensor<2xi8>
 
   // CHECK: return
   return
@@ -1171,9 +1180,9 @@ func.func @rescale_i8_dyn_batch(%arg0 : tensor<?x2xi8>) -> () {
 
   // CHECK: %[[C0:.+]] = arith.constant 0
   // CHECK: %[[BATCH:.+]] = tensor.dim %[[ARG0]], %[[C0]]
-  // CHECK: %[[INIT:.+]] = tensor.empty(%[[BATCH]]) : tensor<?x2xui8>
-  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]] : tensor<?x2xi8>) outs(%[[INIT]] : tensor<?x2xui8>)
-  %1 = tosa.rescale %arg0 {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = array<i32: 19689>, shift = array<i8: 15>, scale32 = false, double_round = false, per_channel = false} : (tensor<?x2xi8>) -> tensor<?x2xui8>
+  // CHECK: %[[INIT:.+]] = tensor.empty(%[[BATCH]]) : tensor<?x2xi8>
+  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]] : tensor<?x2xi8>) outs(%[[INIT]] : tensor<?x2xi8>)
+  %1 = tosa.rescale %arg0 {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = array<i32: 19689>, shift = array<i8: 15>, scale32 = false, double_round = false, per_channel = false, output_unsigned = true} : (tensor<?x2xi8>) -> tensor<?x2xi8>
 
   return
 }
@@ -1199,18 +1208,17 @@ func.func @rescale_dyn(%arg0 : tensor<1x?x?x32xi32>) -> () {
 
 // CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
 
-// CHECK-LABEL: @rescale_ui8
+// CHECK-LABEL: @rescale_i8_unsigned_input
 // CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
-func.func @rescale_ui8(%arg0 : tensor<2xui8>) -> () {
+func.func @rescale_i8_unsigned_input(%arg0 : tensor<2xi8>) -> () {
   // CHECK: [[C0:%.+]] = arith.constant 19689
   // CHECK: [[C1:%.+]] = arith.constant 15
   // CHECK: [[INIT:%.+]] = tensor.empty()
-  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xui8>) outs([[INIT]] : tensor<2xi8>)
-  // CHECK: ^bb0([[IN:%.+]]: ui8, [[UNUSED:%.+]]: i8):
+  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xi8>) outs([[INIT]] : tensor<2xi8>)
+  // CHECK: ^bb0([[IN:%.+]]: i8, [[UNUSED:%.+]]: i8):
   // CHECK: [[C17:%.+]] = arith.constant 17
   // CHECK: [[C22:%.+]] = arith.constant 22
-  // CHECK-DAG: [[CAST:%.+]] = builtin.unrealized_conversion_cast [[IN]] : ui8 to i8
-  // CHECK-DAG: [[IN32:%.+]] = arith.extui [[CAST]]
+  // CHECK-DAG: [[IN32:%.+]] = arith.extui [[IN]]
   // CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C17]]
   // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {double_round = false}
   // CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C22]]
@@ -1220,7 +1228,7 @@ func.func @rescale_ui8(%arg0 : tensor<2xui8>) -> () {
   // CHECK-DAG: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
   // CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
   // CHECK: linalg.yield [[TRUNC]]
-  %0 = tosa.rescale %arg0 {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = array<i32: 19689>, shift = array<i8: 15>, scale32 = false, double_round = false, per_channel = false} : (tensor<2xui8>) -> tensor<2xi8>
+  %0 = tosa.rescale %arg0 {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = array<i32: 19689>, shift = array<i8: 15>, scale32 = false, double_round = false, per_channel = false, input_unsigned = true} : (tensor<2xi8>) -> tensor<2xi8>
 
   return
 }

``````````

</details>


https://github.com/llvm/llvm-project/pull/118075


More information about the Mlir-commits mailing list