[flang-commits] [flang] 4a8305c - [flang] Add TODO for half-precision intrinsic reductions
Valentin Clement via flang-commits
flang-commits at lists.llvm.org
Mon Jun 13 08:40:09 PDT 2022
Author: Valentin Clement
Date: 2022-06-13T17:40:01+02:00
New Revision: 4a8305ce856b6f4d2e49f3100226d00b402dff86
URL: https://github.com/llvm/llvm-project/commit/4a8305ce856b6f4d2e49f3100226d00b402dff86
DIFF: https://github.com/llvm/llvm-project/commit/4a8305ce856b6f4d2e49f3100226d00b402dff86.diff
LOG: [flang] Add TODO for half-precision intrinsic reductions
Add TODO for half-precision for reduction.
This patch is part of the upstreaming effort from fir-dev branch.
Reviewed By: jeanPerier, PeteSteinfeld
Differential Revision: https://reviews.llvm.org/D127622
Co-authored-by: Eric Schweitz <eschweitz at nvidia.com>
Added:
Modified:
flang/lib/Optimizer/Builder/Runtime/Reduction.cpp
Removed:
################################################################################
diff --git a/flang/lib/Optimizer/Builder/Runtime/Reduction.cpp b/flang/lib/Optimizer/Builder/Runtime/Reduction.cpp
index aa39c8e464eca..579020acb8da5 100644
--- a/flang/lib/Optimizer/Builder/Runtime/Reduction.cpp
+++ b/flang/lib/Optimizer/Builder/Runtime/Reduction.cpp
@@ -545,7 +545,9 @@ mlir::Value fir::runtime::genMaxval(fir::FirOpBuilder &builder,
auto eleTy = arrTy.cast<fir::SequenceType>().getEleTy();
auto dim = builder.createIntegerConstant(loc, builder.getIndexType(), 0);
- if (eleTy.isF32())
+ if (eleTy.isF16() || eleTy.isBF16())
+ TODO(loc, "half-precision MAXVAL");
+ else if (eleTy.isF32())
func = fir::runtime::getRuntimeFunc<mkRTKey(MaxvalReal4)>(loc, builder);
else if (eleTy.isF64())
func = fir::runtime::getRuntimeFunc<mkRTKey(MaxvalReal8)>(loc, builder);
@@ -553,23 +555,18 @@ mlir::Value fir::runtime::genMaxval(fir::FirOpBuilder &builder,
func = fir::runtime::getRuntimeFunc<ForcedMaxvalReal10>(loc, builder);
else if (eleTy.isF128())
func = fir::runtime::getRuntimeFunc<ForcedMaxvalReal16>(loc, builder);
- else if (eleTy ==
- builder.getIntegerType(builder.getKindMap().getIntegerBitsize(1)))
+ else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(1)))
func = fir::runtime::getRuntimeFunc<mkRTKey(MaxvalInteger1)>(loc, builder);
- else if (eleTy ==
- builder.getIntegerType(builder.getKindMap().getIntegerBitsize(2)))
+ else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(2)))
func = fir::runtime::getRuntimeFunc<mkRTKey(MaxvalInteger2)>(loc, builder);
- else if (eleTy ==
- builder.getIntegerType(builder.getKindMap().getIntegerBitsize(4)))
+ else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(4)))
func = fir::runtime::getRuntimeFunc<mkRTKey(MaxvalInteger4)>(loc, builder);
- else if (eleTy ==
- builder.getIntegerType(builder.getKindMap().getIntegerBitsize(8)))
+ else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(8)))
func = fir::runtime::getRuntimeFunc<mkRTKey(MaxvalInteger8)>(loc, builder);
- else if (eleTy ==
- builder.getIntegerType(builder.getKindMap().getIntegerBitsize(16)))
+ else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(16)))
func = fir::runtime::getRuntimeFunc<ForcedMaxvalInteger16>(loc, builder);
else
- fir::emitFatalError(loc, "invalid type in Maxval lowering");
+ fir::emitFatalError(loc, "invalid type in MAXVAL");
auto fTy = func.getFunctionType();
auto sourceFile = fir::factory::locationToFilename(builder, loc);
@@ -664,7 +661,9 @@ mlir::Value fir::runtime::genMinval(fir::FirOpBuilder &builder,
auto eleTy = arrTy.cast<fir::SequenceType>().getEleTy();
auto dim = builder.createIntegerConstant(loc, builder.getIndexType(), 0);
- if (eleTy.isF32())
+ if (eleTy.isF16() || eleTy.isBF16())
+ TODO(loc, "half-precision MINVAL");
+ else if (eleTy.isF32())
func = fir::runtime::getRuntimeFunc<mkRTKey(MinvalReal4)>(loc, builder);
else if (eleTy.isF64())
func = fir::runtime::getRuntimeFunc<mkRTKey(MinvalReal8)>(loc, builder);
@@ -672,23 +671,18 @@ mlir::Value fir::runtime::genMinval(fir::FirOpBuilder &builder,
func = fir::runtime::getRuntimeFunc<ForcedMinvalReal10>(loc, builder);
else if (eleTy.isF128())
func = fir::runtime::getRuntimeFunc<ForcedMinvalReal16>(loc, builder);
- else if (eleTy ==
- builder.getIntegerType(builder.getKindMap().getIntegerBitsize(1)))
+ else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(1)))
func = fir::runtime::getRuntimeFunc<mkRTKey(MinvalInteger1)>(loc, builder);
- else if (eleTy ==
- builder.getIntegerType(builder.getKindMap().getIntegerBitsize(2)))
+ else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(2)))
func = fir::runtime::getRuntimeFunc<mkRTKey(MinvalInteger2)>(loc, builder);
- else if (eleTy ==
- builder.getIntegerType(builder.getKindMap().getIntegerBitsize(4)))
+ else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(4)))
func = fir::runtime::getRuntimeFunc<mkRTKey(MinvalInteger4)>(loc, builder);
- else if (eleTy ==
- builder.getIntegerType(builder.getKindMap().getIntegerBitsize(8)))
+ else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(8)))
func = fir::runtime::getRuntimeFunc<mkRTKey(MinvalInteger8)>(loc, builder);
- else if (eleTy ==
- builder.getIntegerType(builder.getKindMap().getIntegerBitsize(16)))
+ else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(16)))
func = fir::runtime::getRuntimeFunc<ForcedMinvalInteger16>(loc, builder);
else
- fir::emitFatalError(loc, "invalid type in Minval lowering");
+ fir::emitFatalError(loc, "invalid type in MINVAL");
auto fTy = func.getFunctionType();
auto sourceFile = fir::factory::locationToFilename(builder, loc);
@@ -721,7 +715,9 @@ mlir::Value fir::runtime::genProduct(fir::FirOpBuilder &builder,
auto eleTy = arrTy.cast<fir::SequenceType>().getEleTy();
auto dim = builder.createIntegerConstant(loc, builder.getIndexType(), 0);
- if (eleTy.isF32())
+ if (eleTy.isF16() || eleTy.isBF16())
+ TODO(loc, "half-precision PRODUCT");
+ else if (eleTy.isF32())
func = fir::runtime::getRuntimeFunc<mkRTKey(ProductReal4)>(loc, builder);
else if (eleTy.isF64())
func = fir::runtime::getRuntimeFunc<mkRTKey(ProductReal8)>(loc, builder);
@@ -729,20 +725,15 @@ mlir::Value fir::runtime::genProduct(fir::FirOpBuilder &builder,
func = fir::runtime::getRuntimeFunc<ForcedProductReal10>(loc, builder);
else if (eleTy.isF128())
func = fir::runtime::getRuntimeFunc<ForcedProductReal16>(loc, builder);
- else if (eleTy ==
- builder.getIntegerType(builder.getKindMap().getIntegerBitsize(1)))
+ else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(1)))
func = fir::runtime::getRuntimeFunc<mkRTKey(ProductInteger1)>(loc, builder);
- else if (eleTy ==
- builder.getIntegerType(builder.getKindMap().getIntegerBitsize(2)))
+ else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(2)))
func = fir::runtime::getRuntimeFunc<mkRTKey(ProductInteger2)>(loc, builder);
- else if (eleTy ==
- builder.getIntegerType(builder.getKindMap().getIntegerBitsize(4)))
+ else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(4)))
func = fir::runtime::getRuntimeFunc<mkRTKey(ProductInteger4)>(loc, builder);
- else if (eleTy ==
- builder.getIntegerType(builder.getKindMap().getIntegerBitsize(8)))
+ else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(8)))
func = fir::runtime::getRuntimeFunc<mkRTKey(ProductInteger8)>(loc, builder);
- else if (eleTy ==
- builder.getIntegerType(builder.getKindMap().getIntegerBitsize(16)))
+ else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(16)))
func = fir::runtime::getRuntimeFunc<ForcedProductInteger16>(loc, builder);
else if (eleTy == fir::ComplexType::get(builder.getContext(), 4))
func =
@@ -754,8 +745,11 @@ mlir::Value fir::runtime::genProduct(fir::FirOpBuilder &builder,
func = fir::runtime::getRuntimeFunc<ForcedProductComplex10>(loc, builder);
else if (eleTy == fir::ComplexType::get(builder.getContext(), 16))
func = fir::runtime::getRuntimeFunc<ForcedProductComplex16>(loc, builder);
+ else if (eleTy == fir::ComplexType::get(builder.getContext(), 2) ||
+ eleTy == fir::ComplexType::get(builder.getContext(), 3))
+ TODO(loc, "half-precision PRODUCT");
else
- fir::emitFatalError(loc, "invalid type in Product lowering");
+ fir::emitFatalError(loc, "invalid type in PRODUCT");
auto fTy = func.getFunctionType();
auto sourceFile = fir::factory::locationToFilename(builder, loc);
@@ -788,7 +782,9 @@ mlir::Value fir::runtime::genDotProduct(fir::FirOpBuilder &builder,
auto arrTy = fir::dyn_cast_ptrOrBoxEleTy(ty);
auto eleTy = arrTy.cast<fir::SequenceType>().getEleTy();
- if (eleTy.isF32())
+ if (eleTy.isF16() || eleTy.isBF16())
+ TODO(loc, "half-precision DOTPRODUCT");
+ else if (eleTy.isF32())
func = fir::runtime::getRuntimeFunc<mkRTKey(DotProductReal4)>(loc, builder);
else if (eleTy.isF64())
func = fir::runtime::getRuntimeFunc<mkRTKey(DotProductReal8)>(loc, builder);
@@ -808,31 +804,29 @@ mlir::Value fir::runtime::genDotProduct(fir::FirOpBuilder &builder,
else if (eleTy == fir::ComplexType::get(builder.getContext(), 16))
func =
fir::runtime::getRuntimeFunc<ForcedDotProductComplex16>(loc, builder);
- else if (eleTy ==
- builder.getIntegerType(builder.getKindMap().getIntegerBitsize(1)))
+ else if (eleTy == fir::ComplexType::get(builder.getContext(), 2) ||
+ eleTy == fir::ComplexType::get(builder.getContext(), 3))
+ TODO(loc, "half-precision DOTPRODUCT");
+ else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(1)))
func =
fir::runtime::getRuntimeFunc<mkRTKey(DotProductInteger1)>(loc, builder);
- else if (eleTy ==
- builder.getIntegerType(builder.getKindMap().getIntegerBitsize(2)))
+ else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(2)))
func =
fir::runtime::getRuntimeFunc<mkRTKey(DotProductInteger2)>(loc, builder);
- else if (eleTy ==
- builder.getIntegerType(builder.getKindMap().getIntegerBitsize(4)))
+ else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(4)))
func =
fir::runtime::getRuntimeFunc<mkRTKey(DotProductInteger4)>(loc, builder);
- else if (eleTy ==
- builder.getIntegerType(builder.getKindMap().getIntegerBitsize(8)))
+ else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(8)))
func =
fir::runtime::getRuntimeFunc<mkRTKey(DotProductInteger8)>(loc, builder);
- else if (eleTy ==
- builder.getIntegerType(builder.getKindMap().getIntegerBitsize(16)))
+ else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(16)))
func =
fir::runtime::getRuntimeFunc<ForcedDotProductInteger16>(loc, builder);
else if (eleTy.isa<fir::LogicalType>())
func =
fir::runtime::getRuntimeFunc<mkRTKey(DotProductLogical)>(loc, builder);
else
- fir::emitFatalError(loc, "invalid type in DotProduct lowering");
+ fir::emitFatalError(loc, "invalid type in DOTPRODUCT");
auto fTy = func.getFunctionType();
auto sourceFile = fir::factory::locationToFilename(builder, loc);
@@ -873,7 +867,9 @@ mlir::Value fir::runtime::genSum(fir::FirOpBuilder &builder, mlir::Location loc,
auto eleTy = arrTy.cast<fir::SequenceType>().getEleTy();
auto dim = builder.createIntegerConstant(loc, builder.getIndexType(), 0);
- if (eleTy.isF32())
+ if (eleTy.isF16() || eleTy.isBF16())
+ TODO(loc, "half-precision SUM");
+ else if (eleTy.isF32())
func = fir::runtime::getRuntimeFunc<mkRTKey(SumReal4)>(loc, builder);
else if (eleTy.isF64())
func = fir::runtime::getRuntimeFunc<mkRTKey(SumReal8)>(loc, builder);
@@ -881,20 +877,15 @@ mlir::Value fir::runtime::genSum(fir::FirOpBuilder &builder, mlir::Location loc,
func = fir::runtime::getRuntimeFunc<ForcedSumReal10>(loc, builder);
else if (eleTy.isF128())
func = fir::runtime::getRuntimeFunc<ForcedSumReal16>(loc, builder);
- else if (eleTy ==
- builder.getIntegerType(builder.getKindMap().getIntegerBitsize(1)))
+ else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(1)))
func = fir::runtime::getRuntimeFunc<mkRTKey(SumInteger1)>(loc, builder);
- else if (eleTy ==
- builder.getIntegerType(builder.getKindMap().getIntegerBitsize(2)))
+ else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(2)))
func = fir::runtime::getRuntimeFunc<mkRTKey(SumInteger2)>(loc, builder);
- else if (eleTy ==
- builder.getIntegerType(builder.getKindMap().getIntegerBitsize(4)))
+ else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(4)))
func = fir::runtime::getRuntimeFunc<mkRTKey(SumInteger4)>(loc, builder);
- else if (eleTy ==
- builder.getIntegerType(builder.getKindMap().getIntegerBitsize(8)))
+ else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(8)))
func = fir::runtime::getRuntimeFunc<mkRTKey(SumInteger8)>(loc, builder);
- else if (eleTy ==
- builder.getIntegerType(builder.getKindMap().getIntegerBitsize(16)))
+ else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(16)))
func = fir::runtime::getRuntimeFunc<ForcedSumInteger16>(loc, builder);
else if (eleTy == fir::ComplexType::get(builder.getContext(), 4))
func = fir::runtime::getRuntimeFunc<mkRTKey(CppSumComplex4)>(loc, builder);
@@ -904,8 +895,11 @@ mlir::Value fir::runtime::genSum(fir::FirOpBuilder &builder, mlir::Location loc,
func = fir::runtime::getRuntimeFunc<ForcedSumComplex10>(loc, builder);
else if (eleTy == fir::ComplexType::get(builder.getContext(), 16))
func = fir::runtime::getRuntimeFunc<ForcedSumComplex16>(loc, builder);
+ else if (eleTy == fir::ComplexType::get(builder.getContext(), 2) ||
+ eleTy == fir::ComplexType::get(builder.getContext(), 3))
+ TODO(loc, "half-precision SUM");
else
- fir::emitFatalError(loc, "invalid type in Sum lowering");
+ fir::emitFatalError(loc, "invalid type in SUM");
auto fTy = func.getFunctionType();
auto sourceFile = fir::factory::locationToFilename(builder, loc);
More information about the flang-commits
mailing list