[Mlir-commits] [mlir] 099ecdf - [mlir][OpenMP] map argument to reduction initialization region (#86979)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Apr 4 02:55:46 PDT 2024


Author: Tom Eccles
Date: 2024-04-04T10:55:42+01:00
New Revision: 099ecdf1ec2f87b5bae74518166daf1d2b09da45

URL: https://github.com/llvm/llvm-project/commit/099ecdf1ec2f87b5bae74518166daf1d2b09da45
DIFF: https://github.com/llvm/llvm-project/commit/099ecdf1ec2f87b5bae74518166daf1d2b09da45.diff

LOG: [mlir][OpenMP] map argument to reduction initialization region (#86979)

The argument to the initialization region of reduction declarations was
never mapped. This meant that if this argument was accessed inside the
initialization region, that mlir operation would be translated to an
llvm operation with a null argument (failing verification).

Adding the mapping ensures that the right LLVM value can be found when
inlining and converting the initialization region.

We have to separately establish and clean up these mappings for each use
of the reduction declaration because repeated usage of the same
declaration will inline it using a different concrete value for the
block argument.

This argument was never used previously because for most cases the
initialized value depends only upon the type of the reduction, not on
the original variable. It is needed now so that we can read the array
extents for the local copy from the mold.

Flang support for reductions on assumed shape arrays patch 2/3

Added: 
    mlir/test/Target/LLVMIR/openmp-reduction-init-arg.mlir

Modified: 
    mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index cacf2c37e38d1c..c4bf6a20ebe71c 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -825,6 +825,25 @@ static void allocByValReductionVars(
   }
 }
 
+/// Map input argument to all reduction initialization regions
+template <typename T>
+static void
+mapInitializationArg(T loop, LLVM::ModuleTranslation &moduleTranslation,
+                     SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls,
+                     unsigned i) {
+  // map input argument to the initialization region
+  mlir::omp::DeclareReductionOp &reduction = reductionDecls[i];
+  Region &initializerRegion = reduction.getInitializerRegion();
+  Block &entry = initializerRegion.front();
+  assert(entry.getNumArguments() == 1 &&
+         "the initialization region has one argument");
+
+  mlir::Value mlirSource = loop.getReductionVars()[i];
+  llvm::Value *llvmSource = moduleTranslation.lookupValue(mlirSource);
+  assert(llvmSource && "lookup reduction var");
+  moduleTranslation.mapValue(entry.getArgument(0), llvmSource);
+}
+
 /// Collect reduction info
 template <typename T>
 static void collectReductionInfo(
@@ -902,6 +921,10 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
       loop.getRegion().getArguments().take_back(loop.getNumReductionVars());
   for (unsigned i = 0; i < loop.getNumReductionVars(); ++i) {
     SmallVector<llvm::Value *> phis;
+
+    // map block argument to initializer region
+    mapInitializationArg(loop, moduleTranslation, reductionDecls, i);
+
     if (failed(inlineConvertOmpRegions(reductionDecls[i].getInitializerRegion(),
                                        "omp.reduction.neutral", builder,
                                        moduleTranslation, &phis)))
@@ -925,6 +948,11 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
       builder.CreateStore(phis[0], privateReductionVariables[i]);
       // the rest was handled in allocByValReductionVars
     }
+
+    // forget the mapping for the initializer region because we might need a
+    // 
diff erent mapping if this reduction declaration is re-used for a
+    // 
diff erent variable
+    moduleTranslation.forgetMapping(reductionDecls[i].getInitializerRegion());
   }
 
   // Store the mapping between reduction variables and their private copies on
@@ -1118,6 +1146,9 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
             opInst.getNumReductionVars());
     for (unsigned i = 0; i < opInst.getNumReductionVars(); ++i) {
       SmallVector<llvm::Value *> phis;
+
+      // map the block argument
+      mapInitializationArg(opInst, moduleTranslation, reductionDecls, i);
       if (failed(inlineConvertOmpRegions(
               reductionDecls[i].getInitializerRegion(), "omp.reduction.neutral",
               builder, moduleTranslation, &phis)))
@@ -1144,6 +1175,10 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
         builder.CreateStore(phis[0], privateReductionVariables[i]);
         // the rest is done in allocByValReductionVars
       }
+
+      // clear block argument mapping in case it needs to be re-created with a
+      // 
diff erent source for another use of the same reduction decl
+      moduleTranslation.forgetMapping(reductionDecls[i].getInitializerRegion());
     }
 
     // Store the mapping between reduction variables and their private copies on

diff  --git a/mlir/test/Target/LLVMIR/openmp-reduction-init-arg.mlir b/mlir/test/Target/LLVMIR/openmp-reduction-init-arg.mlir
new file mode 100644
index 00000000000000..5dd31c425566c1
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/openmp-reduction-init-arg.mlir
@@ -0,0 +1,111 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// Test that the block argument to the initialization region of
+// omp.declare_reduction gets mapped properly when translating to LLVMIR.
+
+module {
+  omp.declare_reduction @add_reduction_byref_box_Uxf64 : !llvm.ptr init {
+  ^bb0(%arg0: !llvm.ptr):
+// test usage of %arg0:
+    %11 = llvm.load %arg0 : !llvm.ptr -> !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
+    omp.yield(%arg0 : !llvm.ptr)
+  } combiner {
+  ^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
+    omp.yield(%arg0 : !llvm.ptr)
+  }
+
+  llvm.func internal @_QFPreduce(%arg0: !llvm.ptr {fir.bindc_name = "r"}, %arg1: !llvm.ptr {fir.bindc_name = "r2"}) attributes {sym_visibility = "private"} {
+  %8 = llvm.mlir.constant(1 : i32) : i32
+  %9 = llvm.mlir.constant(10 : i32) : i32
+  %10 = llvm.mlir.constant(0 : i32) : i32
+  %83 = llvm.mlir.constant(1 : i64) : i64
+  %84 = llvm.alloca %83 x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)> : (i64) -> !llvm.ptr
+  %86 = llvm.mlir.constant(1 : i64) : i64
+  %87 = llvm.alloca %86 x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)> : (i64) -> !llvm.ptr
+// test multiple reduction variables to ensure they don't intefere with eachother
+// when inlining the reduction init region multiple times
+    omp.parallel byref reduction(@add_reduction_byref_box_Uxf64 %84 -> %arg3 : !llvm.ptr, @add_reduction_byref_box_Uxf64 %87 -> %arg4 : !llvm.ptr) {
+      omp.terminator
+    }
+    llvm.return
+  }
+}
+
+// CHECK-LABEL: define internal void @_QFPreduce
+// CHECK:         %[[VAL_0:.*]] = alloca { ptr, ptr }, align 8
+// CHECK:         %[[VAL_1:.*]] = alloca { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] }, i64 1, align 8
+// CHECK:         %[[VAL_2:.*]] = alloca { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] }, i64 1, align 8
+// CHECK:         br label %[[VAL_3:.*]]
+// CHECK:       entry:                                            ; preds = %[[VAL_4:.*]]
+// CHECK:         %[[VAL_5:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
+// CHECK:         br label %[[VAL_6:.*]]
+// CHECK:       omp_parallel:                                     ; preds = %[[VAL_3]]
+// CHECK:         %[[VAL_7:.*]] = getelementptr { ptr, ptr }, ptr %[[VAL_0]], i32 0, i32 0
+// CHECK:         store ptr %[[VAL_1]], ptr %[[VAL_7]], align 8
+// CHECK:         %[[VAL_8:.*]] = getelementptr { ptr, ptr }, ptr %[[VAL_0]], i32 0, i32 1
+// CHECK:         store ptr %[[VAL_2]], ptr %[[VAL_8]], align 8
+// CHECK:         call void (ptr, i32, ptr, ...) @__kmpc_fork_call(ptr @1, i32 1, ptr @_QFPreduce..omp_par, ptr %[[VAL_0]])
+// CHECK:         br label %[[VAL_9:.*]]
+// CHECK:       omp.par.outlined.exit:                            ; preds = %[[VAL_6]]
+// CHECK:         br label %[[VAL_10:.*]]
+// CHECK:       omp.par.exit.split:                               ; preds = %[[VAL_9]]
+// CHECK:         ret void
+// CHECK:       omp.par.entry:
+// CHECK:         %[[VAL_11:.*]] = getelementptr { ptr, ptr }, ptr %[[VAL_12:.*]], i32 0, i32 0
+// CHECK:         %[[VAL_13:.*]] = load ptr, ptr %[[VAL_11]], align 8
+// CHECK:         %[[VAL_14:.*]] = getelementptr { ptr, ptr }, ptr %[[VAL_12]], i32 0, i32 1
+// CHECK:         %[[VAL_15:.*]] = load ptr, ptr %[[VAL_14]], align 8
+// CHECK:         %[[VAL_16:.*]] = alloca i32, align 4
+// CHECK:         %[[VAL_17:.*]] = load i32, ptr %[[VAL_18:.*]], align 4
+// CHECK:         store i32 %[[VAL_17]], ptr %[[VAL_16]], align 4
+// CHECK:         %[[VAL_19:.*]] = load i32, ptr %[[VAL_16]], align 4
+// CHECK:         %[[VAL_20:.*]] = load { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] }, ptr %[[VAL_13]], align 8
+// CHECK:         %[[VAL_21:.*]] = alloca ptr, align 8
+// CHECK:         store ptr %[[VAL_13]], ptr %[[VAL_21]], align 8
+// CHECK:         %[[VAL_22:.*]] = load { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] }, ptr %[[VAL_15]], align 8
+// CHECK:         %[[VAL_23:.*]] = alloca ptr, align 8
+// CHECK:         store ptr %[[VAL_15]], ptr %[[VAL_23]], align 8
+// CHECK:         %[[VAL_24:.*]] = alloca [2 x ptr], align 8
+// CHECK:         br label %[[VAL_25:.*]]
+// CHECK:       omp.par.region:                                   ; preds = %[[VAL_26:.*]]
+// CHECK:         br label %[[VAL_27:.*]]
+// CHECK:       omp.par.region1:                                  ; preds = %[[VAL_25]]
+// CHECK:         br label %[[VAL_28:.*]]
+// CHECK:       omp.region.cont:                                  ; preds = %[[VAL_27]]
+// CHECK:         %[[VAL_29:.*]] = getelementptr inbounds [2 x ptr], ptr %[[VAL_24]], i64 0, i64 0
+// CHECK:         store ptr %[[VAL_21]], ptr %[[VAL_29]], align 8
+// CHECK:         %[[VAL_30:.*]] = getelementptr inbounds [2 x ptr], ptr %[[VAL_24]], i64 0, i64 1
+// CHECK:         store ptr %[[VAL_23]], ptr %[[VAL_30]], align 8
+// CHECK:         %[[VAL_31:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
+// CHECK:         %[[VAL_32:.*]] = call i32 @__kmpc_reduce(ptr @1, i32 %[[VAL_31]], i32 2, i64 16, ptr %[[VAL_24]], ptr @.omp.reduction.func, ptr @.gomp_critical_user_.reduction.var)
+// CHECK:         switch i32 %[[VAL_32]], label %[[VAL_33:.*]] [
+// CHECK:           i32 1, label %[[VAL_34:.*]]
+// CHECK:           i32 2, label %[[VAL_35:.*]]
+// CHECK:         ]
+// CHECK:       reduce.switch.atomic:                             ; preds = %[[VAL_28]]
+// CHECK:         unreachable
+// CHECK:       reduce.switch.nonatomic:                          ; preds = %[[VAL_28]]
+// CHECK:         %[[VAL_36:.*]] = load ptr, ptr %[[VAL_21]], align 8
+// CHECK:         %[[VAL_37:.*]] = load ptr, ptr %[[VAL_23]], align 8
+// CHECK:         call void @__kmpc_end_reduce(ptr @1, i32 %[[VAL_31]], ptr @.gomp_critical_user_.reduction.var)
+// CHECK:         br label %[[VAL_33]]
+// CHECK:       reduce.finalize:                                  ; preds = %[[VAL_34]], %[[VAL_28]]
+// CHECK:         br label %[[VAL_38:.*]]
+// CHECK:       omp.par.pre_finalize:                             ; preds = %[[VAL_33]]
+// CHECK:         br label %[[VAL_39:.*]]
+// CHECK:       omp.par.outlined.exit.exitStub:                   ; preds = %[[VAL_38]]
+// CHECK:         ret void
+// CHECK:         %[[VAL_40:.*]] = getelementptr inbounds [2 x ptr], ptr %[[VAL_41:.*]], i64 0, i64 0
+// CHECK:         %[[VAL_42:.*]] = load ptr, ptr %[[VAL_40]], align 8
+// CHECK:         %[[VAL_43:.*]] = load ptr, ptr %[[VAL_42]], align 8
+// CHECK:         %[[VAL_44:.*]] = getelementptr inbounds [2 x ptr], ptr %[[VAL_45:.*]], i64 0, i64 0
+// CHECK:         %[[VAL_46:.*]] = load ptr, ptr %[[VAL_44]], align 8
+// CHECK:         %[[VAL_47:.*]] = load ptr, ptr %[[VAL_46]], align 8
+// CHECK:         %[[VAL_48:.*]] = getelementptr inbounds [2 x ptr], ptr %[[VAL_41]], i64 0, i64 1
+// CHECK:         %[[VAL_49:.*]] = load ptr, ptr %[[VAL_48]], align 8
+// CHECK:         %[[VAL_50:.*]] = load ptr, ptr %[[VAL_49]], align 8
+// CHECK:         %[[VAL_51:.*]] = getelementptr inbounds [2 x ptr], ptr %[[VAL_45]], i64 0, i64 1
+// CHECK:         %[[VAL_52:.*]] = load ptr, ptr %[[VAL_51]], align 8
+// CHECK:         %[[VAL_53:.*]] = load ptr, ptr %[[VAL_52]], align 8
+// CHECK:         ret void
+


        


More information about the Mlir-commits mailing list