[flang-commits] [flang] [mlir] [MLIR][OpenMP] Improve loop wrapper representation (PR #97706)

Sergio Afonso via flang-commits flang-commits at lists.llvm.org
Thu Jul 4 03:15:01 PDT 2024


https://github.com/skatrak created https://github.com/llvm/llvm-project/pull/97706

This patch replaces the `SingleBlockImplicitTerminator<"TerminatorOp">` trait of loop wrapper operations for the `SingleBlock` trait. This enables a more robust implementation of the `LoopWrapperInterface::isWrapper()` method, since it does no longer have to deal with the potentially missing (implicit) terminator.

The `LoopWrapperInterface::isWrapper()` method is also extended to not identify as wrappers those operations which have a loop wrapper operation inside that is not taking a wrapper role. This is important for cases where `omp.parallel` is nested, which can but is not required to work as a loop wrapper.

Tests are updated to integrate these representation and validation changes.

>From 0e48712dfcb823032c8980911924f7288f5fa6d6 Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Thu, 4 Jul 2024 10:56:43 +0100
Subject: [PATCH] [MLIR][OpenMP] Improve loop wrapper representation

This patch replaces the `SingleBlockImplicitTerminator<"TerminatorOp">` trait
of loop wrapper operations for the `SingleBlock` trait. This enables a more
robust implementation of the `LoopWrapperInterface::isWrapper()` method, since
it does no longer have to deal with the potentially missing (implicit)
terminator.

The `LoopWrapperInterface::isWrapper()` method is also extended to not identify
as wrappers those operations which have a loop wrapper operation inside that is
not taking a wrapper role. This is important for cases where `omp.parallel`
is nested, which can but is not required to work as a loop wrapper.

Tests are updated to integrate these representation and validation changes.
---
 .../Fir/convert-to-llvm-openmp-and-fir.fir    |  4 +++
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td |  8 ++---
 .../Dialect/OpenMP/OpenMPOpsInterfaces.td     | 14 +++++---
 .../OpenMPToLLVM/convert-to-llvmir.mlir       |  1 +
 mlir/test/Dialect/OpenMP/invalid.mlir         | 15 +++++---
 mlir/test/Dialect/OpenMP/ops.mlir             | 35 +++++++++++++++++++
 mlir/test/Target/LLVMIR/openmp-llvm.mlir      |  6 ++++
 7 files changed, 71 insertions(+), 12 deletions(-)

diff --git a/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir b/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
index 8b62787bb3094..eca762d52a724 100644
--- a/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
+++ b/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
@@ -200,6 +200,7 @@ func.func @_QPsimd1(%arg0: !fir.ref<i32> {fir.bindc_name = "n"}, %arg1: !fir.ref
         fir.store %3 to %6 : !fir.ref<i32>
         omp.yield
       }
+      omp.terminator
     }
     omp.terminator
   }
@@ -225,6 +226,7 @@ func.func @_QPsimd1(%arg0: !fir.ref<i32> {fir.bindc_name = "n"}, %arg1: !fir.ref
 // CHECK:   llvm.store %[[I1]], %[[ARR_I_REF]] : i32, !llvm.ptr
 // CHECK: omp.yield
 // CHECK: }
+// CHECK: omp.terminator
 // CHECK: }
 // CHECK: omp.terminator
 // CHECK: }
@@ -518,6 +520,7 @@ func.func @_QPsimd_with_nested_loop() {
       fir.store %7 to %3 : !fir.ref<i32>
       omp.yield
     }
+    omp.terminator
   }
   return
 }
@@ -538,6 +541,7 @@ func.func @_QPsimd_with_nested_loop() {
 // CHECK:             ^bb3:
 // CHECK:               omp.yield
 // CHECK:             }
+// CHECK:             omp.terminator
 // CHECK:           }
 // CHECK:           llvm.return
 // CHECK:         }
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 99e14cd1b7b48..aed0d69619db2 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -354,7 +354,7 @@ def LoopNestOp : OpenMP_Op<"loop_nest", traits = [
 
 def WsloopOp : OpenMP_Op<"wsloop", traits = [
     AttrSizedOperandSegments, DeclareOpInterfaceMethods<LoopWrapperInterface>,
-    RecursiveMemoryEffects, SingleBlockImplicitTerminator<"TerminatorOp">
+    RecursiveMemoryEffects, SingleBlock
   ], clauses = [
     // TODO: Complete clause list (allocate, private).
     // TODO: Sort clauses alphabetically.
@@ -418,7 +418,7 @@ def WsloopOp : OpenMP_Op<"wsloop", traits = [
 
 def SimdOp : OpenMP_Op<"simd", traits = [
     AttrSizedOperandSegments, DeclareOpInterfaceMethods<LoopWrapperInterface>,
-    RecursiveMemoryEffects, SingleBlockImplicitTerminator<"TerminatorOp">
+    RecursiveMemoryEffects, SingleBlock
   ], clauses = [
     // TODO: Complete clause list (linear, private, reduction).
     OpenMP_AlignedClause, OpenMP_IfClause, OpenMP_NontemporalClause,
@@ -485,7 +485,7 @@ def YieldOp : OpenMP_Op<"yield",
 //===----------------------------------------------------------------------===//
 def DistributeOp : OpenMP_Op<"distribute", traits = [
     AttrSizedOperandSegments, DeclareOpInterfaceMethods<LoopWrapperInterface>,
-    RecursiveMemoryEffects, SingleBlockImplicitTerminator<"TerminatorOp">
+    RecursiveMemoryEffects, SingleBlock
   ], clauses = [
     // TODO: Complete clause list (private).
     // TODO: Sort clauses alphabetically.
@@ -575,7 +575,7 @@ def TaskOp : OpenMP_Op<"task", traits = [
 def TaskloopOp : OpenMP_Op<"taskloop", traits = [
     AttrSizedOperandSegments, AutomaticAllocationScope,
     DeclareOpInterfaceMethods<LoopWrapperInterface>, RecursiveMemoryEffects,
-    SingleBlockImplicitTerminator<"TerminatorOp">
+    SingleBlock
   ], clauses = [
     // TODO: Complete clause list (private).
     // TODO: Sort clauses alphabetically.
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
index 31a306072d0ec..385aa8b1b016a 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
@@ -84,8 +84,8 @@ def LoopWrapperInterface : OpInterface<"LoopWrapperInterface"> {
       /*description=*/[{
         Tell whether the operation could be taking the role of a loop wrapper.
         That is, it has a single region with a single block in which there are
-        two operations: another wrapper or `omp.loop_nest` operation and a
-        terminator.
+        two operations: another wrapper (also taking a loop wrapper role) or
+        `omp.loop_nest` operation and a terminator.
       }],
       /*retTy=*/"bool",
       /*methodName=*/"isWrapper",
@@ -102,8 +102,14 @@ def LoopWrapperInterface : OpInterface<"LoopWrapperInterface"> {
 
         Operation &firstOp = *r.op_begin();
         Operation &secondOp = *(std::next(r.op_begin()));
-        return ::llvm::isa<LoopNestOp, LoopWrapperInterface>(firstOp) &&
-               secondOp.hasTrait<OpTrait::IsTerminator>();
+
+        if (!secondOp.hasTrait<OpTrait::IsTerminator>())
+          return false;
+
+        if (auto wrapper = ::llvm::dyn_cast<LoopWrapperInterface>(firstOp))
+          return wrapper.isWrapper();
+
+        return ::llvm::isa<LoopNestOp>(firstOp);
       }]
     >,
     InterfaceMethod<
diff --git a/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir
index 3aeb9e70522d5..4c9e09970279a 100644
--- a/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir
+++ b/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir
@@ -174,6 +174,7 @@ func.func @loop_nest_block_arg(%val : i32, %ub : i32, %i : index) {
     ^bb3:
       omp.yield
     }
+    omp.terminator
   }
   return
 }
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 2915963f704d3..91eeb0911160d 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -11,8 +11,8 @@ func.func @unknown_clause() {
 // -----
 
 func.func @not_wrapper() {
+  // expected-error at +1 {{op must be a loop wrapper}}
   omp.distribute {
-    // expected-error at +1 {{op must take a loop wrapper role if nested inside of 'omp.distribute'}}
     omp.parallel {
       %0 = arith.constant 0 : i32
       omp.terminator
@@ -383,12 +383,16 @@ func.func @omp_simd() -> () {
 
 // -----
 
-func.func @omp_simd_nested_wrapper() -> () {
+func.func @omp_simd_nested_wrapper(%lb : index, %ub : index, %step : index) -> () {
   // expected-error @below {{op must wrap an 'omp.loop_nest' directly}}
   omp.simd {
     omp.distribute {
+      omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
+        omp.yield
+      }
       omp.terminator
     }
+    omp.terminator
   }
   return
 }
@@ -1960,6 +1964,7 @@ func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
       }
       omp.terminator
     }
+    omp.terminator
   }
   return
 }
@@ -2158,11 +2163,13 @@ func.func @omp_distribute_wrapper() -> () {
 
 // -----
 
-func.func @omp_distribute_nested_wrapper(%data_var : memref<i32>) -> () {
+func.func @omp_distribute_nested_wrapper(%lb: index, %ub: index, %step: index) -> () {
   // expected-error @below {{only supported nested wrappers are 'omp.parallel' and 'omp.simd'}}
   omp.distribute {
     "omp.wsloop"() ({
-      %0 = arith.constant 0 : i32
+      omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
+        "omp.yield"() : () -> ()
+      }
       "omp.terminator"() : () -> ()
     }) : () -> ()
     "omp.terminator"() : () -> ()
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index eb283840aa7ee..ff3b1e60f7cfe 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -617,6 +617,7 @@ func.func @omp_simd_pretty(%lb : index, %ub : index, %step : index) -> () {
     omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
       omp.yield
     }
+    omp.terminator
   }
   return
 }
@@ -632,6 +633,7 @@ func.func @omp_simd_pretty_aligned(%lb : index, %ub : index, %step : index,
     omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
       omp.yield
     }
+    omp.terminator
   }
   return
 }
@@ -643,6 +645,7 @@ func.func @omp_simd_pretty_if(%lb : index, %ub : index, %step : index, %if_cond
     omp.loop_nest (%iv): index = (%lb) to (%ub) step (%step) {
       omp.yield
     }
+    omp.terminator
   }
   return
 }
@@ -656,6 +659,7 @@ func.func @omp_simd_pretty_nontemporal(%lb : index, %ub : index, %step : index,
     omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
       omp.yield
     }
+    omp.terminator
   }
   return
 }
@@ -667,18 +671,21 @@ func.func @omp_simd_pretty_order(%lb : index, %ub : index, %step : index) -> ()
     omp.loop_nest (%iv): index = (%lb) to (%ub) step (%step) {
       omp.yield
     }
+    omp.terminator
   }
   // CHECK: omp.simd order(reproducible:concurrent)
   omp.simd order(reproducible:concurrent) {
     omp.loop_nest (%iv): index = (%lb) to (%ub) step (%step) {
       omp.yield
     }
+    omp.terminator
   }
   // CHECK: omp.simd order(unconstrained:concurrent)
   omp.simd order(unconstrained:concurrent) {
     omp.loop_nest (%iv): index = (%lb) to (%ub) step (%step) {
       omp.yield
     }
+    omp.terminator
   }
   return
 }
@@ -690,6 +697,7 @@ func.func @omp_simd_pretty_simdlen(%lb : index, %ub : index, %step : index) -> (
     omp.loop_nest (%iv): index = (%lb) to (%ub) step (%step) {
       omp.yield
     }
+    omp.terminator
   }
   return
 }
@@ -701,6 +709,7 @@ func.func @omp_simd_pretty_safelen(%lb : index, %ub : index, %step : index) -> (
     omp.loop_nest (%iv): index = (%lb) to (%ub) step (%step) {
       omp.yield
     }
+    omp.terminator
   }
   return
 }
@@ -720,42 +729,49 @@ func.func @omp_distribute(%chunk_size : i32, %data_var : memref<i32>, %arg0 : i3
     omp.loop_nest (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) {
       omp.yield
     }
+    omp.terminator
   }
   // CHECK: omp.distribute dist_schedule_static
   omp.distribute dist_schedule_static {
     omp.loop_nest (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) {
       omp.yield
     }
+    omp.terminator
   }
   // CHECK: omp.distribute dist_schedule_static chunk_size(%{{.+}} : i32)
   omp.distribute dist_schedule_static chunk_size(%chunk_size : i32) {
     omp.loop_nest (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) {
       omp.yield
     }
+    omp.terminator
   }
   // CHECK: omp.distribute order(concurrent)
   omp.distribute order(concurrent) {
     omp.loop_nest (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) {
       omp.yield
     }
+    omp.terminator
   }
   // CHECK: omp.distribute order(reproducible:concurrent)
   omp.distribute order(reproducible:concurrent) {
     omp.loop_nest (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) {
       omp.yield
     }
+    omp.terminator
   }
   // CHECK: omp.distribute order(unconstrained:concurrent)
   omp.distribute order(unconstrained:concurrent) {
     omp.loop_nest (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) {
       omp.yield
     }
+    omp.terminator
   }
   // CHECK: omp.distribute allocate(%{{.+}} : memref<i32> -> %{{.+}} : memref<i32>)
   omp.distribute allocate(%data_var : memref<i32> -> %data_var : memref<i32>) {
     omp.loop_nest (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) {
       omp.yield
     }
+    omp.terminator
   }
   // CHECK: omp.distribute
   omp.distribute {
@@ -763,7 +779,9 @@ func.func @omp_distribute(%chunk_size : i32, %data_var : memref<i32>, %arg0 : i3
       omp.loop_nest (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) {
         omp.yield
       }
+      omp.terminator
     }
+    omp.terminator
   }
   return
 }
@@ -2278,6 +2296,7 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
       // CHECK: omp.yield
       omp.yield
     }
+    omp.terminator
   }
 
   %testbool = "test.bool"() : () -> (i1)
@@ -2288,6 +2307,7 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
       // CHECK: omp.yield
       omp.yield
     }
+    omp.terminator
   }
 
   // CHECK: omp.taskloop final(%{{[^)]+}}) {
@@ -2296,6 +2316,7 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
       // CHECK: omp.yield
       omp.yield
     }
+    omp.terminator
   }
 
   // CHECK: omp.taskloop untied {
@@ -2304,6 +2325,7 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
       // CHECK: omp.yield
       omp.yield
     }
+    omp.terminator
   }
 
   // CHECK: omp.taskloop mergeable {
@@ -2312,6 +2334,7 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
       // CHECK: omp.yield
       omp.yield
     }
+    omp.terminator
   }
 
   %testf32 = "test.f32"() : () -> (!llvm.ptr)
@@ -2322,6 +2345,7 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
       // CHECK: omp.yield
       omp.yield
     }
+    omp.terminator
   }
 
   // Checking byref attribute for in_reduction
@@ -2331,6 +2355,7 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
       // CHECK: omp.yield
       omp.yield
     }
+    omp.terminator
   }
 
   // CHECK: omp.taskloop reduction(byref @add_f32 -> %{{.+}} : !llvm.ptr, @add_f32 -> %{{.+}} : !llvm.ptr) {
@@ -2339,6 +2364,7 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
       // CHECK: omp.yield
       omp.yield
     }
+    omp.terminator
   }
 
   // check byref attrbute for reduction
@@ -2348,6 +2374,7 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
       // CHECK: omp.yield
       omp.yield
     }
+    omp.terminator
   }
 
   // CHECK: omp.taskloop in_reduction(@add_f32 -> %{{.+}} : !llvm.ptr) reduction(@add_f32 -> %{{.+}} : !llvm.ptr) {
@@ -2356,6 +2383,7 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
       // CHECK: omp.yield
       omp.yield
     }
+    omp.terminator
   }
 
   %testi32 = "test.i32"() : () -> (i32)
@@ -2365,6 +2393,7 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
       // CHECK: omp.yield
       omp.yield
     }
+    omp.terminator
   }
 
   %testmemref = "test.memref"() : () -> (memref<i32>)
@@ -2374,6 +2403,7 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
       // CHECK: omp.yield
       omp.yield
     }
+    omp.terminator
   }
 
   %testi64 = "test.i64"() : () -> (i64)
@@ -2383,6 +2413,7 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
       // CHECK: omp.yield
       omp.yield
     }
+    omp.terminator
   }
 
   // CHECK: omp.taskloop num_tasks(%{{[^:]+}}: i64) {
@@ -2391,6 +2422,7 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
       // CHECK: omp.yield
       omp.yield
     }
+    omp.terminator
   }
 
   // CHECK: omp.taskloop nogroup {
@@ -2399,6 +2431,7 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
       // CHECK: omp.yield
       omp.yield
     }
+    omp.terminator
   }
 
   // CHECK: omp.taskloop {
@@ -2408,7 +2441,9 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
         // CHECK: omp.yield
         omp.yield
       }
+      omp.terminator
     }
+    omp.terminator
   }
 
   // CHECK: return
diff --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
index 321de67aa48a1..dfeaf4be33adb 100644
--- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
@@ -726,6 +726,7 @@ llvm.func @simd_simple_multiple(%lb1 : i64, %ub1 : i64, %step1 : i64, %lb2 : i64
       llvm.store %3, %5 : f32, !llvm.ptr
       omp.yield
     }
+    omp.terminator
   }
   llvm.return
 }
@@ -749,6 +750,7 @@ llvm.func @simd_simple_multiple_simdlen(%lb1 : i64, %ub1 : i64, %step1 : i64, %l
       llvm.store %3, %5 : f32, !llvm.ptr
       omp.yield
     }
+    omp.terminator
   }
   llvm.return
 }
@@ -769,6 +771,7 @@ llvm.func @simd_simple_multiple_safelen(%lb1 : i64, %ub1 : i64, %step1 : i64, %l
       llvm.store %3, %5 : f32, !llvm.ptr
       omp.yield
     }
+    omp.terminator
   }
   llvm.return
 }
@@ -788,6 +791,7 @@ llvm.func @simd_simple_multiple_simdlen_safelen(%lb1 : i64, %ub1 : i64, %step1 :
       llvm.store %3, %5 : f32, !llvm.ptr
       omp.yield
     }
+    omp.terminator
   }
   llvm.return
 }
@@ -816,6 +820,7 @@ llvm.func @simd_if(%arg0: !llvm.ptr {fir.bindc_name = "n"}, %arg1: !llvm.ptr {fi
       llvm.store %arg2, %1 : i32, !llvm.ptr
       omp.yield
     }
+    omp.terminator
   }
   llvm.return
 }
@@ -836,6 +841,7 @@ llvm.func @simd_order() {
       llvm.store %arg0, %2 : i64, !llvm.ptr
       omp.yield
     }
+    omp.terminator
   }
   llvm.return
 }



More information about the flang-commits mailing list