[Mlir-commits] [mlir] [TOSA] Move cond_if and while_loop operations to controlflow extension (PR #128216)
Tai Ly
llvmlistbot at llvm.org
Fri Feb 21 10:53:23 PST 2025
https://github.com/Tai78641 created https://github.com/llvm/llvm-project/pull/128216
This commit adds the concept of a controlflow extension to the dialect and updates the validation pass to check conf_if and while_loop are supported only in the presence of the controlflow extension.
>From 514e745d467bfa3f8def080acd8da3b633aec48b Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Wed, 23 Oct 2024 20:24:27 +0000
Subject: [PATCH] [TOSA] Move cond_if and while_loop operations to controlflow
extension
This commit adds the concept of a controlflow extension to the dialect
and updates the validation pass to check conf_if and while_loop are
supported only in the presence of the controlflow extension.
Change-Id: Ia2304baebd372d85f7e4f31e82d94ab85679e660
Signed-off-by: Luke Hutton <luke.hutton at arm.com>
---
.../mlir/Dialect/Tosa/IR/TosaOpBase.td | 19 +++++-----
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 8 ++---
.../Dialect/Tosa/IR/TosaProfileCompliance.h | 1 +
.../Tosa/Transforms/TosaValidation.cpp | 2 +-
mlir/test/Dialect/Tosa/availability.mlir | 8 ++---
mlir/test/Dialect/Tosa/invalid.mlir | 2 +-
mlir/test/Dialect/Tosa/invalid_extension.mlir | 36 ++++++++++++++++++-
mlir/test/Dialect/Tosa/level_check.mlir | 2 +-
.../Dialect/Tosa/profile_all_unsupported.mlir | 2 +-
.../Tosa/profile_pro_fp_unsupported.mlir | 2 +-
.../Tosa/profile_pro_int_unsupported.mlir | 2 +-
11 files changed, 60 insertions(+), 24 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index 13bbba2b492fa..95c3b5c7c983d 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -241,19 +241,20 @@ def Tosa_PRO_INT : I32EnumAttrCase<"pro_int", 1>;
def Tosa_PRO_FP : I32EnumAttrCase<"pro_fp", 2>;
def Tosa_NONE : I32EnumAttrCase<"none", 3>;
-def Tosa_EXT_INT16 : I32EnumAttrCase<"int16", 1>;
-def Tosa_EXT_INT4 : I32EnumAttrCase<"int4", 2>;
-def Tosa_EXT_BF16 : I32EnumAttrCase<"bf16", 3>;
-def Tosa_EXT_FP8E4M3 : I32EnumAttrCase<"fp8e4m3", 4>;
-def Tosa_EXT_FP8E5M2 : I32EnumAttrCase<"fp8e5m2", 5>;
-def Tosa_EXT_FFT : I32EnumAttrCase<"fft", 6>;
-def Tosa_EXT_VARIABLE : I32EnumAttrCase<"variable", 7>;
-def Tosa_EXT_NONE : I32EnumAttrCase<"none", 8>;
+def Tosa_EXT_INT16 : I32EnumAttrCase<"int16", 1>;
+def Tosa_EXT_INT4 : I32EnumAttrCase<"int4", 2>;
+def Tosa_EXT_BF16 : I32EnumAttrCase<"bf16", 3>;
+def Tosa_EXT_FP8E4M3 : I32EnumAttrCase<"fp8e4m3", 4>;
+def Tosa_EXT_FP8E5M2 : I32EnumAttrCase<"fp8e5m2", 5>;
+def Tosa_EXT_FFT : I32EnumAttrCase<"fft", 6>;
+def Tosa_EXT_VARIABLE : I32EnumAttrCase<"variable", 7>;
+def Tosa_EXT_CONTROLFLOW : I32EnumAttrCase<"controlflow", 8>;
+def Tosa_EXT_NONE : I32EnumAttrCase<"none", 9>;
def Tosa_ExtensionAttr
: Tosa_I32EnumAttr<"Extension", "supported TOSA extensions", "ext", [
Tosa_EXT_INT16, Tosa_EXT_INT4, Tosa_EXT_BF16, Tosa_EXT_FP8E4M3,
- Tosa_EXT_FP8E5M2, Tosa_EXT_FFT, Tosa_EXT_VARIABLE, Tosa_EXT_NONE
+ Tosa_EXT_FP8E5M2, Tosa_EXT_FFT, Tosa_EXT_VARIABLE, Tosa_EXT_CONTROLFLOW, Tosa_EXT_NONE
]>;
def Tosa_ExtensionArrayAttr
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 69a408767b3c6..7839548f4273e 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2390,8 +2390,8 @@ def Tosa_IfOp : Tosa_Op<"cond_if",
);
list<Availability> availability = [
- Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
- Extension<[]>,
+ Profile<[]>,
+ Extension<[Tosa_EXT_CONTROLFLOW]>,
];
let regions = (region
@@ -2431,8 +2431,8 @@ def Tosa_WhileOp : Tosa_Op<"while_loop", [
);
list<Availability> availability = [
- Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
- Extension<[]>,
+ Profile<[]>,
+ Extension<[Tosa_EXT_CONTROLFLOW]>,
];
let regions = (region
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
index a831bae12f3c1..57a4cd6a382ee 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
@@ -141,6 +141,7 @@ class TosaProfileCompliance {
case Extension::fft:
return {Profile::pro_fp};
case Extension::variable:
+ case Extension::controlflow:
return {Profile::pro_fp, Profile::pro_int};
case Extension::none:
return {};
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index f74a4b4c58b80..32648830bb760 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -425,7 +425,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
} else {
llvm::errs() << "unknown TOSA extension name passed in: " << ext
<< ", supported extension are int16, int4, bf16, "
- << "fp8e4m3, fp8e5m2, fft, and variable\n";
+ << "fp8e4m3, fp8e5m2, fft, variable and controlflow\n";
return signalPassFailure();
}
}
diff --git a/mlir/test/Dialect/Tosa/availability.mlir b/mlir/test/Dialect/Tosa/availability.mlir
index e66ff4cacfd89..da8f9ef82c839 100644
--- a/mlir/test/Dialect/Tosa/availability.mlir
+++ b/mlir/test/Dialect/Tosa/availability.mlir
@@ -629,8 +629,8 @@ func.func @test_identity(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> {
// -----
// CHECK-LABEL: cond_if
func.func @test_cond_if(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
- // CHECK: profiles: [ [pro_int, pro_fp] ]
- // CHECK: extensions: [ [bf16] ]
+ // CHECK: tosa.cond_if profiles: [ ]
+ // CHECK: tosa.cond_if extensions: [ [controlflow] ]
%0 = tosa.cond_if %arg2 -> (tensor<f32>) {
%1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
tosa.yield %1 : tensor<f32>
@@ -645,8 +645,8 @@ func.func @test_cond_if(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1
// CHECK-LABEL: while_loop
func.func @test_while_loop(%arg0: tensor<10xi32>, %arg1: tensor<i32>) {
%0 = "tosa.const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
- // CHECK: profiles: [ [pro_int, pro_fp] ]
- // CHECK: extensions: [ [bf16] ]
+ // CHECK: profiles: [ ]
+ // CHECK: extensions: [ [controlflow] ]
%1:3 = tosa.while_loop (%arg2 = %0, %arg3 = %0, %arg4 = %arg0) : (tensor<i32>, tensor<i32>, tensor<10xi32>) -> (tensor<i32>, tensor<i32>, tensor<10xi32>) {
%2 = tosa.greater_equal %arg3, %arg1 : (tensor<i32>, tensor<i32>) -> tensor<i1>
%3 = tosa.logical_not %2 : (tensor<i1>) -> tensor<i1>
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 1aa8547cb2fdb..c44a0d1c09215 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -4,7 +4,7 @@
// validation flow.
//--------------------------------------------------------------------------------------------------
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=pro_int,pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable strict-op-spec-alignment"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=pro_int,pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow strict-op-spec-alignment"
func.func @test_const() -> tensor<1xf32> {
// expected-error at +1{{'tosa.const' op expected same attr/result element types}}
diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir
index 046b9d5615074..684875f231dec 100644
--- a/mlir/test/Dialect/Tosa/invalid_extension.mlir
+++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir
@@ -2,7 +2,7 @@
// Enable all supported profiles to focus the verification of expected extension requirement errors.
//--------------------------------------------------------------------------------------------------
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int,pro_fp,mt strict-op-spec-alignment"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int,pro_fp strict-op-spec-alignment"
// -----
func.func @test_fft2d(%arg0: tensor<1x4x8xf32>, %arg1: tensor<1x4x8xf32>) -> (tensor<1x4x8xf32>, tensor<1x4x8xf32>) {
@@ -36,3 +36,37 @@ func.func @test_cast_bf16_i32(%arg0: tensor<13x21x3xbf16>) -> tensor<13x21x3xi32
return %0 : tensor<13x21x3xi32>
}
+// -----
+func.func @test_cond_if(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
+ // expected-error at +1 {{'tosa.cond_if' op illegal: requires [controlflow]}}
+ %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+ %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ tosa.yield %1 : tensor<f32>
+ } else {
+ %1 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ tosa.yield %1 : tensor<f32>
+ }
+ return %0 : tensor<f32>
+}
+
+// -----
+func.func @test_while_loop(%arg0: tensor<10xi32>, %arg1: tensor<i32>) {
+ %0 = "tosa.const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+ // expected-error at +1 {{'tosa.while_loop' op illegal: requires [controlflow]}}
+ %1:3 = tosa.while_loop (%arg2 = %0, %arg3 = %0, %arg4 = %arg0) : (tensor<i32>, tensor<i32>, tensor<10xi32>) -> (tensor<i32>, tensor<i32>, tensor<10xi32>) {
+ %2 = tosa.greater_equal %arg3, %arg1 : (tensor<i32>, tensor<i32>) -> tensor<i1>
+ %3 = tosa.logical_not %2 : (tensor<i1>) -> tensor<i1>
+ tosa.yield %3 : tensor<i1>
+ } do {
+ ^bb0(%arg2: tensor<i32>, %arg3: tensor<i32>, %arg4: tensor<10xi32>):
+ %2 = "tosa.const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+ %3 = tosa.add %arg3, %2 : (tensor<i32>, tensor<i32>) -> tensor<i32>
+ %7 = tosa.const_shape {value = dense<[1]> : tensor<1xindex>} : () -> !tosa.shape<1>
+ %4 = tosa.reshape %2, %7 : (tensor<i32>, !tosa.shape<1>) -> tensor<1xi32>
+ %5 = tosa.add %arg4, %4 : (tensor<10xi32>, tensor<1xi32>) -> tensor<10xi32>
+ %6 = tosa.add %arg2, %2 : (tensor<i32>, tensor<i32>) -> tensor<i32>
+ tosa.yield %6, %3, %5 : tensor<i32>, tensor<i32>, tensor<10xi32>
+ }
+ return
+}
+
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index 90c4551564d1e..a75a6bee8e809 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -2,7 +2,7 @@
// Enable all supported profiles and extensions to focus the verification of expected level errors.
//--------------------------------------------------------------------------------------------------
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=pro_int,pro_fp,mt extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=pro_int,pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow"
func.func @test_argmax(%arg0: tensor<1x1x1x1x29x29x4xf32>) -> tensor<1x1x1x1x29x4xi32> {
// expected-error at +1 {{'tosa.argmax' op failed level check: operand rank(shape) <= MAX_RANK}}
diff --git a/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir
index 6dddcf329d110..8183b58272e84 100644
--- a/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir
+++ b/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir
@@ -2,7 +2,7 @@
// Enable all supported extensions to focus the verification of expected profile requirement errors.
//--------------------------------------------------------------------------------------------------
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable strict-op-spec-alignment"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow strict-op-spec-alignment"
// -----
func.func @test_table(%arg0 : tensor<4x5xi8>, %arg1 : tensor<513xi8>) -> () {
diff --git a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
index c46b2543fbed5..f7cbd114280dc 100644
--- a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
+++ b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
@@ -2,7 +2,7 @@
// Enable all supported extensions to focus the verification of expected profile requirement errors.
//--------------------------------------------------------------------------------------------------
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable strict-op-spec-alignment"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow strict-op-spec-alignment"
// -----
func.func @test_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, %arg2: tensor<8xf32>) -> tensor<1x4x4x8xf32> {
diff --git a/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir
index 479b7569f54ae..1d6d33b9a02c7 100644
--- a/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir
+++ b/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir
@@ -2,7 +2,7 @@
// Enable all supported extensions to focus the verification of expected profile requirement errors.
//--------------------------------------------------------------------------------------------------
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable strict-op-spec-alignment"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow strict-op-spec-alignment"
// -----
func.func @test_table(%arg0 : tensor<4x5xi8>, %arg1 : tensor<513xi8>) -> () {
More information about the Mlir-commits
mailing list