[Mlir-commits] [mlir] [mlir][tosa] Require signless types in validation and add corresponding conversion pass (PR #144367)
Luke Hutton
llvmlistbot at llvm.org
Mon Jun 16 08:02:12 PDT 2025
https://github.com/lhutton1 created https://github.com/llvm/llvm-project/pull/144367
Firstly, this commit requires that all types are signless in the strict mode of the validation pass. This is because signless types on operations are required by the TOSA specification. The "strict" mode in the validation pass is the final check for TOSA conformance to the specification, which can often be used for conversion to other formats.
In addition, a conversion pass `--tosa-convert-integer-type-to-signless` is provided to allow a user to convert all integer types to signless. The intention is that this pass can be run before the validation pass. Following use of this pass, input/output information should be carried independently by the user.
>From 40e3f612e52b9b40ce89ab9a6d461389ecf5f8d7 Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Wed, 4 Jun 2025 16:10:32 +0000
Subject: [PATCH] [mlir][tosa] Require signless types in validation and add
corresponding conversion pass
Firstly, this commit requires that all types are signless in the strict
mode of the validation pass. This is because signless types on operations
are required by the TOSA specification. The "strict" mode in the
validation pass is the final check for TOSA conformance to the
specification, which can often be used for conversion to other formats.
In addition, a conversion pass `--tosa-convert-integer-type-to-signless`
is provided to allow a user to convert all integer types to signless.
The intention is that this pass can be run before the validation pass.
Following use of this pass, input/output information should be carried
independently by the user.
Change-Id: Id7aebf0071c9a7516c77f55062db82760c0da533
---
.../mlir/Dialect/Tosa/Transforms/Passes.td | 14 ++
.../Dialect/Tosa/Transforms/CMakeLists.txt | 1 +
.../TosaConvertIntegerTypeToSignless.cpp | 134 ++++++++++++++++++
.../Tosa/Transforms/TosaValidation.cpp | 9 +-
mlir/test/Dialect/Tosa/invalid.mlir | 2 +
...tosa-convert-integer-type-to-signless.mlir | 73 ++++++++++
.../Dialect/Tosa/tosa-validation-valid.mlir | 31 ++++
7 files changed, 260 insertions(+), 4 deletions(-)
create mode 100644 mlir/lib/Dialect/Tosa/Transforms/TosaConvertIntegerTypeToSignless.cpp
create mode 100644 mlir/test/Dialect/Tosa/tosa-convert-integer-type-to-signless.mlir
create mode 100644 mlir/test/Dialect/Tosa/tosa-validation-valid.mlir
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
index d005a4cc6859c..b96682843538c 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
@@ -127,4 +127,18 @@ def TosaReduceTransposes : Pass<"tosa-reduce-transposes", "func::FuncOp"> {
}];
}
+def TosaConvertIntegerTypeToSignless : Pass<"tosa-convert-integer-type-to-signless", "func::FuncOp"> {
+ let summary = "Convert integer types to signless";
+ let description = [{
+ This pass converts signed or unsigned integer types to signless. It
+ currently does this greedily for all operators and can also change the
+ signature of the function. Should the signature of the entrypoint
+ function change, it will be the responsibility of the user to carry
+ signedness information of the inputs and outputs independently.
+
+ This can be a useful transformation for conversion to other formats
+ that require strict adherence to the TOSA specification.
+ }];
+}
+
#endif // MLIR_DIALECT_TOSA_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
index bbf079faea3d0..803993bb1008d 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
@@ -1,4 +1,5 @@
add_mlir_dialect_library(MLIRTosaTransforms
+ TosaConvertIntegerTypeToSignless.cpp
TosaDecomposeTransposeConv.cpp
TosaDecomposeDepthwise.cpp
TosaFolders.cpp
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaConvertIntegerTypeToSignless.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaConvertIntegerTypeToSignless.cpp
new file mode 100644
index 0000000000000..3085e56ceebc0
--- /dev/null
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaConvertIntegerTypeToSignless.cpp
@@ -0,0 +1,134 @@
+//===- TosaConvertIntegerTypeToSignless.cpp
+//-------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===-------------------------------------------------------------------------------===//
+
+// -----------
+// Motivation:
+// -----------
+
+// The TOSA specification uses a signless type system, which means that
+// information about signedness must be encapsulated by the operations
+// themselves. For example, tosa.rescale provides the attrbutes `input_unsigned`
+// and `output_unsigned` to indicate whether the input/output should be
+// interpreted as unsigned or signed.
+
+// The TOSA dialect, on the other hand, allows the use of signed or unsigned
+// types in addition to signless. As such, when converting from TOSA dialect to
+// other formats, we need to ensure that we conform to the TOSA specification.
+
+// ---------
+// Overview:
+// ---------
+
+// This pass converts signed or unsigned integer types to signless. It currently
+// does this greedily for all operators and can also change the signature of the
+// function. Should the signature of the entrypoint function change, it will be
+// the responsibility of the user to carry signedness information of the inputs
+// and outputs independently.
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+namespace tosa {
+
+#define GEN_PASS_DEF_TOSACONVERTINTEGERTYPETOSIGNLESS
+#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
+
+namespace {
+class ToSignlessTensorTypeConverter : public TypeConverter {
+ static Type convertType(Type type) {
+ const auto tensorType = dyn_cast<TensorType>(type);
+ if (!tensorType)
+ return type;
+
+ const auto intType = dyn_cast<IntegerType>(tensorType.getElementType());
+ if (!intType ||
+ intType.getSignedness() == IntegerType::SignednessSemantics::Signless)
+ return type;
+
+ const auto signlessType = IntegerType::get(
+ intType.getContext(), intType.getWidth(), IntegerType::Signless);
+ return tensorType.cloneWith(std::nullopt, signlessType);
+ }
+
+public:
+ explicit ToSignlessTensorTypeConverter() { addConversion(convertType); }
+};
+
+class ConvertGenericOpWithIntegerTensorType : public ConversionPattern {
+public:
+ ConvertGenericOpWithIntegerTensorType(TypeConverter &typeConverter,
+ MLIRContext *context)
+ : ConversionPattern(typeConverter, MatchAnyOpTypeTag{}, 0, context) {}
+
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ // Convert integer types to signless
+ SmallVector<Type, 4> resultTypes;
+ if (failed(typeConverter->convertTypes(op->getResultTypes(), resultTypes)))
+ return failure();
+
+ // Create new op with replaced operands and results
+ auto *newOp = Operation::create(
+ op->getLoc(), op->getName(), resultTypes, operands, op->getAttrs(),
+ op->getPropertiesStorage(), op->getSuccessors(), op->getNumRegions());
+
+ // Handle regions in e.g. tosa.cond_if and tosa.while_loop
+ for (auto regions : llvm::zip(op->getRegions(), newOp->getRegions())) {
+ Region &before = std::get<0>(regions);
+ Region &parent = std::get<1>(regions);
+ rewriter.inlineRegionBefore(before, parent, parent.end());
+ if (failed(rewriter.convertRegionTypes(&parent, *typeConverter)))
+ return failure();
+ }
+
+ // Replace with rewritten op
+ rewriter.insert(newOp);
+ rewriter.replaceOp(op, newOp->getResults());
+ return success();
+ }
+};
+
+class TosaConvertIntegerTypeToSignless
+ : public impl::TosaConvertIntegerTypeToSignlessBase<
+ TosaConvertIntegerTypeToSignless> {
+public:
+ void runOnOperation() override {
+ MLIRContext *context = &getContext();
+ ConversionTarget target(*context);
+ ToSignlessTensorTypeConverter typeConverter;
+
+ target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
+ return typeConverter.isSignatureLegal(op.getFunctionType()) &&
+ typeConverter.isLegal(&op.getBody());
+ });
+ target.markUnknownOpDynamicallyLegal([&](Operation *op) {
+ return typeConverter.isLegal(op->getOperandTypes()) &&
+ typeConverter.isLegal(op->getResultTypes());
+ });
+
+ RewritePatternSet patterns(context);
+ populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
+ patterns, typeConverter);
+ patterns.add<ConvertGenericOpWithIntegerTensorType>(typeConverter, context);
+
+ if (failed(
+ applyFullConversion(getOperation(), target, std::move(patterns))))
+ signalPassFailure();
+ }
+};
+
+} // namespace
+
+} // namespace tosa
+} // namespace mlir
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 229f42d3178b5..3f27849b8c90c 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -1320,13 +1320,14 @@ void TosaValidation::runOnOperation() {
// validate operator element types:
// - rescale operator is allowed to have ui8/ui16/ui32
- // operands/results
+ // operands/results when strictOpSpecAlignment is false
// - perform valid element type check at the beginning to
// protect rest of code against quantized element types
- const bool opIsRescale = isa<tosa::RescaleOp>(op);
+ const bool allowUnsigned =
+ !strictOpSpecAlignment && isa<tosa::RescaleOp>(op);
for (Value operand : op->getOperands()) {
auto elementTy = getElementTypeOrSelf(operand);
- if (!isValidElementType(elementTy, opIsRescale)) {
+ if (!isValidElementType(elementTy, allowUnsigned)) {
op->emitOpError() << "is not profile-aligned: element type "
<< elementTy << " is not legal";
return signalPassFailure();
@@ -1334,7 +1335,7 @@ void TosaValidation::runOnOperation() {
}
for (Type resultTy : op->getResultTypes()) {
auto elementTy = getElementTypeOrSelf(resultTy);
- if (!isValidElementType(elementTy, opIsRescale)) {
+ if (!isValidElementType(elementTy, allowUnsigned)) {
op->emitOpError() << "is not profile-aligned: element type "
<< elementTy << " is not legal";
return signalPassFailure();
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 805522799a6d8..e25b3b7ef3e3a 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -2000,6 +2000,7 @@ func.func @test_rescale_input_unsigned(%arg0: tensor<1x1xui8>) -> (tensor<1x1xi8
%1 = "tosa.const"() <{values = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32>
%2 = "tosa.const"() <{values = dense<3> : tensor<1xi8>}> : () -> tensor<1xi8>
%3 = "tosa.const"() <{values = dense<-128> : tensor<1xi8>}> : () -> tensor<1xi8>
+ // expected-error at +1 {{'tosa.rescale' op is not profile-aligned: element type 'ui8' is not legal}}
%r = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = true, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xui8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1xi8>
return %r : tensor<1x1xi8>
}
@@ -2012,6 +2013,7 @@ func.func @test_rescale_output_unsigned(%arg0: tensor<1x1xi8>) -> (tensor<1x1xui
%1 = "tosa.const"() <{values = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32>
%2 = "tosa.const"() <{values = dense<3> : tensor<1xi8>}> : () -> tensor<1xi8>
%3 = "tosa.const"() <{values = dense<-128> : tensor<1xi8>}> : () -> tensor<1xi8>
+ // expected-error at +1 {{'tosa.rescale' op is not profile-aligned: element type 'ui8' is not legal}}
%r = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = false, output_unsigned = true, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1xui8>
return %r : tensor<1x1xui8>
}
diff --git a/mlir/test/Dialect/Tosa/tosa-convert-integer-type-to-signless.mlir b/mlir/test/Dialect/Tosa/tosa-convert-integer-type-to-signless.mlir
new file mode 100644
index 0000000000000..38ac8d8fb66d9
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/tosa-convert-integer-type-to-signless.mlir
@@ -0,0 +1,73 @@
+// RUN: mlir-opt --split-input-file --tosa-convert-integer-type-to-signless %s | FileCheck %s
+
+// -----
+
+// CHECK-LABEL: test_rescale_output_unsigned
+// CHECK: %arg0: tensor<1x1xi8>
+func.func @test_rescale_output_unsigned(%arg0: tensor<1x1xi8>) -> (tensor<1x1xui8>) {
+ %0 = "tosa.const"() <{values = dense<1> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %1 = "tosa.const"() <{values = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32>
+ %2 = "tosa.const"() <{values = dense<3> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %3 = "tosa.const"() <{values = dense<-128> : tensor<1xi8>}> : () -> tensor<1xi8>
+ // CHECK: %[[RESCALE:.*]] = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = false, output_unsigned = true, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1xi8>
+ %r = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = false, output_unsigned = true, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1xui8>
+ // CHECK: return %[[RESCALE]] : tensor<1x1xi8>
+ return %r : tensor<1x1xui8>
+}
+
+// -----
+
+// CHECK-LABEL: test_rescale_input_unsigned
+// CHECK: %arg0: tensor<1x1xi16>
+func.func @test_rescale_input_unsigned(%arg0: tensor<1x1xui16>) -> (tensor<1x1xi8>) {
+ %0 = "tosa.const"() <{values = dense<1> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %1 = "tosa.const"() <{values = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32>
+ %2 = "tosa.const"() <{values = dense<3> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %3 = "tosa.const"() <{values = dense<32768> : tensor<1xi16>}> : () -> tensor<1xi16>
+ // CHECK: %[[RESCALE:.*]] = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = true, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xi16>, tensor<1xi32>, tensor<1xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<1x1xi8>
+ %r = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = true, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xui16>, tensor<1xi32>, tensor<1xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<1x1xi8>
+ // CHECK: return %[[RESCALE]] : tensor<1x1xi8>
+ return %r : tensor<1x1xi8>
+}
+
+// -----
+
+// CHECK-LABEL: test_unsigned_function_signature
+// CHECK: %arg0: tensor<1xi8>, %arg1: tensor<1xi8>
+func.func @test_unsigned_function_signature(%arg0: tensor<1xui8>, %arg1: tensor<1xui8>) -> (tensor<1xui8>, tensor<1xui8>) {
+ // CHECK: return %arg0, %arg1 : tensor<1xi8>, tensor<1xi8>
+ return %arg0, %arg1 : tensor<1xui8>, tensor<1xui8>
+}
+
+// -----
+
+// CHECK-LABEL: test_no_change
+// CHECK: %arg0: tensor<13x21x3xi8>
+func.func @test_no_change(%arg0: tensor<13x21x3xi8>) -> tensor<13x21x3xi8> {
+ %0 = tosa.reverse %arg0 {axis = 0 : i32} : (tensor<13x21x3xi8>) -> tensor<13x21x3xi8>
+ // CHECK: return %0 : tensor<13x21x3xi8>
+ return %0 : tensor<13x21x3xi8>
+}
+
+// -----
+
+// CHECK-LABEL: test_regions
+// CHECK: %arg0: tensor<i8>, %arg1: tensor<i8>
+func.func @test_regions(%arg0: tensor<ui8>, %arg1: tensor<ui8>, %arg2: tensor<i1>) -> tensor<ui8> {
+ // CHECK: tosa.cond_if %arg2 -> (tensor<i8>)
+ %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
+ ^bb0(%arg3: tensor<ui8>, %arg4: tensor<ui8>):
+ // CHECK: %1 = tosa.add %arg0, %arg1 : (tensor<i8>, tensor<i8>) -> tensor<i8>
+ %1 = tosa.add %arg0, %arg1 : (tensor<ui8>, tensor<ui8>) -> tensor<ui8>
+ // CHECK: tosa.yield %1 : tensor<i8>
+ tosa.yield %1 : tensor<ui8>
+ }, {
+ ^bb0(%arg3: tensor<ui8>, %arg4: tensor<ui8>):
+ // CHECK: %1 = tosa.sub %arg0, %arg1 : (tensor<i8>, tensor<i8>) -> tensor<i8>
+ %1 = tosa.sub %arg0, %arg1 : (tensor<ui8>, tensor<ui8>) -> tensor<ui8>
+ // CHECK: tosa.yield %1 : tensor<i8>
+ tosa.yield %1 : tensor<ui8>
+ }) : (tensor<i1>, tensor<ui8>, tensor<ui8>) -> tensor<ui8>
+ // CHECK: return %0 : tensor<i8>
+ return %0 : tensor<ui8>
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-validation-valid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-valid.mlir
new file mode 100644
index 0000000000000..cab14201dc0ce
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/tosa-validation-valid.mlir
@@ -0,0 +1,31 @@
+//--------------------------------------------------------------------------------------------------
+// Test valid IR in terms of the shape and type of tensor, and the argument type of
+// operation. Excludes the profile compilance checking since it is performed earlier in the
+// validation flow.
+//--------------------------------------------------------------------------------------------------
+
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=pro_int,pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround" | FileCheck %s
+
+// -----
+
+// CHECK-LABEL: test_rescale_input_unsigned
+func.func @test_rescale_input_unsigned(%arg0: tensor<1x1xui8>) -> (tensor<1x1xi8>) {
+ %0 = "tosa.const"() <{values = dense<1> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %1 = "tosa.const"() <{values = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32>
+ %2 = "tosa.const"() <{values = dense<3> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %3 = "tosa.const"() <{values = dense<-128> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %r = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = true, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xui8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1xi8>
+ return %r : tensor<1x1xi8>
+}
+
+// -----
+
+// CHECK-LABEL: test_rescale_output_unsigned
+func.func @test_rescale_output_unsigned(%arg0: tensor<1x1xi8>) -> (tensor<1x1xui8>) {
+ %0 = "tosa.const"() <{values = dense<1> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %1 = "tosa.const"() <{values = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32>
+ %2 = "tosa.const"() <{values = dense<3> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %3 = "tosa.const"() <{values = dense<-128> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %r = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = false, output_unsigned = true, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1xui8>
+ return %r : tensor<1x1xui8>
+}
More information about the Mlir-commits
mailing list