[flang-commits] [flang] [Flang] Maxloc elemental intrinsic lowering. (PR #79469)

Tom Eccles via flang-commits flang-commits at lists.llvm.org
Sat Jan 27 15:12:53 PST 2024


================
@@ -812,50 +812,55 @@ class ReductionElementalConversion : public mlir::OpRewritePattern<Op> {
 // inlined elemental.
 //  %e = hlfir.elemental %shape ({ ... })
 //  %m = hlfir.minloc %array mask %e
-class MinMaxlocElementalConversion
-    : public mlir::OpRewritePattern<hlfir::MinlocOp> {
+template <typename Op>
+class MinMaxlocElementalConversion : public mlir::OpRewritePattern<Op> {
 public:
-  using mlir::OpRewritePattern<hlfir::MinlocOp>::OpRewritePattern;
+  using mlir::OpRewritePattern<Op>::OpRewritePattern;
 
   mlir::LogicalResult
-  matchAndRewrite(hlfir::MinlocOp minloc,
-                  mlir::PatternRewriter &rewriter) const override {
-    if (!minloc.getMask() || minloc.getDim() || minloc.getBack())
-      return rewriter.notifyMatchFailure(minloc, "Did not find valid minloc");
+  matchAndRewrite(Op mloc, mlir::PatternRewriter &rewriter) const override {
+    if (!mloc.getMask() || mloc.getDim() || mloc.getBack())
+      return rewriter.notifyMatchFailure(mloc,
+                                         "Did not find valid minloc/maxloc");
 
-    auto elemental = minloc.getMask().getDefiningOp<hlfir::ElementalOp>();
+    constexpr bool isMax = std::is_same_v<Op, hlfir::MaxlocOp>;
+
+    auto elemental =
+        mloc.getMask().template getDefiningOp<hlfir::ElementalOp>();
     if (!elemental || hlfir::elementalOpMustProduceTemp(elemental))
-      return rewriter.notifyMatchFailure(minloc, "Did not find elemental");
+      return rewriter.notifyMatchFailure(mloc, "Did not find elemental");
 
-    mlir::Value array = minloc.getArray();
+    mlir::Value array = mloc.getArray();
 
-    unsigned rank = mlir::cast<hlfir::ExprType>(minloc.getType()).getShape()[0];
+    unsigned rank = mlir::cast<hlfir::ExprType>(mloc.getType()).getShape()[0];
     mlir::Type arrayType = array.getType();
     if (!arrayType.isa<fir::BoxType>())
       return rewriter.notifyMatchFailure(
-          minloc, "Currently requires a boxed type input");
+          mloc, "Currently requires a boxed type input");
     mlir::Type elementType = hlfir::getFortranElementType(arrayType);
     if (!fir::isa_trivial(elementType))
       return rewriter.notifyMatchFailure(
-          minloc, "Character arrays are currently not handled");
+          mloc, "Character arrays are currently not handled");
 
-    mlir::Location loc = minloc.getLoc();
-    fir::FirOpBuilder builder{rewriter, minloc.getOperation()};
+    mlir::Location loc = mloc.getLoc();
+    fir::FirOpBuilder builder{rewriter, mloc.getOperation()};
     mlir::Value resultArr = builder.createTemporary(
         loc, fir::SequenceType::get(
-                 rank, hlfir::getFortranElementType(minloc.getType())));
+                 rank, hlfir::getFortranElementType(mloc.getType())));
 
     auto init = [](fir::FirOpBuilder builder, mlir::Location loc,
----------------
tblah wrote:

Windows CI needs isMax to be added to the capture specification

https://github.com/llvm/llvm-project/pull/79469


More information about the flang-commits mailing list