[flang-commits] [flang] [Flang] Minloc elemental intrinsic lowering (PR #74828)

Slava Zakharin via flang-commits flang-commits at lists.llvm.org
Mon Dec 18 13:37:32 PST 2023


================
@@ -659,6 +677,194 @@ mlir::LogicalResult VariableAssignBufferization::matchAndRewrite(
   return mlir::success();
 }
 
+// Look for assign(minloc(mask=elemental)) and generate the minloc loop with
+// inlined elemental and no extra temporaries.
+//  %e = hlfir.elemental %shape ({ ... })
+//  %m = hlfir.minloc %array mask %e
+//  hlfir.assign %m to %result
+//  hlfir.destroy %m
+class AssignMinMaxlocElementalConversion
+    : public mlir::OpRewritePattern<hlfir::AssignOp> {
+public:
+  using mlir::OpRewritePattern<hlfir::AssignOp>::OpRewritePattern;
+
+  mlir::LogicalResult
+  matchAndRewrite(hlfir::AssignOp assign,
+                  mlir::PatternRewriter &rewriter) const override {
+    auto minloc = assign.getOperand(0).getDefiningOp<hlfir::MinlocOp>();
+    if (!minloc || !minloc.getMask() || minloc.getDim() || minloc.getBack())
+      return rewriter.notifyMatchFailure(assign,
+                                         "Did not find minloc with kind");
+
+    auto elemental = minloc.getMask().getDefiningOp<hlfir::ElementalOp>();
+    if (!elemental || hlfir::elementalOpMustProduceTemp(elemental))
+      return rewriter.notifyMatchFailure(assign, "Did not find elemental");
+
+    mlir::Operation::user_range users = minloc->getUsers();
+    if (std::distance(users.begin(), users.end()) != 2)
+      return rewriter.notifyMatchFailure(assign, "Did not find minloc users");
+    auto destroy = mlir::dyn_cast<hlfir::DestroyOp>(
+        *users.begin() == minloc ? *++users.begin() : *users.begin());
+    if (!destroy)
+      return rewriter.notifyMatchFailure(assign, "Did not find destroy");
+
+    if (!checkForElementalEffectsBetween(elemental, assign, minloc.getArray(),
+                                         minloc))
+      return rewriter.notifyMatchFailure(assign, "Had unhandled effects");
+
+    mlir::Value resultArr = assign.getOperand(1);
+    mlir::Value array = minloc.getArray();
+
+    unsigned rank = mlir::cast<hlfir::ExprType>(minloc.getType()).getShape()[0];
+    mlir::Type arrayType = array.getType();
+    if (!arrayType.isa<fir::BoxType>())
+      return rewriter.notifyMatchFailure(
+          assign, "Currently requires a boxed type input");
+    mlir::Type elementType = hlfir::getFortranElementType(arrayType);
+    if (!fir::isa_trivial(elementType))
+      return rewriter.notifyMatchFailure(
+          assign, "Character arrays are currently not handled");
+
+    auto init = [](fir::FirOpBuilder builder, mlir::Location loc,
+                   mlir::Type elementType) {
+      if (auto ty = elementType.dyn_cast<mlir::FloatType>()) {
+        const llvm::fltSemantics &sem = ty.getFloatSemantics();
+        return builder.createRealConstant(
+            loc, elementType,
+            llvm::APFloat::getLargest(sem, /*Negative=*/false));
+      }
+      unsigned bits = elementType.getIntOrFloatBitWidth();
+      int64_t maxInt = llvm::APInt::getSignedMaxValue(bits).getSExtValue();
+      return builder.createIntegerConstant(loc, elementType, maxInt);
+    };
+
+    auto genBodyOp =
+        [&rank, &resultArr, &elemental](
+            fir::FirOpBuilder builder, mlir::Location loc,
+            mlir::Type elementType, mlir::Value array, mlir::Value flagRef,
+            mlir::Value reduction,
+            const llvm::SmallVectorImpl<mlir::Value> &indices) -> mlir::Value {
+      // We are in the innermost loop: generate the elemental inline
+      mlir::Value oneIdx =
+          builder.createIntegerConstant(loc, builder.getIndexType(), 1);
+      llvm::SmallVector<mlir::Value> oneBasedIndices;
+      llvm::transform(
+          indices, std::back_inserter(oneBasedIndices), [&](mlir::Value V) {
+            return builder.create<mlir::arith::AddIOp>(loc, V, oneIdx);
+          });
+      hlfir::YieldElementOp yield =
+          hlfir::inlineElementalOp(loc, builder, elemental, oneBasedIndices);
+      mlir::Value maskElem = yield.getElementValue();
+      yield->erase();
+
+      mlir::Type ifCompatType = builder.getI1Type();
+      mlir::Value ifCompatElem =
+          builder.create<fir::ConvertOp>(loc, ifCompatType, maskElem);
+
+      llvm::SmallVector<mlir::Type> resultsTy = {elementType, elementType};
+      fir::IfOp maskIfOp =
+          builder.create<fir::IfOp>(loc, elementType, ifCompatElem,
+                                    /*withElseRegion=*/true);
+      builder.setInsertionPointToStart(&maskIfOp.getThenRegion().front());
+
+      // Set flag that mask was true at some point
+      mlir::Value flagSet = builder.createIntegerConstant(
+          loc, mlir::cast<fir::ReferenceType>(flagRef.getType()).getEleTy(), 1);
+      builder.create<fir::StoreOp>(loc, flagSet, flagRef);
+      mlir::Type eleRefTy = builder.getRefType(elementType);
+      mlir::Value addr = builder.create<hlfir::DesignateOp>(
----------------
vzakhari wrote:

Can you please try using `hlfir::getElementAt`?  It should also work for `!hlfir.expr` input.

https://github.com/llvm/llvm-project/pull/74828


More information about the flang-commits mailing list