[flang-commits] [flang] [mlir] [MLIR][OpenMP] Improve loop wrapper representation (PR #97706)
Sergio Afonso via flang-commits
flang-commits at lists.llvm.org
Thu Jul 4 03:15:01 PDT 2024
https://github.com/skatrak created https://github.com/llvm/llvm-project/pull/97706
This patch replaces the `SingleBlockImplicitTerminator<"TerminatorOp">` trait of loop wrapper operations for the `SingleBlock` trait. This enables a more robust implementation of the `LoopWrapperInterface::isWrapper()` method, since it does no longer have to deal with the potentially missing (implicit) terminator.
The `LoopWrapperInterface::isWrapper()` method is also extended to not identify as wrappers those operations which have a loop wrapper operation inside that is not taking a wrapper role. This is important for cases where `omp.parallel` is nested, which can but is not required to work as a loop wrapper.
Tests are updated to integrate these representation and validation changes.
>From 0e48712dfcb823032c8980911924f7288f5fa6d6 Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Thu, 4 Jul 2024 10:56:43 +0100
Subject: [PATCH] [MLIR][OpenMP] Improve loop wrapper representation
This patch replaces the `SingleBlockImplicitTerminator<"TerminatorOp">` trait
of loop wrapper operations for the `SingleBlock` trait. This enables a more
robust implementation of the `LoopWrapperInterface::isWrapper()` method, since
it does no longer have to deal with the potentially missing (implicit)
terminator.
The `LoopWrapperInterface::isWrapper()` method is also extended to not identify
as wrappers those operations which have a loop wrapper operation inside that is
not taking a wrapper role. This is important for cases where `omp.parallel`
is nested, which can but is not required to work as a loop wrapper.
Tests are updated to integrate these representation and validation changes.
---
.../Fir/convert-to-llvm-openmp-and-fir.fir | 4 +++
mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 8 ++---
.../Dialect/OpenMP/OpenMPOpsInterfaces.td | 14 +++++---
.../OpenMPToLLVM/convert-to-llvmir.mlir | 1 +
mlir/test/Dialect/OpenMP/invalid.mlir | 15 +++++---
mlir/test/Dialect/OpenMP/ops.mlir | 35 +++++++++++++++++++
mlir/test/Target/LLVMIR/openmp-llvm.mlir | 6 ++++
7 files changed, 71 insertions(+), 12 deletions(-)
diff --git a/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir b/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
index 8b62787bb3094..eca762d52a724 100644
--- a/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
+++ b/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
@@ -200,6 +200,7 @@ func.func @_QPsimd1(%arg0: !fir.ref<i32> {fir.bindc_name = "n"}, %arg1: !fir.ref
fir.store %3 to %6 : !fir.ref<i32>
omp.yield
}
+ omp.terminator
}
omp.terminator
}
@@ -225,6 +226,7 @@ func.func @_QPsimd1(%arg0: !fir.ref<i32> {fir.bindc_name = "n"}, %arg1: !fir.ref
// CHECK: llvm.store %[[I1]], %[[ARR_I_REF]] : i32, !llvm.ptr
// CHECK: omp.yield
// CHECK: }
+// CHECK: omp.terminator
// CHECK: }
// CHECK: omp.terminator
// CHECK: }
@@ -518,6 +520,7 @@ func.func @_QPsimd_with_nested_loop() {
fir.store %7 to %3 : !fir.ref<i32>
omp.yield
}
+ omp.terminator
}
return
}
@@ -538,6 +541,7 @@ func.func @_QPsimd_with_nested_loop() {
// CHECK: ^bb3:
// CHECK: omp.yield
// CHECK: }
+// CHECK: omp.terminator
// CHECK: }
// CHECK: llvm.return
// CHECK: }
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 99e14cd1b7b48..aed0d69619db2 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -354,7 +354,7 @@ def LoopNestOp : OpenMP_Op<"loop_nest", traits = [
def WsloopOp : OpenMP_Op<"wsloop", traits = [
AttrSizedOperandSegments, DeclareOpInterfaceMethods<LoopWrapperInterface>,
- RecursiveMemoryEffects, SingleBlockImplicitTerminator<"TerminatorOp">
+ RecursiveMemoryEffects, SingleBlock
], clauses = [
// TODO: Complete clause list (allocate, private).
// TODO: Sort clauses alphabetically.
@@ -418,7 +418,7 @@ def WsloopOp : OpenMP_Op<"wsloop", traits = [
def SimdOp : OpenMP_Op<"simd", traits = [
AttrSizedOperandSegments, DeclareOpInterfaceMethods<LoopWrapperInterface>,
- RecursiveMemoryEffects, SingleBlockImplicitTerminator<"TerminatorOp">
+ RecursiveMemoryEffects, SingleBlock
], clauses = [
// TODO: Complete clause list (linear, private, reduction).
OpenMP_AlignedClause, OpenMP_IfClause, OpenMP_NontemporalClause,
@@ -485,7 +485,7 @@ def YieldOp : OpenMP_Op<"yield",
//===----------------------------------------------------------------------===//
def DistributeOp : OpenMP_Op<"distribute", traits = [
AttrSizedOperandSegments, DeclareOpInterfaceMethods<LoopWrapperInterface>,
- RecursiveMemoryEffects, SingleBlockImplicitTerminator<"TerminatorOp">
+ RecursiveMemoryEffects, SingleBlock
], clauses = [
// TODO: Complete clause list (private).
// TODO: Sort clauses alphabetically.
@@ -575,7 +575,7 @@ def TaskOp : OpenMP_Op<"task", traits = [
def TaskloopOp : OpenMP_Op<"taskloop", traits = [
AttrSizedOperandSegments, AutomaticAllocationScope,
DeclareOpInterfaceMethods<LoopWrapperInterface>, RecursiveMemoryEffects,
- SingleBlockImplicitTerminator<"TerminatorOp">
+ SingleBlock
], clauses = [
// TODO: Complete clause list (private).
// TODO: Sort clauses alphabetically.
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
index 31a306072d0ec..385aa8b1b016a 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
@@ -84,8 +84,8 @@ def LoopWrapperInterface : OpInterface<"LoopWrapperInterface"> {
/*description=*/[{
Tell whether the operation could be taking the role of a loop wrapper.
That is, it has a single region with a single block in which there are
- two operations: another wrapper or `omp.loop_nest` operation and a
- terminator.
+ two operations: another wrapper (also taking a loop wrapper role) or
+ `omp.loop_nest` operation and a terminator.
}],
/*retTy=*/"bool",
/*methodName=*/"isWrapper",
@@ -102,8 +102,14 @@ def LoopWrapperInterface : OpInterface<"LoopWrapperInterface"> {
Operation &firstOp = *r.op_begin();
Operation &secondOp = *(std::next(r.op_begin()));
- return ::llvm::isa<LoopNestOp, LoopWrapperInterface>(firstOp) &&
- secondOp.hasTrait<OpTrait::IsTerminator>();
+
+ if (!secondOp.hasTrait<OpTrait::IsTerminator>())
+ return false;
+
+ if (auto wrapper = ::llvm::dyn_cast<LoopWrapperInterface>(firstOp))
+ return wrapper.isWrapper();
+
+ return ::llvm::isa<LoopNestOp>(firstOp);
}]
>,
InterfaceMethod<
diff --git a/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir
index 3aeb9e70522d5..4c9e09970279a 100644
--- a/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir
+++ b/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir
@@ -174,6 +174,7 @@ func.func @loop_nest_block_arg(%val : i32, %ub : i32, %i : index) {
^bb3:
omp.yield
}
+ omp.terminator
}
return
}
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 2915963f704d3..91eeb0911160d 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -11,8 +11,8 @@ func.func @unknown_clause() {
// -----
func.func @not_wrapper() {
+ // expected-error at +1 {{op must be a loop wrapper}}
omp.distribute {
- // expected-error at +1 {{op must take a loop wrapper role if nested inside of 'omp.distribute'}}
omp.parallel {
%0 = arith.constant 0 : i32
omp.terminator
@@ -383,12 +383,16 @@ func.func @omp_simd() -> () {
// -----
-func.func @omp_simd_nested_wrapper() -> () {
+func.func @omp_simd_nested_wrapper(%lb : index, %ub : index, %step : index) -> () {
// expected-error @below {{op must wrap an 'omp.loop_nest' directly}}
omp.simd {
omp.distribute {
+ omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
+ omp.yield
+ }
omp.terminator
}
+ omp.terminator
}
return
}
@@ -1960,6 +1964,7 @@ func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
}
omp.terminator
}
+ omp.terminator
}
return
}
@@ -2158,11 +2163,13 @@ func.func @omp_distribute_wrapper() -> () {
// -----
-func.func @omp_distribute_nested_wrapper(%data_var : memref<i32>) -> () {
+func.func @omp_distribute_nested_wrapper(%lb: index, %ub: index, %step: index) -> () {
// expected-error @below {{only supported nested wrappers are 'omp.parallel' and 'omp.simd'}}
omp.distribute {
"omp.wsloop"() ({
- %0 = arith.constant 0 : i32
+ omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
+ "omp.yield"() : () -> ()
+ }
"omp.terminator"() : () -> ()
}) : () -> ()
"omp.terminator"() : () -> ()
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index eb283840aa7ee..ff3b1e60f7cfe 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -617,6 +617,7 @@ func.func @omp_simd_pretty(%lb : index, %ub : index, %step : index) -> () {
omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
omp.yield
}
+ omp.terminator
}
return
}
@@ -632,6 +633,7 @@ func.func @omp_simd_pretty_aligned(%lb : index, %ub : index, %step : index,
omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
omp.yield
}
+ omp.terminator
}
return
}
@@ -643,6 +645,7 @@ func.func @omp_simd_pretty_if(%lb : index, %ub : index, %step : index, %if_cond
omp.loop_nest (%iv): index = (%lb) to (%ub) step (%step) {
omp.yield
}
+ omp.terminator
}
return
}
@@ -656,6 +659,7 @@ func.func @omp_simd_pretty_nontemporal(%lb : index, %ub : index, %step : index,
omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
omp.yield
}
+ omp.terminator
}
return
}
@@ -667,18 +671,21 @@ func.func @omp_simd_pretty_order(%lb : index, %ub : index, %step : index) -> ()
omp.loop_nest (%iv): index = (%lb) to (%ub) step (%step) {
omp.yield
}
+ omp.terminator
}
// CHECK: omp.simd order(reproducible:concurrent)
omp.simd order(reproducible:concurrent) {
omp.loop_nest (%iv): index = (%lb) to (%ub) step (%step) {
omp.yield
}
+ omp.terminator
}
// CHECK: omp.simd order(unconstrained:concurrent)
omp.simd order(unconstrained:concurrent) {
omp.loop_nest (%iv): index = (%lb) to (%ub) step (%step) {
omp.yield
}
+ omp.terminator
}
return
}
@@ -690,6 +697,7 @@ func.func @omp_simd_pretty_simdlen(%lb : index, %ub : index, %step : index) -> (
omp.loop_nest (%iv): index = (%lb) to (%ub) step (%step) {
omp.yield
}
+ omp.terminator
}
return
}
@@ -701,6 +709,7 @@ func.func @omp_simd_pretty_safelen(%lb : index, %ub : index, %step : index) -> (
omp.loop_nest (%iv): index = (%lb) to (%ub) step (%step) {
omp.yield
}
+ omp.terminator
}
return
}
@@ -720,42 +729,49 @@ func.func @omp_distribute(%chunk_size : i32, %data_var : memref<i32>, %arg0 : i3
omp.loop_nest (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) {
omp.yield
}
+ omp.terminator
}
// CHECK: omp.distribute dist_schedule_static
omp.distribute dist_schedule_static {
omp.loop_nest (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) {
omp.yield
}
+ omp.terminator
}
// CHECK: omp.distribute dist_schedule_static chunk_size(%{{.+}} : i32)
omp.distribute dist_schedule_static chunk_size(%chunk_size : i32) {
omp.loop_nest (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) {
omp.yield
}
+ omp.terminator
}
// CHECK: omp.distribute order(concurrent)
omp.distribute order(concurrent) {
omp.loop_nest (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) {
omp.yield
}
+ omp.terminator
}
// CHECK: omp.distribute order(reproducible:concurrent)
omp.distribute order(reproducible:concurrent) {
omp.loop_nest (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) {
omp.yield
}
+ omp.terminator
}
// CHECK: omp.distribute order(unconstrained:concurrent)
omp.distribute order(unconstrained:concurrent) {
omp.loop_nest (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) {
omp.yield
}
+ omp.terminator
}
// CHECK: omp.distribute allocate(%{{.+}} : memref<i32> -> %{{.+}} : memref<i32>)
omp.distribute allocate(%data_var : memref<i32> -> %data_var : memref<i32>) {
omp.loop_nest (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) {
omp.yield
}
+ omp.terminator
}
// CHECK: omp.distribute
omp.distribute {
@@ -763,7 +779,9 @@ func.func @omp_distribute(%chunk_size : i32, %data_var : memref<i32>, %arg0 : i3
omp.loop_nest (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) {
omp.yield
}
+ omp.terminator
}
+ omp.terminator
}
return
}
@@ -2278,6 +2296,7 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
// CHECK: omp.yield
omp.yield
}
+ omp.terminator
}
%testbool = "test.bool"() : () -> (i1)
@@ -2288,6 +2307,7 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
// CHECK: omp.yield
omp.yield
}
+ omp.terminator
}
// CHECK: omp.taskloop final(%{{[^)]+}}) {
@@ -2296,6 +2316,7 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
// CHECK: omp.yield
omp.yield
}
+ omp.terminator
}
// CHECK: omp.taskloop untied {
@@ -2304,6 +2325,7 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
// CHECK: omp.yield
omp.yield
}
+ omp.terminator
}
// CHECK: omp.taskloop mergeable {
@@ -2312,6 +2334,7 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
// CHECK: omp.yield
omp.yield
}
+ omp.terminator
}
%testf32 = "test.f32"() : () -> (!llvm.ptr)
@@ -2322,6 +2345,7 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
// CHECK: omp.yield
omp.yield
}
+ omp.terminator
}
// Checking byref attribute for in_reduction
@@ -2331,6 +2355,7 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
// CHECK: omp.yield
omp.yield
}
+ omp.terminator
}
// CHECK: omp.taskloop reduction(byref @add_f32 -> %{{.+}} : !llvm.ptr, @add_f32 -> %{{.+}} : !llvm.ptr) {
@@ -2339,6 +2364,7 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
// CHECK: omp.yield
omp.yield
}
+ omp.terminator
}
// check byref attrbute for reduction
@@ -2348,6 +2374,7 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
// CHECK: omp.yield
omp.yield
}
+ omp.terminator
}
// CHECK: omp.taskloop in_reduction(@add_f32 -> %{{.+}} : !llvm.ptr) reduction(@add_f32 -> %{{.+}} : !llvm.ptr) {
@@ -2356,6 +2383,7 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
// CHECK: omp.yield
omp.yield
}
+ omp.terminator
}
%testi32 = "test.i32"() : () -> (i32)
@@ -2365,6 +2393,7 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
// CHECK: omp.yield
omp.yield
}
+ omp.terminator
}
%testmemref = "test.memref"() : () -> (memref<i32>)
@@ -2374,6 +2403,7 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
// CHECK: omp.yield
omp.yield
}
+ omp.terminator
}
%testi64 = "test.i64"() : () -> (i64)
@@ -2383,6 +2413,7 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
// CHECK: omp.yield
omp.yield
}
+ omp.terminator
}
// CHECK: omp.taskloop num_tasks(%{{[^:]+}}: i64) {
@@ -2391,6 +2422,7 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
// CHECK: omp.yield
omp.yield
}
+ omp.terminator
}
// CHECK: omp.taskloop nogroup {
@@ -2399,6 +2431,7 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
// CHECK: omp.yield
omp.yield
}
+ omp.terminator
}
// CHECK: omp.taskloop {
@@ -2408,7 +2441,9 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
// CHECK: omp.yield
omp.yield
}
+ omp.terminator
}
+ omp.terminator
}
// CHECK: return
diff --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
index 321de67aa48a1..dfeaf4be33adb 100644
--- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
@@ -726,6 +726,7 @@ llvm.func @simd_simple_multiple(%lb1 : i64, %ub1 : i64, %step1 : i64, %lb2 : i64
llvm.store %3, %5 : f32, !llvm.ptr
omp.yield
}
+ omp.terminator
}
llvm.return
}
@@ -749,6 +750,7 @@ llvm.func @simd_simple_multiple_simdlen(%lb1 : i64, %ub1 : i64, %step1 : i64, %l
llvm.store %3, %5 : f32, !llvm.ptr
omp.yield
}
+ omp.terminator
}
llvm.return
}
@@ -769,6 +771,7 @@ llvm.func @simd_simple_multiple_safelen(%lb1 : i64, %ub1 : i64, %step1 : i64, %l
llvm.store %3, %5 : f32, !llvm.ptr
omp.yield
}
+ omp.terminator
}
llvm.return
}
@@ -788,6 +791,7 @@ llvm.func @simd_simple_multiple_simdlen_safelen(%lb1 : i64, %ub1 : i64, %step1 :
llvm.store %3, %5 : f32, !llvm.ptr
omp.yield
}
+ omp.terminator
}
llvm.return
}
@@ -816,6 +820,7 @@ llvm.func @simd_if(%arg0: !llvm.ptr {fir.bindc_name = "n"}, %arg1: !llvm.ptr {fi
llvm.store %arg2, %1 : i32, !llvm.ptr
omp.yield
}
+ omp.terminator
}
llvm.return
}
@@ -836,6 +841,7 @@ llvm.func @simd_order() {
llvm.store %arg0, %2 : i64, !llvm.ptr
omp.yield
}
+ omp.terminator
}
llvm.return
}
More information about the flang-commits
mailing list