[flang-commits] [flang] afc43a7 - Revert "[flang] Inline hlfir.dot_product. (#123143)"
Philip Reames via flang-commits
flang-commits at lists.llvm.org
Thu Jan 16 17:39:00 PST 2025
Author: Philip Reames
Date: 2025-01-16T17:38:40-08:00
New Revision: afc43a7b626ae07f56e6534320e0b46d26070750
URL: https://github.com/llvm/llvm-project/commit/afc43a7b626ae07f56e6534320e0b46d26070750
DIFF: https://github.com/llvm/llvm-project/commit/afc43a7b626ae07f56e6534320e0b46d26070750.diff
LOG: Revert "[flang] Inline hlfir.dot_product. (#123143)"
This reverts commit 9a6433f0ff1b8e294ac785ea3b992304574e0d8f. ninja check-flang on x86 host fails to compile.
Added:
Modified:
flang/include/flang/Optimizer/Builder/HLFIRTools.h
flang/lib/Optimizer/Builder/HLFIRTools.cpp
flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
Removed:
flang/test/HLFIR/simplify-hlfir-intrinsics-dotproduct.fir
################################################################################
diff --git a/flang/include/flang/Optimizer/Builder/HLFIRTools.h b/flang/include/flang/Optimizer/Builder/HLFIRTools.h
index dc439fb323f88a..6e85b8f4ddf86e 100644
--- a/flang/include/flang/Optimizer/Builder/HLFIRTools.h
+++ b/flang/include/flang/Optimizer/Builder/HLFIRTools.h
@@ -513,12 +513,6 @@ genTypeAndKindConvert(mlir::Location loc, fir::FirOpBuilder &builder,
Entity loadElementAt(mlir::Location loc, fir::FirOpBuilder &builder,
Entity entity, mlir::ValueRange oneBasedIndices);
-/// Return a vector of extents for the given entity.
-/// The function creates new operations, but tries to clean-up
-/// after itself.
-llvm::SmallVector<mlir::Value>
-genExtentsVector(mlir::Location loc, fir::FirOpBuilder &builder, Entity entity);
-
} // namespace hlfir
#endif // FORTRAN_OPTIMIZER_BUILDER_HLFIRTOOLS_H
diff --git a/flang/lib/Optimizer/Builder/HLFIRTools.cpp b/flang/lib/Optimizer/Builder/HLFIRTools.cpp
index 66b2298a986b11..5e5d0bbd681326 100644
--- a/flang/lib/Optimizer/Builder/HLFIRTools.cpp
+++ b/flang/lib/Optimizer/Builder/HLFIRTools.cpp
@@ -1421,15 +1421,3 @@ hlfir::Entity hlfir::loadElementAt(mlir::Location loc,
return loadTrivialScalar(loc, builder,
getElementAt(loc, builder, entity, oneBasedIndices));
}
-
-llvm::SmallVector<mlir::Value>
-hlfir::genExtentsVector(mlir::Location loc, fir::FirOpBuilder &builder,
- hlfir::Entity entity) {
- entity = hlfir::derefPointersAndAllocatables(loc, builder, entity);
- mlir::Value shape = hlfir::genShape(loc, builder, entity);
- llvm::SmallVector<mlir::Value, Fortran::common::maxRank> extents =
- hlfir::getExplicitExtentsFromShape(shape, builder);
- if (shape.getUses().empty())
- shape.getDefiningOp()->erase();
- return extents;
-}
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
index fe7ae0eeed3cc3..0fe3620b7f1ae3 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
@@ -37,79 +37,6 @@ static llvm::cl::opt<bool> forceMatmulAsElemental(
namespace {
-// Helper class to generate operations related to computing
-// product of values.
-class ProductFactory {
-public:
- ProductFactory(mlir::Location loc, fir::FirOpBuilder &builder)
- : loc(loc), builder(builder) {}
-
- // Generate an update of the inner product value:
- // acc += v1 * v2, OR
- // acc += CONJ(v1) * v2, OR
- // acc ||= v1 && v2
- //
- // CONJ parameter specifies whether the first complex product argument
- // needs to be conjugated.
- template <bool CONJ = false>
- mlir::Value genAccumulateProduct(mlir::Value acc, mlir::Value v1,
- mlir::Value v2) {
- mlir::Type resultType = acc.getType();
- acc = castToProductType(acc, resultType);
- v1 = castToProductType(v1, resultType);
- v2 = castToProductType(v2, resultType);
- mlir::Value result;
- if (mlir::isa<mlir::FloatType>(resultType)) {
- result = builder.create<mlir::arith::AddFOp>(
- loc, acc, builder.create<mlir::arith::MulFOp>(loc, v1, v2));
- } else if (mlir::isa<mlir::ComplexType>(resultType)) {
- if constexpr (CONJ)
- result = fir::IntrinsicLibrary{builder, loc}.genConjg(resultType, v1);
- else
- result = v1;
-
- result = builder.create<fir::AddcOp>(
- loc, acc, builder.create<fir::MulcOp>(loc, result, v2));
- } else if (mlir::isa<mlir::IntegerType>(resultType)) {
- result = builder.create<mlir::arith::AddIOp>(
- loc, acc, builder.create<mlir::arith::MulIOp>(loc, v1, v2));
- } else if (mlir::isa<fir::LogicalType>(resultType)) {
- result = builder.create<mlir::arith::OrIOp>(
- loc, acc, builder.create<mlir::arith::AndIOp>(loc, v1, v2));
- } else {
- llvm_unreachable("unsupported type");
- }
-
- return builder.createConvert(loc, resultType, result);
- }
-
-private:
- mlir::Location loc;
- fir::FirOpBuilder &builder;
-
- mlir::Value castToProductType(mlir::Value value, mlir::Type type) {
- if (mlir::isa<fir::LogicalType>(type))
- return builder.createConvert(loc, builder.getIntegerType(1), value);
-
- // TODO: the multiplications/additions by/of zero resulting from
- // complex * real are optimized by LLVM under -fno-signed-zeros
- // -fno-honor-nans.
- // We can make them disappear by default if we:
- // * either expand the complex multiplication into real
- // operations, OR
- // * set nnan nsz fast-math flags to the complex operations.
- if (fir::isa_complex(type) && !fir::isa_complex(value.getType())) {
- mlir::Value zeroCmplx = fir::factory::createZeroValue(builder, loc, type);
- fir::factory::Complex helper(builder, loc);
- mlir::Type partType = helper.getComplexPartType(type);
- return helper.insertComplexPart(zeroCmplx,
- castToProductType(value, partType),
- /*isImagPart=*/false);
- }
- return builder.createConvert(loc, type, value);
- }
-};
-
class TransposeAsElementalConversion
: public mlir::OpRewritePattern<hlfir::TransposeOp> {
public:
@@ -163,8 +90,11 @@ class TransposeAsElementalConversion
static mlir::Value genResultShape(mlir::Location loc,
fir::FirOpBuilder &builder,
hlfir::Entity array) {
- llvm::SmallVector<mlir::Value, 2> inExtents =
- hlfir::genExtentsVector(loc, builder, array);
+ mlir::Value inShape = hlfir::genShape(loc, builder, array);
+ llvm::SmallVector<mlir::Value> inExtents =
+ hlfir::getExplicitExtentsFromShape(inShape, builder);
+ if (inShape.getUses().empty())
+ inShape.getDefiningOp()->erase();
// transpose indices
assert(inExtents.size() == 2 && "checked in TransposeOp::validate");
@@ -207,7 +137,7 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
mlir::Value resultShape, dimExtent;
llvm::SmallVector<mlir::Value> arrayExtents;
if (isTotalReduction)
- arrayExtents = hlfir::genExtentsVector(loc, builder, array);
+ arrayExtents = genArrayExtents(loc, builder, array);
else
std::tie(resultShape, dimExtent) =
genResultShapeForPartialReduction(loc, builder, array, dimVal);
@@ -233,8 +163,7 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
// If DIM is not present, do total reduction.
// Initial value for the reduction.
- mlir::Value reductionInitValue =
- fir::factory::createZeroValue(builder, loc, elementType);
+ mlir::Value reductionInitValue = genInitValue(loc, builder, elementType);
// The reduction loop may be unordered if FastMathFlags::reassoc
// transformations are allowed. The integer reduction is always
@@ -335,6 +264,17 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
}
private:
+ static llvm::SmallVector<mlir::Value>
+ genArrayExtents(mlir::Location loc, fir::FirOpBuilder &builder,
+ hlfir::Entity array) {
+ mlir::Value inShape = hlfir::genShape(loc, builder, array);
+ llvm::SmallVector<mlir::Value> inExtents =
+ hlfir::getExplicitExtentsFromShape(inShape, builder);
+ if (inShape.getUses().empty())
+ inShape.getDefiningOp()->erase();
+ return inExtents;
+ }
+
// Return fir.shape specifying the shape of the result
// of a SUM reduction with DIM=dimVal. The second return value
// is the extent of the DIM dimension.
@@ -343,7 +283,7 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
fir::FirOpBuilder &builder,
hlfir::Entity array, int64_t dimVal) {
llvm::SmallVector<mlir::Value> inExtents =
- hlfir::genExtentsVector(loc, builder, array);
+ genArrayExtents(loc, builder, array);
assert(dimVal > 0 && dimVal <= static_cast<int64_t>(inExtents.size()) &&
"DIM must be present and a positive constant not exceeding "
"the array's rank");
@@ -353,6 +293,26 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
return {builder.create<fir::ShapeOp>(loc, inExtents), dimExtent};
}
+ // Generate the initial value for a SUM reduction with the given
+ // data type.
+ static mlir::Value genInitValue(mlir::Location loc,
+ fir::FirOpBuilder &builder,
+ mlir::Type elementType) {
+ if (auto ty = mlir::dyn_cast<mlir::FloatType>(elementType)) {
+ const llvm::fltSemantics &sem = ty.getFloatSemantics();
+ return builder.createRealConstant(loc, elementType,
+ llvm::APFloat::getZero(sem));
+ } else if (auto ty = mlir::dyn_cast<mlir::ComplexType>(elementType)) {
+ mlir::Value initValue = genInitValue(loc, builder, ty.getElementType());
+ return fir::factory::Complex{builder, loc}.createComplex(ty, initValue,
+ initValue);
+ } else if (mlir::isa<mlir::IntegerType>(elementType)) {
+ return builder.createIntegerConstant(loc, elementType, 0);
+ }
+
+ llvm_unreachable("unsupported SUM reduction type");
+ }
+
// Generate scalar addition of the two values (of the same data type).
static mlir::Value genScalarAdd(mlir::Location loc,
fir::FirOpBuilder &builder,
@@ -610,10 +570,16 @@ class MatmulConversion : public mlir::OpRewritePattern<Op> {
static std::tuple<mlir::Value, mlir::Value>
genResultShape(mlir::Location loc, fir::FirOpBuilder &builder,
hlfir::Entity input1, hlfir::Entity input2) {
- llvm::SmallVector<mlir::Value, 2> input1Extents =
- hlfir::genExtentsVector(loc, builder, input1);
- llvm::SmallVector<mlir::Value, 2> input2Extents =
- hlfir::genExtentsVector(loc, builder, input2);
+ mlir::Value input1Shape = hlfir::genShape(loc, builder, input1);
+ llvm::SmallVector<mlir::Value> input1Extents =
+ hlfir::getExplicitExtentsFromShape(input1Shape, builder);
+ if (input1Shape.getUses().empty())
+ input1Shape.getDefiningOp()->erase();
+ mlir::Value input2Shape = hlfir::genShape(loc, builder, input2);
+ llvm::SmallVector<mlir::Value> input2Extents =
+ hlfir::getExplicitExtentsFromShape(input2Shape, builder);
+ if (input2Shape.getUses().empty())
+ input2Shape.getDefiningOp()->erase();
llvm::SmallVector<mlir::Value, 2> newExtents;
mlir::Value innerProduct1Extent, innerProduct2Extent;
@@ -661,6 +627,60 @@ class MatmulConversion : public mlir::OpRewritePattern<Op> {
innerProductExtent[0]};
}
+ static mlir::Value castToProductType(mlir::Location loc,
+ fir::FirOpBuilder &builder,
+ mlir::Value value, mlir::Type type) {
+ if (mlir::isa<fir::LogicalType>(type))
+ return builder.createConvert(loc, builder.getIntegerType(1), value);
+
+ // TODO: the multiplications/additions by/of zero resulting from
+ // complex * real are optimized by LLVM under -fno-signed-zeros
+ // -fno-honor-nans.
+ // We can make them disappear by default if we:
+ // * either expand the complex multiplication into real
+ // operations, OR
+ // * set nnan nsz fast-math flags to the complex operations.
+ if (fir::isa_complex(type) && !fir::isa_complex(value.getType())) {
+ mlir::Value zeroCmplx = fir::factory::createZeroValue(builder, loc, type);
+ fir::factory::Complex helper(builder, loc);
+ mlir::Type partType = helper.getComplexPartType(type);
+ return helper.insertComplexPart(
+ zeroCmplx, castToProductType(loc, builder, value, partType),
+ /*isImagPart=*/false);
+ }
+ return builder.createConvert(loc, type, value);
+ }
+
+ // Generate an update of the inner product value:
+ // acc += v1 * v2, OR
+ // acc ||= v1 && v2
+ static mlir::Value genAccumulateProduct(mlir::Location loc,
+ fir::FirOpBuilder &builder,
+ mlir::Type resultType,
+ mlir::Value acc, mlir::Value v1,
+ mlir::Value v2) {
+ acc = castToProductType(loc, builder, acc, resultType);
+ v1 = castToProductType(loc, builder, v1, resultType);
+ v2 = castToProductType(loc, builder, v2, resultType);
+ mlir::Value result;
+ if (mlir::isa<mlir::FloatType>(resultType))
+ result = builder.create<mlir::arith::AddFOp>(
+ loc, acc, builder.create<mlir::arith::MulFOp>(loc, v1, v2));
+ else if (mlir::isa<mlir::ComplexType>(resultType))
+ result = builder.create<fir::AddcOp>(
+ loc, acc, builder.create<fir::MulcOp>(loc, v1, v2));
+ else if (mlir::isa<mlir::IntegerType>(resultType))
+ result = builder.create<mlir::arith::AddIOp>(
+ loc, acc, builder.create<mlir::arith::MulIOp>(loc, v1, v2));
+ else if (mlir::isa<fir::LogicalType>(resultType))
+ result = builder.create<mlir::arith::OrIOp>(
+ loc, acc, builder.create<mlir::arith::AndIOp>(loc, v1, v2));
+ else
+ llvm_unreachable("unsupported type");
+
+ return builder.createConvert(loc, resultType, result);
+ }
+
static mlir::LogicalResult
genContiguousMatmul(mlir::Location loc, fir::FirOpBuilder &builder,
hlfir::Entity result, mlir::Value resultShape,
@@ -728,9 +748,9 @@ class MatmulConversion : public mlir::OpRewritePattern<Op> {
hlfir::loadElementAt(loc, builder, lhs, {I, K});
hlfir::Entity rhsElementValue =
hlfir::loadElementAt(loc, builder, rhs, {K, J});
- mlir::Value productValue =
- ProductFactory{loc, builder}.genAccumulateProduct(
- resultElementValue, lhsElementValue, rhsElementValue);
+ mlir::Value productValue = genAccumulateProduct(
+ loc, builder, resultElementType, resultElementValue,
+ lhsElementValue, rhsElementValue);
builder.create<hlfir::AssignOp>(loc, productValue, resultElement);
return {};
};
@@ -765,9 +785,9 @@ class MatmulConversion : public mlir::OpRewritePattern<Op> {
hlfir::loadElementAt(loc, builder, lhs, {J, K});
hlfir::Entity rhsElementValue =
hlfir::loadElementAt(loc, builder, rhs, {K});
- mlir::Value productValue =
- ProductFactory{loc, builder}.genAccumulateProduct(
- resultElementValue, lhsElementValue, rhsElementValue);
+ mlir::Value productValue = genAccumulateProduct(
+ loc, builder, resultElementType, resultElementValue,
+ lhsElementValue, rhsElementValue);
builder.create<hlfir::AssignOp>(loc, productValue, resultElement);
return {};
};
@@ -797,9 +817,9 @@ class MatmulConversion : public mlir::OpRewritePattern<Op> {
hlfir::loadElementAt(loc, builder, lhs, {K});
hlfir::Entity rhsElementValue =
hlfir::loadElementAt(loc, builder, rhs, {K, J});
- mlir::Value productValue =
- ProductFactory{loc, builder}.genAccumulateProduct(
- resultElementValue, lhsElementValue, rhsElementValue);
+ mlir::Value productValue = genAccumulateProduct(
+ loc, builder, resultElementType, resultElementValue,
+ lhsElementValue, rhsElementValue);
builder.create<hlfir::AssignOp>(loc, productValue, resultElement);
return {};
};
@@ -865,9 +885,9 @@ class MatmulConversion : public mlir::OpRewritePattern<Op> {
hlfir::loadElementAt(loc, builder, lhs, lhsIndices);
hlfir::Entity rhsElementValue =
hlfir::loadElementAt(loc, builder, rhs, rhsIndices);
- mlir::Value productValue =
- ProductFactory{loc, builder}.genAccumulateProduct(
- reductionArgs[0], lhsElementValue, rhsElementValue);
+ mlir::Value productValue = genAccumulateProduct(
+ loc, builder, resultElementType, reductionArgs[0], lhsElementValue,
+ rhsElementValue);
return {productValue};
};
llvm::SmallVector<mlir::Value, 1> innerProductValue =
@@ -884,73 +904,6 @@ class MatmulConversion : public mlir::OpRewritePattern<Op> {
}
};
-class DotProductConversion
- : public mlir::OpRewritePattern<hlfir::DotProductOp> {
-public:
- using mlir::OpRewritePattern<hlfir::DotProductOp>::OpRewritePattern;
-
- llvm::LogicalResult
- matchAndRewrite(hlfir::DotProductOp product,
- mlir::PatternRewriter &rewriter) const override {
- hlfir::Entity op = hlfir::Entity{product};
- if (!op.isScalar())
- return rewriter.notifyMatchFailure(product, "produces non-scalar result");
-
- mlir::Location loc = product.getLoc();
- fir::FirOpBuilder builder{rewriter, product.getOperation()};
- hlfir::Entity lhs = hlfir::Entity{product.getLhs()};
- hlfir::Entity rhs = hlfir::Entity{product.getRhs()};
- mlir::Type resultElementType = product.getType();
- bool isUnordered = mlir::isa<mlir::IntegerType>(resultElementType) ||
- mlir::isa<fir::LogicalType>(resultElementType) ||
- static_cast<bool>(builder.getFastMathFlags() &
- mlir::arith::FastMathFlags::reassoc);
-
- mlir::Value extent = genProductExtent(loc, builder, lhs, rhs);
-
- auto genBody = [&](mlir::Location loc, fir::FirOpBuilder &builder,
- mlir::ValueRange oneBasedIndices,
- mlir::ValueRange reductionArgs)
- -> llvm::SmallVector<mlir::Value, 1> {
- hlfir::Entity lhsElementValue =
- hlfir::loadElementAt(loc, builder, lhs, oneBasedIndices);
- hlfir::Entity rhsElementValue =
- hlfir::loadElementAt(loc, builder, rhs, oneBasedIndices);
- mlir::Value productValue =
- ProductFactory{loc, builder}.genAccumulateProduct</*CONJ=*/true>(
- reductionArgs[0], lhsElementValue, rhsElementValue);
- return {productValue};
- };
-
- mlir::Value initValue =
- fir::factory::createZeroValue(builder, loc, resultElementType);
-
- llvm::SmallVector<mlir::Value, 1> result = hlfir::genLoopNestWithReductions(
- loc, builder, {extent},
- /*reductionInits=*/{initValue}, genBody, isUnordered);
-
- rewriter.replaceOp(product, result[0]);
- return mlir::success();
- }
-
-private:
- static mlir::Value genProductExtent(mlir::Location loc,
- fir::FirOpBuilder &builder,
- hlfir::Entity input1,
- hlfir::Entity input2) {
- llvm::SmallVector<mlir::Value, 1> input1Extents =
- hlfir::genExtentsVector(loc, builder, input1);
- llvm::SmallVector<mlir::Value, 1> input2Extents =
- hlfir::genExtentsVector(loc, builder, input2);
-
- assert(input1Extents.size() == 1 && input2Extents.size() == 1 &&
- "hlfir.dot_product arguments must be vectors");
- llvm::SmallVector<mlir::Value, 1> extent =
- fir::factory::deduceOptimalExtents(input1Extents, input2Extents);
- return extent[0];
- }
-};
-
class SimplifyHLFIRIntrinsics
: public hlfir::impl::SimplifyHLFIRIntrinsicsBase<SimplifyHLFIRIntrinsics> {
public:
@@ -986,8 +939,6 @@ class SimplifyHLFIRIntrinsics
if (forceMatmulAsElemental || this->allowNewSideEffects)
patterns.insert<MatmulConversion<hlfir::MatmulOp>>(context);
- patterns.insert<DotProductConversion>(context);
-
if (mlir::failed(mlir::applyPatternsGreedily(
getOperation(), std::move(patterns), config))) {
mlir::emitError(getOperation()->getLoc(),
diff --git a/flang/test/HLFIR/simplify-hlfir-intrinsics-dotproduct.fir b/flang/test/HLFIR/simplify-hlfir-intrinsics-dotproduct.fir
deleted file mode 100644
index f59b1422dbc849..00000000000000
--- a/flang/test/HLFIR/simplify-hlfir-intrinsics-dotproduct.fir
+++ /dev/null
@@ -1,144 +0,0 @@
-// Test hlfir.dot_product simplification to a reduction loop:
-// RUN: fir-opt --simplify-hlfir-intrinsics %s | FileCheck %s
-
-func.func @dot_product_integer(%arg0: !hlfir.expr<?xi16>, %arg1: !hlfir.expr<?xi32>) -> i32 {
- %res = hlfir.dot_product %arg0 %arg1 : (!hlfir.expr<?xi16>, !hlfir.expr<?xi32>) -> i32
- return %res : i32
-}
-// CHECK-LABEL: func.func @dot_product_integer(
-// CHECK-SAME: %[[VAL_0:.*]]: !hlfir.expr<?xi16>,
-// CHECK-SAME: %[[VAL_1:.*]]: !hlfir.expr<?xi32>) -> i32 {
-// CHECK: %[[VAL_2:.*]] = arith.constant 1 : index
-// CHECK: %[[VAL_3:.*]] = arith.constant 0 : i32
-// CHECK: %[[VAL_4:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr<?xi16>) -> !fir.shape<1>
-// CHECK: %[[VAL_5:.*]] = hlfir.get_extent %[[VAL_4]] {dim = 0 : index} : (!fir.shape<1>) -> index
-// CHECK: %[[VAL_6:.*]] = fir.do_loop %[[VAL_7:.*]] = %[[VAL_2]] to %[[VAL_5]] step %[[VAL_2]] unordered iter_args(%[[VAL_8:.*]] = %[[VAL_3]]) -> (i32) {
-// CHECK: %[[VAL_9:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_7]] : (!hlfir.expr<?xi16>, index) -> i16
-// CHECK: %[[VAL_10:.*]] = hlfir.apply %[[VAL_1]], %[[VAL_7]] : (!hlfir.expr<?xi32>, index) -> i32
-// CHECK: %[[VAL_11:.*]] = fir.convert %[[VAL_9]] : (i16) -> i32
-// CHECK: %[[VAL_12:.*]] = arith.muli %[[VAL_11]], %[[VAL_10]] : i32
-// CHECK: %[[VAL_13:.*]] = arith.addi %[[VAL_8]], %[[VAL_12]] : i32
-// CHECK: fir.result %[[VAL_13]] : i32
-// CHECK: }
-// CHECK: return %[[VAL_6]] : i32
-// CHECK: }
-
-func.func @dot_product_real(%arg0: !hlfir.expr<?xf32>, %arg1: !hlfir.expr<?xf16>) -> f32 {
- %res = hlfir.dot_product %arg0 %arg1 : (!hlfir.expr<?xf32>, !hlfir.expr<?xf16>) -> f32
- return %res : f32
-}
-// CHECK-LABEL: func.func @dot_product_real(
-// CHECK-SAME: %[[VAL_0:.*]]: !hlfir.expr<?xf32>,
-// CHECK-SAME: %[[VAL_1:.*]]: !hlfir.expr<?xf16>) -> f32 {
-// CHECK: %[[VAL_2:.*]] = arith.constant 1 : index
-// CHECK: %[[VAL_3:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[VAL_4:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr<?xf32>) -> !fir.shape<1>
-// CHECK: %[[VAL_5:.*]] = hlfir.get_extent %[[VAL_4]] {dim = 0 : index} : (!fir.shape<1>) -> index
-// CHECK: %[[VAL_6:.*]] = fir.do_loop %[[VAL_7:.*]] = %[[VAL_2]] to %[[VAL_5]] step %[[VAL_2]] iter_args(%[[VAL_8:.*]] = %[[VAL_3]]) -> (f32) {
-// CHECK: %[[VAL_9:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_7]] : (!hlfir.expr<?xf32>, index) -> f32
-// CHECK: %[[VAL_10:.*]] = hlfir.apply %[[VAL_1]], %[[VAL_7]] : (!hlfir.expr<?xf16>, index) -> f16
-// CHECK: %[[VAL_11:.*]] = fir.convert %[[VAL_10]] : (f16) -> f32
-// CHECK: %[[VAL_12:.*]] = arith.mulf %[[VAL_9]], %[[VAL_11]] : f32
-// CHECK: %[[VAL_13:.*]] = arith.addf %[[VAL_8]], %[[VAL_12]] : f32
-// CHECK: fir.result %[[VAL_13]] : f32
-// CHECK: }
-// CHECK: return %[[VAL_6]] : f32
-// CHECK: }
-
-func.func @dot_product_complex(%arg0: !hlfir.expr<?xcomplex<f32>>, %arg1: !hlfir.expr<?xcomplex<f16>>) -> complex<f32> {
- %res = hlfir.dot_product %arg0 %arg1 : (!hlfir.expr<?xcomplex<f32>>, !hlfir.expr<?xcomplex<f16>>) -> complex<f32>
- return %res : complex<f32>
-}
-// CHECK-LABEL: func.func @dot_product_complex(
-// CHECK-SAME: %[[VAL_0:.*]]: !hlfir.expr<?xcomplex<f32>>,
-// CHECK-SAME: %[[VAL_1:.*]]: !hlfir.expr<?xcomplex<f16>>) -> complex<f32> {
-// CHECK: %[[VAL_2:.*]] = arith.constant 1 : index
-// CHECK: %[[VAL_3:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[VAL_4:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr<?xcomplex<f32>>) -> !fir.shape<1>
-// CHECK: %[[VAL_5:.*]] = hlfir.get_extent %[[VAL_4]] {dim = 0 : index} : (!fir.shape<1>) -> index
-// CHECK: %[[VAL_6:.*]] = fir.undefined complex<f32>
-// CHECK: %[[VAL_7:.*]] = fir.insert_value %[[VAL_6]], %[[VAL_3]], [0 : index] : (complex<f32>, f32) -> complex<f32>
-// CHECK: %[[VAL_8:.*]] = fir.insert_value %[[VAL_7]], %[[VAL_3]], [1 : index] : (complex<f32>, f32) -> complex<f32>
-// CHECK: %[[VAL_9:.*]] = fir.do_loop %[[VAL_10:.*]] = %[[VAL_2]] to %[[VAL_5]] step %[[VAL_2]] iter_args(%[[VAL_11:.*]] = %[[VAL_8]]) -> (complex<f32>) {
-// CHECK: %[[VAL_12:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_10]] : (!hlfir.expr<?xcomplex<f32>>, index) -> complex<f32>
-// CHECK: %[[VAL_13:.*]] = hlfir.apply %[[VAL_1]], %[[VAL_10]] : (!hlfir.expr<?xcomplex<f16>>, index) -> complex<f16>
-// CHECK: %[[VAL_14:.*]] = fir.convert %[[VAL_13]] : (complex<f16>) -> complex<f32>
-// CHECK: %[[VAL_15:.*]] = fir.extract_value %[[VAL_12]], [1 : index] : (complex<f32>) -> f32
-// CHECK: %[[VAL_16:.*]] = arith.negf %[[VAL_15]] : f32
-// CHECK: %[[VAL_17:.*]] = fir.insert_value %[[VAL_12]], %[[VAL_16]], [1 : index] : (complex<f32>, f32) -> complex<f32>
-// CHECK: %[[VAL_18:.*]] = fir.mulc %[[VAL_17]], %[[VAL_14]] : complex<f32>
-// CHECK: %[[VAL_19:.*]] = fir.addc %[[VAL_11]], %[[VAL_18]] : complex<f32>
-// CHECK: fir.result %[[VAL_19]] : complex<f32>
-// CHECK: }
-// CHECK: return %[[VAL_9]] : complex<f32>
-// CHECK: }
-
-func.func @dot_product_real_complex(%arg0: !hlfir.expr<?xf32>, %arg1: !hlfir.expr<?xcomplex<f16>>) -> complex<f32> {
- %res = hlfir.dot_product %arg0 %arg1 : (!hlfir.expr<?xf32>, !hlfir.expr<?xcomplex<f16>>) -> complex<f32>
- return %res : complex<f32>
-}
-// CHECK-LABEL: func.func @dot_product_real_complex(
-// CHECK-SAME: %[[VAL_0:.*]]: !hlfir.expr<?xf32>,
-// CHECK-SAME: %[[VAL_1:.*]]: !hlfir.expr<?xcomplex<f16>>) -> complex<f32> {
-// CHECK: %[[VAL_2:.*]] = arith.constant 1 : index
-// CHECK: %[[VAL_3:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[VAL_4:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr<?xf32>) -> !fir.shape<1>
-// CHECK: %[[VAL_5:.*]] = hlfir.get_extent %[[VAL_4]] {dim = 0 : index} : (!fir.shape<1>) -> index
-// CHECK: %[[VAL_6:.*]] = fir.undefined complex<f32>
-// CHECK: %[[VAL_7:.*]] = fir.insert_value %[[VAL_6]], %[[VAL_3]], [0 : index] : (complex<f32>, f32) -> complex<f32>
-// CHECK: %[[VAL_8:.*]] = fir.insert_value %[[VAL_7]], %[[VAL_3]], [1 : index] : (complex<f32>, f32) -> complex<f32>
-// CHECK: %[[VAL_9:.*]] = fir.do_loop %[[VAL_10:.*]] = %[[VAL_2]] to %[[VAL_5]] step %[[VAL_2]] iter_args(%[[VAL_11:.*]] = %[[VAL_8]]) -> (complex<f32>) {
-// CHECK: %[[VAL_12:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_10]] : (!hlfir.expr<?xf32>, index) -> f32
-// CHECK: %[[VAL_13:.*]] = hlfir.apply %[[VAL_1]], %[[VAL_10]] : (!hlfir.expr<?xcomplex<f16>>, index) -> complex<f16>
-// CHECK: %[[VAL_14:.*]] = fir.undefined complex<f32>
-// CHECK: %[[VAL_15:.*]] = fir.insert_value %[[VAL_14]], %[[VAL_3]], [0 : index] : (complex<f32>, f32) -> complex<f32>
-// CHECK: %[[VAL_16:.*]] = fir.insert_value %[[VAL_15]], %[[VAL_3]], [1 : index] : (complex<f32>, f32) -> complex<f32>
-// CHECK: %[[VAL_17:.*]] = fir.insert_value %[[VAL_16]], %[[VAL_12]], [0 : index] : (complex<f32>, f32) -> complex<f32>
-// CHECK: %[[VAL_18:.*]] = fir.convert %[[VAL_13]] : (complex<f16>) -> complex<f32>
-// CHECK: %[[VAL_19:.*]] = fir.extract_value %[[VAL_17]], [1 : index] : (complex<f32>) -> f32
-// CHECK: %[[VAL_20:.*]] = arith.negf %[[VAL_19]] : f32
-// CHECK: %[[VAL_21:.*]] = fir.insert_value %[[VAL_17]], %[[VAL_20]], [1 : index] : (complex<f32>, f32) -> complex<f32>
-// CHECK: %[[VAL_22:.*]] = fir.mulc %[[VAL_21]], %[[VAL_18]] : complex<f32>
-// CHECK: %[[VAL_23:.*]] = fir.addc %[[VAL_11]], %[[VAL_22]] : complex<f32>
-// CHECK: fir.result %[[VAL_23]] : complex<f32>
-// CHECK: }
-// CHECK: return %[[VAL_9]] : complex<f32>
-// CHECK: }
-
-func.func @dot_product_logical(%arg0: !hlfir.expr<?x!fir.logical<1>>, %arg1: !hlfir.expr<?x!fir.logical<4>>) -> !fir.logical<4> {
- %res = hlfir.dot_product %arg0 %arg1 : (!hlfir.expr<?x!fir.logical<1>>, !hlfir.expr<?x!fir.logical<4>>) -> !fir.logical<4>
- return %res : !fir.logical<4>
-}
-// CHECK-LABEL: func.func @dot_product_logical(
-// CHECK-SAME: %[[VAL_0:.*]]: !hlfir.expr<?x!fir.logical<1>>,
-// CHECK-SAME: %[[VAL_1:.*]]: !hlfir.expr<?x!fir.logical<4>>) -> !fir.logical<4> {
-// CHECK: %[[VAL_2:.*]] = arith.constant 1 : index
-// CHECK: %[[VAL_3:.*]] = arith.constant false
-// CHECK: %[[VAL_4:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr<?x!fir.logical<1>>) -> !fir.shape<1>
-// CHECK: %[[VAL_5:.*]] = hlfir.get_extent %[[VAL_4]] {dim = 0 : index} : (!fir.shape<1>) -> index
-// CHECK: %[[VAL_6:.*]] = fir.convert %[[VAL_3]] : (i1) -> !fir.logical<4>
-// CHECK: %[[VAL_7:.*]] = fir.do_loop %[[VAL_8:.*]] = %[[VAL_2]] to %[[VAL_5]] step %[[VAL_2]] unordered iter_args(%[[VAL_9:.*]] = %[[VAL_6]]) -> (!fir.logical<4>) {
-// CHECK: %[[VAL_10:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_8]] : (!hlfir.expr<?x!fir.logical<1>>, index) -> !fir.logical<1>
-// CHECK: %[[VAL_11:.*]] = hlfir.apply %[[VAL_1]], %[[VAL_8]] : (!hlfir.expr<?x!fir.logical<4>>, index) -> !fir.logical<4>
-// CHECK: %[[VAL_12:.*]] = fir.convert %[[VAL_9]] : (!fir.logical<4>) -> i1
-// CHECK: %[[VAL_13:.*]] = fir.convert %[[VAL_10]] : (!fir.logical<1>) -> i1
-// CHECK: %[[VAL_14:.*]] = fir.convert %[[VAL_11]] : (!fir.logical<4>) -> i1
-// CHECK: %[[VAL_15:.*]] = arith.andi %[[VAL_13]], %[[VAL_14]] : i1
-// CHECK: %[[VAL_16:.*]] = arith.ori %[[VAL_12]], %[[VAL_15]] : i1
-// CHECK: %[[VAL_17:.*]] = fir.convert %[[VAL_16]] : (i1) -> !fir.logical<4>
-// CHECK: fir.result %[[VAL_17]] : !fir.logical<4>
-// CHECK: }
-// CHECK: return %[[VAL_7]] : !fir.logical<4>
-// CHECK: }
-
-func.func @dot_product_known_dim(%arg0: !hlfir.expr<10xf32>, %arg1: !hlfir.expr<?xi16>) -> f32 {
- %res1 = hlfir.dot_product %arg0 %arg1 : (!hlfir.expr<10xf32>, !hlfir.expr<?xi16>) -> f32
- %res2 = hlfir.dot_product %arg1 %arg0 : (!hlfir.expr<?xi16>, !hlfir.expr<10xf32>) -> f32
- %res = arith.addf %res1, %res2 : f32
- return %res : f32
-}
-// CHECK-LABEL: func.func @dot_product_known_dim(
-// CHECK: %[[VAL_2:.*]] = arith.constant 1 : index
-// CHECK: %[[VAL_4:.*]] = arith.constant 10 : index
-// CHECK: fir.do_loop %{{.*}} = %[[VAL_2]] to %[[VAL_4]] step %[[VAL_2]]
-// CHECK: fir.do_loop %{{.*}} = %[[VAL_2]] to %[[VAL_4]] step %[[VAL_2]]
More information about the flang-commits
mailing list