[flang-commits] [flang] [flang] Implement !DIR$ UNROLL [N] (PR #123331)
Jean-Didier PAILLEUX via flang-commits
flang-commits at lists.llvm.org
Mon Jan 20 02:15:04 PST 2025
https://github.com/JDPailleux updated https://github.com/llvm/llvm-project/pull/123331
>From 7d0cf854c0c55962385602aeaadc10db57d6808e Mon Sep 17 00:00:00 2001
From: Jean-Didier PAILLEUX <jean-didier.pailleux at sipearl.com>
Date: Mon, 20 Jan 2025 11:24:07 +0100
Subject: [PATCH] [flang] Implement !DIR$ UNROLL [N]
---
flang/include/flang/Parser/dump-parse-tree.h | 1 +
flang/include/flang/Parser/parse-tree.h | 5 +-
flang/lib/Lower/Bridge.cpp | 52 ++++++++++++++-----
flang/lib/Parser/Fortran-parsers.cpp | 4 ++
flang/lib/Parser/unparse.cpp | 4 ++
.../lib/Semantics/canonicalize-directives.cpp | 7 ++-
flang/lib/Semantics/resolve-names.cpp | 3 +-
flang/test/Integration/unroll.f90 | 16 ++++++
flang/test/Lower/unroll.f90 | 27 ++++++++++
flang/test/Parser/compiler-directives.f90 | 11 ++++
flang/test/Semantics/loop-directives.f90 | 16 ++++++
11 files changed, 129 insertions(+), 17 deletions(-)
create mode 100644 flang/test/Integration/unroll.f90
create mode 100644 flang/test/Lower/unroll.f90
diff --git a/flang/include/flang/Parser/dump-parse-tree.h b/flang/include/flang/Parser/dump-parse-tree.h
index 11725991e9c9a9..8ed44f93a4e8f3 100644
--- a/flang/include/flang/Parser/dump-parse-tree.h
+++ b/flang/include/flang/Parser/dump-parse-tree.h
@@ -208,6 +208,7 @@ class ParseTreeDumper {
NODE(CompilerDirective, NameValue)
NODE(CompilerDirective, Unrecognized)
NODE(CompilerDirective, VectorAlways)
+ NODE(CompilerDirective, Unroll)
NODE(parser, ComplexLiteralConstant)
NODE(parser, ComplexPart)
NODE(parser, ComponentArraySpec)
diff --git a/flang/include/flang/Parser/parse-tree.h b/flang/include/flang/Parser/parse-tree.h
index 00d85aa05fb3a5..2a8020a5f9525f 100644
--- a/flang/include/flang/Parser/parse-tree.h
+++ b/flang/include/flang/Parser/parse-tree.h
@@ -3368,10 +3368,13 @@ struct CompilerDirective {
TUPLE_CLASS_BOILERPLATE(NameValue);
std::tuple<Name, std::optional<std::uint64_t>> t;
};
+ struct Unroll {
+ WRAPPER_CLASS_BOILERPLATE(Unroll, std::optional<std::uint64_t>);
+ };
EMPTY_CLASS(Unrecognized);
CharBlock source;
std::variant<std::list<IgnoreTKR>, LoopCount, std::list<AssumeAligned>,
- VectorAlways, std::list<NameValue>, Unrecognized>
+ VectorAlways, std::list<NameValue>, Unroll, Unrecognized>
u;
};
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 37f51d74d23f8f..91306f70118bfd 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -2153,14 +2153,42 @@ class FirConverter : public Fortran::lower::AbstractConverter {
return builder->createIntegerConstant(loc, controlType, 1); // step
}
- void addLoopAnnotationAttr(IncrementLoopInfo &info) {
+ void addLoopAnnotationAttr(
+ IncrementLoopInfo &info,
+ llvm::SmallVectorImpl<const Fortran::parser::CompilerDirective *> &dirs) {
mlir::BoolAttr f = mlir::BoolAttr::get(builder->getContext(), false);
- mlir::LLVM::LoopVectorizeAttr va = mlir::LLVM::LoopVectorizeAttr::get(
- builder->getContext(), /*disable=*/f, {}, {}, {}, {}, {}, {});
+ mlir::BoolAttr t = mlir::BoolAttr::get(builder->getContext(), true);
+ mlir::LLVM::LoopVectorizeAttr va;
+ mlir::LLVM::LoopUnrollAttr ua;
+ bool has_attrs = false;
+ for (const auto *dir : dirs) {
+ Fortran::common::visit(
+ Fortran::common::visitors{
+ [&](const Fortran::parser::CompilerDirective::VectorAlways &) {
+ va = mlir::LLVM::LoopVectorizeAttr::get(builder->getContext(),
+ /*disable=*/f, {}, {},
+ {}, {}, {}, {});
+ has_attrs = true;
+ },
+ [&](const Fortran::parser::CompilerDirective::Unroll &u) {
+ mlir::IntegerAttr countAttr;
+ if (u.v.has_value()) {
+ countAttr = builder->getIntegerAttr(builder->getI64Type(),
+ u.v.value());
+ }
+ ua = mlir::LLVM::LoopUnrollAttr::get(
+ builder->getContext(), /*disable=*/f, /*count*/ countAttr,
+ {}, /*full*/ u.v.has_value() ? f : t, {}, {}, {});
+ has_attrs = true;
+ },
+ [&](const auto &) {}},
+ dir->u);
+ }
mlir::LLVM::LoopAnnotationAttr la = mlir::LLVM::LoopAnnotationAttr::get(
- builder->getContext(), {}, /*vectorize=*/va, {}, {}, {}, {}, {}, {}, {},
- {}, {}, {}, {}, {}, {});
- info.doLoop.setLoopAnnotationAttr(la);
+ builder->getContext(), {}, /*vectorize=*/va, {}, /*unroll*/ ua, {}, {},
+ {}, {}, {}, {}, {}, {}, {}, {}, {});
+ if (has_attrs)
+ info.doLoop.setLoopAnnotationAttr(la);
}
/// Generate FIR to begin a structured or unstructured increment loop nest.
@@ -2259,14 +2287,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
if (info.hasLocalitySpecs())
handleLocalitySpecs(info);
- for (const auto *dir : dirs) {
- Fortran::common::visit(
- Fortran::common::visitors{
- [&](const Fortran::parser::CompilerDirective::VectorAlways
- &d) { addLoopAnnotationAttr(info); },
- [&](const auto &) {}},
- dir->u);
- }
+ addLoopAnnotationAttr(info, dirs);
continue;
}
@@ -2818,6 +2839,9 @@ class FirConverter : public Fortran::lower::AbstractConverter {
[&](const Fortran::parser::CompilerDirective::VectorAlways &) {
attachDirectiveToLoop(dir, &eval);
},
+ [&](const Fortran::parser::CompilerDirective::Unroll &) {
+ attachDirectiveToLoop(dir, &eval);
+ },
[&](const auto &) {}},
dir.u);
}
diff --git a/flang/lib/Parser/Fortran-parsers.cpp b/flang/lib/Parser/Fortran-parsers.cpp
index 7cb35c1f173bb6..b5bcb53a127613 100644
--- a/flang/lib/Parser/Fortran-parsers.cpp
+++ b/flang/lib/Parser/Fortran-parsers.cpp
@@ -1293,6 +1293,7 @@ TYPE_PARSER(construct<StatOrErrmsg>("STAT =" >> statVariable) ||
// !DIR$ IGNORE_TKR [ [(tkrdmac...)] name ]...
// !DIR$ LOOP COUNT (n1[, n2]...)
// !DIR$ name[=value] [, name[=value]]...
+// !DIR$ UNROLL [n]
// !DIR$ <anything else>
constexpr auto ignore_tkr{
"IGNORE_TKR" >> optionalList(construct<CompilerDirective::IgnoreTKR>(
@@ -1305,11 +1306,14 @@ constexpr auto assumeAligned{"ASSUME_ALIGNED" >>
indirect(designator), ":"_tok >> digitString64))};
constexpr auto vectorAlways{
"VECTOR ALWAYS" >> construct<CompilerDirective::VectorAlways>()};
+constexpr auto unroll{
+ "UNROLL" >> construct<CompilerDirective::Unroll>(maybe(digitString64))};
TYPE_PARSER(beginDirective >> "DIR$ "_tok >>
sourced((construct<CompilerDirective>(ignore_tkr) ||
construct<CompilerDirective>(loopCount) ||
construct<CompilerDirective>(assumeAligned) ||
construct<CompilerDirective>(vectorAlways) ||
+ construct<CompilerDirective>(unroll) ||
construct<CompilerDirective>(
many(construct<CompilerDirective::NameValue>(
name, maybe(("="_tok || ":"_tok) >> digitString64))))) /
diff --git a/flang/lib/Parser/unparse.cpp b/flang/lib/Parser/unparse.cpp
index 7bf404bba2c3e4..8c70564de16650 100644
--- a/flang/lib/Parser/unparse.cpp
+++ b/flang/lib/Parser/unparse.cpp
@@ -1847,6 +1847,10 @@ class UnparseVisitor {
[&](const std::list<CompilerDirective::NameValue> &names) {
Walk("!DIR$ ", names, " ");
},
+ [&](const CompilerDirective::Unroll &unroll) {
+ Word("!DIR$ UNROLL");
+ Walk(" ", unroll.v);
+ },
[&](const CompilerDirective::Unrecognized &) {
Word("!DIR$ ");
Word(x.source.ToString());
diff --git a/flang/lib/Semantics/canonicalize-directives.cpp b/flang/lib/Semantics/canonicalize-directives.cpp
index 739bc3c1992ba6..b27a27618808bc 100644
--- a/flang/lib/Semantics/canonicalize-directives.cpp
+++ b/flang/lib/Semantics/canonicalize-directives.cpp
@@ -54,7 +54,9 @@ bool CanonicalizeDirectives(
}
static bool IsExecutionDirective(const parser::CompilerDirective &dir) {
- return std::holds_alternative<parser::CompilerDirective::VectorAlways>(dir.u);
+ return std::holds_alternative<parser::CompilerDirective::VectorAlways>(
+ dir.u) ||
+ std::holds_alternative<parser::CompilerDirective::Unroll>(dir.u);
}
void CanonicalizationOfDirectives::Post(parser::SpecificationPart &spec) {
@@ -110,6 +112,9 @@ void CanonicalizationOfDirectives::Post(parser::Block &block) {
common::visitors{[&](parser::CompilerDirective::VectorAlways &) {
CheckLoopDirective(*dir, block, it);
},
+ [&](parser::CompilerDirective::Unroll &) {
+ CheckLoopDirective(*dir, block, it);
+ },
[&](auto &) {}},
dir->u);
}
diff --git a/flang/lib/Semantics/resolve-names.cpp b/flang/lib/Semantics/resolve-names.cpp
index f3c2a5bf094d04..705a6e0df04d99 100644
--- a/flang/lib/Semantics/resolve-names.cpp
+++ b/flang/lib/Semantics/resolve-names.cpp
@@ -9245,7 +9245,8 @@ void ResolveNamesVisitor::Post(const parser::AssignedGotoStmt &x) {
}
void ResolveNamesVisitor::Post(const parser::CompilerDirective &x) {
- if (std::holds_alternative<parser::CompilerDirective::VectorAlways>(x.u)) {
+ if (std::holds_alternative<parser::CompilerDirective::VectorAlways>(x.u) ||
+ std::holds_alternative<parser::CompilerDirective::Unroll>(x.u)) {
return;
}
if (const auto *tkr{
diff --git a/flang/test/Integration/unroll.f90 b/flang/test/Integration/unroll.f90
new file mode 100644
index 00000000000000..9d69605e10d1b3
--- /dev/null
+++ b/flang/test/Integration/unroll.f90
@@ -0,0 +1,16 @@
+! RUN: %flang_fc1 -emit-llvm -o - %s | FileCheck %s
+
+! CHECK-LABEL: unroll_dir
+subroutine unroll_dir
+ integer :: a(10)
+ !dir$ unroll
+ ! CHECK: br i1 {{.*}}, label {{.*}}, label {{.*}}, !llvm.loop ![[ANNOTATION:.*]]
+ do i=1,10
+ a(i)=i
+ end do
+end subroutine unroll_dir
+
+! CHECK: ![[ANNOTATION]] = distinct !{![[ANNOTATION]], ![[UNROLL:.*]], ![[UNROLL_FULL:.*]]}
+! CHECK: ![[UNROLL]] = !{!"llvm.loop.unroll.enable"}
+! CHECK: ![[UNROLL_FULL]] = !{!"llvm.loop.unroll.full"}
+
diff --git a/flang/test/Lower/unroll.f90 b/flang/test/Lower/unroll.f90
new file mode 100644
index 00000000000000..229755200fd8d8
--- /dev/null
+++ b/flang/test/Lower/unroll.f90
@@ -0,0 +1,27 @@
+! RUN: %flang_fc1 -emit-hlfir -o - %s | FileCheck %s
+
+! CHECK: #loop_unroll = #llvm.loop_unroll<disable = false, full = true>
+! CHECK: #loop_annotation = #llvm.loop_annotation<unroll = #loop_unroll>
+
+! CHECK-LABEL: unroll_dir
+subroutine unroll_dir
+ integer :: a(10)
+ !dir$ unroll
+ !CHECK: fir.do_loop {{.*}} attributes {loopAnnotation = #loop_annotation}
+ do i=1,10
+ a(i)=i
+ end do
+end subroutine unroll_dir
+
+
+! CHECK-LABEL: intermediate_directive
+subroutine intermediate_directive
+ integer :: a(10)
+ !dir$ unroll
+ !dir$ unknown
+ !CHECK: fir.do_loop {{.*}} attributes {loopAnnotation = #loop_annotation}
+ do i=1,10
+ a(i)=i
+ end do
+end subroutine intermediate_directive
+
diff --git a/flang/test/Parser/compiler-directives.f90 b/flang/test/Parser/compiler-directives.f90
index 246eaf985251c6..f372a9f533a355 100644
--- a/flang/test/Parser/compiler-directives.f90
+++ b/flang/test/Parser/compiler-directives.f90
@@ -35,3 +35,14 @@ subroutine vector_always
do i=1,10
enddo
end subroutine
+
+subroutine unroll
+ !dir$ unroll
+ ! CHECK: !DIR$ UNROLL
+ do i=1,10
+ enddo
+ !dir$ unroll 2
+ ! CHECK: !DIR$ UNROLL 2
+ do i=1,10
+ enddo
+end subroutine
diff --git a/flang/test/Semantics/loop-directives.f90 b/flang/test/Semantics/loop-directives.f90
index 58fb9b8082bc1a..e20c0c9d042dbf 100644
--- a/flang/test/Semantics/loop-directives.f90
+++ b/flang/test/Semantics/loop-directives.f90
@@ -4,11 +4,15 @@
subroutine empty
! WARNING: A DO loop must follow the VECTOR ALWAYS directive
!dir$ vector always
+ ! WARNING: A DO loop must follow the UNROLL directive
+ !dir$ unroll
end subroutine empty
subroutine non_do
! WARNING: A DO loop must follow the VECTOR ALWAYS directive
!dir$ vector always
+ ! WARNING: A DO loop must follow the UNROLL directive
+ !dir$ unroll
a = 1
end subroutine non_do
@@ -16,6 +20,8 @@ subroutine execution_part
do i=1,10
! WARNING: A DO loop must follow the VECTOR ALWAYS directive
!dir$ vector always
+ ! WARNING: A DO loop must follow the UNROLL directive
+ !dir$ unroll
end do
end subroutine execution_part
@@ -28,3 +34,13 @@ subroutine test_vector_always_before_acc(a, b, c)
a(i) = b(i) + c(i)
enddo
end subroutine
+
+! OK
+subroutine test_unroll_before_acc(a, b, c)
+ real, dimension(10) :: a,b,c
+ !dir$ unroll
+ !$acc loop
+ do i=1,N
+ a(i) = b(i) + c(i)
+ enddo
+end subroutine
More information about the flang-commits
mailing list