[Mlir-commits] [mlir] [TOSA] Allow all integer types in most ops (PR #86509)
Matthias Gehre
llvmlistbot at llvm.org
Tue Mar 26 13:58:47 PDT 2024
https://github.com/mgehre-amd updated https://github.com/llvm/llvm-project/pull/86509
>From 9820fd93bb0c71d458a77fa93d302d145e69e078 Mon Sep 17 00:00:00 2001
From: Matthias Gehre <matthias.gehre at amd.com>
Date: Fri, 22 Mar 2024 15:18:04 +0100
Subject: [PATCH 1/2] [TOSA] Allow all integer types in most ops
As discussed in one of the previous TOSA community meetings,
we would like to allow for more integer types in the TOSA dialect
to enable more use cases.
For strict standards conformance, the TosaValidation pass can be used.
Follow up PRs will extend conversions from TOSA where needed.
---
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 2 +-
.../mlir/Dialect/Tosa/IR/TosaTypesBase.td | 27 +++-------
.../Tosa/Transforms/TosaValidation.cpp | 51 +++++++++++++++++--
mlir/test/Dialect/Tosa/level_check.mlir | 16 ++++++
4 files changed, 71 insertions(+), 25 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 0ecded75c5d8bc..306e4a43952088 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1942,7 +1942,7 @@ def Tosa_ConstOp : Tosa_Op<"const", [ConstantLike, Pure,
);
let results = (outs
- TensorOf<[AnyTypeOf<[Tosa_AnyNumber_Plus_F64, Tosa_Int4]>]>:$output
+ TensorOf<[AnyTypeOf<[Tosa_AnyNumber_Plus_F64]>]>:$output
);
let hasFolder = 1;
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index 5a4d6ff464f19e..cff3de0a69af95 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -38,29 +38,17 @@ class Tosa_QuantizedType<string n, list<int> params, bit signed>
// Used to express accumulator results or compare results.
//===----------------------------------------------------------------------===//
-def Tosa_UInt8 : UI<8>;
-def Tosa_UInt16 : UI<16>;
-
def Tosa_Int4 : I<4>;
def Tosa_Int8 : I<8>;
-def Tosa_Int16 : I<16>;
def Tosa_Int32 : I<32>;
-def Tosa_Int48 : I<48>;
def Tosa_Int64 : I<64>;
-def Tosa_SignedInt : AnyTypeOf<[Tosa_Int8,
- Tosa_Int16,
- Tosa_Int32,
- Tosa_Int48,
- Tosa_Int64]>;
-
-def Tosa_Bool : I<1>;
-
-// No unsigned unquantized int types.
-def Tosa_Int : AnyTypeOf<[Tosa_Bool,
- Tosa_UInt8,
- Tosa_UInt16,
- Tosa_SignedInt]>;
+// The TOSA dialect allows more types than the TOSA standard to allow for
+// experimentation. For historical reasons, signless is used in the place of
+// signed.
+// The TosaValidation pass can be used to check for standard conformance.
+def Tosa_Int : AnyTypeOf<[AnyUnsignedInteger,
+ AnySignlessInteger]>;
def Tosa_Int32Or64 : AnyTypeOf<[Tosa_Int32,
Tosa_Int64]>;
@@ -172,9 +160,6 @@ class Tosa_TypeLike<list<Type> types, string description = ""> : TypeConstraint<
def Tosa_IntLike : Tosa_TypeLike<[Tosa_Int], "signless-integer-like">;
def Tosa_Int8Like : Tosa_TypeLike<[Tosa_Int8], "signless-integer-8-bit-like">;
-def Tosa_Int16Like : Tosa_TypeLike<[Tosa_Int16], "signless-integer-16-bit-like">;
-def Tosa_Int32Like : Tosa_TypeLike<[Tosa_Int32], "signless-integer-32-bit-like">;
-def Tosa_Int64Like : Tosa_TypeLike<[Tosa_Int64], "signless-integer-64-bit-like">;
//===----------------------------------------------------------------------===//
// Attribute predicates and classes.
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 967775281ad91f..b669b7362e9432 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -410,6 +410,8 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
bool CheckVariable(Operation *op);
bool CheckVariableReadOrWrite(Operation *op);
+ bool isValidElementType(Type type);
+
SmallVector<std::function<LogicalResult(Operation *)>> constCheckers;
TosaLevel tosaLevel;
DenseMap<StringAttr, mlir::Type> variablesMap;
@@ -503,15 +505,58 @@ LogicalResult TosaValidation::applyVariableCheck(Operation *op) {
return success();
}
+bool TosaValidation::isValidElementType(Type type) {
+ if ((profile == TosaProfileEnum::BaseInference) && isa<FloatType>(type)) {
+ return false;
+ }
+ if (type.isF64()) {
+ return false;
+ }
+ if (auto intTy = dyn_cast<IntegerType>(type)) {
+ if (intTy.isUnsigned()) {
+ switch (intTy.getWidth()) {
+ case 8:
+ case 16:
+ return true;
+ default:
+ return false;
+ }
+ } else {
+ // Signless - treated as signed.
+ switch (intTy.getWidth()) {
+ case 1:
+ case 4:
+ case 8:
+ case 16:
+ case 32:
+ case 48:
+ case 64:
+ return true;
+ default:
+ return false;
+ }
+ }
+ return false;
+ }
+ return true;
+}
+
void TosaValidation::runOnOperation() {
configLevelAndProfile();
getOperation().walk([&](Operation *op) {
for (Value operand : op->getOperands()) {
- if ((profile == TosaProfileEnum::BaseInference) &&
- isa<FloatType>(getElementTypeOrSelf(operand))) {
+ auto elementTy = getElementTypeOrSelf(operand);
+ if (!isValidElementType(elementTy)) {
+ op->emitOpError() << "failed level check: element type " << elementTy
+ << " is not legal";
return signalPassFailure();
}
- if (getElementTypeOrSelf(operand).isF64()) {
+ }
+ for (Type resultTy : op->getResultTypes()) {
+ auto elementTy = getElementTypeOrSelf(resultTy);
+ if (!isValidElementType(elementTy)) {
+ op->emitOpError() << "failed level check: element type " << elementTy
+ << " is not legal";
return signalPassFailure();
}
}
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index 35ecbcc799e3df..1d3ef282836705 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -115,6 +115,22 @@ func.func @test_const(%arg0 : tensor<1x1xi32>) -> tensor<1x1x1x1x1x1x1xi32> {
// -----
+func.func @test_const_i2(%arg0 : tensor<1xi2>) {
+ // expected-error at +1 {{'tosa.const' op failed level check: element type 'i2' is not legal}}
+ %0 = "tosa.const"() {value = dense<0> : tensor<1xi2>} : () -> tensor<1xi2>
+ return
+}
+
+// -----
+
+func.func @test_const_ui32(%arg0 : tensor<1xui32>) {
+ // expected-error at +1 {{'tosa.const' op failed level check: element type 'ui32' is not legal}}
+ %0 = "tosa.const"() {value = dense<0> : tensor<1xui32>} : () -> tensor<1xui32>
+ return
+}
+
+// -----
+
func.func @test_avgpool2d_kernel_y(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
// expected-error at +1 {{'tosa.avg_pool2d' op failed level check: kernel <= MAX_KERNEL}}
%0 = "tosa.avg_pool2d"(%arg0) {kernel = array<i64: 8193, 1>, pad = array<i64: 4, 4, 4, 4>, stride = array<i64: 1, 1>, acc_type = f32} :
>From ad0cf1b99ce03f941afd0817b31aa3e86580d21d Mon Sep 17 00:00:00 2001
From: Matthias Gehre <matthias.gehre at amd.com>
Date: Tue, 26 Mar 2024 21:53:21 +0100
Subject: [PATCH 2/2] Say "'tosa.const' op is not profile-aligned: element type
'ui32' is not legal"
---
mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp | 8 ++++----
mlir/test/Dialect/Tosa/level_check.mlir | 4 ++--
2 files changed, 6 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index b669b7362e9432..74ef6381f3d701 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -547,16 +547,16 @@ void TosaValidation::runOnOperation() {
for (Value operand : op->getOperands()) {
auto elementTy = getElementTypeOrSelf(operand);
if (!isValidElementType(elementTy)) {
- op->emitOpError() << "failed level check: element type " << elementTy
- << " is not legal";
+ op->emitOpError() << "is not profile-aligned: element type "
+ << elementTy << " is not legal";
return signalPassFailure();
}
}
for (Type resultTy : op->getResultTypes()) {
auto elementTy = getElementTypeOrSelf(resultTy);
if (!isValidElementType(elementTy)) {
- op->emitOpError() << "failed level check: element type " << elementTy
- << " is not legal";
+ op->emitOpError() << "is not profile-aligned: element type "
+ << elementTy << " is not legal";
return signalPassFailure();
}
}
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index 1d3ef282836705..d8dd878051f18d 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -116,7 +116,7 @@ func.func @test_const(%arg0 : tensor<1x1xi32>) -> tensor<1x1x1x1x1x1x1xi32> {
// -----
func.func @test_const_i2(%arg0 : tensor<1xi2>) {
- // expected-error at +1 {{'tosa.const' op failed level check: element type 'i2' is not legal}}
+ // expected-error at +1 {{'tosa.const' op is not profile-aligned: element type 'i2' is not legal}}
%0 = "tosa.const"() {value = dense<0> : tensor<1xi2>} : () -> tensor<1xi2>
return
}
@@ -124,7 +124,7 @@ func.func @test_const_i2(%arg0 : tensor<1xi2>) {
// -----
func.func @test_const_ui32(%arg0 : tensor<1xui32>) {
- // expected-error at +1 {{'tosa.const' op failed level check: element type 'ui32' is not legal}}
+ // expected-error at +1 {{'tosa.const' op is not profile-aligned: element type 'ui32' is not legal}}
%0 = "tosa.const"() {value = dense<0> : tensor<1xui32>} : () -> tensor<1xui32>
return
}
More information about the Mlir-commits
mailing list