[Mlir-commits] [mlir] [TOSA] Move cond_if and while_loop operations to controlflow extension (PR #128216)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Feb 21 10:53:57 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-tosa
Author: Tai Ly (Tai78641)
<details>
<summary>Changes</summary>
This commit adds the concept of a controlflow extension to the dialect and updates the validation pass to check conf_if and while_loop are supported only in the presence of the controlflow extension.
---
Full diff: https://github.com/llvm/llvm-project/pull/128216.diff
11 Files Affected:
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td (+10-9)
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+4-4)
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h (+1)
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (+1-1)
- (modified) mlir/test/Dialect/Tosa/availability.mlir (+4-4)
- (modified) mlir/test/Dialect/Tosa/invalid.mlir (+1-1)
- (modified) mlir/test/Dialect/Tosa/invalid_extension.mlir (+35-1)
- (modified) mlir/test/Dialect/Tosa/level_check.mlir (+1-1)
- (modified) mlir/test/Dialect/Tosa/profile_all_unsupported.mlir (+1-1)
- (modified) mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir (+1-1)
- (modified) mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir (+1-1)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index 13bbba2b492fa..95c3b5c7c983d 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -241,19 +241,20 @@ def Tosa_PRO_INT : I32EnumAttrCase<"pro_int", 1>;
def Tosa_PRO_FP : I32EnumAttrCase<"pro_fp", 2>;
def Tosa_NONE : I32EnumAttrCase<"none", 3>;
-def Tosa_EXT_INT16 : I32EnumAttrCase<"int16", 1>;
-def Tosa_EXT_INT4 : I32EnumAttrCase<"int4", 2>;
-def Tosa_EXT_BF16 : I32EnumAttrCase<"bf16", 3>;
-def Tosa_EXT_FP8E4M3 : I32EnumAttrCase<"fp8e4m3", 4>;
-def Tosa_EXT_FP8E5M2 : I32EnumAttrCase<"fp8e5m2", 5>;
-def Tosa_EXT_FFT : I32EnumAttrCase<"fft", 6>;
-def Tosa_EXT_VARIABLE : I32EnumAttrCase<"variable", 7>;
-def Tosa_EXT_NONE : I32EnumAttrCase<"none", 8>;
+def Tosa_EXT_INT16 : I32EnumAttrCase<"int16", 1>;
+def Tosa_EXT_INT4 : I32EnumAttrCase<"int4", 2>;
+def Tosa_EXT_BF16 : I32EnumAttrCase<"bf16", 3>;
+def Tosa_EXT_FP8E4M3 : I32EnumAttrCase<"fp8e4m3", 4>;
+def Tosa_EXT_FP8E5M2 : I32EnumAttrCase<"fp8e5m2", 5>;
+def Tosa_EXT_FFT : I32EnumAttrCase<"fft", 6>;
+def Tosa_EXT_VARIABLE : I32EnumAttrCase<"variable", 7>;
+def Tosa_EXT_CONTROLFLOW : I32EnumAttrCase<"controlflow", 8>;
+def Tosa_EXT_NONE : I32EnumAttrCase<"none", 9>;
def Tosa_ExtensionAttr
: Tosa_I32EnumAttr<"Extension", "supported TOSA extensions", "ext", [
Tosa_EXT_INT16, Tosa_EXT_INT4, Tosa_EXT_BF16, Tosa_EXT_FP8E4M3,
- Tosa_EXT_FP8E5M2, Tosa_EXT_FFT, Tosa_EXT_VARIABLE, Tosa_EXT_NONE
+ Tosa_EXT_FP8E5M2, Tosa_EXT_FFT, Tosa_EXT_VARIABLE, Tosa_EXT_CONTROLFLOW, Tosa_EXT_NONE
]>;
def Tosa_ExtensionArrayAttr
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 69a408767b3c6..7839548f4273e 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2390,8 +2390,8 @@ def Tosa_IfOp : Tosa_Op<"cond_if",
);
list<Availability> availability = [
- Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
- Extension<[]>,
+ Profile<[]>,
+ Extension<[Tosa_EXT_CONTROLFLOW]>,
];
let regions = (region
@@ -2431,8 +2431,8 @@ def Tosa_WhileOp : Tosa_Op<"while_loop", [
);
list<Availability> availability = [
- Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
- Extension<[]>,
+ Profile<[]>,
+ Extension<[Tosa_EXT_CONTROLFLOW]>,
];
let regions = (region
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
index a831bae12f3c1..57a4cd6a382ee 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
@@ -141,6 +141,7 @@ class TosaProfileCompliance {
case Extension::fft:
return {Profile::pro_fp};
case Extension::variable:
+ case Extension::controlflow:
return {Profile::pro_fp, Profile::pro_int};
case Extension::none:
return {};
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index f74a4b4c58b80..32648830bb760 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -425,7 +425,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
} else {
llvm::errs() << "unknown TOSA extension name passed in: " << ext
<< ", supported extension are int16, int4, bf16, "
- << "fp8e4m3, fp8e5m2, fft, and variable\n";
+ << "fp8e4m3, fp8e5m2, fft, variable and controlflow\n";
return signalPassFailure();
}
}
diff --git a/mlir/test/Dialect/Tosa/availability.mlir b/mlir/test/Dialect/Tosa/availability.mlir
index e66ff4cacfd89..da8f9ef82c839 100644
--- a/mlir/test/Dialect/Tosa/availability.mlir
+++ b/mlir/test/Dialect/Tosa/availability.mlir
@@ -629,8 +629,8 @@ func.func @test_identity(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> {
// -----
// CHECK-LABEL: cond_if
func.func @test_cond_if(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
- // CHECK: profiles: [ [pro_int, pro_fp] ]
- // CHECK: extensions: [ [bf16] ]
+ // CHECK: tosa.cond_if profiles: [ ]
+ // CHECK: tosa.cond_if extensions: [ [controlflow] ]
%0 = tosa.cond_if %arg2 -> (tensor<f32>) {
%1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
tosa.yield %1 : tensor<f32>
@@ -645,8 +645,8 @@ func.func @test_cond_if(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1
// CHECK-LABEL: while_loop
func.func @test_while_loop(%arg0: tensor<10xi32>, %arg1: tensor<i32>) {
%0 = "tosa.const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
- // CHECK: profiles: [ [pro_int, pro_fp] ]
- // CHECK: extensions: [ [bf16] ]
+ // CHECK: profiles: [ ]
+ // CHECK: extensions: [ [controlflow] ]
%1:3 = tosa.while_loop (%arg2 = %0, %arg3 = %0, %arg4 = %arg0) : (tensor<i32>, tensor<i32>, tensor<10xi32>) -> (tensor<i32>, tensor<i32>, tensor<10xi32>) {
%2 = tosa.greater_equal %arg3, %arg1 : (tensor<i32>, tensor<i32>) -> tensor<i1>
%3 = tosa.logical_not %2 : (tensor<i1>) -> tensor<i1>
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 1aa8547cb2fdb..c44a0d1c09215 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -4,7 +4,7 @@
// validation flow.
//--------------------------------------------------------------------------------------------------
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=pro_int,pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable strict-op-spec-alignment"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=pro_int,pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow strict-op-spec-alignment"
func.func @test_const() -> tensor<1xf32> {
// expected-error at +1{{'tosa.const' op expected same attr/result element types}}
diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir
index 046b9d5615074..684875f231dec 100644
--- a/mlir/test/Dialect/Tosa/invalid_extension.mlir
+++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir
@@ -2,7 +2,7 @@
// Enable all supported profiles to focus the verification of expected extension requirement errors.
//--------------------------------------------------------------------------------------------------
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int,pro_fp,mt strict-op-spec-alignment"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int,pro_fp strict-op-spec-alignment"
// -----
func.func @test_fft2d(%arg0: tensor<1x4x8xf32>, %arg1: tensor<1x4x8xf32>) -> (tensor<1x4x8xf32>, tensor<1x4x8xf32>) {
@@ -36,3 +36,37 @@ func.func @test_cast_bf16_i32(%arg0: tensor<13x21x3xbf16>) -> tensor<13x21x3xi32
return %0 : tensor<13x21x3xi32>
}
+// -----
+func.func @test_cond_if(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
+ // expected-error at +1 {{'tosa.cond_if' op illegal: requires [controlflow]}}
+ %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+ %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ tosa.yield %1 : tensor<f32>
+ } else {
+ %1 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ tosa.yield %1 : tensor<f32>
+ }
+ return %0 : tensor<f32>
+}
+
+// -----
+func.func @test_while_loop(%arg0: tensor<10xi32>, %arg1: tensor<i32>) {
+ %0 = "tosa.const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+ // expected-error at +1 {{'tosa.while_loop' op illegal: requires [controlflow]}}
+ %1:3 = tosa.while_loop (%arg2 = %0, %arg3 = %0, %arg4 = %arg0) : (tensor<i32>, tensor<i32>, tensor<10xi32>) -> (tensor<i32>, tensor<i32>, tensor<10xi32>) {
+ %2 = tosa.greater_equal %arg3, %arg1 : (tensor<i32>, tensor<i32>) -> tensor<i1>
+ %3 = tosa.logical_not %2 : (tensor<i1>) -> tensor<i1>
+ tosa.yield %3 : tensor<i1>
+ } do {
+ ^bb0(%arg2: tensor<i32>, %arg3: tensor<i32>, %arg4: tensor<10xi32>):
+ %2 = "tosa.const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+ %3 = tosa.add %arg3, %2 : (tensor<i32>, tensor<i32>) -> tensor<i32>
+ %7 = tosa.const_shape {value = dense<[1]> : tensor<1xindex>} : () -> !tosa.shape<1>
+ %4 = tosa.reshape %2, %7 : (tensor<i32>, !tosa.shape<1>) -> tensor<1xi32>
+ %5 = tosa.add %arg4, %4 : (tensor<10xi32>, tensor<1xi32>) -> tensor<10xi32>
+ %6 = tosa.add %arg2, %2 : (tensor<i32>, tensor<i32>) -> tensor<i32>
+ tosa.yield %6, %3, %5 : tensor<i32>, tensor<i32>, tensor<10xi32>
+ }
+ return
+}
+
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index 90c4551564d1e..a75a6bee8e809 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -2,7 +2,7 @@
// Enable all supported profiles and extensions to focus the verification of expected level errors.
//--------------------------------------------------------------------------------------------------
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=pro_int,pro_fp,mt extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=pro_int,pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow"
func.func @test_argmax(%arg0: tensor<1x1x1x1x29x29x4xf32>) -> tensor<1x1x1x1x29x4xi32> {
// expected-error at +1 {{'tosa.argmax' op failed level check: operand rank(shape) <= MAX_RANK}}
diff --git a/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir
index 6dddcf329d110..8183b58272e84 100644
--- a/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir
+++ b/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir
@@ -2,7 +2,7 @@
// Enable all supported extensions to focus the verification of expected profile requirement errors.
//--------------------------------------------------------------------------------------------------
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable strict-op-spec-alignment"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow strict-op-spec-alignment"
// -----
func.func @test_table(%arg0 : tensor<4x5xi8>, %arg1 : tensor<513xi8>) -> () {
diff --git a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
index c46b2543fbed5..f7cbd114280dc 100644
--- a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
+++ b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
@@ -2,7 +2,7 @@
// Enable all supported extensions to focus the verification of expected profile requirement errors.
//--------------------------------------------------------------------------------------------------
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable strict-op-spec-alignment"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow strict-op-spec-alignment"
// -----
func.func @test_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, %arg2: tensor<8xf32>) -> tensor<1x4x4x8xf32> {
diff --git a/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir
index 479b7569f54ae..1d6d33b9a02c7 100644
--- a/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir
+++ b/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir
@@ -2,7 +2,7 @@
// Enable all supported extensions to focus the verification of expected profile requirement errors.
//--------------------------------------------------------------------------------------------------
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable strict-op-spec-alignment"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow strict-op-spec-alignment"
// -----
func.func @test_table(%arg0 : tensor<4x5xi8>, %arg1 : tensor<513xi8>) -> () {
``````````
</details>
https://github.com/llvm/llvm-project/pull/128216
More information about the Mlir-commits
mailing list