[flang-commits] [flang] [flang] Inline hlfir.matmul[_transpose]. (PR #122821)

via flang-commits flang-commits at lists.llvm.org
Tue Jan 14 07:39:04 PST 2025


================
@@ -467,9 +474,431 @@ class CShiftAsElementalConversion
   }
 };
 
+template <typename Op>
+class MatmulConversion : public mlir::OpRewritePattern<Op> {
+public:
+  using mlir::OpRewritePattern<Op>::OpRewritePattern;
+
+  llvm::LogicalResult
+  matchAndRewrite(Op matmul, mlir::PatternRewriter &rewriter) const override {
+    mlir::Location loc = matmul.getLoc();
+    fir::FirOpBuilder builder{rewriter, matmul.getOperation()};
+    hlfir::Entity lhs = hlfir::Entity{matmul.getLhs()};
+    hlfir::Entity rhs = hlfir::Entity{matmul.getRhs()};
+    mlir::Value resultShape, innerProductExtent;
+    std::tie(resultShape, innerProductExtent) =
+        genResultShape(loc, builder, lhs, rhs);
+
+    if (forceMatmulAsElemental || isMatmulTranspose) {
+      // Generate hlfir.elemental that produces the result of
+      // MATMUL/MATMUL(TRANSPOSE).
+      // Note that this implementation is very suboptimal for MATMUL,
+      // but is quite good for MATMUL(TRANSPOSE), e.g.:
+      //   R(1:N) = R(1:N) + MATMUL(TRANSPOSE(X(1:N,1:N)), Y(1:N))
+      // Inlining MATMUL(TRANSPOSE) as hlfir.elemental may result
+      // in merging the inner product computation with the elemental
+      // addition. Note that the inner product computation will
+      // benefit from processing the lowermost dimensions of X and Y,
+      // which may be the best when they are contiguous.
+      //
+      // This is why we always inline MATMUL(TRANSPOSE) as an elemental.
+      // MATMUL is inlined below by default unless forceMatmulAsElemental.
+      hlfir::ExprType resultType =
+          mlir::cast<hlfir::ExprType>(matmul.getType());
+      hlfir::ElementalOp newOp = genElementalMatmul(
+          loc, builder, resultType, resultShape, lhs, rhs, innerProductExtent);
+      rewriter.replaceOp(matmul, newOp);
+      return mlir::success();
+    }
+
+    // Generate hlfir.eval_in_mem to mimic the MATMUL implementation
+    // from Fortran runtime. The implementation needs to operate
+    // with the result array as an in-memory object.
+    hlfir::EvaluateInMemoryOp evalOp =
+        builder.create<hlfir::EvaluateInMemoryOp>(
+            loc, mlir::cast<hlfir::ExprType>(matmul.getType()), resultShape);
+    builder.setInsertionPointToStart(&evalOp.getBody().front());
+
+    // Embox the raw array pointer to simplify designating it.
+    // TODO: this currently results in redundant lower bounds
+    // addition for the designator, but this should be fixed in
+    // hlfir::Entity::mayHaveNonDefaultLowerBounds().
+    mlir::Value resultArray = evalOp.getMemory();
+    mlir::Type arrayType = fir::dyn_cast_ptrEleTy(resultArray.getType());
+    resultArray = builder.createBox(loc, fir::BoxType::get(arrayType),
+                                    resultArray, resultShape, /*slice=*/nullptr,
+                                    /*lengths=*/{}, /*tdesc=*/nullptr);
+
+    // The contiguous MATMUL version is best for the cases
+    // where the input arrays and (maybe) the result are contiguous
+    // in their lowermost dimensions.
+    // Especially, when LLVM can recognize the continuity
+    // and vectorize the loops properly.
+    // TODO: we need to recognize the cases when the continuity
+    // is not statically obvious and try to generate an explicitly
+    // continuous version under a dynamic check. The fallback
+    // implementation may use genElementalMatmul() with
+    // an hlfir.assign into the result of eval_in_mem.
+    mlir::LogicalResult rewriteResult =
+        genContiguousMatmul(loc, builder, hlfir::Entity{resultArray},
+                            resultShape, lhs, rhs, innerProductExtent);
+
+    if (mlir::failed(rewriteResult)) {
+      // Erase the unclaimed eval_in_mem op.
+      rewriter.eraseOp(evalOp);
+      return rewriter.notifyMatchFailure(matmul,
+                                         "genContiguousMatmul() failed");
+    }
+
+    rewriter.replaceOp(matmul, evalOp);
+    return mlir::success();
+  }
+
+private:
+  static constexpr bool isMatmulTranspose =
+      std::is_same_v<Op, hlfir::MatmulTransposeOp>;
+
+  // Return a tuple of:
+  //   * A fir.shape operation representing the shape of the result
+  //     of a MATMUL/MATMUL(TRANSPOSE).
+  //   * An extent of the dimensions of the input array
+  //     that are processed during the inner product computation.
+  static std::tuple<mlir::Value, mlir::Value>
+  genResultShape(mlir::Location loc, fir::FirOpBuilder &builder,
+                 hlfir::Entity input1, hlfir::Entity input2) {
+    mlir::Value input1Shape = hlfir::genShape(loc, builder, input1);
+    llvm::SmallVector<mlir::Value> input1Extents =
+        hlfir::getExplicitExtentsFromShape(input1Shape, builder);
+    if (input1Shape.getUses().empty())
+      input1Shape.getDefiningOp()->erase();
+    mlir::Value input2Shape = hlfir::genShape(loc, builder, input2);
+    llvm::SmallVector<mlir::Value> input2Extents =
+        hlfir::getExplicitExtentsFromShape(input2Shape, builder);
+    if (input2Shape.getUses().empty())
+      input2Shape.getDefiningOp()->erase();
+
+    llvm::SmallVector<mlir::Value, 2> newExtents;
+    mlir::Value innerProduct1Extent, innerProduct2Extent;
+    if (input1Extents.size() == 1) {
+      assert(!isMatmulTranspose &&
+             "hlfir.matmul_transpose's first operand must be rank-2 array");
+      assert(input2Extents.size() == 2 &&
+             "hlfir.matmul second argument must be rank-2 array");
+      newExtents.push_back(input2Extents[1]);
+      innerProduct1Extent = input1Extents[0];
+      innerProduct2Extent = input2Extents[0];
+    } else {
+      if (input2Extents.size() == 1) {
+        assert(input1Extents.size() == 2 &&
+               "hlfir.matmul first argument must be rank-2 array");
+        if constexpr (isMatmulTranspose)
+          newExtents.push_back(input1Extents[1]);
+        else
+          newExtents.push_back(input1Extents[0]);
+      } else {
+        assert(input1Extents.size() == 2 && input2Extents.size() == 2 &&
+               "hlfir.matmul arguments must be rank-2 arrays");
+        if constexpr (isMatmulTranspose)
+          newExtents.push_back(input1Extents[1]);
+        else
+          newExtents.push_back(input1Extents[0]);
+
+        newExtents.push_back(input2Extents[1]);
+      }
+      if constexpr (isMatmulTranspose)
+        innerProduct1Extent = input1Extents[0];
+      else
+        innerProduct1Extent = input1Extents[1];
+
+      innerProduct2Extent = input2Extents[0];
+    }
+    // The inner product dimensions of the input arrays
+    // must match. Pick the best (e.g. constant) out of them
+    // so that the inner product loop bound can be used in
+    // optimizations.
+    llvm::SmallVector<mlir::Value> innerProductExtent =
+        fir::factory::deduceOptimalExtents({innerProduct1Extent},
+                                           {innerProduct2Extent});
+    return {builder.create<fir::ShapeOp>(loc, newExtents),
+            innerProductExtent[0]};
+  }
+
+  static mlir::Value castToProductType(mlir::Location loc,
+                                       fir::FirOpBuilder &builder,
+                                       mlir::Value value, mlir::Type type) {
+    if (mlir::isa<fir::LogicalType>(type))
+      return builder.createConvert(loc, builder.getIntegerType(1), value);
+
+    if (fir::isa_complex(type) && !fir::isa_complex(value.getType())) {
+      mlir::Value zeroCmplx = fir::factory::createZeroValue(builder, loc, type);
----------------
jeanPerier wrote:

Isn't it beneficial to special case `complex * real`? Maybe the multiplied by zero is optimized out later though.
I am not asking you to do it, just curious. If it would be better to special case it, you can add a TODO note.

https://github.com/llvm/llvm-project/pull/122821


More information about the flang-commits mailing list