[flang-commits] [flang] [flang][openacc][NFC] Simplify lowering of recipe (PR #68836)
via flang-commits
flang-commits at lists.llvm.org
Wed Oct 11 14:21:36 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-fir-hlfir
Author: Valentin Clement (バレンタイン クレメン) (clementval)
<details>
<summary>Changes</summary>
Refactor some of the lowering in the reduction and firstprivate recipe to avoid duplicated code.
---
Full diff: https://github.com/llvm/llvm-project/pull/68836.diff
1 Files Affected:
- (modified) flang/lib/Lower/OpenACC.cpp (+74-101)
``````````diff
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 61a1b9fd86717cb..e09266121cdb997 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -463,7 +463,7 @@ bool isConstantBound(mlir::acc::DataBoundsOp &op) {
}
/// Return true iff all the bounds are expressed with constant values.
-bool areAllBoundConstant(llvm::SmallVector<mlir::Value> &bounds) {
+bool areAllBoundConstant(const llvm::SmallVector<mlir::Value> &bounds) {
for (auto bound : bounds) {
auto dataBound =
mlir::dyn_cast<mlir::acc::DataBoundsOp>(bound.getDefiningOp());
@@ -474,27 +474,6 @@ bool areAllBoundConstant(llvm::SmallVector<mlir::Value> &bounds) {
return true;
}
-static fir::ShapeOp
-genShapeFromBounds(mlir::Location loc, fir::FirOpBuilder &builder,
- const llvm::SmallVector<mlir::Value> &args) {
- assert(args.size() % 3 == 0 && "Triplets must be a multiple of 3");
- llvm::SmallVector<mlir::Value> extents;
- mlir::Type idxTy = builder.getIndexType();
- mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
- mlir::Value zero = builder.createIntegerConstant(loc, idxTy, 0);
- for (unsigned i = 0; i < args.size(); i += 3) {
- mlir::Value s1 =
- builder.create<mlir::arith::SubIOp>(loc, args[i + 1], args[0]);
- mlir::Value s2 = builder.create<mlir::arith::AddIOp>(loc, s1, one);
- mlir::Value s3 = builder.create<mlir::arith::DivSIOp>(loc, s2, args[i + 2]);
- mlir::Value cmp = builder.create<mlir::arith::CmpIOp>(
- loc, mlir::arith::CmpIPredicate::sgt, s3, zero);
- mlir::Value ext = builder.create<mlir::arith::SelectOp>(loc, cmp, s3, zero);
- extents.push_back(ext);
- }
- return builder.create<fir::ShapeOp>(loc, extents);
-}
-
static llvm::SmallVector<mlir::Value>
genConstantBounds(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::acc::DataBoundsOp &dataBound) {
@@ -520,6 +499,63 @@ genConstantBounds(fir::FirOpBuilder &builder, mlir::Location loc,
return {lb, ub, step};
}
+static fir::ShapeOp genShapeFromBoundsOrArgs(
+ mlir::Location loc, fir::FirOpBuilder &builder, fir::SequenceType seqTy,
+ const llvm::SmallVector<mlir::Value> &bounds, mlir::ValueRange arguments) {
+ llvm::SmallVector<mlir::Value> args;
+ if (areAllBoundConstant(bounds)) {
+ for (auto bound : llvm::reverse(bounds)) {
+ auto dataBound =
+ mlir::cast<mlir::acc::DataBoundsOp>(bound.getDefiningOp());
+ args.append(genConstantBounds(builder, loc, dataBound));
+ }
+ } else {
+ assert(((arguments.size() - 2) / 3 == seqTy.getDimension()) &&
+ "Expect 3 block arguments per dimension");
+ for (auto arg : arguments.drop_front(2))
+ args.push_back(arg);
+ }
+
+ assert(args.size() % 3 == 0 && "Triplets must be a multiple of 3");
+ llvm::SmallVector<mlir::Value> extents;
+ mlir::Type idxTy = builder.getIndexType();
+ mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
+ mlir::Value zero = builder.createIntegerConstant(loc, idxTy, 0);
+ for (unsigned i = 0; i < args.size(); i += 3) {
+ mlir::Value s1 =
+ builder.create<mlir::arith::SubIOp>(loc, args[i + 1], args[0]);
+ mlir::Value s2 = builder.create<mlir::arith::AddIOp>(loc, s1, one);
+ mlir::Value s3 = builder.create<mlir::arith::DivSIOp>(loc, s2, args[i + 2]);
+ mlir::Value cmp = builder.create<mlir::arith::CmpIOp>(
+ loc, mlir::arith::CmpIPredicate::sgt, s3, zero);
+ mlir::Value ext = builder.create<mlir::arith::SelectOp>(loc, cmp, s3, zero);
+ extents.push_back(ext);
+ }
+ return builder.create<fir::ShapeOp>(loc, extents);
+}
+
+static hlfir::DesignateOp::Subscripts
+getSubscriptsFromArgs(mlir::ValueRange args) {
+ hlfir::DesignateOp::Subscripts triplets;
+ for (unsigned i = 2; i < args.size(); i += 3)
+ triplets.emplace_back(
+ hlfir::DesignateOp::Triplet{args[i], args[i + 1], args[i + 2]});
+ return triplets;
+}
+
+static hlfir::Entity genDesignateWithTriplets(
+ fir::FirOpBuilder &builder, mlir::Location loc, hlfir::Entity &entity,
+ hlfir::DesignateOp::Subscripts &triplets, mlir::Value shape) {
+ llvm::SmallVector<mlir::Value> lenParams;
+ hlfir::genLengthParameters(loc, builder, entity, lenParams);
+ auto designate = builder.create<hlfir::DesignateOp>(
+ loc, entity.getBase().getType(), entity, /*component=*/"",
+ /*componentShape=*/mlir::Value{}, triplets,
+ /*substring=*/mlir::ValueRange{}, /*complexPartAttr=*/std::nullopt, shape,
+ lenParams);
+ return hlfir::Entity{designate.getResult()};
+}
+
mlir::acc::FirstprivateRecipeOp Fortran::lower::createOrGetFirstprivateRecipe(
mlir::OpBuilder &builder, llvm::StringRef recipeName, mlir::Location loc,
mlir::Type ty, llvm::SmallVector<mlir::Value> &bounds) {
@@ -600,47 +636,16 @@ mlir::acc::FirstprivateRecipeOp Fortran::lower::createOrGetFirstprivateRecipe(
if (!seqTy)
TODO(loc, "Unsupported boxed type in OpenACC firstprivate");
- if (allConstantBound) {
- for (auto bound : llvm::reverse(bounds)) {
- auto dataBound =
- mlir::cast<mlir::acc::DataBoundsOp>(bound.getDefiningOp());
- tripletArgs.append(genConstantBounds(firBuilder, loc, dataBound));
- }
- } else {
- assert(((recipe.getCopyRegion().getArguments().size() - 2) / 3 ==
- seqTy.getDimension()) &&
- "Expect 3 block arguments per dimension");
- for (auto arg : recipe.getCopyRegion().getArguments().drop_front(2))
- tripletArgs.push_back(arg);
- }
- auto shape = genShapeFromBounds(loc, firBuilder, tripletArgs);
- hlfir::DesignateOp::Subscripts triplets;
- for (unsigned i = 2; i < recipe.getCopyRegion().getArguments().size();
- i += 3)
- triplets.emplace_back(hlfir::DesignateOp::Triplet{
- recipe.getCopyRegion().getArgument(i),
- recipe.getCopyRegion().getArgument(i + 1),
- recipe.getCopyRegion().getArgument(i + 2)});
-
- llvm::SmallVector<mlir::Value> lenParamsLeft;
+ auto shape = genShapeFromBoundsOrArgs(
+ loc, firBuilder, seqTy, bounds, recipe.getCopyRegion().getArguments());
+ hlfir::DesignateOp::Subscripts triplets =
+ getSubscriptsFromArgs(recipe.getCopyRegion().getArguments());
auto leftEntity = hlfir::Entity{recipe.getCopyRegion().getArgument(0)};
- hlfir::genLengthParameters(loc, firBuilder, leftEntity, lenParamsLeft);
- auto leftDesignate = firBuilder.create<hlfir::DesignateOp>(
- loc, leftEntity.getBase().getType(), leftEntity, /*component=*/"",
- /*componentShape=*/mlir::Value{}, triplets,
- /*substring=*/mlir::ValueRange{}, /*complexPartAttr=*/std::nullopt,
- shape, lenParamsLeft);
- auto left = hlfir::Entity{leftDesignate.getResult()};
-
- llvm::SmallVector<mlir::Value> lenParamsRight;
+ auto left =
+ genDesignateWithTriplets(firBuilder, loc, leftEntity, triplets, shape);
auto rightEntity = hlfir::Entity{recipe.getCopyRegion().getArgument(1)};
- hlfir::genLengthParameters(loc, firBuilder, rightEntity, lenParamsRight);
- auto rightDesignate = firBuilder.create<hlfir::DesignateOp>(
- loc, rightEntity.getBase().getType(), rightEntity, /*component=*/"",
- /*componentShape=*/mlir::Value{}, triplets,
- /*substring=*/mlir::ValueRange{}, /*complexPartAttr=*/std::nullopt,
- shape, lenParamsRight);
- auto right = hlfir::Entity{rightDesignate.getResult()};
+ auto right =
+ genDesignateWithTriplets(firBuilder, loc, rightEntity, triplets, shape);
firBuilder.create<hlfir::AssignOp>(loc, left, right);
}
@@ -1110,48 +1115,16 @@ static void genCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
if (!seqTy)
TODO(loc, "Unsupported boxed type in OpenACC reduction");
- if (allConstantBound) {
- for (auto bound : llvm::reverse(bounds)) {
- auto dataBound =
- mlir::cast<mlir::acc::DataBoundsOp>(bound.getDefiningOp());
- tripletArgs.append(genConstantBounds(builder, loc, dataBound));
- }
- } else {
- assert(((recipe.getCombinerRegion().getArguments().size() - 2) / 3 ==
- seqTy.getDimension()) &&
- "Expect 3 block arguments per dimension");
- for (auto arg : recipe.getCombinerRegion().getArguments().drop_front(2))
- tripletArgs.push_back(arg);
- }
- auto shape = genShapeFromBounds(loc, builder, tripletArgs);
-
- hlfir::DesignateOp::Subscripts triplets;
- for (unsigned i = 2; i < recipe.getCombinerRegion().getArguments().size();
- i += 3)
- triplets.emplace_back(hlfir::DesignateOp::Triplet{
- recipe.getCombinerRegion().getArgument(i),
- recipe.getCombinerRegion().getArgument(i + 1),
- recipe.getCombinerRegion().getArgument(i + 2)});
-
- llvm::SmallVector<mlir::Value> lenParamsLeft;
+ auto shape = genShapeFromBoundsOrArgs(
+ loc, builder, seqTy, bounds, recipe.getCombinerRegion().getArguments());
+ hlfir::DesignateOp::Subscripts triplets =
+ getSubscriptsFromArgs(recipe.getCombinerRegion().getArguments());
auto leftEntity = hlfir::Entity{value1};
- hlfir::genLengthParameters(loc, builder, leftEntity, lenParamsLeft);
- auto leftDesignate = builder.create<hlfir::DesignateOp>(
- loc, value1.getType(), leftEntity, /*component=*/"",
- /*componentShape=*/mlir::Value{}, triplets,
- /*substring=*/mlir::ValueRange{}, /*complexPartAttr=*/std::nullopt,
- shape, lenParamsLeft);
- auto left = hlfir::Entity{leftDesignate.getResult()};
-
- llvm::SmallVector<mlir::Value> lenParamsRight;
+ auto left =
+ genDesignateWithTriplets(builder, loc, leftEntity, triplets, shape);
auto rightEntity = hlfir::Entity{value2};
- hlfir::genLengthParameters(loc, builder, rightEntity, lenParamsRight);
- auto rightDesignate = builder.create<hlfir::DesignateOp>(
- loc, value2.getType(), rightEntity, /*component=*/"",
- /*componentShape=*/mlir::Value{}, triplets,
- /*substring=*/mlir::ValueRange{}, /*complexPartAttr=*/std::nullopt,
- shape, lenParamsRight);
- auto right = hlfir::Entity{rightDesignate.getResult()};
+ auto right =
+ genDesignateWithTriplets(builder, loc, rightEntity, triplets, shape);
llvm::SmallVector<mlir::Value, 1> typeParams;
auto genKernel = [&builder, &loc, op, seqTy, &left, &right](
``````````
</details>
https://github.com/llvm/llvm-project/pull/68836
More information about the flang-commits
mailing list