[flang-commits] [flang] [flang] Implement !DIR$ VECTOR ALWAYS (PR #93830)

David Truby via flang-commits flang-commits at lists.llvm.org
Thu Jun 6 08:07:43 PDT 2024


https://github.com/DavidTruby updated https://github.com/llvm/llvm-project/pull/93830

>From 86122ffa79e3485850ad1addda9c1e8ae15353d9 Mon Sep 17 00:00:00 2001
From: David Truby <david.truby at arm.com>
Date: Thu, 30 May 2024 14:25:47 +0000
Subject: [PATCH 1/3] [flang] Implement !DIR$ VECTOR ALWAYS

This patch implements support for the VECTOR ALWAYS directive, which forces
vectorization to occurr when possible regardless of a decision by the cost
model. This is done by adding an attribute to the branch into the loop in LLVM
to indicate that the loop should always be vectorized.
---
 flang/include/flang/Lower/PFTBuilder.h        |  1 +
 .../include/flang/Optimizer/Dialect/FIROps.h  |  1 +
 .../include/flang/Optimizer/Dialect/FIROps.td |  3 +-
 flang/include/flang/Parser/dump-parse-tree.h  |  1 +
 flang/include/flang/Parser/parse-tree.h       |  3 +-
 flang/lib/Lower/Bridge.cpp                    | 53 +++++++++++++++++--
 .../Transforms/ControlFlowConverter.cpp       |  6 ++-
 flang/lib/Parser/Fortran-parsers.cpp          |  3 ++
 flang/lib/Parser/unparse.cpp                  |  3 ++
 flang/lib/Semantics/resolve-names.cpp         |  3 ++
 flang/test/Fir/vector-always.fir              | 42 +++++++++++++++
 flang/test/Lower/vector-always.f90            | 29 ++++++++++
 12 files changed, 141 insertions(+), 7 deletions(-)
 create mode 100644 flang/test/Fir/vector-always.fir
 create mode 100644 flang/test/Lower/vector-always.f90

diff --git a/flang/include/flang/Lower/PFTBuilder.h b/flang/include/flang/Lower/PFTBuilder.h
index 9913f584133fa..aa83f1603c2a8 100644
--- a/flang/include/flang/Lower/PFTBuilder.h
+++ b/flang/include/flang/Lower/PFTBuilder.h
@@ -347,6 +347,7 @@ struct Evaluation : EvaluationVariant {
   parser::CharBlock position{};
   std::optional<parser::Label> label{};
   std::unique_ptr<EvaluationList> evaluationList; // nested evaluations
+  llvm::SmallVector<const parser::CompilerDirective *> dirs;
   Evaluation *parentConstruct{nullptr};  // set for nodes below the top level
   Evaluation *lexicalSuccessor{nullptr}; // set for leaf nodes, some directives
   Evaluation *controlSuccessor{nullptr}; // set for some leaf nodes
diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.h b/flang/include/flang/Optimizer/Dialect/FIROps.h
index 9f07364ddb627..a21f8bbe17685 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.h
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.h
@@ -16,6 +16,7 @@
 #include "flang/Optimizer/Dialect/FortranVariableInterface.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
 #include "mlir/Interfaces/LoopLikeInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index d9c1149040066..7bf68908e37dd 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -2096,7 +2096,8 @@ def fir_DoLoopOp : region_Op<"do_loop",
     Index:$step,
     Variadic<AnyType>:$initArgs,
     OptionalAttr<UnitAttr>:$unordered,
-    OptionalAttr<UnitAttr>:$finalValue
+    OptionalAttr<UnitAttr>:$finalValue,
+    OptionalAttr<LoopAnnotationAttr>:$loop_annotation
   );
   let results = (outs Variadic<AnyType>:$results);
   let regions = (region SizedRegion<1>:$region);
diff --git a/flang/include/flang/Parser/dump-parse-tree.h b/flang/include/flang/Parser/dump-parse-tree.h
index 68ae50c312cde..8cf790650fb49 100644
--- a/flang/include/flang/Parser/dump-parse-tree.h
+++ b/flang/include/flang/Parser/dump-parse-tree.h
@@ -206,6 +206,7 @@ class ParseTreeDumper {
   NODE(CompilerDirective, IgnoreTKR)
   NODE(CompilerDirective, LoopCount)
   NODE(CompilerDirective, AssumeAligned)
+  NODE(CompilerDirective, VectorAlways)
   NODE(CompilerDirective, NameValue)
   NODE(CompilerDirective, Unrecognized)
   NODE(parser, ComplexLiteralConstant)
diff --git a/flang/include/flang/Parser/parse-tree.h b/flang/include/flang/Parser/parse-tree.h
index 0a40aa8b8f616..116e5f02ff32b 100644
--- a/flang/include/flang/Parser/parse-tree.h
+++ b/flang/include/flang/Parser/parse-tree.h
@@ -3316,6 +3316,7 @@ struct CompilerDirective {
     TUPLE_CLASS_BOILERPLATE(AssumeAligned);
     std::tuple<common::Indirection<Designator>, uint64_t> t;
   };
+  EMPTY_CLASS(VectorAlways);
   struct NameValue {
     TUPLE_CLASS_BOILERPLATE(NameValue);
     std::tuple<Name, std::optional<std::uint64_t>> t;
@@ -3323,7 +3324,7 @@ struct CompilerDirective {
   EMPTY_CLASS(Unrecognized);
   CharBlock source;
   std::variant<std::list<IgnoreTKR>, LoopCount, std::list<AssumeAligned>,
-      std::list<NameValue>, Unrecognized>
+      VectorAlways, std::list<NameValue>, Unrecognized>
       u;
 };
 
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 4e50de3e7ee9c..55d4d190d0fa4 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -1881,7 +1881,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
 
     // Increment loop begin code. (Infinite/while code was already generated.)
     if (!infiniteLoop && !whileCondition)
-      genFIRIncrementLoopBegin(incrementLoopNestInfo);
+      genFIRIncrementLoopBegin(incrementLoopNestInfo, doStmtEval.dirs);
 
     // Loop body code.
     auto iter = eval.getNestedEvaluations().begin();
@@ -1926,8 +1926,22 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     return builder->createIntegerConstant(loc, controlType, 1); // step
   }
 
+  void addLoopAnnotationAttr(IncrementLoopInfo &info) {
+    mlir::BoolAttr f = mlir::BoolAttr::get(builder->getContext(), false);
+    mlir::LLVM::LoopVectorizeAttr va = mlir::LLVM::LoopVectorizeAttr::get(
+        builder->getContext(), f, {}, {}, {}, {}, {}, {});
+    mlir::LLVM::AccessGroupAttr ag =
+        mlir::LLVM::AccessGroupAttr::get(builder->getContext());
+    mlir::LLVM::LoopAnnotationAttr la = mlir::LLVM::LoopAnnotationAttr::get(
+        builder->getContext(), {}, va, {}, {}, {}, {}, {}, {}, {}, {}, {}, {},
+        {}, {}, {ag});
+    info.doLoop.setLoopAnnotationAttr(la);
+  }
+
   /// Generate FIR to begin a structured or unstructured increment loop nest.
-  void genFIRIncrementLoopBegin(IncrementLoopNestInfo &incrementLoopNestInfo) {
+  void genFIRIncrementLoopBegin(
+      IncrementLoopNestInfo &incrementLoopNestInfo,
+      llvm::SmallVectorImpl<const Fortran::parser::CompilerDirective *> &dirs) {
     assert(!incrementLoopNestInfo.empty() && "empty loop nest");
     mlir::Location loc = toLocation();
     for (IncrementLoopInfo &info : incrementLoopNestInfo) {
@@ -1978,6 +1992,15 @@ class FirConverter : public Fortran::lower::AbstractConverter {
         }
         if (info.hasLocalitySpecs())
           handleLocalitySpecs(info);
+
+        for (const auto *dir : dirs) {
+          std::visit(
+              Fortran::common::visitors{
+                  [&](const Fortran::parser::CompilerDirective::VectorAlways
+                          &d) { addLoopAnnotationAttr(info); },
+                  [&](const auto &) {}},
+              dir->u);
+        }
         continue;
       }
 
@@ -2508,8 +2531,30 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     }
   }
 
-  void genFIR(const Fortran::parser::CompilerDirective &) {
-    // TODO
+  void attachLoopDirective(const Fortran::parser::CompilerDirective &dir,
+                           Fortran::lower::pft::Evaluation *e) {
+    while (e->isDirective()) {
+      e = e->lexicalSuccessor;
+    }
+
+    if (e->isA<Fortran::parser::NonLabelDoStmt>()) {
+      e->dirs.push_back(&dir);
+    } else {
+      fir::emitFatalError(toLocation(),
+                          "loop directive must appear before a loop");
+    }
+  }
+
+  void genFIR(const Fortran::parser::CompilerDirective &dir) {
+    Fortran::lower::pft::Evaluation &eval = getEval();
+
+    std::visit(
+        Fortran::common::visitors{
+            [&](const Fortran::parser::CompilerDirective::VectorAlways &) {
+              attachLoopDirective(dir, &eval);
+            },
+            [&](const auto &) {}},
+        dir.u);
   }
 
   void genFIR(const Fortran::parser::OpenACCConstruct &acc) {
diff --git a/flang/lib/Optimizer/Transforms/ControlFlowConverter.cpp b/flang/lib/Optimizer/Transforms/ControlFlowConverter.cpp
index a233e7fbdcd1e..b40c06de8787b 100644
--- a/flang/lib/Optimizer/Transforms/ControlFlowConverter.cpp
+++ b/flang/lib/Optimizer/Transforms/ControlFlowConverter.cpp
@@ -132,10 +132,14 @@ class CfgLoopConv : public mlir::OpRewritePattern<fir::DoLoopOp> {
     auto comparison = rewriter.create<mlir::arith::CmpIOp>(
         loc, arith::CmpIPredicate::sgt, itersLeft, zero);
 
-    rewriter.create<mlir::cf::CondBranchOp>(
+    auto cond = rewriter.create<mlir::cf::CondBranchOp>(
         loc, comparison, firstBlock, llvm::ArrayRef<mlir::Value>(), endBlock,
         llvm::ArrayRef<mlir::Value>());
 
+    if (auto ann = loop.getLoopAnnotation()) {
+      cond->setAttr("loop_annotation", *ann);
+    }
+
     // The result of the loop operation is the values of the condition block
     // arguments except the induction variable on the last iteration.
     auto args = loop.getFinalValue()
diff --git a/flang/lib/Parser/Fortran-parsers.cpp b/flang/lib/Parser/Fortran-parsers.cpp
index ff01974b549a1..d2241fb66a013 100644
--- a/flang/lib/Parser/Fortran-parsers.cpp
+++ b/flang/lib/Parser/Fortran-parsers.cpp
@@ -1276,10 +1276,13 @@ constexpr auto loopCount{
 constexpr auto assumeAligned{"ASSUME_ALIGNED" >>
     optionalList(construct<CompilerDirective::AssumeAligned>(
         indirect(designator), ":"_tok >> digitString64))};
+constexpr auto vectorAlways{
+    "VECTOR ALWAYS" >> construct<CompilerDirective::VectorAlways>()};
 TYPE_PARSER(beginDirective >> "DIR$ "_tok >>
     sourced((construct<CompilerDirective>(ignore_tkr) ||
                 construct<CompilerDirective>(loopCount) ||
                 construct<CompilerDirective>(assumeAligned) ||
+                construct<CompilerDirective>(vectorAlways) ||
                 construct<CompilerDirective>(
                     many(construct<CompilerDirective::NameValue>(
                         name, maybe(("="_tok || ":"_tok) >> digitString64))))) /
diff --git a/flang/lib/Parser/unparse.cpp b/flang/lib/Parser/unparse.cpp
index bdd968b19a43f..5a4c7550d64dd 100644
--- a/flang/lib/Parser/unparse.cpp
+++ b/flang/lib/Parser/unparse.cpp
@@ -1824,6 +1824,9 @@ class UnparseVisitor {
               Word("!DIR$ ASSUME_ALIGNED ");
               Walk(" ", assumeAligned, ", ");
             },
+            [&](const CompilerDirective::VectorAlways &valways) {
+              Word("!DIR$ VECTOR ALWAYS");
+            },
             [&](const std::list<CompilerDirective::NameValue> &names) {
               Walk("!DIR$ ", names, " ");
             },
diff --git a/flang/lib/Semantics/resolve-names.cpp b/flang/lib/Semantics/resolve-names.cpp
index a46c0f378d5d0..e65788a02b725 100644
--- a/flang/lib/Semantics/resolve-names.cpp
+++ b/flang/lib/Semantics/resolve-names.cpp
@@ -8854,6 +8854,9 @@ void ResolveNamesVisitor::Post(const parser::AssignedGotoStmt &x) {
 }
 
 void ResolveNamesVisitor::Post(const parser::CompilerDirective &x) {
+  if (const auto *dir{
+          std::get_if<parser::CompilerDirective::VectorAlways>(&x.u)})
+    return;
   if (const auto *tkr{
           std::get_if<std::list<parser::CompilerDirective::IgnoreTKR>>(&x.u)}) {
     if (currScope().IsTopLevel() ||
diff --git a/flang/test/Fir/vector-always.fir b/flang/test/Fir/vector-always.fir
new file mode 100644
index 0000000000000..b6dcf237ed59a
--- /dev/null
+++ b/flang/test/Fir/vector-always.fir
@@ -0,0 +1,42 @@
+// RUN: %flang_fc1 -emit-llvm -o - %s | FileCheck %s
+
+#access_group = #llvm.access_group<id = distinct[0]<>>
+#loop_vectorize = #llvm.loop_vectorize<disable = false>
+#loop_annotation = #llvm.loop_annotation<vectorize = #loop_vectorize, parallelAccesses = #access_group>
+
+// CHECK-LABEL: @vector_always_
+// CHECK:   br i1 {{.*}}, label {{.*}}, label {{.*}}, !llvm.loop ![[ANNOTATION:.*]]
+func.func @_QPvector_always() {
+    %c1 = arith.constant 1 : index
+    %c10_i32 = arith.constant 10 : i32
+    %c1_i32 = arith.constant 1 : i32
+    %c10 = arith.constant 10 : index
+    %0 = fir.alloca !fir.array<10xi32> {bindc_name = "a", uniq_name = "_QFvector_alwaysEa"}
+    %1 = fir.shape %c10 : (index) -> !fir.shape<1>
+    %2 = fir.declare %0(%1) {uniq_name = "_QFvector_alwaysEa"} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> !fir.ref<!fir.array<10xi32>>
+    %3 = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFvector_alwaysEi"}
+    %4 = fir.declare %3 {uniq_name = "_QFvector_alwaysEi"} : (!fir.ref<i32>) -> !fir.ref<i32>
+    %5 = fir.convert %c1_i32 : (i32) -> index
+    %6 = fir.convert %c10_i32 : (i32) -> index
+    %7 = fir.convert %5 : (index) -> i32
+    %8:2 = fir.do_loop %arg0 = %5 to %6 step %c1 iter_args(%arg1 = %7) -> (index, i32) attributes {loop_annotation = #loop_annotation} {
+      fir.store %arg1 to %4 : !fir.ref<i32>
+      %9 = fir.load %4 : !fir.ref<i32>
+      %10 = fir.load %4 : !fir.ref<i32>
+      %11 = fir.convert %10 : (i32) -> i64
+      %12 = fir.array_coor %2(%1) %11 : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>, i64) -> !fir.ref<i32>
+      fir.store %9 to %12 : !fir.ref<i32>
+      %13 = arith.addi %arg0, %c1 : index
+      %14 = fir.convert %c1 : (index) -> i32
+      %15 = fir.load %4 : !fir.ref<i32>
+      %16 = arith.addi %15, %14 : i32
+      fir.result %13, %16 : index, i32
+    }
+    fir.store %8#1 to %4 : !fir.ref<i32>
+    return
+  }
+
+// CHECK: ![[ANNOTATION]] = distinct !{![[ANNOTATION]], ![[VECTORIZE:.*]], ![[PAR_ACCESS:.*]]}
+// CHECK: ![[VECTORIZE]] = !{!"llvm.loop.vectorize.enable", i1 true}
+// CHECK: ![[PAR_ACCESS]] = !{!"llvm.loop.parallel_accesses", ![[DISTINCT:.*]]}
+// CHECK: ![[DISTINCT]] = distinct !{}
diff --git a/flang/test/Lower/vector-always.f90 b/flang/test/Lower/vector-always.f90
new file mode 100644
index 0000000000000..1994163626f16
--- /dev/null
+++ b/flang/test/Lower/vector-always.f90
@@ -0,0 +1,29 @@
+! RUN: %flang_fc1 -emit-fir -o - %s | FileCheck %s
+
+! CHECK: #access_group = #llvm.access_group<id = distinct[0]<>>
+! CHECK: #access_group1 = #llvm.access_group<id = distinct[1]<>>
+! CHECK: #loop_vectorize = #llvm.loop_vectorize<disable = false>
+! CHECK: #loop_annotation = #llvm.loop_annotation<vectorize = #loop_vectorize, parallelAccesses = #access_group>
+! CHECK: #loop_annotation1 = #llvm.loop_annotation<vectorize = #loop_vectorize, parallelAccesses = #access_group1>
+
+! CHECK-LABEL: vector_always
+subroutine vector_always
+  integer :: a(10)
+  !dir$ vector always
+  !CHECK: fir.do_loop {{.*}} attributes {loop_annotation = #loop_annotation}
+  do i=1,10
+     a(i)=i
+  end do
+end subroutine vector_always
+
+
+! CHECK-LABEL: intermediate_directive
+subroutine intermediate_directive
+  integer :: a(10)
+  !dir$ vector always
+  !dir$ unknown
+  !CHECK: fir.do_loop {{.*}} attributes {loop_annotation = #loop_annotation1}
+  do i=1,10
+     a(i)=i
+  end do
+end subroutine intermediate_directive

>From e11c54bbe838989deae9a4a6b472f2c8e44b00e7 Mon Sep 17 00:00:00 2001
From: David Truby <david.truby at arm.com>
Date: Fri, 31 May 2024 13:25:29 +0000
Subject: [PATCH 2/3] nit fixes for review

---
 flang/include/flang/Lower/PFTBuilder.h          |  3 ++-
 flang/include/flang/Optimizer/Dialect/FIROps.td |  2 +-
 flang/lib/Lower/Bridge.cpp                      | 12 +++++-------
 flang/lib/Semantics/resolve-names.cpp           |  7 +++++--
 flang/test/Fir/vector-always.fir                |  2 +-
 flang/test/Lower/vector-always.f90              |  4 ++--
 6 files changed, 16 insertions(+), 14 deletions(-)

diff --git a/flang/include/flang/Lower/PFTBuilder.h b/flang/include/flang/Lower/PFTBuilder.h
index aa83f1603c2a8..8bc5e02a23dcd 100644
--- a/flang/include/flang/Lower/PFTBuilder.h
+++ b/flang/include/flang/Lower/PFTBuilder.h
@@ -347,7 +347,8 @@ struct Evaluation : EvaluationVariant {
   parser::CharBlock position{};
   std::optional<parser::Label> label{};
   std::unique_ptr<EvaluationList> evaluationList; // nested evaluations
-  llvm::SmallVector<const parser::CompilerDirective *> dirs;
+  // associated compiler directives
+  llvm::SmallVector<const parser::CompilerDirective *, 1> dirs;
   Evaluation *parentConstruct{nullptr};  // set for nodes below the top level
   Evaluation *lexicalSuccessor{nullptr}; // set for leaf nodes, some directives
   Evaluation *controlSuccessor{nullptr}; // set for some leaf nodes
diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index 7bf68908e37dd..a7b50c892495b 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -2097,7 +2097,7 @@ def fir_DoLoopOp : region_Op<"do_loop",
     Variadic<AnyType>:$initArgs,
     OptionalAttr<UnitAttr>:$unordered,
     OptionalAttr<UnitAttr>:$finalValue,
-    OptionalAttr<LoopAnnotationAttr>:$loop_annotation
+    OptionalAttr<LoopAnnotationAttr>:$loopAnnotation
   );
   let results = (outs Variadic<AnyType>:$results);
   let regions = (region SizedRegion<1>:$region);
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 55d4d190d0fa4..3cc20a1e7e073 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -2531,18 +2531,16 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     }
   }
 
-  void attachLoopDirective(const Fortran::parser::CompilerDirective &dir,
+  void attachDirectiveToLoop(const Fortran::parser::CompilerDirective &dir,
                            Fortran::lower::pft::Evaluation *e) {
-    while (e->isDirective()) {
+    while (e->isDirective())
       e = e->lexicalSuccessor;
-    }
 
-    if (e->isA<Fortran::parser::NonLabelDoStmt>()) {
+    if (e->isA<Fortran::parser::NonLabelDoStmt>())
       e->dirs.push_back(&dir);
-    } else {
+    else
       fir::emitFatalError(toLocation(),
                           "loop directive must appear before a loop");
-    }
   }
 
   void genFIR(const Fortran::parser::CompilerDirective &dir) {
@@ -2551,7 +2549,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     std::visit(
         Fortran::common::visitors{
             [&](const Fortran::parser::CompilerDirective::VectorAlways &) {
-              attachLoopDirective(dir, &eval);
+              attachDirectiveToLoop(dir, &eval);
             },
             [&](const auto &) {}},
         dir.u);
diff --git a/flang/lib/Semantics/resolve-names.cpp b/flang/lib/Semantics/resolve-names.cpp
index e65788a02b725..637a286bb3c97 100644
--- a/flang/lib/Semantics/resolve-names.cpp
+++ b/flang/lib/Semantics/resolve-names.cpp
@@ -8854,9 +8854,12 @@ void ResolveNamesVisitor::Post(const parser::AssignedGotoStmt &x) {
 }
 
 void ResolveNamesVisitor::Post(const parser::CompilerDirective &x) {
-  if (const auto *dir{
-          std::get_if<parser::CompilerDirective::VectorAlways>(&x.u)})
+  //if (const auto *dir{
+  //        std::get_if<parser::CompilerDirective::VectorAlways>(&x.u)})
+
+  if (std::holds_alternative<parser::CompilerDirective::VectorAlways>(x.u)) {
     return;
+  }
   if (const auto *tkr{
           std::get_if<std::list<parser::CompilerDirective::IgnoreTKR>>(&x.u)}) {
     if (currScope().IsTopLevel() ||
diff --git a/flang/test/Fir/vector-always.fir b/flang/test/Fir/vector-always.fir
index b6dcf237ed59a..e7e07c41ffee9 100644
--- a/flang/test/Fir/vector-always.fir
+++ b/flang/test/Fir/vector-always.fir
@@ -19,7 +19,7 @@ func.func @_QPvector_always() {
     %5 = fir.convert %c1_i32 : (i32) -> index
     %6 = fir.convert %c10_i32 : (i32) -> index
     %7 = fir.convert %5 : (index) -> i32
-    %8:2 = fir.do_loop %arg0 = %5 to %6 step %c1 iter_args(%arg1 = %7) -> (index, i32) attributes {loop_annotation = #loop_annotation} {
+    %8:2 = fir.do_loop %arg0 = %5 to %6 step %c1 iter_args(%arg1 = %7) -> (index, i32) attributes {loopAnnotation = #loop_annotation} {
       fir.store %arg1 to %4 : !fir.ref<i32>
       %9 = fir.load %4 : !fir.ref<i32>
       %10 = fir.load %4 : !fir.ref<i32>
diff --git a/flang/test/Lower/vector-always.f90 b/flang/test/Lower/vector-always.f90
index 1994163626f16..5806054078f7f 100644
--- a/flang/test/Lower/vector-always.f90
+++ b/flang/test/Lower/vector-always.f90
@@ -10,7 +10,7 @@
 subroutine vector_always
   integer :: a(10)
   !dir$ vector always
-  !CHECK: fir.do_loop {{.*}} attributes {loop_annotation = #loop_annotation}
+  !CHECK: fir.do_loop {{.*}} attributes {loopAnnotation = #loop_annotation}
   do i=1,10
      a(i)=i
   end do
@@ -22,7 +22,7 @@ subroutine intermediate_directive
   integer :: a(10)
   !dir$ vector always
   !dir$ unknown
-  !CHECK: fir.do_loop {{.*}} attributes {loop_annotation = #loop_annotation1}
+  !CHECK: fir.do_loop {{.*}} attributes {loopAnnotation = #loop_annotation1}
   do i=1,10
      a(i)=i
   end do

>From 213df8e45b2aacdf8b4af5c926ef9aaeb79bad0b Mon Sep 17 00:00:00 2001
From: David Truby <david.truby at arm.com>
Date: Thu, 6 Jun 2024 14:28:45 +0000
Subject: [PATCH 3/3] Add check for directive location in semantics

Refactor tests
---
 flang/docs/Directives.md                      |   3 +
 flang/include/flang/Parser/dump-parse-tree.h  |   4 +-
 flang/lib/Lower/Bridge.cpp                    |   7 +-
 flang/lib/Semantics/CMakeLists.txt            |   1 +
 .../lib/Semantics/canonicalize-directives.cpp | 122 ++++++++++++++++++
 flang/lib/Semantics/canonicalize-directives.h |  22 ++++
 flang/lib/Semantics/resolve-names.cpp         |   6 +-
 flang/lib/Semantics/semantics.cpp             |   2 +
 flang/test/Fir/vector-always-cfg.fir          |  32 +++++
 flang/test/Fir/vector-always.fir              |  43 ++----
 flang/test/Integration/vector-always.f90      |  16 +++
 flang/test/Parser/compiler-directives.f90     |   9 +-
 flang/test/Semantics/loop-directives.f90      |  13 ++
 13 files changed, 238 insertions(+), 42 deletions(-)
 create mode 100644 flang/lib/Semantics/canonicalize-directives.cpp
 create mode 100644 flang/lib/Semantics/canonicalize-directives.h
 create mode 100644 flang/test/Fir/vector-always-cfg.fir
 create mode 100644 flang/test/Integration/vector-always.f90
 create mode 100644 flang/test/Semantics/loop-directives.f90

diff --git a/flang/docs/Directives.md b/flang/docs/Directives.md
index fe08b4f855f23..4bd5f39f14243 100644
--- a/flang/docs/Directives.md
+++ b/flang/docs/Directives.md
@@ -36,3 +36,6 @@ A list of non-standard directives supported by Flang
   and is limited to 256.
   [This directive is currently recognised by the parser, but not
   handled by the other parts of the compiler].
+* `!dir$ vector always` forces vectorization on the following loop regardless 
+  of cost model decisions. The loop must still be vectorizable.
+  [This directive currently only works on plain do loops without labels].
diff --git a/flang/include/flang/Parser/dump-parse-tree.h b/flang/include/flang/Parser/dump-parse-tree.h
index 8cf790650fb49..bae2a00b2e893 100644
--- a/flang/include/flang/Parser/dump-parse-tree.h
+++ b/flang/include/flang/Parser/dump-parse-tree.h
@@ -203,12 +203,12 @@ class ParseTreeDumper {
   NODE(parser, CommonStmt)
   NODE(CommonStmt, Block)
   NODE(parser, CompilerDirective)
+  NODE(CompilerDirective, AssumeAligned)
   NODE(CompilerDirective, IgnoreTKR)
   NODE(CompilerDirective, LoopCount)
-  NODE(CompilerDirective, AssumeAligned)
-  NODE(CompilerDirective, VectorAlways)
   NODE(CompilerDirective, NameValue)
   NODE(CompilerDirective, Unrecognized)
+  NODE(CompilerDirective, VectorAlways)
   NODE(parser, ComplexLiteralConstant)
   NODE(parser, ComplexPart)
   NODE(parser, ComponentArraySpec)
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 3cc20a1e7e073..ece68230ca21f 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -1929,12 +1929,13 @@ class FirConverter : public Fortran::lower::AbstractConverter {
   void addLoopAnnotationAttr(IncrementLoopInfo &info) {
     mlir::BoolAttr f = mlir::BoolAttr::get(builder->getContext(), false);
     mlir::LLVM::LoopVectorizeAttr va = mlir::LLVM::LoopVectorizeAttr::get(
-        builder->getContext(), f, {}, {}, {}, {}, {}, {});
+        builder->getContext(), /*disable=*/f, {}, {}, {}, {}, {}, {});
+    // Create distinct access group
     mlir::LLVM::AccessGroupAttr ag =
         mlir::LLVM::AccessGroupAttr::get(builder->getContext());
     mlir::LLVM::LoopAnnotationAttr la = mlir::LLVM::LoopAnnotationAttr::get(
-        builder->getContext(), {}, va, {}, {}, {}, {}, {}, {}, {}, {}, {}, {},
-        {}, {}, {ag});
+        builder->getContext(), {}, /*vectorize=*/va, {}, {}, {}, {}, {}, {}, {},
+        {}, {}, {}, {}, {}, /*parallelAccess=*/{ag});
     info.doLoop.setLoopAnnotationAttr(la);
   }
 
diff --git a/flang/lib/Semantics/CMakeLists.txt b/flang/lib/Semantics/CMakeLists.txt
index 809206565fc1c..41406ecf50e00 100644
--- a/flang/lib/Semantics/CMakeLists.txt
+++ b/flang/lib/Semantics/CMakeLists.txt
@@ -2,6 +2,7 @@ add_flang_library(FortranSemantics
   assignment.cpp
   attr.cpp
   canonicalize-acc.cpp
+  canonicalize-directives.cpp
   canonicalize-do.cpp
   canonicalize-omp.cpp
   check-acc-structure.cpp
diff --git a/flang/lib/Semantics/canonicalize-directives.cpp b/flang/lib/Semantics/canonicalize-directives.cpp
new file mode 100644
index 0000000000000..a3908127d9e8b
--- /dev/null
+++ b/flang/lib/Semantics/canonicalize-directives.cpp
@@ -0,0 +1,122 @@
+//===-- lib/Semantics/check-directives.cpp --------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "canonicalize-directives.h"
+#include "flang/Parser/parse-tree-visitor.h"
+
+namespace Fortran::semantics {
+
+using namespace parser::literals;
+
+// Check that directives are associated with the correct constructs
+class CanonicalizationOfDirectives {
+public:
+  CanonicalizationOfDirectives(parser::Messages &messages)
+      : messages_{messages} {}
+
+  template <typename T> bool Pre(T &) { return true; }
+  template <typename T> void Post(T &) {}
+
+  // Move directives that must appear in the Execution part out of the
+  // Specification part.
+  void Post(parser::SpecificationPart &spec);
+  bool Pre(parser::ExecutionPart &x);
+
+  // Ensure that directives associated with constructs appear accompanying the
+  // construct.
+  void Post(parser::Block &block);
+
+private:
+  // Ensure that loop directives appear immediately before a loop.
+  void CheckLoopDirective(parser::CompilerDirective &dir, parser::Block &block,
+      std::list<parser::ExecutionPartConstruct>::iterator it);
+
+  parser::Messages &messages_;
+
+  // Directives to be moved to the Execution part from the Specification part.
+  std::list<common::Indirection<parser::CompilerDirective>>
+      directivesToConvert_;
+};
+
+bool CanonicalizeDirectives(
+    parser::Messages &messages, parser::Program &program) {
+  CanonicalizationOfDirectives dirs{messages};
+  Walk(program, dirs);
+  return !messages.AnyFatalError();
+}
+
+static bool IsExecutionDirective(const parser::CompilerDirective &dir) {
+  return std::holds_alternative<parser::CompilerDirective::VectorAlways>(dir.u);
+}
+
+void CanonicalizationOfDirectives::Post(parser::SpecificationPart &spec) {
+  auto &list{
+      std::get<std::list<common::Indirection<parser::CompilerDirective>>>(
+          spec.t)};
+  for (auto it{list.begin()}; it != list.end();) {
+    if (IsExecutionDirective(it->value())) {
+      directivesToConvert_.emplace_back(std::move(*it));
+      it = list.erase(it);
+    } else {
+      ++it;
+    }
+  }
+}
+
+bool CanonicalizationOfDirectives::Pre(parser::ExecutionPart &x) {
+  auto origFirst{x.v.begin()};
+  for (auto &dir : directivesToConvert_) {
+    x.v.insert(origFirst,
+        parser::ExecutionPartConstruct{
+            parser::ExecutableConstruct{std::move(dir)}});
+  }
+
+  directivesToConvert_.clear();
+  return true;
+}
+
+template <typename T> T *GetConstructIf(parser::ExecutionPartConstruct &x) {
+  if (auto *y{std::get_if<parser::ExecutableConstruct>(&x.u)}) {
+    if (auto *z{std::get_if<common::Indirection<T>>(&y->u)}) {
+      return &z->value();
+    }
+  }
+  return nullptr;
+}
+
+void CanonicalizationOfDirectives::CheckLoopDirective(
+    parser::CompilerDirective &dir, parser::Block &block,
+    std::list<parser::ExecutionPartConstruct>::iterator it) {
+
+  // Skip over this and other compiler directives
+  while (GetConstructIf<parser::CompilerDirective>(*it)) {
+    ++it;
+  }
+
+  if (it == block.end() || !GetConstructIf<parser::DoConstruct>(*it)) {
+    std::string s{parser::ToUpperCaseLetters(dir.source.ToString())};
+    s.pop_back(); // Remove trailing newline from source string
+    messages_.Say(
+        dir.source, "A DO loop must follow the %s directive"_err_en_US, s);
+  }
+}
+
+void CanonicalizationOfDirectives::Post(parser::Block &block) {
+  for (auto it = block.begin(); it != block.end(); ++it) {
+    if (auto *dir{GetConstructIf<parser::CompilerDirective>(*it)}) {
+      std::visit(
+          common::visitors{[&](parser::CompilerDirective::VectorAlways &) {
+                             CheckLoopDirective(*dir, block, it);
+                           },
+              [&](auto &) {}},
+          dir->u);
+    }
+  }
+}
+
+} // namespace Fortran::semantics
diff --git a/flang/lib/Semantics/canonicalize-directives.h b/flang/lib/Semantics/canonicalize-directives.h
new file mode 100644
index 0000000000000..52c74d91dbff8
--- /dev/null
+++ b/flang/lib/Semantics/canonicalize-directives.h
@@ -0,0 +1,22 @@
+//===-- lib/Semantics/check-directives.h ------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef FORTRAN_SEMANTICS_CHECK_DIRECTIVES_H_
+#define FORTRAN_SEMANTICS_CHECK_DIRECTIVES_H_
+
+namespace Fortran::parser {
+struct Program;
+class Messages;
+} // namespace Fortran::parser
+
+namespace Fortran::semantics {
+bool CanonicalizeDirectives(
+    parser::Messages &messages, parser::Program &program);
+}
+
+#endif // FORTRAN_SEMANTICS_CHECK_DIRECTIVES_H_
diff --git a/flang/lib/Semantics/resolve-names.cpp b/flang/lib/Semantics/resolve-names.cpp
index 637a286bb3c97..d18b429c49593 100644
--- a/flang/lib/Semantics/resolve-names.cpp
+++ b/flang/lib/Semantics/resolve-names.cpp
@@ -8854,10 +8854,8 @@ void ResolveNamesVisitor::Post(const parser::AssignedGotoStmt &x) {
 }
 
 void ResolveNamesVisitor::Post(const parser::CompilerDirective &x) {
-  //if (const auto *dir{
-  //        std::get_if<parser::CompilerDirective::VectorAlways>(&x.u)})
-
-  if (std::holds_alternative<parser::CompilerDirective::VectorAlways>(x.u)) {
+  if (const auto *dir{
+          std::get_if<parser::CompilerDirective::VectorAlways>(&x.u)}) {
     return;
   }
   if (const auto *tkr{
diff --git a/flang/lib/Semantics/semantics.cpp b/flang/lib/Semantics/semantics.cpp
index d51cc62d804e8..1bb0679b75110 100644
--- a/flang/lib/Semantics/semantics.cpp
+++ b/flang/lib/Semantics/semantics.cpp
@@ -9,6 +9,7 @@
 #include "flang/Semantics/semantics.h"
 #include "assignment.h"
 #include "canonicalize-acc.h"
+#include "canonicalize-directives.h"
 #include "canonicalize-do.h"
 #include "canonicalize-omp.h"
 #include "check-acc-structure.h"
@@ -599,6 +600,7 @@ bool Semantics::Perform() {
       CanonicalizeAcc(context_.messages(), program_) &&
       CanonicalizeOmp(context_.messages(), program_) &&
       CanonicalizeCUDA(program_) &&
+      CanonicalizeDirectives(context_.messages(), program_) &&
       PerformStatementSemantics(context_, program_) &&
       ModFileWriter{context_}.WriteAll();
 }
diff --git a/flang/test/Fir/vector-always-cfg.fir b/flang/test/Fir/vector-always-cfg.fir
new file mode 100644
index 0000000000000..45c2ea056a707
--- /dev/null
+++ b/flang/test/Fir/vector-always-cfg.fir
@@ -0,0 +1,32 @@
+// RUN: fir-opt --fir-to-llvm-ir %s | FileCheck %s
+
+#access_group = #llvm.access_group<id = distinct[0]<>>
+// CHECK: #[[ACCESS:.*]] = #llvm.access_group<id = distinct[0]<>>
+#loop_vectorize = #llvm.loop_vectorize<disable = false>
+// CHECK: #[[VECTORIZE:.*]] = #llvm.loop_vectorize<disable = false>
+#loop_annotation = #llvm.loop_annotation<vectorize = #loop_vectorize, parallelAccesses = #access_group>
+// CHECK: #[[ANNOTATION:.*]] = #llvm.loop_annotation<vectorize = #[[VECTORIZE]], parallelAccesses = #[[ACCESS]]>
+
+func.func @_QPvector_always() -> i32 {
+  %c1 = arith.constant 1 : index
+  %c10_i32 = arith.constant 10 : i32
+  %c1_i32 = arith.constant 1 : i32
+  %c10 = arith.constant 10 : index
+  %0 = arith.subi %c10, %c1 : index
+  %1 = arith.addi %0, %c1 : index
+  %2 = arith.divsi %1, %c1 : index
+  cf.br ^bb1(%c1, %c1_i32, %2 : index, i32, index)
+^bb1(%3: index, %4: i32, %5: index):  // 2 preds: ^bb0, ^bb2
+  %c0 = arith.constant 0 : index
+  %6 = arith.cmpi sgt, %5, %c0 : index
+  cf.cond_br %6, ^bb2, ^bb3 {loop_annotation = #loop_annotation}
+// CHECK:   llvm.cond_br %{{.*}}, ^{{.*}}, ^{{.*}} {loop_annotation = #[[ANNOTATION]]}
+^bb2:  // pred: ^bb1
+  %7 = arith.addi %3, %c1 : index
+  %c1_0 = arith.constant 1 : index
+  %8 = arith.subi %5, %c1_0 : index
+  cf.br ^bb1(%7, %c1_i32, %8 : index, i32, index)
+^bb3:  // pred: ^bb1
+  return %4 : i32
+}
+
diff --git a/flang/test/Fir/vector-always.fir b/flang/test/Fir/vector-always.fir
index e7e07c41ffee9..4ebb7912e137b 100644
--- a/flang/test/Fir/vector-always.fir
+++ b/flang/test/Fir/vector-always.fir
@@ -1,42 +1,21 @@
-// RUN: %flang_fc1 -emit-llvm -o - %s | FileCheck %s
+// RUN: fir-opt --cfg-conversion %s | FileCheck %s
 
 #access_group = #llvm.access_group<id = distinct[0]<>>
+// CHECK: #[[ACCESS:.*]] = #llvm.access_group<id = distinct[0]<>>
 #loop_vectorize = #llvm.loop_vectorize<disable = false>
+// CHECK: #[[VECTORIZE:.*]] = #llvm.loop_vectorize<disable = false>
 #loop_annotation = #llvm.loop_annotation<vectorize = #loop_vectorize, parallelAccesses = #access_group>
+// CHECK: #[[ANNOTATION:.*]] = #llvm.loop_annotation<vectorize = #[[VECTORIZE]], parallelAccesses = #[[ACCESS]]>
 
-// CHECK-LABEL: @vector_always_
-// CHECK:   br i1 {{.*}}, label {{.*}}, label {{.*}}, !llvm.loop ![[ANNOTATION:.*]]
-func.func @_QPvector_always() {
+// CHECK-LABEL: @_QPvector_always
+func.func @_QPvector_always() -> i32 {
     %c1 = arith.constant 1 : index
     %c10_i32 = arith.constant 10 : i32
     %c1_i32 = arith.constant 1 : i32
     %c10 = arith.constant 10 : index
-    %0 = fir.alloca !fir.array<10xi32> {bindc_name = "a", uniq_name = "_QFvector_alwaysEa"}
-    %1 = fir.shape %c10 : (index) -> !fir.shape<1>
-    %2 = fir.declare %0(%1) {uniq_name = "_QFvector_alwaysEa"} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> !fir.ref<!fir.array<10xi32>>
-    %3 = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFvector_alwaysEi"}
-    %4 = fir.declare %3 {uniq_name = "_QFvector_alwaysEi"} : (!fir.ref<i32>) -> !fir.ref<i32>
-    %5 = fir.convert %c1_i32 : (i32) -> index
-    %6 = fir.convert %c10_i32 : (i32) -> index
-    %7 = fir.convert %5 : (index) -> i32
-    %8:2 = fir.do_loop %arg0 = %5 to %6 step %c1 iter_args(%arg1 = %7) -> (index, i32) attributes {loopAnnotation = #loop_annotation} {
-      fir.store %arg1 to %4 : !fir.ref<i32>
-      %9 = fir.load %4 : !fir.ref<i32>
-      %10 = fir.load %4 : !fir.ref<i32>
-      %11 = fir.convert %10 : (i32) -> i64
-      %12 = fir.array_coor %2(%1) %11 : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>, i64) -> !fir.ref<i32>
-      fir.store %9 to %12 : !fir.ref<i32>
-      %13 = arith.addi %arg0, %c1 : index
-      %14 = fir.convert %c1 : (index) -> i32
-      %15 = fir.load %4 : !fir.ref<i32>
-      %16 = arith.addi %15, %14 : i32
-      fir.result %13, %16 : index, i32
+// CHECK:   cf.cond_br %{{.*}}, ^{{.*}}, ^{{.*}} {loop_annotation = #[[ANNOTATION]]}
+    %8:2 = fir.do_loop %arg0 = %c1 to %c10 step %c1 iter_args(%arg1 = %c1_i32) -> (index, i32) attributes {loopAnnotation = #loop_annotation} {
+      fir.result %c1, %c1_i32 : index, i32
     }
-    fir.store %8#1 to %4 : !fir.ref<i32>
-    return
-  }
-
-// CHECK: ![[ANNOTATION]] = distinct !{![[ANNOTATION]], ![[VECTORIZE:.*]], ![[PAR_ACCESS:.*]]}
-// CHECK: ![[VECTORIZE]] = !{!"llvm.loop.vectorize.enable", i1 true}
-// CHECK: ![[PAR_ACCESS]] = !{!"llvm.loop.parallel_accesses", ![[DISTINCT:.*]]}
-// CHECK: ![[DISTINCT]] = distinct !{}
+    return %8#1 : i32
+  }
\ No newline at end of file
diff --git a/flang/test/Integration/vector-always.f90 b/flang/test/Integration/vector-always.f90
new file mode 100644
index 0000000000000..5aa04248886cd
--- /dev/null
+++ b/flang/test/Integration/vector-always.f90
@@ -0,0 +1,16 @@
+! RUN: %flang_fc1 -emit-llvm -o - %s | FileCheck %s
+
+! CHECK-LABEL: vector_always
+subroutine vector_always
+  integer :: a(10)
+  !dir$ vector always
+  ! CHECK:   br i1 {{.*}}, label {{.*}}, label {{.*}}, !llvm.loop ![[ANNOTATION:.*]]
+  do i=1,10
+     a(i)=i
+  end do
+end subroutine vector_always
+
+! CHECK: ![[ANNOTATION]] = distinct !{![[ANNOTATION]], ![[VECTORIZE:.*]], ![[PAR_ACCESS:.*]]}
+! CHECK: ![[VECTORIZE]] = !{!"llvm.loop.vectorize.enable", i1 true}
+! CHECK: ![[PAR_ACCESS]] = !{!"llvm.loop.parallel_accesses", ![[DISTINCT:.*]]}
+! CHECK: ![[DISTINCT]] = distinct !{}
diff --git a/flang/test/Parser/compiler-directives.f90 b/flang/test/Parser/compiler-directives.f90
index d4c99ae12f14e..246eaf985251c 100644
--- a/flang/test/Parser/compiler-directives.f90
+++ b/flang/test/Parser/compiler-directives.f90
@@ -1,4 +1,4 @@
-! RUN: %flang_fc1 -fdebug-unparse %s 2>&1
+! RUN: %flang_fc1 -fdebug-unparse %s 2>&1 | FileCheck %s
 
 ! Test that compiler directives can appear in various places.
 
@@ -28,3 +28,10 @@ module m
      !dir$  align : 1024 :: d
   end type stuff
 end
+
+subroutine vector_always
+  !dir$ vector always
+  ! CHECK: !DIR$ VECTOR ALWAYS
+  do i=1,10
+  enddo
+end subroutine
diff --git a/flang/test/Semantics/loop-directives.f90 b/flang/test/Semantics/loop-directives.f90
new file mode 100644
index 0000000000000..9c7e6dadad3bd
--- /dev/null
+++ b/flang/test/Semantics/loop-directives.f90
@@ -0,0 +1,13 @@
+! RUN: %python %S/test_errors.py %s %flang
+
+program empty
+  ! ERROR: A DO loop must follow the VECTOR ALWAYS directive
+  !dir$ vector always
+end program empty
+
+program non_do
+  ! ERROR: A DO loop must follow the VECTOR ALWAYS directive
+  !dir$ vector always
+  a = 1
+end program non_do
+



More information about the flang-commits mailing list