[Mlir-commits] [mlir] c4c1030 - [mlir] support collapsed loops in OpenMP-to-LLVM translation
Alex Zinenko
llvmlistbot at llvm.org
Fri Aug 6 08:13:20 PDT 2021
Author: Alex Zinenko
Date: 2021-08-06T17:13:12+02:00
New Revision: c4c103097660b7b130eaf134919516726d7bd9e6
URL: https://github.com/llvm/llvm-project/commit/c4c103097660b7b130eaf134919516726d7bd9e6
DIFF: https://github.com/llvm/llvm-project/commit/c4c103097660b7b130eaf134919516726d7bd9e6.diff
LOG: [mlir] support collapsed loops in OpenMP-to-LLVM translation
Reviewed By: Meinersbur
Differential Revision: https://reviews.llvm.org/D105706
Added:
Modified:
mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
mlir/test/Target/LLVMIR/openmp-llvm.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 686386a3542a9..31477ac4c0376 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -252,25 +252,12 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
if (loop.lowerBound().empty())
return failure();
- if (loop.getNumLoops() != 1)
- return opInst.emitOpError("collapsed loops not yet supported");
-
// Static is the default.
omp::ClauseScheduleKind schedule = omp::ClauseScheduleKind::Static;
if (loop.schedule_val().hasValue())
schedule =
*omp::symbolizeClauseScheduleKind(loop.schedule_val().getValue());
- // Find the loop configuration.
- llvm::Value *lowerBound = moduleTranslation.lookupValue(loop.lowerBound()[0]);
- llvm::Value *upperBound = moduleTranslation.lookupValue(loop.upperBound()[0]);
- llvm::Value *step = moduleTranslation.lookupValue(loop.step()[0]);
- llvm::Type *ivType = step->getType();
- llvm::Value *chunk =
- loop.schedule_chunk_var()
- ? moduleTranslation.lookupValue(loop.schedule_chunk_var())
- : llvm::ConstantInt::get(ivType, 1);
-
// Set up the source location value for OpenMP runtime.
llvm::DISubprogram *subprogram =
builder.GetInsertBlock()->getParent()->getSubprogram();
@@ -279,22 +266,29 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder.saveIP(),
llvm::DebugLoc(diLoc));
- // Generator of the canonical loop body. Produces an SESE region of basic
- // blocks.
+ // Generator of the canonical loop body.
// TODO: support error propagation in OpenMPIRBuilder and use it instead of
// relying on captured variables.
+ SmallVector<llvm::CanonicalLoopInfo *> loopInfos;
+ SmallVector<llvm::OpenMPIRBuilder::InsertPointTy> bodyInsertPoints;
LogicalResult bodyGenStatus = success();
auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip, llvm::Value *iv) {
- llvm::IRBuilder<>::InsertPointGuard guard(builder);
-
// Make sure further conversions know about the induction variable.
- moduleTranslation.mapValue(loop.getRegion().front().getArgument(0), iv);
+ moduleTranslation.mapValue(
+ loop.getRegion().front().getArgument(loopInfos.size()), iv);
+
+ // Capture the body insertion point for use in nested loops. BodyIP of the
+ // CanonicalLoopInfo always points to the beginning of the entry block of
+ // the body.
+ bodyInsertPoints.push_back(ip);
+
+ if (loopInfos.size() != loop.getNumLoops() - 1)
+ return;
+ // Convert the body of the loop.
llvm::BasicBlock *entryBlock = ip.getBlock();
llvm::BasicBlock *exitBlock =
entryBlock->splitBasicBlock(ip.getPoint(), "omp.wsloop.exit");
-
- // Convert the body of the loop.
convertOmpOpRegions(loop.region(), "omp.wsloop.region", *entryBlock,
*exitBlock, builder, moduleTranslation, bodyGenStatus);
};
@@ -303,21 +297,49 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
// TODO: this currently assumes WsLoop is semantically similar to SCF loop,
// i.e. it has a positive step, uses signed integer semantics. Reconsider
// this code when WsLoop clearly supports more cases.
+ llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
+ for (unsigned i = 0, e = loop.getNumLoops(); i < e; ++i) {
+ llvm::Value *lowerBound =
+ moduleTranslation.lookupValue(loop.lowerBound()[i]);
+ llvm::Value *upperBound =
+ moduleTranslation.lookupValue(loop.upperBound()[i]);
+ llvm::Value *step = moduleTranslation.lookupValue(loop.step()[i]);
+
+ // Make sure loop trip count are emitted in the preheader of the outermost
+ // loop at the latest so that they are all available for the new collapsed
+ // loop will be created below.
+ llvm::OpenMPIRBuilder::LocationDescription loc = ompLoc;
+ llvm::OpenMPIRBuilder::InsertPointTy computeIP = ompLoc.IP;
+ if (i != 0) {
+ loc = llvm::OpenMPIRBuilder::LocationDescription(bodyInsertPoints.back(),
+ llvm::DebugLoc(diLoc));
+ computeIP = loopInfos.front()->getPreheaderIP();
+ }
+ loopInfos.push_back(ompBuilder->createCanonicalLoop(
+ loc, bodyGen, lowerBound, upperBound, step,
+ /*IsSigned=*/true, loop.inclusive(), computeIP));
+
+ if (failed(bodyGenStatus))
+ return failure();
+ }
+
+ // Collapse loops. Store the insertion point because LoopInfos may get
+ // invalidated.
+ llvm::IRBuilderBase::InsertPoint afterIP = loopInfos.front()->getAfterIP();
llvm::CanonicalLoopInfo *loopInfo =
- moduleTranslation.getOpenMPBuilder()->createCanonicalLoop(
- ompLoc, bodyGen, lowerBound, upperBound, step, /*IsSigned=*/true,
- /*InclusiveStop=*/loop.inclusive());
- if (failed(bodyGenStatus))
- return failure();
+ ompBuilder->collapseLoops(diLoc, loopInfos, {});
+ // Find the loop configuration.
+ llvm::Type *ivType = loopInfo->getIndVar()->getType();
+ llvm::Value *chunk =
+ loop.schedule_chunk_var()
+ ? moduleTranslation.lookupValue(loop.schedule_chunk_var())
+ : llvm::ConstantInt::get(ivType, 1);
llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
findAllocaInsertPoint(builder, moduleTranslation);
- llvm::OpenMPIRBuilder::InsertPointTy afterIP;
- llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
if (schedule == omp::ClauseScheduleKind::Static) {
- loopInfo = ompBuilder->createStaticWorkshareLoop(ompLoc, loopInfo, allocaIP,
- !loop.nowait(), chunk);
- afterIP = loopInfo->getAfterIP();
+ ompBuilder->createStaticWorkshareLoop(ompLoc, loopInfo, allocaIP,
+ !loop.nowait(), chunk);
} else {
llvm::omp::OMPScheduleType schedType;
switch (schedule) {
@@ -338,11 +360,14 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
break;
}
- afterIP = ompBuilder->createDynamicWorkshareLoop(
- ompLoc, loopInfo, allocaIP, schedType, !loop.nowait(), chunk);
+ ompBuilder->createDynamicWorkshareLoop(ompLoc, loopInfo, allocaIP,
+ schedType, !loop.nowait(), chunk);
}
- // Continue building IR after the loop.
+ // Continue building IR after the loop. Note that the LoopInfo returned by
+ // `collapseLoops` points inside the outermost loop and is intended for
+ // potential further loop transformations. Use the insertion point stored
+ // before collapsing loops instead.
builder.restoreIP(afterIP);
return success();
}
diff --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
index 51cefc1ae2ebb..544c85e865feb 100644
--- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
@@ -467,6 +467,8 @@ llvm.func @test_omp_wsloop_guided(%lb : i64, %ub : i64, %step : i64) -> () {
llvm.return
}
+// -----
+
// CHECK-LABEL: @omp_critical
llvm.func @omp_critical(%x : !llvm.ptr<i32>, %xval : i32) -> () {
// CHECK: call void @__kmpc_critical_with_hint({{.*}}critical_user_.var{{.*}}, i32 0)
@@ -488,6 +490,65 @@ llvm.func @omp_critical(%x : !llvm.ptr<i32>, %xval : i32) -> () {
omp.terminator
}
// CHECK: call void @__kmpc_end_critical({{.*}}critical_user_mutex.var{{.*}})
+ llvm.return
+}
+
+// -----
+// Check that the loop bounds are emitted in the correct location in case of
+// collapse. This only checks the overall shape of the IR, detailed checking
+// is done by the OpenMPIRBuilder.
+
+// CHECK-LABEL: @collapse_wsloop
+// CHECK: i32* noalias %[[TIDADDR:[0-9A-Za-z.]*]]
+// CHECK: load i32, i32* %[[TIDADDR]]
+// CHECK: store
+// CHECK: load
+// CHECK: %[[LB0:.*]] = load i32
+// CHECK: %[[UB0:.*]] = load i32
+// CHECK: %[[STEP0:.*]] = load i32
+// CHECK: %[[LB1:.*]] = load i32
+// CHECK: %[[UB1:.*]] = load i32
+// CHECK: %[[STEP1:.*]] = load i32
+// CHECK: %[[LB2:.*]] = load i32
+// CHECK: %[[UB2:.*]] = load i32
+// CHECK: %[[STEP2:.*]] = load i32
+llvm.func @collapse_wsloop(
+ %0: i32, %1: i32, %2: i32,
+ %3: i32, %4: i32, %5: i32,
+ %6: i32, %7: i32, %8: i32,
+ %20: !llvm.ptr<i32>) {
+ omp.parallel {
+ // CHECK: icmp slt i32 %[[LB0]], 0
+ // CHECK-COUNT-4: select
+ // CHECK: %[[TRIPCOUNT0:.*]] = select
+ // CHECK: br label %[[PREHEADER:.*]]
+ //
+ // CHECK: [[PREHEADER]]:
+ // CHECK: icmp slt i32 %[[LB1]], 0
+ // CHECK-COUNT-4: select
+ // CHECK: %[[TRIPCOUNT1:.*]] = select
+ // CHECK: icmp slt i32 %[[LB2]], 0
+ // CHECK-COUNT-4: select
+ // CHECK: %[[TRIPCOUNT2:.*]] = select
+ // CHECK: %[[PROD:.*]] = mul nuw i32 %[[TRIPCOUNT0]], %[[TRIPCOUNT1]]
+ // CHECK: %[[TOTAL:.*]] = mul nuw i32 %[[PROD]], %[[TRIPCOUNT2]]
+ // CHECK: br label %[[COLLAPSED_PREHEADER:.*]]
+ //
+ // CHECK: [[COLLAPSED_PREHEADER]]:
+ // CHECK: store i32 0, i32*
+ // CHECK: %[[TOTAL_SUB_1:.*]] = sub i32 %[[TOTAL]], 1
+ // CHECK: store i32 %[[TOTAL_SUB_1]], i32*
+ // CHECK: call void @__kmpc_for_static_init_4u
+ omp.wsloop (%arg0, %arg1, %arg2) : i32 = (%0, %1, %2) to (%3, %4, %5) step (%6, %7, %8) collapse(3) {
+ %31 = llvm.load %20 : !llvm.ptr<i32>
+ %32 = llvm.add %31, %arg0 : i32
+ %33 = llvm.add %32, %arg1 : i32
+ %34 = llvm.add %33, %arg2 : i32
+ llvm.store %34, %20 : !llvm.ptr<i32>
+ omp.yield
+ }
+ omp.terminator
+ }
llvm.return
}
More information about the Mlir-commits
mailing list