[flang-commits] [flang] 8bd76ac - [flang] Support multidimensional reductions in SimplifyIntrinsicsPass.
Slava Zakharin via flang-commits
flang-commits at lists.llvm.org
Mon Sep 19 12:17:48 PDT 2022
Author: Slava Zakharin
Date: 2022-09-19T12:16:23-07:00
New Revision: 8bd76ac151534d2b9534ed919c0a7f4511002d84
URL: https://github.com/llvm/llvm-project/commit/8bd76ac151534d2b9534ed919c0a7f4511002d84
DIFF: https://github.com/llvm/llvm-project/commit/8bd76ac151534d2b9534ed919c0a7f4511002d84.diff
LOG: [flang] Support multidimensional reductions in SimplifyIntrinsicsPass.
Create simplified functions for each rank with "x<rank>" suffix
that implement multidimensional reductions. To enable this I had to fix
an issue with taking incorrect box shape in cases of sliced embox/rebox.
Differential Revision: https://reviews.llvm.org/D133820
Added:
Modified:
flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp
flang/test/Transforms/simplifyintrinsics.fir
Removed:
################################################################################
diff --git a/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp b/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp
index d23736ef8a68e..5682fa2816714 100644
--- a/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp
+++ b/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp
@@ -61,7 +61,7 @@ class SimplifyIntrinsicsPass
using FunctionBodyGeneratorTy =
llvm::function_ref<void(fir::FirOpBuilder &, mlir::func::FuncOp &)>;
using GenReductionBodyTy = llvm::function_ref<void(
- fir::FirOpBuilder &builder, mlir::func::FuncOp &funcOp)>;
+ fir::FirOpBuilder &builder, mlir::func::FuncOp &funcOp, unsigned rank)>;
public:
/// Generate a new function implementing a simplified version
@@ -110,10 +110,11 @@ using InitValGeneratorTy = llvm::function_ref<mlir::Value(
/// the reduction value
/// \p genBody is called to fill in the actual reduciton operation
/// for example add for SUM, MAX for MAXVAL, etc.
+/// \p rank is the rank of the input argument.
static void genReductionLoop(fir::FirOpBuilder &builder,
mlir::func::FuncOp &funcOp,
InitValGeneratorTy initVal,
- BodyOpGeneratorTy genBody) {
+ BodyOpGeneratorTy genBody, unsigned rank) {
auto loc = mlir::UnknownLoc::get(builder.getContext());
mlir::Type elementType = funcOp.getResultTypes()[0];
builder.setInsertionPointToEnd(funcOp.addEntryBlock());
@@ -125,59 +126,98 @@ static void genReductionLoop(fir::FirOpBuilder &builder,
mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0);
- fir::SequenceType::Shape flatShape = {fir::SequenceType::getUnknownExtent()};
+ fir::SequenceType::Shape flatShape(rank,
+ fir::SequenceType::getUnknownExtent());
mlir::Type arrTy = fir::SequenceType::get(flatShape, elementType);
mlir::Type boxArrTy = fir::BoxType::get(arrTy);
mlir::Value array = builder.create<fir::ConvertOp>(loc, boxArrTy, arg);
- auto dims =
- builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, array, zeroIdx);
- mlir::Value len = dims.getResult(1);
+ mlir::Value init = initVal(builder, loc, elementType);
+
+ llvm::SmallVector<mlir::Value, 15> bounds;
+
+ assert(rank > 0 && "rank cannot be zero");
mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
- mlir::Value step = one;
- // We use C indexing here, so len-1 as loopcount
- mlir::Value loopCount = builder.create<mlir::arith::SubIOp>(loc, len, one);
- mlir::Value init = initVal(builder, loc, elementType);
- auto loop = builder.create<fir::DoLoopOp>(loc, zeroIdx, loopCount, step,
- /*unordered=*/false,
- /*finalCountValue=*/false, init);
- mlir::Value reductionVal = loop.getRegionIterArgs()[0];
+ // Compute all the upper bounds before the loop nest.
+ // It is not strictly necessary for performance, since the loop nest
+ // does not have any store operations and any LICM optimization
+ // should be able to optimize the redundancy.
+ for (unsigned i = 0; i < rank; ++i) {
+ mlir::Value dimIdx = builder.createIntegerConstant(loc, idxTy, i);
+ auto dims =
+ builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, array, dimIdx);
+ mlir::Value len = dims.getResult(1);
+ // We use C indexing here, so len-1 as loopcount
+ mlir::Value loopCount = builder.create<mlir::arith::SubIOp>(loc, len, one);
+ bounds.push_back(loopCount);
+ }
- // Begin loop code
- mlir::OpBuilder::InsertPoint loopEndPt = builder.saveInsertionPoint();
- builder.setInsertionPointToStart(loop.getBody());
+ // Create a loop nest consisting of DoLoopOp operations.
+ // Collect the loops' induction variables into indices array,
+ // which will be used in the innermost loop to load the input
+ // array's element.
+ // The loops are generated such that the innermost loop processes
+ // the 0 dimension.
+ llvm::SmallVector<mlir::Value, 15> indices;
+ for (unsigned i = rank; 0 < i; --i) {
+ mlir::Value step = one;
+ mlir::Value loopCount = bounds[i - 1];
+ auto loop = builder.create<fir::DoLoopOp>(loc, zeroIdx, loopCount, step,
+ /*unordered=*/false,
+ /*finalCountValue=*/false, init);
+ init = loop.getRegionIterArgs()[0];
+ indices.push_back(loop.getInductionVar());
+ // Set insertion point to the loop body so that the next loop
+ // is inserted inside the current one.
+ builder.setInsertionPointToStart(loop.getBody());
+ }
+
+ // Reverse the indices such that they are ordered as:
+ // <dim-0-idx, dim-1-idx, ...>
+ std::reverse(indices.begin(), indices.end());
+ // We are in the innermost loop: generate the reduction body.
mlir::Type eleRefTy = builder.getRefType(elementType);
- mlir::Value index = loop.getInductionVar();
mlir::Value addr =
- builder.create<fir::CoordinateOp>(loc, eleRefTy, array, index);
+ builder.create<fir::CoordinateOp>(loc, eleRefTy, array, indices);
mlir::Value elem = builder.create<fir::LoadOp>(loc, addr);
- reductionVal = genBody(builder, loc, elementType, elem, reductionVal);
-
- builder.create<fir::ResultOp>(loc, reductionVal);
- // End of loop.
- builder.restoreInsertionPoint(loopEndPt);
+ mlir::Value reductionVal = genBody(builder, loc, elementType, elem, init);
+
+ // Unwind the loop nest and insert ResultOp on each level
+ // to return the updated value of the reduction to the enclosing
+ // loops.
+ for (unsigned i = 0; i < rank; ++i) {
+ auto result = builder.create<fir::ResultOp>(loc, reductionVal);
+ // Proceed to the outer loop.
+ auto loop = mlir::cast<fir::DoLoopOp>(result->getParentOp());
+ reductionVal = loop.getResult(0);
+ // Set insertion point after the loop operation that we have
+ // just processed.
+ builder.setInsertionPointAfter(loop.getOperation());
+ }
- mlir::Value resultVal = loop.getResult(0);
- builder.create<mlir::func::ReturnOp>(loc, resultVal);
+ // End of loop nest. The insertion point is after the outermost loop.
+ // Return the reduction value from the function.
+ builder.create<mlir::func::ReturnOp>(loc, reductionVal);
}
/// Generate function body of the simplified version of RTNAME(Sum)
/// with signature provided by \p funcOp. The caller is responsible
/// for saving/restoring the original insertion point of \p builder.
/// \p funcOp is expected to be empty on entry to this function.
+/// \p rank specifies the rank of the input argument.
static void genRuntimeSumBody(fir::FirOpBuilder &builder,
- mlir::func::FuncOp &funcOp) {
- // function RTNAME(Sum)<T>_simplified(arr)
+ mlir::func::FuncOp &funcOp, unsigned rank) {
+ // function RTNAME(Sum)<T>x<rank>_simplified(arr)
// T, dimension(:) :: arr
// T sum = 0
// integer iter
// do iter = 0, extent(arr)
// sum = sum + arr[iter]
// end do
- // RTNAME(Sum)<T>_simplified = sum
- // end function RTNAME(Sum)<T>_simplified
+ // RTNAME(Sum)<T>x<rank>_simplified = sum
+ // end function RTNAME(Sum)<T>x<rank>_simplified
auto zero = [](fir::FirOpBuilder builder, mlir::Location loc,
mlir::Type elementType) {
if (auto ty = elementType.dyn_cast<mlir::FloatType>()) {
@@ -200,11 +240,11 @@ static void genRuntimeSumBody(fir::FirOpBuilder &builder,
return {};
};
- genReductionLoop(builder, funcOp, zero, genBodyOp);
+ genReductionLoop(builder, funcOp, zero, genBodyOp, rank);
}
static void genRuntimeMaxvalBody(fir::FirOpBuilder &builder,
- mlir::func::FuncOp &funcOp) {
+ mlir::func::FuncOp &funcOp, unsigned rank) {
auto init = [](fir::FirOpBuilder builder, mlir::Location loc,
mlir::Type elementType) {
if (auto ty = elementType.dyn_cast<mlir::FloatType>()) {
@@ -228,7 +268,7 @@ static void genRuntimeMaxvalBody(fir::FirOpBuilder &builder,
llvm_unreachable("unsupported type");
return {};
};
- genReductionLoop(builder, funcOp, init, genBodyOp);
+ genReductionLoop(builder, funcOp, init, genBodyOp, rank);
}
/// Generate function type for the simplified version of RTNAME(DotProduct)
@@ -410,21 +450,31 @@ static bool isZero(mlir::Value val) {
return false;
}
-static mlir::Value findShape(mlir::Value val) {
+static mlir::Value findBoxDef(mlir::Value val) {
if (auto op = expectConvertOp(val)) {
assert(op->getOperands().size() != 0);
if (auto box = mlir::dyn_cast_or_null<fir::EmboxOp>(
op->getOperand(0).getDefiningOp()))
- return box.getShape();
+ return box.getResult();
+ if (auto box = mlir::dyn_cast_or_null<fir::ReboxOp>(
+ op->getOperand(0).getDefiningOp()))
+ return box.getResult();
}
return {};
}
static unsigned getDimCount(mlir::Value val) {
- if (mlir::Value shapeVal = findShape(val)) {
- mlir::Type resType = shapeVal.getDefiningOp()->getResultTypes()[0];
- return fir::getRankOfShapeType(resType);
- }
+ // In order to find the dimensions count, we look for EmboxOp/ReboxOp
+ // and take the count from its *result* type. Note that in case
+ // of sliced emboxing the operand and the result of EmboxOp/ReboxOp
+ // have
diff erent types.
+ // Actually, we can take the box type from the operand of
+ // the first ConvertOp that has non-opaque box type that we meet
+ // going through the ConvertOp chain.
+ if (mlir::Value emboxVal = findBoxDef(val))
+ if (auto boxTy = emboxVal.getType().dyn_cast<fir::BoxType>())
+ if (auto seqTy = boxTy.getEleTy().dyn_cast<fir::SequenceType>())
+ return seqTy.getDimension();
return 0;
}
@@ -455,7 +505,6 @@ void SimplifyIntrinsicsPass::simplifyReduction(fir::CallOp call,
const fir::KindMapping &kindMap,
GenReductionBodyTy genBodyFunc) {
mlir::SymbolRefAttr callee = call.getCalleeAttr();
- mlir::StringRef funcName = callee.getLeafReference().getValue();
mlir::Operation::operand_range args = call.getArgs();
// args[1] and args[2] are source filename and line number, ignored.
const mlir::Value &dim = args[3];
@@ -464,7 +513,7 @@ void SimplifyIntrinsicsPass::simplifyReduction(fir::CallOp call,
// detail in the runtime library.
bool dimAndMaskAbsent = isZero(dim) && isOperandAbsent(mask);
unsigned rank = getDimCount(args[0]);
- if (dimAndMaskAbsent && rank == 1) {
+ if (dimAndMaskAbsent && rank > 0) {
mlir::Location loc = call.getLoc();
fir::FirOpBuilder builder(call, kindMap);
@@ -483,8 +532,17 @@ void SimplifyIntrinsicsPass::simplifyReduction(fir::CallOp call,
auto typeGenerator = [&resultType](fir::FirOpBuilder &builder) {
return genNoneBoxType(builder, resultType);
};
+ auto bodyGenerator = [&rank, &genBodyFunc](fir::FirOpBuilder &builder,
+ mlir::func::FuncOp &funcOp) {
+ genBodyFunc(builder, funcOp, rank);
+ };
+ // Mangle the function name with the rank value as "x<rank>".
+ std::string funcName =
+ (mlir::Twine{callee.getLeafReference().getValue(), "x"} +
+ mlir::Twine{rank})
+ .str();
mlir::func::FuncOp newFunc =
- getOrCreateFunction(builder, funcName, typeGenerator, genBodyFunc);
+ getOrCreateFunction(builder, funcName, typeGenerator, bodyGenerator);
auto newCall =
builder.create<fir::CallOp>(loc, newFunc, mlir::ValueRange{args[0]});
call->replaceAllUsesWith(newCall.getResults());
diff --git a/flang/test/Transforms/simplifyintrinsics.fir b/flang/test/Transforms/simplifyintrinsics.fir
index b5d24c5785243..e3ac9c930d299 100644
--- a/flang/test/Transforms/simplifyintrinsics.fir
+++ b/flang/test/Transforms/simplifyintrinsics.fir
@@ -34,20 +34,21 @@ module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.targ
// CHECK: %[[A_BOX_I32:.*]] = fir.embox %[[A]](%[[SHAPE]]) : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> !fir.box<!fir.array<10xi32>>
// CHECK: %[[A_BOX_NONE:.*]] = fir.convert %[[A_BOX_I32]] : (!fir.box<!fir.array<10xi32>>) -> !fir.box<none>
// CHECK-NOT: fir.call @_FortranASumInteger4({{.*}})
-// CHECK: %[[RES:.*]] = fir.call @_FortranASumInteger4_simplified(%[[A_BOX_NONE]]) : (!fir.box<none>) -> i32
+// CHECK: %[[RES:.*]] = fir.call @_FortranASumInteger4x1_simplified(%[[A_BOX_NONE]]) : (!fir.box<none>) -> i32
// CHECK-NOT: fir.call @_FortranASumInteger4({{.*}})
// CHECK: return %{{.*}} : i32
// CHECK: }
// CHECK: func.func private @_FortranASumInteger4(!fir.box<none>, !fir.ref<i8>, i32, i32, !fir.box<none>) -> i32 attributes {fir.runtime}
-// CHECK-LABEL: func.func private @_FortranASumInteger4_simplified(
+// CHECK-LABEL: func.func private @_FortranASumInteger4x1_simplified(
// CHECK-SAME: %[[ARR:.*]]: !fir.box<none>) -> i32 attributes {llvm.linkage = #llvm.linkage<linkonce_odr>} {
// CHECK: %[[CINDEX_0:.*]] = arith.constant 0 : index
// CHECK: %[[ARR_BOX_I32:.*]] = fir.convert %[[ARR]] : (!fir.box<none>) -> !fir.box<!fir.array<?xi32>>
-// CHECK: %[[DIMS:.*]]:3 = fir.box_dims %[[ARR_BOX_I32]], %[[CINDEX_0]] : (!fir.box<!fir.array<?xi32>>, index) -> (index, index, index)
+// CHECK: %[[CI32_0:.*]] = arith.constant 0 : i32
// CHECK: %[[CINDEX_1:.*]] = arith.constant 1 : index
+// CHECK: %[[DIMIDX_0:.*]] = arith.constant 0 : index
+// CHECK: %[[DIMS:.*]]:3 = fir.box_dims %[[ARR_BOX_I32]], %[[DIMIDX_0]] : (!fir.box<!fir.array<?xi32>>, index) -> (index, index, index)
// CHECK: %[[EXTENT:.*]] = arith.subi %[[DIMS]]#1, %[[CINDEX_1]] : index
-// CHECK: %[[CI32_0:.*]] = arith.constant 0 : i32
// CHECK: %[[RES:.*]] = fir.do_loop %[[ITER:.*]] = %[[CINDEX_0]] to %[[EXTENT]] step %[[CINDEX_1]] iter_args(%[[SUM:.*]] = %[[CI32_0]]) -> (i32) {
// CHECK: %[[ITEM:.*]] = fir.coordinate_of %[[ARR_BOX_I32]], %[[ITER]] : (!fir.box<!fir.array<?xi32>>, index) -> !fir.ref<i32>
// CHECK: %[[ITEM_VAL:.*]] = fir.load %[[ITEM]] : !fir.ref<i32>
@@ -59,7 +60,7 @@ module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.targ
// -----
-// Call to SUM with 2D I32 arrays is not replaced.
+// Call to SUM with 2D I32 arrays is replaced.
module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.target_triple = "native"} {
func.func @sum_2d_array_int(%arg0: !fir.ref<!fir.array<10x10xi32>> {fir.bindc_name = "a"}) -> i32 {
%c10 = arith.constant 10 : index
@@ -88,9 +89,39 @@ module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.targ
}
// CHECK-LABEL: func.func @sum_2d_array_int({{.*}} !fir.ref<!fir.array<10x10xi32>> {fir.bindc_name = "a"}) -> i32 {
-// CHECK-NOT: fir.call @_FortranASumInteger4_simplified({{.*}})
-// CHECK: fir.call @_FortranASumInteger4({{.*}}) : (!fir.box<none>, !fir.ref<i8>, i32, i32, !fir.box<none>) -> i32
-// CHECK-NOT: fir.call @_FortranASumInteger4_simplified({{.*}})
+// CHECK: %[[SHAPE:.*]] = fir.shape %{{.*}} : (index, index) -> !fir.shape<2>
+// CHECK: %[[A_BOX_I32:.*]] = fir.embox %[[A]](%[[SHAPE]]) : (!fir.ref<!fir.array<10x10xi32>>, !fir.shape<2>) -> !fir.box<!fir.array<10x10xi32>>
+// CHECK: %[[A_BOX_NONE:.*]] = fir.convert %[[A_BOX_I32]] : (!fir.box<!fir.array<10x10xi32>>) -> !fir.box<none>
+// CHECK-NOT: fir.call @_FortranASumInteger4({{.*}})
+// CHECK: %[[RES:.*]] = fir.call @_FortranASumInteger4x2_simplified(%[[A_BOX_NONE]]) : (!fir.box<none>) -> i32
+// CHECK-NOT: fir.call @_FortranASumInteger4({{.*}})
+// CHECK: return %{{.*}} : i32
+// CHECK: }
+// CHECK: func.func private @_FortranASumInteger4(!fir.box<none>, !fir.ref<i8>, i32, i32, !fir.box<none>) -> i32 attributes {fir.runtime}
+
+// CHECK-LABEL: func.func private @_FortranASumInteger4x2_simplified(
+// CHECK-SAME: %[[ARR:.*]]: !fir.box<none>) -> i32 attributes {llvm.linkage = #llvm.linkage<linkonce_odr>} {
+// CHECK: %[[CINDEX_0:.*]] = arith.constant 0 : index
+// CHECK: %[[ARR_BOX_I32:.*]] = fir.convert %[[ARR]] : (!fir.box<none>) -> !fir.box<!fir.array<?x?xi32>>
+// CHECK: %[[CI32_0:.*]] = arith.constant 0 : i32
+// CHECK: %[[CINDEX_1:.*]] = arith.constant 1 : index
+// CHECK: %[[DIMIDX_0:.*]] = arith.constant 0 : index
+// CHECK: %[[DIMS_0:.*]]:3 = fir.box_dims %[[ARR_BOX_I32]], %[[DIMIDX_0]] : (!fir.box<!fir.array<?x?xi32>>, index) -> (index, index, index)
+// CHECK: %[[EXTENT_0:.*]] = arith.subi %[[DIMS_0]]#1, %[[CINDEX_1]] : index
+// CHECK: %[[DIMIDX_1:.*]] = arith.constant 1 : index
+// CHECK: %[[DIMS_1:.*]]:3 = fir.box_dims %[[ARR_BOX_I32]], %[[DIMIDX_1]] : (!fir.box<!fir.array<?x?xi32>>, index) -> (index, index, index)
+// CHECK: %[[EXTENT_1:.*]] = arith.subi %[[DIMS_1]]#1, %[[CINDEX_1]] : index
+// CHECK: %[[RES_1:.*]] = fir.do_loop %[[ITER_1:.*]] = %[[CINDEX_0]] to %[[EXTENT_1]] step %[[CINDEX_1]] iter_args(%[[SUM_1:.*]] = %[[CI32_0]]) -> (i32) {
+// CHECK: %[[RES_0:.*]] = fir.do_loop %[[ITER_0:.*]] = %[[CINDEX_0]] to %[[EXTENT_0]] step %[[CINDEX_1]] iter_args(%[[SUM_0:.*]] = %[[SUM_1]]) -> (i32) {
+// CHECK: %[[ITEM:.*]] = fir.coordinate_of %[[ARR_BOX_I32]], %[[ITER_0]], %[[ITER_1]] : (!fir.box<!fir.array<?x?xi32>>, index, index) -> !fir.ref<i32>
+// CHECK: %[[ITEM_VAL:.*]] = fir.load %[[ITEM]] : !fir.ref<i32>
+// CHECK: %[[NEW_SUM:.*]] = arith.addi %[[ITEM_VAL]], %[[SUM_0]] : i32
+// CHECK: fir.result %[[NEW_SUM]] : i32
+// CHECK: }
+// CHECK: fir.result %[[RES_0]]
+// CHECK: }
+// CHECK: return %[[RES_1]] : i32
+// CHECK: }
// -----
@@ -129,19 +160,20 @@ module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.targ
// CHECK: %[[A_BOX_F64:.*]] = fir.embox %[[A]](%[[SHAPE]]) : (!fir.ref<!fir.array<10xf64>>, !fir.shape<1>) -> !fir.box<!fir.array<10xf64>>
// CHECK: %[[A_BOX_NONE:.*]] = fir.convert %[[A_BOX_F64]] : (!fir.box<!fir.array<10xf64>>) -> !fir.box<none>
// CHECK-NOT: fir.call @_FortranASumReal8({{.*}})
-// CHECK: %[[RES:.*]] = fir.call @_FortranASumReal8_simplified(%[[A_BOX_NONE]]) : (!fir.box<none>) -> f64
+// CHECK: %[[RES:.*]] = fir.call @_FortranASumReal8x1_simplified(%[[A_BOX_NONE]]) : (!fir.box<none>) -> f64
// CHECK-NOT: fir.call @_FortranASumReal8({{.*}})
// CHECK: return %{{.*}} : f64
// CHECK: }
-// CHECK-LABEL: func.func private @_FortranASumReal8_simplified(
+// CHECK-LABEL: func.func private @_FortranASumReal8x1_simplified(
// CHECK-SAME: %[[ARR:.*]]: !fir.box<none>) -> f64 attributes {llvm.linkage = #llvm.linkage<linkonce_odr>} {
// CHECK: %[[CINDEX_0:.*]] = arith.constant 0 : index
// CHECK: %[[ARR_BOX_F64:.*]] = fir.convert %[[ARR]] : (!fir.box<none>) -> !fir.box<!fir.array<?xf64>>
-// CHECK: %[[DIMS:.*]]:3 = fir.box_dims %[[ARR_BOX_F64]], %[[CINDEX_0]] : (!fir.box<!fir.array<?xf64>>, index) -> (index, index, index)
+// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f64
// CHECK: %[[CINDEX_1:.*]] = arith.constant 1 : index
+// CHECK: %[[DIMIDX_0:.*]] = arith.constant 0 : index
+// CHECK: %[[DIMS:.*]]:3 = fir.box_dims %[[ARR_BOX_F64]], %[[DIMIDX_0]] : (!fir.box<!fir.array<?xf64>>, index) -> (index, index, index)
// CHECK: %[[EXTENT:.*]] = arith.subi %[[DIMS]]#1, %[[CINDEX_1]] : index
-// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f64
// CHECK: %[[RES:.*]] = fir.do_loop %[[ITER:.*]] = %[[CINDEX_0]] to %[[EXTENT]] step %[[CINDEX_1]] iter_args(%[[SUM]] = %[[ZERO]]) -> (f64) {
// CHECK: %[[ITEM:.*]] = fir.coordinate_of %[[ARR_BOX_F64]], %[[ITER]] : (!fir.box<!fir.array<?xf64>>, index) -> !fir.ref<f64>
// CHECK: %[[ITEM_VAL:.*]] = fir.load %[[ITEM]] : !fir.ref<f64>
@@ -188,19 +220,20 @@ module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.targ
// CHECK: %[[A_BOX_F32:.*]] = fir.embox %[[A]](%[[SHAPE]]) : (!fir.ref<!fir.array<10xf32>>, !fir.shape<1>) -> !fir.box<!fir.array<10xf32>>
// CHECK: %[[A_BOX_NONE:.*]] = fir.convert %[[A_BOX_F32]] : (!fir.box<!fir.array<10xf32>>) -> !fir.box<none>
// CHECK-NOT: fir.call @_FortranASumReal4({{.*}})
-// CHECK: %[[RES:.*]] = fir.call @_FortranASumReal4_simplified(%[[A_BOX_NONE]]) : (!fir.box<none>) -> f32
+// CHECK: %[[RES:.*]] = fir.call @_FortranASumReal4x1_simplified(%[[A_BOX_NONE]]) : (!fir.box<none>) -> f32
// CHECK-NOT: fir.call @_FortranASumReal4({{.*}})
// CHECK: return %{{.*}} : f32
// CHECK: }
-// CHECK-LABEL: func.func private @_FortranASumReal4_simplified(
+// CHECK-LABEL: func.func private @_FortranASumReal4x1_simplified(
// CHECK-SAME: %[[ARR:.*]]: !fir.box<none>) -> f32 attributes {llvm.linkage = #llvm.linkage<linkonce_odr>} {
// CHECK: %[[CINDEX_0:.*]] = arith.constant 0 : index
// CHECK: %[[ARR_BOX_F32:.*]] = fir.convert %[[ARR]] : (!fir.box<none>) -> !fir.box<!fir.array<?xf32>>
-// CHECK: %[[DIMS:.*]]:3 = fir.box_dims %[[ARR_BOX_F32]], %[[CINDEX_0]] : (!fir.box<!fir.array<?xf32>>, index) -> (index, index, index)
+// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[CINDEX_1:.*]] = arith.constant 1 : index
+// CHECK: %[[DIMIDX_0:.*]] = arith.constant 0 : index
+// CHECK: %[[DIMS:.*]]:3 = fir.box_dims %[[ARR_BOX_F32]], %[[DIMIDX_0]] : (!fir.box<!fir.array<?xf32>>, index) -> (index, index, index)
// CHECK: %[[EXTENT:.*]] = arith.subi %[[DIMS]]#1, %[[CINDEX_1]] : index
-// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[RES:.*]] = fir.do_loop %[[ITER:.*]] = %[[CINDEX_0]] to %[[EXTENT]] step %[[CINDEX_1]] iter_args(%[[SUM]] = %[[ZERO]]) -> (f32) {
// CHECK: %[[ITEM:.*]] = fir.coordinate_of %[[ARR_BOX_F32]], %[[ITER]] : (!fir.box<!fir.array<?xf32>>, index) -> !fir.ref<f32>
// CHECK: %[[ITEM_VAL:.*]] = fir.load %[[ITEM]] : !fir.ref<f32>
@@ -243,9 +276,9 @@ module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.targ
}
// CHECK-LABEL: func.func @sum_1d_complex(%{{.*}}: !fir.ref<!fir.array<10x!fir.complex<4>>> {fir.bindc_name = "a"}) -> !fir.complex<4> {
-// CHECK-NOT: fir.call @_FortranACppSumComplex4_simplified({{.*}})
+// CHECK-NOT: fir.call @_FortranACppSumComplex4x1_simplified({{.*}})
// CHECK: fir.call @_FortranACppSumComplex4({{.*}}) : (!fir.ref<complex<f32>>, !fir.box<none>, !fir.ref<i8>, i32, i32, !fir.box<none>) -> none
-// CHECK-NOT: fir.call @_FortranACppSumComplex4_simplified({{.*}})
+// CHECK-NOT: fir.call @_FortranACppSumComplex4x1_simplified({{.*}})
// -----
@@ -298,20 +331,20 @@ module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.targ
// CHECK-LABEL: func.func @sum_1d_calla(%{{.*}}) -> i32 {
// CHECK-NOT: fir.call @_FortranASumInteger4({{.*}})
-// CHECK: fir.call @_FortranASumInteger4_simplified(%{{.*}})
+// CHECK: fir.call @_FortranASumInteger4x1_simplified(%{{.*}})
// CHECK-NOT: fir.call @_FortranASumInteger4({{.*}})
// CHECK: }
// CHECK-LABEL: func.func @sum_1d_callb(%{{.*}}) -> i32 {
// CHECK-NOT: fir.call @_FortranASumInteger4({{.*}})
-// CHECK: fir.call @_FortranASumInteger4_simplified(%{{.*}})
+// CHECK: fir.call @_FortranASumInteger4x1_simplified(%{{.*}})
// CHECK-NOT: fir.call @_FortranASumInteger4({{.*}})
// CHECK: }
-// CHECK-LABEL: func.func private @_FortranASumInteger4_simplified({{.*}}) -> i32 {{.*}} {
+// CHECK-LABEL: func.func private @_FortranASumInteger4x1_simplified({{.*}}) -> i32 {{.*}} {
// CHECK: return %{{.*}} : i32
// CHECK: }
-// CHECK-NOT: func.func private @_FortranASumInteger4_simplified({{.*}})
+// CHECK-NOT: func.func private @_FortranASumInteger4x1_simplified({{.*}})
// -----
@@ -354,14 +387,14 @@ module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.targ
// CHECK: %[[SLICE:.*]] = fir.slice %{{.*}}, %{{.*}}, %[[CINDEX_2]] : (index, index, index) -> !fir.slice<1>
// CHECK: %[[A_BOX_I32:.*]] = fir.embox %{{.*}}(%[[SHAPE]]) {{\[}}%[[SLICE]]] : (!fir.ref<!fir.array<20xi32>>, !fir.shape<1>, !fir.slice<1>) -> !fir.box<!fir.array<?xi32>>
// CHECK: %[[A_BOX_NONE:.*]] = fir.convert %[[A_BOX_I32]] : (!fir.box<!fir.array<?xi32>>) -> !fir.box<none>
-// CHECK: %{{.*}} = fir.call @_FortranASumInteger4_simplified(%[[A_BOX_NONE]]) : (!fir.box<none>) -> i32
+// CHECK: %{{.*}} = fir.call @_FortranASumInteger4x1_simplified(%[[A_BOX_NONE]]) : (!fir.box<none>) -> i32
// CHECK: return %{{.*}} : i32
// CHECK: }
-// CHECK-LABEL: func.func private @_FortranASumInteger4_simplified(%{{.*}}) -> i32 attributes {llvm.linkage = #llvm.linkage<linkonce_odr>} {
+// CHECK-LABEL: func.func private @_FortranASumInteger4x1_simplified(%{{.*}}) -> i32 attributes {llvm.linkage = #llvm.linkage<linkonce_odr>} {
// CHECK: %[[ARR_BOX_I32:.*]] = fir.convert %{{.*}} : (!fir.box<none>) -> !fir.box<!fir.array<?xi32>>
-// CHECK: %[[DIMS:.*]]:3 = fir.box_dims %[[ARR_BOX_I32]], %{{.*}} : (!fir.box<!fir.array<?xi32>>, index) -> (index, index, index)
// CHECK: %[[CINDEX_1:.*]] = arith.constant 1 : index
+// CHECK: %[[DIMS:.*]]:3 = fir.box_dims %[[ARR_BOX_I32]], %{{.*}} : (!fir.box<!fir.array<?xi32>>, index) -> (index, index, index)
// CHECK: %[[EXTENT:.*]] = arith.subi %[[DIMS]]#1, %[[CINDEX_1]] : index
// CHECK: %[[RES:.*]] = fir.do_loop %[[ITER:.*]] = %{{.*}} to %[[EXTENT]] step %[[CINDEX_1]] iter_args({{.*}}) -> (i32) {
// CHECK: %{{.*}} = fir.coordinate_of %[[ARR_BOX_I32]], %[[ITER]] : (!fir.box<!fir.array<?xi32>>, index) -> !fir.ref<i32>
@@ -792,18 +825,19 @@ module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.targ
// CHECK: %[[SHAPE:.*]] = fir.shape %{{.*}} : (index) -> !fir.shape<1>
// CHECK: %[[A_BOX_I32:.*]] = fir.embox %[[A]](%[[SHAPE]]) : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> !fir.box<!fir.array<10xi32>>
// CHECK: %[[A_BOX_NONE:.*]] = fir.convert %[[A_BOX_I32]] : (!fir.box<!fir.array<10xi32>>) -> !fir.box<none>
-// CHECK: %[[RES:.*]] = fir.call @_FortranAMaxvalInteger4_simplified(%[[A_BOX_NONE]]) : (!fir.box<none>) -> i32
+// CHECK: %[[RES:.*]] = fir.call @_FortranAMaxvalInteger4x1_simplified(%[[A_BOX_NONE]]) : (!fir.box<none>) -> i32
// CHECK: return %{{.*}} : i32
// CHECK: }
-// CHECK-LABEL: func.func private @_FortranAMaxvalInteger4_simplified(
+// CHECK-LABEL: func.func private @_FortranAMaxvalInteger4x1_simplified(
// CHECK-SAME: %[[ARR:.*]]: !fir.box<none>) -> i32 attributes {llvm.linkage = #llvm.linkage<linkonce_odr>} {
// CHECK: %[[CINDEX_0:.*]] = arith.constant 0 : index
// CHECK: %[[ARR_BOX_I32:.*]] = fir.convert %[[ARR]] : (!fir.box<none>) -> !fir.box<!fir.array<?xi32>>
-// CHECK: %[[DIMS:.*]]:3 = fir.box_dims %[[ARR_BOX_I32]], %[[CINDEX_0]] : (!fir.box<!fir.array<?xi32>>, index) -> (index, index, index)
+// CHECK: %[[CI32_MININT:.*]] = arith.constant -2147483648 : i32
// CHECK: %[[CINDEX_1:.*]] = arith.constant 1 : index
+// CHECK: %[[DIMIDX_0:.*]] = arith.constant 0 : index
+// CHECK: %[[DIMS:.*]]:3 = fir.box_dims %[[ARR_BOX_I32]], %[[DIMIDX_0]] : (!fir.box<!fir.array<?xi32>>, index) -> (index, index, index)
// CHECK: %[[EXTENT:.*]] = arith.subi %[[DIMS]]#1, %[[CINDEX_1]] : index
-// CHECK: %[[CI32_MININT:.*]] = arith.constant -2147483648 : i32
// CHECK: %[[RES:.*]] = fir.do_loop %[[ITER:.*]] = %[[CINDEX_0]] to %[[EXTENT]] step %[[CINDEX_1]] iter_args(%[[MAX:.*]] = %[[CI32_MININT]]) -> (i32) {
// CHECK: %[[ITEM:.*]] = fir.coordinate_of %[[ARR_BOX_I32]], %[[ITER]] : (!fir.box<!fir.array<?xi32>>, index) -> !fir.ref<i32>
// CHECK: %[[ITEM_VAL:.*]] = fir.load %[[ITEM]] : !fir.ref<i32>
@@ -849,18 +883,19 @@ module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.targ
// CHECK: %[[SHAPE:.*]] = fir.shape %[[CINDEX_10]] : (index) -> !fir.shape<1>
// CHECK: %[[A_BOX_F64:.*]] = fir.embox %[[A]](%[[SHAPE]]) : (!fir.ref<!fir.array<10xf64>>, !fir.shape<1>) -> !fir.box<!fir.array<10xf64>>
// CHECK: %[[A_BOX_NONE:.*]] = fir.convert %[[A_BOX_F64]] : (!fir.box<!fir.array<10xf64>>) -> !fir.box<none>
-// CHECK: %[[RES:.*]] = fir.call @_FortranAMaxvalReal8_simplified(%[[A_BOX_NONE]]) : (!fir.box<none>) -> f64
+// CHECK: %[[RES:.*]] = fir.call @_FortranAMaxvalReal8x1_simplified(%[[A_BOX_NONE]]) : (!fir.box<none>) -> f64
// CHECK: return %{{.*}} : f64
// CHECK: }
-// CHECK-LABEL: func.func private @_FortranAMaxvalReal8_simplified(
+// CHECK-LABEL: func.func private @_FortranAMaxvalReal8x1_simplified(
// CHECK-SAME: %[[ARR:.*]]: !fir.box<none>) -> f64 attributes {llvm.linkage = #llvm.linkage<linkonce_odr>} {
// CHECK: %[[CINDEX_0:.*]] = arith.constant 0 : index
// CHECK: %[[ARR_BOX_F64:.*]] = fir.convert %[[ARR]] : (!fir.box<none>) -> !fir.box<!fir.array<?xf64>>
-// CHECK: %[[DIMS:.*]]:3 = fir.box_dims %[[ARR_BOX_F64]], %[[CINDEX_0]] : (!fir.box<!fir.array<?xf64>>, index) -> (index, index, index)
+// CHECK: %[[NEG_DBL_MAX:.*]] = arith.constant -1.7976931348623157E+308 : f64
// CHECK: %[[CINDEX_1:.*]] = arith.constant 1 : index
+// CHECK: %[[DIMIDX_0:.*]] = arith.constant 0 : index
+// CHECK: %[[DIMS:.*]]:3 = fir.box_dims %[[ARR_BOX_F64]], %[[DIMIDX_0]] : (!fir.box<!fir.array<?xf64>>, index) -> (index, index, index)
// CHECK: %[[EXTENT:.*]] = arith.subi %[[DIMS]]#1, %[[CINDEX_1]] : index
-// CHECK: %[[NEG_DBL_MAX:.*]] = arith.constant -1.7976931348623157E+308 : f64
// CHECK: %[[RES:.*]] = fir.do_loop %[[ITER:.*]] = %[[CINDEX_0]] to %[[EXTENT]] step %[[CINDEX_1]] iter_args(%[[MAX]] = %[[NEG_DBL_MAX]]) -> (f64) {
// CHECK: %[[ITEM:.*]] = fir.coordinate_of %[[ARR_BOX_F64]], %[[ITER]] : (!fir.box<!fir.array<?xf64>>, index) -> !fir.ref<f64>
// CHECK: %[[ITEM_VAL:.*]] = fir.load %[[ITEM]] : !fir.ref<f64>
@@ -869,3 +904,97 @@ module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.targ
// CHECK: }
// CHECK: return %[[RES]] : f64
// CHECK: }
+
+// -----
+
+// SUM reduction of sliced explicit-shape array is replaced with
+// 2D simplified implementation.
+func.func @sum_sliced_embox_i64(%arg0: !fir.ref<!fir.array<10x10x10xi64>> {fir.bindc_name = "a"}) -> f32 {
+ %c10 = arith.constant 10 : index
+ %c10_0 = arith.constant 10 : index
+ %c10_1 = arith.constant 10 : index
+ %0 = fir.alloca f32 {bindc_name = "sum_sliced_embox_i64", uniq_name = "_QFsum_sliced_embox_i64Esum_sliced_embox_i64"}
+ %1 = fir.alloca i64 {bindc_name = "sum_sliced_i64", uniq_name = "_QFsum_sliced_embox_i64Esum_sliced_i64"}
+ %c1 = arith.constant 1 : index
+ %c1_i64 = arith.constant 1 : i64
+ %2 = fir.convert %c1_i64 : (i64) -> index
+ %3 = arith.addi %c1, %c10 : index
+ %4 = arith.subi %3, %c1 : index
+ %c1_i64_2 = arith.constant 1 : i64
+ %5 = fir.convert %c1_i64_2 : (i64) -> index
+ %6 = arith.addi %c1, %c10_0 : index
+ %7 = arith.subi %6, %c1 : index
+ %c1_i64_3 = arith.constant 1 : i64
+ %8 = fir.undefined index
+ %9 = fir.shape %c10, %c10_0, %c10_1 : (index, index, index) -> !fir.shape<3>
+ %10 = fir.slice %c1, %4, %2, %c1, %7, %5, %c1_i64_3, %8, %8 : (index, index, index, index, index, index, i64, index, index) -> !fir.slice<3>
+ %11 = fir.embox %arg0(%9) [%10] : (!fir.ref<!fir.array<10x10x10xi64>>, !fir.shape<3>, !fir.slice<3>) -> !fir.box<!fir.array<?x?xi64>>
+ %12 = fir.absent !fir.box<i1>
+ %c0 = arith.constant 0 : index
+ %13 = fir.address_of(@_QQcl.2E2F746573742E66393000) : !fir.ref<!fir.char<1,11>>
+ %c3_i32 = arith.constant 3 : i32
+ %14 = fir.convert %11 : (!fir.box<!fir.array<?x?xi64>>) -> !fir.box<none>
+ %15 = fir.convert %13 : (!fir.ref<!fir.char<1,11>>) -> !fir.ref<i8>
+ %16 = fir.convert %c0 : (index) -> i32
+ %17 = fir.convert %12 : (!fir.box<i1>) -> !fir.box<none>
+ %18 = fir.call @_FortranASumInteger8(%14, %15, %c3_i32, %16, %17) : (!fir.box<none>, !fir.ref<i8>, i32, i32, !fir.box<none>) -> i64
+ fir.store %18 to %1 : !fir.ref<i64>
+ %19 = fir.load %0 : !fir.ref<f32>
+ return %19 : f32
+}
+func.func private @_FortranASumInteger8(!fir.box<none>, !fir.ref<i8>, i32, i32, !fir.box<none>) -> i64 attributes {fir.runtime}
+fir.global linkonce @_QQcl.2E2F746573742E66393000 constant : !fir.char<1,11> {
+ %0 = fir.string_lit "./test.f90\00"(11) : !fir.char<1,11>
+ fir.has_value %0 : !fir.char<1,11>
+}
+
+// CHECK-NOT: call{{.*}}_FortranASumInteger8(
+// CHECK: call @_FortranASumInteger8x2_simplified(
+// CHECK-NOT: call{{.*}}_FortranASumInteger8(
+
+// -----
+
+// SUM reduction of sliced assumed-shape array is replaced with
+// 2D simplified implementation.
+func.func @_QPsum_sliced_rebox_i64(%arg0: !fir.box<!fir.array<?x?x?xi64>> {fir.bindc_name = "a"}) -> f32 {
+ %0 = fir.alloca i64 {bindc_name = "sum_sliced_i64", uniq_name = "_QFsum_sliced_rebox_i64Esum_sliced_i64"}
+ %1 = fir.alloca f32 {bindc_name = "sum_sliced_rebox_i64", uniq_name = "_QFsum_sliced_rebox_i64Esum_sliced_rebox_i64"}
+ %c1 = arith.constant 1 : index
+ %c1_i64 = arith.constant 1 : i64
+ %2 = fir.convert %c1_i64 : (i64) -> index
+ %c0 = arith.constant 0 : index
+ %3:3 = fir.box_dims %arg0, %c0 : (!fir.box<!fir.array<?x?x?xi64>>, index) -> (index, index, index)
+ %4 = arith.addi %c1, %3#1 : index
+ %5 = arith.subi %4, %c1 : index
+ %c1_i64_0 = arith.constant 1 : i64
+ %6 = fir.convert %c1_i64_0 : (i64) -> index
+ %c1_1 = arith.constant 1 : index
+ %7:3 = fir.box_dims %arg0, %c1_1 : (!fir.box<!fir.array<?x?x?xi64>>, index) -> (index, index, index)
+ %8 = arith.addi %c1, %7#1 : index
+ %9 = arith.subi %8, %c1 : index
+ %c1_i64_2 = arith.constant 1 : i64
+ %10 = fir.undefined index
+ %11 = fir.slice %c1, %5, %2, %c1, %9, %6, %c1_i64_2, %10, %10 : (index, index, index, index, index, index, i64, index, index) -> !fir.slice<3>
+ %12 = fir.rebox %arg0 [%11] : (!fir.box<!fir.array<?x?x?xi64>>, !fir.slice<3>) -> !fir.box<!fir.array<?x?xi64>>
+ %13 = fir.absent !fir.box<i1>
+ %c0_3 = arith.constant 0 : index
+ %14 = fir.address_of(@_QQcl.2E2F746573742E66393000) : !fir.ref<!fir.char<1,11>>
+ %c8_i32 = arith.constant 8 : i32
+ %15 = fir.convert %12 : (!fir.box<!fir.array<?x?xi64>>) -> !fir.box<none>
+ %16 = fir.convert %14 : (!fir.ref<!fir.char<1,11>>) -> !fir.ref<i8>
+ %17 = fir.convert %c0_3 : (index) -> i32
+ %18 = fir.convert %13 : (!fir.box<i1>) -> !fir.box<none>
+ %19 = fir.call @_FortranASumInteger8(%15, %16, %c8_i32, %17, %18) : (!fir.box<none>, !fir.ref<i8>, i32, i32, !fir.box<none>) -> i64
+ fir.store %19 to %0 : !fir.ref<i64>
+ %20 = fir.load %1 : !fir.ref<f32>
+ return %20 : f32
+}
+func.func private @_FortranASumInteger8(!fir.box<none>, !fir.ref<i8>, i32, i32, !fir.box<none>) -> i64 attributes {fir.runtime}
+fir.global linkonce @_QQcl.2E2F746573742E66393000 constant : !fir.char<1,11> {
+ %0 = fir.string_lit "./test.f90\00"(11) : !fir.char<1,11>
+ fir.has_value %0 : !fir.char<1,11>
+}
+
+// CHECK-NOT: call{{.*}}_FortranASumInteger8(
+// CHECK: call @_FortranASumInteger8x2_simplified(
+// CHECK-NOT: call{{.*}}_FortranASumInteger8(
More information about the flang-commits
mailing list