[Mlir-commits] [mlir] [TOSA] Move cond_if and while_loop operations to controlflow extension (PR #128216)

Tai Ly llvmlistbot at llvm.org
Fri Feb 21 10:53:23 PST 2025


https://github.com/Tai78641 created https://github.com/llvm/llvm-project/pull/128216

This commit adds the concept of a controlflow extension to the dialect and updates the validation pass to check conf_if and while_loop are supported only in the presence of the controlflow extension.


>From 514e745d467bfa3f8def080acd8da3b633aec48b Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Wed, 23 Oct 2024 20:24:27 +0000
Subject: [PATCH] [TOSA] Move cond_if and while_loop operations to controlflow
 extension

This commit adds the concept of a controlflow extension to the dialect
and updates the validation pass to check conf_if and while_loop are
supported only in the presence of the controlflow extension.

Change-Id: Ia2304baebd372d85f7e4f31e82d94ab85679e660
Signed-off-by: Luke Hutton <luke.hutton at arm.com>
---
 .../mlir/Dialect/Tosa/IR/TosaOpBase.td        | 19 +++++-----
 mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td  |  8 ++---
 .../Dialect/Tosa/IR/TosaProfileCompliance.h   |  1 +
 .../Tosa/Transforms/TosaValidation.cpp        |  2 +-
 mlir/test/Dialect/Tosa/availability.mlir      |  8 ++---
 mlir/test/Dialect/Tosa/invalid.mlir           |  2 +-
 mlir/test/Dialect/Tosa/invalid_extension.mlir | 36 ++++++++++++++++++-
 mlir/test/Dialect/Tosa/level_check.mlir       |  2 +-
 .../Dialect/Tosa/profile_all_unsupported.mlir |  2 +-
 .../Tosa/profile_pro_fp_unsupported.mlir      |  2 +-
 .../Tosa/profile_pro_int_unsupported.mlir     |  2 +-
 11 files changed, 60 insertions(+), 24 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index 13bbba2b492fa..95c3b5c7c983d 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -241,19 +241,20 @@ def Tosa_PRO_INT   : I32EnumAttrCase<"pro_int", 1>;
 def Tosa_PRO_FP   : I32EnumAttrCase<"pro_fp", 2>;
 def Tosa_NONE : I32EnumAttrCase<"none", 3>;
 
-def Tosa_EXT_INT16    : I32EnumAttrCase<"int16", 1>;
-def Tosa_EXT_INT4     : I32EnumAttrCase<"int4", 2>;
-def Tosa_EXT_BF16     : I32EnumAttrCase<"bf16", 3>;
-def Tosa_EXT_FP8E4M3  : I32EnumAttrCase<"fp8e4m3", 4>;
-def Tosa_EXT_FP8E5M2  : I32EnumAttrCase<"fp8e5m2", 5>;
-def Tosa_EXT_FFT      : I32EnumAttrCase<"fft", 6>;
-def Tosa_EXT_VARIABLE : I32EnumAttrCase<"variable", 7>;
-def Tosa_EXT_NONE     : I32EnumAttrCase<"none", 8>;
+def Tosa_EXT_INT16        : I32EnumAttrCase<"int16", 1>;
+def Tosa_EXT_INT4         : I32EnumAttrCase<"int4", 2>;
+def Tosa_EXT_BF16         : I32EnumAttrCase<"bf16", 3>;
+def Tosa_EXT_FP8E4M3      : I32EnumAttrCase<"fp8e4m3", 4>;
+def Tosa_EXT_FP8E5M2      : I32EnumAttrCase<"fp8e5m2", 5>;
+def Tosa_EXT_FFT          : I32EnumAttrCase<"fft", 6>;
+def Tosa_EXT_VARIABLE     : I32EnumAttrCase<"variable", 7>;
+def Tosa_EXT_CONTROLFLOW  : I32EnumAttrCase<"controlflow", 8>;
+def Tosa_EXT_NONE         : I32EnumAttrCase<"none", 9>;
 
 def Tosa_ExtensionAttr
     : Tosa_I32EnumAttr<"Extension", "supported TOSA extensions", "ext", [
       Tosa_EXT_INT16, Tosa_EXT_INT4, Tosa_EXT_BF16, Tosa_EXT_FP8E4M3,
-      Tosa_EXT_FP8E5M2, Tosa_EXT_FFT, Tosa_EXT_VARIABLE, Tosa_EXT_NONE
+      Tosa_EXT_FP8E5M2, Tosa_EXT_FFT, Tosa_EXT_VARIABLE, Tosa_EXT_CONTROLFLOW, Tosa_EXT_NONE
     ]>;
 
 def Tosa_ExtensionArrayAttr
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 69a408767b3c6..7839548f4273e 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2390,8 +2390,8 @@ def Tosa_IfOp : Tosa_Op<"cond_if",
   );
 
   list<Availability> availability = [
-    Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
-    Extension<[]>,
+    Profile<[]>,
+    Extension<[Tosa_EXT_CONTROLFLOW]>,
   ];
 
   let regions = (region
@@ -2431,8 +2431,8 @@ def Tosa_WhileOp : Tosa_Op<"while_loop", [
   );
 
   list<Availability> availability = [
-    Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
-    Extension<[]>,
+    Profile<[]>,
+    Extension<[Tosa_EXT_CONTROLFLOW]>,
   ];
 
   let regions = (region
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
index a831bae12f3c1..57a4cd6a382ee 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
@@ -141,6 +141,7 @@ class TosaProfileCompliance {
     case Extension::fft:
       return {Profile::pro_fp};
     case Extension::variable:
+    case Extension::controlflow:
       return {Profile::pro_fp, Profile::pro_int};
     case Extension::none:
       return {};
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index f74a4b4c58b80..32648830bb760 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -425,7 +425,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
         } else {
           llvm::errs() << "unknown TOSA extension name passed in: " << ext
                        << ", supported extension are int16, int4, bf16, "
-                       << "fp8e4m3, fp8e5m2, fft, and variable\n";
+                       << "fp8e4m3, fp8e5m2, fft, variable and controlflow\n";
           return signalPassFailure();
         }
       }
diff --git a/mlir/test/Dialect/Tosa/availability.mlir b/mlir/test/Dialect/Tosa/availability.mlir
index e66ff4cacfd89..da8f9ef82c839 100644
--- a/mlir/test/Dialect/Tosa/availability.mlir
+++ b/mlir/test/Dialect/Tosa/availability.mlir
@@ -629,8 +629,8 @@ func.func @test_identity(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> {
 // -----
 // CHECK-LABEL: cond_if
 func.func @test_cond_if(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
-  // CHECK: profiles: [ [pro_int, pro_fp] ]
-  // CHECK: extensions: [ [bf16] ]
+  // CHECK: tosa.cond_if profiles: [ ]
+  // CHECK: tosa.cond_if extensions: [ [controlflow] ]
   %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
     %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
     tosa.yield %1 : tensor<f32>
@@ -645,8 +645,8 @@ func.func @test_cond_if(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1
 // CHECK-LABEL: while_loop
 func.func @test_while_loop(%arg0: tensor<10xi32>, %arg1: tensor<i32>) {
   %0 = "tosa.const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
-  // CHECK: profiles: [ [pro_int, pro_fp] ]
-  // CHECK: extensions: [ [bf16] ]
+  // CHECK: profiles: [ ]
+  // CHECK: extensions: [ [controlflow] ]
   %1:3 = tosa.while_loop (%arg2 = %0, %arg3 = %0, %arg4 = %arg0) : (tensor<i32>, tensor<i32>, tensor<10xi32>) -> (tensor<i32>, tensor<i32>, tensor<10xi32>) {
     %2 = tosa.greater_equal %arg3, %arg1 : (tensor<i32>, tensor<i32>) -> tensor<i1>
     %3 = tosa.logical_not %2 : (tensor<i1>) -> tensor<i1>
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 1aa8547cb2fdb..c44a0d1c09215 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -4,7 +4,7 @@
 // validation flow.
 //--------------------------------------------------------------------------------------------------
 
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=pro_int,pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable strict-op-spec-alignment"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=pro_int,pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow strict-op-spec-alignment"
 
 func.func @test_const() -> tensor<1xf32> {
   // expected-error at +1{{'tosa.const' op expected same attr/result element types}}
diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir
index 046b9d5615074..684875f231dec 100644
--- a/mlir/test/Dialect/Tosa/invalid_extension.mlir
+++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir
@@ -2,7 +2,7 @@
 // Enable all supported profiles to focus the verification of expected extension requirement errors.
 //--------------------------------------------------------------------------------------------------
 
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int,pro_fp,mt strict-op-spec-alignment"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int,pro_fp strict-op-spec-alignment"
 
 // -----
 func.func @test_fft2d(%arg0: tensor<1x4x8xf32>, %arg1: tensor<1x4x8xf32>) -> (tensor<1x4x8xf32>, tensor<1x4x8xf32>) {
@@ -36,3 +36,37 @@ func.func @test_cast_bf16_i32(%arg0: tensor<13x21x3xbf16>) -> tensor<13x21x3xi32
   return %0 : tensor<13x21x3xi32>
 }
 
+// -----
+func.func @test_cond_if(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
+  // expected-error at +1 {{'tosa.cond_if' op illegal: requires [controlflow]}}
+  %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+    %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    tosa.yield %1 : tensor<f32>
+  } else {
+    %1 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    tosa.yield %1 : tensor<f32>
+  }
+  return %0 : tensor<f32>
+}
+
+// -----
+func.func @test_while_loop(%arg0: tensor<10xi32>, %arg1: tensor<i32>) {
+  %0 = "tosa.const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+  // expected-error at +1 {{'tosa.while_loop' op illegal: requires [controlflow]}}
+  %1:3 = tosa.while_loop (%arg2 = %0, %arg3 = %0, %arg4 = %arg0) : (tensor<i32>, tensor<i32>, tensor<10xi32>) -> (tensor<i32>, tensor<i32>, tensor<10xi32>) {
+    %2 = tosa.greater_equal %arg3, %arg1 : (tensor<i32>, tensor<i32>) -> tensor<i1>
+    %3 = tosa.logical_not %2 : (tensor<i1>) -> tensor<i1>
+    tosa.yield %3 : tensor<i1>
+  } do {
+  ^bb0(%arg2: tensor<i32>, %arg3: tensor<i32>, %arg4: tensor<10xi32>):
+    %2 = "tosa.const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+    %3 = tosa.add %arg3, %2 : (tensor<i32>, tensor<i32>) -> tensor<i32>
+    %7 = tosa.const_shape {value = dense<[1]> : tensor<1xindex>} : () -> !tosa.shape<1>
+    %4 = tosa.reshape %2, %7 : (tensor<i32>, !tosa.shape<1>) -> tensor<1xi32>
+    %5 = tosa.add %arg4, %4 : (tensor<10xi32>, tensor<1xi32>) -> tensor<10xi32>
+    %6 = tosa.add %arg2, %2 : (tensor<i32>, tensor<i32>) -> tensor<i32>
+    tosa.yield %6, %3, %5 : tensor<i32>, tensor<i32>, tensor<10xi32>
+  }
+  return
+}
+
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index 90c4551564d1e..a75a6bee8e809 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -2,7 +2,7 @@
 // Enable all supported profiles and extensions to focus the verification of expected level errors.
 //--------------------------------------------------------------------------------------------------
 
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=pro_int,pro_fp,mt extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=pro_int,pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow"
 
 func.func @test_argmax(%arg0: tensor<1x1x1x1x29x29x4xf32>) -> tensor<1x1x1x1x29x4xi32> {
   // expected-error at +1 {{'tosa.argmax' op failed level check: operand rank(shape) <= MAX_RANK}}
diff --git a/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir
index 6dddcf329d110..8183b58272e84 100644
--- a/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir
+++ b/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir
@@ -2,7 +2,7 @@
 // Enable all supported extensions to focus the verification of expected profile requirement errors.
 //--------------------------------------------------------------------------------------------------
 
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable strict-op-spec-alignment"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow strict-op-spec-alignment"
 
 // -----
 func.func @test_table(%arg0 : tensor<4x5xi8>, %arg1 : tensor<513xi8>) -> () {
diff --git a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
index c46b2543fbed5..f7cbd114280dc 100644
--- a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
+++ b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
@@ -2,7 +2,7 @@
 // Enable all supported extensions to focus the verification of expected profile requirement errors.
 //--------------------------------------------------------------------------------------------------
 
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable strict-op-spec-alignment"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow strict-op-spec-alignment"
 
 // -----
 func.func @test_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, %arg2: tensor<8xf32>) -> tensor<1x4x4x8xf32> {
diff --git a/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir
index 479b7569f54ae..1d6d33b9a02c7 100644
--- a/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir
+++ b/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir
@@ -2,7 +2,7 @@
 // Enable all supported extensions to focus the verification of expected profile requirement errors.
 //--------------------------------------------------------------------------------------------------
 
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable strict-op-spec-alignment"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow strict-op-spec-alignment"
 
 // -----
 func.func @test_table(%arg0 : tensor<4x5xi8>, %arg1 : tensor<513xi8>) -> () {



More information about the Mlir-commits mailing list