[flang-commits] [flang] [Flang] Minloc elemental intrinsic lowering (PR #74828)
Slava Zakharin via flang-commits
flang-commits at lists.llvm.org
Mon Dec 18 13:37:32 PST 2023
================
@@ -659,6 +677,194 @@ mlir::LogicalResult VariableAssignBufferization::matchAndRewrite(
return mlir::success();
}
+// Look for assign(minloc(mask=elemental)) and generate the minloc loop with
+// inlined elemental and no extra temporaries.
+// %e = hlfir.elemental %shape ({ ... })
+// %m = hlfir.minloc %array mask %e
+// hlfir.assign %m to %result
+// hlfir.destroy %m
+class AssignMinMaxlocElementalConversion
+ : public mlir::OpRewritePattern<hlfir::AssignOp> {
+public:
+ using mlir::OpRewritePattern<hlfir::AssignOp>::OpRewritePattern;
+
+ mlir::LogicalResult
+ matchAndRewrite(hlfir::AssignOp assign,
+ mlir::PatternRewriter &rewriter) const override {
+ auto minloc = assign.getOperand(0).getDefiningOp<hlfir::MinlocOp>();
+ if (!minloc || !minloc.getMask() || minloc.getDim() || minloc.getBack())
+ return rewriter.notifyMatchFailure(assign,
+ "Did not find minloc with kind");
+
+ auto elemental = minloc.getMask().getDefiningOp<hlfir::ElementalOp>();
+ if (!elemental || hlfir::elementalOpMustProduceTemp(elemental))
+ return rewriter.notifyMatchFailure(assign, "Did not find elemental");
+
+ mlir::Operation::user_range users = minloc->getUsers();
+ if (std::distance(users.begin(), users.end()) != 2)
+ return rewriter.notifyMatchFailure(assign, "Did not find minloc users");
+ auto destroy = mlir::dyn_cast<hlfir::DestroyOp>(
+ *users.begin() == minloc ? *++users.begin() : *users.begin());
+ if (!destroy)
+ return rewriter.notifyMatchFailure(assign, "Did not find destroy");
+
+ if (!checkForElementalEffectsBetween(elemental, assign, minloc.getArray(),
+ minloc))
+ return rewriter.notifyMatchFailure(assign, "Had unhandled effects");
+
+ mlir::Value resultArr = assign.getOperand(1);
+ mlir::Value array = minloc.getArray();
+
+ unsigned rank = mlir::cast<hlfir::ExprType>(minloc.getType()).getShape()[0];
+ mlir::Type arrayType = array.getType();
+ if (!arrayType.isa<fir::BoxType>())
+ return rewriter.notifyMatchFailure(
+ assign, "Currently requires a boxed type input");
+ mlir::Type elementType = hlfir::getFortranElementType(arrayType);
+ if (!fir::isa_trivial(elementType))
+ return rewriter.notifyMatchFailure(
+ assign, "Character arrays are currently not handled");
+
+ auto init = [](fir::FirOpBuilder builder, mlir::Location loc,
+ mlir::Type elementType) {
+ if (auto ty = elementType.dyn_cast<mlir::FloatType>()) {
+ const llvm::fltSemantics &sem = ty.getFloatSemantics();
+ return builder.createRealConstant(
+ loc, elementType,
+ llvm::APFloat::getLargest(sem, /*Negative=*/false));
+ }
+ unsigned bits = elementType.getIntOrFloatBitWidth();
+ int64_t maxInt = llvm::APInt::getSignedMaxValue(bits).getSExtValue();
+ return builder.createIntegerConstant(loc, elementType, maxInt);
+ };
+
+ auto genBodyOp =
+ [&rank, &resultArr, &elemental](
+ fir::FirOpBuilder builder, mlir::Location loc,
+ mlir::Type elementType, mlir::Value array, mlir::Value flagRef,
+ mlir::Value reduction,
+ const llvm::SmallVectorImpl<mlir::Value> &indices) -> mlir::Value {
+ // We are in the innermost loop: generate the elemental inline
+ mlir::Value oneIdx =
+ builder.createIntegerConstant(loc, builder.getIndexType(), 1);
+ llvm::SmallVector<mlir::Value> oneBasedIndices;
+ llvm::transform(
+ indices, std::back_inserter(oneBasedIndices), [&](mlir::Value V) {
+ return builder.create<mlir::arith::AddIOp>(loc, V, oneIdx);
+ });
+ hlfir::YieldElementOp yield =
+ hlfir::inlineElementalOp(loc, builder, elemental, oneBasedIndices);
+ mlir::Value maskElem = yield.getElementValue();
+ yield->erase();
+
+ mlir::Type ifCompatType = builder.getI1Type();
+ mlir::Value ifCompatElem =
+ builder.create<fir::ConvertOp>(loc, ifCompatType, maskElem);
+
+ llvm::SmallVector<mlir::Type> resultsTy = {elementType, elementType};
+ fir::IfOp maskIfOp =
+ builder.create<fir::IfOp>(loc, elementType, ifCompatElem,
+ /*withElseRegion=*/true);
+ builder.setInsertionPointToStart(&maskIfOp.getThenRegion().front());
+
+ // Set flag that mask was true at some point
+ mlir::Value flagSet = builder.createIntegerConstant(
+ loc, mlir::cast<fir::ReferenceType>(flagRef.getType()).getEleTy(), 1);
+ builder.create<fir::StoreOp>(loc, flagSet, flagRef);
+ mlir::Type eleRefTy = builder.getRefType(elementType);
+ mlir::Value addr = builder.create<hlfir::DesignateOp>(
----------------
vzakhari wrote:
Can you please try using `hlfir::getElementAt`? It should also work for `!hlfir.expr` input.
https://github.com/llvm/llvm-project/pull/74828
More information about the flang-commits
mailing list