[Mlir-commits] [mlir] c4c1030 - [mlir] support collapsed loops in OpenMP-to-LLVM translation

Alex Zinenko llvmlistbot at llvm.org
Fri Aug 6 08:13:20 PDT 2021


Author: Alex Zinenko
Date: 2021-08-06T17:13:12+02:00
New Revision: c4c103097660b7b130eaf134919516726d7bd9e6

URL: https://github.com/llvm/llvm-project/commit/c4c103097660b7b130eaf134919516726d7bd9e6
DIFF: https://github.com/llvm/llvm-project/commit/c4c103097660b7b130eaf134919516726d7bd9e6.diff

LOG: [mlir] support collapsed loops in OpenMP-to-LLVM translation

Reviewed By: Meinersbur

Differential Revision: https://reviews.llvm.org/D105706

Added: 
    

Modified: 
    mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
    mlir/test/Target/LLVMIR/openmp-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 686386a3542a9..31477ac4c0376 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -252,25 +252,12 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
   if (loop.lowerBound().empty())
     return failure();
 
-  if (loop.getNumLoops() != 1)
-    return opInst.emitOpError("collapsed loops not yet supported");
-
   // Static is the default.
   omp::ClauseScheduleKind schedule = omp::ClauseScheduleKind::Static;
   if (loop.schedule_val().hasValue())
     schedule =
         *omp::symbolizeClauseScheduleKind(loop.schedule_val().getValue());
 
-  // Find the loop configuration.
-  llvm::Value *lowerBound = moduleTranslation.lookupValue(loop.lowerBound()[0]);
-  llvm::Value *upperBound = moduleTranslation.lookupValue(loop.upperBound()[0]);
-  llvm::Value *step = moduleTranslation.lookupValue(loop.step()[0]);
-  llvm::Type *ivType = step->getType();
-  llvm::Value *chunk =
-      loop.schedule_chunk_var()
-          ? moduleTranslation.lookupValue(loop.schedule_chunk_var())
-          : llvm::ConstantInt::get(ivType, 1);
-
   // Set up the source location value for OpenMP runtime.
   llvm::DISubprogram *subprogram =
       builder.GetInsertBlock()->getParent()->getSubprogram();
@@ -279,22 +266,29 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder.saveIP(),
                                                     llvm::DebugLoc(diLoc));
 
-  // Generator of the canonical loop body. Produces an SESE region of basic
-  // blocks.
+  // Generator of the canonical loop body.
   // TODO: support error propagation in OpenMPIRBuilder and use it instead of
   // relying on captured variables.
+  SmallVector<llvm::CanonicalLoopInfo *> loopInfos;
+  SmallVector<llvm::OpenMPIRBuilder::InsertPointTy> bodyInsertPoints;
   LogicalResult bodyGenStatus = success();
   auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip, llvm::Value *iv) {
-    llvm::IRBuilder<>::InsertPointGuard guard(builder);
-
     // Make sure further conversions know about the induction variable.
-    moduleTranslation.mapValue(loop.getRegion().front().getArgument(0), iv);
+    moduleTranslation.mapValue(
+        loop.getRegion().front().getArgument(loopInfos.size()), iv);
+
+    // Capture the body insertion point for use in nested loops. BodyIP of the
+    // CanonicalLoopInfo always points to the beginning of the entry block of
+    // the body.
+    bodyInsertPoints.push_back(ip);
+
+    if (loopInfos.size() != loop.getNumLoops() - 1)
+      return;
 
+    // Convert the body of the loop.
     llvm::BasicBlock *entryBlock = ip.getBlock();
     llvm::BasicBlock *exitBlock =
         entryBlock->splitBasicBlock(ip.getPoint(), "omp.wsloop.exit");
-
-    // Convert the body of the loop.
     convertOmpOpRegions(loop.region(), "omp.wsloop.region", *entryBlock,
                         *exitBlock, builder, moduleTranslation, bodyGenStatus);
   };
@@ -303,21 +297,49 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
   // TODO: this currently assumes WsLoop is semantically similar to SCF loop,
   // i.e. it has a positive step, uses signed integer semantics. Reconsider
   // this code when WsLoop clearly supports more cases.
+  llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
+  for (unsigned i = 0, e = loop.getNumLoops(); i < e; ++i) {
+    llvm::Value *lowerBound =
+        moduleTranslation.lookupValue(loop.lowerBound()[i]);
+    llvm::Value *upperBound =
+        moduleTranslation.lookupValue(loop.upperBound()[i]);
+    llvm::Value *step = moduleTranslation.lookupValue(loop.step()[i]);
+
+    // Make sure loop trip count are emitted in the preheader of the outermost
+    // loop at the latest so that they are all available for the new collapsed
+    // loop will be created below.
+    llvm::OpenMPIRBuilder::LocationDescription loc = ompLoc;
+    llvm::OpenMPIRBuilder::InsertPointTy computeIP = ompLoc.IP;
+    if (i != 0) {
+      loc = llvm::OpenMPIRBuilder::LocationDescription(bodyInsertPoints.back(),
+                                                       llvm::DebugLoc(diLoc));
+      computeIP = loopInfos.front()->getPreheaderIP();
+    }
+    loopInfos.push_back(ompBuilder->createCanonicalLoop(
+        loc, bodyGen, lowerBound, upperBound, step,
+        /*IsSigned=*/true, loop.inclusive(), computeIP));
+
+    if (failed(bodyGenStatus))
+      return failure();
+  }
+
+  // Collapse loops. Store the insertion point because LoopInfos may get
+  // invalidated.
+  llvm::IRBuilderBase::InsertPoint afterIP = loopInfos.front()->getAfterIP();
   llvm::CanonicalLoopInfo *loopInfo =
-      moduleTranslation.getOpenMPBuilder()->createCanonicalLoop(
-          ompLoc, bodyGen, lowerBound, upperBound, step, /*IsSigned=*/true,
-          /*InclusiveStop=*/loop.inclusive());
-  if (failed(bodyGenStatus))
-    return failure();
+      ompBuilder->collapseLoops(diLoc, loopInfos, {});
 
+  // Find the loop configuration.
+  llvm::Type *ivType = loopInfo->getIndVar()->getType();
+  llvm::Value *chunk =
+      loop.schedule_chunk_var()
+          ? moduleTranslation.lookupValue(loop.schedule_chunk_var())
+          : llvm::ConstantInt::get(ivType, 1);
   llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
       findAllocaInsertPoint(builder, moduleTranslation);
-  llvm::OpenMPIRBuilder::InsertPointTy afterIP;
-  llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
   if (schedule == omp::ClauseScheduleKind::Static) {
-    loopInfo = ompBuilder->createStaticWorkshareLoop(ompLoc, loopInfo, allocaIP,
-                                                     !loop.nowait(), chunk);
-    afterIP = loopInfo->getAfterIP();
+    ompBuilder->createStaticWorkshareLoop(ompLoc, loopInfo, allocaIP,
+                                          !loop.nowait(), chunk);
   } else {
     llvm::omp::OMPScheduleType schedType;
     switch (schedule) {
@@ -338,11 +360,14 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
       break;
     }
 
-    afterIP = ompBuilder->createDynamicWorkshareLoop(
-        ompLoc, loopInfo, allocaIP, schedType, !loop.nowait(), chunk);
+    ompBuilder->createDynamicWorkshareLoop(ompLoc, loopInfo, allocaIP,
+                                           schedType, !loop.nowait(), chunk);
   }
 
-  // Continue building IR after the loop.
+  // Continue building IR after the loop. Note that the LoopInfo returned by
+  // `collapseLoops` points inside the outermost loop and is intended for
+  // potential further loop transformations. Use the insertion point stored
+  // before collapsing loops instead.
   builder.restoreIP(afterIP);
   return success();
 }

diff  --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
index 51cefc1ae2ebb..544c85e865feb 100644
--- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
@@ -467,6 +467,8 @@ llvm.func @test_omp_wsloop_guided(%lb : i64, %ub : i64, %step : i64) -> () {
  llvm.return
 }
 
+// -----
+
 // CHECK-LABEL: @omp_critical
 llvm.func @omp_critical(%x : !llvm.ptr<i32>, %xval : i32) -> () {
   // CHECK: call void @__kmpc_critical_with_hint({{.*}}critical_user_.var{{.*}}, i32 0)
@@ -488,6 +490,65 @@ llvm.func @omp_critical(%x : !llvm.ptr<i32>, %xval : i32) -> () {
     omp.terminator
   }
   // CHECK: call void @__kmpc_end_critical({{.*}}critical_user_mutex.var{{.*}})
+  llvm.return
+}
+
+// -----
 
+// Check that the loop bounds are emitted in the correct location in case of
+// collapse. This only checks the overall shape of the IR, detailed checking
+// is done by the OpenMPIRBuilder.
+
+// CHECK-LABEL: @collapse_wsloop
+// CHECK: i32* noalias %[[TIDADDR:[0-9A-Za-z.]*]]
+// CHECK: load i32, i32* %[[TIDADDR]]
+// CHECK: store
+// CHECK: load
+// CHECK: %[[LB0:.*]] = load i32
+// CHECK: %[[UB0:.*]] = load i32
+// CHECK: %[[STEP0:.*]] = load i32
+// CHECK: %[[LB1:.*]] = load i32
+// CHECK: %[[UB1:.*]] = load i32
+// CHECK: %[[STEP1:.*]] = load i32
+// CHECK: %[[LB2:.*]] = load i32
+// CHECK: %[[UB2:.*]] = load i32
+// CHECK: %[[STEP2:.*]] = load i32
+llvm.func @collapse_wsloop(
+    %0: i32, %1: i32, %2: i32,
+    %3: i32, %4: i32, %5: i32,
+    %6: i32, %7: i32, %8: i32,
+    %20: !llvm.ptr<i32>) {
+  omp.parallel {
+    // CHECK: icmp slt i32 %[[LB0]], 0
+    // CHECK-COUNT-4: select
+    // CHECK: %[[TRIPCOUNT0:.*]] = select
+    // CHECK: br label %[[PREHEADER:.*]]
+    //
+    // CHECK: [[PREHEADER]]:
+    // CHECK: icmp slt i32 %[[LB1]], 0
+    // CHECK-COUNT-4: select
+    // CHECK: %[[TRIPCOUNT1:.*]] = select
+    // CHECK: icmp slt i32 %[[LB2]], 0
+    // CHECK-COUNT-4: select
+    // CHECK: %[[TRIPCOUNT2:.*]] = select
+    // CHECK: %[[PROD:.*]] = mul nuw i32 %[[TRIPCOUNT0]], %[[TRIPCOUNT1]]
+    // CHECK: %[[TOTAL:.*]] = mul nuw i32 %[[PROD]], %[[TRIPCOUNT2]]
+    // CHECK: br label %[[COLLAPSED_PREHEADER:.*]]
+    //
+    // CHECK: [[COLLAPSED_PREHEADER]]:
+    // CHECK: store i32 0, i32*
+    // CHECK: %[[TOTAL_SUB_1:.*]] = sub i32 %[[TOTAL]], 1
+    // CHECK: store i32 %[[TOTAL_SUB_1]], i32*
+    // CHECK: call void @__kmpc_for_static_init_4u
+    omp.wsloop (%arg0, %arg1, %arg2) : i32 = (%0, %1, %2) to (%3, %4, %5) step (%6, %7, %8) collapse(3) {
+      %31 = llvm.load %20 : !llvm.ptr<i32>
+      %32 = llvm.add %31, %arg0 : i32
+      %33 = llvm.add %32, %arg1 : i32
+      %34 = llvm.add %33, %arg2 : i32
+      llvm.store %34, %20 : !llvm.ptr<i32>
+      omp.yield
+    }
+    omp.terminator
+  }
   llvm.return
 }


        


More information about the Mlir-commits mailing list