[Mlir-commits] [mlir] [mlir][spirv] Add Cast/Rescale ops in TOSA Ext Inst Set (PR #189028)
Davide Grohmann
llvmlistbot at llvm.org
Mon Mar 30 00:27:50 PDT 2026
================
@@ -2648,4 +2648,128 @@ def SPIRV_TosaResizeOp : SPIRV_TosaOpWithResult<"Resize", 63, [Pure,
}
+def SPIRV_TosaCastOp : SPIRV_TosaOpWithResult<"Cast", 64, [Pure,
+ AllShapesMatch<["input", "output"]>]> {
+ let summary = "Cast operation.";
+
+ let description = [{
+ Casts a tensor from one data type to another.
+
+ References:
+ * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_cast
+ * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_cast
+
+ #### Example:
+ ```mlir
+ %0 = spirv.Tosa.Cast %arg0 : !spirv.arm.tensor<1x65538x1x2xi8> -> !spirv.arm.tensor<1x65538x1x2xi32>
+ %0 = spirv.Tosa.Cast %arg0 : !spirv.arm.tensor<11x5x14x4xf32> -> !spirv.arm.tensor<11x5x14x4xf16>
+ ```
+ }];
+
+ let arguments = (ins
+ SPIRV_TosaAny_TensorArm: $input
+ );
+
+ let results = (outs
+ SPIRV_TosaAny_TensorArm: $output
+ );
+
+ let assemblyFormat = [{
+ $input
+ attr-dict `:` type(operands) `->` type(results)
+ }];
+
+ let extraClassDeclaration = extraBaseClassDeclaration#[{
+ ::mlir::spirv::TensorArmType getInputType() {
+ return cast<::mlir::spirv::TensorArmType>(getInput().getType());
+ }
+ }];
+}
+
+
+def SPIRV_TosaRescaleOp : SPIRV_TosaOpWithResult<"Rescale", 65, [NoMemoryEffect,
+ AllShapesMatch<["input", "output"]>,
+ AllElementTypesMatch<["input", "input_zp"]>,
+ AllElementTypesMatch<["output", "output_zp"]>,
+ ElementTypeMatchesScale32<"multiplier">,
+ TensorLengthMatchesPerChannel<"multiplier">,
+ TensorLengthMatchesPerChannel<"shift">,
+ TypeConstraintImplicationOn<"input", I8, "output", [I8, I16, I32]>,
+ TypeConstraintImplicationOn<"input", I16, "output", [I8, I16, I32]>,
+ TypeConstraintImplicationOn<"input", I32, "output", [I8, I16, I32]>,
+ TypeConstraintImplicationOn<"input", I64, "output", [I8, I16, I32]>]> {
+ let summary = "Rescale operator.";
+
+ let description = [{
+ Rescale is defined using an integer multiply, add, and shift.
+
+ Rescale supports two precisions of multiplier: 16-bit and 32-bit. The
+ 32-bit multiplier version supports two rounding modes to enable simpler
+ lowering of existing frameworks that use two stage rounding. All arithmetic
+ is designed so that it does not overflow a 64-bit accumulator and that the
+ result fits in 32 bits. In particular, a 48-bit value cannot be scaled with
----------------
davidegrohmann wrote:
Thanks for catching this. My mistake, that should have been replaced by "64-bit value" since the 48-bit in TOSA is mapped to 64-bit in SPIR-V.
https://github.com/llvm/llvm-project/pull/189028
More information about the Mlir-commits
mailing list