[flang] [llvm] [mlir] [MLIR][OpenMP] Lowering nontemporal clause to LLVM IR for SIMD directive (PR #118751)

Kaviya Rajendiran via llvm-commits llvm-commits at lists.llvm.org
Tue Apr 22 23:56:57 PDT 2025


https://github.com/kaviya2510 updated https://github.com/llvm/llvm-project/pull/118751

>From a7a485940e3582e5375a39599a8676d2fabb3388 Mon Sep 17 00:00:00 2001
From: Kaviya Rajendiran <kaviyara2000 at gmail.com>
Date: Fri, 6 Dec 2024 14:49:39 +0530
Subject: [PATCH 1/5] [MLIR][OpenMP] Lowering nontemporal clause to LLVM IR for
 SIMD directive

---
 .../llvm/Frontend/OpenMP/OMPIRBuilder.h       |  2 +-
 llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp     | 85 +++++++++++++++-
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      | 15 +--
 .../Target/LLVMIR/openmp-nontemporal.mlir     | 96 +++++++++++++++++++
 mlir/test/Target/LLVMIR/openmp-todo.mlir      | 13 ---
 5 files changed, 190 insertions(+), 21 deletions(-)
 create mode 100644 mlir/test/Target/LLVMIR/openmp-nontemporal.mlir

diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 80b4aa2bd2855..fc726ec6cf4b4 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -1245,7 +1245,7 @@ class OpenMPIRBuilder {
   void applySimd(CanonicalLoopInfo *Loop,
                  MapVector<Value *, Value *> AlignedVars, Value *IfCond,
                  omp::OrderKind Order, ConstantInt *Simdlen,
-                 ConstantInt *Safelen);
+                 ConstantInt *Safelen, ArrayRef<Value *> NontempralVars = {});
 
   /// Generator for '#omp flush'
   ///
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index e34e93442ff85..bff201a1377c8 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -5385,10 +5385,86 @@ OpenMPIRBuilder::getOpenMPDefaultSimdAlign(const Triple &TargetTriple,
   return 0;
 }
 
+static void appendNontemporalVars(BasicBlock *Block,
+                                  SmallVectorImpl<Value *> &NontemporalVars) {
+  for (Instruction &I : *Block) {
+    if (const CallInst *CI = dyn_cast<CallInst>(&I)) {
+      if (CI->getIntrinsicID() == Intrinsic::memcpy) {
+        llvm::Value *DestPtr = CI->getArgOperand(0);
+        llvm::Value *SrcPtr = CI->getArgOperand(1);
+        for (const llvm::Value *Var : NontemporalVars) {
+          if (Var == SrcPtr) {
+            NontemporalVars.push_back(DestPtr);
+            break;
+          }
+        }
+      }
+    }
+  }
+}
+
+/** Attach nontemporal metadata to the load/store instructions of nontemporal
+ *  variables of \p Block
+ * Nontemporal variables may be a scalar, fixed size or allocatable
+ * or pointer array
+ *
+ * Example scenarios for nontemporal variables:
+ * Case 1: Scalar variable
+ *    If the nontemporal variable is a scalar, it is allocated on stack.Load and
+ * store instructions directly access the alloca pointer of the scalar
+ * variable for fetching information about scalar variable or  writing
+ * into the scalar variable. Mark those load and store instructions as
+ * non-temporal.
+ *
+ * Case 2: Fixed Size array
+ *    If the nontemporal variable is a fixed-size array, it is allocated
+ * as a contiguous block of memory. It uses one GEP instruction, to compute the
+ * address of each individual array elements and perform load or store
+ * operation on it. Mark those load and store instructions as non-temporal.
+ *
+ * Case 3: Allocatable array
+ *     For an allocatable array, which might involve runtime type descriptor,
+ * needs to navigate through descriptors using two or more GEP and load
+ * instructions to compute the address of each individual element in an array.
+ * Mark those load or store which access the individual array elements as
+ * non-temporal.
+ */
+static void addNonTemporalMetadata(BasicBlock *Block, MDNode *Nontemporal,
+                                   SmallVectorImpl<Value *> &NontemporalVars) {
+  appendNontemporalVars(Block, NontemporalVars);
+  for (Instruction &I : *Block) {
+    llvm::Value *mem_ptr = nullptr;
+    bool MetadataFlag = true;
+    if (llvm::LoadInst *li = dyn_cast<llvm::LoadInst>(&I)) {
+      if (!(li->getType()->isPointerTy()))
+        mem_ptr = li->getPointerOperand();
+    } else if (llvm::StoreInst *si = dyn_cast<llvm::StoreInst>(&I))
+      mem_ptr = si->getPointerOperand();
+    if (mem_ptr) {
+      while (mem_ptr && !(isa<llvm::AllocaInst>(mem_ptr))) {
+        if (llvm::GetElementPtrInst *gep =
+                dyn_cast<llvm::GetElementPtrInst>(mem_ptr)) {
+          llvm::Type *sourceType = gep->getSourceElementType();
+          if (sourceType->isStructTy() && gep->getNumIndices() >= 2 &&
+              !(gep->hasAllZeroIndices())) {
+            MetadataFlag = false;
+            break;
+          }
+          mem_ptr = gep->getPointerOperand();
+        } else if (llvm::LoadInst *li = dyn_cast<llvm::LoadInst>(mem_ptr))
+          mem_ptr = li->getPointerOperand();
+      }
+      if (MetadataFlag && is_contained(NontemporalVars, mem_ptr))
+        I.setMetadata(LLVMContext::MD_nontemporal, Nontemporal);
+    }
+  }
+}
+
 void OpenMPIRBuilder::applySimd(CanonicalLoopInfo *CanonicalLoop,
                                 MapVector<Value *, Value *> AlignedVars,
                                 Value *IfCond, OrderKind Order,
-                                ConstantInt *Simdlen, ConstantInt *Safelen) {
+                                ConstantInt *Simdlen, ConstantInt *Safelen,
+                                ArrayRef<Value *> NontemporalVarsIn) {
   LLVMContext &Ctx = Builder.getContext();
 
   Function *F = CanonicalLoop->getFunction();
@@ -5486,6 +5562,13 @@ void OpenMPIRBuilder::applySimd(CanonicalLoopInfo *CanonicalLoop,
   }
 
   addLoopMetadata(CanonicalLoop, LoopMDList);
+  SmallVector<Value *> NontemporalVars{NontemporalVarsIn};
+  // Set nontemporal metadata to load and stores of nontemporal values
+  if (NontemporalVars.size()) {
+    MDNode *NontemporalNode = MDNode::getDistinct(Ctx, {});
+    for (BasicBlock *BB : Reachable)
+      addNonTemporalMetadata(BB, NontemporalNode, NontemporalVars);
+  }
 }
 
 /// Create the TargetMachine object to query the backend for optimization
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 32c7c501d03c3..1c9690a1c7b68 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -189,10 +189,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
     if (!op.getLinearVars().empty() || !op.getLinearStepVars().empty())
       result = todo("linear");
   };
-  auto checkNontemporal = [&todo](auto op, LogicalResult &result) {
-    if (!op.getNontemporalVars().empty())
-      result = todo("nontemporal");
-  };
   auto checkNowait = [&todo](auto op, LogicalResult &result) {
     if (op.getNowait())
       result = todo("nowait");
@@ -300,7 +296,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
       })
       .Case([&](omp::SimdOp op) {
         checkLinear(op, result);
-        checkNontemporal(op, result);
         checkReduction(op, result);
       })
       .Case<omp::AtomicReadOp, omp::AtomicWriteOp, omp::AtomicUpdateOp,
@@ -2527,6 +2522,14 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
 
   llvm::MapVector<llvm::Value *, llvm::Value *> alignedVars;
   llvm::omp::OrderKind order = convertOrderKind(simdOp.getOrder());
+
+  llvm::SmallVector<llvm::Value *> nontemporalVars;
+  mlir::OperandRange nontemporals = simdOp.getNontemporalVars();
+  for (mlir::Value nontemporal : nontemporals) {
+    llvm::Value *nt = moduleTranslation.lookupValue(nontemporal);
+    nontemporalVars.push_back(nt);
+  }
+
   llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
   std::optional<ArrayAttr> alignmentValues = simdOp.getAlignments();
   mlir::OperandRange operands = simdOp.getAlignedVars();
@@ -2558,7 +2561,7 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
                         simdOp.getIfExpr()
                             ? moduleTranslation.lookupValue(simdOp.getIfExpr())
                             : nullptr,
-                        order, simdlen, safelen);
+                        order, simdlen, safelen, nontemporalVars);
 
   return cleanupPrivateVars(builder, moduleTranslation, simdOp.getLoc(),
                             llvmPrivateVars, privateDecls);
diff --git a/mlir/test/Target/LLVMIR/openmp-nontemporal.mlir b/mlir/test/Target/LLVMIR/openmp-nontemporal.mlir
new file mode 100644
index 0000000000000..f8cee94be4ff7
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/openmp-nontemporal.mlir
@@ -0,0 +1,96 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// -----
+// CHECK-LABEL: @simd_nontemporal
+llvm.func @simd_nontemporal() {
+  %0 = llvm.mlir.constant(10 : i64) : i64
+  %1 = llvm.mlir.constant(1 : i64) : i64
+  %2 = llvm.alloca %1 x i64 : (i64) -> !llvm.ptr
+  %3 = llvm.alloca %1 x i64 : (i64) -> !llvm.ptr
+  //CHECK: %[[A_ADDR:.*]] = alloca i64, i64 1, align 8
+  //CHECK: %[[B_ADDR:.*]] = alloca i64, i64 1, align 8
+  //CHECK: %[[B:.*]] = load i64, ptr %[[B_ADDR]], align 4, !nontemporal !1, !llvm.access.group !2
+  //CHECK: store i64 %[[B]], ptr %[[A_ADDR]], align 4, !nontemporal !1, !llvm.access.group !2
+  omp.simd nontemporal(%2, %3 : !llvm.ptr, !llvm.ptr) {
+    omp.loop_nest (%arg0) : i64 = (%1) to (%0) inclusive step (%1) {
+      %4 = llvm.load %3 : !llvm.ptr -> i64
+      llvm.store %4, %2 : i64, !llvm.ptr
+      omp.yield
+    }
+  }
+  llvm.return
+}
+
+// -----
+
+//CHECK-LABEL:  define void @_QPtest(ptr %0, ptr %1) {
+llvm.func @_QPtest(%arg0: !llvm.ptr {fir.bindc_name = "n"}, %arg1: !llvm.ptr {fir.bindc_name = "a"}) {
+    %0 = llvm.mlir.constant(1 : i32) : i32
+    %1 = llvm.alloca %0 x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)> {alignment = 8 : i64} : (i32) -> !llvm.ptr
+    %2 = llvm.alloca %0 x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)> {alignment = 8 : i64} : (i32) -> !llvm.ptr
+    %3 = llvm.mlir.constant(1 : i64) : i64
+    %4 = llvm.alloca %3 x i32 {bindc_name = "i", pinned} : (i64) -> !llvm.ptr
+    %6 = llvm.load %arg0 : !llvm.ptr -> i32
+    //CHECK:  %[[A_VAL1:.*]] = alloca { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] }, align 8
+    //CHECK:  %[[A_VAL2:.*]] = alloca { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] }, align 8
+    omp.simd nontemporal(%arg1 : !llvm.ptr) {
+      omp.loop_nest (%arg2) : i32 = (%0) to (%6) inclusive step (%0) {
+        llvm.store %arg2, %4 : i32, !llvm.ptr
+        //CHECK:  call void @llvm.memcpy.p0.p0.i32(ptr %[[A_VAL2]], ptr %1, i32 48, i1 false)
+        %7 = llvm.mlir.constant(48 : i32) : i32
+        "llvm.intr.memcpy"(%2, %arg1, %7) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
+        %8 = llvm.load %4 : !llvm.ptr -> i32
+        %9 = llvm.sext %8 : i32 to i64
+        %10 = llvm.getelementptr %2[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
+        %11 = llvm.load %10 : !llvm.ptr -> !llvm.ptr
+        %12 = llvm.mlir.constant(0 : index) : i64
+        %13 = llvm.getelementptr %2[0, 7, %12, 0] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
+        %14 = llvm.load %13 : !llvm.ptr -> i64
+        %15 = llvm.getelementptr %2[0, 7, %12, 1] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
+        %16 = llvm.load %15 : !llvm.ptr -> i64
+        %17 = llvm.getelementptr %2[0, 7, %12, 2] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
+        %18 = llvm.load %17 : !llvm.ptr -> i64
+        %19 = llvm.mlir.constant(0 : i64) : i64
+        %20 = llvm.sub %9, %14 overflow<nsw> : i64
+        %21 = llvm.mul %20, %3 overflow<nsw> : i64
+        %22 = llvm.mul %21, %3 overflow<nsw> : i64
+        %23 = llvm.add %22,%19 overflow<nsw> : i64
+        %24 = llvm.mul %3, %16 overflow<nsw> : i64
+        //CHECK:  %[[VAL1:.*]] = getelementptr float, ptr {{.*}}, i64 %{{.*}}
+        //CHECK:  %[[LOAD_A:.*]] = load float, ptr %[[VAL1]], align 4, !nontemporal 
+        //CHECK:  %[[RES:.*]] = fadd contract float %[[LOAD_A]], 2.000000e+01
+        %25 = llvm.getelementptr %11[%23] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+        %26 = llvm.load %25 : !llvm.ptr -> f32
+        %27 = llvm.mlir.constant(2.000000e+01 : f32) : f32
+        %28 = llvm.fadd %26, %27 {fastmathFlags = #llvm.fastmath<contract>} : f32
+        //CHECK:  call void @llvm.memcpy.p0.p0.i32(ptr %[[A_VAL1]], ptr %1, i32 48, i1 false)
+        %29 = llvm.mlir.constant(48 : i32) : i32
+        "llvm.intr.memcpy"(%1, %arg1, %29) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
+        %30 = llvm.load %4 : !llvm.ptr -> i32
+        %31 = llvm.sext %30 : i32 to i64
+        %32 = llvm.getelementptr %1[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
+        %33 = llvm.load %32 : !llvm.ptr -> !llvm.ptr
+        %34 = llvm.mlir.constant(0 : index) : i64
+        %35 = llvm.getelementptr %1[0, 7, %34, 0] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
+        %36 = llvm.load %35 : !llvm.ptr -> i64
+        %37 = llvm.getelementptr %1[0, 7, %34, 1] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
+        %38 = llvm.load %37 : !llvm.ptr -> i64
+        %39 = llvm.getelementptr %1[0, 7, %34, 2] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
+        %40 = llvm.load %39 : !llvm.ptr -> i64
+        %41 = llvm.sub %31, %36 overflow<nsw> : i64
+        %42 = llvm.mul %41, %3 overflow<nsw> : i64
+        %43 = llvm.mul %42, %3 overflow<nsw> : i64
+        %44 = llvm.add %43,%19 overflow<nsw> : i64
+        %45 = llvm.mul %3, %38 overflow<nsw> : i64
+        //CHECK:  %[[VAL2:.*]] = getelementptr float, ptr %{{.*}}, i64 %{{.*}}
+        //CHECK:  store float %[[RES]], ptr %[[VAL2]], align 4, !nontemporal 
+        %46 = llvm.getelementptr %33[%44] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+        llvm.store %28, %46 : f32, !llvm.ptr
+        omp.yield
+      }
+    }
+    llvm.return
+  }
+// -----
+ 
+
diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index f907bb3f94a2a..a6ed73b158bdb 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -184,19 +184,6 @@ llvm.func @simd_linear(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) {
 
 // -----
 
-llvm.func @simd_nontemporal(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) {
-  // expected-error at below {{not yet implemented: Unhandled clause nontemporal in omp.simd operation}}
-  // expected-error at below {{LLVM Translation failed for operation: omp.simd}}
-  omp.simd nontemporal(%x : !llvm.ptr) {
-    omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
-      omp.yield
-    }
-  }
-  llvm.return
-}
-
-// -----
-
 omp.declare_reduction @add_f32 : f32
 init {
 ^bb0(%arg: f32):

>From 53cc22910ede9382ecb2258f390515a530edc140 Mon Sep 17 00:00:00 2001
From: Kaviya Rajendiran <kaviyara2000 at gmail.com>
Date: Wed, 5 Mar 2025 12:20:45 +0530
Subject: [PATCH 2/5] [MLIR][OpenMP] Created callback function for adding the
 metadata of nontemporal and handled the translation in
 OpenMPToLLVMIRTranslation.cpp

---
 .../llvm/Frontend/OpenMP/OMPIRBuilder.h       | 16 +++-
 llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp     | 91 ++-----------------
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      | 91 +++++++++++++++++--
 3 files changed, 104 insertions(+), 94 deletions(-)

diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index fc726ec6cf4b4..60c64e369ca54 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -1229,6 +1229,9 @@ class OpenMPIRBuilder {
   void unrollLoopPartial(DebugLoc DL, CanonicalLoopInfo *Loop, int32_t Factor,
                          CanonicalLoopInfo **UnrolledCLI);
 
+  using NonTemporalBodyGenCallbackTy =
+      function_ref<void(llvm::BasicBlock *BB, MDNode *NontemporalNode)>;
+
   /// Add metadata to simd-ize a loop. If IfCond is not nullptr, the loop
   /// is cloned. The metadata which prevents vectorization is added to
   /// to the cloned loop. The cloned loop is executed when ifCond is evaluated
@@ -1242,10 +1245,15 @@ class OpenMPIRBuilder {
   /// \param Order       The enum to map order clause.
   /// \param Simdlen     The Simdlen length to apply to the simd loop.
   /// \param Safelen     The Safelen length to apply to the simd loop.
-  void applySimd(CanonicalLoopInfo *Loop,
-                 MapVector<Value *, Value *> AlignedVars, Value *IfCond,
-                 omp::OrderKind Order, ConstantInt *Simdlen,
-                 ConstantInt *Safelen, ArrayRef<Value *> NontempralVars = {});
+  /// \param NontemporalCBFunc    Call back function for nontemporal.
+  /// \param NontemporalVars      Array of nontemporal vars.
+  void applySimd(
+      CanonicalLoopInfo *Loop, MapVector<Value *, Value *> AlignedVars,
+      Value *IfCond, omp::OrderKind Order, ConstantInt *Simdlen,
+      ConstantInt *Safelen,
+      NonTemporalBodyGenCallbackTy NontemporalCBFunc = [](BasicBlock *,
+                                                          MDNode *) {},
+      ArrayRef<Value *> NontempralVars = {});
 
   /// Generator for '#omp flush'
   ///
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index bff201a1377c8..5973c0d63aa5d 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -5385,86 +5385,11 @@ OpenMPIRBuilder::getOpenMPDefaultSimdAlign(const Triple &TargetTriple,
   return 0;
 }
 
-static void appendNontemporalVars(BasicBlock *Block,
-                                  SmallVectorImpl<Value *> &NontemporalVars) {
-  for (Instruction &I : *Block) {
-    if (const CallInst *CI = dyn_cast<CallInst>(&I)) {
-      if (CI->getIntrinsicID() == Intrinsic::memcpy) {
-        llvm::Value *DestPtr = CI->getArgOperand(0);
-        llvm::Value *SrcPtr = CI->getArgOperand(1);
-        for (const llvm::Value *Var : NontemporalVars) {
-          if (Var == SrcPtr) {
-            NontemporalVars.push_back(DestPtr);
-            break;
-          }
-        }
-      }
-    }
-  }
-}
-
-/** Attach nontemporal metadata to the load/store instructions of nontemporal
- *  variables of \p Block
- * Nontemporal variables may be a scalar, fixed size or allocatable
- * or pointer array
- *
- * Example scenarios for nontemporal variables:
- * Case 1: Scalar variable
- *    If the nontemporal variable is a scalar, it is allocated on stack.Load and
- * store instructions directly access the alloca pointer of the scalar
- * variable for fetching information about scalar variable or  writing
- * into the scalar variable. Mark those load and store instructions as
- * non-temporal.
- *
- * Case 2: Fixed Size array
- *    If the nontemporal variable is a fixed-size array, it is allocated
- * as a contiguous block of memory. It uses one GEP instruction, to compute the
- * address of each individual array elements and perform load or store
- * operation on it. Mark those load and store instructions as non-temporal.
- *
- * Case 3: Allocatable array
- *     For an allocatable array, which might involve runtime type descriptor,
- * needs to navigate through descriptors using two or more GEP and load
- * instructions to compute the address of each individual element in an array.
- * Mark those load or store which access the individual array elements as
- * non-temporal.
- */
-static void addNonTemporalMetadata(BasicBlock *Block, MDNode *Nontemporal,
-                                   SmallVectorImpl<Value *> &NontemporalVars) {
-  appendNontemporalVars(Block, NontemporalVars);
-  for (Instruction &I : *Block) {
-    llvm::Value *mem_ptr = nullptr;
-    bool MetadataFlag = true;
-    if (llvm::LoadInst *li = dyn_cast<llvm::LoadInst>(&I)) {
-      if (!(li->getType()->isPointerTy()))
-        mem_ptr = li->getPointerOperand();
-    } else if (llvm::StoreInst *si = dyn_cast<llvm::StoreInst>(&I))
-      mem_ptr = si->getPointerOperand();
-    if (mem_ptr) {
-      while (mem_ptr && !(isa<llvm::AllocaInst>(mem_ptr))) {
-        if (llvm::GetElementPtrInst *gep =
-                dyn_cast<llvm::GetElementPtrInst>(mem_ptr)) {
-          llvm::Type *sourceType = gep->getSourceElementType();
-          if (sourceType->isStructTy() && gep->getNumIndices() >= 2 &&
-              !(gep->hasAllZeroIndices())) {
-            MetadataFlag = false;
-            break;
-          }
-          mem_ptr = gep->getPointerOperand();
-        } else if (llvm::LoadInst *li = dyn_cast<llvm::LoadInst>(mem_ptr))
-          mem_ptr = li->getPointerOperand();
-      }
-      if (MetadataFlag && is_contained(NontemporalVars, mem_ptr))
-        I.setMetadata(LLVMContext::MD_nontemporal, Nontemporal);
-    }
-  }
-}
-
-void OpenMPIRBuilder::applySimd(CanonicalLoopInfo *CanonicalLoop,
-                                MapVector<Value *, Value *> AlignedVars,
-                                Value *IfCond, OrderKind Order,
-                                ConstantInt *Simdlen, ConstantInt *Safelen,
-                                ArrayRef<Value *> NontemporalVarsIn) {
+void OpenMPIRBuilder::applySimd(
+    CanonicalLoopInfo *CanonicalLoop, MapVector<Value *, Value *> AlignedVars,
+    Value *IfCond, OrderKind Order, ConstantInt *Simdlen, ConstantInt *Safelen,
+    OpenMPIRBuilder::NonTemporalBodyGenCallbackTy NontemporalCBFunc,
+    ArrayRef<Value *> NontemporalVarsIn) {
   LLVMContext &Ctx = Builder.getContext();
 
   Function *F = CanonicalLoop->getFunction();
@@ -5562,12 +5487,12 @@ void OpenMPIRBuilder::applySimd(CanonicalLoopInfo *CanonicalLoop,
   }
 
   addLoopMetadata(CanonicalLoop, LoopMDList);
-  SmallVector<Value *> NontemporalVars{NontemporalVarsIn};
+
   // Set nontemporal metadata to load and stores of nontemporal values
-  if (NontemporalVars.size()) {
+  if (NontemporalVarsIn.size()) {
     MDNode *NontemporalNode = MDNode::getDistinct(Ctx, {});
     for (BasicBlock *BB : Reachable)
-      addNonTemporalMetadata(BB, NontemporalNode, NontemporalVars);
+      NontemporalCBFunc(BB, NontemporalNode);
   }
 }
 
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 1c9690a1c7b68..ef313388df8e0 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -2462,6 +2462,25 @@ convertOrderKind(std::optional<omp::ClauseOrderKind> o) {
   llvm_unreachable("Unknown ClauseOrderKind kind");
 }
 
+static void
+appendNontemporalVars(llvm::BasicBlock *Block,
+                      SmallVectorImpl<llvm::Value *> &NontemporalVars) {
+  for (llvm::Instruction &I : *Block) {
+    if (const llvm::CallInst *CI = dyn_cast<llvm::CallInst>(&I)) {
+      if (CI->getIntrinsicID() == llvm::Intrinsic::memcpy) {
+        llvm::Value *DestPtr = CI->getArgOperand(0);
+        llvm::Value *SrcPtr = CI->getArgOperand(1);
+        for (const llvm::Value *Var : NontemporalVars) {
+          if (Var == SrcPtr) {
+            NontemporalVars.push_back(DestPtr);
+            break;
+          }
+        }
+      }
+    }
+  }
+}
+
 /// Converts an OpenMP simd loop into LLVM IR using OpenMPIRBuilder.
 static LogicalResult
 convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
@@ -2523,13 +2542,71 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
   llvm::MapVector<llvm::Value *, llvm::Value *> alignedVars;
   llvm::omp::OrderKind order = convertOrderKind(simdOp.getOrder());
 
-  llvm::SmallVector<llvm::Value *> nontemporalVars;
+  llvm::SmallVector<llvm::Value *> nontemporalOrigVars;
   mlir::OperandRange nontemporals = simdOp.getNontemporalVars();
   for (mlir::Value nontemporal : nontemporals) {
     llvm::Value *nt = moduleTranslation.lookupValue(nontemporal);
-    nontemporalVars.push_back(nt);
+    nontemporalOrigVars.push_back(nt);
   }
 
+  /** Call back function to attach nontemporal metadata to the load/store
+   * instructions of nontemporal variables of Block.
+   * Nontemporal variables may be a scalar, fixed size or allocatable
+   * or pointer array
+   *
+   * Example scenarios for nontemporal variables:
+   * Case 1: Scalar variable
+   *    If the nontemporal variable is a scalar, it is allocated on stack.Load
+   * and store instructions directly access the alloca pointer of the scalar
+   * variable for fetching information about scalar variable or  writing
+   * into the scalar variable. Mark those load and store instructions as
+   * non-temporal.
+   *
+   * Case 2: Fixed Size array
+   *    If the nontemporal variable is a fixed-size array, it is allocated
+   * as a contiguous block of memory. It uses one GEP instruction, to compute
+   * the address of each individual array elements and perform load or store
+   * operation on it. Mark those load and store instructions as non-temporal.
+   *
+   * Case 3: Allocatable array
+   *     For an allocatable array, which might involve runtime type descriptor,
+   * needs to navigate through descriptors using two or more GEP and load
+   * instructions to compute the address of each individual element in an array.
+   * Mark those load or store which access the individual array elements as
+   * non-temporal.
+   */
+  auto addNonTemporalMetadataCB = [&](llvm::BasicBlock *Block,
+                                      llvm::MDNode *Nontemporal) {
+    SmallVector<llvm::Value *> NontemporalVars{nontemporalOrigVars};
+    appendNontemporalVars(Block, NontemporalVars);
+    for (llvm::Instruction &I : *Block) {
+      llvm::Value *mem_ptr = nullptr;
+      bool MetadataFlag = true;
+      if (llvm::LoadInst *li = dyn_cast<llvm::LoadInst>(&I)) {
+        if (!(li->getType()->isPointerTy()))
+          mem_ptr = li->getPointerOperand();
+      } else if (llvm::StoreInst *si = dyn_cast<llvm::StoreInst>(&I))
+        mem_ptr = si->getPointerOperand();
+      if (mem_ptr) {
+        while (mem_ptr && !(isa<llvm::AllocaInst>(mem_ptr))) {
+          if (llvm::GetElementPtrInst *gep =
+                  dyn_cast<llvm::GetElementPtrInst>(mem_ptr)) {
+            llvm::Type *sourceType = gep->getSourceElementType();
+            if (sourceType->isStructTy() && gep->getNumIndices() >= 2 &&
+                !(gep->hasAllZeroIndices())) {
+              MetadataFlag = false;
+              break;
+            }
+            mem_ptr = gep->getPointerOperand();
+          } else if (llvm::LoadInst *li = dyn_cast<llvm::LoadInst>(mem_ptr))
+            mem_ptr = li->getPointerOperand();
+        }
+        if (MetadataFlag && is_contained(NontemporalVars, mem_ptr))
+          I.setMetadata(llvm::LLVMContext::MD_nontemporal, Nontemporal);
+      }
+    }
+  };
+
   llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
   std::optional<ArrayAttr> alignmentValues = simdOp.getAlignments();
   mlir::OperandRange operands = simdOp.getAlignedVars();
@@ -2557,11 +2634,11 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
 
   builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
   llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
-  ompBuilder->applySimd(loopInfo, alignedVars,
-                        simdOp.getIfExpr()
-                            ? moduleTranslation.lookupValue(simdOp.getIfExpr())
-                            : nullptr,
-                        order, simdlen, safelen, nontemporalVars);
+  ompBuilder->applySimd(
+      loopInfo, alignedVars,
+      simdOp.getIfExpr() ? moduleTranslation.lookupValue(simdOp.getIfExpr())
+                         : nullptr,
+      order, simdlen, safelen, addNonTemporalMetadataCB, nontemporalOrigVars);
 
   return cleanupPrivateVars(builder, moduleTranslation, simdOp.getLoc(),
                             llvmPrivateVars, privateDecls);

>From a95b791e164a8ccd3eeb74a807ac90545aa28718 Mon Sep 17 00:00:00 2001
From: Kaviya Rajendiran <kaviyara2000 at gmail.com>
Date: Fri, 4 Apr 2025 16:21:45 +0530
Subject: [PATCH 3/5] Lowering Nontemporal clause to LLVM IR: Added a Flang
 pass lower-nontemporal, which adds nontemporal attribute to the load and
 store of nontemporal variables during fir conversion

---
 .../include/flang/Optimizer/OpenMP/Passes.td  |  7 ++
 flang/lib/Optimizer/CodeGen/CodeGen.cpp       |  8 +-
 flang/lib/Optimizer/OpenMP/CMakeLists.txt     |  1 +
 .../lib/Optimizer/OpenMP/LowerNontemporal.cpp | 71 ++++++++++++++
 flang/lib/Optimizer/Passes/Pipelines.cpp      |  4 +-
 flang/test/Fir/basic-program.fir              |  2 +
 flang/test/Lower/OpenMP/simd-nontemporal.f90  | 67 +++++++++++++
 .../llvm/Frontend/OpenMP/OMPIRBuilder.h       | 16 +---
 llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp     | 16 +---
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      | 94 +------------------
 .../Target/LLVMIR/openmp-nontemporal.mlir     | 28 +++---
 11 files changed, 185 insertions(+), 129 deletions(-)
 create mode 100644 flang/lib/Optimizer/OpenMP/LowerNontemporal.cpp
 create mode 100644 flang/test/Lower/OpenMP/simd-nontemporal.f90

diff --git a/flang/include/flang/Optimizer/OpenMP/Passes.td b/flang/include/flang/Optimizer/OpenMP/Passes.td
index fcc7a4ca31fef..704faf0ccd856 100644
--- a/flang/include/flang/Optimizer/OpenMP/Passes.td
+++ b/flang/include/flang/Optimizer/OpenMP/Passes.td
@@ -81,6 +81,13 @@ def DoConcurrentConversionPass : Pass<"omp-do-concurrent-conversion", "mlir::fun
   ];
 }
 
+def LowerNontemporalPass : Pass<"lower-nontemporal", "mlir::func::FuncOp"> {
+  let summary =
+      "Adds nontemporal attribute to loads and stores performed on "
+      "the list items specified in the nontemporal clause of omp.simd.";
+  let dependentDialects = ["mlir::omp::OpenMPDialect"];
+}
+
 // Needs to be scheduled on Module as we create functions in it
 def LowerWorkshare : Pass<"lower-workshare", "::mlir::ModuleOp"> {
   let summary = "Lower workshare construct";
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index b54b497ee4ba1..2f5bdd65117b2 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -3550,7 +3550,13 @@ struct StoreOpConversion : public fir::FIROpConversion<fir::StoreOp> {
       newOp = rewriter.create<mlir::LLVM::MemcpyOp>(
           loc, llvmMemref, llvmValue, boxSize, /*isVolatile=*/false);
     } else {
-      newOp = rewriter.create<mlir::LLVM::StoreOp>(loc, llvmValue, llvmMemref);
+      unsigned alignment =
+          store->getAttrOfType<mlir::IntegerAttr>("alignment")
+              ? store->getAttrOfType<mlir::IntegerAttr>("alignment").getInt()
+              : 0;
+      newOp = rewriter.create<mlir::LLVM::StoreOp>(
+          loc, llvmValue, llvmMemref, alignment, store->hasAttr("volatile"),
+          store->hasAttr("nontemporal"));
     }
     if (std::optional<mlir::ArrayAttr> optionalTag = store.getTbaa())
       newOp.setTBAATags(*optionalTag);
diff --git a/flang/lib/Optimizer/OpenMP/CMakeLists.txt b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
index 3acf143594356..ad89a0a606e10 100644
--- a/flang/lib/Optimizer/OpenMP/CMakeLists.txt
+++ b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
@@ -8,6 +8,7 @@ add_flang_library(FlangOpenMPTransforms
   MapInfoFinalization.cpp
   MarkDeclareTarget.cpp
   LowerWorkshare.cpp
+  LowerNontemporal.cpp
 
   DEPENDS
   FIRDialect
diff --git a/flang/lib/Optimizer/OpenMP/LowerNontemporal.cpp b/flang/lib/Optimizer/OpenMP/LowerNontemporal.cpp
new file mode 100644
index 0000000000000..2284a42950a05
--- /dev/null
+++ b/flang/lib/Optimizer/OpenMP/LowerNontemporal.cpp
@@ -0,0 +1,71 @@
+//===- LowerNontemporal.cpp -------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Add nontemporal attributes to load and stores of variables marked as
+// nontemporal.
+//
+//===----------------------------------------------------------------------===//
+#include "flang/Optimizer/Dialect/FIROpsSupport.h"
+#include "flang/Optimizer/OpenMP/Passes.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+using namespace mlir;
+namespace flangomp {
+#define GEN_PASS_DEF_LOWERNONTEMPORALPASS
+#include "flang/Optimizer/OpenMP/Passes.h.inc"
+} // namespace flangomp
+namespace {
+class LowerNontemporalPass
+    : public flangomp::impl::LowerNontemporalPassBase<LowerNontemporalPass> {
+  void addNonTemporalAttr(omp::SimdOp simdOp) {
+    if (!simdOp.getNontemporalVars().empty()) {
+      llvm::SmallVector<mlir::Value> nontemporalOrigVars;
+      mlir::OperandRange nontemporals = simdOp.getNontemporalVars();
+      for (mlir::Value nontemporal : nontemporals) {
+        nontemporalOrigVars.push_back(nontemporal);
+      }
+      std::function<mlir::Value(mlir::Value)> getBaseOperand =
+          [&](mlir::Value operand) -> mlir::Value {
+        if (mlir::isa<fir::DeclareOp>(operand.getDefiningOp()))
+          return operand;
+        else if (auto arrayCoorOp = llvm::dyn_cast<fir::ArrayCoorOp>(
+                     operand.getDefiningOp())) {
+          return getBaseOperand(arrayCoorOp.getMemref());
+        } else if (auto boxAddrOp = llvm::dyn_cast<fir::BoxAddrOp>(
+                       operand.getDefiningOp())) {
+          return getBaseOperand(boxAddrOp.getVal());
+        } else if (auto loadOp =
+                       llvm::dyn_cast<fir::LoadOp>(operand.getDefiningOp())) {
+          return getBaseOperand(loadOp.getMemref());
+        } else {
+          return operand;
+        }
+      };
+      simdOp->walk([&](Operation *op) {
+        mlir::Value Operand = nullptr;
+        if (auto loadOp = llvm::dyn_cast<fir::LoadOp>(op)) {
+          Operand = loadOp.getMemref();
+        } else if (auto storeOp = llvm::dyn_cast<fir::StoreOp>(op)) {
+          Operand = storeOp.getMemref();
+        }
+        if (Operand && !(fir::isAllocatableType(Operand.getType()) ||
+                         fir::isPointerType((Operand.getType())))) {
+          Operand = getBaseOperand(Operand);
+          if (is_contained(nontemporalOrigVars, Operand)) {
+            // Set the attribute
+            op->setAttr("nontemporal", UnitAttr::get(op->getContext()));
+          }
+        }
+      });
+    }
+  }
+  void runOnOperation() override {
+    Operation *op = getOperation();
+    op->walk([&](omp::SimdOp simdOp) { addNonTemporalAttr(simdOp); });
+  }
+};
+} // namespace
\ No newline at end of file
diff --git a/flang/lib/Optimizer/Passes/Pipelines.cpp b/flang/lib/Optimizer/Passes/Pipelines.cpp
index 81ff6bf9b2c6a..0d96ef4ace3d7 100644
--- a/flang/lib/Optimizer/Passes/Pipelines.cpp
+++ b/flang/lib/Optimizer/Passes/Pipelines.cpp
@@ -274,8 +274,10 @@ void createHLFIRToFIRPassPipeline(mlir::PassManager &pm, bool enableOpenMP,
     addNestedPassToAllTopLevelOperations<PassConstructor>(
         pm, hlfir::createInlineHLFIRAssign);
   pm.addPass(hlfir::createConvertHLFIRtoFIR());
-  if (enableOpenMP)
+  if (enableOpenMP) {
     pm.addPass(flangomp::createLowerWorkshare());
+    pm.addPass(flangomp::createLowerNontemporalPass());
+  }
 }
 
 /// Create a pass pipeline for handling certain OpenMP transformations needed
diff --git a/flang/test/Fir/basic-program.fir b/flang/test/Fir/basic-program.fir
index ded42886aad44..9330373c0d4b6 100644
--- a/flang/test/Fir/basic-program.fir
+++ b/flang/test/Fir/basic-program.fir
@@ -65,6 +65,8 @@ func.func @_QQmain() {
 // PASSES-NEXT:     InlineHLFIRAssign
 // PASSES-NEXT:   ConvertHLFIRtoFIR
 // PASSES-NEXT:   LowerWorkshare
+// PASSES-NEXT:   'func.func' Pipeline
+// PASSES-NEXT:   LowerNontemporalPass
 // PASSES-NEXT:   CSE
 // PASSES-NEXT:   (S) 0 num-cse'd - Number of operations CSE'd
 // PASSES-NEXT:   (S) 0 num-dce'd - Number of operations DCE'd
diff --git a/flang/test/Lower/OpenMP/simd-nontemporal.f90 b/flang/test/Lower/OpenMP/simd-nontemporal.f90
new file mode 100644
index 0000000000000..c25aec3347bf1
--- /dev/null
+++ b/flang/test/Lower/OpenMP/simd-nontemporal.f90
@@ -0,0 +1,67 @@
+! Test nontemporal clause
+! RUN: %flang_fc1 -emit-fir -fopenmp -fopenmp-version=50 %s -o - | FileCheck %s
+! RUN: bbc -emit-fir -fopenmp -fopenmp-version=50 %s -o - | FileCheck %s
+
+
+! CHECK-LABEL: func @_QPsimd_with_nontemporal_clause
+subroutine simd_with_nontemporal_clause(n)
+  ! CHECK: %[[A_DECL:.*]] = fir.declare %{{.*}} {uniq_name = "_QFsimd_with_nontemporal_clauseEa"} : (!fir.ref<i32>) -> !fir.ref<i32>
+  ! CHECK: %[[C_DECL:.*]] = fir.declare %{{.*}} {uniq_name = "_QFsimd_with_nontemporal_clauseEc"} : (!fir.ref<i32>) -> !fir.ref<i32>
+  integer :: i, n
+  integer :: A, B, C
+  ! CHECK: omp.simd nontemporal(%[[A_DECL]], %[[C_DECL]] : !fir.ref<i32>, !fir.ref<i32>) private(@_QFsimd_with_nontemporal_clauseEi_private_i32 %8 -> %arg1 : !fir.ref<i32>) {
+  ! CHECK-NEXT: omp.loop_nest (%{{.*}}) : i32 = (%{{.*}}) to (%{{.*}}) inclusive step (%{{.*}}) {
+  !$OMP SIMD NONTEMPORAL(A, C)
+  do i = 1, n
+    ! CHECK:  %[[LOAD:.*]] = fir.load %[[A_DECL]] {nontemporal} : !fir.ref<i32>
+    C = A + B
+    ! CHECK: %[[ADD_VAL:.*]] = arith.addi %{{.*}}, %{{.*}} : i32
+    ! CHECK: fir.store %[[ADD_VAL]] to %[[C_DECL]] {nontemporal} : !fir.ref<i32>
+  end do
+  !$OMP END SIMD
+end subroutine
+
+! CHECK-LABEL:  func.func @_QPsimd_nontemporal_allocatable
+subroutine simd_nontemporal_allocatable(x, y)
+  integer, allocatable :: x(:)
+  integer :: y
+  allocate(x(100))
+  ! CHECK:  %[[X_DECL:.*]] = fir.declare %{{.*}} dummy_scope %{{.*}} {fortran_attrs = #fir.var_attrs<allocatable>, 
+  ! CHECK-SAME: uniq_name = "_QFsimd_nontemporal_allocatableEx"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.dscope) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+  ! CHECK:  omp.simd nontemporal(%[[X_DECL]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) private(@_QFsimd_nontemporal_allocatableEi_private_i32 %2 -> %arg2 : !fir.ref<i32>) {
+  ! CHECK:   omp.loop_nest (%{{.*}}) : i32 = (%{{.*}}) to (%{{.*}}) inclusive step (%{{.*}}) {
+  !$omp simd nontemporal(x)
+  do i=1,100
+    ! CHECK:  %[[VAL1:.*]] = fir.load %[[X_DECL]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+    ! CHECK:  %[[BOX_ADDR:.*]] = fir.box_addr %[[VAL1]] : (!fir.box<!fir.heap<!fir.array<?xi32>>>) -> !fir.heap<!fir.array<?xi32>>
+    ! CHECK:  %[[ARR_COOR:.*]] = fir.array_coor %[[BOX_ADDR]](%{{.*}}) %{{.*}} : (!fir.heap<!fir.array<?xi32>>, !fir.shapeshift<1>, i64) -> !fir.ref<i32>
+    ! CHECK:  %[[VAL2:.*]] = fir.load %[[ARR_COOR]] {nontemporal} : !fir.ref<i32>
+  x(i) = x(i) + y
+    ! CHECK:  fir.store %{{.*}} to %{{.*}} {nontemporal} : !fir.ref<i32>
+  end do
+  !$omp end simd
+end subroutine
+
+! CHECK-LABEL:  func.func @_QPsimd_nontemporal_pointers
+subroutine simd_nontemporal_pointers(a, c)
+   integer :: b, i
+   integer :: n
+   integer, pointer, intent(in):: a(:)
+   integer, pointer, intent(out) :: c(:)
+   ! CHECK:  %[[A_DECL:.*]] = fir.declare  %{{.*}} dummy_scope %{{.*}} {fortran_attrs = #fir.var_attrs<intent_in, pointer>, 
+   ! CHECK-SAME: uniq_name = "_QFsimd_nontemporal_pointersEa"} : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, !fir.dscope) -> !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>
+   ! CHECK:  %[[C_DECL:.*]] = fir.declare %{{.*}} dummy_scope %{{.*}} {fortran_attrs = #fir.var_attrs<intent_out, pointer>, 
+   ! CHECK-SAME: uniq_name = "_QFsimd_nontemporal_pointersEc"} : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, !fir.dscope) -> !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>
+   !$OMP SIMD NONTEMPORAL(a,c)
+   do i = 1, n
+      ! CHECK: %[[VAL1:.*]] = fir.load %[[A_DECL]] : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>
+      ! CHECK: %[[VAL2:.*]] = fir.array_coor %[[VAL1]](%{{.*}}) %{{.*}} : (!fir.box<!fir.ptr<!fir.array<?xi32>>>, !fir.shift<1>, i64) -> !fir.ref<i32>
+      ! CHECK: %[[VAL3:.*]] = fir.load %[[VAL2]] {nontemporal} : !fir.ref<i32>
+      c(i) = a(i) + b
+      ! CHECK: %[[VAL4:.*]] = fir.load %[[C_DECL]] : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>
+      ! CHECK: %[[VAL5:.*]] = fir.array_coor %[[VAL4]](%{{.*}}) %{{.*}} : (!fir.box<!fir.ptr<!fir.array<?xi32>>>, !fir.shift<1>, i64) -> !fir.ref<i32>
+      ! CHECK: fir.store %{{.*}} to %[[VAL5]] {nontemporal} : !fir.ref<i32>
+   end do
+   !$OMP END SIMD
+end subroutine
+
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index b85775c9cb7a0..ec013d1822439 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -1231,9 +1231,6 @@ class OpenMPIRBuilder {
   void unrollLoopPartial(DebugLoc DL, CanonicalLoopInfo *Loop, int32_t Factor,
                          CanonicalLoopInfo **UnrolledCLI);
 
-  using NonTemporalBodyGenCallbackTy =
-      function_ref<void(llvm::BasicBlock *BB, MDNode *NontemporalNode)>;
-
   /// Add metadata to simd-ize a loop. If IfCond is not nullptr, the loop
   /// is cloned. The metadata which prevents vectorization is added to
   /// to the cloned loop. The cloned loop is executed when ifCond is evaluated
@@ -1247,15 +1244,10 @@ class OpenMPIRBuilder {
   /// \param Order       The enum to map order clause.
   /// \param Simdlen     The Simdlen length to apply to the simd loop.
   /// \param Safelen     The Safelen length to apply to the simd loop.
-  /// \param NontemporalCBFunc    Call back function for nontemporal.
-  /// \param NontemporalVars      Array of nontemporal vars.
-  void applySimd(
-      CanonicalLoopInfo *Loop, MapVector<Value *, Value *> AlignedVars,
-      Value *IfCond, omp::OrderKind Order, ConstantInt *Simdlen,
-      ConstantInt *Safelen,
-      NonTemporalBodyGenCallbackTy NontemporalCBFunc = [](BasicBlock *,
-                                                          MDNode *) {},
-      ArrayRef<Value *> NontempralVars = {});
+  void applySimd(CanonicalLoopInfo *Loop,
+                 MapVector<Value *, Value *> AlignedVars, Value *IfCond,
+                 omp::OrderKind Order, ConstantInt *Simdlen,
+                 ConstantInt *Safelen);
 
   /// Generator for '#omp flush'
   ///
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 15b25a8d10994..68b1fa42934ad 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -5341,11 +5341,10 @@ OpenMPIRBuilder::getOpenMPDefaultSimdAlign(const Triple &TargetTriple,
   return 0;
 }
 
-void OpenMPIRBuilder::applySimd(
-    CanonicalLoopInfo *CanonicalLoop, MapVector<Value *, Value *> AlignedVars,
-    Value *IfCond, OrderKind Order, ConstantInt *Simdlen, ConstantInt *Safelen,
-    OpenMPIRBuilder::NonTemporalBodyGenCallbackTy NontemporalCBFunc,
-    ArrayRef<Value *> NontemporalVarsIn) {
+void OpenMPIRBuilder::applySimd(CanonicalLoopInfo *CanonicalLoop,
+                                MapVector<Value *, Value *> AlignedVars,
+                                Value *IfCond, OrderKind Order,
+                                ConstantInt *Simdlen, ConstantInt *Safelen) {
   LLVMContext &Ctx = Builder.getContext();
 
   Function *F = CanonicalLoop->getFunction();
@@ -5443,13 +5442,6 @@ void OpenMPIRBuilder::applySimd(
   }
 
   addLoopMetadata(CanonicalLoop, LoopMDList);
-
-  // Set nontemporal metadata to load and stores of nontemporal values
-  if (NontemporalVarsIn.size()) {
-    MDNode *NontemporalNode = MDNode::getDistinct(Ctx, {});
-    for (BasicBlock *BB : Reachable)
-      NontemporalCBFunc(BB, NontemporalNode);
-  }
 }
 
 /// Create the TargetMachine object to query the backend for optimization
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index a17c4f22c9fa8..71d50a60e727c 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -2462,25 +2462,6 @@ convertOrderKind(std::optional<omp::ClauseOrderKind> o) {
   llvm_unreachable("Unknown ClauseOrderKind kind");
 }
 
-static void
-appendNontemporalVars(llvm::BasicBlock *Block,
-                      SmallVectorImpl<llvm::Value *> &NontemporalVars) {
-  for (llvm::Instruction &I : *Block) {
-    if (const llvm::CallInst *CI = dyn_cast<llvm::CallInst>(&I)) {
-      if (CI->getIntrinsicID() == llvm::Intrinsic::memcpy) {
-        llvm::Value *DestPtr = CI->getArgOperand(0);
-        llvm::Value *SrcPtr = CI->getArgOperand(1);
-        for (const llvm::Value *Var : NontemporalVars) {
-          if (Var == SrcPtr) {
-            NontemporalVars.push_back(DestPtr);
-            break;
-          }
-        }
-      }
-    }
-  }
-}
-
 /// Converts an OpenMP simd loop into LLVM IR using OpenMPIRBuilder.
 static LogicalResult
 convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
@@ -2529,71 +2510,6 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
   llvm::MapVector<llvm::Value *, llvm::Value *> alignedVars;
   llvm::omp::OrderKind order = convertOrderKind(simdOp.getOrder());
 
-  llvm::SmallVector<llvm::Value *> nontemporalOrigVars;
-  mlir::OperandRange nontemporals = simdOp.getNontemporalVars();
-  for (mlir::Value nontemporal : nontemporals) {
-    llvm::Value *nt = moduleTranslation.lookupValue(nontemporal);
-    nontemporalOrigVars.push_back(nt);
-  }
-
-  /** Call back function to attach nontemporal metadata to the load/store
-   * instructions of nontemporal variables of Block.
-   * Nontemporal variables may be a scalar, fixed size or allocatable
-   * or pointer array
-   *
-   * Example scenarios for nontemporal variables:
-   * Case 1: Scalar variable
-   *    If the nontemporal variable is a scalar, it is allocated on stack.Load
-   * and store instructions directly access the alloca pointer of the scalar
-   * variable for fetching information about scalar variable or  writing
-   * into the scalar variable. Mark those load and store instructions as
-   * non-temporal.
-   *
-   * Case 2: Fixed Size array
-   *    If the nontemporal variable is a fixed-size array, it is allocated
-   * as a contiguous block of memory. It uses one GEP instruction, to compute
-   * the address of each individual array elements and perform load or store
-   * operation on it. Mark those load and store instructions as non-temporal.
-   *
-   * Case 3: Allocatable array
-   *     For an allocatable array, which might involve runtime type descriptor,
-   * needs to navigate through descriptors using two or more GEP and load
-   * instructions to compute the address of each individual element in an array.
-   * Mark those load or store which access the individual array elements as
-   * non-temporal.
-   */
-  auto addNonTemporalMetadataCB = [&](llvm::BasicBlock *Block,
-                                      llvm::MDNode *Nontemporal) {
-    SmallVector<llvm::Value *> NontemporalVars{nontemporalOrigVars};
-    appendNontemporalVars(Block, NontemporalVars);
-    for (llvm::Instruction &I : *Block) {
-      llvm::Value *mem_ptr = nullptr;
-      bool MetadataFlag = true;
-      if (llvm::LoadInst *li = dyn_cast<llvm::LoadInst>(&I)) {
-        if (!(li->getType()->isPointerTy()))
-          mem_ptr = li->getPointerOperand();
-      } else if (llvm::StoreInst *si = dyn_cast<llvm::StoreInst>(&I))
-        mem_ptr = si->getPointerOperand();
-      if (mem_ptr) {
-        while (mem_ptr && !(isa<llvm::AllocaInst>(mem_ptr))) {
-          if (llvm::GetElementPtrInst *gep =
-                  dyn_cast<llvm::GetElementPtrInst>(mem_ptr)) {
-            llvm::Type *sourceType = gep->getSourceElementType();
-            if (sourceType->isStructTy() && gep->getNumIndices() >= 2 &&
-                !(gep->hasAllZeroIndices())) {
-              MetadataFlag = false;
-              break;
-            }
-            mem_ptr = gep->getPointerOperand();
-          } else if (llvm::LoadInst *li = dyn_cast<llvm::LoadInst>(mem_ptr))
-            mem_ptr = li->getPointerOperand();
-        }
-        if (MetadataFlag && is_contained(NontemporalVars, mem_ptr))
-          I.setMetadata(llvm::LLVMContext::MD_nontemporal, Nontemporal);
-      }
-    }
-  };
-
   llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
   std::optional<ArrayAttr> alignmentValues = simdOp.getAlignments();
   mlir::OperandRange operands = simdOp.getAlignedVars();
@@ -2621,11 +2537,11 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
 
   builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
   llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
-  ompBuilder->applySimd(
-      loopInfo, alignedVars,
-      simdOp.getIfExpr() ? moduleTranslation.lookupValue(simdOp.getIfExpr())
-                         : nullptr,
-      order, simdlen, safelen, addNonTemporalMetadataCB, nontemporalOrigVars);
+  ompBuilder->applySimd(loopInfo, alignedVars,
+                        simdOp.getIfExpr()
+                            ? moduleTranslation.lookupValue(simdOp.getIfExpr())
+                            : nullptr,
+                        order, simdlen, safelen);
 
   return cleanupPrivateVars(builder, moduleTranslation, simdOp.getLoc(),
                             privateVarsInfo.llvmVars,
diff --git a/mlir/test/Target/LLVMIR/openmp-nontemporal.mlir b/mlir/test/Target/LLVMIR/openmp-nontemporal.mlir
index f8cee94be4ff7..974cf674d547d 100644
--- a/mlir/test/Target/LLVMIR/openmp-nontemporal.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-nontemporal.mlir
@@ -13,8 +13,8 @@ llvm.func @simd_nontemporal() {
   //CHECK: store i64 %[[B]], ptr %[[A_ADDR]], align 4, !nontemporal !1, !llvm.access.group !2
   omp.simd nontemporal(%2, %3 : !llvm.ptr, !llvm.ptr) {
     omp.loop_nest (%arg0) : i64 = (%1) to (%0) inclusive step (%1) {
-      %4 = llvm.load %3 : !llvm.ptr -> i64
-      llvm.store %4, %2 : i64, !llvm.ptr
+      %4 = llvm.load %3 {nontemporal}: !llvm.ptr -> i64
+      llvm.store %4, %2 {nontemporal} : i64, !llvm.ptr
       omp.yield
     }
   }
@@ -31,12 +31,12 @@ llvm.func @_QPtest(%arg0: !llvm.ptr {fir.bindc_name = "n"}, %arg1: !llvm.ptr {fi
     %3 = llvm.mlir.constant(1 : i64) : i64
     %4 = llvm.alloca %3 x i32 {bindc_name = "i", pinned} : (i64) -> !llvm.ptr
     %6 = llvm.load %arg0 : !llvm.ptr -> i32
-    //CHECK:  %[[A_VAL1:.*]] = alloca { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] }, align 8
-    //CHECK:  %[[A_VAL2:.*]] = alloca { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] }, align 8
+    // CHECK:  %[[A_VAL1:.*]] = alloca { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] }, align 8
+    // CHECK:  %[[A_VAL2:.*]] = alloca { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] }, align 8
     omp.simd nontemporal(%arg1 : !llvm.ptr) {
       omp.loop_nest (%arg2) : i32 = (%0) to (%6) inclusive step (%0) {
         llvm.store %arg2, %4 : i32, !llvm.ptr
-        //CHECK:  call void @llvm.memcpy.p0.p0.i32(ptr %[[A_VAL2]], ptr %1, i32 48, i1 false)
+        // CHECK:  call void @llvm.memcpy.p0.p0.i32(ptr %[[A_VAL2]], ptr %1, i32 48, i1 false)
         %7 = llvm.mlir.constant(48 : i32) : i32
         "llvm.intr.memcpy"(%2, %arg1, %7) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
         %8 = llvm.load %4 : !llvm.ptr -> i32
@@ -56,14 +56,14 @@ llvm.func @_QPtest(%arg0: !llvm.ptr {fir.bindc_name = "n"}, %arg1: !llvm.ptr {fi
         %22 = llvm.mul %21, %3 overflow<nsw> : i64
         %23 = llvm.add %22,%19 overflow<nsw> : i64
         %24 = llvm.mul %3, %16 overflow<nsw> : i64
-        //CHECK:  %[[VAL1:.*]] = getelementptr float, ptr {{.*}}, i64 %{{.*}}
-        //CHECK:  %[[LOAD_A:.*]] = load float, ptr %[[VAL1]], align 4, !nontemporal 
-        //CHECK:  %[[RES:.*]] = fadd contract float %[[LOAD_A]], 2.000000e+01
+        // CHECK:  %[[VAL1:.*]] = getelementptr float, ptr {{.*}}, i64 %{{.*}}
+        // CHECK:  %[[LOAD_A:.*]] = load float, ptr %[[VAL1]], align 4, !nontemporal 
+        // CHECK:  %[[RES:.*]] = fadd contract float %[[LOAD_A]], 2.000000e+01
         %25 = llvm.getelementptr %11[%23] : (!llvm.ptr, i64) -> !llvm.ptr, f32
-        %26 = llvm.load %25 : !llvm.ptr -> f32
+        %26 = llvm.load %25 {nontemporal} : !llvm.ptr -> f32
         %27 = llvm.mlir.constant(2.000000e+01 : f32) : f32
         %28 = llvm.fadd %26, %27 {fastmathFlags = #llvm.fastmath<contract>} : f32
-        //CHECK:  call void @llvm.memcpy.p0.p0.i32(ptr %[[A_VAL1]], ptr %1, i32 48, i1 false)
+        // CHECK:  call void @llvm.memcpy.p0.p0.i32(ptr %[[A_VAL1]], ptr %1, i32 48, i1 false)
         %29 = llvm.mlir.constant(48 : i32) : i32
         "llvm.intr.memcpy"(%1, %arg1, %29) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
         %30 = llvm.load %4 : !llvm.ptr -> i32
@@ -82,15 +82,15 @@ llvm.func @_QPtest(%arg0: !llvm.ptr {fir.bindc_name = "n"}, %arg1: !llvm.ptr {fi
         %43 = llvm.mul %42, %3 overflow<nsw> : i64
         %44 = llvm.add %43,%19 overflow<nsw> : i64
         %45 = llvm.mul %3, %38 overflow<nsw> : i64
-        //CHECK:  %[[VAL2:.*]] = getelementptr float, ptr %{{.*}}, i64 %{{.*}}
-        //CHECK:  store float %[[RES]], ptr %[[VAL2]], align 4, !nontemporal 
+        // CHECK:  %[[VAL2:.*]] = getelementptr float, ptr %{{.*}}, i64 %{{.*}}
+        // CHECK:  store float %[[RES]], ptr %[[VAL2]], align 4, !nontemporal 
         %46 = llvm.getelementptr %33[%44] : (!llvm.ptr, i64) -> !llvm.ptr, f32
-        llvm.store %28, %46 : f32, !llvm.ptr
+        llvm.store %28, %46 {nontemporal} : f32, !llvm.ptr
         omp.yield
       }
     }
     llvm.return
   }
+
 // -----
- 
 

>From 37fc0bc55799dc6c8adb39375df9c42c6aa77026 Mon Sep 17 00:00:00 2001
From: Kaviya Rajendiran <kaviyara2000 at gmail.com>
Date: Mon, 21 Apr 2025 19:44:36 +0530
Subject: [PATCH 4/5] Addressed review comments : Added nontemporal attribute
 to fir.load and fir.store and used that to mark the operations as nontemporal

---
 .../include/flang/Optimizer/Dialect/FIROps.td |   7 +-
 flang/lib/Optimizer/CodeGen/CodeGen.cpp       |  14 ++-
 .../lib/Optimizer/OpenMP/LowerNontemporal.cpp |  83 +++++++-------
 flang/lib/Optimizer/Passes/Pipelines.cpp      |   8 +-
 flang/test/Fir/basic-program.fir              |   3 +-
 flang/test/Lower/OpenMP/simd-nontemporal.f90  |  67 ------------
 flang/test/Lower/OpenMP/simd-nontemporal.mlir | 101 ++++++++++++++++++
 7 files changed, 160 insertions(+), 123 deletions(-)
 delete mode 100644 flang/test/Lower/OpenMP/simd-nontemporal.f90
 create mode 100644 flang/test/Lower/OpenMP/simd-nontemporal.mlir

diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index f9dc2e51a396c..cd5aa139b7391 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -305,7 +305,7 @@ def fir_LoadOp : fir_OneResultOp<"load", [FirAliasTagOpInterface,
   }];
 
   let arguments = (ins AnyReferenceLike:$memref,
-                  OptionalAttr<LLVM_TBAATagArrayAttr>:$tbaa);
+      OptionalAttr<LLVM_TBAATagArrayAttr>:$tbaa, UnitAttr:$nontemporal);
 
   let builders = [OpBuilder<(ins "mlir::Value":$refVal)>,
                   OpBuilder<(ins "mlir::Type":$resTy, "mlir::Value":$refVal)>];
@@ -337,9 +337,8 @@ def fir_StoreOp : fir_Op<"store", [FirAliasTagOpInterface,
     `%p`, is undefined or null.
   }];
 
-  let arguments = (ins AnyType:$value,
-                   AnyReferenceLike:$memref,
-                   OptionalAttr<LLVM_TBAATagArrayAttr>:$tbaa);
+  let arguments = (ins AnyType:$value, AnyReferenceLike:$memref,
+      OptionalAttr<LLVM_TBAATagArrayAttr>:$tbaa, UnitAttr:$nontemporal);
 
   let builders = [OpBuilder<(ins "mlir::Value":$value, "mlir::Value":$memref)>];
 
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index 0014de3df4752..662ec8e30a56c 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -3567,17 +3567,15 @@ struct StoreOpConversion : public fir::FIROpConversion<fir::StoreOp> {
       newOp = rewriter.create<mlir::LLVM::MemcpyOp>(loc, llvmMemref, llvmValue,
                                                     boxSize, isVolatile);
     } else {
-      unsigned alignment =
-          store->getAttrOfType<mlir::IntegerAttr>("alignment")
-              ? store->getAttrOfType<mlir::IntegerAttr>("alignment").getInt()
-              : 0;
-
-      mlir::LLVM::StoreOp storeOp = rewriter.create<mlir::LLVM::StoreOp>(
-          loc, llvmValue, llvmMemref, alignment, store->hasAttr("volatile"),
-          store->hasAttr("nontemporal"));
+      mlir::LLVM::StoreOp storeOp =
+          rewriter.create<mlir::LLVM::StoreOp>(loc, llvmValue, llvmMemref);
 
       if (isVolatile)
         storeOp.setVolatile_(true);
+
+      if (store.getNontemporal())
+        storeOp.setNontemporal(true);
+
       newOp = storeOp;
     }
     if (std::optional<mlir::ArrayAttr> optionalTag = store.getTbaa())
diff --git a/flang/lib/Optimizer/OpenMP/LowerNontemporal.cpp b/flang/lib/Optimizer/OpenMP/LowerNontemporal.cpp
index 2284a42950a05..0b0dcac6f4e6f 100644
--- a/flang/lib/Optimizer/OpenMP/LowerNontemporal.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerNontemporal.cpp
@@ -10,62 +10,67 @@
 // nontemporal.
 //
 //===----------------------------------------------------------------------===//
+
 #include "flang/Optimizer/Dialect/FIROpsSupport.h"
 #include "flang/Optimizer/OpenMP/Passes.h"
 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+
 using namespace mlir;
+
 namespace flangomp {
 #define GEN_PASS_DEF_LOWERNONTEMPORALPASS
 #include "flang/Optimizer/OpenMP/Passes.h.inc"
 } // namespace flangomp
+
 namespace {
 class LowerNontemporalPass
     : public flangomp::impl::LowerNontemporalPassBase<LowerNontemporalPass> {
   void addNonTemporalAttr(omp::SimdOp simdOp) {
-    if (!simdOp.getNontemporalVars().empty()) {
-      llvm::SmallVector<mlir::Value> nontemporalOrigVars;
-      mlir::OperandRange nontemporals = simdOp.getNontemporalVars();
-      for (mlir::Value nontemporal : nontemporals) {
-        nontemporalOrigVars.push_back(nontemporal);
-      }
-      std::function<mlir::Value(mlir::Value)> getBaseOperand =
-          [&](mlir::Value operand) -> mlir::Value {
-        if (mlir::isa<fir::DeclareOp>(operand.getDefiningOp()))
-          return operand;
-        else if (auto arrayCoorOp = llvm::dyn_cast<fir::ArrayCoorOp>(
-                     operand.getDefiningOp())) {
-          return getBaseOperand(arrayCoorOp.getMemref());
-        } else if (auto boxAddrOp = llvm::dyn_cast<fir::BoxAddrOp>(
-                       operand.getDefiningOp())) {
-          return getBaseOperand(boxAddrOp.getVal());
-        } else if (auto loadOp =
-                       llvm::dyn_cast<fir::LoadOp>(operand.getDefiningOp())) {
-          return getBaseOperand(loadOp.getMemref());
-        } else {
-          return operand;
-        }
-      };
-      simdOp->walk([&](Operation *op) {
-        mlir::Value Operand = nullptr;
-        if (auto loadOp = llvm::dyn_cast<fir::LoadOp>(op)) {
-          Operand = loadOp.getMemref();
-        } else if (auto storeOp = llvm::dyn_cast<fir::StoreOp>(op)) {
-          Operand = storeOp.getMemref();
+    if (simdOp.getNontemporalVars().empty())
+      return;
+
+    std::function<mlir::Value(mlir::Value)> getBaseOperand =
+        [&](mlir::Value operand) -> mlir::Value {
+      if (mlir::isa<mlir::BlockArgument>(operand) ||
+          (mlir::isa<fir::AllocaOp>(operand.getDefiningOp())) ||
+          (mlir::isa<fir::DeclareOp>(operand.getDefiningOp())))
+        return operand;
+
+      Operation *definingOp = operand.getDefiningOp();
+      if (definingOp) {
+        for (Value srcOp : definingOp->getOperands()) {
+          return getBaseOperand(srcOp);
         }
-        if (Operand && !(fir::isAllocatableType(Operand.getType()) ||
-                         fir::isPointerType((Operand.getType())))) {
-          Operand = getBaseOperand(Operand);
-          if (is_contained(nontemporalOrigVars, Operand)) {
-            // Set the attribute
-            op->setAttr("nontemporal", UnitAttr::get(op->getContext()));
-          }
+      }
+      return operand;
+    };
+
+    // walk through the operations and mark the load and store as nontemporal
+    simdOp->walk([&](Operation *op) {
+      mlir::Value operand = nullptr;
+
+      if (auto loadOp = llvm::dyn_cast<fir::LoadOp>(op))
+        operand = loadOp.getMemref();
+      else if (auto storeOp = llvm::dyn_cast<fir::StoreOp>(op))
+        operand = storeOp.getMemref();
+
+      if (operand && !(fir::isAllocatableType(operand.getType()) ||
+                       fir::isPointerType((operand.getType())))) {
+        operand = getBaseOperand(operand);
+
+        if (llvm::is_contained(simdOp.getNontemporalVars(), operand)) {
+          if (auto loadOp = llvm::dyn_cast<fir::LoadOp>(op))
+            loadOp.setNontemporal(true);
+          else if (auto storeOp = llvm::dyn_cast<fir::StoreOp>(op))
+            storeOp.setNontemporal(true);
         }
-      });
-    }
+      }
+    });
   }
+
   void runOnOperation() override {
     Operation *op = getOperation();
     op->walk([&](omp::SimdOp simdOp) { addNonTemporalAttr(simdOp); });
   }
 };
-} // namespace
\ No newline at end of file
+} // namespace
diff --git a/flang/lib/Optimizer/Passes/Pipelines.cpp b/flang/lib/Optimizer/Passes/Pipelines.cpp
index 0d96ef4ace3d7..24e943760c7a5 100644
--- a/flang/lib/Optimizer/Passes/Pipelines.cpp
+++ b/flang/lib/Optimizer/Passes/Pipelines.cpp
@@ -274,10 +274,8 @@ void createHLFIRToFIRPassPipeline(mlir::PassManager &pm, bool enableOpenMP,
     addNestedPassToAllTopLevelOperations<PassConstructor>(
         pm, hlfir::createInlineHLFIRAssign);
   pm.addPass(hlfir::createConvertHLFIRtoFIR());
-  if (enableOpenMP) {
+  if (enableOpenMP)
     pm.addPass(flangomp::createLowerWorkshare());
-    pm.addPass(flangomp::createLowerNontemporalPass());
-  }
 }
 
 /// Create a pass pipeline for handling certain OpenMP transformations needed
@@ -347,6 +345,10 @@ void createDefaultFIRCodeGenPassPipeline(mlir::PassManager &pm,
        config.ApproxFuncFPMath, config.NoSignedZerosFPMath, config.UnsafeFPMath,
        ""}));
 
+  if (config.EnableOpenMP)
+    pm.addNestedPass<mlir::func::FuncOp>(
+        flangomp::createLowerNontemporalPass());
+
   fir::addFIRToLLVMPass(pm, config);
 }
 
diff --git a/flang/test/Fir/basic-program.fir b/flang/test/Fir/basic-program.fir
index 9330373c0d4b6..5a02dd46c6031 100644
--- a/flang/test/Fir/basic-program.fir
+++ b/flang/test/Fir/basic-program.fir
@@ -65,8 +65,6 @@ func.func @_QQmain() {
 // PASSES-NEXT:     InlineHLFIRAssign
 // PASSES-NEXT:   ConvertHLFIRtoFIR
 // PASSES-NEXT:   LowerWorkshare
-// PASSES-NEXT:   'func.func' Pipeline
-// PASSES-NEXT:   LowerNontemporalPass
 // PASSES-NEXT:   CSE
 // PASSES-NEXT:   (S) 0 num-cse'd - Number of operations CSE'd
 // PASSES-NEXT:   (S) 0 num-dce'd - Number of operations DCE'd
@@ -151,6 +149,7 @@ func.func @_QQmain() {
 // PASSES-NEXT: CompilerGeneratedNamesConversion
 // PASSES-NEXT: 'func.func' Pipeline
 // PASSES-NEXT:  FunctionAttr
+// PASSES-NEXT:  LowerNontemporalPass
 // PASSES-NEXT: FIRToLLVMLowering
 // PASSES-NEXT: ReconcileUnrealizedCasts
 // PASSES-NEXT: LLVMIRLoweringPass
diff --git a/flang/test/Lower/OpenMP/simd-nontemporal.f90 b/flang/test/Lower/OpenMP/simd-nontemporal.f90
deleted file mode 100644
index c25aec3347bf1..0000000000000
--- a/flang/test/Lower/OpenMP/simd-nontemporal.f90
+++ /dev/null
@@ -1,67 +0,0 @@
-! Test nontemporal clause
-! RUN: %flang_fc1 -emit-fir -fopenmp -fopenmp-version=50 %s -o - | FileCheck %s
-! RUN: bbc -emit-fir -fopenmp -fopenmp-version=50 %s -o - | FileCheck %s
-
-
-! CHECK-LABEL: func @_QPsimd_with_nontemporal_clause
-subroutine simd_with_nontemporal_clause(n)
-  ! CHECK: %[[A_DECL:.*]] = fir.declare %{{.*}} {uniq_name = "_QFsimd_with_nontemporal_clauseEa"} : (!fir.ref<i32>) -> !fir.ref<i32>
-  ! CHECK: %[[C_DECL:.*]] = fir.declare %{{.*}} {uniq_name = "_QFsimd_with_nontemporal_clauseEc"} : (!fir.ref<i32>) -> !fir.ref<i32>
-  integer :: i, n
-  integer :: A, B, C
-  ! CHECK: omp.simd nontemporal(%[[A_DECL]], %[[C_DECL]] : !fir.ref<i32>, !fir.ref<i32>) private(@_QFsimd_with_nontemporal_clauseEi_private_i32 %8 -> %arg1 : !fir.ref<i32>) {
-  ! CHECK-NEXT: omp.loop_nest (%{{.*}}) : i32 = (%{{.*}}) to (%{{.*}}) inclusive step (%{{.*}}) {
-  !$OMP SIMD NONTEMPORAL(A, C)
-  do i = 1, n
-    ! CHECK:  %[[LOAD:.*]] = fir.load %[[A_DECL]] {nontemporal} : !fir.ref<i32>
-    C = A + B
-    ! CHECK: %[[ADD_VAL:.*]] = arith.addi %{{.*}}, %{{.*}} : i32
-    ! CHECK: fir.store %[[ADD_VAL]] to %[[C_DECL]] {nontemporal} : !fir.ref<i32>
-  end do
-  !$OMP END SIMD
-end subroutine
-
-! CHECK-LABEL:  func.func @_QPsimd_nontemporal_allocatable
-subroutine simd_nontemporal_allocatable(x, y)
-  integer, allocatable :: x(:)
-  integer :: y
-  allocate(x(100))
-  ! CHECK:  %[[X_DECL:.*]] = fir.declare %{{.*}} dummy_scope %{{.*}} {fortran_attrs = #fir.var_attrs<allocatable>, 
-  ! CHECK-SAME: uniq_name = "_QFsimd_nontemporal_allocatableEx"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.dscope) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
-  ! CHECK:  omp.simd nontemporal(%[[X_DECL]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) private(@_QFsimd_nontemporal_allocatableEi_private_i32 %2 -> %arg2 : !fir.ref<i32>) {
-  ! CHECK:   omp.loop_nest (%{{.*}}) : i32 = (%{{.*}}) to (%{{.*}}) inclusive step (%{{.*}}) {
-  !$omp simd nontemporal(x)
-  do i=1,100
-    ! CHECK:  %[[VAL1:.*]] = fir.load %[[X_DECL]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
-    ! CHECK:  %[[BOX_ADDR:.*]] = fir.box_addr %[[VAL1]] : (!fir.box<!fir.heap<!fir.array<?xi32>>>) -> !fir.heap<!fir.array<?xi32>>
-    ! CHECK:  %[[ARR_COOR:.*]] = fir.array_coor %[[BOX_ADDR]](%{{.*}}) %{{.*}} : (!fir.heap<!fir.array<?xi32>>, !fir.shapeshift<1>, i64) -> !fir.ref<i32>
-    ! CHECK:  %[[VAL2:.*]] = fir.load %[[ARR_COOR]] {nontemporal} : !fir.ref<i32>
-  x(i) = x(i) + y
-    ! CHECK:  fir.store %{{.*}} to %{{.*}} {nontemporal} : !fir.ref<i32>
-  end do
-  !$omp end simd
-end subroutine
-
-! CHECK-LABEL:  func.func @_QPsimd_nontemporal_pointers
-subroutine simd_nontemporal_pointers(a, c)
-   integer :: b, i
-   integer :: n
-   integer, pointer, intent(in):: a(:)
-   integer, pointer, intent(out) :: c(:)
-   ! CHECK:  %[[A_DECL:.*]] = fir.declare  %{{.*}} dummy_scope %{{.*}} {fortran_attrs = #fir.var_attrs<intent_in, pointer>, 
-   ! CHECK-SAME: uniq_name = "_QFsimd_nontemporal_pointersEa"} : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, !fir.dscope) -> !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>
-   ! CHECK:  %[[C_DECL:.*]] = fir.declare %{{.*}} dummy_scope %{{.*}} {fortran_attrs = #fir.var_attrs<intent_out, pointer>, 
-   ! CHECK-SAME: uniq_name = "_QFsimd_nontemporal_pointersEc"} : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, !fir.dscope) -> !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>
-   !$OMP SIMD NONTEMPORAL(a,c)
-   do i = 1, n
-      ! CHECK: %[[VAL1:.*]] = fir.load %[[A_DECL]] : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>
-      ! CHECK: %[[VAL2:.*]] = fir.array_coor %[[VAL1]](%{{.*}}) %{{.*}} : (!fir.box<!fir.ptr<!fir.array<?xi32>>>, !fir.shift<1>, i64) -> !fir.ref<i32>
-      ! CHECK: %[[VAL3:.*]] = fir.load %[[VAL2]] {nontemporal} : !fir.ref<i32>
-      c(i) = a(i) + b
-      ! CHECK: %[[VAL4:.*]] = fir.load %[[C_DECL]] : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>
-      ! CHECK: %[[VAL5:.*]] = fir.array_coor %[[VAL4]](%{{.*}}) %{{.*}} : (!fir.box<!fir.ptr<!fir.array<?xi32>>>, !fir.shift<1>, i64) -> !fir.ref<i32>
-      ! CHECK: fir.store %{{.*}} to %[[VAL5]] {nontemporal} : !fir.ref<i32>
-   end do
-   !$OMP END SIMD
-end subroutine
-
diff --git a/flang/test/Lower/OpenMP/simd-nontemporal.mlir b/flang/test/Lower/OpenMP/simd-nontemporal.mlir
new file mode 100644
index 0000000000000..e43564c2bdb3c
--- /dev/null
+++ b/flang/test/Lower/OpenMP/simd-nontemporal.mlir
@@ -0,0 +1,101 @@
+// Test lower-nontemporal pass
+// RUN: fir-opt --lower-nontemporal %s | FileCheck %s
+
+// CHECK-LABEL: func @_QPsimd_with_nontemporal_clause
+func.func @_QPsimd_with_nontemporal_clause(%arg0: !fir.ref<i32> {fir.bindc_name = "n"}) {
+    %c1_i32 = arith.constant 1 : i32
+    %0 = fir.dummy_scope : !fir.dscope
+    %1 = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFsimd_with_nontemporal_clauseEa"}
+    // CHECK: %[[A_DECL:.*]] = fir.declare %{{.*}} {uniq_name = "_QFsimd_with_nontemporal_clauseEa"} : (!fir.ref<i32>) -> !fir.ref<i32>
+    // CHECK: %[[C_DECL:.*]] = fir.declare %{{.*}} {uniq_name = "_QFsimd_with_nontemporal_clauseEc"} : (!fir.ref<i32>) -> !fir.ref<i32>
+    %2 = fir.declare %1 {uniq_name = "_QFsimd_with_nontemporal_clauseEa"} : (!fir.ref<i32>) -> !fir.ref<i32>
+    %3 = fir.alloca i32 {bindc_name = "b", uniq_name = "_QFsimd_with_nontemporal_clauseEb"}
+    %4 = fir.declare %3 {uniq_name = "_QFsimd_with_nontemporal_clauseEb"} : (!fir.ref<i32>) -> !fir.ref<i32>
+    %5 = fir.alloca i32 {bindc_name = "c", uniq_name = "_QFsimd_with_nontemporal_clauseEc"}
+    %6 = fir.declare %5 {uniq_name = "_QFsimd_with_nontemporal_clauseEc"} : (!fir.ref<i32>) -> !fir.ref<i32>
+    %7 = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFsimd_with_nontemporal_clauseEi"}
+    %8 = fir.declare %7 {uniq_name = "_QFsimd_with_nontemporal_clauseEi"} : (!fir.ref<i32>) -> !fir.ref<i32>
+    %9 = fir.declare %arg0 dummy_scope %0 {uniq_name = "_QFsimd_with_nontemporal_clauseEn"} : (!fir.ref<i32>, !fir.dscope) -> !fir.ref<i32>
+    %10 = fir.load %9 : !fir.ref<i32>
+    // CHECK: omp.simd nontemporal(%[[A_DECL]], %[[C_DECL]] : !fir.ref<i32>, !fir.ref<i32>) private(@_QFsimd_with_nontemporal_clauseEi_private_i32 %8 -> %arg1 : !fir.ref<i32>) {
+    // CHECK-NEXT: omp.loop_nest (%{{.*}}) : i32 = (%{{.*}}) to (%{{.*}}) inclusive step (%{{.*}}) {
+    omp.simd nontemporal(%2, %6 : !fir.ref<i32>, !fir.ref<i32>) private(@_QFsimd_with_nontemporal_clauseEi_private_i32 %8 -> %arg1 : !fir.ref<i32>) {
+      omp.loop_nest (%arg2) : i32 = (%c1_i32) to (%10) inclusive step (%c1_i32) {
+        %11 = fir.declare %arg1 {uniq_name = "_QFsimd_with_nontemporal_clauseEi"} : (!fir.ref<i32>) -> !fir.ref<i32>
+        fir.store %arg2 to %11 : !fir.ref<i32>
+        // CHECK:  %[[LOAD:.*]] = fir.load %[[A_DECL]] {nontemporal} : !fir.ref<i32>
+        %12 = fir.load %2 : !fir.ref<i32>
+        %13 = fir.load %4 : !fir.ref<i32>
+        %14 = arith.addi %12, %13 : i32
+        // CHECK: %[[ADD_VAL:.*]] = arith.addi %{{.*}}, %{{.*}} : i32
+        // CHECK: fir.store %[[ADD_VAL]] to %[[C_DECL]] {nontemporal} : !fir.ref<i32>
+        fir.store %14 to %6 : !fir.ref<i32>
+        omp.yield
+      }
+    }
+    return
+  }
+
+//  CHECK-LABEL:  func.func @_QPsimd_nontemporal_allocatable
+func.func @_QPsimd_nontemporal_allocatable(%arg0: !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>> {fir.bindc_name = "x"}, %arg1: !fir.ref<i32> {fir.bindc_name = "y"}) {
+    %c1_i32 = arith.constant 1 : i32
+    %c0 = arith.constant 0 : index
+    %c100_i32 = arith.constant 100 : i32
+    %0 = fir.dummy_scope : !fir.dscope
+    %1 = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFsimd_nontemporal_allocatableEi"}
+    %2 = fir.declare %1 {uniq_name = "_QFsimd_nontemporal_allocatableEi"} : (!fir.ref<i32>) -> !fir.ref<i32>
+    // CHECK:  %[[X_DECL:.*]] = fir.declare %{{.*}} dummy_scope %{{.*}} {fortran_attrs = #fir.var_attrs<allocatable>, 
+    // CHECK-SAME: uniq_name = "_QFsimd_nontemporal_allocatableEx"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.dscope) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+    %3 = fir.declare %arg0 dummy_scope %0 {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsimd_nontemporal_allocatableEx"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.dscope) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+    %4 = fir.declare %arg1 dummy_scope %0 {uniq_name = "_QFsimd_nontemporal_allocatableEy"} : (!fir.ref<i32>, !fir.dscope) -> !fir.ref<i32>
+    %5 = fir.convert %c100_i32 : (i32) -> index
+    %6 = arith.cmpi sgt, %5, %c0 : index
+    %7 = arith.select %6, %5, %c0 : index
+    %8 = fir.allocmem !fir.array<?xi32>, %7 {fir.must_be_heap = true, uniq_name = "_QFsimd_nontemporal_allocatableEx.alloc"}
+    %9 = fir.shape %7 : (index) -> !fir.shape<1>
+    %10 = fir.embox %8(%9) : (!fir.heap<!fir.array<?xi32>>, !fir.shape<1>) -> !fir.box<!fir.heap<!fir.array<?xi32>>>
+    fir.store %10 to %3 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+    // CHECK:  omp.simd nontemporal(%[[X_DECL]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) private(@_QFsimd_nontemporal_allocatableEi_private_i32 %2 -> %arg2 : !fir.ref<i32>) {
+    // CHECK:   omp.loop_nest (%{{.*}}) : i32 = (%{{.*}}) to (%{{.*}}) inclusive step (%{{.*}}) {
+    omp.simd nontemporal(%3 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) private(@_QFsimd_nontemporal_allocatableEi_private_i32 %2 -> %arg2 : !fir.ref<i32>) {
+      omp.loop_nest (%arg3) : i32 = (%c1_i32) to (%c100_i32) inclusive step (%c1_i32) {
+        %16 = fir.declare %arg2 {uniq_name = "_QFsimd_nontemporal_allocatableEi"} : (!fir.ref<i32>) -> !fir.ref<i32>
+        fir.store %arg3 to %16 : !fir.ref<i32>
+        // CHECK:  %[[VAL1:.*]] = fir.load %[[X_DECL]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+        %17 = fir.load %3 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+        %18 = fir.load %16 : !fir.ref<i32>
+        %19 = fir.convert %18 : (i32) -> i64
+        // CHECK:  %[[BOX_ADDR:.*]] = fir.box_addr %[[VAL1]] : (!fir.box<!fir.heap<!fir.array<?xi32>>>) -> !fir.heap<!fir.array<?xi32>>
+        %20 = fir.box_addr %17 : (!fir.box<!fir.heap<!fir.array<?xi32>>>) -> !fir.heap<!fir.array<?xi32>>
+        %c0_0 = arith.constant 0 : index
+        %21:3 = fir.box_dims %17, %c0_0 : (!fir.box<!fir.heap<!fir.array<?xi32>>>, index) -> (index, index, index)
+        %22 = fir.shape_shift %21#0, %21#1 : (index, index) -> !fir.shapeshift<1>
+        // CHECK:  %[[ARR_COOR:.*]] = fir.array_coor %[[BOX_ADDR]](%{{.*}}) %{{.*}} : (!fir.heap<!fir.array<?xi32>>, !fir.shapeshift<1>, i64) -> !fir.ref<i32>
+        %23 = fir.array_coor %20(%22) %19 : (!fir.heap<!fir.array<?xi32>>, !fir.shapeshift<1>, i64) -> !fir.ref<i32>
+        // CHECK:  %[[VAL2:.*]] = fir.load %[[ARR_COOR]] {nontemporal} : !fir.ref<i32>
+        %24 = fir.load %23 : !fir.ref<i32>
+        %25 = fir.load %4 : !fir.ref<i32>
+        %26 = arith.addi %24, %25 : i32
+        %27 = fir.load %3 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+        %28 = fir.load %16 : !fir.ref<i32>
+        %29 = fir.convert %28 : (i32) -> i64
+        %30 = fir.box_addr %27 : (!fir.box<!fir.heap<!fir.array<?xi32>>>) -> !fir.heap<!fir.array<?xi32>>
+        %c0_1 = arith.constant 0 : index
+        %31:3 = fir.box_dims %27, %c0_1 : (!fir.box<!fir.heap<!fir.array<?xi32>>>, index) -> (index, index, index)
+        %32 = fir.shape_shift %31#0, %31#1 : (index, index) -> !fir.shapeshift<1>
+        %33 = fir.array_coor %30(%32) %29 : (!fir.heap<!fir.array<?xi32>>, !fir.shapeshift<1>, i64) -> !fir.ref<i32>
+        // CHECK:  fir.store %{{.*}} to %{{.*}} {nontemporal} : !fir.ref<i32>
+        fir.store %26 to %33 : !fir.ref<i32>
+        omp.yield
+      }
+    }
+    %11 = fir.load %3 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+    %12 = fir.box_addr %11 : (!fir.box<!fir.heap<!fir.array<?xi32>>>) -> !fir.heap<!fir.array<?xi32>>
+    fir.freemem %12 : !fir.heap<!fir.array<?xi32>>
+    %13 = fir.zero_bits !fir.heap<!fir.array<?xi32>>
+    %14 = fir.shape %c0 : (index) -> !fir.shape<1>
+    %15 = fir.embox %13(%14) : (!fir.heap<!fir.array<?xi32>>, !fir.shape<1>) -> !fir.box<!fir.heap<!fir.array<?xi32>>>
+    fir.store %15 to %3 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+    return
+  }
+

>From 1802fcd4ede70e41d7d5e0e7945a8731cd95ab6c Mon Sep 17 00:00:00 2001
From: Kaviya Rajendiran <kaviyara2000 at gmail.com>
Date: Wed, 23 Apr 2025 12:20:42 +0530
Subject: [PATCH 5/5] Addressed review comments: Modified
 getBaseOperand()function and added a testcase
 'convert-nontemporal-to-llvm.fir'

---
 .../lib/Optimizer/OpenMP/LowerNontemporal.cpp |  25 ++--
 flang/lib/Optimizer/Passes/Pipelines.cpp      |   3 +-
 .../test/Fir/convert-nontemporal-to-llvm.fir  | 111 ++++++++++++++++++
 .../simd-nontemporal.fir}                     |   4 +-
 4 files changed, 131 insertions(+), 12 deletions(-)
 create mode 100644 flang/test/Fir/convert-nontemporal-to-llvm.fir
 rename flang/test/{Lower/OpenMP/simd-nontemporal.mlir => Fir/simd-nontemporal.fir} (97%)

diff --git a/flang/lib/Optimizer/OpenMP/LowerNontemporal.cpp b/flang/lib/Optimizer/OpenMP/LowerNontemporal.cpp
index 0b0dcac6f4e6f..e83680dc36237 100644
--- a/flang/lib/Optimizer/OpenMP/LowerNontemporal.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerNontemporal.cpp
@@ -11,9 +11,11 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "flang/Optimizer/Dialect/FIRCG/CGOps.h"
 #include "flang/Optimizer/Dialect/FIROpsSupport.h"
 #include "flang/Optimizer/OpenMP/Passes.h"
 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "llvm/ADT/TypeSwitch.h"
 
 using namespace mlir;
 
@@ -31,16 +33,19 @@ class LowerNontemporalPass
 
     std::function<mlir::Value(mlir::Value)> getBaseOperand =
         [&](mlir::Value operand) -> mlir::Value {
-      if (mlir::isa<mlir::BlockArgument>(operand) ||
-          (mlir::isa<fir::AllocaOp>(operand.getDefiningOp())) ||
-          (mlir::isa<fir::DeclareOp>(operand.getDefiningOp())))
-        return operand;
-
-      Operation *definingOp = operand.getDefiningOp();
-      if (definingOp) {
-        for (Value srcOp : definingOp->getOperands()) {
-          return getBaseOperand(srcOp);
-        }
+      auto *defOp = operand.getDefiningOp();
+      while (defOp) {
+        llvm::TypeSwitch<Operation *>(defOp)
+            .Case<fir::ArrayCoorOp, fir::cg::XArrayCoorOp, fir::LoadOp>(
+                [&](auto op) {
+                  operand = op.getMemref();
+                  defOp = operand.getDefiningOp();
+                })
+            .Case<fir::BoxAddrOp>([&](auto op) {
+              operand = op.getVal();
+              defOp = operand.getDefiningOp();
+            })
+            .Default([&](auto op) { defOp = nullptr; });
       }
       return operand;
     };
diff --git a/flang/lib/Optimizer/Passes/Pipelines.cpp b/flang/lib/Optimizer/Passes/Pipelines.cpp
index 24e943760c7a5..11697defa1e48 100644
--- a/flang/lib/Optimizer/Passes/Pipelines.cpp
+++ b/flang/lib/Optimizer/Passes/Pipelines.cpp
@@ -345,9 +345,10 @@ void createDefaultFIRCodeGenPassPipeline(mlir::PassManager &pm,
        config.ApproxFuncFPMath, config.NoSignedZerosFPMath, config.UnsafeFPMath,
        ""}));
 
-  if (config.EnableOpenMP)
+  if (config.EnableOpenMP) {
     pm.addNestedPass<mlir::func::FuncOp>(
         flangomp::createLowerNontemporalPass());
+  }
 
   fir::addFIRToLLVMPass(pm, config);
 }
diff --git a/flang/test/Fir/convert-nontemporal-to-llvm.fir b/flang/test/Fir/convert-nontemporal-to-llvm.fir
new file mode 100644
index 0000000000000..6200ef1c621d7
--- /dev/null
+++ b/flang/test/Fir/convert-nontemporal-to-llvm.fir
@@ -0,0 +1,111 @@
+// Test lower-nontemporal pass
+// RUN: fir-opt --fir-to-llvm-ir %s | FileCheck %s --check-prefixes=CHECK-LABEL,CHECK
+
+// CHECK-LABEL:  llvm.func @_QPtest() 
+// CHECK:    %[[CONST_VAL:.*]] = llvm.mlir.constant(1 : i64) : i64
+// CHECK:    %[[VAL1:.*]] = llvm.alloca %[[CONST_VAL]] x i32 {bindc_name = "n"} : (i64) -> !llvm.ptr
+// CHECK:    %[[CONST_VAL1:.*]] = llvm.mlir.constant(1 : i64) : i64
+// CHECK:    %[[VAL2:.*]] = llvm.alloca %[[CONST_VAL1]] x i32 {bindc_name = "i"} : (i64) -> !llvm.ptr
+// CHECK:    %[[CONST_VAL2:.*]] = llvm.mlir.constant(1 : i64) : i64
+// CHECK:    %[[VAL3:.*]] = llvm.alloca %[[CONST_VAL2]] x i32 {bindc_name = "c"} : (i64) -> !llvm.ptr
+// CHECK:    %[[CONST_VAL3:.*]] = llvm.mlir.constant(1 : i64) : i64
+// CHECK:    %[[VAL4:.*]] = llvm.alloca %[[CONST_VAL3]] x i32 {bindc_name = "b"} : (i64) -> !llvm.ptr
+// CHECK:    %[[CONST_VAL4:.*]] = llvm.mlir.constant(1 : i64) : i64
+// CHECK:    %[[VAL5:.*]] = llvm.alloca %[[CONST_VAL4]] x i32 {bindc_name = "a"} : (i64) -> !llvm.ptr
+// CHECK:    %[[CONST_VAL5:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK:    %[[VAL6:.*]] = llvm.load %[[VAL1]] : !llvm.ptr -> i32
+// CHECK:    omp.simd nontemporal(%[[VAL5]], %[[VAL3]] : !llvm.ptr, !llvm.ptr) private(@_QFtestEi_private_i32 %[[VAL2]] -> %arg0 : !llvm.ptr) {
+// CHECK:      omp.loop_nest (%{{.*}}) : i32 = (%[[CONST_VAL5]]) to (%[[VAL6]]) inclusive step (%[[CONST_VAL5]]) {
+// CHECK:        llvm.store %{{.*}}, %{{.*}} : i32, !llvm.ptr
+// CHECK:        %[[VAL8:.*]] = llvm.load %[[VAL5]] {nontemporal} : !llvm.ptr -> i32
+// CHECK:        %[[VAL9:.*]] = llvm.load %[[VAL4]] : !llvm.ptr -> i32
+// CHECK:        %[[VAL10:.*]] = llvm.add %[[VAL8]], %[[VAL9]] : i32
+// CHECK:        llvm.store %[[VAL10]], %[[VAL3]] {nontemporal} : i32, !llvm.ptr
+// CHECK:        omp.yield
+// CHECK:      }
+// CHECK:    }
+
+ func.func @_QPtest() {
+    %c1_i32 = arith.constant 1 : i32
+    %0 = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFtestEa"}
+    %1 = fir.alloca i32 {bindc_name = "b", uniq_name = "_QFtestEb"}
+    %2 = fir.alloca i32 {bindc_name = "c", uniq_name = "_QFtestEc"}
+    %3 = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFtestEi"}
+    %4 = fir.alloca i32 {bindc_name = "n", uniq_name = "_QFtestEn"}
+    %5 = fir.load %4 : !fir.ref<i32>
+    omp.simd nontemporal(%0, %2 : !fir.ref<i32>, !fir.ref<i32>) private(@_QFtestEi_private_i32 %3 -> %arg0 : !fir.ref<i32>) {
+      omp.loop_nest (%arg1) : i32 = (%c1_i32) to (%5) inclusive step (%c1_i32) {
+        fir.store %arg1 to %arg0 : !fir.ref<i32>
+        %6 = fir.load %0 {nontemporal}: !fir.ref<i32>
+        %7 = fir.load %1 : !fir.ref<i32>
+        %8 = arith.addi %6, %7 : i32
+        fir.store %8 to %2 {nontemporal} : !fir.ref<i32>
+        omp.yield
+      }
+    }
+    return
+  }
+
+// CHECK-LABEL:  llvm.func @_QPsimd_nontemporal_allocatable
+// CHECK:    %[[CONST_VAL:.*]] = llvm.mlir.constant(1 : i64) : i64
+// CHECK:    %[[ALLOCA2:.*]] = llvm.alloca %[[CONST_VAL]] x i32 {bindc_name = "i"} : (i64) -> !llvm.ptr
+// CHECK:    %[[IDX_VAL:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK:    %[[CONST_VAL1:.*]] = llvm.mlir.constant(0 : index) : i64
+// CHECK:    %[[END_IDX:.*]] = llvm.mlir.constant(100 : i32) : i32
+// CHECK:    omp.simd nontemporal(%[[ARG0:.*]] : !llvm.ptr) private(@_QFsimd_nontemporal_allocatableEi_private_i32 %[[ALLOCA2]] -> %[[ARG2:.*]] : !llvm.ptr) {
+// CHECK:      omp.loop_nest (%[[ARG3:.*]]) : i32 = (%[[IDX_VAL]]) to (%[[END_IDX]]) inclusive step (%[[IDX_VAL]]) {
+// CHECK:        llvm.store %[[ARG3]], %[[ARG2]] : i32, !llvm.ptr
+// CHECK:        %[[CONST_VAL2:.*]] = llvm.mlir.constant(48 : i32) : i32
+// CHECK:        "llvm.intr.memcpy"(%[[ALLOCA1:.*]], %[[ARG0]], %[[CONST_VAL2]]) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
+// CHECK:        %[[VAL1:.*]] = llvm.load %[[ARG2]] : !llvm.ptr -> i32
+// CHECK:        %[[VAL2:.*]] = llvm.sext %[[VAL1]] : i32 to i64
+// CHECK:        %[[VAL3:.*]] = llvm.getelementptr %[[ALLOCA1]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
+// CHECK:        %[[VAL4:.*]] = llvm.load %[[VAL3]] : !llvm.ptr -> !llvm.ptr
+// CHECK:        %[[VAL5:.*]] = llvm.getelementptr %[[ALLOCA1]][0, 7, %[[CONST_VAL1]], 0] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
+// CHECK:        %[[VAL6:.*]] = llvm.load %[[VAL5]] : !llvm.ptr -> i64
+// CHECK:        %[[VAL7:.*]] = llvm.getelementptr %[[ALLOCA1]][0, 7, %[[CONST_VAL1]], 1] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
+// CHECK:        %[[VAL8:.*]] = llvm.load %[[VAL7]] : !llvm.ptr -> i64
+// CHECK:        %[[VAL10:.*]] = llvm.mlir.constant(1 : i64) : i64
+// CHECK:        %[[VAL11:.*]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK:        %[[VAL12:.*]] = llvm.sub %[[VAL2]], %[[VAL6]] overflow<nsw> : i64
+// CHECK:        %[[VAL13:.*]] = llvm.mul %[[VAL12]], %[[VAL10]] overflow<nsw> : i64
+// CHECK:        %[[VAL14:.*]] = llvm.mul %[[VAL13]], %[[VAL10]] overflow<nsw> : i64
+// CHECK:        %[[VAL15:.*]] = llvm.add %[[VAL14]], %[[VAL11]] overflow<nsw> : i64
+// CHECK:        %[[VAL16:.*]] = llvm.mul %[[VAL10]], %[[VAL8]] overflow<nsw> : i64
+// CHECK:        %[[VAL17:.*]] = llvm.getelementptr %[[VAL4]][%[[VAL15]]] : (!llvm.ptr, i64) -> !llvm.ptr, i32
+// CHECK:        %[[VAL18:.*]] = llvm.load %[[VAL17]] {nontemporal} : !llvm.ptr -> i32
+// CHECK:        %[[VAL19:.*]] = llvm.load %{{.*}} : !llvm.ptr -> i32
+// CHECK:        %[[VAL20:.*]] = llvm.add %[[VAL18]], %[[VAL19]] : i32
+// CHECK:        llvm.store %[[VAL20]], %[[VAL17]] {nontemporal} : i32, !llvm.ptr
+// CHECK:        omp.yield
+// CHECK:      }
+// CHECK:    }
+// CHECK:    llvm.return
+
+  func.func @_QPsimd_nontemporal_allocatable(%arg0: !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>> {fir.bindc_name = "x"}, %arg1: !fir.ref<i32> {fir.bindc_name = "y"}) {
+   %c100 = arith.constant 100 : index
+   %c1_i32 = arith.constant 1 : i32
+    %c0 = arith.constant 0 : index
+    %c100_i32 = arith.constant 100 : i32
+    %0 = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFsimd_nontemporal_allocatableEi"}
+    %1 = fir.allocmem !fir.array<?xi32>, %c100 {fir.must_be_heap = true, uniq_name = "_QFsimd_nontemporal_allocatableEx.alloc"}
+    %2 = fircg.ext_embox %1(%c100) : (!fir.heap<!fir.array<?xi32>>, index) -> !fir.box<!fir.heap<!fir.array<?xi32>>>
+    fir.store %2 to %arg0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+    omp.simd nontemporal(%arg0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) private(@_QFsimd_nontemporal_allocatableEi_private_i32 %0 -> %arg2 : !fir.ref<i32>) {
+      omp.loop_nest (%arg3) : i32 = (%c1_i32) to (%c100_i32) inclusive step (%c1_i32) {
+        fir.store %arg3 to %arg2 : !fir.ref<i32>
+        %7 = fir.load %arg0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+        %8 = fir.load %arg2 : !fir.ref<i32>
+        %9 = fir.convert %8 : (i32) -> i64
+        %10 = fir.box_addr %7 : (!fir.box<!fir.heap<!fir.array<?xi32>>>) -> !fir.heap<!fir.array<?xi32>>
+        %11:3 = fir.box_dims %7, %c0 : (!fir.box<!fir.heap<!fir.array<?xi32>>>, index) -> (index, index, index)
+        %12 = fircg.ext_array_coor %10(%11#1) origin %11#0<%9> : (!fir.heap<!fir.array<?xi32>>, index, index, i64) -> !fir.ref<i32>
+        %13 = fir.load %12 {nontemporal} : !fir.ref<i32> 
+        %14 = fir.load %arg1 : !fir.ref<i32>
+        %15 = arith.addi %13, %14 : i32
+        fir.store %15 to %12 {nontemporal} : !fir.ref<i32>
+        omp.yield
+      }
+    }
+    return
+  }
diff --git a/flang/test/Lower/OpenMP/simd-nontemporal.mlir b/flang/test/Fir/simd-nontemporal.fir
similarity index 97%
rename from flang/test/Lower/OpenMP/simd-nontemporal.mlir
rename to flang/test/Fir/simd-nontemporal.fir
index e43564c2bdb3c..31051ff52f9bd 100644
--- a/flang/test/Lower/OpenMP/simd-nontemporal.mlir
+++ b/flang/test/Fir/simd-nontemporal.fir
@@ -25,6 +25,7 @@ func.func @_QPsimd_with_nontemporal_clause(%arg0: !fir.ref<i32> {fir.bindc_name
         fir.store %arg2 to %11 : !fir.ref<i32>
         // CHECK:  %[[LOAD:.*]] = fir.load %[[A_DECL]] {nontemporal} : !fir.ref<i32>
         %12 = fir.load %2 : !fir.ref<i32>
+        // CHECK:  %[[LOAD1:.*]] = fir.load %{{.*}} : !fir.ref<i32>
         %13 = fir.load %4 : !fir.ref<i32>
         %14 = arith.addi %12, %13 : i32
         // CHECK: %[[ADD_VAL:.*]] = arith.addi %{{.*}}, %{{.*}} : i32
@@ -63,6 +64,7 @@ func.func @_QPsimd_nontemporal_allocatable(%arg0: !fir.ref<!fir.box<!fir.heap<!f
         fir.store %arg3 to %16 : !fir.ref<i32>
         // CHECK:  %[[VAL1:.*]] = fir.load %[[X_DECL]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
         %17 = fir.load %3 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+        // CHECK: %[[VAL2:.*]] = fir.load %{{.*}} : !fir.ref<i32>
         %18 = fir.load %16 : !fir.ref<i32>
         %19 = fir.convert %18 : (i32) -> i64
         // CHECK:  %[[BOX_ADDR:.*]] = fir.box_addr %[[VAL1]] : (!fir.box<!fir.heap<!fir.array<?xi32>>>) -> !fir.heap<!fir.array<?xi32>>
@@ -72,7 +74,7 @@ func.func @_QPsimd_nontemporal_allocatable(%arg0: !fir.ref<!fir.box<!fir.heap<!f
         %22 = fir.shape_shift %21#0, %21#1 : (index, index) -> !fir.shapeshift<1>
         // CHECK:  %[[ARR_COOR:.*]] = fir.array_coor %[[BOX_ADDR]](%{{.*}}) %{{.*}} : (!fir.heap<!fir.array<?xi32>>, !fir.shapeshift<1>, i64) -> !fir.ref<i32>
         %23 = fir.array_coor %20(%22) %19 : (!fir.heap<!fir.array<?xi32>>, !fir.shapeshift<1>, i64) -> !fir.ref<i32>
-        // CHECK:  %[[VAL2:.*]] = fir.load %[[ARR_COOR]] {nontemporal} : !fir.ref<i32>
+        // CHECK:  %[[VAL3:.*]] = fir.load %[[ARR_COOR]] {nontemporal} : !fir.ref<i32>
         %24 = fir.load %23 : !fir.ref<i32>
         %25 = fir.load %4 : !fir.ref<i32>
         %26 = arith.addi %24, %25 : i32



More information about the llvm-commits mailing list