[flang-commits] [flang] [flang] Implement !DIR$ UNROLL [N] (PR #123331)

Jean-Didier PAILLEUX via flang-commits flang-commits at lists.llvm.org
Mon Jan 20 02:15:04 PST 2025


https://github.com/JDPailleux updated https://github.com/llvm/llvm-project/pull/123331

>From 7d0cf854c0c55962385602aeaadc10db57d6808e Mon Sep 17 00:00:00 2001
From: Jean-Didier PAILLEUX <jean-didier.pailleux at sipearl.com>
Date: Mon, 20 Jan 2025 11:24:07 +0100
Subject: [PATCH] [flang] Implement !DIR$ UNROLL [N]

---
 flang/include/flang/Parser/dump-parse-tree.h  |  1 +
 flang/include/flang/Parser/parse-tree.h       |  5 +-
 flang/lib/Lower/Bridge.cpp                    | 52 ++++++++++++++-----
 flang/lib/Parser/Fortran-parsers.cpp          |  4 ++
 flang/lib/Parser/unparse.cpp                  |  4 ++
 .../lib/Semantics/canonicalize-directives.cpp |  7 ++-
 flang/lib/Semantics/resolve-names.cpp         |  3 +-
 flang/test/Integration/unroll.f90             | 16 ++++++
 flang/test/Lower/unroll.f90                   | 27 ++++++++++
 flang/test/Parser/compiler-directives.f90     | 11 ++++
 flang/test/Semantics/loop-directives.f90      | 16 ++++++
 11 files changed, 129 insertions(+), 17 deletions(-)
 create mode 100644 flang/test/Integration/unroll.f90
 create mode 100644 flang/test/Lower/unroll.f90

diff --git a/flang/include/flang/Parser/dump-parse-tree.h b/flang/include/flang/Parser/dump-parse-tree.h
index 11725991e9c9a9..8ed44f93a4e8f3 100644
--- a/flang/include/flang/Parser/dump-parse-tree.h
+++ b/flang/include/flang/Parser/dump-parse-tree.h
@@ -208,6 +208,7 @@ class ParseTreeDumper {
   NODE(CompilerDirective, NameValue)
   NODE(CompilerDirective, Unrecognized)
   NODE(CompilerDirective, VectorAlways)
+  NODE(CompilerDirective, Unroll)
   NODE(parser, ComplexLiteralConstant)
   NODE(parser, ComplexPart)
   NODE(parser, ComponentArraySpec)
diff --git a/flang/include/flang/Parser/parse-tree.h b/flang/include/flang/Parser/parse-tree.h
index 00d85aa05fb3a5..2a8020a5f9525f 100644
--- a/flang/include/flang/Parser/parse-tree.h
+++ b/flang/include/flang/Parser/parse-tree.h
@@ -3368,10 +3368,13 @@ struct CompilerDirective {
     TUPLE_CLASS_BOILERPLATE(NameValue);
     std::tuple<Name, std::optional<std::uint64_t>> t;
   };
+  struct Unroll {
+    WRAPPER_CLASS_BOILERPLATE(Unroll, std::optional<std::uint64_t>);
+  };
   EMPTY_CLASS(Unrecognized);
   CharBlock source;
   std::variant<std::list<IgnoreTKR>, LoopCount, std::list<AssumeAligned>,
-      VectorAlways, std::list<NameValue>, Unrecognized>
+      VectorAlways, std::list<NameValue>, Unroll, Unrecognized>
       u;
 };
 
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 37f51d74d23f8f..91306f70118bfd 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -2153,14 +2153,42 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     return builder->createIntegerConstant(loc, controlType, 1); // step
   }
 
-  void addLoopAnnotationAttr(IncrementLoopInfo &info) {
+  void addLoopAnnotationAttr(
+      IncrementLoopInfo &info,
+      llvm::SmallVectorImpl<const Fortran::parser::CompilerDirective *> &dirs) {
     mlir::BoolAttr f = mlir::BoolAttr::get(builder->getContext(), false);
-    mlir::LLVM::LoopVectorizeAttr va = mlir::LLVM::LoopVectorizeAttr::get(
-        builder->getContext(), /*disable=*/f, {}, {}, {}, {}, {}, {});
+    mlir::BoolAttr t = mlir::BoolAttr::get(builder->getContext(), true);
+    mlir::LLVM::LoopVectorizeAttr va;
+    mlir::LLVM::LoopUnrollAttr ua;
+    bool has_attrs = false;
+    for (const auto *dir : dirs) {
+      Fortran::common::visit(
+          Fortran::common::visitors{
+              [&](const Fortran::parser::CompilerDirective::VectorAlways &) {
+                va = mlir::LLVM::LoopVectorizeAttr::get(builder->getContext(),
+                                                        /*disable=*/f, {}, {},
+                                                        {}, {}, {}, {});
+                has_attrs = true;
+              },
+              [&](const Fortran::parser::CompilerDirective::Unroll &u) {
+                mlir::IntegerAttr countAttr;
+                if (u.v.has_value()) {
+                  countAttr = builder->getIntegerAttr(builder->getI64Type(),
+                                                      u.v.value());
+                }
+                ua = mlir::LLVM::LoopUnrollAttr::get(
+                    builder->getContext(), /*disable=*/f, /*count*/ countAttr,
+                    {}, /*full*/ u.v.has_value() ? f : t, {}, {}, {});
+                has_attrs = true;
+              },
+              [&](const auto &) {}},
+          dir->u);
+    }
     mlir::LLVM::LoopAnnotationAttr la = mlir::LLVM::LoopAnnotationAttr::get(
-        builder->getContext(), {}, /*vectorize=*/va, {}, {}, {}, {}, {}, {}, {},
-        {}, {}, {}, {}, {}, {});
-    info.doLoop.setLoopAnnotationAttr(la);
+        builder->getContext(), {}, /*vectorize=*/va, {}, /*unroll*/ ua, {}, {},
+        {}, {}, {}, {}, {}, {}, {}, {}, {});
+    if (has_attrs)
+      info.doLoop.setLoopAnnotationAttr(la);
   }
 
   /// Generate FIR to begin a structured or unstructured increment loop nest.
@@ -2259,14 +2287,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
         if (info.hasLocalitySpecs())
           handleLocalitySpecs(info);
 
-        for (const auto *dir : dirs) {
-          Fortran::common::visit(
-              Fortran::common::visitors{
-                  [&](const Fortran::parser::CompilerDirective::VectorAlways
-                          &d) { addLoopAnnotationAttr(info); },
-                  [&](const auto &) {}},
-              dir->u);
-        }
+        addLoopAnnotationAttr(info, dirs);
         continue;
       }
 
@@ -2818,6 +2839,9 @@ class FirConverter : public Fortran::lower::AbstractConverter {
             [&](const Fortran::parser::CompilerDirective::VectorAlways &) {
               attachDirectiveToLoop(dir, &eval);
             },
+            [&](const Fortran::parser::CompilerDirective::Unroll &) {
+              attachDirectiveToLoop(dir, &eval);
+            },
             [&](const auto &) {}},
         dir.u);
   }
diff --git a/flang/lib/Parser/Fortran-parsers.cpp b/flang/lib/Parser/Fortran-parsers.cpp
index 7cb35c1f173bb6..b5bcb53a127613 100644
--- a/flang/lib/Parser/Fortran-parsers.cpp
+++ b/flang/lib/Parser/Fortran-parsers.cpp
@@ -1293,6 +1293,7 @@ TYPE_PARSER(construct<StatOrErrmsg>("STAT =" >> statVariable) ||
 // !DIR$ IGNORE_TKR [ [(tkrdmac...)] name ]...
 // !DIR$ LOOP COUNT (n1[, n2]...)
 // !DIR$ name[=value] [, name[=value]]...
+// !DIR$ UNROLL [n]
 // !DIR$ <anything else>
 constexpr auto ignore_tkr{
     "IGNORE_TKR" >> optionalList(construct<CompilerDirective::IgnoreTKR>(
@@ -1305,11 +1306,14 @@ constexpr auto assumeAligned{"ASSUME_ALIGNED" >>
         indirect(designator), ":"_tok >> digitString64))};
 constexpr auto vectorAlways{
     "VECTOR ALWAYS" >> construct<CompilerDirective::VectorAlways>()};
+constexpr auto unroll{
+    "UNROLL" >> construct<CompilerDirective::Unroll>(maybe(digitString64))};
 TYPE_PARSER(beginDirective >> "DIR$ "_tok >>
     sourced((construct<CompilerDirective>(ignore_tkr) ||
                 construct<CompilerDirective>(loopCount) ||
                 construct<CompilerDirective>(assumeAligned) ||
                 construct<CompilerDirective>(vectorAlways) ||
+                construct<CompilerDirective>(unroll) ||
                 construct<CompilerDirective>(
                     many(construct<CompilerDirective::NameValue>(
                         name, maybe(("="_tok || ":"_tok) >> digitString64))))) /
diff --git a/flang/lib/Parser/unparse.cpp b/flang/lib/Parser/unparse.cpp
index 7bf404bba2c3e4..8c70564de16650 100644
--- a/flang/lib/Parser/unparse.cpp
+++ b/flang/lib/Parser/unparse.cpp
@@ -1847,6 +1847,10 @@ class UnparseVisitor {
             [&](const std::list<CompilerDirective::NameValue> &names) {
               Walk("!DIR$ ", names, " ");
             },
+            [&](const CompilerDirective::Unroll &unroll) {
+              Word("!DIR$ UNROLL");
+              Walk(" ", unroll.v);
+            },
             [&](const CompilerDirective::Unrecognized &) {
               Word("!DIR$ ");
               Word(x.source.ToString());
diff --git a/flang/lib/Semantics/canonicalize-directives.cpp b/flang/lib/Semantics/canonicalize-directives.cpp
index 739bc3c1992ba6..b27a27618808bc 100644
--- a/flang/lib/Semantics/canonicalize-directives.cpp
+++ b/flang/lib/Semantics/canonicalize-directives.cpp
@@ -54,7 +54,9 @@ bool CanonicalizeDirectives(
 }
 
 static bool IsExecutionDirective(const parser::CompilerDirective &dir) {
-  return std::holds_alternative<parser::CompilerDirective::VectorAlways>(dir.u);
+  return std::holds_alternative<parser::CompilerDirective::VectorAlways>(
+             dir.u) ||
+      std::holds_alternative<parser::CompilerDirective::Unroll>(dir.u);
 }
 
 void CanonicalizationOfDirectives::Post(parser::SpecificationPart &spec) {
@@ -110,6 +112,9 @@ void CanonicalizationOfDirectives::Post(parser::Block &block) {
           common::visitors{[&](parser::CompilerDirective::VectorAlways &) {
                              CheckLoopDirective(*dir, block, it);
                            },
+              [&](parser::CompilerDirective::Unroll &) {
+                CheckLoopDirective(*dir, block, it);
+              },
               [&](auto &) {}},
           dir->u);
     }
diff --git a/flang/lib/Semantics/resolve-names.cpp b/flang/lib/Semantics/resolve-names.cpp
index f3c2a5bf094d04..705a6e0df04d99 100644
--- a/flang/lib/Semantics/resolve-names.cpp
+++ b/flang/lib/Semantics/resolve-names.cpp
@@ -9245,7 +9245,8 @@ void ResolveNamesVisitor::Post(const parser::AssignedGotoStmt &x) {
 }
 
 void ResolveNamesVisitor::Post(const parser::CompilerDirective &x) {
-  if (std::holds_alternative<parser::CompilerDirective::VectorAlways>(x.u)) {
+  if (std::holds_alternative<parser::CompilerDirective::VectorAlways>(x.u) ||
+      std::holds_alternative<parser::CompilerDirective::Unroll>(x.u)) {
     return;
   }
   if (const auto *tkr{
diff --git a/flang/test/Integration/unroll.f90 b/flang/test/Integration/unroll.f90
new file mode 100644
index 00000000000000..9d69605e10d1b3
--- /dev/null
+++ b/flang/test/Integration/unroll.f90
@@ -0,0 +1,16 @@
+! RUN: %flang_fc1 -emit-llvm -o - %s | FileCheck %s
+
+! CHECK-LABEL: unroll_dir
+subroutine unroll_dir
+  integer :: a(10)
+  !dir$ unroll 
+  ! CHECK:   br i1 {{.*}}, label {{.*}}, label {{.*}}, !llvm.loop ![[ANNOTATION:.*]]
+  do i=1,10
+     a(i)=i
+  end do
+end subroutine unroll_dir
+
+! CHECK: ![[ANNOTATION]] = distinct !{![[ANNOTATION]], ![[UNROLL:.*]], ![[UNROLL_FULL:.*]]}
+! CHECK: ![[UNROLL]] = !{!"llvm.loop.unroll.enable"}
+! CHECK: ![[UNROLL_FULL]] = !{!"llvm.loop.unroll.full"}
+
diff --git a/flang/test/Lower/unroll.f90 b/flang/test/Lower/unroll.f90
new file mode 100644
index 00000000000000..229755200fd8d8
--- /dev/null
+++ b/flang/test/Lower/unroll.f90
@@ -0,0 +1,27 @@
+! RUN: %flang_fc1 -emit-hlfir -o - %s | FileCheck %s
+
+! CHECK: #loop_unroll = #llvm.loop_unroll<disable = false, full = true>
+! CHECK: #loop_annotation = #llvm.loop_annotation<unroll = #loop_unroll>
+
+! CHECK-LABEL: unroll_dir
+subroutine unroll_dir
+  integer :: a(10)
+  !dir$ unroll
+  !CHECK: fir.do_loop {{.*}} attributes {loopAnnotation = #loop_annotation}
+  do i=1,10
+     a(i)=i
+  end do
+end subroutine unroll_dir
+
+
+! CHECK-LABEL: intermediate_directive
+subroutine intermediate_directive
+  integer :: a(10)
+  !dir$ unroll
+  !dir$ unknown
+  !CHECK: fir.do_loop {{.*}} attributes {loopAnnotation = #loop_annotation}
+  do i=1,10
+     a(i)=i
+  end do
+end subroutine intermediate_directive
+
diff --git a/flang/test/Parser/compiler-directives.f90 b/flang/test/Parser/compiler-directives.f90
index 246eaf985251c6..f372a9f533a355 100644
--- a/flang/test/Parser/compiler-directives.f90
+++ b/flang/test/Parser/compiler-directives.f90
@@ -35,3 +35,14 @@ subroutine vector_always
   do i=1,10
   enddo
 end subroutine
+
+subroutine unroll
+  !dir$ unroll
+  ! CHECK: !DIR$ UNROLL
+  do i=1,10
+  enddo
+  !dir$ unroll 2
+  ! CHECK: !DIR$ UNROLL 2
+  do i=1,10
+  enddo
+end subroutine
diff --git a/flang/test/Semantics/loop-directives.f90 b/flang/test/Semantics/loop-directives.f90
index 58fb9b8082bc1a..e20c0c9d042dbf 100644
--- a/flang/test/Semantics/loop-directives.f90
+++ b/flang/test/Semantics/loop-directives.f90
@@ -4,11 +4,15 @@
 subroutine empty
   ! WARNING: A DO loop must follow the VECTOR ALWAYS directive
   !dir$ vector always
+  ! WARNING: A DO loop must follow the UNROLL directive
+  !dir$ unroll
 end subroutine empty
 
 subroutine non_do
   ! WARNING: A DO loop must follow the VECTOR ALWAYS directive
   !dir$ vector always
+  ! WARNING: A DO loop must follow the UNROLL directive
+  !dir$ unroll
   a = 1
 end subroutine non_do
 
@@ -16,6 +20,8 @@ subroutine execution_part
   do i=1,10
   ! WARNING: A DO loop must follow the VECTOR ALWAYS directive
   !dir$ vector always
+  ! WARNING: A DO loop must follow the UNROLL directive
+  !dir$ unroll
   end do
 end subroutine execution_part
 
@@ -28,3 +34,13 @@ subroutine test_vector_always_before_acc(a, b, c)
     a(i) = b(i) + c(i)
   enddo
 end subroutine
+
+! OK
+subroutine test_unroll_before_acc(a, b, c)
+  real, dimension(10) :: a,b,c
+  !dir$ unroll
+  !$acc loop
+  do i=1,N
+    a(i) = b(i) + c(i)
+  enddo
+end subroutine



More information about the flang-commits mailing list