[flang-commits] [flang] [flang] Inline hlfir.matmul[_transpose]. (PR #122821)
Slava Zakharin via flang-commits
flang-commits at lists.llvm.org
Tue Jan 14 10:13:27 PST 2025
https://github.com/vzakhari updated https://github.com/llvm/llvm-project/pull/122821
>From 9e1c219361b41f878ad53218a79b4cf1554aa4cf Mon Sep 17 00:00:00 2001
From: Slava Zakharin <szakharin at nvidia.com>
Date: Mon, 13 Jan 2025 14:27:27 -0800
Subject: [PATCH 1/3] [flang] Inline hlfir.matmul[_transpose].
Inlining `hlfir.matmul` as `hlfir.eval_in_mem` does not allow
to get rid of a temporary array in many cases, but it may still be
much better allowing to:
* Get rid of any overhead related to calling runtime MATMUL
(such as descriptors creation).
* Use CPU-specific vectorization cost model for matmul loops,
which Fortran runtime cannot currently do.
* Optimize matmul of known-size arrays by complete unrolling.
One of the drawbacks of `hlfir.eval_in_mem` inlining is that
the ops inside it with store memory effects block the current
MLIR CSE, so I decided to run this inlining late in the pipeline.
There is a source commen explaining the CSE issue in more detail.
Straightforward inlining of `hlfir.matmul` as an `hlfir.elemental`
is not good for performance, and I got performance regressions
with it comparing to Fortran runtime implementation. I put it
under an enigneering option for experiments.
At the same time, inlining `hlfir.matmul_transpose` as `hlfir.elemental`
seems to be a good approach, e.g. it allows getting rid of a temporay
array in cases like: `A(:)=B(:)+MATMUL(TRANSPOSE(C(:,:)),D(:))`.
This patch improves performance of galgel and tonto a little bit.
---
.../flang/Optimizer/Builder/FIRBuilder.h | 9 +
.../flang/Optimizer/Builder/HLFIRTools.h | 5 +
flang/include/flang/Optimizer/HLFIR/Passes.td | 11 +
flang/lib/Optimizer/Builder/FIRBuilder.cpp | 14 +
flang/lib/Optimizer/Builder/HLFIRTools.cpp | 17 +-
.../Transforms/SimplifyHLFIRIntrinsics.cpp | 452 ++++++++++++
flang/lib/Optimizer/Passes/Pipelines.cpp | 6 +
flang/test/Driver/mlir-pass-pipeline.f90 | 4 +
flang/test/Fir/basic-program.fir | 4 +
.../simplify-hlfir-intrinsics-matmul.fir | 660 ++++++++++++++++++
10 files changed, 1179 insertions(+), 3 deletions(-)
create mode 100644 flang/test/HLFIR/simplify-hlfir-intrinsics-matmul.fir
diff --git a/flang/include/flang/Optimizer/Builder/FIRBuilder.h b/flang/include/flang/Optimizer/Builder/FIRBuilder.h
index c5d86e713f253a..ea658fb16a36c3 100644
--- a/flang/include/flang/Optimizer/Builder/FIRBuilder.h
+++ b/flang/include/flang/Optimizer/Builder/FIRBuilder.h
@@ -804,6 +804,15 @@ elideLengthsAlreadyInType(mlir::Type type, mlir::ValueRange lenParams);
/// Get the address space which should be used for allocas
uint64_t getAllocaAddressSpace(mlir::DataLayout *dataLayout);
+/// The two vectors of MLIR values have the following property:
+/// \p extents1[i] must have the same value as \p extents2[i]
+/// The function returns a new vector of MLIR values that preserves
+/// the same property vs \p extents1 and \p extents2, but allows
+/// more optimizations. For example, if extents1[j] is a known constant,
+/// and extents2[j] is not, then result[j] is the MLIR value extents1[j].
+llvm::SmallVector<mlir::Value> deduceOptimalExtents(mlir::ValueRange extents1,
+ mlir::ValueRange extents2);
+
} // namespace fir::factory
#endif // FORTRAN_OPTIMIZER_BUILDER_FIRBUILDER_H
diff --git a/flang/include/flang/Optimizer/Builder/HLFIRTools.h b/flang/include/flang/Optimizer/Builder/HLFIRTools.h
index c8aad644bc784a..6e85b8f4ddf86e 100644
--- a/flang/include/flang/Optimizer/Builder/HLFIRTools.h
+++ b/flang/include/flang/Optimizer/Builder/HLFIRTools.h
@@ -508,6 +508,11 @@ genTypeAndKindConvert(mlir::Location loc, fir::FirOpBuilder &builder,
hlfir::Entity source, mlir::Type toType,
bool preserveLowerBounds);
+/// A shortcut for loadTrivialScalar(getElementAt()),
+/// which designates and loads an element of an array.
+Entity loadElementAt(mlir::Location loc, fir::FirOpBuilder &builder,
+ Entity entity, mlir::ValueRange oneBasedIndices);
+
} // namespace hlfir
#endif // FORTRAN_OPTIMIZER_BUILDER_HLFIRTOOLS_H
diff --git a/flang/include/flang/Optimizer/HLFIR/Passes.td b/flang/include/flang/Optimizer/HLFIR/Passes.td
index 644f1e3c3af2b8..90cf6e74241bd0 100644
--- a/flang/include/flang/Optimizer/HLFIR/Passes.td
+++ b/flang/include/flang/Optimizer/HLFIR/Passes.td
@@ -43,6 +43,17 @@ def LowerHLFIROrderedAssignments : Pass<"lower-hlfir-ordered-assignments", "::ml
def SimplifyHLFIRIntrinsics : Pass<"simplify-hlfir-intrinsics"> {
let summary = "Simplify HLFIR intrinsic operations that don't need to result in runtime calls";
+ let options = [Option<"allowNewSideEffects", "allow-new-side-effects", "bool",
+ /*default=*/"false",
+ "If enabled, then the HLFIR operations simplification "
+ "may introduce operations with side effects. "
+ "For example, hlfir.matmul may be inlined as "
+ "and hlfir.eval_in_mem with hlfir.assign inside it."
+ "The hlfir.assign has a write effect on the memory "
+ "argument of hlfir.eval_in_mem, which may block "
+ "some existing MLIR transformations (e.g. CSE) "
+ "that otherwise would have been possible across "
+ "the hlfir.matmul.">];
}
def InlineElementals : Pass<"inline-elementals"> {
diff --git a/flang/lib/Optimizer/Builder/FIRBuilder.cpp b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
index d01becfe800937..218f98ef9ef429 100644
--- a/flang/lib/Optimizer/Builder/FIRBuilder.cpp
+++ b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
@@ -1740,3 +1740,17 @@ uint64_t fir::factory::getAllocaAddressSpace(mlir::DataLayout *dataLayout) {
return mlir::cast<mlir::IntegerAttr>(addrSpace).getUInt();
return 0;
}
+
+llvm::SmallVector<mlir::Value>
+fir::factory::deduceOptimalExtents(mlir::ValueRange extents1,
+ mlir::ValueRange extents2) {
+ llvm::SmallVector<mlir::Value> extents;
+ extents.reserve(extents1.size());
+ for (auto [extent1, extent2] : llvm::zip(extents1, extents2)) {
+ if (!fir::getIntIfConstant(extent1) && fir::getIntIfConstant(extent2))
+ extents.push_back(extent2);
+ else
+ extents.push_back(extent1);
+ }
+ return extents;
+}
diff --git a/flang/lib/Optimizer/Builder/HLFIRTools.cpp b/flang/lib/Optimizer/Builder/HLFIRTools.cpp
index 94238bc24e453d..5e5d0bbd681326 100644
--- a/flang/lib/Optimizer/Builder/HLFIRTools.cpp
+++ b/flang/lib/Optimizer/Builder/HLFIRTools.cpp
@@ -939,8 +939,10 @@ llvm::SmallVector<mlir::Value> hlfir::genLoopNestWithReductions(
doLoop = builder.create<fir::DoLoopOp>(loc, one, ub, one, isUnordered,
/*finalCountValue=*/false,
parentLoop.getRegionIterArgs());
- // Return the results of the child loop from its parent loop.
- builder.create<fir::ResultOp>(loc, doLoop.getResults());
+ if (!reductionInits.empty()) {
+ // Return the results of the child loop from its parent loop.
+ builder.create<fir::ResultOp>(loc, doLoop.getResults());
+ }
}
builder.setInsertionPointToStart(doLoop.getBody());
@@ -955,7 +957,8 @@ llvm::SmallVector<mlir::Value> hlfir::genLoopNestWithReductions(
reductionValues =
genBody(loc, builder, oneBasedIndices, parentLoop.getRegionIterArgs());
builder.setInsertionPointToEnd(parentLoop.getBody());
- builder.create<fir::ResultOp>(loc, reductionValues);
+ if (!reductionValues.empty())
+ builder.create<fir::ResultOp>(loc, reductionValues);
builder.setInsertionPointAfter(outerLoop);
return outerLoop->getResults();
}
@@ -1410,3 +1413,11 @@ void hlfir::computeEvaluateOpIn(mlir::Location loc, fir::FirOpBuilder &builder,
builder.clone(op, mapper);
return;
}
+
+hlfir::Entity hlfir::loadElementAt(mlir::Location loc,
+ fir::FirOpBuilder &builder,
+ hlfir::Entity entity,
+ mlir::ValueRange oneBasedIndices) {
+ return loadTrivialScalar(loc, builder,
+ getElementAt(loc, builder, entity, oneBasedIndices));
+}
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
index 314ced8679521a..0fd535b4290799 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
@@ -28,6 +28,13 @@ namespace hlfir {
#include "flang/Optimizer/HLFIR/Passes.h.inc"
} // namespace hlfir
+#define DEBUG_TYPE "simplify-hlfir-intrinsics"
+
+static llvm::cl::opt<bool> forceMatmulAsElemental(
+ "flang-inline-matmul-as-elemental",
+ llvm::cl::desc("Expand hlfir.matmul as elemental operation"),
+ llvm::cl::init(false));
+
namespace {
class TransposeAsElementalConversion
@@ -467,9 +474,438 @@ class CShiftAsElementalConversion
}
};
+template <typename Op>
+class MatmulConversion : public mlir::OpRewritePattern<Op> {
+public:
+ using mlir::OpRewritePattern<Op>::OpRewritePattern;
+
+ llvm::LogicalResult
+ matchAndRewrite(Op matmul, mlir::PatternRewriter &rewriter) const override {
+ mlir::Location loc = matmul.getLoc();
+ fir::FirOpBuilder builder{rewriter, matmul.getOperation()};
+ hlfir::Entity lhs = hlfir::Entity{matmul.getLhs()};
+ hlfir::Entity rhs = hlfir::Entity{matmul.getRhs()};
+ mlir::Value resultShape, innerProductExtent;
+ std::tie(resultShape, innerProductExtent) =
+ genResultShape(loc, builder, lhs, rhs);
+
+ if (forceMatmulAsElemental || isMatmulTranspose) {
+ // Generate hlfir.elemental that produces the result of
+ // MATMUL/MATMUL(TRANSPOSE).
+ // Note that this implementation is very suboptimal for MATMUL,
+ // but is quite good for MATMUL(TRANSPOSE), e.g.:
+ // R(1:N) = R(1:N) + MATMUL(TRANSPOSE(X(1:N,1:N)), Y(1:N))
+ // Inlining MATMUL(TRANSPOSE) as hlfir.elemental may result
+ // in merging the inner product computation with the elemental
+ // addition. Note that the inner product computation will
+ // benefit from processing the lowermost dimensions of X and Y,
+ // which may be the best when they are contiguous.
+ //
+ // This is why we always inline MATMUL(TRANSPOSE) as an elemental.
+ // MATMUL is inlined below by default unless forceMatmulAsElemental.
+ hlfir::ExprType resultType =
+ mlir::cast<hlfir::ExprType>(matmul.getType());
+ hlfir::ElementalOp newOp = genElementalMatmul(
+ loc, builder, resultType, resultShape, lhs, rhs, innerProductExtent);
+ rewriter.replaceOp(matmul, newOp);
+ return mlir::success();
+ }
+
+ // Generate hlfir.eval_in_mem to mimic the MATMUL implementation
+ // from Fortran runtime. The implementation needs to operate
+ // with the result array as an in-memory object.
+ hlfir::EvaluateInMemoryOp evalOp =
+ builder.create<hlfir::EvaluateInMemoryOp>(
+ loc, mlir::cast<hlfir::ExprType>(matmul.getType()), resultShape);
+ builder.setInsertionPointToStart(&evalOp.getBody().front());
+
+ // Embox the raw array pointer to simplify designating it.
+ // TODO: this currently results in redundant lower bounds
+ // addition for the designator, but this should be fixed in
+ // hlfir::Entity::mayHaveNonDefaultLowerBounds().
+ mlir::Value resultArray = evalOp.getMemory();
+ mlir::Type arrayType = fir::dyn_cast_ptrEleTy(resultArray.getType());
+ resultArray = builder.createBox(loc, fir::BoxType::get(arrayType),
+ resultArray, resultShape, /*slice=*/nullptr,
+ /*lengths=*/{}, /*tdesc=*/nullptr);
+
+ // The contiguous MATMUL version is best for the cases
+ // where the input arrays and (maybe) the result are contiguous
+ // in their lowermost dimensions.
+ // Especially, when LLVM can recognize the continuity
+ // and vectorize the loops properly.
+ // TODO: we need to recognize the cases when the continuity
+ // is not statically obvious and try to generate an explicitly
+ // continuous version under a dynamic check. The fallback
+ // implementation may use genElementalMatmul() with
+ // an hlfir.assign into the result of eval_in_mem.
+ mlir::LogicalResult rewriteResult =
+ genContiguousMatmul(loc, builder, hlfir::Entity{resultArray},
+ resultShape, lhs, rhs, innerProductExtent);
+
+ if (mlir::failed(rewriteResult)) {
+ // Erase the unclaimed eval_in_mem op.
+ rewriter.eraseOp(evalOp);
+ return rewriter.notifyMatchFailure(matmul,
+ "genContiguousMatmul() failed");
+ }
+
+ rewriter.replaceOp(matmul, evalOp);
+ return mlir::success();
+ }
+
+private:
+ static constexpr bool isMatmulTranspose =
+ std::is_same_v<Op, hlfir::MatmulTransposeOp>;
+
+ // Return a tuple of:
+ // * A fir.shape operation representing the shape of the result
+ // of a MATMUL/MATMUL(TRANSPOSE).
+ // * An extent of the dimensions of the input array
+ // that are processed during the inner product computation.
+ static std::tuple<mlir::Value, mlir::Value>
+ genResultShape(mlir::Location loc, fir::FirOpBuilder &builder,
+ hlfir::Entity input1, hlfir::Entity input2) {
+ mlir::Value input1Shape = hlfir::genShape(loc, builder, input1);
+ llvm::SmallVector<mlir::Value> input1Extents =
+ hlfir::getExplicitExtentsFromShape(input1Shape, builder);
+ if (input1Shape.getUses().empty())
+ input1Shape.getDefiningOp()->erase();
+ mlir::Value input2Shape = hlfir::genShape(loc, builder, input2);
+ llvm::SmallVector<mlir::Value> input2Extents =
+ hlfir::getExplicitExtentsFromShape(input2Shape, builder);
+ if (input2Shape.getUses().empty())
+ input2Shape.getDefiningOp()->erase();
+
+ llvm::SmallVector<mlir::Value, 2> newExtents;
+ mlir::Value innerProduct1Extent, innerProduct2Extent;
+ if (input1Extents.size() == 1) {
+ assert(!isMatmulTranspose &&
+ "hlfir.matmul_transpose's first operand must be rank-2 array");
+ assert(input2Extents.size() == 2 &&
+ "hlfir.matmul second argument must be rank-2 array");
+ newExtents.push_back(input2Extents[1]);
+ innerProduct1Extent = input1Extents[0];
+ innerProduct2Extent = input2Extents[0];
+ } else {
+ if (input2Extents.size() == 1) {
+ assert(input1Extents.size() == 2 &&
+ "hlfir.matmul first argument must be rank-2 array");
+ if constexpr (isMatmulTranspose)
+ newExtents.push_back(input1Extents[1]);
+ else
+ newExtents.push_back(input1Extents[0]);
+ } else {
+ assert(input1Extents.size() == 2 && input2Extents.size() == 2 &&
+ "hlfir.matmul arguments must be rank-2 arrays");
+ if constexpr (isMatmulTranspose)
+ newExtents.push_back(input1Extents[1]);
+ else
+ newExtents.push_back(input1Extents[0]);
+
+ newExtents.push_back(input2Extents[1]);
+ }
+ if constexpr (isMatmulTranspose)
+ innerProduct1Extent = input1Extents[0];
+ else
+ innerProduct1Extent = input1Extents[1];
+
+ innerProduct2Extent = input2Extents[0];
+ }
+ // The inner product dimensions of the input arrays
+ // must match. Pick the best (e.g. constant) out of them
+ // so that the inner product loop bound can be used in
+ // optimizations.
+ llvm::SmallVector<mlir::Value> innerProductExtent =
+ fir::factory::deduceOptimalExtents({innerProduct1Extent},
+ {innerProduct2Extent});
+ return {builder.create<fir::ShapeOp>(loc, newExtents),
+ innerProductExtent[0]};
+ }
+
+ static mlir::Value castToProductType(mlir::Location loc,
+ fir::FirOpBuilder &builder,
+ mlir::Value value, mlir::Type type) {
+ if (mlir::isa<fir::LogicalType>(type))
+ return builder.createConvert(loc, builder.getIntegerType(1), value);
+
+ if (fir::isa_complex(type) && !fir::isa_complex(value.getType())) {
+ mlir::Value zeroCmplx = fir::factory::createZeroValue(builder, loc, type);
+ fir::factory::Complex helper(builder, loc);
+ mlir::Type partType = helper.getComplexPartType(type);
+ return helper.insertComplexPart(
+ zeroCmplx, castToProductType(loc, builder, value, partType),
+ /*isImagPart=*/false);
+ }
+ return builder.createConvert(loc, type, value);
+ }
+
+ // Generate an update of the inner product value:
+ // acc += v1 * v2, OR
+ // acc ||= v1 && v2
+ static mlir::Value genAccumulateProduct(mlir::Location loc,
+ fir::FirOpBuilder &builder,
+ mlir::Type resultType,
+ mlir::Value acc, mlir::Value v1,
+ mlir::Value v2) {
+ acc = castToProductType(loc, builder, acc, resultType);
+ v1 = castToProductType(loc, builder, v1, resultType);
+ v2 = castToProductType(loc, builder, v2, resultType);
+ mlir::Value result;
+ if (mlir::isa<mlir::FloatType>(resultType))
+ result = builder.create<mlir::arith::AddFOp>(
+ loc, acc, builder.create<mlir::arith::MulFOp>(loc, v1, v2));
+ else if (mlir::isa<mlir::ComplexType>(resultType))
+ result = builder.create<fir::AddcOp>(
+ loc, acc, builder.create<fir::MulcOp>(loc, v1, v2));
+ else if (mlir::isa<mlir::IntegerType>(resultType))
+ result = builder.create<mlir::arith::AddIOp>(
+ loc, acc, builder.create<mlir::arith::MulIOp>(loc, v1, v2));
+ else if (mlir::isa<fir::LogicalType>(resultType))
+ result = builder.create<mlir::arith::OrIOp>(
+ loc, acc, builder.create<mlir::arith::AndIOp>(loc, v1, v2));
+ else
+ llvm_unreachable("unsupported type");
+
+ return builder.createConvert(loc, resultType, result);
+ }
+
+ static mlir::LogicalResult
+ genContiguousMatmul(mlir::Location loc, fir::FirOpBuilder &builder,
+ hlfir::Entity result, mlir::Value resultShape,
+ hlfir::Entity lhs, hlfir::Entity rhs,
+ mlir::Value innerProductExtent) {
+ // This code does not support MATMUL(TRANSPOSE), and it is supposed
+ // to be inlined as hlfir.elemental.
+ if constexpr (isMatmulTranspose)
+ return mlir::failure();
+
+ mlir::OpBuilder::InsertionGuard guard(builder);
+ mlir::Type resultElementType = result.getFortranElementType();
+ llvm::SmallVector<mlir::Value, 2> resultExtents =
+ mlir::cast<fir::ShapeOp>(resultShape.getDefiningOp()).getExtents();
+
+ // The inner product loop may be unordered if FastMathFlags::reassoc
+ // transformations are allowed. The integer/logical inner product is
+ // always unordered.
+ // Note that isUnordered is currently applied to all loops
+ // in the loop nests generated below, while it has to be applied
+ // only to one.
+ bool isUnordered = mlir::isa<mlir::IntegerType>(resultElementType) ||
+ mlir::isa<fir::LogicalType>(resultElementType) ||
+ static_cast<bool>(builder.getFastMathFlags() &
+ mlir::arith::FastMathFlags::reassoc);
+
+ // Insert the initialization loop nest that fills the whole result with
+ // zeroes.
+ mlir::Value initValue =
+ fir::factory::createZeroValue(builder, loc, resultElementType);
+ auto genInitBody = [&](mlir::Location loc, fir::FirOpBuilder &builder,
+ mlir::ValueRange oneBasedIndices,
+ mlir::ValueRange reductionArgs)
+ -> llvm::SmallVector<mlir::Value, 0> {
+ hlfir::Entity resultElement =
+ hlfir::getElementAt(loc, builder, result, oneBasedIndices);
+ // builder.create<fir::StoreOp>(loc, initValue, resultElement);
+ builder.create<hlfir::AssignOp>(loc, initValue, resultElement);
+ return {};
+ };
+
+ hlfir::genLoopNestWithReductions(loc, builder, resultExtents,
+ /*reductionInits=*/{}, genInitBody,
+ /*isUnordered=*/true);
+
+ if (lhs.getRank() == 2 && rhs.getRank() == 2) {
+ // LHS(NROWS,N) * RHS(N,NCOLS) -> RESULT(NROWS,NCOLS)
+ //
+ // Insert the computation loop nest:
+ // DO 2 K = 1, N
+ // DO 2 J = 1, NCOLS
+ // DO 2 I = 1, NROWS
+ // 2 RESULT(I,J) = RESULT(I,J) + LHS(I,K)*RHS(K,J)
+ auto genMatrixMatrix = [&](mlir::Location loc, fir::FirOpBuilder &builder,
+ mlir::ValueRange oneBasedIndices,
+ mlir::ValueRange reductionArgs)
+ -> llvm::SmallVector<mlir::Value, 0> {
+ mlir::Value I = oneBasedIndices[0];
+ mlir::Value J = oneBasedIndices[1];
+ mlir::Value K = oneBasedIndices[2];
+ hlfir::Entity resultElement =
+ hlfir::getElementAt(loc, builder, result, {I, J});
+ hlfir::Entity resultElementValue =
+ hlfir::loadTrivialScalar(loc, builder, resultElement);
+ hlfir::Entity lhsElementValue =
+ hlfir::loadElementAt(loc, builder, lhs, {I, K});
+ hlfir::Entity rhsElementValue =
+ hlfir::loadElementAt(loc, builder, rhs, {K, J});
+ mlir::Value productValue = genAccumulateProduct(
+ loc, builder, resultElementType, resultElementValue,
+ lhsElementValue, rhsElementValue);
+ builder.create<hlfir::AssignOp>(loc, productValue, resultElement);
+ // builder.create<fir::StoreOp>(loc, productValue,
+ // resultElement);
+ return {};
+ };
+
+ // Note that the loops are inserted in reverse order,
+ // so innerProductExtent should be passed as the last extent.
+ hlfir::genLoopNestWithReductions(
+ loc, builder,
+ {resultExtents[0], resultExtents[1], innerProductExtent},
+ /*reductionInits=*/{}, genMatrixMatrix, isUnordered);
+ return mlir::success();
+ }
+
+ if (lhs.getRank() == 2 && rhs.getRank() == 1) {
+ // LHS(NROWS,N) * RHS(N) -> RESULT(NROWS)
+ //
+ // Insert the computation loop nest:
+ // DO 2 K = 1, N
+ // DO 2 J = 1, NROWS
+ // 2 RES(J) = RES(J) + LHS(J,K)*RHS(K)
+ auto genMatrixVector = [&](mlir::Location loc, fir::FirOpBuilder &builder,
+ mlir::ValueRange oneBasedIndices,
+ mlir::ValueRange reductionArgs)
+ -> llvm::SmallVector<mlir::Value, 0> {
+ mlir::Value J = oneBasedIndices[0];
+ mlir::Value K = oneBasedIndices[1];
+ hlfir::Entity resultElement =
+ hlfir::getElementAt(loc, builder, result, {J});
+ hlfir::Entity resultElementValue =
+ hlfir::loadTrivialScalar(loc, builder, resultElement);
+ hlfir::Entity lhsElementValue =
+ hlfir::loadElementAt(loc, builder, lhs, {J, K});
+ hlfir::Entity rhsElementValue =
+ hlfir::loadElementAt(loc, builder, rhs, {K});
+ mlir::Value productValue = genAccumulateProduct(
+ loc, builder, resultElementType, resultElementValue,
+ lhsElementValue, rhsElementValue);
+ builder.create<hlfir::AssignOp>(loc, productValue, resultElement);
+ // builder.create<fir::StoreOp>(loc, productValue,
+ // resultElement);
+ return {};
+ };
+ hlfir::genLoopNestWithReductions(
+ loc, builder, {resultExtents[0], innerProductExtent},
+ /*reductionInits=*/{}, genMatrixVector, isUnordered);
+ return mlir::success();
+ }
+ if (lhs.getRank() == 1 && rhs.getRank() == 2) {
+ // LHS(N) * RHS(N,NCOLS) -> RESULT(NCOLS)
+ //
+ // Insert the computation loop nest:
+ // DO 2 K = 1, N
+ // DO 2 J = 1, NCOLS
+ // 2 RES(J) = RES(J) + LHS(K)*RHS(K,J)
+ auto genVectorMatrix = [&](mlir::Location loc, fir::FirOpBuilder &builder,
+ mlir::ValueRange oneBasedIndices,
+ mlir::ValueRange reductionArgs)
+ -> llvm::SmallVector<mlir::Value, 0> {
+ mlir::Value J = oneBasedIndices[0];
+ mlir::Value K = oneBasedIndices[1];
+ hlfir::Entity resultElement =
+ hlfir::getElementAt(loc, builder, result, {J});
+ hlfir::Entity resultElementValue =
+ hlfir::loadTrivialScalar(loc, builder, resultElement);
+ hlfir::Entity lhsElementValue =
+ hlfir::loadElementAt(loc, builder, lhs, {K});
+ hlfir::Entity rhsElementValue =
+ hlfir::loadElementAt(loc, builder, rhs, {K, J});
+ mlir::Value productValue = genAccumulateProduct(
+ loc, builder, resultElementType, resultElementValue,
+ lhsElementValue, rhsElementValue);
+ builder.create<hlfir::AssignOp>(loc, productValue, resultElement);
+ // builder.create<fir::StoreOp>(loc, productValue,
+ // resultElement);
+ return {};
+ };
+ hlfir::genLoopNestWithReductions(
+ loc, builder, {resultExtents[0], innerProductExtent},
+ /*reductionInits=*/{}, genVectorMatrix, isUnordered);
+ return mlir::success();
+ }
+
+ llvm_unreachable("unsupported MATMUL arguments' ranks");
+ }
+
+ static hlfir::ElementalOp
+ genElementalMatmul(mlir::Location loc, fir::FirOpBuilder &builder,
+ hlfir::ExprType resultType, mlir::Value resultShape,
+ hlfir::Entity lhs, hlfir::Entity rhs,
+ mlir::Value innerProductExtent) {
+ mlir::OpBuilder::InsertionGuard guard(builder);
+ mlir::Type resultElementType = resultType.getElementType();
+ auto genKernel = [&](mlir::Location loc, fir::FirOpBuilder &builder,
+ mlir::ValueRange resultIndices) -> hlfir::Entity {
+ mlir::Value initValue =
+ fir::factory::createZeroValue(builder, loc, resultElementType);
+ // The inner product loop may be unordered if FastMathFlags::reassoc
+ // transformations are allowed. The integer/logical inner product is
+ // always unordered.
+ bool isUnordered = mlir::isa<mlir::IntegerType>(resultElementType) ||
+ mlir::isa<fir::LogicalType>(resultElementType) ||
+ static_cast<bool>(builder.getFastMathFlags() &
+ mlir::arith::FastMathFlags::reassoc);
+
+ auto genBody = [&](mlir::Location loc, fir::FirOpBuilder &builder,
+ mlir::ValueRange oneBasedIndices,
+ mlir::ValueRange reductionArgs)
+ -> llvm::SmallVector<mlir::Value, 1> {
+ llvm::SmallVector<mlir::Value, 2> lhsIndices;
+ llvm::SmallVector<mlir::Value, 2> rhsIndices;
+ // MATMUL:
+ // LHS(NROWS,N) * RHS(N,NCOLS) -> RESULT(NROWS,NCOLS)
+ // LHS(NROWS,N) * RHS(N) -> RESULT(NROWS)
+ // LHS(N) * RHS(N,NCOLS) -> RESULT(NCOLS)
+ //
+ // MATMUL(TRANSPOSE):
+ // TRANSPOSE(LHS(N,NROWS)) * RHS(N,NCOLS) -> RESULT(NROWS,NCOLS)
+ // TRANSPOSE(LHS(N,NROWS)) * RHS(N) -> RESULT(NROWS)
+ //
+ // The resultIndices iterate over (NROWS[,NCOLS]).
+ // The oneBasedIndices iterate over (N).
+ if (lhs.getRank() > 1)
+ lhsIndices.push_back(resultIndices[0]);
+ lhsIndices.push_back(oneBasedIndices[0]);
+
+ if constexpr (isMatmulTranspose) {
+ // Swap the LHS indices for TRANSPOSE.
+ std::swap(lhsIndices[0], lhsIndices[1]);
+ }
+
+ rhsIndices.push_back(oneBasedIndices[0]);
+ if (rhs.getRank() > 1)
+ rhsIndices.push_back(resultIndices.back());
+
+ hlfir::Entity lhsElementValue =
+ hlfir::loadElementAt(loc, builder, lhs, lhsIndices);
+ hlfir::Entity rhsElementValue =
+ hlfir::loadElementAt(loc, builder, rhs, rhsIndices);
+ mlir::Value productValue = genAccumulateProduct(
+ loc, builder, resultElementType, reductionArgs[0], lhsElementValue,
+ rhsElementValue);
+ return {productValue};
+ };
+ llvm::SmallVector<mlir::Value, 1> innerProductValue =
+ hlfir::genLoopNestWithReductions(loc, builder, {innerProductExtent},
+ {initValue}, genBody, isUnordered);
+ return hlfir::Entity{innerProductValue[0]};
+ };
+ hlfir::ElementalOp elementalOp = hlfir::genElementalOp(
+ loc, builder, resultElementType, resultShape, /*typeParams=*/{},
+ genKernel,
+ /*isUnordered=*/true, /*polymorphicMold=*/nullptr, resultType);
+
+ return elementalOp;
+ }
+};
+
class SimplifyHLFIRIntrinsics
: public hlfir::impl::SimplifyHLFIRIntrinsicsBase<SimplifyHLFIRIntrinsics> {
public:
+ using SimplifyHLFIRIntrinsicsBase<
+ SimplifyHLFIRIntrinsics>::SimplifyHLFIRIntrinsicsBase;
+
void runOnOperation() override {
mlir::MLIRContext *context = &getContext();
@@ -482,6 +918,22 @@ class SimplifyHLFIRIntrinsics
patterns.insert<TransposeAsElementalConversion>(context);
patterns.insert<SumAsElementalConversion>(context);
patterns.insert<CShiftAsElementalConversion>(context);
+ patterns.insert<MatmulConversion<hlfir::MatmulTransposeOp>>(context);
+
+ // If forceMatmulAsElemental is false, then hlfir.matmul inlining
+ // will introduce hlfir.eval_in_mem operation with new memory side
+ // effects. This conflicts with CSE and optimized bufferization, e.g.:
+ // A(1:N,1:N) = A(1:N,1:N) - MATMUL(...)
+ // If we introduce hlfir.eval_in_mem before CSE, then the current
+ // MLIR CSE won't be able to optimize the trivial loads of 'N' value
+ // that happen before and after hlfir.matmul.
+ // If 'N' loads are not optimized, then the optimized bufferization
+ // won't be able to prove that the slices of A are identical
+ // on both sides of the assignment.
+ // This is actually the CSE problem, but we can work it around
+ // for the time being.
+ if (forceMatmulAsElemental || this->allowNewSideEffects)
+ patterns.insert<MatmulConversion<hlfir::MatmulOp>>(context);
if (mlir::failed(mlir::applyPatternsGreedily(
getOperation(), std::move(patterns), config))) {
diff --git a/flang/lib/Optimizer/Passes/Pipelines.cpp b/flang/lib/Optimizer/Passes/Pipelines.cpp
index e1d7376ec3805d..1cc3f0b81c20ad 100644
--- a/flang/lib/Optimizer/Passes/Pipelines.cpp
+++ b/flang/lib/Optimizer/Passes/Pipelines.cpp
@@ -232,6 +232,12 @@ void createHLFIRToFIRPassPipeline(mlir::PassManager &pm, bool enableOpenMP,
if (optLevel.isOptimizingForSpeed()) {
addCanonicalizerPassWithoutRegionSimplification(pm);
pm.addPass(mlir::createCSEPass());
+ // Run SimplifyHLFIRIntrinsics pass late after CSE,
+ // and allow introducing operations with new side effects.
+ addNestedPassToAllTopLevelOperations<PassConstructor>(pm, []() {
+ return hlfir::createSimplifyHLFIRIntrinsics(
+ {/*allowNewSideEffects=*/true});
+ });
addNestedPassToAllTopLevelOperations<PassConstructor>(
pm, hlfir::createOptimizedBufferization);
addNestedPassToAllTopLevelOperations<PassConstructor>(
diff --git a/flang/test/Driver/mlir-pass-pipeline.f90 b/flang/test/Driver/mlir-pass-pipeline.f90
index 55e86da2dfdf14..dd46aecb3274c1 100644
--- a/flang/test/Driver/mlir-pass-pipeline.f90
+++ b/flang/test/Driver/mlir-pass-pipeline.f90
@@ -35,15 +35,19 @@
! O2-NEXT: (S) {{.*}} num-dce'd
! O2-NEXT: Pipeline Collection : ['fir.global', 'func.func', 'omp.declare_reduction', 'omp.private']
! O2-NEXT: 'fir.global' Pipeline
+! O2-NEXT: SimplifyHLFIRIntrinsics
! O2-NEXT: OptimizedBufferization
! O2-NEXT: InlineHLFIRAssign
! O2-NEXT: 'func.func' Pipeline
+! O2-NEXT: SimplifyHLFIRIntrinsics
! O2-NEXT: OptimizedBufferization
! O2-NEXT: InlineHLFIRAssign
! O2-NEXT: 'omp.declare_reduction' Pipeline
+! O2-NEXT: SimplifyHLFIRIntrinsics
! O2-NEXT: OptimizedBufferization
! O2-NEXT: InlineHLFIRAssign
! O2-NEXT: 'omp.private' Pipeline
+! O2-NEXT: SimplifyHLFIRIntrinsics
! O2-NEXT: OptimizedBufferization
! O2-NEXT: InlineHLFIRAssign
! ALL: LowerHLFIROrderedAssignments
diff --git a/flang/test/Fir/basic-program.fir b/flang/test/Fir/basic-program.fir
index 29a0f661579710..51e68d2157631a 100644
--- a/flang/test/Fir/basic-program.fir
+++ b/flang/test/Fir/basic-program.fir
@@ -36,15 +36,19 @@ func.func @_QQmain() {
// PASSES-NEXT: (S) 0 num-dce'd - Number of operations DCE'd
// PASSES-NEXT: Pipeline Collection : ['fir.global', 'func.func', 'omp.declare_reduction', 'omp.private']
// PASSES-NEXT: 'fir.global' Pipeline
+// PASSES-NEXT: SimplifyHLFIRIntrinsics
// PASSES-NEXT: OptimizedBufferization
// PASSES-NEXT: InlineHLFIRAssign
// PASSES-NEXT: 'func.func' Pipeline
+// PASSES-NEXT: SimplifyHLFIRIntrinsics
// PASSES-NEXT: OptimizedBufferization
// PASSES-NEXT: InlineHLFIRAssign
// PASSES-NEXT: 'omp.declare_reduction' Pipeline
+// PASSES-NEXT: SimplifyHLFIRIntrinsics
// PASSES-NEXT: OptimizedBufferization
// PASSES-NEXT: InlineHLFIRAssign
// PASSES-NEXT: 'omp.private' Pipeline
+// PASSES-NEXT: SimplifyHLFIRIntrinsics
// PASSES-NEXT: OptimizedBufferization
// PASSES-NEXT: InlineHLFIRAssign
// PASSES-NEXT: LowerHLFIROrderedAssignments
diff --git a/flang/test/HLFIR/simplify-hlfir-intrinsics-matmul.fir b/flang/test/HLFIR/simplify-hlfir-intrinsics-matmul.fir
new file mode 100644
index 00000000000000..d29e9a26c20ba9
--- /dev/null
+++ b/flang/test/HLFIR/simplify-hlfir-intrinsics-matmul.fir
@@ -0,0 +1,660 @@
+// Test hlfir.cshift simplification to hlfir.elemental:
+// RUN: fir-opt --simplify-hlfir-intrinsics=allow-new-side-effects=false %s | FileCheck %s --check-prefixes=ALL,NOANSE
+// RUN: fir-opt --simplify-hlfir-intrinsics=allow-new-side-effects=true %s | FileCheck %s --check-prefixes=ALL,ANSE
+// RUN: fir-opt --simplify-hlfir-intrinsics -flang-inline-matmul-as-elemental %s | FileCheck %s --check-prefixes=ALL,ELEMENTAL
+
+func.func @matmul_matrix_matrix_integer(%arg0: !hlfir.expr<?x?xi16>, %arg1: !hlfir.expr<?x?xi32>) -> !hlfir.expr<?x?xi32> {
+ %res = hlfir.matmul %arg0 %arg1 : (!hlfir.expr<?x?xi16>, !hlfir.expr<?x?xi32>) -> !hlfir.expr<?x?xi32>
+ return %res : !hlfir.expr<?x?xi32>
+}
+// ALL-LABEL: func.func @matmul_matrix_matrix_integer(
+// ALL-SAME: %[[VAL_0:.*]]: !hlfir.expr<?x?xi16>,
+// ALL-SAME: %[[VAL_1:.*]]: !hlfir.expr<?x?xi32>) -> !hlfir.expr<?x?xi32> {
+
+// NOANSE: hlfir.matmul
+
+// ANSE: %[[VAL_2:.*]] = arith.constant 0 : index
+// ANSE: %[[VAL_3:.*]] = arith.constant 1 : index
+// ANSE: %[[VAL_4:.*]] = arith.constant 0 : i32
+// ANSE: %[[VAL_5:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr<?x?xi16>) -> !fir.shape<2>
+// ANSE: %[[VAL_6:.*]] = hlfir.get_extent %[[VAL_5]] {dim = 0 : index} : (!fir.shape<2>) -> index
+// ANSE: %[[VAL_7:.*]] = hlfir.get_extent %[[VAL_5]] {dim = 1 : index} : (!fir.shape<2>) -> index
+// ANSE: %[[VAL_8:.*]] = hlfir.shape_of %[[VAL_1]] : (!hlfir.expr<?x?xi32>) -> !fir.shape<2>
+// ANSE: %[[VAL_9:.*]] = hlfir.get_extent %[[VAL_8]] {dim = 1 : index} : (!fir.shape<2>) -> index
+// ANSE: %[[VAL_10:.*]] = fir.shape %[[VAL_6]], %[[VAL_9]] : (index, index) -> !fir.shape<2>
+// ANSE: %[[VAL_11:.*]] = hlfir.eval_in_mem shape %[[VAL_10]] : (!fir.shape<2>) -> !hlfir.expr<?x?xi32> {
+// ANSE: ^bb0(%[[VAL_12:.*]]: !fir.ref<!fir.array<?x?xi32>>):
+// ANSE: %[[VAL_13:.*]] = fir.embox %[[VAL_12]](%[[VAL_10]]) : (!fir.ref<!fir.array<?x?xi32>>, !fir.shape<2>) -> !fir.box<!fir.array<?x?xi32>>
+// ANSE: fir.do_loop %[[VAL_14:.*]] = %[[VAL_3]] to %[[VAL_9]] step %[[VAL_3]] unordered {
+// ANSE: fir.do_loop %[[VAL_15:.*]] = %[[VAL_3]] to %[[VAL_6]] step %[[VAL_3]] unordered {
+// ANSE: %[[VAL_16:.*]]:3 = fir.box_dims %[[VAL_13]], %[[VAL_2]] : (!fir.box<!fir.array<?x?xi32>>, index) -> (index, index, index)
+// ANSE: %[[VAL_17:.*]]:3 = fir.box_dims %[[VAL_13]], %[[VAL_3]] : (!fir.box<!fir.array<?x?xi32>>, index) -> (index, index, index)
+// ANSE: %[[VAL_18:.*]] = arith.subi %[[VAL_16]]#0, %[[VAL_3]] : index
+// ANSE: %[[VAL_19:.*]] = arith.addi %[[VAL_15]], %[[VAL_18]] : index
+// ANSE: %[[VAL_20:.*]] = arith.subi %[[VAL_17]]#0, %[[VAL_3]] : index
+// ANSE: %[[VAL_21:.*]] = arith.addi %[[VAL_14]], %[[VAL_20]] : index
+// ANSE: %[[VAL_22:.*]] = hlfir.designate %[[VAL_13]] (%[[VAL_19]], %[[VAL_21]]) : (!fir.box<!fir.array<?x?xi32>>, index, index) -> !fir.ref<i32>
+// ANSE: hlfir.assign %[[VAL_4]] to %[[VAL_22]] : i32, !fir.ref<i32>
+// ANSE: }
+// ANSE: }
+// ANSE: fir.do_loop %[[VAL_23:.*]] = %[[VAL_3]] to %[[VAL_7]] step %[[VAL_3]] unordered {
+// ANSE: fir.do_loop %[[VAL_24:.*]] = %[[VAL_3]] to %[[VAL_9]] step %[[VAL_3]] unordered {
+// ANSE: fir.do_loop %[[VAL_25:.*]] = %[[VAL_3]] to %[[VAL_6]] step %[[VAL_3]] unordered {
+// ANSE: %[[VAL_26:.*]]:3 = fir.box_dims %[[VAL_13]], %[[VAL_2]] : (!fir.box<!fir.array<?x?xi32>>, index) -> (index, index, index)
+// ANSE: %[[VAL_27:.*]]:3 = fir.box_dims %[[VAL_13]], %[[VAL_3]] : (!fir.box<!fir.array<?x?xi32>>, index) -> (index, index, index)
+// ANSE: %[[VAL_28:.*]] = arith.subi %[[VAL_26]]#0, %[[VAL_3]] : index
+// ANSE: %[[VAL_29:.*]] = arith.addi %[[VAL_25]], %[[VAL_28]] : index
+// ANSE: %[[VAL_30:.*]] = arith.subi %[[VAL_27]]#0, %[[VAL_3]] : index
+// ANSE: %[[VAL_31:.*]] = arith.addi %[[VAL_24]], %[[VAL_30]] : index
+// ANSE: %[[VAL_32:.*]] = hlfir.designate %[[VAL_13]] (%[[VAL_29]], %[[VAL_31]]) : (!fir.box<!fir.array<?x?xi32>>, index, index) -> !fir.ref<i32>
+// ANSE: %[[VAL_33:.*]] = fir.load %[[VAL_32]] : !fir.ref<i32>
+// ANSE: %[[VAL_34:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_25]], %[[VAL_23]] : (!hlfir.expr<?x?xi16>, index, index) -> i16
+// ANSE: %[[VAL_35:.*]] = hlfir.apply %[[VAL_1]], %[[VAL_23]], %[[VAL_24]] : (!hlfir.expr<?x?xi32>, index, index) -> i32
+// ANSE: %[[VAL_36:.*]] = fir.convert %[[VAL_34]] : (i16) -> i32
+// ANSE: %[[VAL_37:.*]] = arith.muli %[[VAL_36]], %[[VAL_35]] : i32
+// ANSE: %[[VAL_38:.*]] = arith.addi %[[VAL_33]], %[[VAL_37]] : i32
+// ANSE: hlfir.assign %[[VAL_38]] to %[[VAL_32]] : i32, !fir.ref<i32>
+// ANSE: }
+// ANSE: }
+// ANSE: }
+// ANSE: }
+// ANSE: return %[[VAL_11]] : !hlfir.expr<?x?xi32>
+// ANSE: }
+
+// ELEMENTAL: %[[VAL_2:.*]] = arith.constant 1 : index
+// ELEMENTAL: %[[VAL_3:.*]] = arith.constant 0 : i32
+// ELEMENTAL: %[[VAL_4:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr<?x?xi16>) -> !fir.shape<2>
+// ELEMENTAL: %[[VAL_5:.*]] = hlfir.get_extent %[[VAL_4]] {dim = 0 : index} : (!fir.shape<2>) -> index
+// ELEMENTAL: %[[VAL_6:.*]] = hlfir.get_extent %[[VAL_4]] {dim = 1 : index} : (!fir.shape<2>) -> index
+// ELEMENTAL: %[[VAL_7:.*]] = hlfir.shape_of %[[VAL_1]] : (!hlfir.expr<?x?xi32>) -> !fir.shape<2>
+// ELEMENTAL: %[[VAL_8:.*]] = hlfir.get_extent %[[VAL_7]] {dim = 1 : index} : (!fir.shape<2>) -> index
+// ELEMENTAL: %[[VAL_9:.*]] = fir.shape %[[VAL_5]], %[[VAL_8]] : (index, index) -> !fir.shape<2>
+// ELEMENTAL: %[[VAL_10:.*]] = hlfir.elemental %[[VAL_9]] unordered : (!fir.shape<2>) -> !hlfir.expr<?x?xi32> {
+// ELEMENTAL: ^bb0(%[[VAL_11:.*]]: index, %[[VAL_12:.*]]: index):
+// ELEMENTAL: %[[VAL_13:.*]] = fir.do_loop %[[VAL_14:.*]] = %[[VAL_2]] to %[[VAL_6]] step %[[VAL_2]] unordered iter_args(%[[VAL_15:.*]] = %[[VAL_3]]) -> (i32) {
+// ELEMENTAL: %[[VAL_16:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_11]], %[[VAL_14]] : (!hlfir.expr<?x?xi16>, index, index) -> i16
+// ELEMENTAL: %[[VAL_17:.*]] = hlfir.apply %[[VAL_1]], %[[VAL_14]], %[[VAL_12]] : (!hlfir.expr<?x?xi32>, index, index) -> i32
+// ELEMENTAL: %[[VAL_18:.*]] = fir.convert %[[VAL_16]] : (i16) -> i32
+// ELEMENTAL: %[[VAL_19:.*]] = arith.muli %[[VAL_18]], %[[VAL_17]] : i32
+// ELEMENTAL: %[[VAL_20:.*]] = arith.addi %[[VAL_15]], %[[VAL_19]] : i32
+// ELEMENTAL: fir.result %[[VAL_20]] : i32
+// ELEMENTAL: }
+// ELEMENTAL: hlfir.yield_element %[[VAL_13]] : i32
+// ELEMENTAL: }
+// ELEMENTAL: return %[[VAL_10]] : !hlfir.expr<?x?xi32>
+// ELEMENTAL: }
+
+func.func @matmul_matrix_matrix_real(%arg0: !hlfir.expr<?x?xf32>, %arg1: !hlfir.expr<?x?xf16>) -> !hlfir.expr<?x?xf32> {
+ %res = hlfir.matmul %arg0 %arg1 : (!hlfir.expr<?x?xf32>, !hlfir.expr<?x?xf16>) -> !hlfir.expr<?x?xf32>
+ return %res : !hlfir.expr<?x?xf32>
+}
+// ALL-LABEL: func.func @matmul_matrix_matrix_real(
+// ALL-SAME: %[[VAL_0:.*]]: !hlfir.expr<?x?xf32>,
+// ALL-SAME: %[[VAL_1:.*]]: !hlfir.expr<?x?xf16>) -> !hlfir.expr<?x?xf32> {
+
+// NOANSE: hlfir.matmul
+
+// ANSE: %[[VAL_2:.*]] = arith.constant 0 : index
+// ANSE: %[[VAL_3:.*]] = arith.constant 1 : index
+// ANSE: %[[VAL_4:.*]] = arith.constant 0.000000e+00 : f32
+// ANSE: %[[VAL_5:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr<?x?xf32>) -> !fir.shape<2>
+// ANSE: %[[VAL_6:.*]] = hlfir.get_extent %[[VAL_5]] {dim = 0 : index} : (!fir.shape<2>) -> index
+// ANSE: %[[VAL_7:.*]] = hlfir.get_extent %[[VAL_5]] {dim = 1 : index} : (!fir.shape<2>) -> index
+// ANSE: %[[VAL_8:.*]] = hlfir.shape_of %[[VAL_1]] : (!hlfir.expr<?x?xf16>) -> !fir.shape<2>
+// ANSE: %[[VAL_9:.*]] = hlfir.get_extent %[[VAL_8]] {dim = 1 : index} : (!fir.shape<2>) -> index
+// ANSE: %[[VAL_10:.*]] = fir.shape %[[VAL_6]], %[[VAL_9]] : (index, index) -> !fir.shape<2>
+// ANSE: %[[VAL_11:.*]] = hlfir.eval_in_mem shape %[[VAL_10]] : (!fir.shape<2>) -> !hlfir.expr<?x?xf32> {
+// ANSE: ^bb0(%[[VAL_12:.*]]: !fir.ref<!fir.array<?x?xf32>>):
+// ANSE: %[[VAL_13:.*]] = fir.embox %[[VAL_12]](%[[VAL_10]]) : (!fir.ref<!fir.array<?x?xf32>>, !fir.shape<2>) -> !fir.box<!fir.array<?x?xf32>>
+// ANSE: fir.do_loop %[[VAL_14:.*]] = %[[VAL_3]] to %[[VAL_9]] step %[[VAL_3]] unordered {
+// ANSE: fir.do_loop %[[VAL_15:.*]] = %[[VAL_3]] to %[[VAL_6]] step %[[VAL_3]] unordered {
+// ANSE: %[[VAL_16:.*]]:3 = fir.box_dims %[[VAL_13]], %[[VAL_2]] : (!fir.box<!fir.array<?x?xf32>>, index) -> (index, index, index)
+// ANSE: %[[VAL_17:.*]]:3 = fir.box_dims %[[VAL_13]], %[[VAL_3]] : (!fir.box<!fir.array<?x?xf32>>, index) -> (index, index, index)
+// ANSE: %[[VAL_18:.*]] = arith.subi %[[VAL_16]]#0, %[[VAL_3]] : index
+// ANSE: %[[VAL_19:.*]] = arith.addi %[[VAL_15]], %[[VAL_18]] : index
+// ANSE: %[[VAL_20:.*]] = arith.subi %[[VAL_17]]#0, %[[VAL_3]] : index
+// ANSE: %[[VAL_21:.*]] = arith.addi %[[VAL_14]], %[[VAL_20]] : index
+// ANSE: %[[VAL_22:.*]] = hlfir.designate %[[VAL_13]] (%[[VAL_19]], %[[VAL_21]]) : (!fir.box<!fir.array<?x?xf32>>, index, index) -> !fir.ref<f32>
+// ANSE: hlfir.assign %[[VAL_4]] to %[[VAL_22]] : f32, !fir.ref<f32>
+// ANSE: }
+// ANSE: }
+// ANSE: fir.do_loop %[[VAL_23:.*]] = %[[VAL_3]] to %[[VAL_7]] step %[[VAL_3]] {
+// ANSE: fir.do_loop %[[VAL_24:.*]] = %[[VAL_3]] to %[[VAL_9]] step %[[VAL_3]] {
+// ANSE: fir.do_loop %[[VAL_25:.*]] = %[[VAL_3]] to %[[VAL_6]] step %[[VAL_3]] {
+// ANSE: %[[VAL_26:.*]]:3 = fir.box_dims %[[VAL_13]], %[[VAL_2]] : (!fir.box<!fir.array<?x?xf32>>, index) -> (index, index, index)
+// ANSE: %[[VAL_27:.*]]:3 = fir.box_dims %[[VAL_13]], %[[VAL_3]] : (!fir.box<!fir.array<?x?xf32>>, index) -> (index, index, index)
+// ANSE: %[[VAL_28:.*]] = arith.subi %[[VAL_26]]#0, %[[VAL_3]] : index
+// ANSE: %[[VAL_29:.*]] = arith.addi %[[VAL_25]], %[[VAL_28]] : index
+// ANSE: %[[VAL_30:.*]] = arith.subi %[[VAL_27]]#0, %[[VAL_3]] : index
+// ANSE: %[[VAL_31:.*]] = arith.addi %[[VAL_24]], %[[VAL_30]] : index
+// ANSE: %[[VAL_32:.*]] = hlfir.designate %[[VAL_13]] (%[[VAL_29]], %[[VAL_31]]) : (!fir.box<!fir.array<?x?xf32>>, index, index) -> !fir.ref<f32>
+// ANSE: %[[VAL_33:.*]] = fir.load %[[VAL_32]] : !fir.ref<f32>
+// ANSE: %[[VAL_34:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_25]], %[[VAL_23]] : (!hlfir.expr<?x?xf32>, index, index) -> f32
+// ANSE: %[[VAL_35:.*]] = hlfir.apply %[[VAL_1]], %[[VAL_23]], %[[VAL_24]] : (!hlfir.expr<?x?xf16>, index, index) -> f16
+// ANSE: %[[VAL_36:.*]] = fir.convert %[[VAL_35]] : (f16) -> f32
+// ANSE: %[[VAL_37:.*]] = arith.mulf %[[VAL_34]], %[[VAL_36]] : f32
+// ANSE: %[[VAL_38:.*]] = arith.addf %[[VAL_33]], %[[VAL_37]] : f32
+// ANSE: hlfir.assign %[[VAL_38]] to %[[VAL_32]] : f32, !fir.ref<f32>
+// ANSE: }
+// ANSE: }
+// ANSE: }
+// ANSE: }
+// ANSE: return %[[VAL_11]] : !hlfir.expr<?x?xf32>
+// ANSE: }
+
+// ELEMENTAL: %[[VAL_2:.*]] = arith.constant 1 : index
+// ELEMENTAL: %[[VAL_3:.*]] = arith.constant 0.000000e+00 : f32
+// ELEMENTAL: %[[VAL_4:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr<?x?xf32>) -> !fir.shape<2>
+// ELEMENTAL: %[[VAL_5:.*]] = hlfir.get_extent %[[VAL_4]] {dim = 0 : index} : (!fir.shape<2>) -> index
+// ELEMENTAL: %[[VAL_6:.*]] = hlfir.get_extent %[[VAL_4]] {dim = 1 : index} : (!fir.shape<2>) -> index
+// ELEMENTAL: %[[VAL_7:.*]] = hlfir.shape_of %[[VAL_1]] : (!hlfir.expr<?x?xf16>) -> !fir.shape<2>
+// ELEMENTAL: %[[VAL_8:.*]] = hlfir.get_extent %[[VAL_7]] {dim = 1 : index} : (!fir.shape<2>) -> index
+// ELEMENTAL: %[[VAL_9:.*]] = fir.shape %[[VAL_5]], %[[VAL_8]] : (index, index) -> !fir.shape<2>
+// ELEMENTAL: %[[VAL_10:.*]] = hlfir.elemental %[[VAL_9]] unordered : (!fir.shape<2>) -> !hlfir.expr<?x?xf32> {
+// ELEMENTAL: ^bb0(%[[VAL_11:.*]]: index, %[[VAL_12:.*]]: index):
+// ELEMENTAL: %[[VAL_13:.*]] = fir.do_loop %[[VAL_14:.*]] = %[[VAL_2]] to %[[VAL_6]] step %[[VAL_2]] iter_args(%[[VAL_15:.*]] = %[[VAL_3]]) -> (f32) {
+// ELEMENTAL: %[[VAL_16:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_11]], %[[VAL_14]] : (!hlfir.expr<?x?xf32>, index, index) -> f32
+// ELEMENTAL: %[[VAL_17:.*]] = hlfir.apply %[[VAL_1]], %[[VAL_14]], %[[VAL_12]] : (!hlfir.expr<?x?xf16>, index, index) -> f16
+// ELEMENTAL: %[[VAL_18:.*]] = fir.convert %[[VAL_17]] : (f16) -> f32
+// ELEMENTAL: %[[VAL_19:.*]] = arith.mulf %[[VAL_16]], %[[VAL_18]] : f32
+// ELEMENTAL: %[[VAL_20:.*]] = arith.addf %[[VAL_15]], %[[VAL_19]] : f32
+// ELEMENTAL: fir.result %[[VAL_20]] : f32
+// ELEMENTAL: }
+// ELEMENTAL: hlfir.yield_element %[[VAL_13]] : f32
+// ELEMENTAL: }
+// ELEMENTAL: return %[[VAL_10]] : !hlfir.expr<?x?xf32>
+// ELEMENTAL: }
+
+func.func @matmul_matrix_matrix_complex(%arg0: !hlfir.expr<?x?xcomplex<f32>>, %arg1: !hlfir.expr<?x?xcomplex<f16>>) -> !hlfir.expr<?x?xcomplex<f32>> {
+ %res = hlfir.matmul %arg0 %arg1 : (!hlfir.expr<?x?xcomplex<f32>>, !hlfir.expr<?x?xcomplex<f16>>) -> !hlfir.expr<?x?xcomplex<f32>>
+ return %res : !hlfir.expr<?x?xcomplex<f32>>
+}
+// ALL-LABEL: func.func @matmul_matrix_matrix_complex(
+// ALL-SAME: %[[VAL_0:.*]]: !hlfir.expr<?x?xcomplex<f32>>,
+// ALL-SAME: %[[VAL_1:.*]]: !hlfir.expr<?x?xcomplex<f16>>) -> !hlfir.expr<?x?xcomplex<f32>> {
+
+// NOANSE: hlfir.matmul
+
+// ANSE: %[[VAL_2:.*]] = arith.constant 0 : index
+// ANSE: %[[VAL_3:.*]] = arith.constant 1 : index
+// ANSE: %[[VAL_4:.*]] = arith.constant 0.000000e+00 : f32
+// ANSE: %[[VAL_5:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr<?x?xcomplex<f32>>) -> !fir.shape<2>
+// ANSE: %[[VAL_6:.*]] = hlfir.get_extent %[[VAL_5]] {dim = 0 : index} : (!fir.shape<2>) -> index
+// ANSE: %[[VAL_7:.*]] = hlfir.get_extent %[[VAL_5]] {dim = 1 : index} : (!fir.shape<2>) -> index
+// ANSE: %[[VAL_8:.*]] = hlfir.shape_of %[[VAL_1]] : (!hlfir.expr<?x?xcomplex<f16>>) -> !fir.shape<2>
+// ANSE: %[[VAL_9:.*]] = hlfir.get_extent %[[VAL_8]] {dim = 1 : index} : (!fir.shape<2>) -> index
+// ANSE: %[[VAL_10:.*]] = fir.shape %[[VAL_6]], %[[VAL_9]] : (index, index) -> !fir.shape<2>
+// ANSE: %[[VAL_11:.*]] = hlfir.eval_in_mem shape %[[VAL_10]] : (!fir.shape<2>) -> !hlfir.expr<?x?xcomplex<f32>> {
+// ANSE: ^bb0(%[[VAL_12:.*]]: !fir.ref<!fir.array<?x?xcomplex<f32>>>):
+// ANSE: %[[VAL_13:.*]] = fir.embox %[[VAL_12]](%[[VAL_10]]) : (!fir.ref<!fir.array<?x?xcomplex<f32>>>, !fir.shape<2>) -> !fir.box<!fir.array<?x?xcomplex<f32>>>
+// ANSE: %[[VAL_14:.*]] = fir.undefined complex<f32>
+// ANSE: %[[VAL_15:.*]] = fir.insert_value %[[VAL_14]], %[[VAL_4]], [0 : index] : (complex<f32>, f32) -> complex<f32>
+// ANSE: %[[VAL_16:.*]] = fir.insert_value %[[VAL_15]], %[[VAL_4]], [1 : index] : (complex<f32>, f32) -> complex<f32>
+// ANSE: fir.do_loop %[[VAL_17:.*]] = %[[VAL_3]] to %[[VAL_9]] step %[[VAL_3]] unordered {
+// ANSE: fir.do_loop %[[VAL_18:.*]] = %[[VAL_3]] to %[[VAL_6]] step %[[VAL_3]] unordered {
+// ANSE: %[[VAL_19:.*]]:3 = fir.box_dims %[[VAL_13]], %[[VAL_2]] : (!fir.box<!fir.array<?x?xcomplex<f32>>>, index) -> (index, index, index)
+// ANSE: %[[VAL_20:.*]]:3 = fir.box_dims %[[VAL_13]], %[[VAL_3]] : (!fir.box<!fir.array<?x?xcomplex<f32>>>, index) -> (index, index, index)
+// ANSE: %[[VAL_21:.*]] = arith.subi %[[VAL_19]]#0, %[[VAL_3]] : index
+// ANSE: %[[VAL_22:.*]] = arith.addi %[[VAL_18]], %[[VAL_21]] : index
+// ANSE: %[[VAL_23:.*]] = arith.subi %[[VAL_20]]#0, %[[VAL_3]] : index
+// ANSE: %[[VAL_24:.*]] = arith.addi %[[VAL_17]], %[[VAL_23]] : index
+// ANSE: %[[VAL_25:.*]] = hlfir.designate %[[VAL_13]] (%[[VAL_22]], %[[VAL_24]]) : (!fir.box<!fir.array<?x?xcomplex<f32>>>, index, index) -> !fir.ref<complex<f32>>
+// ANSE: hlfir.assign %[[VAL_16]] to %[[VAL_25]] : complex<f32>, !fir.ref<complex<f32>>
+// ANSE: }
+// ANSE: }
+// ANSE: fir.do_loop %[[VAL_26:.*]] = %[[VAL_3]] to %[[VAL_7]] step %[[VAL_3]] {
+// ANSE: fir.do_loop %[[VAL_27:.*]] = %[[VAL_3]] to %[[VAL_9]] step %[[VAL_3]] {
+// ANSE: fir.do_loop %[[VAL_28:.*]] = %[[VAL_3]] to %[[VAL_6]] step %[[VAL_3]] {
+// ANSE: %[[VAL_29:.*]]:3 = fir.box_dims %[[VAL_13]], %[[VAL_2]] : (!fir.box<!fir.array<?x?xcomplex<f32>>>, index) -> (index, index, index)
+// ANSE: %[[VAL_30:.*]]:3 = fir.box_dims %[[VAL_13]], %[[VAL_3]] : (!fir.box<!fir.array<?x?xcomplex<f32>>>, index) -> (index, index, index)
+// ANSE: %[[VAL_31:.*]] = arith.subi %[[VAL_29]]#0, %[[VAL_3]] : index
+// ANSE: %[[VAL_32:.*]] = arith.addi %[[VAL_28]], %[[VAL_31]] : index
+// ANSE: %[[VAL_33:.*]] = arith.subi %[[VAL_30]]#0, %[[VAL_3]] : index
+// ANSE: %[[VAL_34:.*]] = arith.addi %[[VAL_27]], %[[VAL_33]] : index
+// ANSE: %[[VAL_35:.*]] = hlfir.designate %[[VAL_13]] (%[[VAL_32]], %[[VAL_34]]) : (!fir.box<!fir.array<?x?xcomplex<f32>>>, index, index) -> !fir.ref<complex<f32>>
+// ANSE: %[[VAL_36:.*]] = fir.load %[[VAL_35]] : !fir.ref<complex<f32>>
+// ANSE: %[[VAL_37:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_28]], %[[VAL_26]] : (!hlfir.expr<?x?xcomplex<f32>>, index, index) -> complex<f32>
+// ANSE: %[[VAL_38:.*]] = hlfir.apply %[[VAL_1]], %[[VAL_26]], %[[VAL_27]] : (!hlfir.expr<?x?xcomplex<f16>>, index, index) -> complex<f16>
+// ANSE: %[[VAL_39:.*]] = fir.convert %[[VAL_38]] : (complex<f16>) -> complex<f32>
+// ANSE: %[[VAL_40:.*]] = fir.mulc %[[VAL_37]], %[[VAL_39]] : complex<f32>
+// ANSE: %[[VAL_41:.*]] = fir.addc %[[VAL_36]], %[[VAL_40]] : complex<f32>
+// ANSE: hlfir.assign %[[VAL_41]] to %[[VAL_35]] : complex<f32>, !fir.ref<complex<f32>>
+// ANSE: }
+// ANSE: }
+// ANSE: }
+// ANSE: }
+// ANSE: return %[[VAL_11]] : !hlfir.expr<?x?xcomplex<f32>>
+// ANSE: }
+
+// ELEMENTAL: %[[VAL_2:.*]] = arith.constant 1 : index
+// ELEMENTAL: %[[VAL_3:.*]] = arith.constant 0.000000e+00 : f32
+// ELEMENTAL: %[[VAL_4:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr<?x?xcomplex<f32>>) -> !fir.shape<2>
+// ELEMENTAL: %[[VAL_5:.*]] = hlfir.get_extent %[[VAL_4]] {dim = 0 : index} : (!fir.shape<2>) -> index
+// ELEMENTAL: %[[VAL_6:.*]] = hlfir.get_extent %[[VAL_4]] {dim = 1 : index} : (!fir.shape<2>) -> index
+// ELEMENTAL: %[[VAL_7:.*]] = hlfir.shape_of %[[VAL_1]] : (!hlfir.expr<?x?xcomplex<f16>>) -> !fir.shape<2>
+// ELEMENTAL: %[[VAL_8:.*]] = hlfir.get_extent %[[VAL_7]] {dim = 1 : index} : (!fir.shape<2>) -> index
+// ELEMENTAL: %[[VAL_9:.*]] = fir.shape %[[VAL_5]], %[[VAL_8]] : (index, index) -> !fir.shape<2>
+// ELEMENTAL: %[[VAL_10:.*]] = hlfir.elemental %[[VAL_9]] unordered : (!fir.shape<2>) -> !hlfir.expr<?x?xcomplex<f32>> {
+// ELEMENTAL: ^bb0(%[[VAL_11:.*]]: index, %[[VAL_12:.*]]: index):
+// ELEMENTAL: %[[VAL_13:.*]] = fir.undefined complex<f32>
+// ELEMENTAL: %[[VAL_14:.*]] = fir.insert_value %[[VAL_13]], %[[VAL_3]], [0 : index] : (complex<f32>, f32) -> complex<f32>
+// ELEMENTAL: %[[VAL_15:.*]] = fir.insert_value %[[VAL_14]], %[[VAL_3]], [1 : index] : (complex<f32>, f32) -> complex<f32>
+// ELEMENTAL: %[[VAL_16:.*]] = fir.do_loop %[[VAL_17:.*]] = %[[VAL_2]] to %[[VAL_6]] step %[[VAL_2]] iter_args(%[[VAL_18:.*]] = %[[VAL_15]]) -> (complex<f32>) {
+// ELEMENTAL: %[[VAL_19:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_11]], %[[VAL_17]] : (!hlfir.expr<?x?xcomplex<f32>>, index, index) -> complex<f32>
+// ELEMENTAL: %[[VAL_20:.*]] = hlfir.apply %[[VAL_1]], %[[VAL_17]], %[[VAL_12]] : (!hlfir.expr<?x?xcomplex<f16>>, index, index) -> complex<f16>
+// ELEMENTAL: %[[VAL_21:.*]] = fir.convert %[[VAL_20]] : (complex<f16>) -> complex<f32>
+// ELEMENTAL: %[[VAL_22:.*]] = fir.mulc %[[VAL_19]], %[[VAL_21]] : complex<f32>
+// ELEMENTAL: %[[VAL_23:.*]] = fir.addc %[[VAL_18]], %[[VAL_22]] : complex<f32>
+// ELEMENTAL: fir.result %[[VAL_23]] : complex<f32>
+// ELEMENTAL: }
+// ELEMENTAL: hlfir.yield_element %[[VAL_16]] : complex<f32>
+// ELEMENTAL: }
+// ELEMENTAL: return %[[VAL_10]] : !hlfir.expr<?x?xcomplex<f32>>
+// ELEMENTAL: }
+
+func.func @matmul_matrix_matrix_complex_real(%arg0: !hlfir.expr<?x?xcomplex<f32>>, %arg1: !hlfir.expr<?x?xf16>) -> !hlfir.expr<?x?xcomplex<f32>> {
+ %res = hlfir.matmul %arg0 %arg1 : (!hlfir.expr<?x?xcomplex<f32>>, !hlfir.expr<?x?xf16>) -> !hlfir.expr<?x?xcomplex<f32>>
+ return %res : !hlfir.expr<?x?xcomplex<f32>>
+}
+// ALL-LABEL: func.func @matmul_matrix_matrix_complex_real(
+// ALL-SAME: %[[VAL_0:.*]]: !hlfir.expr<?x?xcomplex<f32>>,
+// ALL-SAME: %[[VAL_1:.*]]: !hlfir.expr<?x?xf16>) -> !hlfir.expr<?x?xcomplex<f32>> {
+
+// NOANSE: hlfir.matmul
+
+// ANSE: %[[VAL_2:.*]] = arith.constant 0 : index
+// ANSE: %[[VAL_3:.*]] = arith.constant 1 : index
+// ANSE: %[[VAL_4:.*]] = arith.constant 0.000000e+00 : f32
+// ANSE: %[[VAL_5:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr<?x?xcomplex<f32>>) -> !fir.shape<2>
+// ANSE: %[[VAL_6:.*]] = hlfir.get_extent %[[VAL_5]] {dim = 0 : index} : (!fir.shape<2>) -> index
+// ANSE: %[[VAL_7:.*]] = hlfir.get_extent %[[VAL_5]] {dim = 1 : index} : (!fir.shape<2>) -> index
+// ANSE: %[[VAL_8:.*]] = hlfir.shape_of %[[VAL_1]] : (!hlfir.expr<?x?xf16>) -> !fir.shape<2>
+// ANSE: %[[VAL_9:.*]] = hlfir.get_extent %[[VAL_8]] {dim = 1 : index} : (!fir.shape<2>) -> index
+// ANSE: %[[VAL_10:.*]] = fir.shape %[[VAL_6]], %[[VAL_9]] : (index, index) -> !fir.shape<2>
+// ANSE: %[[VAL_11:.*]] = hlfir.eval_in_mem shape %[[VAL_10]] : (!fir.shape<2>) -> !hlfir.expr<?x?xcomplex<f32>> {
+// ANSE: ^bb0(%[[VAL_12:.*]]: !fir.ref<!fir.array<?x?xcomplex<f32>>>):
+// ANSE: %[[VAL_13:.*]] = fir.embox %[[VAL_12]](%[[VAL_10]]) : (!fir.ref<!fir.array<?x?xcomplex<f32>>>, !fir.shape<2>) -> !fir.box<!fir.array<?x?xcomplex<f32>>>
+// ANSE: %[[VAL_14:.*]] = fir.undefined complex<f32>
+// ANSE: %[[VAL_15:.*]] = fir.insert_value %[[VAL_14]], %[[VAL_4]], [0 : index] : (complex<f32>, f32) -> complex<f32>
+// ANSE: %[[VAL_16:.*]] = fir.insert_value %[[VAL_15]], %[[VAL_4]], [1 : index] : (complex<f32>, f32) -> complex<f32>
+// ANSE: fir.do_loop %[[VAL_17:.*]] = %[[VAL_3]] to %[[VAL_9]] step %[[VAL_3]] unordered {
+// ANSE: fir.do_loop %[[VAL_18:.*]] = %[[VAL_3]] to %[[VAL_6]] step %[[VAL_3]] unordered {
+// ANSE: %[[VAL_19:.*]]:3 = fir.box_dims %[[VAL_13]], %[[VAL_2]] : (!fir.box<!fir.array<?x?xcomplex<f32>>>, index) -> (index, index, index)
+// ANSE: %[[VAL_20:.*]]:3 = fir.box_dims %[[VAL_13]], %[[VAL_3]] : (!fir.box<!fir.array<?x?xcomplex<f32>>>, index) -> (index, index, index)
+// ANSE: %[[VAL_21:.*]] = arith.subi %[[VAL_19]]#0, %[[VAL_3]] : index
+// ANSE: %[[VAL_22:.*]] = arith.addi %[[VAL_18]], %[[VAL_21]] : index
+// ANSE: %[[VAL_23:.*]] = arith.subi %[[VAL_20]]#0, %[[VAL_3]] : index
+// ANSE: %[[VAL_24:.*]] = arith.addi %[[VAL_17]], %[[VAL_23]] : index
+// ANSE: %[[VAL_25:.*]] = hlfir.designate %[[VAL_13]] (%[[VAL_22]], %[[VAL_24]]) : (!fir.box<!fir.array<?x?xcomplex<f32>>>, index, index) -> !fir.ref<complex<f32>>
+// ANSE: hlfir.assign %[[VAL_16]] to %[[VAL_25]] : complex<f32>, !fir.ref<complex<f32>>
+// ANSE: }
+// ANSE: }
+// ANSE: fir.do_loop %[[VAL_26:.*]] = %[[VAL_3]] to %[[VAL_7]] step %[[VAL_3]] {
+// ANSE: fir.do_loop %[[VAL_27:.*]] = %[[VAL_3]] to %[[VAL_9]] step %[[VAL_3]] {
+// ANSE: fir.do_loop %[[VAL_28:.*]] = %[[VAL_3]] to %[[VAL_6]] step %[[VAL_3]] {
+// ANSE: %[[VAL_29:.*]]:3 = fir.box_dims %[[VAL_13]], %[[VAL_2]] : (!fir.box<!fir.array<?x?xcomplex<f32>>>, index) -> (index, index, index)
+// ANSE: %[[VAL_30:.*]]:3 = fir.box_dims %[[VAL_13]], %[[VAL_3]] : (!fir.box<!fir.array<?x?xcomplex<f32>>>, index) -> (index, index, index)
+// ANSE: %[[VAL_31:.*]] = arith.subi %[[VAL_29]]#0, %[[VAL_3]] : index
+// ANSE: %[[VAL_32:.*]] = arith.addi %[[VAL_28]], %[[VAL_31]] : index
+// ANSE: %[[VAL_33:.*]] = arith.subi %[[VAL_30]]#0, %[[VAL_3]] : index
+// ANSE: %[[VAL_34:.*]] = arith.addi %[[VAL_27]], %[[VAL_33]] : index
+// ANSE: %[[VAL_35:.*]] = hlfir.designate %[[VAL_13]] (%[[VAL_32]], %[[VAL_34]]) : (!fir.box<!fir.array<?x?xcomplex<f32>>>, index, index) -> !fir.ref<complex<f32>>
+// ANSE: %[[VAL_36:.*]] = fir.load %[[VAL_35]] : !fir.ref<complex<f32>>
+// ANSE: %[[VAL_37:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_28]], %[[VAL_26]] : (!hlfir.expr<?x?xcomplex<f32>>, index, index) -> complex<f32>
+// ANSE: %[[VAL_38:.*]] = hlfir.apply %[[VAL_1]], %[[VAL_26]], %[[VAL_27]] : (!hlfir.expr<?x?xf16>, index, index) -> f16
+// ANSE: %[[VAL_39:.*]] = fir.undefined complex<f32>
+// ANSE: %[[VAL_40:.*]] = fir.insert_value %[[VAL_39]], %[[VAL_4]], [0 : index] : (complex<f32>, f32) -> complex<f32>
+// ANSE: %[[VAL_41:.*]] = fir.insert_value %[[VAL_40]], %[[VAL_4]], [1 : index] : (complex<f32>, f32) -> complex<f32>
+// ANSE: %[[VAL_42:.*]] = fir.convert %[[VAL_38]] : (f16) -> f32
+// ANSE: %[[VAL_43:.*]] = fir.insert_value %[[VAL_41]], %[[VAL_42]], [0 : index] : (complex<f32>, f32) -> complex<f32>
+// ANSE: %[[VAL_44:.*]] = fir.mulc %[[VAL_37]], %[[VAL_43]] : complex<f32>
+// ANSE: %[[VAL_45:.*]] = fir.addc %[[VAL_36]], %[[VAL_44]] : complex<f32>
+// ANSE: hlfir.assign %[[VAL_45]] to %[[VAL_35]] : complex<f32>, !fir.ref<complex<f32>>
+// ANSE: }
+// ANSE: }
+// ANSE: }
+// ANSE: }
+// ANSE: return %[[VAL_11]] : !hlfir.expr<?x?xcomplex<f32>>
+// ANSE: }
+
+// ELEMENTAL: %[[VAL_2:.*]] = arith.constant 1 : index
+// ELEMENTAL: %[[VAL_3:.*]] = arith.constant 0.000000e+00 : f32
+// ELEMENTAL: %[[VAL_4:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr<?x?xcomplex<f32>>) -> !fir.shape<2>
+// ELEMENTAL: %[[VAL_5:.*]] = hlfir.get_extent %[[VAL_4]] {dim = 0 : index} : (!fir.shape<2>) -> index
+// ELEMENTAL: %[[VAL_6:.*]] = hlfir.get_extent %[[VAL_4]] {dim = 1 : index} : (!fir.shape<2>) -> index
+// ELEMENTAL: %[[VAL_7:.*]] = hlfir.shape_of %[[VAL_1]] : (!hlfir.expr<?x?xf16>) -> !fir.shape<2>
+// ELEMENTAL: %[[VAL_8:.*]] = hlfir.get_extent %[[VAL_7]] {dim = 1 : index} : (!fir.shape<2>) -> index
+// ELEMENTAL: %[[VAL_9:.*]] = fir.shape %[[VAL_5]], %[[VAL_8]] : (index, index) -> !fir.shape<2>
+// ELEMENTAL: %[[VAL_10:.*]] = hlfir.elemental %[[VAL_9]] unordered : (!fir.shape<2>) -> !hlfir.expr<?x?xcomplex<f32>> {
+// ELEMENTAL: ^bb0(%[[VAL_11:.*]]: index, %[[VAL_12:.*]]: index):
+// ELEMENTAL: %[[VAL_13:.*]] = fir.undefined complex<f32>
+// ELEMENTAL: %[[VAL_14:.*]] = fir.insert_value %[[VAL_13]], %[[VAL_3]], [0 : index] : (complex<f32>, f32) -> complex<f32>
+// ELEMENTAL: %[[VAL_15:.*]] = fir.insert_value %[[VAL_14]], %[[VAL_3]], [1 : index] : (complex<f32>, f32) -> complex<f32>
+// ELEMENTAL: %[[VAL_16:.*]] = fir.do_loop %[[VAL_17:.*]] = %[[VAL_2]] to %[[VAL_6]] step %[[VAL_2]] iter_args(%[[VAL_18:.*]] = %[[VAL_15]]) -> (complex<f32>) {
+// ELEMENTAL: %[[VAL_19:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_11]], %[[VAL_17]] : (!hlfir.expr<?x?xcomplex<f32>>, index, index) -> complex<f32>
+// ELEMENTAL: %[[VAL_20:.*]] = hlfir.apply %[[VAL_1]], %[[VAL_17]], %[[VAL_12]] : (!hlfir.expr<?x?xf16>, index, index) -> f16
+// ELEMENTAL: %[[VAL_21:.*]] = fir.undefined complex<f32>
+// ELEMENTAL: %[[VAL_22:.*]] = fir.insert_value %[[VAL_21]], %[[VAL_3]], [0 : index] : (complex<f32>, f32) -> complex<f32>
+// ELEMENTAL: %[[VAL_23:.*]] = fir.insert_value %[[VAL_22]], %[[VAL_3]], [1 : index] : (complex<f32>, f32) -> complex<f32>
+// ELEMENTAL: %[[VAL_24:.*]] = fir.convert %[[VAL_20]] : (f16) -> f32
+// ELEMENTAL: %[[VAL_25:.*]] = fir.insert_value %[[VAL_23]], %[[VAL_24]], [0 : index] : (complex<f32>, f32) -> complex<f32>
+// ELEMENTAL: %[[VAL_26:.*]] = fir.mulc %[[VAL_19]], %[[VAL_25]] : complex<f32>
+// ELEMENTAL: %[[VAL_27:.*]] = fir.addc %[[VAL_18]], %[[VAL_26]] : complex<f32>
+// ELEMENTAL: fir.result %[[VAL_27]] : complex<f32>
+// ELEMENTAL: }
+// ELEMENTAL: hlfir.yield_element %[[VAL_16]] : complex<f32>
+// ELEMENTAL: }
+// ELEMENTAL: return %[[VAL_10]] : !hlfir.expr<?x?xcomplex<f32>>
+// ELEMENTAL: }
+
+func.func @matmul_matrix_matrix_logical(%arg0: !hlfir.expr<?x?x!fir.logical<1>>, %arg1: !hlfir.expr<?x?x!fir.logical<4>>) -> !hlfir.expr<?x?x!fir.logical<4>> {
+ %res = hlfir.matmul %arg0 %arg1 : (!hlfir.expr<?x?x!fir.logical<1>>, !hlfir.expr<?x?x!fir.logical<4>>) -> !hlfir.expr<?x?x!fir.logical<4>>
+ return %res : !hlfir.expr<?x?x!fir.logical<4>>
+}
+// ALL-LABEL: func.func @matmul_matrix_matrix_logical(
+// ALL-SAME: %[[VAL_0:.*]]: !hlfir.expr<?x?x!fir.logical<1>>,
+// ALL-SAME: %[[VAL_1:.*]]: !hlfir.expr<?x?x!fir.logical<4>>) -> !hlfir.expr<?x?x!fir.logical<4>> {
+
+// NOANSE: hlfir.matmul
+
+// ANSE: %[[VAL_2:.*]] = arith.constant 0 : index
+// ANSE: %[[VAL_3:.*]] = arith.constant 1 : index
+// ANSE: %[[VAL_4:.*]] = arith.constant false
+// ANSE: %[[VAL_5:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr<?x?x!fir.logical<1>>) -> !fir.shape<2>
+// ANSE: %[[VAL_6:.*]] = hlfir.get_extent %[[VAL_5]] {dim = 0 : index} : (!fir.shape<2>) -> index
+// ANSE: %[[VAL_7:.*]] = hlfir.get_extent %[[VAL_5]] {dim = 1 : index} : (!fir.shape<2>) -> index
+// ANSE: %[[VAL_8:.*]] = hlfir.shape_of %[[VAL_1]] : (!hlfir.expr<?x?x!fir.logical<4>>) -> !fir.shape<2>
+// ANSE: %[[VAL_9:.*]] = hlfir.get_extent %[[VAL_8]] {dim = 1 : index} : (!fir.shape<2>) -> index
+// ANSE: %[[VAL_10:.*]] = fir.shape %[[VAL_6]], %[[VAL_9]] : (index, index) -> !fir.shape<2>
+// ANSE: %[[VAL_11:.*]] = hlfir.eval_in_mem shape %[[VAL_10]] : (!fir.shape<2>) -> !hlfir.expr<?x?x!fir.logical<4>> {
+// ANSE: ^bb0(%[[VAL_12:.*]]: !fir.ref<!fir.array<?x?x!fir.logical<4>>>):
+// ANSE: %[[VAL_13:.*]] = fir.embox %[[VAL_12]](%[[VAL_10]]) : (!fir.ref<!fir.array<?x?x!fir.logical<4>>>, !fir.shape<2>) -> !fir.box<!fir.array<?x?x!fir.logical<4>>>
+// ANSE: %[[VAL_14:.*]] = fir.convert %[[VAL_4]] : (i1) -> !fir.logical<4>
+// ANSE: fir.do_loop %[[VAL_15:.*]] = %[[VAL_3]] to %[[VAL_9]] step %[[VAL_3]] unordered {
+// ANSE: fir.do_loop %[[VAL_16:.*]] = %[[VAL_3]] to %[[VAL_6]] step %[[VAL_3]] unordered {
+// ANSE: %[[VAL_17:.*]]:3 = fir.box_dims %[[VAL_13]], %[[VAL_2]] : (!fir.box<!fir.array<?x?x!fir.logical<4>>>, index) -> (index, index, index)
+// ANSE: %[[VAL_18:.*]]:3 = fir.box_dims %[[VAL_13]], %[[VAL_3]] : (!fir.box<!fir.array<?x?x!fir.logical<4>>>, index) -> (index, index, index)
+// ANSE: %[[VAL_19:.*]] = arith.subi %[[VAL_17]]#0, %[[VAL_3]] : index
+// ANSE: %[[VAL_20:.*]] = arith.addi %[[VAL_16]], %[[VAL_19]] : index
+// ANSE: %[[VAL_21:.*]] = arith.subi %[[VAL_18]]#0, %[[VAL_3]] : index
+// ANSE: %[[VAL_22:.*]] = arith.addi %[[VAL_15]], %[[VAL_21]] : index
+// ANSE: %[[VAL_23:.*]] = hlfir.designate %[[VAL_13]] (%[[VAL_20]], %[[VAL_22]]) : (!fir.box<!fir.array<?x?x!fir.logical<4>>>, index, index) -> !fir.ref<!fir.logical<4>>
+// ANSE: hlfir.assign %[[VAL_14]] to %[[VAL_23]] : !fir.logical<4>, !fir.ref<!fir.logical<4>>
+// ANSE: }
+// ANSE: }
+// ANSE: fir.do_loop %[[VAL_24:.*]] = %[[VAL_3]] to %[[VAL_7]] step %[[VAL_3]] unordered {
+// ANSE: fir.do_loop %[[VAL_25:.*]] = %[[VAL_3]] to %[[VAL_9]] step %[[VAL_3]] unordered {
+// ANSE: fir.do_loop %[[VAL_26:.*]] = %[[VAL_3]] to %[[VAL_6]] step %[[VAL_3]] unordered {
+// ANSE: %[[VAL_27:.*]]:3 = fir.box_dims %[[VAL_13]], %[[VAL_2]] : (!fir.box<!fir.array<?x?x!fir.logical<4>>>, index) -> (index, index, index)
+// ANSE: %[[VAL_28:.*]]:3 = fir.box_dims %[[VAL_13]], %[[VAL_3]] : (!fir.box<!fir.array<?x?x!fir.logical<4>>>, index) -> (index, index, index)
+// ANSE: %[[VAL_29:.*]] = arith.subi %[[VAL_27]]#0, %[[VAL_3]] : index
+// ANSE: %[[VAL_30:.*]] = arith.addi %[[VAL_26]], %[[VAL_29]] : index
+// ANSE: %[[VAL_31:.*]] = arith.subi %[[VAL_28]]#0, %[[VAL_3]] : index
+// ANSE: %[[VAL_32:.*]] = arith.addi %[[VAL_25]], %[[VAL_31]] : index
+// ANSE: %[[VAL_33:.*]] = hlfir.designate %[[VAL_13]] (%[[VAL_30]], %[[VAL_32]]) : (!fir.box<!fir.array<?x?x!fir.logical<4>>>, index, index) -> !fir.ref<!fir.logical<4>>
+// ANSE: %[[VAL_34:.*]] = fir.load %[[VAL_33]] : !fir.ref<!fir.logical<4>>
+// ANSE: %[[VAL_35:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_26]], %[[VAL_24]] : (!hlfir.expr<?x?x!fir.logical<1>>, index, index) -> !fir.logical<1>
+// ANSE: %[[VAL_36:.*]] = hlfir.apply %[[VAL_1]], %[[VAL_24]], %[[VAL_25]] : (!hlfir.expr<?x?x!fir.logical<4>>, index, index) -> !fir.logical<4>
+// ANSE: %[[VAL_37:.*]] = fir.convert %[[VAL_34]] : (!fir.logical<4>) -> i1
+// ANSE: %[[VAL_38:.*]] = fir.convert %[[VAL_35]] : (!fir.logical<1>) -> i1
+// ANSE: %[[VAL_39:.*]] = fir.convert %[[VAL_36]] : (!fir.logical<4>) -> i1
+// ANSE: %[[VAL_40:.*]] = arith.andi %[[VAL_38]], %[[VAL_39]] : i1
+// ANSE: %[[VAL_41:.*]] = arith.ori %[[VAL_37]], %[[VAL_40]] : i1
+// ANSE: %[[VAL_42:.*]] = fir.convert %[[VAL_41]] : (i1) -> !fir.logical<4>
+// ANSE: hlfir.assign %[[VAL_42]] to %[[VAL_33]] : !fir.logical<4>, !fir.ref<!fir.logical<4>>
+// ANSE: }
+// ANSE: }
+// ANSE: }
+// ANSE: }
+// ANSE: return %[[VAL_11]] : !hlfir.expr<?x?x!fir.logical<4>>
+// ANSE: }
+
+// ELEMENTAL: %[[VAL_2:.*]] = arith.constant 1 : index
+// ELEMENTAL: %[[VAL_3:.*]] = arith.constant false
+// ELEMENTAL: %[[VAL_4:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr<?x?x!fir.logical<1>>) -> !fir.shape<2>
+// ELEMENTAL: %[[VAL_5:.*]] = hlfir.get_extent %[[VAL_4]] {dim = 0 : index} : (!fir.shape<2>) -> index
+// ELEMENTAL: %[[VAL_6:.*]] = hlfir.get_extent %[[VAL_4]] {dim = 1 : index} : (!fir.shape<2>) -> index
+// ELEMENTAL: %[[VAL_7:.*]] = hlfir.shape_of %[[VAL_1]] : (!hlfir.expr<?x?x!fir.logical<4>>) -> !fir.shape<2>
+// ELEMENTAL: %[[VAL_8:.*]] = hlfir.get_extent %[[VAL_7]] {dim = 1 : index} : (!fir.shape<2>) -> index
+// ELEMENTAL: %[[VAL_9:.*]] = fir.shape %[[VAL_5]], %[[VAL_8]] : (index, index) -> !fir.shape<2>
+// ELEMENTAL: %[[VAL_10:.*]] = hlfir.elemental %[[VAL_9]] unordered : (!fir.shape<2>) -> !hlfir.expr<?x?x!fir.logical<4>> {
+// ELEMENTAL: ^bb0(%[[VAL_11:.*]]: index, %[[VAL_12:.*]]: index):
+// ELEMENTAL: %[[VAL_13:.*]] = fir.convert %[[VAL_3]] : (i1) -> !fir.logical<4>
+// ELEMENTAL: %[[VAL_14:.*]] = fir.do_loop %[[VAL_15:.*]] = %[[VAL_2]] to %[[VAL_6]] step %[[VAL_2]] unordered iter_args(%[[VAL_16:.*]] = %[[VAL_13]]) -> (!fir.logical<4>) {
+// ELEMENTAL: %[[VAL_17:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_11]], %[[VAL_15]] : (!hlfir.expr<?x?x!fir.logical<1>>, index, index) -> !fir.logical<1>
+// ELEMENTAL: %[[VAL_18:.*]] = hlfir.apply %[[VAL_1]], %[[VAL_15]], %[[VAL_12]] : (!hlfir.expr<?x?x!fir.logical<4>>, index, index) -> !fir.logical<4>
+// ELEMENTAL: %[[VAL_19:.*]] = fir.convert %[[VAL_16]] : (!fir.logical<4>) -> i1
+// ELEMENTAL: %[[VAL_20:.*]] = fir.convert %[[VAL_17]] : (!fir.logical<1>) -> i1
+// ELEMENTAL: %[[VAL_21:.*]] = fir.convert %[[VAL_18]] : (!fir.logical<4>) -> i1
+// ELEMENTAL: %[[VAL_22:.*]] = arith.andi %[[VAL_20]], %[[VAL_21]] : i1
+// ELEMENTAL: %[[VAL_23:.*]] = arith.ori %[[VAL_19]], %[[VAL_22]] : i1
+// ELEMENTAL: %[[VAL_24:.*]] = fir.convert %[[VAL_23]] : (i1) -> !fir.logical<4>
+// ELEMENTAL: fir.result %[[VAL_24]] : !fir.logical<4>
+// ELEMENTAL: }
+// ELEMENTAL: hlfir.yield_element %[[VAL_14]] : !fir.logical<4>
+// ELEMENTAL: }
+// ELEMENTAL: return %[[VAL_10]] : !hlfir.expr<?x?x!fir.logical<4>>
+// ELEMENTAL: }
+
+func.func @matmul_matrix_vector_real(%arg0: !hlfir.expr<?x?xf32>, %arg1: !hlfir.expr<?xf16>) -> !hlfir.expr<?xf32> {
+ %res = hlfir.matmul %arg0 %arg1 : (!hlfir.expr<?x?xf32>, !hlfir.expr<?xf16>) -> !hlfir.expr<?xf32>
+ return %res : !hlfir.expr<?xf32>
+}
+// ALL-LABEL: func.func @matmul_matrix_vector_real(
+// ALL-SAME: %[[VAL_0:.*]]: !hlfir.expr<?x?xf32>,
+// ALL-SAME: %[[VAL_1:.*]]: !hlfir.expr<?xf16>) -> !hlfir.expr<?xf32> {
+
+// NOANSE: hlfir.matmul
+
+// ANSE: %[[VAL_2:.*]] = arith.constant 0 : index
+// ANSE: %[[VAL_3:.*]] = arith.constant 1 : index
+// ANSE: %[[VAL_4:.*]] = arith.constant 0.000000e+00 : f32
+// ANSE: %[[VAL_5:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr<?x?xf32>) -> !fir.shape<2>
+// ANSE: %[[VAL_6:.*]] = hlfir.get_extent %[[VAL_5]] {dim = 0 : index} : (!fir.shape<2>) -> index
+// ANSE: %[[VAL_7:.*]] = hlfir.get_extent %[[VAL_5]] {dim = 1 : index} : (!fir.shape<2>) -> index
+// ANSE: %[[VAL_8:.*]] = fir.shape %[[VAL_6]] : (index) -> !fir.shape<1>
+// ANSE: %[[VAL_9:.*]] = hlfir.eval_in_mem shape %[[VAL_8]] : (!fir.shape<1>) -> !hlfir.expr<?xf32> {
+// ANSE: ^bb0(%[[VAL_10:.*]]: !fir.ref<!fir.array<?xf32>>):
+// ANSE: %[[VAL_11:.*]] = fir.embox %[[VAL_10]](%[[VAL_8]]) : (!fir.ref<!fir.array<?xf32>>, !fir.shape<1>) -> !fir.box<!fir.array<?xf32>>
+// ANSE: fir.do_loop %[[VAL_12:.*]] = %[[VAL_3]] to %[[VAL_6]] step %[[VAL_3]] unordered {
+// ANSE: %[[VAL_13:.*]]:3 = fir.box_dims %[[VAL_11]], %[[VAL_2]] : (!fir.box<!fir.array<?xf32>>, index) -> (index, index, index)
+// ANSE: %[[VAL_14:.*]] = arith.subi %[[VAL_13]]#0, %[[VAL_3]] : index
+// ANSE: %[[VAL_15:.*]] = arith.addi %[[VAL_12]], %[[VAL_14]] : index
+// ANSE: %[[VAL_16:.*]] = hlfir.designate %[[VAL_11]] (%[[VAL_15]]) : (!fir.box<!fir.array<?xf32>>, index) -> !fir.ref<f32>
+// ANSE: hlfir.assign %[[VAL_4]] to %[[VAL_16]] : f32, !fir.ref<f32>
+// ANSE: }
+// ANSE: fir.do_loop %[[VAL_17:.*]] = %[[VAL_3]] to %[[VAL_7]] step %[[VAL_3]] {
+// ANSE: fir.do_loop %[[VAL_18:.*]] = %[[VAL_3]] to %[[VAL_6]] step %[[VAL_3]] {
+// ANSE: %[[VAL_19:.*]]:3 = fir.box_dims %[[VAL_11]], %[[VAL_2]] : (!fir.box<!fir.array<?xf32>>, index) -> (index, index, index)
+// ANSE: %[[VAL_20:.*]] = arith.subi %[[VAL_19]]#0, %[[VAL_3]] : index
+// ANSE: %[[VAL_21:.*]] = arith.addi %[[VAL_18]], %[[VAL_20]] : index
+// ANSE: %[[VAL_22:.*]] = hlfir.designate %[[VAL_11]] (%[[VAL_21]]) : (!fir.box<!fir.array<?xf32>>, index) -> !fir.ref<f32>
+// ANSE: %[[VAL_23:.*]] = fir.load %[[VAL_22]] : !fir.ref<f32>
+// ANSE: %[[VAL_24:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_18]], %[[VAL_17]] : (!hlfir.expr<?x?xf32>, index, index) -> f32
+// ANSE: %[[VAL_25:.*]] = hlfir.apply %[[VAL_1]], %[[VAL_17]] : (!hlfir.expr<?xf16>, index) -> f16
+// ANSE: %[[VAL_26:.*]] = fir.convert %[[VAL_25]] : (f16) -> f32
+// ANSE: %[[VAL_27:.*]] = arith.mulf %[[VAL_24]], %[[VAL_26]] : f32
+// ANSE: %[[VAL_28:.*]] = arith.addf %[[VAL_23]], %[[VAL_27]] : f32
+// ANSE: hlfir.assign %[[VAL_28]] to %[[VAL_22]] : f32, !fir.ref<f32>
+// ANSE: }
+// ANSE: }
+// ANSE: }
+// ANSE: return %[[VAL_9]] : !hlfir.expr<?xf32>
+// ANSE: }
+
+// ELEMENTAL: %[[VAL_2:.*]] = arith.constant 1 : index
+// ELEMENTAL: %[[VAL_3:.*]] = arith.constant 0.000000e+00 : f32
+// ELEMENTAL: %[[VAL_4:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr<?x?xf32>) -> !fir.shape<2>
+// ELEMENTAL: %[[VAL_5:.*]] = hlfir.get_extent %[[VAL_4]] {dim = 0 : index} : (!fir.shape<2>) -> index
+// ELEMENTAL: %[[VAL_6:.*]] = hlfir.get_extent %[[VAL_4]] {dim = 1 : index} : (!fir.shape<2>) -> index
+// ELEMENTAL: %[[VAL_7:.*]] = fir.shape %[[VAL_5]] : (index) -> !fir.shape<1>
+// ELEMENTAL: %[[VAL_8:.*]] = hlfir.elemental %[[VAL_7]] unordered : (!fir.shape<1>) -> !hlfir.expr<?xf32> {
+// ELEMENTAL: ^bb0(%[[VAL_9:.*]]: index):
+// ELEMENTAL: %[[VAL_10:.*]] = fir.do_loop %[[VAL_11:.*]] = %[[VAL_2]] to %[[VAL_6]] step %[[VAL_2]] iter_args(%[[VAL_12:.*]] = %[[VAL_3]]) -> (f32) {
+// ELEMENTAL: %[[VAL_13:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_9]], %[[VAL_11]] : (!hlfir.expr<?x?xf32>, index, index) -> f32
+// ELEMENTAL: %[[VAL_14:.*]] = hlfir.apply %[[VAL_1]], %[[VAL_11]] : (!hlfir.expr<?xf16>, index) -> f16
+// ELEMENTAL: %[[VAL_15:.*]] = fir.convert %[[VAL_14]] : (f16) -> f32
+// ELEMENTAL: %[[VAL_16:.*]] = arith.mulf %[[VAL_13]], %[[VAL_15]] : f32
+// ELEMENTAL: %[[VAL_17:.*]] = arith.addf %[[VAL_12]], %[[VAL_16]] : f32
+// ELEMENTAL: fir.result %[[VAL_17]] : f32
+// ELEMENTAL: }
+// ELEMENTAL: hlfir.yield_element %[[VAL_10]] : f32
+// ELEMENTAL: }
+// ELEMENTAL: return %[[VAL_8]] : !hlfir.expr<?xf32>
+// ELEMENTAL: }
+
+func.func @matmul_vector_matrix_real(%arg0: !hlfir.expr<?xf32>, %arg1: !hlfir.expr<?x?xf16>) -> !hlfir.expr<?xf32> {
+ %res = hlfir.matmul %arg0 %arg1 : (!hlfir.expr<?xf32>, !hlfir.expr<?x?xf16>) -> !hlfir.expr<?xf32>
+ return %res : !hlfir.expr<?xf32>
+}
+// ALL-LABEL: func.func @matmul_vector_matrix_real(
+// ALL-SAME: %[[VAL_0:.*]]: !hlfir.expr<?xf32>,
+// ALL-SAME: %[[VAL_1:.*]]: !hlfir.expr<?x?xf16>) -> !hlfir.expr<?xf32> {
+
+// NOANSE: hlfir.matmul
+
+// ANSE: %[[VAL_2:.*]] = arith.constant 0 : index
+// ANSE: %[[VAL_3:.*]] = arith.constant 1 : index
+// ANSE: %[[VAL_4:.*]] = arith.constant 0.000000e+00 : f32
+// ANSE: %[[VAL_5:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr<?xf32>) -> !fir.shape<1>
+// ANSE: %[[VAL_6:.*]] = hlfir.get_extent %[[VAL_5]] {dim = 0 : index} : (!fir.shape<1>) -> index
+// ANSE: %[[VAL_7:.*]] = hlfir.shape_of %[[VAL_1]] : (!hlfir.expr<?x?xf16>) -> !fir.shape<2>
+// ANSE: %[[VAL_8:.*]] = hlfir.get_extent %[[VAL_7]] {dim = 1 : index} : (!fir.shape<2>) -> index
+// ANSE: %[[VAL_9:.*]] = fir.shape %[[VAL_8]] : (index) -> !fir.shape<1>
+// ANSE: %[[VAL_10:.*]] = hlfir.eval_in_mem shape %[[VAL_9]] : (!fir.shape<1>) -> !hlfir.expr<?xf32> {
+// ANSE: ^bb0(%[[VAL_11:.*]]: !fir.ref<!fir.array<?xf32>>):
+// ANSE: %[[VAL_12:.*]] = fir.embox %[[VAL_11]](%[[VAL_9]]) : (!fir.ref<!fir.array<?xf32>>, !fir.shape<1>) -> !fir.box<!fir.array<?xf32>>
+// ANSE: fir.do_loop %[[VAL_13:.*]] = %[[VAL_3]] to %[[VAL_8]] step %[[VAL_3]] unordered {
+// ANSE: %[[VAL_14:.*]]:3 = fir.box_dims %[[VAL_12]], %[[VAL_2]] : (!fir.box<!fir.array<?xf32>>, index) -> (index, index, index)
+// ANSE: %[[VAL_15:.*]] = arith.subi %[[VAL_14]]#0, %[[VAL_3]] : index
+// ANSE: %[[VAL_16:.*]] = arith.addi %[[VAL_13]], %[[VAL_15]] : index
+// ANSE: %[[VAL_17:.*]] = hlfir.designate %[[VAL_12]] (%[[VAL_16]]) : (!fir.box<!fir.array<?xf32>>, index) -> !fir.ref<f32>
+// ANSE: hlfir.assign %[[VAL_4]] to %[[VAL_17]] : f32, !fir.ref<f32>
+// ANSE: }
+// ANSE: fir.do_loop %[[VAL_18:.*]] = %[[VAL_3]] to %[[VAL_6]] step %[[VAL_3]] {
+// ANSE: fir.do_loop %[[VAL_19:.*]] = %[[VAL_3]] to %[[VAL_8]] step %[[VAL_3]] {
+// ANSE: %[[VAL_20:.*]]:3 = fir.box_dims %[[VAL_12]], %[[VAL_2]] : (!fir.box<!fir.array<?xf32>>, index) -> (index, index, index)
+// ANSE: %[[VAL_21:.*]] = arith.subi %[[VAL_20]]#0, %[[VAL_3]] : index
+// ANSE: %[[VAL_22:.*]] = arith.addi %[[VAL_19]], %[[VAL_21]] : index
+// ANSE: %[[VAL_23:.*]] = hlfir.designate %[[VAL_12]] (%[[VAL_22]]) : (!fir.box<!fir.array<?xf32>>, index) -> !fir.ref<f32>
+// ANSE: %[[VAL_24:.*]] = fir.load %[[VAL_23]] : !fir.ref<f32>
+// ANSE: %[[VAL_25:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_18]] : (!hlfir.expr<?xf32>, index) -> f32
+// ANSE: %[[VAL_26:.*]] = hlfir.apply %[[VAL_1]], %[[VAL_18]], %[[VAL_19]] : (!hlfir.expr<?x?xf16>, index, index) -> f16
+// ANSE: %[[VAL_27:.*]] = fir.convert %[[VAL_26]] : (f16) -> f32
+// ANSE: %[[VAL_28:.*]] = arith.mulf %[[VAL_25]], %[[VAL_27]] : f32
+// ANSE: %[[VAL_29:.*]] = arith.addf %[[VAL_24]], %[[VAL_28]] : f32
+// ANSE: hlfir.assign %[[VAL_29]] to %[[VAL_23]] : f32, !fir.ref<f32>
+// ANSE: }
+// ANSE: }
+// ANSE: }
+// ANSE: return %[[VAL_10]] : !hlfir.expr<?xf32>
+// ANSE: }
+
+// ELEMENTAL: %[[VAL_2:.*]] = arith.constant 1 : index
+// ELEMENTAL: %[[VAL_3:.*]] = arith.constant 0.000000e+00 : f32
+// ELEMENTAL: %[[VAL_4:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr<?xf32>) -> !fir.shape<1>
+// ELEMENTAL: %[[VAL_5:.*]] = hlfir.get_extent %[[VAL_4]] {dim = 0 : index} : (!fir.shape<1>) -> index
+// ELEMENTAL: %[[VAL_6:.*]] = hlfir.shape_of %[[VAL_1]] : (!hlfir.expr<?x?xf16>) -> !fir.shape<2>
+// ELEMENTAL: %[[VAL_7:.*]] = hlfir.get_extent %[[VAL_6]] {dim = 1 : index} : (!fir.shape<2>) -> index
+// ELEMENTAL: %[[VAL_8:.*]] = fir.shape %[[VAL_7]] : (index) -> !fir.shape<1>
+// ELEMENTAL: %[[VAL_9:.*]] = hlfir.elemental %[[VAL_8]] unordered : (!fir.shape<1>) -> !hlfir.expr<?xf32> {
+// ELEMENTAL: ^bb0(%[[VAL_10:.*]]: index):
+// ELEMENTAL: %[[VAL_11:.*]] = fir.do_loop %[[VAL_12:.*]] = %[[VAL_2]] to %[[VAL_5]] step %[[VAL_2]] iter_args(%[[VAL_13:.*]] = %[[VAL_3]]) -> (f32) {
+// ELEMENTAL: %[[VAL_14:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_12]] : (!hlfir.expr<?xf32>, index) -> f32
+// ELEMENTAL: %[[VAL_15:.*]] = hlfir.apply %[[VAL_1]], %[[VAL_12]], %[[VAL_10]] : (!hlfir.expr<?x?xf16>, index, index) -> f16
+// ELEMENTAL: %[[VAL_16:.*]] = fir.convert %[[VAL_15]] : (f16) -> f32
+// ELEMENTAL: %[[VAL_17:.*]] = arith.mulf %[[VAL_14]], %[[VAL_16]] : f32
+// ELEMENTAL: %[[VAL_18:.*]] = arith.addf %[[VAL_13]], %[[VAL_17]] : f32
+// ELEMENTAL: fir.result %[[VAL_18]] : f32
+// ELEMENTAL: }
+// ELEMENTAL: hlfir.yield_element %[[VAL_11]] : f32
+// ELEMENTAL: }
+// ELEMENTAL: return %[[VAL_9]] : !hlfir.expr<?xf32>
+// ELEMENTAL: }
+
+func.func @matmul_transpose_matrix_matrix_integer(%arg0: !hlfir.expr<?x?xi16>, %arg1: !hlfir.expr<?x?xi32>) -> !hlfir.expr<?x?xi32> {
+ %res = hlfir.matmul_transpose %arg0 %arg1 : (!hlfir.expr<?x?xi16>, !hlfir.expr<?x?xi32>) -> !hlfir.expr<?x?xi32>
+ return %res : !hlfir.expr<?x?xi32>
+}
+// ALL-LABEL: func.func @matmul_transpose_matrix_matrix_integer(
+// ALL-SAME: %[[VAL_0:.*]]: !hlfir.expr<?x?xi16>,
+// ALL-SAME: %[[VAL_1:.*]]: !hlfir.expr<?x?xi32>) -> !hlfir.expr<?x?xi32> {
+// ALL: %[[VAL_2:.*]] = arith.constant 1 : index
+// ALL: %[[VAL_3:.*]] = arith.constant 0 : i32
+// ALL: %[[VAL_4:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr<?x?xi16>) -> !fir.shape<2>
+// ALL: %[[VAL_5:.*]] = hlfir.get_extent %[[VAL_4]] {dim = 0 : index} : (!fir.shape<2>) -> index
+// ALL: %[[VAL_6:.*]] = hlfir.get_extent %[[VAL_4]] {dim = 1 : index} : (!fir.shape<2>) -> index
+// ALL: %[[VAL_7:.*]] = hlfir.shape_of %[[VAL_1]] : (!hlfir.expr<?x?xi32>) -> !fir.shape<2>
+// ALL: %[[VAL_8:.*]] = hlfir.get_extent %[[VAL_7]] {dim = 1 : index} : (!fir.shape<2>) -> index
+// ALL: %[[VAL_9:.*]] = fir.shape %[[VAL_6]], %[[VAL_8]] : (index, index) -> !fir.shape<2>
+// ALL: %[[VAL_10:.*]] = hlfir.elemental %[[VAL_9]] unordered : (!fir.shape<2>) -> !hlfir.expr<?x?xi32> {
+// ALL: ^bb0(%[[VAL_11:.*]]: index, %[[VAL_12:.*]]: index):
+// ALL: %[[VAL_13:.*]] = fir.do_loop %[[VAL_14:.*]] = %[[VAL_2]] to %[[VAL_5]] step %[[VAL_2]] unordered iter_args(%[[VAL_15:.*]] = %[[VAL_3]]) -> (i32) {
+// ALL: %[[VAL_16:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_14]], %[[VAL_11]] : (!hlfir.expr<?x?xi16>, index, index) -> i16
+// ALL: %[[VAL_17:.*]] = hlfir.apply %[[VAL_1]], %[[VAL_14]], %[[VAL_12]] : (!hlfir.expr<?x?xi32>, index, index) -> i32
+// ALL: %[[VAL_18:.*]] = fir.convert %[[VAL_16]] : (i16) -> i32
+// ALL: %[[VAL_19:.*]] = arith.muli %[[VAL_18]], %[[VAL_17]] : i32
+// ALL: %[[VAL_20:.*]] = arith.addi %[[VAL_15]], %[[VAL_19]] : i32
+// ALL: fir.result %[[VAL_20]] : i32
+// ALL: }
+// ALL: hlfir.yield_element %[[VAL_13]] : i32
+// ALL: }
+// ALL: return %[[VAL_10]] : !hlfir.expr<?x?xi32>
+// ALL: }
+
+func.func @matmul_transpose_matrix_vector_real(%arg0: !hlfir.expr<?x?xf32>, %arg1: !hlfir.expr<?xf16>) -> !hlfir.expr<?xf32> {
+ %res = hlfir.matmul_transpose %arg0 %arg1 : (!hlfir.expr<?x?xf32>, !hlfir.expr<?xf16>) -> !hlfir.expr<?xf32>
+ return %res : !hlfir.expr<?xf32>
+}
+// ALL-LABEL: func.func @matmul_transpose_matrix_vector_real(
+// ALL-SAME: %[[VAL_0:.*]]: !hlfir.expr<?x?xf32>,
+// ALL-SAME: %[[VAL_1:.*]]: !hlfir.expr<?xf16>) -> !hlfir.expr<?xf32> {
+// ALL: %[[VAL_2:.*]] = arith.constant 1 : index
+// ALL: %[[VAL_3:.*]] = arith.constant 0.000000e+00 : f32
+// ALL: %[[VAL_4:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr<?x?xf32>) -> !fir.shape<2>
+// ALL: %[[VAL_5:.*]] = hlfir.get_extent %[[VAL_4]] {dim = 0 : index} : (!fir.shape<2>) -> index
+// ALL: %[[VAL_6:.*]] = hlfir.get_extent %[[VAL_4]] {dim = 1 : index} : (!fir.shape<2>) -> index
+// ALL: %[[VAL_7:.*]] = fir.shape %[[VAL_6]] : (index) -> !fir.shape<1>
+// ALL: %[[VAL_8:.*]] = hlfir.elemental %[[VAL_7]] unordered : (!fir.shape<1>) -> !hlfir.expr<?xf32> {
+// ALL: ^bb0(%[[VAL_9:.*]]: index):
+// ALL: %[[VAL_10:.*]] = fir.do_loop %[[VAL_11:.*]] = %[[VAL_2]] to %[[VAL_5]] step %[[VAL_2]] iter_args(%[[VAL_12:.*]] = %[[VAL_3]]) -> (f32) {
+// ALL: %[[VAL_13:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_11]], %[[VAL_9]] : (!hlfir.expr<?x?xf32>, index, index) -> f32
+// ALL: %[[VAL_14:.*]] = hlfir.apply %[[VAL_1]], %[[VAL_11]] : (!hlfir.expr<?xf16>, index) -> f16
+// ALL: %[[VAL_15:.*]] = fir.convert %[[VAL_14]] : (f16) -> f32
+// ALL: %[[VAL_16:.*]] = arith.mulf %[[VAL_13]], %[[VAL_15]] : f32
+// ALL: %[[VAL_17:.*]] = arith.addf %[[VAL_12]], %[[VAL_16]] : f32
+// ALL: fir.result %[[VAL_17]] : f32
+// ALL: }
+// ALL: hlfir.yield_element %[[VAL_10]] : f32
+// ALL: }
+// ALL: return %[[VAL_8]] : !hlfir.expr<?xf32>
+// ALL: }
+
+// Check that the inner-product loop uses the best known extent
+// of the input matrices:
+func.func @matmul_matrix_matrix_deduce_bounds(%arg0: !hlfir.expr<?x10xi16>, %arg1: !hlfir.expr<?x?xi32>, %arg2: !hlfir.expr<10x?xi16>) -> (!hlfir.expr<?x?xi32>, !hlfir.expr<?x?xi32>) {
+ %res1 = hlfir.matmul %arg0 %arg1 : (!hlfir.expr<?x10xi16>, !hlfir.expr<?x?xi32>) -> !hlfir.expr<?x?xi32>
+ %res2 = hlfir.matmul %arg1 %arg2 : (!hlfir.expr<?x?xi32>, !hlfir.expr<10x?xi16>) -> !hlfir.expr<?x?xi32>
+ return %res1, %res2 : !hlfir.expr<?x?xi32>, !hlfir.expr<?x?xi32>
+}
+// ALL-LABEL: func.func @matmul_matrix_matrix_deduce_bounds(
+
+// ANSE: %[[VAL_6:.*]] = arith.constant 10 : index
+// ANSE: hlfir.eval_in_mem shape {{.*}}
+// ANSE: fir.do_loop
+// ANSE: fir.do_loop
+// ANSE: fir.do_loop %{{.*}} = %{{.*}} to %[[VAL_6]]
+// ANSE: fir.do_loop
+// ANSE: fir.do_loop
+// ANSE: hlfir.eval_in_mem shape {{.*}}
+// ANSE: fir.do_loop
+// ANSE: fir.do_loop
+// ANSE: fir.do_loop %{{.*}} = %{{.*}} to %[[VAL_6]]
+// ANSE: fir.do_loop
+// ANSE: fir.do_loop
+
+// ELEMENTAL: %[[VAL_5:.*]] = arith.constant 10 : index
+// ELEMENTAL: hlfir.elemental %{{.*}}
+// ELEMENTAL: fir.do_loop %{{.*}} = %{{.*}} to %[[VAL_5]]
+// ELEMENTAL: hlfir.elemental %{{.*}}
+// ELEMENTAL: fir.do_loop %{{.*}} = %{{.*}} to %[[VAL_5]]
>From 27ec9ad13578da08c7279c016fd5434d85112ae5 Mon Sep 17 00:00:00 2001
From: Slava Zakharin <szakharin at nvidia.com>
Date: Mon, 13 Jan 2025 16:19:05 -0800
Subject: [PATCH 2/3] Removed commented out code.
---
.../Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp | 7 -------
1 file changed, 7 deletions(-)
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
index 0fd535b4290799..e0914e0f275fcb 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
@@ -706,7 +706,6 @@ class MatmulConversion : public mlir::OpRewritePattern<Op> {
-> llvm::SmallVector<mlir::Value, 0> {
hlfir::Entity resultElement =
hlfir::getElementAt(loc, builder, result, oneBasedIndices);
- // builder.create<fir::StoreOp>(loc, initValue, resultElement);
builder.create<hlfir::AssignOp>(loc, initValue, resultElement);
return {};
};
@@ -742,8 +741,6 @@ class MatmulConversion : public mlir::OpRewritePattern<Op> {
loc, builder, resultElementType, resultElementValue,
lhsElementValue, rhsElementValue);
builder.create<hlfir::AssignOp>(loc, productValue, resultElement);
- // builder.create<fir::StoreOp>(loc, productValue,
- // resultElement);
return {};
};
@@ -781,8 +778,6 @@ class MatmulConversion : public mlir::OpRewritePattern<Op> {
loc, builder, resultElementType, resultElementValue,
lhsElementValue, rhsElementValue);
builder.create<hlfir::AssignOp>(loc, productValue, resultElement);
- // builder.create<fir::StoreOp>(loc, productValue,
- // resultElement);
return {};
};
hlfir::genLoopNestWithReductions(
@@ -815,8 +810,6 @@ class MatmulConversion : public mlir::OpRewritePattern<Op> {
loc, builder, resultElementType, resultElementValue,
lhsElementValue, rhsElementValue);
builder.create<hlfir::AssignOp>(loc, productValue, resultElement);
- // builder.create<fir::StoreOp>(loc, productValue,
- // resultElement);
return {};
};
hlfir::genLoopNestWithReductions(
>From 56cdaf4819dce6990e67b4a75aa5b3a775effc8f Mon Sep 17 00:00:00 2001
From: Slava Zakharin <szakharin at nvidia.com>
Date: Tue, 14 Jan 2025 10:13:06 -0800
Subject: [PATCH 3/3] Updated comments.
---
.../Transforms/SimplifyHLFIRIntrinsics.cpp | 17 ++++++++++++++---
1 file changed, 14 insertions(+), 3 deletions(-)
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
index e0914e0f275fcb..0fe3620b7f1ae3 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
@@ -534,10 +534,14 @@ class MatmulConversion : public mlir::OpRewritePattern<Op> {
// in their lowermost dimensions.
// Especially, when LLVM can recognize the continuity
// and vectorize the loops properly.
- // TODO: we need to recognize the cases when the continuity
+ // Note that the contiguous MATMUL inlining is correct
+ // even when the input arrays are not contiguous.
+ // TODO: we can try to recognize the cases when the continuity
// is not statically obvious and try to generate an explicitly
- // continuous version under a dynamic check. The fallback
- // implementation may use genElementalMatmul() with
+ // continuous version under a dynamic check. This should allow
+ // LLVM to vectorize the loops better. Note that this can
+ // also be postponed up to the LoopVersioning pass.
+ // The fallback implementation may use genElementalMatmul() with
// an hlfir.assign into the result of eval_in_mem.
mlir::LogicalResult rewriteResult =
genContiguousMatmul(loc, builder, hlfir::Entity{resultArray},
@@ -629,6 +633,13 @@ class MatmulConversion : public mlir::OpRewritePattern<Op> {
if (mlir::isa<fir::LogicalType>(type))
return builder.createConvert(loc, builder.getIntegerType(1), value);
+ // TODO: the multiplications/additions by/of zero resulting from
+ // complex * real are optimized by LLVM under -fno-signed-zeros
+ // -fno-honor-nans.
+ // We can make them disappear by default if we:
+ // * either expand the complex multiplication into real
+ // operations, OR
+ // * set nnan nsz fast-math flags to the complex operations.
if (fir::isa_complex(type) && !fir::isa_complex(value.getType())) {
mlir::Value zeroCmplx = fir::factory::createZeroValue(builder, loc, type);
fir::factory::Complex helper(builder, loc);
More information about the flang-commits
mailing list