[llvm] [mlir] [MLIR][OpenMP] Lowering nontemporal clause to LLVM IR for SIMD directive (PR #118751)

Kaviya Rajendiran via llvm-commits llvm-commits at lists.llvm.org
Fri Dec 6 02:36:15 PST 2024


https://github.com/kaviya2510 updated https://github.com/llvm/llvm-project/pull/118751

>From 792b974f435e6f0f1c2bc602c6a1d056be85a531 Mon Sep 17 00:00:00 2001
From: Kaviya Rajendiran <kaviyara2000 at gmail.com>
Date: Fri, 6 Dec 2024 14:49:39 +0530
Subject: [PATCH] [MLIR][OpenMP] Lowering nontemporal clause to LLVM IR for
 SIMD directive

---
 .../llvm/Frontend/OpenMP/OMPIRBuilder.h       |  2 +-
 llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp     | 86 ++++++++++++++++-
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      | 16 ++--
 .../Target/LLVMIR/openmp-nontemporal.mlir     | 96 +++++++++++++++++++
 mlir/test/Target/LLVMIR/openmp-todo.mlir      | 13 ---
 5 files changed, 192 insertions(+), 21 deletions(-)
 create mode 100644 mlir/test/Target/LLVMIR/openmp-nontemporal.mlir

diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index a97deafa3683cf..15af2bbc9c949f 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -1220,7 +1220,7 @@ class OpenMPIRBuilder {
   void applySimd(CanonicalLoopInfo *Loop,
                  MapVector<Value *, Value *> AlignedVars, Value *IfCond,
                  omp::OrderKind Order, ConstantInt *Simdlen,
-                 ConstantInt *Safelen);
+                 ConstantInt *Safelen, ArrayRef<Value *> NontempralVars = {});
 
   /// Generator for '#omp flush'
   ///
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 1fae138b449ed5..ea30022760f2ee 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -5265,10 +5265,87 @@ OpenMPIRBuilder::getOpenMPDefaultSimdAlign(const Triple &TargetTriple,
   return 0;
 }
 
+static void appendNontemporalVars(BasicBlock *Block,
+                                  SmallVectorImpl<Value *> &NontemporalVars) {
+  for (Instruction &I : *Block) {
+    if (const CallInst *CI = dyn_cast<CallInst>(&I)) {
+      if (CI->getIntrinsicID() == Intrinsic::memcpy) {
+        llvm::Value *DestPtr = CI->getArgOperand(0);
+        llvm::Value *SrcPtr = CI->getArgOperand(1);
+        for (const llvm::Value *Var : NontemporalVars) {
+          if (Var == SrcPtr) {
+            NontemporalVars.push_back(DestPtr);
+            break;
+          }
+        }
+      }
+    }
+  }
+}
+
+/** Attach nontemporal metadata to the load/store instructions of nontemporal
+ *  variables of \p Block
+ * Nontemporal variables may be a scalar, fixed size or allocatable
+ * or pointer array
+ *
+ * !$omp simd nontemporal(a,b)    ;; where a is scalar
+ *  %a = alloca i32, i64 1        ;; (allocate a)
+ *  %1 = load i32, ptr %a         ;; (mark LOAD as nontemporal)
+ *  store i32 11, ptr %1          ;; (mark STORE as nontemporal)
+ *
+ * !$omp simd nontemporal(a)                ;; where a is an fixed size array
+ *  %a = alloca [20 x i32], i64 1           ;; (allocate a)
+ *  %2 = getelementptr i32, ptr %a, i64 %1  ;; (compute the address of arr ele)
+ *  %3 = load i32, ptr %2                   ;; (mark LOAD as nontemporal)
+ *
+ * !$omp simd nontemporal(a),              ;; where a is an allocatable
+ *  %struct.a = { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] }
+ *  %a = alloca %struct.a                  ;; (allocate a)
+ *  %a_copy =  alloca %struct.a
+ *  call void @llvm.memcpy.p0.p0.i32(ptr %a_copy, ptr %a, i32 48, i1 false)
+ *  %1 = getelementptr %struct.a, ptr %a_copy, i32 0, i32 0
+ *  %2 = load ptr, ptr %1, align 8
+ *  %3 = getelementptr i32, ptr %2, i64 %52
+ *  %4 = load i32, ptr %3                 ;; (mark LOAD as nontemporal)
+ *
+ * It works the same way for store
+ */
+static void addNonTemporalMetadata(BasicBlock *Block, MDNode *Nontemporal,
+                                   SmallVectorImpl<Value *> &NontemporalVars) {
+  appendNontemporalVars(Block, NontemporalVars);
+  for (Instruction &I : *Block) {
+    llvm::Value *mem_ptr = nullptr;
+    bool MetadataFlag = true;
+    if (llvm::LoadInst *li = dyn_cast<llvm::LoadInst>(&I)) {
+      if (!(li->getType()->isPointerTy()))
+        mem_ptr = li->getPointerOperand();
+    } else if (llvm::StoreInst *si = dyn_cast<llvm::StoreInst>(&I))
+      mem_ptr = si->getPointerOperand();
+    if (mem_ptr) {
+      while (mem_ptr && !(isa<llvm::AllocaInst>(mem_ptr))) {
+        if (llvm::GetElementPtrInst *gep =
+                dyn_cast<llvm::GetElementPtrInst>(mem_ptr)) {
+          llvm::Type *sourceType = gep->getSourceElementType();
+          if (sourceType->isStructTy() && gep->getNumIndices() >= 2 &&
+              !(gep->hasAllZeroIndices())) {
+            MetadataFlag = false;
+            break;
+          }
+          mem_ptr = gep->getPointerOperand();
+        } else if (llvm::LoadInst *li = dyn_cast<llvm::LoadInst>(mem_ptr))
+          mem_ptr = li->getPointerOperand();
+      }
+      if (MetadataFlag && is_contained(NontemporalVars, mem_ptr))
+        I.setMetadata(LLVMContext::MD_nontemporal, Nontemporal);
+    }
+  }
+}
+
 void OpenMPIRBuilder::applySimd(CanonicalLoopInfo *CanonicalLoop,
                                 MapVector<Value *, Value *> AlignedVars,
                                 Value *IfCond, OrderKind Order,
-                                ConstantInt *Simdlen, ConstantInt *Safelen) {
+                                ConstantInt *Simdlen, ConstantInt *Safelen,
+                                ArrayRef<Value *> NontemporalVarsIn) {
   LLVMContext &Ctx = Builder.getContext();
 
   Function *F = CanonicalLoop->getFunction();
@@ -5365,6 +5442,13 @@ void OpenMPIRBuilder::applySimd(CanonicalLoopInfo *CanonicalLoop,
   }
 
   addLoopMetadata(CanonicalLoop, LoopMDList);
+  SmallVector<Value *> NontemporalVars{NontemporalVarsIn};
+  // Set nontemporal metadata to load and stores of nontemporal values
+  if (NontemporalVars.size()) {
+    MDNode *NontemporalNode = MDNode::getDistinct(Ctx, {});
+    for (BasicBlock *BB : Reachable)
+      addNonTemporalMetadata(BB, NontemporalNode, NontemporalVars);
+  }
 }
 
 /// Create the TargetMachine object to query the backend for optimization
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 063055f8015b81..c313b7355b2ea5 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -191,10 +191,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
     if (!op.getLinearVars().empty() || !op.getLinearStepVars().empty())
       result = todo("linear");
   };
-  auto checkNontemporal = [&todo](auto op, LogicalResult &result) {
-    if (!op.getNontemporalVars().empty())
-      result = todo("nontemporal");
-  };
+
   auto checkNowait = [&todo](auto op, LogicalResult &result) {
     if (op.getNowait())
       result = todo("nowait");
@@ -274,7 +271,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
       .Case([&](omp::SimdOp op) {
         checkAligned(op, result);
         checkLinear(op, result);
-        checkNontemporal(op, result);
         checkPrivate(op, result);
         checkReduction(op, result);
       })
@@ -2230,11 +2226,19 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
 
   llvm::MapVector<llvm::Value *, llvm::Value *> alignedVars;
   llvm::omp::OrderKind order = convertOrderKind(simdOp.getOrder());
+
+  llvm::SmallVector<llvm::Value *> nontemporalVars;
+  mlir::OperandRange nontemporals = simdOp.getNontemporalVars();
+  for (mlir::Value nontemporal : nontemporals) {
+    llvm::Value *nt = moduleTranslation.lookupValue(nontemporal);
+    nontemporalVars.push_back(nt);
+  }
+
   ompBuilder->applySimd(loopInfo, alignedVars,
                         simdOp.getIfExpr()
                             ? moduleTranslation.lookupValue(simdOp.getIfExpr())
                             : nullptr,
-                        order, simdlen, safelen);
+                        order, simdlen, safelen, nontemporalVars);
 
   builder.restoreIP(afterIP);
   return success();
diff --git a/mlir/test/Target/LLVMIR/openmp-nontemporal.mlir b/mlir/test/Target/LLVMIR/openmp-nontemporal.mlir
new file mode 100644
index 00000000000000..f8cee94be4ff7a
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/openmp-nontemporal.mlir
@@ -0,0 +1,96 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// -----
+// CHECK-LABEL: @simd_nontemporal
+llvm.func @simd_nontemporal() {
+  %0 = llvm.mlir.constant(10 : i64) : i64
+  %1 = llvm.mlir.constant(1 : i64) : i64
+  %2 = llvm.alloca %1 x i64 : (i64) -> !llvm.ptr
+  %3 = llvm.alloca %1 x i64 : (i64) -> !llvm.ptr
+  //CHECK: %[[A_ADDR:.*]] = alloca i64, i64 1, align 8
+  //CHECK: %[[B_ADDR:.*]] = alloca i64, i64 1, align 8
+  //CHECK: %[[B:.*]] = load i64, ptr %[[B_ADDR]], align 4, !nontemporal !1, !llvm.access.group !2
+  //CHECK: store i64 %[[B]], ptr %[[A_ADDR]], align 4, !nontemporal !1, !llvm.access.group !2
+  omp.simd nontemporal(%2, %3 : !llvm.ptr, !llvm.ptr) {
+    omp.loop_nest (%arg0) : i64 = (%1) to (%0) inclusive step (%1) {
+      %4 = llvm.load %3 : !llvm.ptr -> i64
+      llvm.store %4, %2 : i64, !llvm.ptr
+      omp.yield
+    }
+  }
+  llvm.return
+}
+
+// -----
+
+//CHECK-LABEL:  define void @_QPtest(ptr %0, ptr %1) {
+llvm.func @_QPtest(%arg0: !llvm.ptr {fir.bindc_name = "n"}, %arg1: !llvm.ptr {fir.bindc_name = "a"}) {
+    %0 = llvm.mlir.constant(1 : i32) : i32
+    %1 = llvm.alloca %0 x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)> {alignment = 8 : i64} : (i32) -> !llvm.ptr
+    %2 = llvm.alloca %0 x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)> {alignment = 8 : i64} : (i32) -> !llvm.ptr
+    %3 = llvm.mlir.constant(1 : i64) : i64
+    %4 = llvm.alloca %3 x i32 {bindc_name = "i", pinned} : (i64) -> !llvm.ptr
+    %6 = llvm.load %arg0 : !llvm.ptr -> i32
+    //CHECK:  %[[A_VAL1:.*]] = alloca { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] }, align 8
+    //CHECK:  %[[A_VAL2:.*]] = alloca { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] }, align 8
+    omp.simd nontemporal(%arg1 : !llvm.ptr) {
+      omp.loop_nest (%arg2) : i32 = (%0) to (%6) inclusive step (%0) {
+        llvm.store %arg2, %4 : i32, !llvm.ptr
+        //CHECK:  call void @llvm.memcpy.p0.p0.i32(ptr %[[A_VAL2]], ptr %1, i32 48, i1 false)
+        %7 = llvm.mlir.constant(48 : i32) : i32
+        "llvm.intr.memcpy"(%2, %arg1, %7) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
+        %8 = llvm.load %4 : !llvm.ptr -> i32
+        %9 = llvm.sext %8 : i32 to i64
+        %10 = llvm.getelementptr %2[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
+        %11 = llvm.load %10 : !llvm.ptr -> !llvm.ptr
+        %12 = llvm.mlir.constant(0 : index) : i64
+        %13 = llvm.getelementptr %2[0, 7, %12, 0] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
+        %14 = llvm.load %13 : !llvm.ptr -> i64
+        %15 = llvm.getelementptr %2[0, 7, %12, 1] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
+        %16 = llvm.load %15 : !llvm.ptr -> i64
+        %17 = llvm.getelementptr %2[0, 7, %12, 2] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
+        %18 = llvm.load %17 : !llvm.ptr -> i64
+        %19 = llvm.mlir.constant(0 : i64) : i64
+        %20 = llvm.sub %9, %14 overflow<nsw> : i64
+        %21 = llvm.mul %20, %3 overflow<nsw> : i64
+        %22 = llvm.mul %21, %3 overflow<nsw> : i64
+        %23 = llvm.add %22,%19 overflow<nsw> : i64
+        %24 = llvm.mul %3, %16 overflow<nsw> : i64
+        //CHECK:  %[[VAL1:.*]] = getelementptr float, ptr {{.*}}, i64 %{{.*}}
+        //CHECK:  %[[LOAD_A:.*]] = load float, ptr %[[VAL1]], align 4, !nontemporal 
+        //CHECK:  %[[RES:.*]] = fadd contract float %[[LOAD_A]], 2.000000e+01
+        %25 = llvm.getelementptr %11[%23] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+        %26 = llvm.load %25 : !llvm.ptr -> f32
+        %27 = llvm.mlir.constant(2.000000e+01 : f32) : f32
+        %28 = llvm.fadd %26, %27 {fastmathFlags = #llvm.fastmath<contract>} : f32
+        //CHECK:  call void @llvm.memcpy.p0.p0.i32(ptr %[[A_VAL1]], ptr %1, i32 48, i1 false)
+        %29 = llvm.mlir.constant(48 : i32) : i32
+        "llvm.intr.memcpy"(%1, %arg1, %29) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
+        %30 = llvm.load %4 : !llvm.ptr -> i32
+        %31 = llvm.sext %30 : i32 to i64
+        %32 = llvm.getelementptr %1[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
+        %33 = llvm.load %32 : !llvm.ptr -> !llvm.ptr
+        %34 = llvm.mlir.constant(0 : index) : i64
+        %35 = llvm.getelementptr %1[0, 7, %34, 0] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
+        %36 = llvm.load %35 : !llvm.ptr -> i64
+        %37 = llvm.getelementptr %1[0, 7, %34, 1] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
+        %38 = llvm.load %37 : !llvm.ptr -> i64
+        %39 = llvm.getelementptr %1[0, 7, %34, 2] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
+        %40 = llvm.load %39 : !llvm.ptr -> i64
+        %41 = llvm.sub %31, %36 overflow<nsw> : i64
+        %42 = llvm.mul %41, %3 overflow<nsw> : i64
+        %43 = llvm.mul %42, %3 overflow<nsw> : i64
+        %44 = llvm.add %43,%19 overflow<nsw> : i64
+        %45 = llvm.mul %3, %38 overflow<nsw> : i64
+        //CHECK:  %[[VAL2:.*]] = getelementptr float, ptr %{{.*}}, i64 %{{.*}}
+        //CHECK:  store float %[[RES]], ptr %[[VAL2]], align 4, !nontemporal 
+        %46 = llvm.getelementptr %33[%44] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+        llvm.store %28, %46 : f32, !llvm.ptr
+        omp.yield
+      }
+    }
+    llvm.return
+  }
+// -----
+ 
+
diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index de797ea2aa3649..21582ec8a8715f 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -155,19 +155,6 @@ llvm.func @simd_linear(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) {
 
 // -----
 
-llvm.func @simd_nontemporal(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) {
-  // expected-error at below {{not yet implemented: Unhandled clause nontemporal in omp.simd operation}}
-  // expected-error at below {{LLVM Translation failed for operation: omp.simd}}
-  omp.simd nontemporal(%x : !llvm.ptr) {
-    omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
-      omp.yield
-    }
-  }
-  llvm.return
-}
-
-// -----
-
 omp.private {type = private} @x.privatizer : !llvm.ptr alloc {
 ^bb0(%arg0: !llvm.ptr):
   %0 = llvm.mlir.constant(1 : i32) : i32



More information about the llvm-commits mailing list