[Mlir-commits] [mlir] [mlir][spirv] Add Cast/Rescale ops in TOSA Ext Inst Set (PR #189028)

Davide Grohmann llvmlistbot at llvm.org
Mon Mar 30 00:27:50 PDT 2026


================
@@ -2648,4 +2648,128 @@ def SPIRV_TosaResizeOp : SPIRV_TosaOpWithResult<"Resize", 63, [Pure,
 }
 
 
+def SPIRV_TosaCastOp : SPIRV_TosaOpWithResult<"Cast", 64, [Pure,
+  AllShapesMatch<["input", "output"]>]> {
+  let summary = "Cast operation.";
+
+  let description = [{
+    Casts a tensor from one data type to another.
+
+    References:
+      * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_cast
+      * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_cast
+
+    #### Example:
+    ```mlir
+    %0 = spirv.Tosa.Cast %arg0 : !spirv.arm.tensor<1x65538x1x2xi8> -> !spirv.arm.tensor<1x65538x1x2xi32>
+    %0 = spirv.Tosa.Cast %arg0 : !spirv.arm.tensor<11x5x14x4xf32> -> !spirv.arm.tensor<11x5x14x4xf16>
+    ```
+  }];
+
+  let arguments = (ins
+    SPIRV_TosaAny_TensorArm: $input
+  );
+
+  let results = (outs
+    SPIRV_TosaAny_TensorArm: $output
+  );
+
+  let assemblyFormat = [{
+    $input
+    attr-dict `:` type(operands) `->` type(results)
+  }];
+
+  let extraClassDeclaration = extraBaseClassDeclaration#[{
+    ::mlir::spirv::TensorArmType getInputType() {
+      return cast<::mlir::spirv::TensorArmType>(getInput().getType());
+    }
+  }];
+}
+
+
+def SPIRV_TosaRescaleOp : SPIRV_TosaOpWithResult<"Rescale", 65, [NoMemoryEffect,
+  AllShapesMatch<["input", "output"]>,
+  AllElementTypesMatch<["input", "input_zp"]>,
+  AllElementTypesMatch<["output", "output_zp"]>,
+  ElementTypeMatchesScale32<"multiplier">,
+  TensorLengthMatchesPerChannel<"multiplier">,
+  TensorLengthMatchesPerChannel<"shift">,
+  TypeConstraintImplicationOn<"input", I8, "output", [I8, I16, I32]>,
+  TypeConstraintImplicationOn<"input", I16, "output", [I8, I16, I32]>,
+  TypeConstraintImplicationOn<"input", I32, "output", [I8, I16, I32]>,
+  TypeConstraintImplicationOn<"input", I64, "output", [I8, I16, I32]>]> {
+  let summary = "Rescale operator.";
+
+  let description = [{
+    Rescale is defined using an integer multiply, add, and shift.
+
+    Rescale supports two precisions of multiplier: 16-bit and 32-bit. The
+    32-bit multiplier version supports two rounding modes to enable simpler
+    lowering of existing frameworks that use two stage rounding. All arithmetic
+    is designed so that it does not overflow a 64-bit accumulator and that the
+    result fits in 32 bits. In particular, a 48-bit value cannot be scaled with
----------------
davidegrohmann wrote:

Thanks for catching this. My mistake, that should have been replaced by "64-bit value" since the 48-bit in TOSA is mapped to 64-bit in SPIR-V.

https://github.com/llvm/llvm-project/pull/189028


More information about the Mlir-commits mailing list