[Mlir-commits] [llvm] [mlir] [MLIR][OpenMP] Lowering nontemporal clause to LLVM IR for SIMD directive (PR #118751)
Kaviya Rajendiran
llvmlistbot at llvm.org
Tue Mar 4 22:51:05 PST 2025
https://github.com/kaviya2510 updated https://github.com/llvm/llvm-project/pull/118751
>From a7a485940e3582e5375a39599a8676d2fabb3388 Mon Sep 17 00:00:00 2001
From: Kaviya Rajendiran <kaviyara2000 at gmail.com>
Date: Fri, 6 Dec 2024 14:49:39 +0530
Subject: [PATCH 1/2] [MLIR][OpenMP] Lowering nontemporal clause to LLVM IR for
SIMD directive
---
.../llvm/Frontend/OpenMP/OMPIRBuilder.h | 2 +-
llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp | 85 +++++++++++++++-
.../OpenMP/OpenMPToLLVMIRTranslation.cpp | 15 +--
.../Target/LLVMIR/openmp-nontemporal.mlir | 96 +++++++++++++++++++
mlir/test/Target/LLVMIR/openmp-todo.mlir | 13 ---
5 files changed, 190 insertions(+), 21 deletions(-)
create mode 100644 mlir/test/Target/LLVMIR/openmp-nontemporal.mlir
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 80b4aa2bd2855..fc726ec6cf4b4 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -1245,7 +1245,7 @@ class OpenMPIRBuilder {
void applySimd(CanonicalLoopInfo *Loop,
MapVector<Value *, Value *> AlignedVars, Value *IfCond,
omp::OrderKind Order, ConstantInt *Simdlen,
- ConstantInt *Safelen);
+ ConstantInt *Safelen, ArrayRef<Value *> NontempralVars = {});
/// Generator for '#omp flush'
///
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index e34e93442ff85..bff201a1377c8 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -5385,10 +5385,86 @@ OpenMPIRBuilder::getOpenMPDefaultSimdAlign(const Triple &TargetTriple,
return 0;
}
+static void appendNontemporalVars(BasicBlock *Block,
+ SmallVectorImpl<Value *> &NontemporalVars) {
+ for (Instruction &I : *Block) {
+ if (const CallInst *CI = dyn_cast<CallInst>(&I)) {
+ if (CI->getIntrinsicID() == Intrinsic::memcpy) {
+ llvm::Value *DestPtr = CI->getArgOperand(0);
+ llvm::Value *SrcPtr = CI->getArgOperand(1);
+ for (const llvm::Value *Var : NontemporalVars) {
+ if (Var == SrcPtr) {
+ NontemporalVars.push_back(DestPtr);
+ break;
+ }
+ }
+ }
+ }
+ }
+}
+
+/** Attach nontemporal metadata to the load/store instructions of nontemporal
+ * variables of \p Block
+ * Nontemporal variables may be a scalar, fixed size or allocatable
+ * or pointer array
+ *
+ * Example scenarios for nontemporal variables:
+ * Case 1: Scalar variable
+ * If the nontemporal variable is a scalar, it is allocated on stack.Load and
+ * store instructions directly access the alloca pointer of the scalar
+ * variable for fetching information about scalar variable or writing
+ * into the scalar variable. Mark those load and store instructions as
+ * non-temporal.
+ *
+ * Case 2: Fixed Size array
+ * If the nontemporal variable is a fixed-size array, it is allocated
+ * as a contiguous block of memory. It uses one GEP instruction, to compute the
+ * address of each individual array elements and perform load or store
+ * operation on it. Mark those load and store instructions as non-temporal.
+ *
+ * Case 3: Allocatable array
+ * For an allocatable array, which might involve runtime type descriptor,
+ * needs to navigate through descriptors using two or more GEP and load
+ * instructions to compute the address of each individual element in an array.
+ * Mark those load or store which access the individual array elements as
+ * non-temporal.
+ */
+static void addNonTemporalMetadata(BasicBlock *Block, MDNode *Nontemporal,
+ SmallVectorImpl<Value *> &NontemporalVars) {
+ appendNontemporalVars(Block, NontemporalVars);
+ for (Instruction &I : *Block) {
+ llvm::Value *mem_ptr = nullptr;
+ bool MetadataFlag = true;
+ if (llvm::LoadInst *li = dyn_cast<llvm::LoadInst>(&I)) {
+ if (!(li->getType()->isPointerTy()))
+ mem_ptr = li->getPointerOperand();
+ } else if (llvm::StoreInst *si = dyn_cast<llvm::StoreInst>(&I))
+ mem_ptr = si->getPointerOperand();
+ if (mem_ptr) {
+ while (mem_ptr && !(isa<llvm::AllocaInst>(mem_ptr))) {
+ if (llvm::GetElementPtrInst *gep =
+ dyn_cast<llvm::GetElementPtrInst>(mem_ptr)) {
+ llvm::Type *sourceType = gep->getSourceElementType();
+ if (sourceType->isStructTy() && gep->getNumIndices() >= 2 &&
+ !(gep->hasAllZeroIndices())) {
+ MetadataFlag = false;
+ break;
+ }
+ mem_ptr = gep->getPointerOperand();
+ } else if (llvm::LoadInst *li = dyn_cast<llvm::LoadInst>(mem_ptr))
+ mem_ptr = li->getPointerOperand();
+ }
+ if (MetadataFlag && is_contained(NontemporalVars, mem_ptr))
+ I.setMetadata(LLVMContext::MD_nontemporal, Nontemporal);
+ }
+ }
+}
+
void OpenMPIRBuilder::applySimd(CanonicalLoopInfo *CanonicalLoop,
MapVector<Value *, Value *> AlignedVars,
Value *IfCond, OrderKind Order,
- ConstantInt *Simdlen, ConstantInt *Safelen) {
+ ConstantInt *Simdlen, ConstantInt *Safelen,
+ ArrayRef<Value *> NontemporalVarsIn) {
LLVMContext &Ctx = Builder.getContext();
Function *F = CanonicalLoop->getFunction();
@@ -5486,6 +5562,13 @@ void OpenMPIRBuilder::applySimd(CanonicalLoopInfo *CanonicalLoop,
}
addLoopMetadata(CanonicalLoop, LoopMDList);
+ SmallVector<Value *> NontemporalVars{NontemporalVarsIn};
+ // Set nontemporal metadata to load and stores of nontemporal values
+ if (NontemporalVars.size()) {
+ MDNode *NontemporalNode = MDNode::getDistinct(Ctx, {});
+ for (BasicBlock *BB : Reachable)
+ addNonTemporalMetadata(BB, NontemporalNode, NontemporalVars);
+ }
}
/// Create the TargetMachine object to query the backend for optimization
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 32c7c501d03c3..1c9690a1c7b68 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -189,10 +189,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
if (!op.getLinearVars().empty() || !op.getLinearStepVars().empty())
result = todo("linear");
};
- auto checkNontemporal = [&todo](auto op, LogicalResult &result) {
- if (!op.getNontemporalVars().empty())
- result = todo("nontemporal");
- };
auto checkNowait = [&todo](auto op, LogicalResult &result) {
if (op.getNowait())
result = todo("nowait");
@@ -300,7 +296,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
})
.Case([&](omp::SimdOp op) {
checkLinear(op, result);
- checkNontemporal(op, result);
checkReduction(op, result);
})
.Case<omp::AtomicReadOp, omp::AtomicWriteOp, omp::AtomicUpdateOp,
@@ -2527,6 +2522,14 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
llvm::MapVector<llvm::Value *, llvm::Value *> alignedVars;
llvm::omp::OrderKind order = convertOrderKind(simdOp.getOrder());
+
+ llvm::SmallVector<llvm::Value *> nontemporalVars;
+ mlir::OperandRange nontemporals = simdOp.getNontemporalVars();
+ for (mlir::Value nontemporal : nontemporals) {
+ llvm::Value *nt = moduleTranslation.lookupValue(nontemporal);
+ nontemporalVars.push_back(nt);
+ }
+
llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
std::optional<ArrayAttr> alignmentValues = simdOp.getAlignments();
mlir::OperandRange operands = simdOp.getAlignedVars();
@@ -2558,7 +2561,7 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
simdOp.getIfExpr()
? moduleTranslation.lookupValue(simdOp.getIfExpr())
: nullptr,
- order, simdlen, safelen);
+ order, simdlen, safelen, nontemporalVars);
return cleanupPrivateVars(builder, moduleTranslation, simdOp.getLoc(),
llvmPrivateVars, privateDecls);
diff --git a/mlir/test/Target/LLVMIR/openmp-nontemporal.mlir b/mlir/test/Target/LLVMIR/openmp-nontemporal.mlir
new file mode 100644
index 0000000000000..f8cee94be4ff7
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/openmp-nontemporal.mlir
@@ -0,0 +1,96 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// -----
+// CHECK-LABEL: @simd_nontemporal
+llvm.func @simd_nontemporal() {
+ %0 = llvm.mlir.constant(10 : i64) : i64
+ %1 = llvm.mlir.constant(1 : i64) : i64
+ %2 = llvm.alloca %1 x i64 : (i64) -> !llvm.ptr
+ %3 = llvm.alloca %1 x i64 : (i64) -> !llvm.ptr
+ //CHECK: %[[A_ADDR:.*]] = alloca i64, i64 1, align 8
+ //CHECK: %[[B_ADDR:.*]] = alloca i64, i64 1, align 8
+ //CHECK: %[[B:.*]] = load i64, ptr %[[B_ADDR]], align 4, !nontemporal !1, !llvm.access.group !2
+ //CHECK: store i64 %[[B]], ptr %[[A_ADDR]], align 4, !nontemporal !1, !llvm.access.group !2
+ omp.simd nontemporal(%2, %3 : !llvm.ptr, !llvm.ptr) {
+ omp.loop_nest (%arg0) : i64 = (%1) to (%0) inclusive step (%1) {
+ %4 = llvm.load %3 : !llvm.ptr -> i64
+ llvm.store %4, %2 : i64, !llvm.ptr
+ omp.yield
+ }
+ }
+ llvm.return
+}
+
+// -----
+
+//CHECK-LABEL: define void @_QPtest(ptr %0, ptr %1) {
+llvm.func @_QPtest(%arg0: !llvm.ptr {fir.bindc_name = "n"}, %arg1: !llvm.ptr {fir.bindc_name = "a"}) {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ %1 = llvm.alloca %0 x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)> {alignment = 8 : i64} : (i32) -> !llvm.ptr
+ %2 = llvm.alloca %0 x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)> {alignment = 8 : i64} : (i32) -> !llvm.ptr
+ %3 = llvm.mlir.constant(1 : i64) : i64
+ %4 = llvm.alloca %3 x i32 {bindc_name = "i", pinned} : (i64) -> !llvm.ptr
+ %6 = llvm.load %arg0 : !llvm.ptr -> i32
+ //CHECK: %[[A_VAL1:.*]] = alloca { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] }, align 8
+ //CHECK: %[[A_VAL2:.*]] = alloca { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] }, align 8
+ omp.simd nontemporal(%arg1 : !llvm.ptr) {
+ omp.loop_nest (%arg2) : i32 = (%0) to (%6) inclusive step (%0) {
+ llvm.store %arg2, %4 : i32, !llvm.ptr
+ //CHECK: call void @llvm.memcpy.p0.p0.i32(ptr %[[A_VAL2]], ptr %1, i32 48, i1 false)
+ %7 = llvm.mlir.constant(48 : i32) : i32
+ "llvm.intr.memcpy"(%2, %arg1, %7) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
+ %8 = llvm.load %4 : !llvm.ptr -> i32
+ %9 = llvm.sext %8 : i32 to i64
+ %10 = llvm.getelementptr %2[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
+ %11 = llvm.load %10 : !llvm.ptr -> !llvm.ptr
+ %12 = llvm.mlir.constant(0 : index) : i64
+ %13 = llvm.getelementptr %2[0, 7, %12, 0] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
+ %14 = llvm.load %13 : !llvm.ptr -> i64
+ %15 = llvm.getelementptr %2[0, 7, %12, 1] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
+ %16 = llvm.load %15 : !llvm.ptr -> i64
+ %17 = llvm.getelementptr %2[0, 7, %12, 2] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
+ %18 = llvm.load %17 : !llvm.ptr -> i64
+ %19 = llvm.mlir.constant(0 : i64) : i64
+ %20 = llvm.sub %9, %14 overflow<nsw> : i64
+ %21 = llvm.mul %20, %3 overflow<nsw> : i64
+ %22 = llvm.mul %21, %3 overflow<nsw> : i64
+ %23 = llvm.add %22,%19 overflow<nsw> : i64
+ %24 = llvm.mul %3, %16 overflow<nsw> : i64
+ //CHECK: %[[VAL1:.*]] = getelementptr float, ptr {{.*}}, i64 %{{.*}}
+ //CHECK: %[[LOAD_A:.*]] = load float, ptr %[[VAL1]], align 4, !nontemporal
+ //CHECK: %[[RES:.*]] = fadd contract float %[[LOAD_A]], 2.000000e+01
+ %25 = llvm.getelementptr %11[%23] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+ %26 = llvm.load %25 : !llvm.ptr -> f32
+ %27 = llvm.mlir.constant(2.000000e+01 : f32) : f32
+ %28 = llvm.fadd %26, %27 {fastmathFlags = #llvm.fastmath<contract>} : f32
+ //CHECK: call void @llvm.memcpy.p0.p0.i32(ptr %[[A_VAL1]], ptr %1, i32 48, i1 false)
+ %29 = llvm.mlir.constant(48 : i32) : i32
+ "llvm.intr.memcpy"(%1, %arg1, %29) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
+ %30 = llvm.load %4 : !llvm.ptr -> i32
+ %31 = llvm.sext %30 : i32 to i64
+ %32 = llvm.getelementptr %1[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
+ %33 = llvm.load %32 : !llvm.ptr -> !llvm.ptr
+ %34 = llvm.mlir.constant(0 : index) : i64
+ %35 = llvm.getelementptr %1[0, 7, %34, 0] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
+ %36 = llvm.load %35 : !llvm.ptr -> i64
+ %37 = llvm.getelementptr %1[0, 7, %34, 1] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
+ %38 = llvm.load %37 : !llvm.ptr -> i64
+ %39 = llvm.getelementptr %1[0, 7, %34, 2] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
+ %40 = llvm.load %39 : !llvm.ptr -> i64
+ %41 = llvm.sub %31, %36 overflow<nsw> : i64
+ %42 = llvm.mul %41, %3 overflow<nsw> : i64
+ %43 = llvm.mul %42, %3 overflow<nsw> : i64
+ %44 = llvm.add %43,%19 overflow<nsw> : i64
+ %45 = llvm.mul %3, %38 overflow<nsw> : i64
+ //CHECK: %[[VAL2:.*]] = getelementptr float, ptr %{{.*}}, i64 %{{.*}}
+ //CHECK: store float %[[RES]], ptr %[[VAL2]], align 4, !nontemporal
+ %46 = llvm.getelementptr %33[%44] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+ llvm.store %28, %46 : f32, !llvm.ptr
+ omp.yield
+ }
+ }
+ llvm.return
+ }
+// -----
+
+
diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index f907bb3f94a2a..a6ed73b158bdb 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -184,19 +184,6 @@ llvm.func @simd_linear(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) {
// -----
-llvm.func @simd_nontemporal(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) {
- // expected-error at below {{not yet implemented: Unhandled clause nontemporal in omp.simd operation}}
- // expected-error at below {{LLVM Translation failed for operation: omp.simd}}
- omp.simd nontemporal(%x : !llvm.ptr) {
- omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
- omp.yield
- }
- }
- llvm.return
-}
-
-// -----
-
omp.declare_reduction @add_f32 : f32
init {
^bb0(%arg: f32):
>From 53cc22910ede9382ecb2258f390515a530edc140 Mon Sep 17 00:00:00 2001
From: Kaviya Rajendiran <kaviyara2000 at gmail.com>
Date: Wed, 5 Mar 2025 12:20:45 +0530
Subject: [PATCH 2/2] [MLIR][OpenMP] Created callback function for adding the
metadata of nontemporal and handled the translation in
OpenMPToLLVMIRTranslation.cpp
---
.../llvm/Frontend/OpenMP/OMPIRBuilder.h | 16 +++-
llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp | 91 ++-----------------
.../OpenMP/OpenMPToLLVMIRTranslation.cpp | 91 +++++++++++++++++--
3 files changed, 104 insertions(+), 94 deletions(-)
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index fc726ec6cf4b4..60c64e369ca54 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -1229,6 +1229,9 @@ class OpenMPIRBuilder {
void unrollLoopPartial(DebugLoc DL, CanonicalLoopInfo *Loop, int32_t Factor,
CanonicalLoopInfo **UnrolledCLI);
+ using NonTemporalBodyGenCallbackTy =
+ function_ref<void(llvm::BasicBlock *BB, MDNode *NontemporalNode)>;
+
/// Add metadata to simd-ize a loop. If IfCond is not nullptr, the loop
/// is cloned. The metadata which prevents vectorization is added to
/// to the cloned loop. The cloned loop is executed when ifCond is evaluated
@@ -1242,10 +1245,15 @@ class OpenMPIRBuilder {
/// \param Order The enum to map order clause.
/// \param Simdlen The Simdlen length to apply to the simd loop.
/// \param Safelen The Safelen length to apply to the simd loop.
- void applySimd(CanonicalLoopInfo *Loop,
- MapVector<Value *, Value *> AlignedVars, Value *IfCond,
- omp::OrderKind Order, ConstantInt *Simdlen,
- ConstantInt *Safelen, ArrayRef<Value *> NontempralVars = {});
+ /// \param NontemporalCBFunc Call back function for nontemporal.
+ /// \param NontemporalVars Array of nontemporal vars.
+ void applySimd(
+ CanonicalLoopInfo *Loop, MapVector<Value *, Value *> AlignedVars,
+ Value *IfCond, omp::OrderKind Order, ConstantInt *Simdlen,
+ ConstantInt *Safelen,
+ NonTemporalBodyGenCallbackTy NontemporalCBFunc = [](BasicBlock *,
+ MDNode *) {},
+ ArrayRef<Value *> NontempralVars = {});
/// Generator for '#omp flush'
///
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index bff201a1377c8..5973c0d63aa5d 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -5385,86 +5385,11 @@ OpenMPIRBuilder::getOpenMPDefaultSimdAlign(const Triple &TargetTriple,
return 0;
}
-static void appendNontemporalVars(BasicBlock *Block,
- SmallVectorImpl<Value *> &NontemporalVars) {
- for (Instruction &I : *Block) {
- if (const CallInst *CI = dyn_cast<CallInst>(&I)) {
- if (CI->getIntrinsicID() == Intrinsic::memcpy) {
- llvm::Value *DestPtr = CI->getArgOperand(0);
- llvm::Value *SrcPtr = CI->getArgOperand(1);
- for (const llvm::Value *Var : NontemporalVars) {
- if (Var == SrcPtr) {
- NontemporalVars.push_back(DestPtr);
- break;
- }
- }
- }
- }
- }
-}
-
-/** Attach nontemporal metadata to the load/store instructions of nontemporal
- * variables of \p Block
- * Nontemporal variables may be a scalar, fixed size or allocatable
- * or pointer array
- *
- * Example scenarios for nontemporal variables:
- * Case 1: Scalar variable
- * If the nontemporal variable is a scalar, it is allocated on stack.Load and
- * store instructions directly access the alloca pointer of the scalar
- * variable for fetching information about scalar variable or writing
- * into the scalar variable. Mark those load and store instructions as
- * non-temporal.
- *
- * Case 2: Fixed Size array
- * If the nontemporal variable is a fixed-size array, it is allocated
- * as a contiguous block of memory. It uses one GEP instruction, to compute the
- * address of each individual array elements and perform load or store
- * operation on it. Mark those load and store instructions as non-temporal.
- *
- * Case 3: Allocatable array
- * For an allocatable array, which might involve runtime type descriptor,
- * needs to navigate through descriptors using two or more GEP and load
- * instructions to compute the address of each individual element in an array.
- * Mark those load or store which access the individual array elements as
- * non-temporal.
- */
-static void addNonTemporalMetadata(BasicBlock *Block, MDNode *Nontemporal,
- SmallVectorImpl<Value *> &NontemporalVars) {
- appendNontemporalVars(Block, NontemporalVars);
- for (Instruction &I : *Block) {
- llvm::Value *mem_ptr = nullptr;
- bool MetadataFlag = true;
- if (llvm::LoadInst *li = dyn_cast<llvm::LoadInst>(&I)) {
- if (!(li->getType()->isPointerTy()))
- mem_ptr = li->getPointerOperand();
- } else if (llvm::StoreInst *si = dyn_cast<llvm::StoreInst>(&I))
- mem_ptr = si->getPointerOperand();
- if (mem_ptr) {
- while (mem_ptr && !(isa<llvm::AllocaInst>(mem_ptr))) {
- if (llvm::GetElementPtrInst *gep =
- dyn_cast<llvm::GetElementPtrInst>(mem_ptr)) {
- llvm::Type *sourceType = gep->getSourceElementType();
- if (sourceType->isStructTy() && gep->getNumIndices() >= 2 &&
- !(gep->hasAllZeroIndices())) {
- MetadataFlag = false;
- break;
- }
- mem_ptr = gep->getPointerOperand();
- } else if (llvm::LoadInst *li = dyn_cast<llvm::LoadInst>(mem_ptr))
- mem_ptr = li->getPointerOperand();
- }
- if (MetadataFlag && is_contained(NontemporalVars, mem_ptr))
- I.setMetadata(LLVMContext::MD_nontemporal, Nontemporal);
- }
- }
-}
-
-void OpenMPIRBuilder::applySimd(CanonicalLoopInfo *CanonicalLoop,
- MapVector<Value *, Value *> AlignedVars,
- Value *IfCond, OrderKind Order,
- ConstantInt *Simdlen, ConstantInt *Safelen,
- ArrayRef<Value *> NontemporalVarsIn) {
+void OpenMPIRBuilder::applySimd(
+ CanonicalLoopInfo *CanonicalLoop, MapVector<Value *, Value *> AlignedVars,
+ Value *IfCond, OrderKind Order, ConstantInt *Simdlen, ConstantInt *Safelen,
+ OpenMPIRBuilder::NonTemporalBodyGenCallbackTy NontemporalCBFunc,
+ ArrayRef<Value *> NontemporalVarsIn) {
LLVMContext &Ctx = Builder.getContext();
Function *F = CanonicalLoop->getFunction();
@@ -5562,12 +5487,12 @@ void OpenMPIRBuilder::applySimd(CanonicalLoopInfo *CanonicalLoop,
}
addLoopMetadata(CanonicalLoop, LoopMDList);
- SmallVector<Value *> NontemporalVars{NontemporalVarsIn};
+
// Set nontemporal metadata to load and stores of nontemporal values
- if (NontemporalVars.size()) {
+ if (NontemporalVarsIn.size()) {
MDNode *NontemporalNode = MDNode::getDistinct(Ctx, {});
for (BasicBlock *BB : Reachable)
- addNonTemporalMetadata(BB, NontemporalNode, NontemporalVars);
+ NontemporalCBFunc(BB, NontemporalNode);
}
}
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 1c9690a1c7b68..ef313388df8e0 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -2462,6 +2462,25 @@ convertOrderKind(std::optional<omp::ClauseOrderKind> o) {
llvm_unreachable("Unknown ClauseOrderKind kind");
}
+static void
+appendNontemporalVars(llvm::BasicBlock *Block,
+ SmallVectorImpl<llvm::Value *> &NontemporalVars) {
+ for (llvm::Instruction &I : *Block) {
+ if (const llvm::CallInst *CI = dyn_cast<llvm::CallInst>(&I)) {
+ if (CI->getIntrinsicID() == llvm::Intrinsic::memcpy) {
+ llvm::Value *DestPtr = CI->getArgOperand(0);
+ llvm::Value *SrcPtr = CI->getArgOperand(1);
+ for (const llvm::Value *Var : NontemporalVars) {
+ if (Var == SrcPtr) {
+ NontemporalVars.push_back(DestPtr);
+ break;
+ }
+ }
+ }
+ }
+ }
+}
+
/// Converts an OpenMP simd loop into LLVM IR using OpenMPIRBuilder.
static LogicalResult
convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
@@ -2523,13 +2542,71 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
llvm::MapVector<llvm::Value *, llvm::Value *> alignedVars;
llvm::omp::OrderKind order = convertOrderKind(simdOp.getOrder());
- llvm::SmallVector<llvm::Value *> nontemporalVars;
+ llvm::SmallVector<llvm::Value *> nontemporalOrigVars;
mlir::OperandRange nontemporals = simdOp.getNontemporalVars();
for (mlir::Value nontemporal : nontemporals) {
llvm::Value *nt = moduleTranslation.lookupValue(nontemporal);
- nontemporalVars.push_back(nt);
+ nontemporalOrigVars.push_back(nt);
}
+ /** Call back function to attach nontemporal metadata to the load/store
+ * instructions of nontemporal variables of Block.
+ * Nontemporal variables may be a scalar, fixed size or allocatable
+ * or pointer array
+ *
+ * Example scenarios for nontemporal variables:
+ * Case 1: Scalar variable
+ * If the nontemporal variable is a scalar, it is allocated on stack.Load
+ * and store instructions directly access the alloca pointer of the scalar
+ * variable for fetching information about scalar variable or writing
+ * into the scalar variable. Mark those load and store instructions as
+ * non-temporal.
+ *
+ * Case 2: Fixed Size array
+ * If the nontemporal variable is a fixed-size array, it is allocated
+ * as a contiguous block of memory. It uses one GEP instruction, to compute
+ * the address of each individual array elements and perform load or store
+ * operation on it. Mark those load and store instructions as non-temporal.
+ *
+ * Case 3: Allocatable array
+ * For an allocatable array, which might involve runtime type descriptor,
+ * needs to navigate through descriptors using two or more GEP and load
+ * instructions to compute the address of each individual element in an array.
+ * Mark those load or store which access the individual array elements as
+ * non-temporal.
+ */
+ auto addNonTemporalMetadataCB = [&](llvm::BasicBlock *Block,
+ llvm::MDNode *Nontemporal) {
+ SmallVector<llvm::Value *> NontemporalVars{nontemporalOrigVars};
+ appendNontemporalVars(Block, NontemporalVars);
+ for (llvm::Instruction &I : *Block) {
+ llvm::Value *mem_ptr = nullptr;
+ bool MetadataFlag = true;
+ if (llvm::LoadInst *li = dyn_cast<llvm::LoadInst>(&I)) {
+ if (!(li->getType()->isPointerTy()))
+ mem_ptr = li->getPointerOperand();
+ } else if (llvm::StoreInst *si = dyn_cast<llvm::StoreInst>(&I))
+ mem_ptr = si->getPointerOperand();
+ if (mem_ptr) {
+ while (mem_ptr && !(isa<llvm::AllocaInst>(mem_ptr))) {
+ if (llvm::GetElementPtrInst *gep =
+ dyn_cast<llvm::GetElementPtrInst>(mem_ptr)) {
+ llvm::Type *sourceType = gep->getSourceElementType();
+ if (sourceType->isStructTy() && gep->getNumIndices() >= 2 &&
+ !(gep->hasAllZeroIndices())) {
+ MetadataFlag = false;
+ break;
+ }
+ mem_ptr = gep->getPointerOperand();
+ } else if (llvm::LoadInst *li = dyn_cast<llvm::LoadInst>(mem_ptr))
+ mem_ptr = li->getPointerOperand();
+ }
+ if (MetadataFlag && is_contained(NontemporalVars, mem_ptr))
+ I.setMetadata(llvm::LLVMContext::MD_nontemporal, Nontemporal);
+ }
+ }
+ };
+
llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
std::optional<ArrayAttr> alignmentValues = simdOp.getAlignments();
mlir::OperandRange operands = simdOp.getAlignedVars();
@@ -2557,11 +2634,11 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
- ompBuilder->applySimd(loopInfo, alignedVars,
- simdOp.getIfExpr()
- ? moduleTranslation.lookupValue(simdOp.getIfExpr())
- : nullptr,
- order, simdlen, safelen, nontemporalVars);
+ ompBuilder->applySimd(
+ loopInfo, alignedVars,
+ simdOp.getIfExpr() ? moduleTranslation.lookupValue(simdOp.getIfExpr())
+ : nullptr,
+ order, simdlen, safelen, addNonTemporalMetadataCB, nontemporalOrigVars);
return cleanupPrivateVars(builder, moduleTranslation, simdOp.getLoc(),
llvmPrivateVars, privateDecls);
More information about the Mlir-commits
mailing list