[flang-commits] [flang] [flang] Inline hlfir.reshape as hlfir.elemental. (PR #124683)

David Green via flang-commits flang-commits at lists.llvm.org
Thu Jan 30 14:01:20 PST 2025


================
@@ -951,6 +951,218 @@ class DotProductConversion
   }
 };
 
+class ReshapeAsElementalConversion
+    : public mlir::OpRewritePattern<hlfir::ReshapeOp> {
+public:
+  using mlir::OpRewritePattern<hlfir::ReshapeOp>::OpRewritePattern;
+
+  llvm::LogicalResult
+  matchAndRewrite(hlfir::ReshapeOp reshape,
+                  mlir::PatternRewriter &rewriter) const override {
+    // Do not inline RESHAPE with ORDER yet. The runtime implementation
+    // may be good enough, unless the temporary creation overhead
+    // is high.
+    // TODO: If ORDER is constant, then we can still easily inline.
+    // TODO: If the result's rank is 1, then we can assume ORDER == (/1/).
+    if (reshape.getOrder())
+      return rewriter.notifyMatchFailure(reshape,
+                                         "RESHAPE with ORDER argument");
+
+    // Verify that the element types of ARRAY, PAD and the result
+    // match before doing any transformations. For example,
+    // the character types of different lengths may appear in the dead
+    // code, and it just does not make sense to inline hlfir.reshape
+    // in this case (a runtime call might have less code size footprint).
+    hlfir::Entity result = hlfir::Entity{reshape};
+    hlfir::Entity array = hlfir::Entity{reshape.getArray()};
+    mlir::Type elementType = array.getFortranElementType();
+    if (result.getFortranElementType() != elementType)
+      return rewriter.notifyMatchFailure(
+          reshape, "ARRAY and result have different types");
+    mlir::Value pad = reshape.getPad();
+    if (pad && hlfir::getFortranElementType(pad.getType()) != elementType)
+      return rewriter.notifyMatchFailure(reshape,
+                                         "ARRAY and PAD have different types");
+
+    // TODO: selecting between ARRAY and PAD of non-trivial element types
+    // requires more work. We have to select between two references
+    // to elements in ARRAY and PAD. This requires conditional
+    // bufferization of the element, if ARRAY/PAD is an expression.
+    if (pad && !fir::isa_trivial(elementType))
+      return rewriter.notifyMatchFailure(reshape,
+                                         "PAD present with non-trivial type");
+
+    mlir::Location loc = reshape.getLoc();
+    fir::FirOpBuilder builder{rewriter, reshape.getOperation()};
+    // Assume that all the indices arithmetic does not overflow
+    // the IndexType.
+    builder.setIntegerOverflowFlags(mlir::arith::IntegerOverflowFlags::nuw);
+
+    llvm::SmallVector<mlir::Value, 1> typeParams;
+    hlfir::genLengthParameters(loc, builder, array, typeParams);
+
+    // Fetch the extents of ARRAY, PAD and result beforehand.
+    llvm::SmallVector<mlir::Value, Fortran::common::maxRank> arrayExtents =
+        hlfir::genExtentsVector(loc, builder, array);
+
+    // If PAD is present, we have to use array size to start taking
+    // elements from the PAD array.
+    mlir::Value arraySize =
+        pad ? computeArraySize(loc, builder, arrayExtents) : nullptr;
+    hlfir::Entity shape = hlfir::Entity{reshape.getShape()};
+    llvm::SmallVector<mlir::Value, Fortran::common::maxRank> resultExtents;
+    mlir::Type indexType = builder.getIndexType();
+    for (int idx = 0; idx < result.getRank(); ++idx)
+      resultExtents.push_back(hlfir::loadElementAt(
+          loc, builder, shape,
+          builder.createIntegerConstant(loc, indexType, idx + 1)));
+    auto resultShape = builder.create<fir::ShapeOp>(loc, resultExtents);
+
+    auto genKernel = [&](mlir::Location loc, fir::FirOpBuilder &builder,
+                         mlir::ValueRange inputIndices) -> hlfir::Entity {
+      mlir::Value linearIndex =
+          computeLinearIndex(loc, builder, resultExtents, inputIndices);
+      fir::IfOp ifOp;
+      if (pad) {
+        // PAD is present. Check if this element comes from the PAD array.
+        mlir::Value isInsideArray = builder.create<mlir::arith::CmpIOp>(
+            loc, mlir::arith::CmpIPredicate::ult, linearIndex, arraySize);
+        ifOp = builder.create<fir::IfOp>(loc, elementType, isInsideArray,
+                                         /*withElseRegion=*/true);
+
+        // In the 'else' block, return an element from the PAD.
+        builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
+        // PAD is dynamically optional, but we can unconditionally access it
+        // in the 'else' block. If we have to start taking elements from it,
+        // then it must be present in a valid program.
+        llvm::SmallVector<mlir::Value, Fortran::common::maxRank> padExtents =
+            hlfir::genExtentsVector(loc, builder, hlfir::Entity{pad});
+        // Subtract the ARRAY size from the zero-based linear index
+        // to get the zero-based linear index into PAD.
+        mlir::Value padLinearIndex =
+            builder.create<mlir::arith::SubIOp>(loc, linearIndex, arraySize);
+        llvm::SmallVector<mlir::Value, Fortran::common::maxRank> padIndices =
+            delinearizeIndex(loc, builder, padExtents, padLinearIndex,
+                             /*wrapAround=*/true);
+        mlir::Value padElement =
+            hlfir::loadElementAt(loc, builder, hlfir::Entity{pad}, padIndices);
+        builder.create<fir::ResultOp>(loc, padElement);
+
+        // In the 'then' block, return an element from the ARRAY.
+        builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+      }
+
+      llvm::SmallVector<mlir::Value, Fortran::common::maxRank> arrayIndices =
+          delinearizeIndex(loc, builder, arrayExtents, linearIndex,
+                           /*wrapAround=*/false);
+      mlir::Value arrayElement =
+          hlfir::loadElementAt(loc, builder, array, arrayIndices);
+
+      if (ifOp) {
+        builder.create<fir::ResultOp>(loc, arrayElement);
+        builder.setInsertionPointAfter(ifOp);
+        arrayElement = ifOp.getResult(0);
+      }
+
+      return hlfir::Entity{arrayElement};
+    };
+    hlfir::ElementalOp elementalOp = hlfir::genElementalOp(
+        loc, builder, elementType, resultShape, typeParams, genKernel,
+        /*isUnordered=*/true,
+        /*polymorphicMold=*/result.isPolymorphic() ? array : mlir::Value{},
+        reshape.getResult().getType());
+    assert(elementalOp.getResult().getType() == reshape.getResult().getType());
+    rewriter.replaceOp(reshape, elementalOp);
+    return mlir::success();
+  }
+
+private:
+  /// Compute zero-based linear index given an array extents
+  /// and one-based indices:
+  ///   \p extents: [e0, e1, ..., en]
+  ///   \p indices: [i0, i1, ..., in]
+  ///
+  /// linear-index :=
+  ///   (...((in-1)*e(n-1)+(i(n-1)-1))*e(n-2)+...)*e0+(i0-1)
+  static mlir::Value computeLinearIndex(mlir::Location loc,
+                                        fir::FirOpBuilder &builder,
+                                        mlir::ValueRange extents,
+                                        mlir::ValueRange indices) {
+    std::size_t rank = extents.size();
+    assert(rank = indices.size());
----------------
davemgreen wrote:

Nit: Should be ==? GCC was giving a warning.

https://github.com/llvm/llvm-project/pull/124683


More information about the flang-commits mailing list