[Mlir-commits] [mlir] [mlir][spirv] Add Cast/Rescale ops in TOSA Ext Inst Set (PR #189028)
Davide Grohmann
llvmlistbot at llvm.org
Mon Mar 30 01:46:18 PDT 2026
================
@@ -2648,4 +2648,128 @@ def SPIRV_TosaResizeOp : SPIRV_TosaOpWithResult<"Resize", 63, [Pure,
}
+def SPIRV_TosaCastOp : SPIRV_TosaOpWithResult<"Cast", 64, [Pure,
+ AllShapesMatch<["input", "output"]>]> {
+ let summary = "Cast operation.";
+
+ let description = [{
+ Casts a tensor from one data type to another.
+
+ References:
+ * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_cast
+ * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_cast
+
+ #### Example:
+ ```mlir
+ %0 = spirv.Tosa.Cast %arg0 : !spirv.arm.tensor<1x65538x1x2xi8> -> !spirv.arm.tensor<1x65538x1x2xi32>
+ %0 = spirv.Tosa.Cast %arg0 : !spirv.arm.tensor<11x5x14x4xf32> -> !spirv.arm.tensor<11x5x14x4xf16>
+ ```
+ }];
+
+ let arguments = (ins
+ SPIRV_TosaAny_TensorArm: $input
+ );
+
+ let results = (outs
+ SPIRV_TosaAny_TensorArm: $output
+ );
+
+ let assemblyFormat = [{
+ $input
+ attr-dict `:` type(operands) `->` type(results)
+ }];
+
+ let extraClassDeclaration = extraBaseClassDeclaration#[{
+ ::mlir::spirv::TensorArmType getInputType() {
+ return cast<::mlir::spirv::TensorArmType>(getInput().getType());
+ }
+ }];
+}
+
+
+def SPIRV_TosaRescaleOp : SPIRV_TosaOpWithResult<"Rescale", 65, [NoMemoryEffect,
+ AllShapesMatch<["input", "output"]>,
+ AllElementTypesMatch<["input", "input_zp"]>,
+ AllElementTypesMatch<["output", "output_zp"]>,
+ ElementTypeMatchesScale32<"multiplier">,
+ TensorLengthMatchesPerChannel<"multiplier">,
+ TensorLengthMatchesPerChannel<"shift">,
+ TypeConstraintImplicationOn<"input", I8, "output", [I8, I16, I32]>,
+ TypeConstraintImplicationOn<"input", I16, "output", [I8, I16, I32]>,
+ TypeConstraintImplicationOn<"input", I32, "output", [I8, I16, I32]>,
+ TypeConstraintImplicationOn<"input", I64, "output", [I8, I16, I32]>]> {
+ let summary = "Rescale operator.";
+
+ let description = [{
+ Rescale is defined using an integer multiply, add, and shift.
+
+ Rescale supports two precisions of multiplier: 16-bit and 32-bit. The
+ 32-bit multiplier version supports two rounding modes to enable simpler
+ lowering of existing frameworks that use two stage rounding. All arithmetic
+ is designed so that it does not overflow a 64-bit accumulator and that the
+ result fits in 32 bits. In particular, a 48-bit value cannot be scaled with
+ the 32-bit multiplier because the accumulator would need to have 80 bits.
+
+ The shift and value range are limited to allow a variety of implementations.
+ The limit of 62 on shift allows the shift to be decomposed as two right
+ shifts of 31.
+
+ Undefined behaviour may occur if the calculated result underflows or overflows
+ their integer ranges.
+
+ References:
+ * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_rescale
+ * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_rescale
+
+ #### Example:
+ ```mlir
+ %9 = spirv.Tosa.Rescale scale32 = true, rounding_mode = <DoubleRound>, per_channel = false, input_unsigned = false, output_unsigned = true, %arg0, %multiplier, %shift, %input_zp, %output_zp : !spirv.arm.tensor<17x29x19xi16>, !spirv.arm.tensor<1xi32>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi16> -> !spirv.arm.tensor<17x29x19xi16>
+ ```
+ }];
+
+ let arguments = (ins
+ SPIRV_BoolConstAttr: $scale32,
+ SPIRV_TosaExtRoundingModeAttr: $rounding_mode,
+ SPIRV_BoolConstAttr: $per_channel,
+ SPIRV_BoolConstAttr: $input_unsigned,
----------------
davidegrohmann wrote:
Add the comment and improved the constraints around unsigned I/O.
https://github.com/llvm/llvm-project/pull/189028
More information about the Mlir-commits
mailing list