[Mlir-commits] [mlir] [mlir][spirv] Add Cast/Rescale ops in TOSA Ext Inst Set (PR #189028)

Davide Grohmann llvmlistbot at llvm.org
Mon Mar 30 01:46:18 PDT 2026


================
@@ -2648,4 +2648,128 @@ def SPIRV_TosaResizeOp : SPIRV_TosaOpWithResult<"Resize", 63, [Pure,
 }
 
 
+def SPIRV_TosaCastOp : SPIRV_TosaOpWithResult<"Cast", 64, [Pure,
+  AllShapesMatch<["input", "output"]>]> {
+  let summary = "Cast operation.";
+
+  let description = [{
+    Casts a tensor from one data type to another.
+
+    References:
+      * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_cast
+      * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_cast
+
+    #### Example:
+    ```mlir
+    %0 = spirv.Tosa.Cast %arg0 : !spirv.arm.tensor<1x65538x1x2xi8> -> !spirv.arm.tensor<1x65538x1x2xi32>
+    %0 = spirv.Tosa.Cast %arg0 : !spirv.arm.tensor<11x5x14x4xf32> -> !spirv.arm.tensor<11x5x14x4xf16>
+    ```
+  }];
+
+  let arguments = (ins
+    SPIRV_TosaAny_TensorArm: $input
+  );
+
+  let results = (outs
+    SPIRV_TosaAny_TensorArm: $output
+  );
+
+  let assemblyFormat = [{
+    $input
+    attr-dict `:` type(operands) `->` type(results)
+  }];
+
+  let extraClassDeclaration = extraBaseClassDeclaration#[{
+    ::mlir::spirv::TensorArmType getInputType() {
+      return cast<::mlir::spirv::TensorArmType>(getInput().getType());
+    }
+  }];
+}
+
+
+def SPIRV_TosaRescaleOp : SPIRV_TosaOpWithResult<"Rescale", 65, [NoMemoryEffect,
+  AllShapesMatch<["input", "output"]>,
+  AllElementTypesMatch<["input", "input_zp"]>,
+  AllElementTypesMatch<["output", "output_zp"]>,
+  ElementTypeMatchesScale32<"multiplier">,
+  TensorLengthMatchesPerChannel<"multiplier">,
+  TensorLengthMatchesPerChannel<"shift">,
+  TypeConstraintImplicationOn<"input", I8, "output", [I8, I16, I32]>,
+  TypeConstraintImplicationOn<"input", I16, "output", [I8, I16, I32]>,
+  TypeConstraintImplicationOn<"input", I32, "output", [I8, I16, I32]>,
+  TypeConstraintImplicationOn<"input", I64, "output", [I8, I16, I32]>]> {
+  let summary = "Rescale operator.";
+
+  let description = [{
+    Rescale is defined using an integer multiply, add, and shift.
+
+    Rescale supports two precisions of multiplier: 16-bit and 32-bit. The
+    32-bit multiplier version supports two rounding modes to enable simpler
+    lowering of existing frameworks that use two stage rounding. All arithmetic
+    is designed so that it does not overflow a 64-bit accumulator and that the
+    result fits in 32 bits. In particular, a 48-bit value cannot be scaled with
+    the 32-bit multiplier because the accumulator would need to have 80 bits.
+
+    The shift and value range are limited to allow a variety of implementations.
+    The limit of 62 on shift allows the shift to be decomposed as two right
+    shifts of 31.
+
+    Undefined behaviour may occur if the calculated result underflows or overflows
+    their integer ranges.
+
+    References:
+      * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_rescale
+      * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_rescale
+
+    #### Example:
+    ```mlir
+    %9 = spirv.Tosa.Rescale scale32 = true, rounding_mode = <DoubleRound>, per_channel = false, input_unsigned = false, output_unsigned = true, %arg0, %multiplier, %shift, %input_zp, %output_zp : !spirv.arm.tensor<17x29x19xi16>, !spirv.arm.tensor<1xi32>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi16> -> !spirv.arm.tensor<17x29x19xi16>
+    ```
+  }];
+
+  let arguments = (ins
+    SPIRV_BoolConstAttr: $scale32,
+    SPIRV_TosaExtRoundingModeAttr: $rounding_mode,
+    SPIRV_BoolConstAttr: $per_channel,
+    SPIRV_BoolConstAttr: $input_unsigned,
----------------
davidegrohmann wrote:

Add the comment and improved the constraints around unsigned I/O.

https://github.com/llvm/llvm-project/pull/189028


More information about the Mlir-commits mailing list