[Mlir-commits] [clang] [llvm] [mlir] [OpenMP] Migrate GPU Reductions CodeGen from Clang to OMPIRBuilder (PR #80343)

Akash Banerjee llvmlistbot at llvm.org
Thu Feb 15 07:59:09 PST 2024


================
@@ -2051,36 +2057,1424 @@ OpenMPIRBuilder::createSection(const LocationDescription &Loc,
                               /*IsCancellable*/ true);
 }
 
-/// Create a function with a unique name and a "void (i8*, i8*)" signature in
-/// the given module and return it.
-Function *getFreshReductionFunc(Module &M) {
+static OpenMPIRBuilder::InsertPointTy getInsertPointAfterInstr(Instruction *I) {
+  BasicBlock::iterator IT(I);
+  IT++;
+  return OpenMPIRBuilder::InsertPointTy(I->getParent(), IT);
+}
+
+void OpenMPIRBuilder::emitUsed(StringRef Name,
+                               std::vector<WeakTrackingVH> &List) {
+  if (List.empty())
+    return;
+
+  // Convert List to what ConstantArray needs.
+  SmallVector<Constant *, 8> UsedArray;
+  UsedArray.resize(List.size());
+  for (unsigned I = 0, E = List.size(); I != E; ++I)
+    UsedArray[I] = ConstantExpr::getPointerBitCastOrAddrSpaceCast(
+        cast<Constant>(&*List[I]), Builder.getPtrTy());
+
+  if (UsedArray.empty())
+    return;
+  ArrayType *ATy = ArrayType::get(Builder.getPtrTy(), UsedArray.size());
+
+  auto *GV = new GlobalVariable(M, ATy, false, GlobalValue::AppendingLinkage,
+                                ConstantArray::get(ATy, UsedArray), Name);
+
+  GV->setSection("llvm.metadata");
+}
+
+Value *OpenMPIRBuilder::getGPUThreadID() {
+  return Builder.CreateCall(
+      getOrCreateRuntimeFunction(M,
+                                 OMPRTL___kmpc_get_hardware_thread_id_in_block),
+      {});
+}
+
+Value *OpenMPIRBuilder::getGPUWarpSize() {
+  return Builder.CreateCall(
+      getOrCreateRuntimeFunction(M, OMPRTL___kmpc_get_warp_size), {});
+}
+
+Value *OpenMPIRBuilder::getNVPTXWarpID() {
+  unsigned LaneIDBits = Log2_32(Config.getGridValue().GV_Warp_Size);
+  return Builder.CreateAShr(getGPUThreadID(), LaneIDBits, "nvptx_warp_id");
+}
+
+Value *OpenMPIRBuilder::getNVPTXLaneID() {
+  unsigned LaneIDBits = Log2_32(Config.getGridValue().GV_Warp_Size);
+  assert(LaneIDBits < 32 && "Invalid LaneIDBits size in NVPTX device.");
+  unsigned LaneIDMask = ~0u >> (32u - LaneIDBits);
+  return Builder.CreateAnd(getGPUThreadID(), Builder.getInt32(LaneIDMask),
+                           "nvptx_lane_id");
+}
+
+Value *OpenMPIRBuilder::castValueToType(InsertPointTy AllocaIP, Value *From,
+                                        Type *ToType) {
+  Type *FromType = From->getType();
+  uint64_t FromSize = M.getDataLayout().getTypeStoreSize(FromType);
+  uint64_t ToSize = M.getDataLayout().getTypeStoreSize(ToType);
+  assert(FromSize > 0 && "From size must be greater than zero");
+  assert(ToSize > 0 && "To size must be greater than zero");
+  if (FromType == ToType)
+    return From;
+  if (FromSize == ToSize)
+    return Builder.CreateBitCast(From, ToType);
+  if (ToType->isIntegerTy() && FromType->isIntegerTy())
+    return Builder.CreateIntCast(From, ToType, /*isSigned*/ true);
+  InsertPointTy SaveIP = Builder.saveIP();
+  Builder.restoreIP(AllocaIP);
+  Value *CastItem = Builder.CreateAlloca(ToType);
+  Builder.restoreIP(SaveIP);
+
+  Value *ValCastItem = Builder.CreatePointerBitCastOrAddrSpaceCast(
+      CastItem, FromType->getPointerTo());
+  Builder.CreateStore(From, ValCastItem);
+  return Builder.CreateLoad(ToType, CastItem);
+}
+
+Value *OpenMPIRBuilder::createRuntimeShuffleFunction(InsertPointTy AllocaIP,
+                                                     Value *Element,
+                                                     Type *ElementType,
+                                                     Value *Offset) {
+  uint64_t Size = M.getDataLayout().getTypeStoreSize(ElementType);
+  assert(Size <= 8 && "Unsupported bitwidth in shuffle instruction");
+
+  // Cast all types to 32- or 64-bit values before calling shuffle routines.
+  Type *CastTy = Builder.getIntNTy(Size <= 4 ? 32 : 64);
+  Value *ElemCast = castValueToType(AllocaIP, Element, CastTy);
+  Value *WarpSize =
+      Builder.CreateIntCast(getGPUWarpSize(), Builder.getInt16Ty(), true);
+  Function *ShuffleFunc = getOrCreateRuntimeFunctionPtr(
+      Size <= 4 ? RuntimeFunction::OMPRTL___kmpc_shuffle_int32
+                : RuntimeFunction::OMPRTL___kmpc_shuffle_int64);
+  Value *WarpSizeCast =
+      Builder.CreateIntCast(WarpSize, Builder.getInt16Ty(), /*isSigned=*/true);
+  Value *ShuffleCall =
+      Builder.CreateCall(ShuffleFunc, {ElemCast, Offset, WarpSizeCast});
+  return castValueToType(AllocaIP, ShuffleCall, CastTy);
+}
+
+void OpenMPIRBuilder::shuffleAndStore(InsertPointTy AllocaIP, Value *SrcAddr,
+                                      Value *DstAddr, Type *ElemType,
+                                      Value *Offset, Type *ReductionArrayTy) {
+  uint64_t Size = M.getDataLayout().getTypeStoreSize(ElemType);
+  // Create the loop over the big sized data.
+  // ptr = (void*)Elem;
+  // ptrEnd = (void*) Elem + 1;
+  // Step = 8;
+  // while (ptr + Step < ptrEnd)
+  //   shuffle((int64_t)*ptr);
+  // Step = 4;
+  // while (ptr + Step < ptrEnd)
+  //   shuffle((int32_t)*ptr);
+  // ...
+  Type *IndexTy = Builder.getIndexTy(
+      M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
+  Value *ElemPtr = DstAddr;
+  Value *Ptr = SrcAddr;
+  for (unsigned IntSize = 8; IntSize >= 1; IntSize /= 2) {
+    if (Size < IntSize)
+      continue;
+    Type *IntType = Builder.getIntNTy(IntSize * 8);
+    Ptr = Builder.CreatePointerBitCastOrAddrSpaceCast(
+        Ptr, IntType->getPointerTo(), Ptr->getName() + ".ascast");
+    Value *SrcAddrGEP =
+        Builder.CreateGEP(ElemType, SrcAddr, {ConstantInt::get(IndexTy, 1)});
+    ElemPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
+        ElemPtr, IntType->getPointerTo(), ElemPtr->getName() + ".ascast");
+
+    Function *CurFunc = Builder.GetInsertBlock()->getParent();
+    if ((Size / IntSize) > 1) {
+      Value *PtrEnd = Builder.CreatePointerBitCastOrAddrSpaceCast(
+          SrcAddrGEP, Builder.getPtrTy());
+      BasicBlock *PreCondBB =
+          BasicBlock::Create(M.getContext(), ".shuffle.pre_cond");
+      BasicBlock *ThenBB = BasicBlock::Create(M.getContext(), ".shuffle.then");
+      BasicBlock *ExitBB = BasicBlock::Create(M.getContext(), ".shuffle.exit");
+      BasicBlock *CurrentBB = Builder.GetInsertBlock();
+      emitBlock(PreCondBB, CurFunc);
+      PHINode *PhiSrc =
+          Builder.CreatePHI(Ptr->getType(), /*NumReservedValues=*/2);
+      PhiSrc->addIncoming(Ptr, CurrentBB);
+      PHINode *PhiDest =
+          Builder.CreatePHI(ElemPtr->getType(), /*NumReservedValues=*/2);
+      PhiDest->addIncoming(ElemPtr, CurrentBB);
+      Ptr = PhiSrc;
+      ElemPtr = PhiDest;
+      Value *PtrDiff = Builder.CreatePtrDiff(
+          Builder.getInt8Ty(), PtrEnd,
+          Builder.CreatePointerBitCastOrAddrSpaceCast(Ptr, Builder.getPtrTy()));
+      Builder.CreateCondBr(
+          Builder.CreateICmpSGT(PtrDiff, Builder.getInt64(IntSize - 1)), ThenBB,
+          ExitBB);
+      emitBlock(ThenBB, CurFunc);
+      Value *Res = createRuntimeShuffleFunction(
+          AllocaIP,
+          Builder.CreateAlignedLoad(
+              IntType, Ptr, M.getDataLayout().getPrefTypeAlign(ElemType)),
+          IntType, Offset);
+      Builder.CreateAlignedStore(Res, ElemPtr,
+                                 M.getDataLayout().getPrefTypeAlign(ElemType));
+      Value *LocalPtr =
+          Builder.CreateGEP(IntType, Ptr, {ConstantInt::get(IndexTy, 1)});
+      Value *LocalElemPtr =
+          Builder.CreateGEP(IntType, ElemPtr, {ConstantInt::get(IndexTy, 1)});
+      PhiSrc->addIncoming(LocalPtr, ThenBB);
+      PhiDest->addIncoming(LocalElemPtr, ThenBB);
+      emitBranch(PreCondBB);
+      emitBlock(ExitBB, CurFunc);
+    } else {
+      Value *Res = createRuntimeShuffleFunction(
+          AllocaIP, Builder.CreateLoad(IntType, Ptr), IntType, Offset);
+      if (ElemType->isIntegerTy() && ElemType->getScalarSizeInBits() <
+                                         Res->getType()->getScalarSizeInBits())
+        Res = Builder.CreateTrunc(Res, ElemType);
+      Builder.CreateStore(Res, ElemPtr);
+      Ptr = Builder.CreateGEP(IntType, Ptr, {ConstantInt::get(IndexTy, 1)});
+      ElemPtr =
+          Builder.CreateGEP(IntType, ElemPtr, {ConstantInt::get(IndexTy, 1)});
+    }
+    Size = Size % IntSize;
+  }
+}
+
+void OpenMPIRBuilder::emitReductionListCopy(
+    InsertPointTy AllocaIP, CopyAction Action, Type *ReductionArrayTy,
+    ArrayRef<ReductionInfo> ReductionInfos, Value *SrcBase, Value *DestBase,
+    CopyOptionsTy CopyOptions) {
+  Type *IndexTy = Builder.getIndexTy(
+      M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
+  Value *RemoteLaneOffset = CopyOptions.RemoteLaneOffset;
+
+  // Iterates, element-by-element, through the source Reduce list and
+  // make a copy.
+  for (auto En : enumerate(ReductionInfos)) {
+    const ReductionInfo &RI = En.value();
+    Value *SrcElementAddr = nullptr;
+    Value *DestElementAddr = nullptr;
+    Value *DestElementPtrAddr = nullptr;
+    // Should we shuffle in an element from a remote lane?
+    bool ShuffleInElement = false;
+    // Set to true to update the pointer in the dest Reduce list to a
+    // newly created element.
+    bool UpdateDestListPtr = false;
+
+    // Step 1.1: Get the address for the src element in the Reduce list.
+    Value *SrcElementPtrAddr = Builder.CreateInBoundsGEP(
+        ReductionArrayTy, SrcBase,
+        {ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
+    SrcElementAddr = Builder.CreateLoad(Builder.getPtrTy(), SrcElementPtrAddr);
+
+    // Step 1.2: Create a temporary to store the element in the destination
+    // Reduce list.
+    DestElementPtrAddr = Builder.CreateInBoundsGEP(
+        ReductionArrayTy, DestBase,
+        {ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
+    switch (Action) {
+    case CopyAction::RemoteLaneToThread: {
+      InsertPointTy CurIP = Builder.saveIP();
+      Builder.restoreIP(AllocaIP);
+      AllocaInst *DestAlloca = Builder.CreateAlloca(RI.ElementType, nullptr,
+                                                    ".omp.reduction.element");
+      DestAlloca->setAlignment(
+          M.getDataLayout().getPrefTypeAlign(RI.ElementType));
+      DestElementAddr = DestAlloca;
+      DestElementAddr =
+          Builder.CreateAddrSpaceCast(DestElementAddr, Builder.getPtrTy(),
+                                      DestElementAddr->getName() + ".ascast");
+      Builder.restoreIP(CurIP);
+      ShuffleInElement = true;
+      UpdateDestListPtr = true;
+      break;
+    }
+    case CopyAction::ThreadCopy: {
+      DestElementAddr =
+          Builder.CreateLoad(Builder.getPtrTy(), DestElementPtrAddr);
+      break;
+    }
+    }
+
+    // Now that all active lanes have read the element in the
+    // Reduce list, shuffle over the value from the remote lane.
+    if (ShuffleInElement) {
+      shuffleAndStore(AllocaIP, SrcElementAddr, DestElementAddr, RI.ElementType,
+                      RemoteLaneOffset, ReductionArrayTy);
+    } else {
+      switch (RI.EvaluationKind) {
+      case EvaluationKindTy::Scalar: {
+        Value *Elem = Builder.CreateLoad(RI.ElementType, SrcElementAddr);
+        // Store the source element value to the dest element address.
+        Builder.CreateStore(Elem, DestElementAddr);
+        break;
+      }
+      case EvaluationKindTy::Complex: {
+        assert(false && "Complex data type not handled");
+        break;
+      }
----------------
TIFitis wrote:

@jdoerfert Currently there are no test cases for the Complex kind. I was trying to come up with test cases that might fall under this case but I wasn't able to.

Please let me know if this case is infeasible, and we can leave it as an assertion failure.

Here's my attempt at creating a test case, but it still falls under the aggregate kind:

```
int foo() {
  int i;
  int j;
  std::complex<int> sum[10][10];
  std::complex<int> res;

  #pragma omp target teams loop reduction(+:sum)
  for(i=0; i<10; i++)
    for(j=0; j<10; j++)
      res += sum[i][j];

  return 0;
}
```

https://github.com/llvm/llvm-project/pull/80343


More information about the Mlir-commits mailing list