[Mlir-commits] [mlir] [TOSA] Move cond_if and while_loop operations to controlflow extension (PR #128216)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Feb 21 10:53:57 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-tosa

Author: Tai Ly (Tai78641)

<details>
<summary>Changes</summary>

This commit adds the concept of a controlflow extension to the dialect and updates the validation pass to check conf_if and while_loop are supported only in the presence of the controlflow extension.


---
Full diff: https://github.com/llvm/llvm-project/pull/128216.diff


11 Files Affected:

- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td (+10-9) 
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+4-4) 
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h (+1) 
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (+1-1) 
- (modified) mlir/test/Dialect/Tosa/availability.mlir (+4-4) 
- (modified) mlir/test/Dialect/Tosa/invalid.mlir (+1-1) 
- (modified) mlir/test/Dialect/Tosa/invalid_extension.mlir (+35-1) 
- (modified) mlir/test/Dialect/Tosa/level_check.mlir (+1-1) 
- (modified) mlir/test/Dialect/Tosa/profile_all_unsupported.mlir (+1-1) 
- (modified) mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir (+1-1) 
- (modified) mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir (+1-1) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index 13bbba2b492fa..95c3b5c7c983d 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -241,19 +241,20 @@ def Tosa_PRO_INT   : I32EnumAttrCase<"pro_int", 1>;
 def Tosa_PRO_FP   : I32EnumAttrCase<"pro_fp", 2>;
 def Tosa_NONE : I32EnumAttrCase<"none", 3>;
 
-def Tosa_EXT_INT16    : I32EnumAttrCase<"int16", 1>;
-def Tosa_EXT_INT4     : I32EnumAttrCase<"int4", 2>;
-def Tosa_EXT_BF16     : I32EnumAttrCase<"bf16", 3>;
-def Tosa_EXT_FP8E4M3  : I32EnumAttrCase<"fp8e4m3", 4>;
-def Tosa_EXT_FP8E5M2  : I32EnumAttrCase<"fp8e5m2", 5>;
-def Tosa_EXT_FFT      : I32EnumAttrCase<"fft", 6>;
-def Tosa_EXT_VARIABLE : I32EnumAttrCase<"variable", 7>;
-def Tosa_EXT_NONE     : I32EnumAttrCase<"none", 8>;
+def Tosa_EXT_INT16        : I32EnumAttrCase<"int16", 1>;
+def Tosa_EXT_INT4         : I32EnumAttrCase<"int4", 2>;
+def Tosa_EXT_BF16         : I32EnumAttrCase<"bf16", 3>;
+def Tosa_EXT_FP8E4M3      : I32EnumAttrCase<"fp8e4m3", 4>;
+def Tosa_EXT_FP8E5M2      : I32EnumAttrCase<"fp8e5m2", 5>;
+def Tosa_EXT_FFT          : I32EnumAttrCase<"fft", 6>;
+def Tosa_EXT_VARIABLE     : I32EnumAttrCase<"variable", 7>;
+def Tosa_EXT_CONTROLFLOW  : I32EnumAttrCase<"controlflow", 8>;
+def Tosa_EXT_NONE         : I32EnumAttrCase<"none", 9>;
 
 def Tosa_ExtensionAttr
     : Tosa_I32EnumAttr<"Extension", "supported TOSA extensions", "ext", [
       Tosa_EXT_INT16, Tosa_EXT_INT4, Tosa_EXT_BF16, Tosa_EXT_FP8E4M3,
-      Tosa_EXT_FP8E5M2, Tosa_EXT_FFT, Tosa_EXT_VARIABLE, Tosa_EXT_NONE
+      Tosa_EXT_FP8E5M2, Tosa_EXT_FFT, Tosa_EXT_VARIABLE, Tosa_EXT_CONTROLFLOW, Tosa_EXT_NONE
     ]>;
 
 def Tosa_ExtensionArrayAttr
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 69a408767b3c6..7839548f4273e 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2390,8 +2390,8 @@ def Tosa_IfOp : Tosa_Op<"cond_if",
   );
 
   list<Availability> availability = [
-    Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
-    Extension<[]>,
+    Profile<[]>,
+    Extension<[Tosa_EXT_CONTROLFLOW]>,
   ];
 
   let regions = (region
@@ -2431,8 +2431,8 @@ def Tosa_WhileOp : Tosa_Op<"while_loop", [
   );
 
   list<Availability> availability = [
-    Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
-    Extension<[]>,
+    Profile<[]>,
+    Extension<[Tosa_EXT_CONTROLFLOW]>,
   ];
 
   let regions = (region
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
index a831bae12f3c1..57a4cd6a382ee 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
@@ -141,6 +141,7 @@ class TosaProfileCompliance {
     case Extension::fft:
       return {Profile::pro_fp};
     case Extension::variable:
+    case Extension::controlflow:
       return {Profile::pro_fp, Profile::pro_int};
     case Extension::none:
       return {};
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index f74a4b4c58b80..32648830bb760 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -425,7 +425,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
         } else {
           llvm::errs() << "unknown TOSA extension name passed in: " << ext
                        << ", supported extension are int16, int4, bf16, "
-                       << "fp8e4m3, fp8e5m2, fft, and variable\n";
+                       << "fp8e4m3, fp8e5m2, fft, variable and controlflow\n";
           return signalPassFailure();
         }
       }
diff --git a/mlir/test/Dialect/Tosa/availability.mlir b/mlir/test/Dialect/Tosa/availability.mlir
index e66ff4cacfd89..da8f9ef82c839 100644
--- a/mlir/test/Dialect/Tosa/availability.mlir
+++ b/mlir/test/Dialect/Tosa/availability.mlir
@@ -629,8 +629,8 @@ func.func @test_identity(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> {
 // -----
 // CHECK-LABEL: cond_if
 func.func @test_cond_if(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
-  // CHECK: profiles: [ [pro_int, pro_fp] ]
-  // CHECK: extensions: [ [bf16] ]
+  // CHECK: tosa.cond_if profiles: [ ]
+  // CHECK: tosa.cond_if extensions: [ [controlflow] ]
   %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
     %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
     tosa.yield %1 : tensor<f32>
@@ -645,8 +645,8 @@ func.func @test_cond_if(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1
 // CHECK-LABEL: while_loop
 func.func @test_while_loop(%arg0: tensor<10xi32>, %arg1: tensor<i32>) {
   %0 = "tosa.const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
-  // CHECK: profiles: [ [pro_int, pro_fp] ]
-  // CHECK: extensions: [ [bf16] ]
+  // CHECK: profiles: [ ]
+  // CHECK: extensions: [ [controlflow] ]
   %1:3 = tosa.while_loop (%arg2 = %0, %arg3 = %0, %arg4 = %arg0) : (tensor<i32>, tensor<i32>, tensor<10xi32>) -> (tensor<i32>, tensor<i32>, tensor<10xi32>) {
     %2 = tosa.greater_equal %arg3, %arg1 : (tensor<i32>, tensor<i32>) -> tensor<i1>
     %3 = tosa.logical_not %2 : (tensor<i1>) -> tensor<i1>
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 1aa8547cb2fdb..c44a0d1c09215 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -4,7 +4,7 @@
 // validation flow.
 //--------------------------------------------------------------------------------------------------
 
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=pro_int,pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable strict-op-spec-alignment"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=pro_int,pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow strict-op-spec-alignment"
 
 func.func @test_const() -> tensor<1xf32> {
   // expected-error at +1{{'tosa.const' op expected same attr/result element types}}
diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir
index 046b9d5615074..684875f231dec 100644
--- a/mlir/test/Dialect/Tosa/invalid_extension.mlir
+++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir
@@ -2,7 +2,7 @@
 // Enable all supported profiles to focus the verification of expected extension requirement errors.
 //--------------------------------------------------------------------------------------------------
 
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int,pro_fp,mt strict-op-spec-alignment"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int,pro_fp strict-op-spec-alignment"
 
 // -----
 func.func @test_fft2d(%arg0: tensor<1x4x8xf32>, %arg1: tensor<1x4x8xf32>) -> (tensor<1x4x8xf32>, tensor<1x4x8xf32>) {
@@ -36,3 +36,37 @@ func.func @test_cast_bf16_i32(%arg0: tensor<13x21x3xbf16>) -> tensor<13x21x3xi32
   return %0 : tensor<13x21x3xi32>
 }
 
+// -----
+func.func @test_cond_if(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
+  // expected-error at +1 {{'tosa.cond_if' op illegal: requires [controlflow]}}
+  %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+    %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    tosa.yield %1 : tensor<f32>
+  } else {
+    %1 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    tosa.yield %1 : tensor<f32>
+  }
+  return %0 : tensor<f32>
+}
+
+// -----
+func.func @test_while_loop(%arg0: tensor<10xi32>, %arg1: tensor<i32>) {
+  %0 = "tosa.const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+  // expected-error at +1 {{'tosa.while_loop' op illegal: requires [controlflow]}}
+  %1:3 = tosa.while_loop (%arg2 = %0, %arg3 = %0, %arg4 = %arg0) : (tensor<i32>, tensor<i32>, tensor<10xi32>) -> (tensor<i32>, tensor<i32>, tensor<10xi32>) {
+    %2 = tosa.greater_equal %arg3, %arg1 : (tensor<i32>, tensor<i32>) -> tensor<i1>
+    %3 = tosa.logical_not %2 : (tensor<i1>) -> tensor<i1>
+    tosa.yield %3 : tensor<i1>
+  } do {
+  ^bb0(%arg2: tensor<i32>, %arg3: tensor<i32>, %arg4: tensor<10xi32>):
+    %2 = "tosa.const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+    %3 = tosa.add %arg3, %2 : (tensor<i32>, tensor<i32>) -> tensor<i32>
+    %7 = tosa.const_shape {value = dense<[1]> : tensor<1xindex>} : () -> !tosa.shape<1>
+    %4 = tosa.reshape %2, %7 : (tensor<i32>, !tosa.shape<1>) -> tensor<1xi32>
+    %5 = tosa.add %arg4, %4 : (tensor<10xi32>, tensor<1xi32>) -> tensor<10xi32>
+    %6 = tosa.add %arg2, %2 : (tensor<i32>, tensor<i32>) -> tensor<i32>
+    tosa.yield %6, %3, %5 : tensor<i32>, tensor<i32>, tensor<10xi32>
+  }
+  return
+}
+
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index 90c4551564d1e..a75a6bee8e809 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -2,7 +2,7 @@
 // Enable all supported profiles and extensions to focus the verification of expected level errors.
 //--------------------------------------------------------------------------------------------------
 
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=pro_int,pro_fp,mt extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=pro_int,pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow"
 
 func.func @test_argmax(%arg0: tensor<1x1x1x1x29x29x4xf32>) -> tensor<1x1x1x1x29x4xi32> {
   // expected-error at +1 {{'tosa.argmax' op failed level check: operand rank(shape) <= MAX_RANK}}
diff --git a/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir
index 6dddcf329d110..8183b58272e84 100644
--- a/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir
+++ b/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir
@@ -2,7 +2,7 @@
 // Enable all supported extensions to focus the verification of expected profile requirement errors.
 //--------------------------------------------------------------------------------------------------
 
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable strict-op-spec-alignment"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow strict-op-spec-alignment"
 
 // -----
 func.func @test_table(%arg0 : tensor<4x5xi8>, %arg1 : tensor<513xi8>) -> () {
diff --git a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
index c46b2543fbed5..f7cbd114280dc 100644
--- a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
+++ b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
@@ -2,7 +2,7 @@
 // Enable all supported extensions to focus the verification of expected profile requirement errors.
 //--------------------------------------------------------------------------------------------------
 
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable strict-op-spec-alignment"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow strict-op-spec-alignment"
 
 // -----
 func.func @test_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, %arg2: tensor<8xf32>) -> tensor<1x4x4x8xf32> {
diff --git a/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir
index 479b7569f54ae..1d6d33b9a02c7 100644
--- a/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir
+++ b/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir
@@ -2,7 +2,7 @@
 // Enable all supported extensions to focus the verification of expected profile requirement errors.
 //--------------------------------------------------------------------------------------------------
 
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable strict-op-spec-alignment"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow strict-op-spec-alignment"
 
 // -----
 func.func @test_table(%arg0 : tensor<4x5xi8>, %arg1 : tensor<513xi8>) -> () {

``````````

</details>


https://github.com/llvm/llvm-project/pull/128216


More information about the Mlir-commits mailing list