[flang-commits] [flang] [flang] Inline hlfir.matmul[_transpose]. (PR #122821)

Slava Zakharin via flang-commits flang-commits at lists.llvm.org
Tue Jan 14 10:13:27 PST 2025


https://github.com/vzakhari updated https://github.com/llvm/llvm-project/pull/122821

>From 9e1c219361b41f878ad53218a79b4cf1554aa4cf Mon Sep 17 00:00:00 2001
From: Slava Zakharin <szakharin at nvidia.com>
Date: Mon, 13 Jan 2025 14:27:27 -0800
Subject: [PATCH 1/3] [flang] Inline hlfir.matmul[_transpose].

Inlining `hlfir.matmul` as `hlfir.eval_in_mem` does not allow
to get rid of a temporary array in many cases, but it may still be
much better allowing to:
  * Get rid of any overhead related to calling runtime MATMUL
    (such as descriptors creation).
  * Use CPU-specific vectorization cost model for matmul loops,
    which Fortran runtime cannot currently do.
  * Optimize matmul of known-size arrays by complete unrolling.

One of the drawbacks of `hlfir.eval_in_mem` inlining is that
the ops inside it with store memory effects block the current
MLIR CSE, so I decided to run this inlining late in the pipeline.
There is a source commen explaining the CSE issue in more detail.

Straightforward inlining of `hlfir.matmul` as an `hlfir.elemental`
is not good for performance, and I got performance regressions
with it comparing to Fortran runtime implementation. I put it
under an enigneering option for experiments.

At the same time, inlining `hlfir.matmul_transpose` as `hlfir.elemental`
seems to be a good approach, e.g. it allows getting rid of a temporay
array in cases like: `A(:)=B(:)+MATMUL(TRANSPOSE(C(:,:)),D(:))`.

This patch improves performance of galgel and tonto a little bit.
---
 .../flang/Optimizer/Builder/FIRBuilder.h      |   9 +
 .../flang/Optimizer/Builder/HLFIRTools.h      |   5 +
 flang/include/flang/Optimizer/HLFIR/Passes.td |  11 +
 flang/lib/Optimizer/Builder/FIRBuilder.cpp    |  14 +
 flang/lib/Optimizer/Builder/HLFIRTools.cpp    |  17 +-
 .../Transforms/SimplifyHLFIRIntrinsics.cpp    | 452 ++++++++++++
 flang/lib/Optimizer/Passes/Pipelines.cpp      |   6 +
 flang/test/Driver/mlir-pass-pipeline.f90      |   4 +
 flang/test/Fir/basic-program.fir              |   4 +
 .../simplify-hlfir-intrinsics-matmul.fir      | 660 ++++++++++++++++++
 10 files changed, 1179 insertions(+), 3 deletions(-)
 create mode 100644 flang/test/HLFIR/simplify-hlfir-intrinsics-matmul.fir

diff --git a/flang/include/flang/Optimizer/Builder/FIRBuilder.h b/flang/include/flang/Optimizer/Builder/FIRBuilder.h
index c5d86e713f253a..ea658fb16a36c3 100644
--- a/flang/include/flang/Optimizer/Builder/FIRBuilder.h
+++ b/flang/include/flang/Optimizer/Builder/FIRBuilder.h
@@ -804,6 +804,15 @@ elideLengthsAlreadyInType(mlir::Type type, mlir::ValueRange lenParams);
 /// Get the address space which should be used for allocas
 uint64_t getAllocaAddressSpace(mlir::DataLayout *dataLayout);
 
+/// The two vectors of MLIR values have the following property:
+///   \p extents1[i] must have the same value as \p extents2[i]
+/// The function returns a new vector of MLIR values that preserves
+/// the same property vs \p extents1 and \p extents2, but allows
+/// more optimizations. For example, if extents1[j] is a known constant,
+/// and extents2[j] is not, then result[j] is the MLIR value extents1[j].
+llvm::SmallVector<mlir::Value> deduceOptimalExtents(mlir::ValueRange extents1,
+                                                    mlir::ValueRange extents2);
+
 } // namespace fir::factory
 
 #endif // FORTRAN_OPTIMIZER_BUILDER_FIRBUILDER_H
diff --git a/flang/include/flang/Optimizer/Builder/HLFIRTools.h b/flang/include/flang/Optimizer/Builder/HLFIRTools.h
index c8aad644bc784a..6e85b8f4ddf86e 100644
--- a/flang/include/flang/Optimizer/Builder/HLFIRTools.h
+++ b/flang/include/flang/Optimizer/Builder/HLFIRTools.h
@@ -508,6 +508,11 @@ genTypeAndKindConvert(mlir::Location loc, fir::FirOpBuilder &builder,
                       hlfir::Entity source, mlir::Type toType,
                       bool preserveLowerBounds);
 
+/// A shortcut for loadTrivialScalar(getElementAt()),
+/// which designates and loads an element of an array.
+Entity loadElementAt(mlir::Location loc, fir::FirOpBuilder &builder,
+                     Entity entity, mlir::ValueRange oneBasedIndices);
+
 } // namespace hlfir
 
 #endif // FORTRAN_OPTIMIZER_BUILDER_HLFIRTOOLS_H
diff --git a/flang/include/flang/Optimizer/HLFIR/Passes.td b/flang/include/flang/Optimizer/HLFIR/Passes.td
index 644f1e3c3af2b8..90cf6e74241bd0 100644
--- a/flang/include/flang/Optimizer/HLFIR/Passes.td
+++ b/flang/include/flang/Optimizer/HLFIR/Passes.td
@@ -43,6 +43,17 @@ def LowerHLFIROrderedAssignments : Pass<"lower-hlfir-ordered-assignments", "::ml
 
 def SimplifyHLFIRIntrinsics : Pass<"simplify-hlfir-intrinsics"> {
   let summary = "Simplify HLFIR intrinsic operations that don't need to result in runtime calls";
+  let options = [Option<"allowNewSideEffects", "allow-new-side-effects", "bool",
+                        /*default=*/"false",
+                        "If enabled, then the HLFIR operations simplification "
+                        "may introduce operations with side effects. "
+                        "For example, hlfir.matmul may be inlined as "
+                        "and hlfir.eval_in_mem with hlfir.assign inside it."
+                        "The hlfir.assign has a write effect on the memory "
+                        "argument of hlfir.eval_in_mem, which may block "
+                        "some existing MLIR transformations (e.g. CSE) "
+                        "that otherwise would have been possible across "
+                        "the hlfir.matmul.">];
 }
 
 def InlineElementals : Pass<"inline-elementals"> {
diff --git a/flang/lib/Optimizer/Builder/FIRBuilder.cpp b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
index d01becfe800937..218f98ef9ef429 100644
--- a/flang/lib/Optimizer/Builder/FIRBuilder.cpp
+++ b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
@@ -1740,3 +1740,17 @@ uint64_t fir::factory::getAllocaAddressSpace(mlir::DataLayout *dataLayout) {
       return mlir::cast<mlir::IntegerAttr>(addrSpace).getUInt();
   return 0;
 }
+
+llvm::SmallVector<mlir::Value>
+fir::factory::deduceOptimalExtents(mlir::ValueRange extents1,
+                                   mlir::ValueRange extents2) {
+  llvm::SmallVector<mlir::Value> extents;
+  extents.reserve(extents1.size());
+  for (auto [extent1, extent2] : llvm::zip(extents1, extents2)) {
+    if (!fir::getIntIfConstant(extent1) && fir::getIntIfConstant(extent2))
+      extents.push_back(extent2);
+    else
+      extents.push_back(extent1);
+  }
+  return extents;
+}
diff --git a/flang/lib/Optimizer/Builder/HLFIRTools.cpp b/flang/lib/Optimizer/Builder/HLFIRTools.cpp
index 94238bc24e453d..5e5d0bbd681326 100644
--- a/flang/lib/Optimizer/Builder/HLFIRTools.cpp
+++ b/flang/lib/Optimizer/Builder/HLFIRTools.cpp
@@ -939,8 +939,10 @@ llvm::SmallVector<mlir::Value> hlfir::genLoopNestWithReductions(
       doLoop = builder.create<fir::DoLoopOp>(loc, one, ub, one, isUnordered,
                                              /*finalCountValue=*/false,
                                              parentLoop.getRegionIterArgs());
-      // Return the results of the child loop from its parent loop.
-      builder.create<fir::ResultOp>(loc, doLoop.getResults());
+      if (!reductionInits.empty()) {
+        // Return the results of the child loop from its parent loop.
+        builder.create<fir::ResultOp>(loc, doLoop.getResults());
+      }
     }
 
     builder.setInsertionPointToStart(doLoop.getBody());
@@ -955,7 +957,8 @@ llvm::SmallVector<mlir::Value> hlfir::genLoopNestWithReductions(
   reductionValues =
       genBody(loc, builder, oneBasedIndices, parentLoop.getRegionIterArgs());
   builder.setInsertionPointToEnd(parentLoop.getBody());
-  builder.create<fir::ResultOp>(loc, reductionValues);
+  if (!reductionValues.empty())
+    builder.create<fir::ResultOp>(loc, reductionValues);
   builder.setInsertionPointAfter(outerLoop);
   return outerLoop->getResults();
 }
@@ -1410,3 +1413,11 @@ void hlfir::computeEvaluateOpIn(mlir::Location loc, fir::FirOpBuilder &builder,
     builder.clone(op, mapper);
   return;
 }
+
+hlfir::Entity hlfir::loadElementAt(mlir::Location loc,
+                                   fir::FirOpBuilder &builder,
+                                   hlfir::Entity entity,
+                                   mlir::ValueRange oneBasedIndices) {
+  return loadTrivialScalar(loc, builder,
+                           getElementAt(loc, builder, entity, oneBasedIndices));
+}
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
index 314ced8679521a..0fd535b4290799 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
@@ -28,6 +28,13 @@ namespace hlfir {
 #include "flang/Optimizer/HLFIR/Passes.h.inc"
 } // namespace hlfir
 
+#define DEBUG_TYPE "simplify-hlfir-intrinsics"
+
+static llvm::cl::opt<bool> forceMatmulAsElemental(
+    "flang-inline-matmul-as-elemental",
+    llvm::cl::desc("Expand hlfir.matmul as elemental operation"),
+    llvm::cl::init(false));
+
 namespace {
 
 class TransposeAsElementalConversion
@@ -467,9 +474,438 @@ class CShiftAsElementalConversion
   }
 };
 
+template <typename Op>
+class MatmulConversion : public mlir::OpRewritePattern<Op> {
+public:
+  using mlir::OpRewritePattern<Op>::OpRewritePattern;
+
+  llvm::LogicalResult
+  matchAndRewrite(Op matmul, mlir::PatternRewriter &rewriter) const override {
+    mlir::Location loc = matmul.getLoc();
+    fir::FirOpBuilder builder{rewriter, matmul.getOperation()};
+    hlfir::Entity lhs = hlfir::Entity{matmul.getLhs()};
+    hlfir::Entity rhs = hlfir::Entity{matmul.getRhs()};
+    mlir::Value resultShape, innerProductExtent;
+    std::tie(resultShape, innerProductExtent) =
+        genResultShape(loc, builder, lhs, rhs);
+
+    if (forceMatmulAsElemental || isMatmulTranspose) {
+      // Generate hlfir.elemental that produces the result of
+      // MATMUL/MATMUL(TRANSPOSE).
+      // Note that this implementation is very suboptimal for MATMUL,
+      // but is quite good for MATMUL(TRANSPOSE), e.g.:
+      //   R(1:N) = R(1:N) + MATMUL(TRANSPOSE(X(1:N,1:N)), Y(1:N))
+      // Inlining MATMUL(TRANSPOSE) as hlfir.elemental may result
+      // in merging the inner product computation with the elemental
+      // addition. Note that the inner product computation will
+      // benefit from processing the lowermost dimensions of X and Y,
+      // which may be the best when they are contiguous.
+      //
+      // This is why we always inline MATMUL(TRANSPOSE) as an elemental.
+      // MATMUL is inlined below by default unless forceMatmulAsElemental.
+      hlfir::ExprType resultType =
+          mlir::cast<hlfir::ExprType>(matmul.getType());
+      hlfir::ElementalOp newOp = genElementalMatmul(
+          loc, builder, resultType, resultShape, lhs, rhs, innerProductExtent);
+      rewriter.replaceOp(matmul, newOp);
+      return mlir::success();
+    }
+
+    // Generate hlfir.eval_in_mem to mimic the MATMUL implementation
+    // from Fortran runtime. The implementation needs to operate
+    // with the result array as an in-memory object.
+    hlfir::EvaluateInMemoryOp evalOp =
+        builder.create<hlfir::EvaluateInMemoryOp>(
+            loc, mlir::cast<hlfir::ExprType>(matmul.getType()), resultShape);
+    builder.setInsertionPointToStart(&evalOp.getBody().front());
+
+    // Embox the raw array pointer to simplify designating it.
+    // TODO: this currently results in redundant lower bounds
+    // addition for the designator, but this should be fixed in
+    // hlfir::Entity::mayHaveNonDefaultLowerBounds().
+    mlir::Value resultArray = evalOp.getMemory();
+    mlir::Type arrayType = fir::dyn_cast_ptrEleTy(resultArray.getType());
+    resultArray = builder.createBox(loc, fir::BoxType::get(arrayType),
+                                    resultArray, resultShape, /*slice=*/nullptr,
+                                    /*lengths=*/{}, /*tdesc=*/nullptr);
+
+    // The contiguous MATMUL version is best for the cases
+    // where the input arrays and (maybe) the result are contiguous
+    // in their lowermost dimensions.
+    // Especially, when LLVM can recognize the continuity
+    // and vectorize the loops properly.
+    // TODO: we need to recognize the cases when the continuity
+    // is not statically obvious and try to generate an explicitly
+    // continuous version under a dynamic check. The fallback
+    // implementation may use genElementalMatmul() with
+    // an hlfir.assign into the result of eval_in_mem.
+    mlir::LogicalResult rewriteResult =
+        genContiguousMatmul(loc, builder, hlfir::Entity{resultArray},
+                            resultShape, lhs, rhs, innerProductExtent);
+
+    if (mlir::failed(rewriteResult)) {
+      // Erase the unclaimed eval_in_mem op.
+      rewriter.eraseOp(evalOp);
+      return rewriter.notifyMatchFailure(matmul,
+                                         "genContiguousMatmul() failed");
+    }
+
+    rewriter.replaceOp(matmul, evalOp);
+    return mlir::success();
+  }
+
+private:
+  static constexpr bool isMatmulTranspose =
+      std::is_same_v<Op, hlfir::MatmulTransposeOp>;
+
+  // Return a tuple of:
+  //   * A fir.shape operation representing the shape of the result
+  //     of a MATMUL/MATMUL(TRANSPOSE).
+  //   * An extent of the dimensions of the input array
+  //     that are processed during the inner product computation.
+  static std::tuple<mlir::Value, mlir::Value>
+  genResultShape(mlir::Location loc, fir::FirOpBuilder &builder,
+                 hlfir::Entity input1, hlfir::Entity input2) {
+    mlir::Value input1Shape = hlfir::genShape(loc, builder, input1);
+    llvm::SmallVector<mlir::Value> input1Extents =
+        hlfir::getExplicitExtentsFromShape(input1Shape, builder);
+    if (input1Shape.getUses().empty())
+      input1Shape.getDefiningOp()->erase();
+    mlir::Value input2Shape = hlfir::genShape(loc, builder, input2);
+    llvm::SmallVector<mlir::Value> input2Extents =
+        hlfir::getExplicitExtentsFromShape(input2Shape, builder);
+    if (input2Shape.getUses().empty())
+      input2Shape.getDefiningOp()->erase();
+
+    llvm::SmallVector<mlir::Value, 2> newExtents;
+    mlir::Value innerProduct1Extent, innerProduct2Extent;
+    if (input1Extents.size() == 1) {
+      assert(!isMatmulTranspose &&
+             "hlfir.matmul_transpose's first operand must be rank-2 array");
+      assert(input2Extents.size() == 2 &&
+             "hlfir.matmul second argument must be rank-2 array");
+      newExtents.push_back(input2Extents[1]);
+      innerProduct1Extent = input1Extents[0];
+      innerProduct2Extent = input2Extents[0];
+    } else {
+      if (input2Extents.size() == 1) {
+        assert(input1Extents.size() == 2 &&
+               "hlfir.matmul first argument must be rank-2 array");
+        if constexpr (isMatmulTranspose)
+          newExtents.push_back(input1Extents[1]);
+        else
+          newExtents.push_back(input1Extents[0]);
+      } else {
+        assert(input1Extents.size() == 2 && input2Extents.size() == 2 &&
+               "hlfir.matmul arguments must be rank-2 arrays");
+        if constexpr (isMatmulTranspose)
+          newExtents.push_back(input1Extents[1]);
+        else
+          newExtents.push_back(input1Extents[0]);
+
+        newExtents.push_back(input2Extents[1]);
+      }
+      if constexpr (isMatmulTranspose)
+        innerProduct1Extent = input1Extents[0];
+      else
+        innerProduct1Extent = input1Extents[1];
+
+      innerProduct2Extent = input2Extents[0];
+    }
+    // The inner product dimensions of the input arrays
+    // must match. Pick the best (e.g. constant) out of them
+    // so that the inner product loop bound can be used in
+    // optimizations.
+    llvm::SmallVector<mlir::Value> innerProductExtent =
+        fir::factory::deduceOptimalExtents({innerProduct1Extent},
+                                           {innerProduct2Extent});
+    return {builder.create<fir::ShapeOp>(loc, newExtents),
+            innerProductExtent[0]};
+  }
+
+  static mlir::Value castToProductType(mlir::Location loc,
+                                       fir::FirOpBuilder &builder,
+                                       mlir::Value value, mlir::Type type) {
+    if (mlir::isa<fir::LogicalType>(type))
+      return builder.createConvert(loc, builder.getIntegerType(1), value);
+
+    if (fir::isa_complex(type) && !fir::isa_complex(value.getType())) {
+      mlir::Value zeroCmplx = fir::factory::createZeroValue(builder, loc, type);
+      fir::factory::Complex helper(builder, loc);
+      mlir::Type partType = helper.getComplexPartType(type);
+      return helper.insertComplexPart(
+          zeroCmplx, castToProductType(loc, builder, value, partType),
+          /*isImagPart=*/false);
+    }
+    return builder.createConvert(loc, type, value);
+  }
+
+  // Generate an update of the inner product value:
+  //   acc += v1 * v2, OR
+  //   acc ||= v1 && v2
+  static mlir::Value genAccumulateProduct(mlir::Location loc,
+                                          fir::FirOpBuilder &builder,
+                                          mlir::Type resultType,
+                                          mlir::Value acc, mlir::Value v1,
+                                          mlir::Value v2) {
+    acc = castToProductType(loc, builder, acc, resultType);
+    v1 = castToProductType(loc, builder, v1, resultType);
+    v2 = castToProductType(loc, builder, v2, resultType);
+    mlir::Value result;
+    if (mlir::isa<mlir::FloatType>(resultType))
+      result = builder.create<mlir::arith::AddFOp>(
+          loc, acc, builder.create<mlir::arith::MulFOp>(loc, v1, v2));
+    else if (mlir::isa<mlir::ComplexType>(resultType))
+      result = builder.create<fir::AddcOp>(
+          loc, acc, builder.create<fir::MulcOp>(loc, v1, v2));
+    else if (mlir::isa<mlir::IntegerType>(resultType))
+      result = builder.create<mlir::arith::AddIOp>(
+          loc, acc, builder.create<mlir::arith::MulIOp>(loc, v1, v2));
+    else if (mlir::isa<fir::LogicalType>(resultType))
+      result = builder.create<mlir::arith::OrIOp>(
+          loc, acc, builder.create<mlir::arith::AndIOp>(loc, v1, v2));
+    else
+      llvm_unreachable("unsupported type");
+
+    return builder.createConvert(loc, resultType, result);
+  }
+
+  static mlir::LogicalResult
+  genContiguousMatmul(mlir::Location loc, fir::FirOpBuilder &builder,
+                      hlfir::Entity result, mlir::Value resultShape,
+                      hlfir::Entity lhs, hlfir::Entity rhs,
+                      mlir::Value innerProductExtent) {
+    // This code does not support MATMUL(TRANSPOSE), and it is supposed
+    // to be inlined as hlfir.elemental.
+    if constexpr (isMatmulTranspose)
+      return mlir::failure();
+
+    mlir::OpBuilder::InsertionGuard guard(builder);
+    mlir::Type resultElementType = result.getFortranElementType();
+    llvm::SmallVector<mlir::Value, 2> resultExtents =
+        mlir::cast<fir::ShapeOp>(resultShape.getDefiningOp()).getExtents();
+
+    // The inner product loop may be unordered if FastMathFlags::reassoc
+    // transformations are allowed. The integer/logical inner product is
+    // always unordered.
+    // Note that isUnordered is currently applied to all loops
+    // in the loop nests generated below, while it has to be applied
+    // only to one.
+    bool isUnordered = mlir::isa<mlir::IntegerType>(resultElementType) ||
+                       mlir::isa<fir::LogicalType>(resultElementType) ||
+                       static_cast<bool>(builder.getFastMathFlags() &
+                                         mlir::arith::FastMathFlags::reassoc);
+
+    // Insert the initialization loop nest that fills the whole result with
+    // zeroes.
+    mlir::Value initValue =
+        fir::factory::createZeroValue(builder, loc, resultElementType);
+    auto genInitBody = [&](mlir::Location loc, fir::FirOpBuilder &builder,
+                           mlir::ValueRange oneBasedIndices,
+                           mlir::ValueRange reductionArgs)
+        -> llvm::SmallVector<mlir::Value, 0> {
+      hlfir::Entity resultElement =
+          hlfir::getElementAt(loc, builder, result, oneBasedIndices);
+      //      builder.create<fir::StoreOp>(loc, initValue, resultElement);
+      builder.create<hlfir::AssignOp>(loc, initValue, resultElement);
+      return {};
+    };
+
+    hlfir::genLoopNestWithReductions(loc, builder, resultExtents,
+                                     /*reductionInits=*/{}, genInitBody,
+                                     /*isUnordered=*/true);
+
+    if (lhs.getRank() == 2 && rhs.getRank() == 2) {
+      //   LHS(NROWS,N) * RHS(N,NCOLS) -> RESULT(NROWS,NCOLS)
+      //
+      // Insert the computation loop nest:
+      //   DO 2 K = 1, N
+      //    DO 2 J = 1, NCOLS
+      //     DO 2 I = 1, NROWS
+      //   2  RESULT(I,J) = RESULT(I,J) + LHS(I,K)*RHS(K,J)
+      auto genMatrixMatrix = [&](mlir::Location loc, fir::FirOpBuilder &builder,
+                                 mlir::ValueRange oneBasedIndices,
+                                 mlir::ValueRange reductionArgs)
+          -> llvm::SmallVector<mlir::Value, 0> {
+        mlir::Value I = oneBasedIndices[0];
+        mlir::Value J = oneBasedIndices[1];
+        mlir::Value K = oneBasedIndices[2];
+        hlfir::Entity resultElement =
+            hlfir::getElementAt(loc, builder, result, {I, J});
+        hlfir::Entity resultElementValue =
+            hlfir::loadTrivialScalar(loc, builder, resultElement);
+        hlfir::Entity lhsElementValue =
+            hlfir::loadElementAt(loc, builder, lhs, {I, K});
+        hlfir::Entity rhsElementValue =
+            hlfir::loadElementAt(loc, builder, rhs, {K, J});
+        mlir::Value productValue = genAccumulateProduct(
+            loc, builder, resultElementType, resultElementValue,
+            lhsElementValue, rhsElementValue);
+        builder.create<hlfir::AssignOp>(loc, productValue, resultElement);
+        //        builder.create<fir::StoreOp>(loc, productValue,
+        //        resultElement);
+        return {};
+      };
+
+      // Note that the loops are inserted in reverse order,
+      // so innerProductExtent should be passed as the last extent.
+      hlfir::genLoopNestWithReductions(
+          loc, builder,
+          {resultExtents[0], resultExtents[1], innerProductExtent},
+          /*reductionInits=*/{}, genMatrixMatrix, isUnordered);
+      return mlir::success();
+    }
+
+    if (lhs.getRank() == 2 && rhs.getRank() == 1) {
+      //   LHS(NROWS,N) * RHS(N) -> RESULT(NROWS)
+      //
+      // Insert the computation loop nest:
+      //   DO 2 K = 1, N
+      //    DO 2 J = 1, NROWS
+      //   2 RES(J) = RES(J) + LHS(J,K)*RHS(K)
+      auto genMatrixVector = [&](mlir::Location loc, fir::FirOpBuilder &builder,
+                                 mlir::ValueRange oneBasedIndices,
+                                 mlir::ValueRange reductionArgs)
+          -> llvm::SmallVector<mlir::Value, 0> {
+        mlir::Value J = oneBasedIndices[0];
+        mlir::Value K = oneBasedIndices[1];
+        hlfir::Entity resultElement =
+            hlfir::getElementAt(loc, builder, result, {J});
+        hlfir::Entity resultElementValue =
+            hlfir::loadTrivialScalar(loc, builder, resultElement);
+        hlfir::Entity lhsElementValue =
+            hlfir::loadElementAt(loc, builder, lhs, {J, K});
+        hlfir::Entity rhsElementValue =
+            hlfir::loadElementAt(loc, builder, rhs, {K});
+        mlir::Value productValue = genAccumulateProduct(
+            loc, builder, resultElementType, resultElementValue,
+            lhsElementValue, rhsElementValue);
+        builder.create<hlfir::AssignOp>(loc, productValue, resultElement);
+        //        builder.create<fir::StoreOp>(loc, productValue,
+        //        resultElement);
+        return {};
+      };
+      hlfir::genLoopNestWithReductions(
+          loc, builder, {resultExtents[0], innerProductExtent},
+          /*reductionInits=*/{}, genMatrixVector, isUnordered);
+      return mlir::success();
+    }
+    if (lhs.getRank() == 1 && rhs.getRank() == 2) {
+      //   LHS(N) * RHS(N,NCOLS) -> RESULT(NCOLS)
+      //
+      // Insert the computation loop nest:
+      //   DO 2 K = 1, N
+      //    DO 2 J = 1, NCOLS
+      //   2 RES(J) = RES(J) + LHS(K)*RHS(K,J)
+      auto genVectorMatrix = [&](mlir::Location loc, fir::FirOpBuilder &builder,
+                                 mlir::ValueRange oneBasedIndices,
+                                 mlir::ValueRange reductionArgs)
+          -> llvm::SmallVector<mlir::Value, 0> {
+        mlir::Value J = oneBasedIndices[0];
+        mlir::Value K = oneBasedIndices[1];
+        hlfir::Entity resultElement =
+            hlfir::getElementAt(loc, builder, result, {J});
+        hlfir::Entity resultElementValue =
+            hlfir::loadTrivialScalar(loc, builder, resultElement);
+        hlfir::Entity lhsElementValue =
+            hlfir::loadElementAt(loc, builder, lhs, {K});
+        hlfir::Entity rhsElementValue =
+            hlfir::loadElementAt(loc, builder, rhs, {K, J});
+        mlir::Value productValue = genAccumulateProduct(
+            loc, builder, resultElementType, resultElementValue,
+            lhsElementValue, rhsElementValue);
+        builder.create<hlfir::AssignOp>(loc, productValue, resultElement);
+        //        builder.create<fir::StoreOp>(loc, productValue,
+        //        resultElement);
+        return {};
+      };
+      hlfir::genLoopNestWithReductions(
+          loc, builder, {resultExtents[0], innerProductExtent},
+          /*reductionInits=*/{}, genVectorMatrix, isUnordered);
+      return mlir::success();
+    }
+
+    llvm_unreachable("unsupported MATMUL arguments' ranks");
+  }
+
+  static hlfir::ElementalOp
+  genElementalMatmul(mlir::Location loc, fir::FirOpBuilder &builder,
+                     hlfir::ExprType resultType, mlir::Value resultShape,
+                     hlfir::Entity lhs, hlfir::Entity rhs,
+                     mlir::Value innerProductExtent) {
+    mlir::OpBuilder::InsertionGuard guard(builder);
+    mlir::Type resultElementType = resultType.getElementType();
+    auto genKernel = [&](mlir::Location loc, fir::FirOpBuilder &builder,
+                         mlir::ValueRange resultIndices) -> hlfir::Entity {
+      mlir::Value initValue =
+          fir::factory::createZeroValue(builder, loc, resultElementType);
+      // The inner product loop may be unordered if FastMathFlags::reassoc
+      // transformations are allowed. The integer/logical inner product is
+      // always unordered.
+      bool isUnordered = mlir::isa<mlir::IntegerType>(resultElementType) ||
+                         mlir::isa<fir::LogicalType>(resultElementType) ||
+                         static_cast<bool>(builder.getFastMathFlags() &
+                                           mlir::arith::FastMathFlags::reassoc);
+
+      auto genBody = [&](mlir::Location loc, fir::FirOpBuilder &builder,
+                         mlir::ValueRange oneBasedIndices,
+                         mlir::ValueRange reductionArgs)
+          -> llvm::SmallVector<mlir::Value, 1> {
+        llvm::SmallVector<mlir::Value, 2> lhsIndices;
+        llvm::SmallVector<mlir::Value, 2> rhsIndices;
+        // MATMUL:
+        //   LHS(NROWS,N) * RHS(N,NCOLS) -> RESULT(NROWS,NCOLS)
+        //   LHS(NROWS,N) * RHS(N) -> RESULT(NROWS)
+        //   LHS(N) * RHS(N,NCOLS) -> RESULT(NCOLS)
+        //
+        // MATMUL(TRANSPOSE):
+        //   TRANSPOSE(LHS(N,NROWS)) * RHS(N,NCOLS) -> RESULT(NROWS,NCOLS)
+        //   TRANSPOSE(LHS(N,NROWS)) * RHS(N) -> RESULT(NROWS)
+        //
+        // The resultIndices iterate over (NROWS[,NCOLS]).
+        // The oneBasedIndices iterate over (N).
+        if (lhs.getRank() > 1)
+          lhsIndices.push_back(resultIndices[0]);
+        lhsIndices.push_back(oneBasedIndices[0]);
+
+        if constexpr (isMatmulTranspose) {
+          // Swap the LHS indices for TRANSPOSE.
+          std::swap(lhsIndices[0], lhsIndices[1]);
+        }
+
+        rhsIndices.push_back(oneBasedIndices[0]);
+        if (rhs.getRank() > 1)
+          rhsIndices.push_back(resultIndices.back());
+
+        hlfir::Entity lhsElementValue =
+            hlfir::loadElementAt(loc, builder, lhs, lhsIndices);
+        hlfir::Entity rhsElementValue =
+            hlfir::loadElementAt(loc, builder, rhs, rhsIndices);
+        mlir::Value productValue = genAccumulateProduct(
+            loc, builder, resultElementType, reductionArgs[0], lhsElementValue,
+            rhsElementValue);
+        return {productValue};
+      };
+      llvm::SmallVector<mlir::Value, 1> innerProductValue =
+          hlfir::genLoopNestWithReductions(loc, builder, {innerProductExtent},
+                                           {initValue}, genBody, isUnordered);
+      return hlfir::Entity{innerProductValue[0]};
+    };
+    hlfir::ElementalOp elementalOp = hlfir::genElementalOp(
+        loc, builder, resultElementType, resultShape, /*typeParams=*/{},
+        genKernel,
+        /*isUnordered=*/true, /*polymorphicMold=*/nullptr, resultType);
+
+    return elementalOp;
+  }
+};
+
 class SimplifyHLFIRIntrinsics
     : public hlfir::impl::SimplifyHLFIRIntrinsicsBase<SimplifyHLFIRIntrinsics> {
 public:
+  using SimplifyHLFIRIntrinsicsBase<
+      SimplifyHLFIRIntrinsics>::SimplifyHLFIRIntrinsicsBase;
+
   void runOnOperation() override {
     mlir::MLIRContext *context = &getContext();
 
@@ -482,6 +918,22 @@ class SimplifyHLFIRIntrinsics
     patterns.insert<TransposeAsElementalConversion>(context);
     patterns.insert<SumAsElementalConversion>(context);
     patterns.insert<CShiftAsElementalConversion>(context);
+    patterns.insert<MatmulConversion<hlfir::MatmulTransposeOp>>(context);
+
+    // If forceMatmulAsElemental is false, then hlfir.matmul inlining
+    // will introduce hlfir.eval_in_mem operation with new memory side
+    // effects. This conflicts with CSE and optimized bufferization, e.g.:
+    //   A(1:N,1:N) =  A(1:N,1:N) - MATMUL(...)
+    // If we introduce hlfir.eval_in_mem before CSE, then the current
+    // MLIR CSE won't be able to optimize the trivial loads of 'N' value
+    // that happen before and after hlfir.matmul.
+    // If 'N' loads are not optimized, then the optimized bufferization
+    // won't be able to prove that the slices of A are identical
+    // on both sides of the assignment.
+    // This is actually the CSE problem, but we can work it around
+    // for the time being.
+    if (forceMatmulAsElemental || this->allowNewSideEffects)
+      patterns.insert<MatmulConversion<hlfir::MatmulOp>>(context);
 
     if (mlir::failed(mlir::applyPatternsGreedily(
             getOperation(), std::move(patterns), config))) {
diff --git a/flang/lib/Optimizer/Passes/Pipelines.cpp b/flang/lib/Optimizer/Passes/Pipelines.cpp
index e1d7376ec3805d..1cc3f0b81c20ad 100644
--- a/flang/lib/Optimizer/Passes/Pipelines.cpp
+++ b/flang/lib/Optimizer/Passes/Pipelines.cpp
@@ -232,6 +232,12 @@ void createHLFIRToFIRPassPipeline(mlir::PassManager &pm, bool enableOpenMP,
   if (optLevel.isOptimizingForSpeed()) {
     addCanonicalizerPassWithoutRegionSimplification(pm);
     pm.addPass(mlir::createCSEPass());
+    // Run SimplifyHLFIRIntrinsics pass late after CSE,
+    // and allow introducing operations with new side effects.
+    addNestedPassToAllTopLevelOperations<PassConstructor>(pm, []() {
+      return hlfir::createSimplifyHLFIRIntrinsics(
+          {/*allowNewSideEffects=*/true});
+    });
     addNestedPassToAllTopLevelOperations<PassConstructor>(
         pm, hlfir::createOptimizedBufferization);
     addNestedPassToAllTopLevelOperations<PassConstructor>(
diff --git a/flang/test/Driver/mlir-pass-pipeline.f90 b/flang/test/Driver/mlir-pass-pipeline.f90
index 55e86da2dfdf14..dd46aecb3274c1 100644
--- a/flang/test/Driver/mlir-pass-pipeline.f90
+++ b/flang/test/Driver/mlir-pass-pipeline.f90
@@ -35,15 +35,19 @@
 ! O2-NEXT: (S) {{.*}} num-dce'd
 ! O2-NEXT: Pipeline Collection : ['fir.global', 'func.func', 'omp.declare_reduction', 'omp.private']
 ! O2-NEXT: 'fir.global' Pipeline
+! O2-NEXT:   SimplifyHLFIRIntrinsics
 ! O2-NEXT:   OptimizedBufferization
 ! O2-NEXT:   InlineHLFIRAssign
 ! O2-NEXT: 'func.func' Pipeline
+! O2-NEXT:   SimplifyHLFIRIntrinsics
 ! O2-NEXT:   OptimizedBufferization
 ! O2-NEXT:   InlineHLFIRAssign
 ! O2-NEXT: 'omp.declare_reduction' Pipeline
+! O2-NEXT:   SimplifyHLFIRIntrinsics
 ! O2-NEXT:   OptimizedBufferization
 ! O2-NEXT:   InlineHLFIRAssign
 ! O2-NEXT: 'omp.private' Pipeline
+! O2-NEXT:   SimplifyHLFIRIntrinsics
 ! O2-NEXT:   OptimizedBufferization
 ! O2-NEXT:   InlineHLFIRAssign
 ! ALL: LowerHLFIROrderedAssignments
diff --git a/flang/test/Fir/basic-program.fir b/flang/test/Fir/basic-program.fir
index 29a0f661579710..51e68d2157631a 100644
--- a/flang/test/Fir/basic-program.fir
+++ b/flang/test/Fir/basic-program.fir
@@ -36,15 +36,19 @@ func.func @_QQmain() {
 // PASSES-NEXT:    (S) 0 num-dce'd - Number of operations DCE'd
 // PASSES-NEXT: Pipeline Collection : ['fir.global', 'func.func', 'omp.declare_reduction', 'omp.private']
 // PASSES-NEXT: 'fir.global' Pipeline
+// PASSES-NEXT:    SimplifyHLFIRIntrinsics
 // PASSES-NEXT:    OptimizedBufferization
 // PASSES-NEXT:    InlineHLFIRAssign
 // PASSES-NEXT: 'func.func' Pipeline
+// PASSES-NEXT:    SimplifyHLFIRIntrinsics
 // PASSES-NEXT:    OptimizedBufferization
 // PASSES-NEXT:    InlineHLFIRAssign
 // PASSES-NEXT: 'omp.declare_reduction' Pipeline
+// PASSES-NEXT:    SimplifyHLFIRIntrinsics
 // PASSES-NEXT:    OptimizedBufferization
 // PASSES-NEXT:    InlineHLFIRAssign
 // PASSES-NEXT: 'omp.private' Pipeline
+// PASSES-NEXT:    SimplifyHLFIRIntrinsics
 // PASSES-NEXT:    OptimizedBufferization
 // PASSES-NEXT:    InlineHLFIRAssign
 // PASSES-NEXT:   LowerHLFIROrderedAssignments
diff --git a/flang/test/HLFIR/simplify-hlfir-intrinsics-matmul.fir b/flang/test/HLFIR/simplify-hlfir-intrinsics-matmul.fir
new file mode 100644
index 00000000000000..d29e9a26c20ba9
--- /dev/null
+++ b/flang/test/HLFIR/simplify-hlfir-intrinsics-matmul.fir
@@ -0,0 +1,660 @@
+// Test hlfir.cshift simplification to hlfir.elemental:
+// RUN: fir-opt --simplify-hlfir-intrinsics=allow-new-side-effects=false %s | FileCheck %s --check-prefixes=ALL,NOANSE
+// RUN: fir-opt --simplify-hlfir-intrinsics=allow-new-side-effects=true %s | FileCheck %s --check-prefixes=ALL,ANSE
+// RUN: fir-opt --simplify-hlfir-intrinsics -flang-inline-matmul-as-elemental %s | FileCheck %s --check-prefixes=ALL,ELEMENTAL
+
+func.func @matmul_matrix_matrix_integer(%arg0: !hlfir.expr<?x?xi16>, %arg1: !hlfir.expr<?x?xi32>) -> !hlfir.expr<?x?xi32> {
+  %res = hlfir.matmul %arg0 %arg1 : (!hlfir.expr<?x?xi16>, !hlfir.expr<?x?xi32>) -> !hlfir.expr<?x?xi32>
+  return %res : !hlfir.expr<?x?xi32>
+}
+// ALL-LABEL:   func.func @matmul_matrix_matrix_integer(
+// ALL-SAME:                                            %[[VAL_0:.*]]: !hlfir.expr<?x?xi16>,
+// ALL-SAME:                                            %[[VAL_1:.*]]: !hlfir.expr<?x?xi32>) -> !hlfir.expr<?x?xi32> {
+
+// NOANSE: hlfir.matmul
+
+// ANSE:           %[[VAL_2:.*]] = arith.constant 0 : index
+// ANSE:           %[[VAL_3:.*]] = arith.constant 1 : index
+// ANSE:           %[[VAL_4:.*]] = arith.constant 0 : i32
+// ANSE:           %[[VAL_5:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr<?x?xi16>) -> !fir.shape<2>
+// ANSE:           %[[VAL_6:.*]] = hlfir.get_extent %[[VAL_5]] {dim = 0 : index} : (!fir.shape<2>) -> index
+// ANSE:           %[[VAL_7:.*]] = hlfir.get_extent %[[VAL_5]] {dim = 1 : index} : (!fir.shape<2>) -> index
+// ANSE:           %[[VAL_8:.*]] = hlfir.shape_of %[[VAL_1]] : (!hlfir.expr<?x?xi32>) -> !fir.shape<2>
+// ANSE:           %[[VAL_9:.*]] = hlfir.get_extent %[[VAL_8]] {dim = 1 : index} : (!fir.shape<2>) -> index
+// ANSE:           %[[VAL_10:.*]] = fir.shape %[[VAL_6]], %[[VAL_9]] : (index, index) -> !fir.shape<2>
+// ANSE:           %[[VAL_11:.*]] = hlfir.eval_in_mem shape %[[VAL_10]] : (!fir.shape<2>) -> !hlfir.expr<?x?xi32> {
+// ANSE:           ^bb0(%[[VAL_12:.*]]: !fir.ref<!fir.array<?x?xi32>>):
+// ANSE:             %[[VAL_13:.*]] = fir.embox %[[VAL_12]](%[[VAL_10]]) : (!fir.ref<!fir.array<?x?xi32>>, !fir.shape<2>) -> !fir.box<!fir.array<?x?xi32>>
+// ANSE:             fir.do_loop %[[VAL_14:.*]] = %[[VAL_3]] to %[[VAL_9]] step %[[VAL_3]] unordered {
+// ANSE:               fir.do_loop %[[VAL_15:.*]] = %[[VAL_3]] to %[[VAL_6]] step %[[VAL_3]] unordered {
+// ANSE:                 %[[VAL_16:.*]]:3 = fir.box_dims %[[VAL_13]], %[[VAL_2]] : (!fir.box<!fir.array<?x?xi32>>, index) -> (index, index, index)
+// ANSE:                 %[[VAL_17:.*]]:3 = fir.box_dims %[[VAL_13]], %[[VAL_3]] : (!fir.box<!fir.array<?x?xi32>>, index) -> (index, index, index)
+// ANSE:                 %[[VAL_18:.*]] = arith.subi %[[VAL_16]]#0, %[[VAL_3]] : index
+// ANSE:                 %[[VAL_19:.*]] = arith.addi %[[VAL_15]], %[[VAL_18]] : index
+// ANSE:                 %[[VAL_20:.*]] = arith.subi %[[VAL_17]]#0, %[[VAL_3]] : index
+// ANSE:                 %[[VAL_21:.*]] = arith.addi %[[VAL_14]], %[[VAL_20]] : index
+// ANSE:                 %[[VAL_22:.*]] = hlfir.designate %[[VAL_13]] (%[[VAL_19]], %[[VAL_21]])  : (!fir.box<!fir.array<?x?xi32>>, index, index) -> !fir.ref<i32>
+// ANSE:                 hlfir.assign %[[VAL_4]] to %[[VAL_22]] : i32, !fir.ref<i32>
+// ANSE:               }
+// ANSE:             }
+// ANSE:             fir.do_loop %[[VAL_23:.*]] = %[[VAL_3]] to %[[VAL_7]] step %[[VAL_3]] unordered {
+// ANSE:               fir.do_loop %[[VAL_24:.*]] = %[[VAL_3]] to %[[VAL_9]] step %[[VAL_3]] unordered {
+// ANSE:                 fir.do_loop %[[VAL_25:.*]] = %[[VAL_3]] to %[[VAL_6]] step %[[VAL_3]] unordered {
+// ANSE:                   %[[VAL_26:.*]]:3 = fir.box_dims %[[VAL_13]], %[[VAL_2]] : (!fir.box<!fir.array<?x?xi32>>, index) -> (index, index, index)
+// ANSE:                   %[[VAL_27:.*]]:3 = fir.box_dims %[[VAL_13]], %[[VAL_3]] : (!fir.box<!fir.array<?x?xi32>>, index) -> (index, index, index)
+// ANSE:                   %[[VAL_28:.*]] = arith.subi %[[VAL_26]]#0, %[[VAL_3]] : index
+// ANSE:                   %[[VAL_29:.*]] = arith.addi %[[VAL_25]], %[[VAL_28]] : index
+// ANSE:                   %[[VAL_30:.*]] = arith.subi %[[VAL_27]]#0, %[[VAL_3]] : index
+// ANSE:                   %[[VAL_31:.*]] = arith.addi %[[VAL_24]], %[[VAL_30]] : index
+// ANSE:                   %[[VAL_32:.*]] = hlfir.designate %[[VAL_13]] (%[[VAL_29]], %[[VAL_31]])  : (!fir.box<!fir.array<?x?xi32>>, index, index) -> !fir.ref<i32>
+// ANSE:                   %[[VAL_33:.*]] = fir.load %[[VAL_32]] : !fir.ref<i32>
+// ANSE:                   %[[VAL_34:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_25]], %[[VAL_23]] : (!hlfir.expr<?x?xi16>, index, index) -> i16
+// ANSE:                   %[[VAL_35:.*]] = hlfir.apply %[[VAL_1]], %[[VAL_23]], %[[VAL_24]] : (!hlfir.expr<?x?xi32>, index, index) -> i32
+// ANSE:                   %[[VAL_36:.*]] = fir.convert %[[VAL_34]] : (i16) -> i32
+// ANSE:                   %[[VAL_37:.*]] = arith.muli %[[VAL_36]], %[[VAL_35]] : i32
+// ANSE:                   %[[VAL_38:.*]] = arith.addi %[[VAL_33]], %[[VAL_37]] : i32
+// ANSE:                   hlfir.assign %[[VAL_38]] to %[[VAL_32]] : i32, !fir.ref<i32>
+// ANSE:                 }
+// ANSE:               }
+// ANSE:             }
+// ANSE:           }
+// ANSE:           return %[[VAL_11]] : !hlfir.expr<?x?xi32>
+// ANSE:         }
+
+// ELEMENTAL:           %[[VAL_2:.*]] = arith.constant 1 : index
+// ELEMENTAL:           %[[VAL_3:.*]] = arith.constant 0 : i32
+// ELEMENTAL:           %[[VAL_4:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr<?x?xi16>) -> !fir.shape<2>
+// ELEMENTAL:           %[[VAL_5:.*]] = hlfir.get_extent %[[VAL_4]] {dim = 0 : index} : (!fir.shape<2>) -> index
+// ELEMENTAL:           %[[VAL_6:.*]] = hlfir.get_extent %[[VAL_4]] {dim = 1 : index} : (!fir.shape<2>) -> index
+// ELEMENTAL:           %[[VAL_7:.*]] = hlfir.shape_of %[[VAL_1]] : (!hlfir.expr<?x?xi32>) -> !fir.shape<2>
+// ELEMENTAL:           %[[VAL_8:.*]] = hlfir.get_extent %[[VAL_7]] {dim = 1 : index} : (!fir.shape<2>) -> index
+// ELEMENTAL:           %[[VAL_9:.*]] = fir.shape %[[VAL_5]], %[[VAL_8]] : (index, index) -> !fir.shape<2>
+// ELEMENTAL:           %[[VAL_10:.*]] = hlfir.elemental %[[VAL_9]] unordered : (!fir.shape<2>) -> !hlfir.expr<?x?xi32> {
+// ELEMENTAL:           ^bb0(%[[VAL_11:.*]]: index, %[[VAL_12:.*]]: index):
+// ELEMENTAL:             %[[VAL_13:.*]] = fir.do_loop %[[VAL_14:.*]] = %[[VAL_2]] to %[[VAL_6]] step %[[VAL_2]] unordered iter_args(%[[VAL_15:.*]] = %[[VAL_3]]) -> (i32) {
+// ELEMENTAL:               %[[VAL_16:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_11]], %[[VAL_14]] : (!hlfir.expr<?x?xi16>, index, index) -> i16
+// ELEMENTAL:               %[[VAL_17:.*]] = hlfir.apply %[[VAL_1]], %[[VAL_14]], %[[VAL_12]] : (!hlfir.expr<?x?xi32>, index, index) -> i32
+// ELEMENTAL:               %[[VAL_18:.*]] = fir.convert %[[VAL_16]] : (i16) -> i32
+// ELEMENTAL:               %[[VAL_19:.*]] = arith.muli %[[VAL_18]], %[[VAL_17]] : i32
+// ELEMENTAL:               %[[VAL_20:.*]] = arith.addi %[[VAL_15]], %[[VAL_19]] : i32
+// ELEMENTAL:               fir.result %[[VAL_20]] : i32
+// ELEMENTAL:             }
+// ELEMENTAL:             hlfir.yield_element %[[VAL_13]] : i32
+// ELEMENTAL:           }
+// ELEMENTAL:           return %[[VAL_10]] : !hlfir.expr<?x?xi32>
+// ELEMENTAL:         }
+
+func.func @matmul_matrix_matrix_real(%arg0: !hlfir.expr<?x?xf32>, %arg1: !hlfir.expr<?x?xf16>) -> !hlfir.expr<?x?xf32> {
+  %res = hlfir.matmul %arg0 %arg1 : (!hlfir.expr<?x?xf32>, !hlfir.expr<?x?xf16>) -> !hlfir.expr<?x?xf32>
+  return %res : !hlfir.expr<?x?xf32>
+}
+// ALL-LABEL:   func.func @matmul_matrix_matrix_real(
+// ALL-SAME:                                         %[[VAL_0:.*]]: !hlfir.expr<?x?xf32>,
+// ALL-SAME:                                         %[[VAL_1:.*]]: !hlfir.expr<?x?xf16>) -> !hlfir.expr<?x?xf32> {
+
+// NOANSE: hlfir.matmul
+
+// ANSE:           %[[VAL_2:.*]] = arith.constant 0 : index
+// ANSE:           %[[VAL_3:.*]] = arith.constant 1 : index
+// ANSE:           %[[VAL_4:.*]] = arith.constant 0.000000e+00 : f32
+// ANSE:           %[[VAL_5:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr<?x?xf32>) -> !fir.shape<2>
+// ANSE:           %[[VAL_6:.*]] = hlfir.get_extent %[[VAL_5]] {dim = 0 : index} : (!fir.shape<2>) -> index
+// ANSE:           %[[VAL_7:.*]] = hlfir.get_extent %[[VAL_5]] {dim = 1 : index} : (!fir.shape<2>) -> index
+// ANSE:           %[[VAL_8:.*]] = hlfir.shape_of %[[VAL_1]] : (!hlfir.expr<?x?xf16>) -> !fir.shape<2>
+// ANSE:           %[[VAL_9:.*]] = hlfir.get_extent %[[VAL_8]] {dim = 1 : index} : (!fir.shape<2>) -> index
+// ANSE:           %[[VAL_10:.*]] = fir.shape %[[VAL_6]], %[[VAL_9]] : (index, index) -> !fir.shape<2>
+// ANSE:           %[[VAL_11:.*]] = hlfir.eval_in_mem shape %[[VAL_10]] : (!fir.shape<2>) -> !hlfir.expr<?x?xf32> {
+// ANSE:           ^bb0(%[[VAL_12:.*]]: !fir.ref<!fir.array<?x?xf32>>):
+// ANSE:             %[[VAL_13:.*]] = fir.embox %[[VAL_12]](%[[VAL_10]]) : (!fir.ref<!fir.array<?x?xf32>>, !fir.shape<2>) -> !fir.box<!fir.array<?x?xf32>>
+// ANSE:             fir.do_loop %[[VAL_14:.*]] = %[[VAL_3]] to %[[VAL_9]] step %[[VAL_3]] unordered {
+// ANSE:               fir.do_loop %[[VAL_15:.*]] = %[[VAL_3]] to %[[VAL_6]] step %[[VAL_3]] unordered {
+// ANSE:                 %[[VAL_16:.*]]:3 = fir.box_dims %[[VAL_13]], %[[VAL_2]] : (!fir.box<!fir.array<?x?xf32>>, index) -> (index, index, index)
+// ANSE:                 %[[VAL_17:.*]]:3 = fir.box_dims %[[VAL_13]], %[[VAL_3]] : (!fir.box<!fir.array<?x?xf32>>, index) -> (index, index, index)
+// ANSE:                 %[[VAL_18:.*]] = arith.subi %[[VAL_16]]#0, %[[VAL_3]] : index
+// ANSE:                 %[[VAL_19:.*]] = arith.addi %[[VAL_15]], %[[VAL_18]] : index
+// ANSE:                 %[[VAL_20:.*]] = arith.subi %[[VAL_17]]#0, %[[VAL_3]] : index
+// ANSE:                 %[[VAL_21:.*]] = arith.addi %[[VAL_14]], %[[VAL_20]] : index
+// ANSE:                 %[[VAL_22:.*]] = hlfir.designate %[[VAL_13]] (%[[VAL_19]], %[[VAL_21]])  : (!fir.box<!fir.array<?x?xf32>>, index, index) -> !fir.ref<f32>
+// ANSE:                 hlfir.assign %[[VAL_4]] to %[[VAL_22]] : f32, !fir.ref<f32>
+// ANSE:               }
+// ANSE:             }
+// ANSE:             fir.do_loop %[[VAL_23:.*]] = %[[VAL_3]] to %[[VAL_7]] step %[[VAL_3]] {
+// ANSE:               fir.do_loop %[[VAL_24:.*]] = %[[VAL_3]] to %[[VAL_9]] step %[[VAL_3]] {
+// ANSE:                 fir.do_loop %[[VAL_25:.*]] = %[[VAL_3]] to %[[VAL_6]] step %[[VAL_3]] {
+// ANSE:                   %[[VAL_26:.*]]:3 = fir.box_dims %[[VAL_13]], %[[VAL_2]] : (!fir.box<!fir.array<?x?xf32>>, index) -> (index, index, index)
+// ANSE:                   %[[VAL_27:.*]]:3 = fir.box_dims %[[VAL_13]], %[[VAL_3]] : (!fir.box<!fir.array<?x?xf32>>, index) -> (index, index, index)
+// ANSE:                   %[[VAL_28:.*]] = arith.subi %[[VAL_26]]#0, %[[VAL_3]] : index
+// ANSE:                   %[[VAL_29:.*]] = arith.addi %[[VAL_25]], %[[VAL_28]] : index
+// ANSE:                   %[[VAL_30:.*]] = arith.subi %[[VAL_27]]#0, %[[VAL_3]] : index
+// ANSE:                   %[[VAL_31:.*]] = arith.addi %[[VAL_24]], %[[VAL_30]] : index
+// ANSE:                   %[[VAL_32:.*]] = hlfir.designate %[[VAL_13]] (%[[VAL_29]], %[[VAL_31]])  : (!fir.box<!fir.array<?x?xf32>>, index, index) -> !fir.ref<f32>
+// ANSE:                   %[[VAL_33:.*]] = fir.load %[[VAL_32]] : !fir.ref<f32>
+// ANSE:                   %[[VAL_34:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_25]], %[[VAL_23]] : (!hlfir.expr<?x?xf32>, index, index) -> f32
+// ANSE:                   %[[VAL_35:.*]] = hlfir.apply %[[VAL_1]], %[[VAL_23]], %[[VAL_24]] : (!hlfir.expr<?x?xf16>, index, index) -> f16
+// ANSE:                   %[[VAL_36:.*]] = fir.convert %[[VAL_35]] : (f16) -> f32
+// ANSE:                   %[[VAL_37:.*]] = arith.mulf %[[VAL_34]], %[[VAL_36]] : f32
+// ANSE:                   %[[VAL_38:.*]] = arith.addf %[[VAL_33]], %[[VAL_37]] : f32
+// ANSE:                   hlfir.assign %[[VAL_38]] to %[[VAL_32]] : f32, !fir.ref<f32>
+// ANSE:                 }
+// ANSE:               }
+// ANSE:             }
+// ANSE:           }
+// ANSE:           return %[[VAL_11]] : !hlfir.expr<?x?xf32>
+// ANSE:         }
+
+// ELEMENTAL:           %[[VAL_2:.*]] = arith.constant 1 : index
+// ELEMENTAL:           %[[VAL_3:.*]] = arith.constant 0.000000e+00 : f32
+// ELEMENTAL:           %[[VAL_4:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr<?x?xf32>) -> !fir.shape<2>
+// ELEMENTAL:           %[[VAL_5:.*]] = hlfir.get_extent %[[VAL_4]] {dim = 0 : index} : (!fir.shape<2>) -> index
+// ELEMENTAL:           %[[VAL_6:.*]] = hlfir.get_extent %[[VAL_4]] {dim = 1 : index} : (!fir.shape<2>) -> index
+// ELEMENTAL:           %[[VAL_7:.*]] = hlfir.shape_of %[[VAL_1]] : (!hlfir.expr<?x?xf16>) -> !fir.shape<2>
+// ELEMENTAL:           %[[VAL_8:.*]] = hlfir.get_extent %[[VAL_7]] {dim = 1 : index} : (!fir.shape<2>) -> index
+// ELEMENTAL:           %[[VAL_9:.*]] = fir.shape %[[VAL_5]], %[[VAL_8]] : (index, index) -> !fir.shape<2>
+// ELEMENTAL:           %[[VAL_10:.*]] = hlfir.elemental %[[VAL_9]] unordered : (!fir.shape<2>) -> !hlfir.expr<?x?xf32> {
+// ELEMENTAL:           ^bb0(%[[VAL_11:.*]]: index, %[[VAL_12:.*]]: index):
+// ELEMENTAL:             %[[VAL_13:.*]] = fir.do_loop %[[VAL_14:.*]] = %[[VAL_2]] to %[[VAL_6]] step %[[VAL_2]] iter_args(%[[VAL_15:.*]] = %[[VAL_3]]) -> (f32) {
+// ELEMENTAL:               %[[VAL_16:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_11]], %[[VAL_14]] : (!hlfir.expr<?x?xf32>, index, index) -> f32
+// ELEMENTAL:               %[[VAL_17:.*]] = hlfir.apply %[[VAL_1]], %[[VAL_14]], %[[VAL_12]] : (!hlfir.expr<?x?xf16>, index, index) -> f16
+// ELEMENTAL:               %[[VAL_18:.*]] = fir.convert %[[VAL_17]] : (f16) -> f32
+// ELEMENTAL:               %[[VAL_19:.*]] = arith.mulf %[[VAL_16]], %[[VAL_18]] : f32
+// ELEMENTAL:               %[[VAL_20:.*]] = arith.addf %[[VAL_15]], %[[VAL_19]] : f32
+// ELEMENTAL:               fir.result %[[VAL_20]] : f32
+// ELEMENTAL:             }
+// ELEMENTAL:             hlfir.yield_element %[[VAL_13]] : f32
+// ELEMENTAL:           }
+// ELEMENTAL:           return %[[VAL_10]] : !hlfir.expr<?x?xf32>
+// ELEMENTAL:         }
+
+func.func @matmul_matrix_matrix_complex(%arg0: !hlfir.expr<?x?xcomplex<f32>>, %arg1: !hlfir.expr<?x?xcomplex<f16>>) -> !hlfir.expr<?x?xcomplex<f32>> {
+  %res = hlfir.matmul %arg0 %arg1 : (!hlfir.expr<?x?xcomplex<f32>>, !hlfir.expr<?x?xcomplex<f16>>) -> !hlfir.expr<?x?xcomplex<f32>>
+  return %res : !hlfir.expr<?x?xcomplex<f32>>
+}
+// ALL-LABEL:   func.func @matmul_matrix_matrix_complex(
+// ALL-SAME:                                            %[[VAL_0:.*]]: !hlfir.expr<?x?xcomplex<f32>>,
+// ALL-SAME:                                            %[[VAL_1:.*]]: !hlfir.expr<?x?xcomplex<f16>>) -> !hlfir.expr<?x?xcomplex<f32>> {
+
+// NOANSE: hlfir.matmul
+
+// ANSE:           %[[VAL_2:.*]] = arith.constant 0 : index
+// ANSE:           %[[VAL_3:.*]] = arith.constant 1 : index
+// ANSE:           %[[VAL_4:.*]] = arith.constant 0.000000e+00 : f32
+// ANSE:           %[[VAL_5:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr<?x?xcomplex<f32>>) -> !fir.shape<2>
+// ANSE:           %[[VAL_6:.*]] = hlfir.get_extent %[[VAL_5]] {dim = 0 : index} : (!fir.shape<2>) -> index
+// ANSE:           %[[VAL_7:.*]] = hlfir.get_extent %[[VAL_5]] {dim = 1 : index} : (!fir.shape<2>) -> index
+// ANSE:           %[[VAL_8:.*]] = hlfir.shape_of %[[VAL_1]] : (!hlfir.expr<?x?xcomplex<f16>>) -> !fir.shape<2>
+// ANSE:           %[[VAL_9:.*]] = hlfir.get_extent %[[VAL_8]] {dim = 1 : index} : (!fir.shape<2>) -> index
+// ANSE:           %[[VAL_10:.*]] = fir.shape %[[VAL_6]], %[[VAL_9]] : (index, index) -> !fir.shape<2>
+// ANSE:           %[[VAL_11:.*]] = hlfir.eval_in_mem shape %[[VAL_10]] : (!fir.shape<2>) -> !hlfir.expr<?x?xcomplex<f32>> {
+// ANSE:           ^bb0(%[[VAL_12:.*]]: !fir.ref<!fir.array<?x?xcomplex<f32>>>):
+// ANSE:             %[[VAL_13:.*]] = fir.embox %[[VAL_12]](%[[VAL_10]]) : (!fir.ref<!fir.array<?x?xcomplex<f32>>>, !fir.shape<2>) -> !fir.box<!fir.array<?x?xcomplex<f32>>>
+// ANSE:             %[[VAL_14:.*]] = fir.undefined complex<f32>
+// ANSE:             %[[VAL_15:.*]] = fir.insert_value %[[VAL_14]], %[[VAL_4]], [0 : index] : (complex<f32>, f32) -> complex<f32>
+// ANSE:             %[[VAL_16:.*]] = fir.insert_value %[[VAL_15]], %[[VAL_4]], [1 : index] : (complex<f32>, f32) -> complex<f32>
+// ANSE:             fir.do_loop %[[VAL_17:.*]] = %[[VAL_3]] to %[[VAL_9]] step %[[VAL_3]] unordered {
+// ANSE:               fir.do_loop %[[VAL_18:.*]] = %[[VAL_3]] to %[[VAL_6]] step %[[VAL_3]] unordered {
+// ANSE:                 %[[VAL_19:.*]]:3 = fir.box_dims %[[VAL_13]], %[[VAL_2]] : (!fir.box<!fir.array<?x?xcomplex<f32>>>, index) -> (index, index, index)
+// ANSE:                 %[[VAL_20:.*]]:3 = fir.box_dims %[[VAL_13]], %[[VAL_3]] : (!fir.box<!fir.array<?x?xcomplex<f32>>>, index) -> (index, index, index)
+// ANSE:                 %[[VAL_21:.*]] = arith.subi %[[VAL_19]]#0, %[[VAL_3]] : index
+// ANSE:                 %[[VAL_22:.*]] = arith.addi %[[VAL_18]], %[[VAL_21]] : index
+// ANSE:                 %[[VAL_23:.*]] = arith.subi %[[VAL_20]]#0, %[[VAL_3]] : index
+// ANSE:                 %[[VAL_24:.*]] = arith.addi %[[VAL_17]], %[[VAL_23]] : index
+// ANSE:                 %[[VAL_25:.*]] = hlfir.designate %[[VAL_13]] (%[[VAL_22]], %[[VAL_24]])  : (!fir.box<!fir.array<?x?xcomplex<f32>>>, index, index) -> !fir.ref<complex<f32>>
+// ANSE:                 hlfir.assign %[[VAL_16]] to %[[VAL_25]] : complex<f32>, !fir.ref<complex<f32>>
+// ANSE:               }
+// ANSE:             }
+// ANSE:             fir.do_loop %[[VAL_26:.*]] = %[[VAL_3]] to %[[VAL_7]] step %[[VAL_3]] {
+// ANSE:               fir.do_loop %[[VAL_27:.*]] = %[[VAL_3]] to %[[VAL_9]] step %[[VAL_3]] {
+// ANSE:                 fir.do_loop %[[VAL_28:.*]] = %[[VAL_3]] to %[[VAL_6]] step %[[VAL_3]] {
+// ANSE:                   %[[VAL_29:.*]]:3 = fir.box_dims %[[VAL_13]], %[[VAL_2]] : (!fir.box<!fir.array<?x?xcomplex<f32>>>, index) -> (index, index, index)
+// ANSE:                   %[[VAL_30:.*]]:3 = fir.box_dims %[[VAL_13]], %[[VAL_3]] : (!fir.box<!fir.array<?x?xcomplex<f32>>>, index) -> (index, index, index)
+// ANSE:                   %[[VAL_31:.*]] = arith.subi %[[VAL_29]]#0, %[[VAL_3]] : index
+// ANSE:                   %[[VAL_32:.*]] = arith.addi %[[VAL_28]], %[[VAL_31]] : index
+// ANSE:                   %[[VAL_33:.*]] = arith.subi %[[VAL_30]]#0, %[[VAL_3]] : index
+// ANSE:                   %[[VAL_34:.*]] = arith.addi %[[VAL_27]], %[[VAL_33]] : index
+// ANSE:                   %[[VAL_35:.*]] = hlfir.designate %[[VAL_13]] (%[[VAL_32]], %[[VAL_34]])  : (!fir.box<!fir.array<?x?xcomplex<f32>>>, index, index) -> !fir.ref<complex<f32>>
+// ANSE:                   %[[VAL_36:.*]] = fir.load %[[VAL_35]] : !fir.ref<complex<f32>>
+// ANSE:                   %[[VAL_37:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_28]], %[[VAL_26]] : (!hlfir.expr<?x?xcomplex<f32>>, index, index) -> complex<f32>
+// ANSE:                   %[[VAL_38:.*]] = hlfir.apply %[[VAL_1]], %[[VAL_26]], %[[VAL_27]] : (!hlfir.expr<?x?xcomplex<f16>>, index, index) -> complex<f16>
+// ANSE:                   %[[VAL_39:.*]] = fir.convert %[[VAL_38]] : (complex<f16>) -> complex<f32>
+// ANSE:                   %[[VAL_40:.*]] = fir.mulc %[[VAL_37]], %[[VAL_39]] : complex<f32>
+// ANSE:                   %[[VAL_41:.*]] = fir.addc %[[VAL_36]], %[[VAL_40]] : complex<f32>
+// ANSE:                   hlfir.assign %[[VAL_41]] to %[[VAL_35]] : complex<f32>, !fir.ref<complex<f32>>
+// ANSE:                 }
+// ANSE:               }
+// ANSE:             }
+// ANSE:           }
+// ANSE:           return %[[VAL_11]] : !hlfir.expr<?x?xcomplex<f32>>
+// ANSE:         }
+
+// ELEMENTAL:           %[[VAL_2:.*]] = arith.constant 1 : index
+// ELEMENTAL:           %[[VAL_3:.*]] = arith.constant 0.000000e+00 : f32
+// ELEMENTAL:           %[[VAL_4:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr<?x?xcomplex<f32>>) -> !fir.shape<2>
+// ELEMENTAL:           %[[VAL_5:.*]] = hlfir.get_extent %[[VAL_4]] {dim = 0 : index} : (!fir.shape<2>) -> index
+// ELEMENTAL:           %[[VAL_6:.*]] = hlfir.get_extent %[[VAL_4]] {dim = 1 : index} : (!fir.shape<2>) -> index
+// ELEMENTAL:           %[[VAL_7:.*]] = hlfir.shape_of %[[VAL_1]] : (!hlfir.expr<?x?xcomplex<f16>>) -> !fir.shape<2>
+// ELEMENTAL:           %[[VAL_8:.*]] = hlfir.get_extent %[[VAL_7]] {dim = 1 : index} : (!fir.shape<2>) -> index
+// ELEMENTAL:           %[[VAL_9:.*]] = fir.shape %[[VAL_5]], %[[VAL_8]] : (index, index) -> !fir.shape<2>
+// ELEMENTAL:           %[[VAL_10:.*]] = hlfir.elemental %[[VAL_9]] unordered : (!fir.shape<2>) -> !hlfir.expr<?x?xcomplex<f32>> {
+// ELEMENTAL:           ^bb0(%[[VAL_11:.*]]: index, %[[VAL_12:.*]]: index):
+// ELEMENTAL:             %[[VAL_13:.*]] = fir.undefined complex<f32>
+// ELEMENTAL:             %[[VAL_14:.*]] = fir.insert_value %[[VAL_13]], %[[VAL_3]], [0 : index] : (complex<f32>, f32) -> complex<f32>
+// ELEMENTAL:             %[[VAL_15:.*]] = fir.insert_value %[[VAL_14]], %[[VAL_3]], [1 : index] : (complex<f32>, f32) -> complex<f32>
+// ELEMENTAL:             %[[VAL_16:.*]] = fir.do_loop %[[VAL_17:.*]] = %[[VAL_2]] to %[[VAL_6]] step %[[VAL_2]] iter_args(%[[VAL_18:.*]] = %[[VAL_15]]) -> (complex<f32>) {
+// ELEMENTAL:               %[[VAL_19:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_11]], %[[VAL_17]] : (!hlfir.expr<?x?xcomplex<f32>>, index, index) -> complex<f32>
+// ELEMENTAL:               %[[VAL_20:.*]] = hlfir.apply %[[VAL_1]], %[[VAL_17]], %[[VAL_12]] : (!hlfir.expr<?x?xcomplex<f16>>, index, index) -> complex<f16>
+// ELEMENTAL:               %[[VAL_21:.*]] = fir.convert %[[VAL_20]] : (complex<f16>) -> complex<f32>
+// ELEMENTAL:               %[[VAL_22:.*]] = fir.mulc %[[VAL_19]], %[[VAL_21]] : complex<f32>
+// ELEMENTAL:               %[[VAL_23:.*]] = fir.addc %[[VAL_18]], %[[VAL_22]] : complex<f32>
+// ELEMENTAL:               fir.result %[[VAL_23]] : complex<f32>
+// ELEMENTAL:             }
+// ELEMENTAL:             hlfir.yield_element %[[VAL_16]] : complex<f32>
+// ELEMENTAL:           }
+// ELEMENTAL:           return %[[VAL_10]] : !hlfir.expr<?x?xcomplex<f32>>
+// ELEMENTAL:         }
+
+func.func @matmul_matrix_matrix_complex_real(%arg0: !hlfir.expr<?x?xcomplex<f32>>, %arg1: !hlfir.expr<?x?xf16>) -> !hlfir.expr<?x?xcomplex<f32>> {
+  %res = hlfir.matmul %arg0 %arg1 : (!hlfir.expr<?x?xcomplex<f32>>, !hlfir.expr<?x?xf16>) -> !hlfir.expr<?x?xcomplex<f32>>
+  return %res : !hlfir.expr<?x?xcomplex<f32>>
+}
+// ALL-LABEL:   func.func @matmul_matrix_matrix_complex_real(
+// ALL-SAME:                                                 %[[VAL_0:.*]]: !hlfir.expr<?x?xcomplex<f32>>,
+// ALL-SAME:                                                 %[[VAL_1:.*]]: !hlfir.expr<?x?xf16>) -> !hlfir.expr<?x?xcomplex<f32>> {
+
+// NOANSE: hlfir.matmul
+
+// ANSE:           %[[VAL_2:.*]] = arith.constant 0 : index
+// ANSE:           %[[VAL_3:.*]] = arith.constant 1 : index
+// ANSE:           %[[VAL_4:.*]] = arith.constant 0.000000e+00 : f32
+// ANSE:           %[[VAL_5:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr<?x?xcomplex<f32>>) -> !fir.shape<2>
+// ANSE:           %[[VAL_6:.*]] = hlfir.get_extent %[[VAL_5]] {dim = 0 : index} : (!fir.shape<2>) -> index
+// ANSE:           %[[VAL_7:.*]] = hlfir.get_extent %[[VAL_5]] {dim = 1 : index} : (!fir.shape<2>) -> index
+// ANSE:           %[[VAL_8:.*]] = hlfir.shape_of %[[VAL_1]] : (!hlfir.expr<?x?xf16>) -> !fir.shape<2>
+// ANSE:           %[[VAL_9:.*]] = hlfir.get_extent %[[VAL_8]] {dim = 1 : index} : (!fir.shape<2>) -> index
+// ANSE:           %[[VAL_10:.*]] = fir.shape %[[VAL_6]], %[[VAL_9]] : (index, index) -> !fir.shape<2>
+// ANSE:           %[[VAL_11:.*]] = hlfir.eval_in_mem shape %[[VAL_10]] : (!fir.shape<2>) -> !hlfir.expr<?x?xcomplex<f32>> {
+// ANSE:           ^bb0(%[[VAL_12:.*]]: !fir.ref<!fir.array<?x?xcomplex<f32>>>):
+// ANSE:             %[[VAL_13:.*]] = fir.embox %[[VAL_12]](%[[VAL_10]]) : (!fir.ref<!fir.array<?x?xcomplex<f32>>>, !fir.shape<2>) -> !fir.box<!fir.array<?x?xcomplex<f32>>>
+// ANSE:             %[[VAL_14:.*]] = fir.undefined complex<f32>
+// ANSE:             %[[VAL_15:.*]] = fir.insert_value %[[VAL_14]], %[[VAL_4]], [0 : index] : (complex<f32>, f32) -> complex<f32>
+// ANSE:             %[[VAL_16:.*]] = fir.insert_value %[[VAL_15]], %[[VAL_4]], [1 : index] : (complex<f32>, f32) -> complex<f32>
+// ANSE:             fir.do_loop %[[VAL_17:.*]] = %[[VAL_3]] to %[[VAL_9]] step %[[VAL_3]] unordered {
+// ANSE:               fir.do_loop %[[VAL_18:.*]] = %[[VAL_3]] to %[[VAL_6]] step %[[VAL_3]] unordered {
+// ANSE:                 %[[VAL_19:.*]]:3 = fir.box_dims %[[VAL_13]], %[[VAL_2]] : (!fir.box<!fir.array<?x?xcomplex<f32>>>, index) -> (index, index, index)
+// ANSE:                 %[[VAL_20:.*]]:3 = fir.box_dims %[[VAL_13]], %[[VAL_3]] : (!fir.box<!fir.array<?x?xcomplex<f32>>>, index) -> (index, index, index)
+// ANSE:                 %[[VAL_21:.*]] = arith.subi %[[VAL_19]]#0, %[[VAL_3]] : index
+// ANSE:                 %[[VAL_22:.*]] = arith.addi %[[VAL_18]], %[[VAL_21]] : index
+// ANSE:                 %[[VAL_23:.*]] = arith.subi %[[VAL_20]]#0, %[[VAL_3]] : index
+// ANSE:                 %[[VAL_24:.*]] = arith.addi %[[VAL_17]], %[[VAL_23]] : index
+// ANSE:                 %[[VAL_25:.*]] = hlfir.designate %[[VAL_13]] (%[[VAL_22]], %[[VAL_24]])  : (!fir.box<!fir.array<?x?xcomplex<f32>>>, index, index) -> !fir.ref<complex<f32>>
+// ANSE:                 hlfir.assign %[[VAL_16]] to %[[VAL_25]] : complex<f32>, !fir.ref<complex<f32>>
+// ANSE:               }
+// ANSE:             }
+// ANSE:             fir.do_loop %[[VAL_26:.*]] = %[[VAL_3]] to %[[VAL_7]] step %[[VAL_3]] {
+// ANSE:               fir.do_loop %[[VAL_27:.*]] = %[[VAL_3]] to %[[VAL_9]] step %[[VAL_3]] {
+// ANSE:                 fir.do_loop %[[VAL_28:.*]] = %[[VAL_3]] to %[[VAL_6]] step %[[VAL_3]] {
+// ANSE:                   %[[VAL_29:.*]]:3 = fir.box_dims %[[VAL_13]], %[[VAL_2]] : (!fir.box<!fir.array<?x?xcomplex<f32>>>, index) -> (index, index, index)
+// ANSE:                   %[[VAL_30:.*]]:3 = fir.box_dims %[[VAL_13]], %[[VAL_3]] : (!fir.box<!fir.array<?x?xcomplex<f32>>>, index) -> (index, index, index)
+// ANSE:                   %[[VAL_31:.*]] = arith.subi %[[VAL_29]]#0, %[[VAL_3]] : index
+// ANSE:                   %[[VAL_32:.*]] = arith.addi %[[VAL_28]], %[[VAL_31]] : index
+// ANSE:                   %[[VAL_33:.*]] = arith.subi %[[VAL_30]]#0, %[[VAL_3]] : index
+// ANSE:                   %[[VAL_34:.*]] = arith.addi %[[VAL_27]], %[[VAL_33]] : index
+// ANSE:                   %[[VAL_35:.*]] = hlfir.designate %[[VAL_13]] (%[[VAL_32]], %[[VAL_34]])  : (!fir.box<!fir.array<?x?xcomplex<f32>>>, index, index) -> !fir.ref<complex<f32>>
+// ANSE:                   %[[VAL_36:.*]] = fir.load %[[VAL_35]] : !fir.ref<complex<f32>>
+// ANSE:                   %[[VAL_37:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_28]], %[[VAL_26]] : (!hlfir.expr<?x?xcomplex<f32>>, index, index) -> complex<f32>
+// ANSE:                   %[[VAL_38:.*]] = hlfir.apply %[[VAL_1]], %[[VAL_26]], %[[VAL_27]] : (!hlfir.expr<?x?xf16>, index, index) -> f16
+// ANSE:                   %[[VAL_39:.*]] = fir.undefined complex<f32>
+// ANSE:                   %[[VAL_40:.*]] = fir.insert_value %[[VAL_39]], %[[VAL_4]], [0 : index] : (complex<f32>, f32) -> complex<f32>
+// ANSE:                   %[[VAL_41:.*]] = fir.insert_value %[[VAL_40]], %[[VAL_4]], [1 : index] : (complex<f32>, f32) -> complex<f32>
+// ANSE:                   %[[VAL_42:.*]] = fir.convert %[[VAL_38]] : (f16) -> f32
+// ANSE:                   %[[VAL_43:.*]] = fir.insert_value %[[VAL_41]], %[[VAL_42]], [0 : index] : (complex<f32>, f32) -> complex<f32>
+// ANSE:                   %[[VAL_44:.*]] = fir.mulc %[[VAL_37]], %[[VAL_43]] : complex<f32>
+// ANSE:                   %[[VAL_45:.*]] = fir.addc %[[VAL_36]], %[[VAL_44]] : complex<f32>
+// ANSE:                   hlfir.assign %[[VAL_45]] to %[[VAL_35]] : complex<f32>, !fir.ref<complex<f32>>
+// ANSE:                 }
+// ANSE:               }
+// ANSE:             }
+// ANSE:           }
+// ANSE:           return %[[VAL_11]] : !hlfir.expr<?x?xcomplex<f32>>
+// ANSE:         }
+
+// ELEMENTAL:           %[[VAL_2:.*]] = arith.constant 1 : index
+// ELEMENTAL:           %[[VAL_3:.*]] = arith.constant 0.000000e+00 : f32
+// ELEMENTAL:           %[[VAL_4:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr<?x?xcomplex<f32>>) -> !fir.shape<2>
+// ELEMENTAL:           %[[VAL_5:.*]] = hlfir.get_extent %[[VAL_4]] {dim = 0 : index} : (!fir.shape<2>) -> index
+// ELEMENTAL:           %[[VAL_6:.*]] = hlfir.get_extent %[[VAL_4]] {dim = 1 : index} : (!fir.shape<2>) -> index
+// ELEMENTAL:           %[[VAL_7:.*]] = hlfir.shape_of %[[VAL_1]] : (!hlfir.expr<?x?xf16>) -> !fir.shape<2>
+// ELEMENTAL:           %[[VAL_8:.*]] = hlfir.get_extent %[[VAL_7]] {dim = 1 : index} : (!fir.shape<2>) -> index
+// ELEMENTAL:           %[[VAL_9:.*]] = fir.shape %[[VAL_5]], %[[VAL_8]] : (index, index) -> !fir.shape<2>
+// ELEMENTAL:           %[[VAL_10:.*]] = hlfir.elemental %[[VAL_9]] unordered : (!fir.shape<2>) -> !hlfir.expr<?x?xcomplex<f32>> {
+// ELEMENTAL:           ^bb0(%[[VAL_11:.*]]: index, %[[VAL_12:.*]]: index):
+// ELEMENTAL:             %[[VAL_13:.*]] = fir.undefined complex<f32>
+// ELEMENTAL:             %[[VAL_14:.*]] = fir.insert_value %[[VAL_13]], %[[VAL_3]], [0 : index] : (complex<f32>, f32) -> complex<f32>
+// ELEMENTAL:             %[[VAL_15:.*]] = fir.insert_value %[[VAL_14]], %[[VAL_3]], [1 : index] : (complex<f32>, f32) -> complex<f32>
+// ELEMENTAL:             %[[VAL_16:.*]] = fir.do_loop %[[VAL_17:.*]] = %[[VAL_2]] to %[[VAL_6]] step %[[VAL_2]] iter_args(%[[VAL_18:.*]] = %[[VAL_15]]) -> (complex<f32>) {
+// ELEMENTAL:               %[[VAL_19:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_11]], %[[VAL_17]] : (!hlfir.expr<?x?xcomplex<f32>>, index, index) -> complex<f32>
+// ELEMENTAL:               %[[VAL_20:.*]] = hlfir.apply %[[VAL_1]], %[[VAL_17]], %[[VAL_12]] : (!hlfir.expr<?x?xf16>, index, index) -> f16
+// ELEMENTAL:               %[[VAL_21:.*]] = fir.undefined complex<f32>
+// ELEMENTAL:               %[[VAL_22:.*]] = fir.insert_value %[[VAL_21]], %[[VAL_3]], [0 : index] : (complex<f32>, f32) -> complex<f32>
+// ELEMENTAL:               %[[VAL_23:.*]] = fir.insert_value %[[VAL_22]], %[[VAL_3]], [1 : index] : (complex<f32>, f32) -> complex<f32>
+// ELEMENTAL:               %[[VAL_24:.*]] = fir.convert %[[VAL_20]] : (f16) -> f32
+// ELEMENTAL:               %[[VAL_25:.*]] = fir.insert_value %[[VAL_23]], %[[VAL_24]], [0 : index] : (complex<f32>, f32) -> complex<f32>
+// ELEMENTAL:               %[[VAL_26:.*]] = fir.mulc %[[VAL_19]], %[[VAL_25]] : complex<f32>
+// ELEMENTAL:               %[[VAL_27:.*]] = fir.addc %[[VAL_18]], %[[VAL_26]] : complex<f32>
+// ELEMENTAL:               fir.result %[[VAL_27]] : complex<f32>
+// ELEMENTAL:             }
+// ELEMENTAL:             hlfir.yield_element %[[VAL_16]] : complex<f32>
+// ELEMENTAL:           }
+// ELEMENTAL:           return %[[VAL_10]] : !hlfir.expr<?x?xcomplex<f32>>
+// ELEMENTAL:         }
+
+func.func @matmul_matrix_matrix_logical(%arg0: !hlfir.expr<?x?x!fir.logical<1>>, %arg1: !hlfir.expr<?x?x!fir.logical<4>>) -> !hlfir.expr<?x?x!fir.logical<4>> {
+  %res = hlfir.matmul %arg0 %arg1 : (!hlfir.expr<?x?x!fir.logical<1>>, !hlfir.expr<?x?x!fir.logical<4>>) -> !hlfir.expr<?x?x!fir.logical<4>>
+  return %res : !hlfir.expr<?x?x!fir.logical<4>>
+}
+// ALL-LABEL:   func.func @matmul_matrix_matrix_logical(
+// ALL-SAME:                                            %[[VAL_0:.*]]: !hlfir.expr<?x?x!fir.logical<1>>,
+// ALL-SAME:                                            %[[VAL_1:.*]]: !hlfir.expr<?x?x!fir.logical<4>>) -> !hlfir.expr<?x?x!fir.logical<4>> {
+
+// NOANSE: hlfir.matmul
+
+// ANSE:           %[[VAL_2:.*]] = arith.constant 0 : index
+// ANSE:           %[[VAL_3:.*]] = arith.constant 1 : index
+// ANSE:           %[[VAL_4:.*]] = arith.constant false
+// ANSE:           %[[VAL_5:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr<?x?x!fir.logical<1>>) -> !fir.shape<2>
+// ANSE:           %[[VAL_6:.*]] = hlfir.get_extent %[[VAL_5]] {dim = 0 : index} : (!fir.shape<2>) -> index
+// ANSE:           %[[VAL_7:.*]] = hlfir.get_extent %[[VAL_5]] {dim = 1 : index} : (!fir.shape<2>) -> index
+// ANSE:           %[[VAL_8:.*]] = hlfir.shape_of %[[VAL_1]] : (!hlfir.expr<?x?x!fir.logical<4>>) -> !fir.shape<2>
+// ANSE:           %[[VAL_9:.*]] = hlfir.get_extent %[[VAL_8]] {dim = 1 : index} : (!fir.shape<2>) -> index
+// ANSE:           %[[VAL_10:.*]] = fir.shape %[[VAL_6]], %[[VAL_9]] : (index, index) -> !fir.shape<2>
+// ANSE:           %[[VAL_11:.*]] = hlfir.eval_in_mem shape %[[VAL_10]] : (!fir.shape<2>) -> !hlfir.expr<?x?x!fir.logical<4>> {
+// ANSE:           ^bb0(%[[VAL_12:.*]]: !fir.ref<!fir.array<?x?x!fir.logical<4>>>):
+// ANSE:             %[[VAL_13:.*]] = fir.embox %[[VAL_12]](%[[VAL_10]]) : (!fir.ref<!fir.array<?x?x!fir.logical<4>>>, !fir.shape<2>) -> !fir.box<!fir.array<?x?x!fir.logical<4>>>
+// ANSE:             %[[VAL_14:.*]] = fir.convert %[[VAL_4]] : (i1) -> !fir.logical<4>
+// ANSE:             fir.do_loop %[[VAL_15:.*]] = %[[VAL_3]] to %[[VAL_9]] step %[[VAL_3]] unordered {
+// ANSE:               fir.do_loop %[[VAL_16:.*]] = %[[VAL_3]] to %[[VAL_6]] step %[[VAL_3]] unordered {
+// ANSE:                 %[[VAL_17:.*]]:3 = fir.box_dims %[[VAL_13]], %[[VAL_2]] : (!fir.box<!fir.array<?x?x!fir.logical<4>>>, index) -> (index, index, index)
+// ANSE:                 %[[VAL_18:.*]]:3 = fir.box_dims %[[VAL_13]], %[[VAL_3]] : (!fir.box<!fir.array<?x?x!fir.logical<4>>>, index) -> (index, index, index)
+// ANSE:                 %[[VAL_19:.*]] = arith.subi %[[VAL_17]]#0, %[[VAL_3]] : index
+// ANSE:                 %[[VAL_20:.*]] = arith.addi %[[VAL_16]], %[[VAL_19]] : index
+// ANSE:                 %[[VAL_21:.*]] = arith.subi %[[VAL_18]]#0, %[[VAL_3]] : index
+// ANSE:                 %[[VAL_22:.*]] = arith.addi %[[VAL_15]], %[[VAL_21]] : index
+// ANSE:                 %[[VAL_23:.*]] = hlfir.designate %[[VAL_13]] (%[[VAL_20]], %[[VAL_22]])  : (!fir.box<!fir.array<?x?x!fir.logical<4>>>, index, index) -> !fir.ref<!fir.logical<4>>
+// ANSE:                 hlfir.assign %[[VAL_14]] to %[[VAL_23]] : !fir.logical<4>, !fir.ref<!fir.logical<4>>
+// ANSE:               }
+// ANSE:             }
+// ANSE:             fir.do_loop %[[VAL_24:.*]] = %[[VAL_3]] to %[[VAL_7]] step %[[VAL_3]] unordered {
+// ANSE:               fir.do_loop %[[VAL_25:.*]] = %[[VAL_3]] to %[[VAL_9]] step %[[VAL_3]] unordered {
+// ANSE:                 fir.do_loop %[[VAL_26:.*]] = %[[VAL_3]] to %[[VAL_6]] step %[[VAL_3]] unordered {
+// ANSE:                   %[[VAL_27:.*]]:3 = fir.box_dims %[[VAL_13]], %[[VAL_2]] : (!fir.box<!fir.array<?x?x!fir.logical<4>>>, index) -> (index, index, index)
+// ANSE:                   %[[VAL_28:.*]]:3 = fir.box_dims %[[VAL_13]], %[[VAL_3]] : (!fir.box<!fir.array<?x?x!fir.logical<4>>>, index) -> (index, index, index)
+// ANSE:                   %[[VAL_29:.*]] = arith.subi %[[VAL_27]]#0, %[[VAL_3]] : index
+// ANSE:                   %[[VAL_30:.*]] = arith.addi %[[VAL_26]], %[[VAL_29]] : index
+// ANSE:                   %[[VAL_31:.*]] = arith.subi %[[VAL_28]]#0, %[[VAL_3]] : index
+// ANSE:                   %[[VAL_32:.*]] = arith.addi %[[VAL_25]], %[[VAL_31]] : index
+// ANSE:                   %[[VAL_33:.*]] = hlfir.designate %[[VAL_13]] (%[[VAL_30]], %[[VAL_32]])  : (!fir.box<!fir.array<?x?x!fir.logical<4>>>, index, index) -> !fir.ref<!fir.logical<4>>
+// ANSE:                   %[[VAL_34:.*]] = fir.load %[[VAL_33]] : !fir.ref<!fir.logical<4>>
+// ANSE:                   %[[VAL_35:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_26]], %[[VAL_24]] : (!hlfir.expr<?x?x!fir.logical<1>>, index, index) -> !fir.logical<1>
+// ANSE:                   %[[VAL_36:.*]] = hlfir.apply %[[VAL_1]], %[[VAL_24]], %[[VAL_25]] : (!hlfir.expr<?x?x!fir.logical<4>>, index, index) -> !fir.logical<4>
+// ANSE:                   %[[VAL_37:.*]] = fir.convert %[[VAL_34]] : (!fir.logical<4>) -> i1
+// ANSE:                   %[[VAL_38:.*]] = fir.convert %[[VAL_35]] : (!fir.logical<1>) -> i1
+// ANSE:                   %[[VAL_39:.*]] = fir.convert %[[VAL_36]] : (!fir.logical<4>) -> i1
+// ANSE:                   %[[VAL_40:.*]] = arith.andi %[[VAL_38]], %[[VAL_39]] : i1
+// ANSE:                   %[[VAL_41:.*]] = arith.ori %[[VAL_37]], %[[VAL_40]] : i1
+// ANSE:                   %[[VAL_42:.*]] = fir.convert %[[VAL_41]] : (i1) -> !fir.logical<4>
+// ANSE:                   hlfir.assign %[[VAL_42]] to %[[VAL_33]] : !fir.logical<4>, !fir.ref<!fir.logical<4>>
+// ANSE:                 }
+// ANSE:               }
+// ANSE:             }
+// ANSE:           }
+// ANSE:           return %[[VAL_11]] : !hlfir.expr<?x?x!fir.logical<4>>
+// ANSE:         }
+
+// ELEMENTAL:           %[[VAL_2:.*]] = arith.constant 1 : index
+// ELEMENTAL:           %[[VAL_3:.*]] = arith.constant false
+// ELEMENTAL:           %[[VAL_4:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr<?x?x!fir.logical<1>>) -> !fir.shape<2>
+// ELEMENTAL:           %[[VAL_5:.*]] = hlfir.get_extent %[[VAL_4]] {dim = 0 : index} : (!fir.shape<2>) -> index
+// ELEMENTAL:           %[[VAL_6:.*]] = hlfir.get_extent %[[VAL_4]] {dim = 1 : index} : (!fir.shape<2>) -> index
+// ELEMENTAL:           %[[VAL_7:.*]] = hlfir.shape_of %[[VAL_1]] : (!hlfir.expr<?x?x!fir.logical<4>>) -> !fir.shape<2>
+// ELEMENTAL:           %[[VAL_8:.*]] = hlfir.get_extent %[[VAL_7]] {dim = 1 : index} : (!fir.shape<2>) -> index
+// ELEMENTAL:           %[[VAL_9:.*]] = fir.shape %[[VAL_5]], %[[VAL_8]] : (index, index) -> !fir.shape<2>
+// ELEMENTAL:           %[[VAL_10:.*]] = hlfir.elemental %[[VAL_9]] unordered : (!fir.shape<2>) -> !hlfir.expr<?x?x!fir.logical<4>> {
+// ELEMENTAL:           ^bb0(%[[VAL_11:.*]]: index, %[[VAL_12:.*]]: index):
+// ELEMENTAL:             %[[VAL_13:.*]] = fir.convert %[[VAL_3]] : (i1) -> !fir.logical<4>
+// ELEMENTAL:             %[[VAL_14:.*]] = fir.do_loop %[[VAL_15:.*]] = %[[VAL_2]] to %[[VAL_6]] step %[[VAL_2]] unordered iter_args(%[[VAL_16:.*]] = %[[VAL_13]]) -> (!fir.logical<4>) {
+// ELEMENTAL:               %[[VAL_17:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_11]], %[[VAL_15]] : (!hlfir.expr<?x?x!fir.logical<1>>, index, index) -> !fir.logical<1>
+// ELEMENTAL:               %[[VAL_18:.*]] = hlfir.apply %[[VAL_1]], %[[VAL_15]], %[[VAL_12]] : (!hlfir.expr<?x?x!fir.logical<4>>, index, index) -> !fir.logical<4>
+// ELEMENTAL:               %[[VAL_19:.*]] = fir.convert %[[VAL_16]] : (!fir.logical<4>) -> i1
+// ELEMENTAL:               %[[VAL_20:.*]] = fir.convert %[[VAL_17]] : (!fir.logical<1>) -> i1
+// ELEMENTAL:               %[[VAL_21:.*]] = fir.convert %[[VAL_18]] : (!fir.logical<4>) -> i1
+// ELEMENTAL:               %[[VAL_22:.*]] = arith.andi %[[VAL_20]], %[[VAL_21]] : i1
+// ELEMENTAL:               %[[VAL_23:.*]] = arith.ori %[[VAL_19]], %[[VAL_22]] : i1
+// ELEMENTAL:               %[[VAL_24:.*]] = fir.convert %[[VAL_23]] : (i1) -> !fir.logical<4>
+// ELEMENTAL:               fir.result %[[VAL_24]] : !fir.logical<4>
+// ELEMENTAL:             }
+// ELEMENTAL:             hlfir.yield_element %[[VAL_14]] : !fir.logical<4>
+// ELEMENTAL:           }
+// ELEMENTAL:           return %[[VAL_10]] : !hlfir.expr<?x?x!fir.logical<4>>
+// ELEMENTAL:         }
+
+func.func @matmul_matrix_vector_real(%arg0: !hlfir.expr<?x?xf32>, %arg1: !hlfir.expr<?xf16>) -> !hlfir.expr<?xf32> {
+  %res = hlfir.matmul %arg0 %arg1 : (!hlfir.expr<?x?xf32>, !hlfir.expr<?xf16>) -> !hlfir.expr<?xf32>
+  return %res : !hlfir.expr<?xf32>
+}
+// ALL-LABEL:   func.func @matmul_matrix_vector_real(
+// ALL-SAME:                                         %[[VAL_0:.*]]: !hlfir.expr<?x?xf32>,
+// ALL-SAME:                                         %[[VAL_1:.*]]: !hlfir.expr<?xf16>) -> !hlfir.expr<?xf32> {
+
+// NOANSE: hlfir.matmul
+
+// ANSE:           %[[VAL_2:.*]] = arith.constant 0 : index
+// ANSE:           %[[VAL_3:.*]] = arith.constant 1 : index
+// ANSE:           %[[VAL_4:.*]] = arith.constant 0.000000e+00 : f32
+// ANSE:           %[[VAL_5:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr<?x?xf32>) -> !fir.shape<2>
+// ANSE:           %[[VAL_6:.*]] = hlfir.get_extent %[[VAL_5]] {dim = 0 : index} : (!fir.shape<2>) -> index
+// ANSE:           %[[VAL_7:.*]] = hlfir.get_extent %[[VAL_5]] {dim = 1 : index} : (!fir.shape<2>) -> index
+// ANSE:           %[[VAL_8:.*]] = fir.shape %[[VAL_6]] : (index) -> !fir.shape<1>
+// ANSE:           %[[VAL_9:.*]] = hlfir.eval_in_mem shape %[[VAL_8]] : (!fir.shape<1>) -> !hlfir.expr<?xf32> {
+// ANSE:           ^bb0(%[[VAL_10:.*]]: !fir.ref<!fir.array<?xf32>>):
+// ANSE:             %[[VAL_11:.*]] = fir.embox %[[VAL_10]](%[[VAL_8]]) : (!fir.ref<!fir.array<?xf32>>, !fir.shape<1>) -> !fir.box<!fir.array<?xf32>>
+// ANSE:             fir.do_loop %[[VAL_12:.*]] = %[[VAL_3]] to %[[VAL_6]] step %[[VAL_3]] unordered {
+// ANSE:               %[[VAL_13:.*]]:3 = fir.box_dims %[[VAL_11]], %[[VAL_2]] : (!fir.box<!fir.array<?xf32>>, index) -> (index, index, index)
+// ANSE:               %[[VAL_14:.*]] = arith.subi %[[VAL_13]]#0, %[[VAL_3]] : index
+// ANSE:               %[[VAL_15:.*]] = arith.addi %[[VAL_12]], %[[VAL_14]] : index
+// ANSE:               %[[VAL_16:.*]] = hlfir.designate %[[VAL_11]] (%[[VAL_15]])  : (!fir.box<!fir.array<?xf32>>, index) -> !fir.ref<f32>
+// ANSE:               hlfir.assign %[[VAL_4]] to %[[VAL_16]] : f32, !fir.ref<f32>
+// ANSE:             }
+// ANSE:             fir.do_loop %[[VAL_17:.*]] = %[[VAL_3]] to %[[VAL_7]] step %[[VAL_3]] {
+// ANSE:               fir.do_loop %[[VAL_18:.*]] = %[[VAL_3]] to %[[VAL_6]] step %[[VAL_3]] {
+// ANSE:                 %[[VAL_19:.*]]:3 = fir.box_dims %[[VAL_11]], %[[VAL_2]] : (!fir.box<!fir.array<?xf32>>, index) -> (index, index, index)
+// ANSE:                 %[[VAL_20:.*]] = arith.subi %[[VAL_19]]#0, %[[VAL_3]] : index
+// ANSE:                 %[[VAL_21:.*]] = arith.addi %[[VAL_18]], %[[VAL_20]] : index
+// ANSE:                 %[[VAL_22:.*]] = hlfir.designate %[[VAL_11]] (%[[VAL_21]])  : (!fir.box<!fir.array<?xf32>>, index) -> !fir.ref<f32>
+// ANSE:                 %[[VAL_23:.*]] = fir.load %[[VAL_22]] : !fir.ref<f32>
+// ANSE:                 %[[VAL_24:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_18]], %[[VAL_17]] : (!hlfir.expr<?x?xf32>, index, index) -> f32
+// ANSE:                 %[[VAL_25:.*]] = hlfir.apply %[[VAL_1]], %[[VAL_17]] : (!hlfir.expr<?xf16>, index) -> f16
+// ANSE:                 %[[VAL_26:.*]] = fir.convert %[[VAL_25]] : (f16) -> f32
+// ANSE:                 %[[VAL_27:.*]] = arith.mulf %[[VAL_24]], %[[VAL_26]] : f32
+// ANSE:                 %[[VAL_28:.*]] = arith.addf %[[VAL_23]], %[[VAL_27]] : f32
+// ANSE:                 hlfir.assign %[[VAL_28]] to %[[VAL_22]] : f32, !fir.ref<f32>
+// ANSE:               }
+// ANSE:             }
+// ANSE:           }
+// ANSE:           return %[[VAL_9]] : !hlfir.expr<?xf32>
+// ANSE:         }
+
+// ELEMENTAL:           %[[VAL_2:.*]] = arith.constant 1 : index
+// ELEMENTAL:           %[[VAL_3:.*]] = arith.constant 0.000000e+00 : f32
+// ELEMENTAL:           %[[VAL_4:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr<?x?xf32>) -> !fir.shape<2>
+// ELEMENTAL:           %[[VAL_5:.*]] = hlfir.get_extent %[[VAL_4]] {dim = 0 : index} : (!fir.shape<2>) -> index
+// ELEMENTAL:           %[[VAL_6:.*]] = hlfir.get_extent %[[VAL_4]] {dim = 1 : index} : (!fir.shape<2>) -> index
+// ELEMENTAL:           %[[VAL_7:.*]] = fir.shape %[[VAL_5]] : (index) -> !fir.shape<1>
+// ELEMENTAL:           %[[VAL_8:.*]] = hlfir.elemental %[[VAL_7]] unordered : (!fir.shape<1>) -> !hlfir.expr<?xf32> {
+// ELEMENTAL:           ^bb0(%[[VAL_9:.*]]: index):
+// ELEMENTAL:             %[[VAL_10:.*]] = fir.do_loop %[[VAL_11:.*]] = %[[VAL_2]] to %[[VAL_6]] step %[[VAL_2]] iter_args(%[[VAL_12:.*]] = %[[VAL_3]]) -> (f32) {
+// ELEMENTAL:               %[[VAL_13:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_9]], %[[VAL_11]] : (!hlfir.expr<?x?xf32>, index, index) -> f32
+// ELEMENTAL:               %[[VAL_14:.*]] = hlfir.apply %[[VAL_1]], %[[VAL_11]] : (!hlfir.expr<?xf16>, index) -> f16
+// ELEMENTAL:               %[[VAL_15:.*]] = fir.convert %[[VAL_14]] : (f16) -> f32
+// ELEMENTAL:               %[[VAL_16:.*]] = arith.mulf %[[VAL_13]], %[[VAL_15]] : f32
+// ELEMENTAL:               %[[VAL_17:.*]] = arith.addf %[[VAL_12]], %[[VAL_16]] : f32
+// ELEMENTAL:               fir.result %[[VAL_17]] : f32
+// ELEMENTAL:             }
+// ELEMENTAL:             hlfir.yield_element %[[VAL_10]] : f32
+// ELEMENTAL:           }
+// ELEMENTAL:           return %[[VAL_8]] : !hlfir.expr<?xf32>
+// ELEMENTAL:         }
+
+func.func @matmul_vector_matrix_real(%arg0: !hlfir.expr<?xf32>, %arg1: !hlfir.expr<?x?xf16>) -> !hlfir.expr<?xf32> {
+  %res = hlfir.matmul %arg0 %arg1 : (!hlfir.expr<?xf32>, !hlfir.expr<?x?xf16>) -> !hlfir.expr<?xf32>
+  return %res : !hlfir.expr<?xf32>
+}
+// ALL-LABEL:   func.func @matmul_vector_matrix_real(
+// ALL-SAME:                                         %[[VAL_0:.*]]: !hlfir.expr<?xf32>,
+// ALL-SAME:                                         %[[VAL_1:.*]]: !hlfir.expr<?x?xf16>) -> !hlfir.expr<?xf32> {
+
+// NOANSE: hlfir.matmul
+
+// ANSE:           %[[VAL_2:.*]] = arith.constant 0 : index
+// ANSE:           %[[VAL_3:.*]] = arith.constant 1 : index
+// ANSE:           %[[VAL_4:.*]] = arith.constant 0.000000e+00 : f32
+// ANSE:           %[[VAL_5:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr<?xf32>) -> !fir.shape<1>
+// ANSE:           %[[VAL_6:.*]] = hlfir.get_extent %[[VAL_5]] {dim = 0 : index} : (!fir.shape<1>) -> index
+// ANSE:           %[[VAL_7:.*]] = hlfir.shape_of %[[VAL_1]] : (!hlfir.expr<?x?xf16>) -> !fir.shape<2>
+// ANSE:           %[[VAL_8:.*]] = hlfir.get_extent %[[VAL_7]] {dim = 1 : index} : (!fir.shape<2>) -> index
+// ANSE:           %[[VAL_9:.*]] = fir.shape %[[VAL_8]] : (index) -> !fir.shape<1>
+// ANSE:           %[[VAL_10:.*]] = hlfir.eval_in_mem shape %[[VAL_9]] : (!fir.shape<1>) -> !hlfir.expr<?xf32> {
+// ANSE:           ^bb0(%[[VAL_11:.*]]: !fir.ref<!fir.array<?xf32>>):
+// ANSE:             %[[VAL_12:.*]] = fir.embox %[[VAL_11]](%[[VAL_9]]) : (!fir.ref<!fir.array<?xf32>>, !fir.shape<1>) -> !fir.box<!fir.array<?xf32>>
+// ANSE:             fir.do_loop %[[VAL_13:.*]] = %[[VAL_3]] to %[[VAL_8]] step %[[VAL_3]] unordered {
+// ANSE:               %[[VAL_14:.*]]:3 = fir.box_dims %[[VAL_12]], %[[VAL_2]] : (!fir.box<!fir.array<?xf32>>, index) -> (index, index, index)
+// ANSE:               %[[VAL_15:.*]] = arith.subi %[[VAL_14]]#0, %[[VAL_3]] : index
+// ANSE:               %[[VAL_16:.*]] = arith.addi %[[VAL_13]], %[[VAL_15]] : index
+// ANSE:               %[[VAL_17:.*]] = hlfir.designate %[[VAL_12]] (%[[VAL_16]])  : (!fir.box<!fir.array<?xf32>>, index) -> !fir.ref<f32>
+// ANSE:               hlfir.assign %[[VAL_4]] to %[[VAL_17]] : f32, !fir.ref<f32>
+// ANSE:             }
+// ANSE:             fir.do_loop %[[VAL_18:.*]] = %[[VAL_3]] to %[[VAL_6]] step %[[VAL_3]] {
+// ANSE:               fir.do_loop %[[VAL_19:.*]] = %[[VAL_3]] to %[[VAL_8]] step %[[VAL_3]] {
+// ANSE:                 %[[VAL_20:.*]]:3 = fir.box_dims %[[VAL_12]], %[[VAL_2]] : (!fir.box<!fir.array<?xf32>>, index) -> (index, index, index)
+// ANSE:                 %[[VAL_21:.*]] = arith.subi %[[VAL_20]]#0, %[[VAL_3]] : index
+// ANSE:                 %[[VAL_22:.*]] = arith.addi %[[VAL_19]], %[[VAL_21]] : index
+// ANSE:                 %[[VAL_23:.*]] = hlfir.designate %[[VAL_12]] (%[[VAL_22]])  : (!fir.box<!fir.array<?xf32>>, index) -> !fir.ref<f32>
+// ANSE:                 %[[VAL_24:.*]] = fir.load %[[VAL_23]] : !fir.ref<f32>
+// ANSE:                 %[[VAL_25:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_18]] : (!hlfir.expr<?xf32>, index) -> f32
+// ANSE:                 %[[VAL_26:.*]] = hlfir.apply %[[VAL_1]], %[[VAL_18]], %[[VAL_19]] : (!hlfir.expr<?x?xf16>, index, index) -> f16
+// ANSE:                 %[[VAL_27:.*]] = fir.convert %[[VAL_26]] : (f16) -> f32
+// ANSE:                 %[[VAL_28:.*]] = arith.mulf %[[VAL_25]], %[[VAL_27]] : f32
+// ANSE:                 %[[VAL_29:.*]] = arith.addf %[[VAL_24]], %[[VAL_28]] : f32
+// ANSE:                 hlfir.assign %[[VAL_29]] to %[[VAL_23]] : f32, !fir.ref<f32>
+// ANSE:               }
+// ANSE:             }
+// ANSE:           }
+// ANSE:           return %[[VAL_10]] : !hlfir.expr<?xf32>
+// ANSE:         }
+
+// ELEMENTAL:           %[[VAL_2:.*]] = arith.constant 1 : index
+// ELEMENTAL:           %[[VAL_3:.*]] = arith.constant 0.000000e+00 : f32
+// ELEMENTAL:           %[[VAL_4:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr<?xf32>) -> !fir.shape<1>
+// ELEMENTAL:           %[[VAL_5:.*]] = hlfir.get_extent %[[VAL_4]] {dim = 0 : index} : (!fir.shape<1>) -> index
+// ELEMENTAL:           %[[VAL_6:.*]] = hlfir.shape_of %[[VAL_1]] : (!hlfir.expr<?x?xf16>) -> !fir.shape<2>
+// ELEMENTAL:           %[[VAL_7:.*]] = hlfir.get_extent %[[VAL_6]] {dim = 1 : index} : (!fir.shape<2>) -> index
+// ELEMENTAL:           %[[VAL_8:.*]] = fir.shape %[[VAL_7]] : (index) -> !fir.shape<1>
+// ELEMENTAL:           %[[VAL_9:.*]] = hlfir.elemental %[[VAL_8]] unordered : (!fir.shape<1>) -> !hlfir.expr<?xf32> {
+// ELEMENTAL:           ^bb0(%[[VAL_10:.*]]: index):
+// ELEMENTAL:             %[[VAL_11:.*]] = fir.do_loop %[[VAL_12:.*]] = %[[VAL_2]] to %[[VAL_5]] step %[[VAL_2]] iter_args(%[[VAL_13:.*]] = %[[VAL_3]]) -> (f32) {
+// ELEMENTAL:               %[[VAL_14:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_12]] : (!hlfir.expr<?xf32>, index) -> f32
+// ELEMENTAL:               %[[VAL_15:.*]] = hlfir.apply %[[VAL_1]], %[[VAL_12]], %[[VAL_10]] : (!hlfir.expr<?x?xf16>, index, index) -> f16
+// ELEMENTAL:               %[[VAL_16:.*]] = fir.convert %[[VAL_15]] : (f16) -> f32
+// ELEMENTAL:               %[[VAL_17:.*]] = arith.mulf %[[VAL_14]], %[[VAL_16]] : f32
+// ELEMENTAL:               %[[VAL_18:.*]] = arith.addf %[[VAL_13]], %[[VAL_17]] : f32
+// ELEMENTAL:               fir.result %[[VAL_18]] : f32
+// ELEMENTAL:             }
+// ELEMENTAL:             hlfir.yield_element %[[VAL_11]] : f32
+// ELEMENTAL:           }
+// ELEMENTAL:           return %[[VAL_9]] : !hlfir.expr<?xf32>
+// ELEMENTAL:         }
+
+func.func @matmul_transpose_matrix_matrix_integer(%arg0: !hlfir.expr<?x?xi16>, %arg1: !hlfir.expr<?x?xi32>) -> !hlfir.expr<?x?xi32> {
+  %res = hlfir.matmul_transpose %arg0 %arg1 : (!hlfir.expr<?x?xi16>, !hlfir.expr<?x?xi32>) -> !hlfir.expr<?x?xi32>
+  return %res : !hlfir.expr<?x?xi32>
+}
+// ALL-LABEL:   func.func @matmul_transpose_matrix_matrix_integer(
+// ALL-SAME:                                                      %[[VAL_0:.*]]: !hlfir.expr<?x?xi16>,
+// ALL-SAME:                                                      %[[VAL_1:.*]]: !hlfir.expr<?x?xi32>) -> !hlfir.expr<?x?xi32> {
+// ALL:           %[[VAL_2:.*]] = arith.constant 1 : index
+// ALL:           %[[VAL_3:.*]] = arith.constant 0 : i32
+// ALL:           %[[VAL_4:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr<?x?xi16>) -> !fir.shape<2>
+// ALL:           %[[VAL_5:.*]] = hlfir.get_extent %[[VAL_4]] {dim = 0 : index} : (!fir.shape<2>) -> index
+// ALL:           %[[VAL_6:.*]] = hlfir.get_extent %[[VAL_4]] {dim = 1 : index} : (!fir.shape<2>) -> index
+// ALL:           %[[VAL_7:.*]] = hlfir.shape_of %[[VAL_1]] : (!hlfir.expr<?x?xi32>) -> !fir.shape<2>
+// ALL:           %[[VAL_8:.*]] = hlfir.get_extent %[[VAL_7]] {dim = 1 : index} : (!fir.shape<2>) -> index
+// ALL:           %[[VAL_9:.*]] = fir.shape %[[VAL_6]], %[[VAL_8]] : (index, index) -> !fir.shape<2>
+// ALL:           %[[VAL_10:.*]] = hlfir.elemental %[[VAL_9]] unordered : (!fir.shape<2>) -> !hlfir.expr<?x?xi32> {
+// ALL:           ^bb0(%[[VAL_11:.*]]: index, %[[VAL_12:.*]]: index):
+// ALL:             %[[VAL_13:.*]] = fir.do_loop %[[VAL_14:.*]] = %[[VAL_2]] to %[[VAL_5]] step %[[VAL_2]] unordered iter_args(%[[VAL_15:.*]] = %[[VAL_3]]) -> (i32) {
+// ALL:               %[[VAL_16:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_14]], %[[VAL_11]] : (!hlfir.expr<?x?xi16>, index, index) -> i16
+// ALL:               %[[VAL_17:.*]] = hlfir.apply %[[VAL_1]], %[[VAL_14]], %[[VAL_12]] : (!hlfir.expr<?x?xi32>, index, index) -> i32
+// ALL:               %[[VAL_18:.*]] = fir.convert %[[VAL_16]] : (i16) -> i32
+// ALL:               %[[VAL_19:.*]] = arith.muli %[[VAL_18]], %[[VAL_17]] : i32
+// ALL:               %[[VAL_20:.*]] = arith.addi %[[VAL_15]], %[[VAL_19]] : i32
+// ALL:               fir.result %[[VAL_20]] : i32
+// ALL:             }
+// ALL:             hlfir.yield_element %[[VAL_13]] : i32
+// ALL:           }
+// ALL:           return %[[VAL_10]] : !hlfir.expr<?x?xi32>
+// ALL:         }
+
+func.func @matmul_transpose_matrix_vector_real(%arg0: !hlfir.expr<?x?xf32>, %arg1: !hlfir.expr<?xf16>) -> !hlfir.expr<?xf32> {
+  %res = hlfir.matmul_transpose %arg0 %arg1 : (!hlfir.expr<?x?xf32>, !hlfir.expr<?xf16>) -> !hlfir.expr<?xf32>
+  return %res : !hlfir.expr<?xf32>
+}
+// ALL-LABEL:   func.func @matmul_transpose_matrix_vector_real(
+// ALL-SAME:                                                   %[[VAL_0:.*]]: !hlfir.expr<?x?xf32>,
+// ALL-SAME:                                                   %[[VAL_1:.*]]: !hlfir.expr<?xf16>) -> !hlfir.expr<?xf32> {
+// ALL:           %[[VAL_2:.*]] = arith.constant 1 : index
+// ALL:           %[[VAL_3:.*]] = arith.constant 0.000000e+00 : f32
+// ALL:           %[[VAL_4:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr<?x?xf32>) -> !fir.shape<2>
+// ALL:           %[[VAL_5:.*]] = hlfir.get_extent %[[VAL_4]] {dim = 0 : index} : (!fir.shape<2>) -> index
+// ALL:           %[[VAL_6:.*]] = hlfir.get_extent %[[VAL_4]] {dim = 1 : index} : (!fir.shape<2>) -> index
+// ALL:           %[[VAL_7:.*]] = fir.shape %[[VAL_6]] : (index) -> !fir.shape<1>
+// ALL:           %[[VAL_8:.*]] = hlfir.elemental %[[VAL_7]] unordered : (!fir.shape<1>) -> !hlfir.expr<?xf32> {
+// ALL:           ^bb0(%[[VAL_9:.*]]: index):
+// ALL:             %[[VAL_10:.*]] = fir.do_loop %[[VAL_11:.*]] = %[[VAL_2]] to %[[VAL_5]] step %[[VAL_2]] iter_args(%[[VAL_12:.*]] = %[[VAL_3]]) -> (f32) {
+// ALL:               %[[VAL_13:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_11]], %[[VAL_9]] : (!hlfir.expr<?x?xf32>, index, index) -> f32
+// ALL:               %[[VAL_14:.*]] = hlfir.apply %[[VAL_1]], %[[VAL_11]] : (!hlfir.expr<?xf16>, index) -> f16
+// ALL:               %[[VAL_15:.*]] = fir.convert %[[VAL_14]] : (f16) -> f32
+// ALL:               %[[VAL_16:.*]] = arith.mulf %[[VAL_13]], %[[VAL_15]] : f32
+// ALL:               %[[VAL_17:.*]] = arith.addf %[[VAL_12]], %[[VAL_16]] : f32
+// ALL:               fir.result %[[VAL_17]] : f32
+// ALL:             }
+// ALL:             hlfir.yield_element %[[VAL_10]] : f32
+// ALL:           }
+// ALL:           return %[[VAL_8]] : !hlfir.expr<?xf32>
+// ALL:         }
+
+// Check that the inner-product loop uses the best known extent
+// of the input matrices:
+func.func @matmul_matrix_matrix_deduce_bounds(%arg0: !hlfir.expr<?x10xi16>, %arg1: !hlfir.expr<?x?xi32>, %arg2: !hlfir.expr<10x?xi16>) -> (!hlfir.expr<?x?xi32>, !hlfir.expr<?x?xi32>) {
+  %res1 = hlfir.matmul %arg0 %arg1 : (!hlfir.expr<?x10xi16>, !hlfir.expr<?x?xi32>) -> !hlfir.expr<?x?xi32>
+  %res2 = hlfir.matmul %arg1 %arg2 : (!hlfir.expr<?x?xi32>, !hlfir.expr<10x?xi16>) -> !hlfir.expr<?x?xi32>
+  return %res1, %res2 : !hlfir.expr<?x?xi32>, !hlfir.expr<?x?xi32>
+}
+// ALL-LABEL:   func.func @matmul_matrix_matrix_deduce_bounds(
+
+// ANSE:           %[[VAL_6:.*]] = arith.constant 10 : index
+// ANSE:           hlfir.eval_in_mem shape {{.*}}
+// ANSE:             fir.do_loop
+// ANSE:               fir.do_loop
+// ANSE:             fir.do_loop %{{.*}} = %{{.*}} to %[[VAL_6]]
+// ANSE:               fir.do_loop
+// ANSE:                 fir.do_loop
+// ANSE:           hlfir.eval_in_mem shape {{.*}}
+// ANSE:             fir.do_loop
+// ANSE:               fir.do_loop
+// ANSE:             fir.do_loop %{{.*}} = %{{.*}} to %[[VAL_6]]
+// ANSE:               fir.do_loop
+// ANSE:                 fir.do_loop
+
+// ELEMENTAL:           %[[VAL_5:.*]] = arith.constant 10 : index
+// ELEMENTAL:           hlfir.elemental %{{.*}}
+// ELEMENTAL:             fir.do_loop %{{.*}} = %{{.*}} to %[[VAL_5]]
+// ELEMENTAL:           hlfir.elemental %{{.*}}
+// ELEMENTAL:             fir.do_loop %{{.*}} = %{{.*}} to %[[VAL_5]]

>From 27ec9ad13578da08c7279c016fd5434d85112ae5 Mon Sep 17 00:00:00 2001
From: Slava Zakharin <szakharin at nvidia.com>
Date: Mon, 13 Jan 2025 16:19:05 -0800
Subject: [PATCH 2/3] Removed commented out code.

---
 .../Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp | 7 -------
 1 file changed, 7 deletions(-)

diff --git a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
index 0fd535b4290799..e0914e0f275fcb 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
@@ -706,7 +706,6 @@ class MatmulConversion : public mlir::OpRewritePattern<Op> {
         -> llvm::SmallVector<mlir::Value, 0> {
       hlfir::Entity resultElement =
           hlfir::getElementAt(loc, builder, result, oneBasedIndices);
-      //      builder.create<fir::StoreOp>(loc, initValue, resultElement);
       builder.create<hlfir::AssignOp>(loc, initValue, resultElement);
       return {};
     };
@@ -742,8 +741,6 @@ class MatmulConversion : public mlir::OpRewritePattern<Op> {
             loc, builder, resultElementType, resultElementValue,
             lhsElementValue, rhsElementValue);
         builder.create<hlfir::AssignOp>(loc, productValue, resultElement);
-        //        builder.create<fir::StoreOp>(loc, productValue,
-        //        resultElement);
         return {};
       };
 
@@ -781,8 +778,6 @@ class MatmulConversion : public mlir::OpRewritePattern<Op> {
             loc, builder, resultElementType, resultElementValue,
             lhsElementValue, rhsElementValue);
         builder.create<hlfir::AssignOp>(loc, productValue, resultElement);
-        //        builder.create<fir::StoreOp>(loc, productValue,
-        //        resultElement);
         return {};
       };
       hlfir::genLoopNestWithReductions(
@@ -815,8 +810,6 @@ class MatmulConversion : public mlir::OpRewritePattern<Op> {
             loc, builder, resultElementType, resultElementValue,
             lhsElementValue, rhsElementValue);
         builder.create<hlfir::AssignOp>(loc, productValue, resultElement);
-        //        builder.create<fir::StoreOp>(loc, productValue,
-        //        resultElement);
         return {};
       };
       hlfir::genLoopNestWithReductions(

>From 56cdaf4819dce6990e67b4a75aa5b3a775effc8f Mon Sep 17 00:00:00 2001
From: Slava Zakharin <szakharin at nvidia.com>
Date: Tue, 14 Jan 2025 10:13:06 -0800
Subject: [PATCH 3/3] Updated comments.

---
 .../Transforms/SimplifyHLFIRIntrinsics.cpp      | 17 ++++++++++++++---
 1 file changed, 14 insertions(+), 3 deletions(-)

diff --git a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
index e0914e0f275fcb..0fe3620b7f1ae3 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
@@ -534,10 +534,14 @@ class MatmulConversion : public mlir::OpRewritePattern<Op> {
     // in their lowermost dimensions.
     // Especially, when LLVM can recognize the continuity
     // and vectorize the loops properly.
-    // TODO: we need to recognize the cases when the continuity
+    // Note that the contiguous MATMUL inlining is correct
+    // even when the input arrays are not contiguous.
+    // TODO: we can try to recognize the cases when the continuity
     // is not statically obvious and try to generate an explicitly
-    // continuous version under a dynamic check. The fallback
-    // implementation may use genElementalMatmul() with
+    // continuous version under a dynamic check. This should allow
+    // LLVM to vectorize the loops better. Note that this can
+    // also be postponed up to the LoopVersioning pass.
+    // The fallback implementation may use genElementalMatmul() with
     // an hlfir.assign into the result of eval_in_mem.
     mlir::LogicalResult rewriteResult =
         genContiguousMatmul(loc, builder, hlfir::Entity{resultArray},
@@ -629,6 +633,13 @@ class MatmulConversion : public mlir::OpRewritePattern<Op> {
     if (mlir::isa<fir::LogicalType>(type))
       return builder.createConvert(loc, builder.getIntegerType(1), value);
 
+    // TODO: the multiplications/additions by/of zero resulting from
+    // complex * real are optimized by LLVM under -fno-signed-zeros
+    // -fno-honor-nans.
+    // We can make them disappear by default if we:
+    //   * either expand the complex multiplication into real
+    //     operations, OR
+    //   * set nnan nsz fast-math flags to the complex operations.
     if (fir::isa_complex(type) && !fir::isa_complex(value.getType())) {
       mlir::Value zeroCmplx = fir::factory::createZeroValue(builder, loc, type);
       fir::factory::Complex helper(builder, loc);



More information about the flang-commits mailing list