[flang-commits] [flang] 4a8305c - [flang] Add TODO for half-precision intrinsic reductions

Valentin Clement via flang-commits flang-commits at lists.llvm.org
Mon Jun 13 08:40:09 PDT 2022


Author: Valentin Clement
Date: 2022-06-13T17:40:01+02:00
New Revision: 4a8305ce856b6f4d2e49f3100226d00b402dff86

URL: https://github.com/llvm/llvm-project/commit/4a8305ce856b6f4d2e49f3100226d00b402dff86
DIFF: https://github.com/llvm/llvm-project/commit/4a8305ce856b6f4d2e49f3100226d00b402dff86.diff

LOG: [flang] Add TODO for half-precision intrinsic reductions

Add TODO for half-precision for reduction.

This patch is part of the upstreaming effort from fir-dev branch.

Reviewed By: jeanPerier, PeteSteinfeld

Differential Revision: https://reviews.llvm.org/D127622

Co-authored-by: Eric Schweitz <eschweitz at nvidia.com>

Added: 
    

Modified: 
    flang/lib/Optimizer/Builder/Runtime/Reduction.cpp

Removed: 
    


################################################################################
diff  --git a/flang/lib/Optimizer/Builder/Runtime/Reduction.cpp b/flang/lib/Optimizer/Builder/Runtime/Reduction.cpp
index aa39c8e464eca..579020acb8da5 100644
--- a/flang/lib/Optimizer/Builder/Runtime/Reduction.cpp
+++ b/flang/lib/Optimizer/Builder/Runtime/Reduction.cpp
@@ -545,7 +545,9 @@ mlir::Value fir::runtime::genMaxval(fir::FirOpBuilder &builder,
   auto eleTy = arrTy.cast<fir::SequenceType>().getEleTy();
   auto dim = builder.createIntegerConstant(loc, builder.getIndexType(), 0);
 
-  if (eleTy.isF32())
+  if (eleTy.isF16() || eleTy.isBF16())
+    TODO(loc, "half-precision MAXVAL");
+  else if (eleTy.isF32())
     func = fir::runtime::getRuntimeFunc<mkRTKey(MaxvalReal4)>(loc, builder);
   else if (eleTy.isF64())
     func = fir::runtime::getRuntimeFunc<mkRTKey(MaxvalReal8)>(loc, builder);
@@ -553,23 +555,18 @@ mlir::Value fir::runtime::genMaxval(fir::FirOpBuilder &builder,
     func = fir::runtime::getRuntimeFunc<ForcedMaxvalReal10>(loc, builder);
   else if (eleTy.isF128())
     func = fir::runtime::getRuntimeFunc<ForcedMaxvalReal16>(loc, builder);
-  else if (eleTy ==
-           builder.getIntegerType(builder.getKindMap().getIntegerBitsize(1)))
+  else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(1)))
     func = fir::runtime::getRuntimeFunc<mkRTKey(MaxvalInteger1)>(loc, builder);
-  else if (eleTy ==
-           builder.getIntegerType(builder.getKindMap().getIntegerBitsize(2)))
+  else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(2)))
     func = fir::runtime::getRuntimeFunc<mkRTKey(MaxvalInteger2)>(loc, builder);
-  else if (eleTy ==
-           builder.getIntegerType(builder.getKindMap().getIntegerBitsize(4)))
+  else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(4)))
     func = fir::runtime::getRuntimeFunc<mkRTKey(MaxvalInteger4)>(loc, builder);
-  else if (eleTy ==
-           builder.getIntegerType(builder.getKindMap().getIntegerBitsize(8)))
+  else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(8)))
     func = fir::runtime::getRuntimeFunc<mkRTKey(MaxvalInteger8)>(loc, builder);
-  else if (eleTy ==
-           builder.getIntegerType(builder.getKindMap().getIntegerBitsize(16)))
+  else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(16)))
     func = fir::runtime::getRuntimeFunc<ForcedMaxvalInteger16>(loc, builder);
   else
-    fir::emitFatalError(loc, "invalid type in Maxval lowering");
+    fir::emitFatalError(loc, "invalid type in MAXVAL");
 
   auto fTy = func.getFunctionType();
   auto sourceFile = fir::factory::locationToFilename(builder, loc);
@@ -664,7 +661,9 @@ mlir::Value fir::runtime::genMinval(fir::FirOpBuilder &builder,
   auto eleTy = arrTy.cast<fir::SequenceType>().getEleTy();
   auto dim = builder.createIntegerConstant(loc, builder.getIndexType(), 0);
 
-  if (eleTy.isF32())
+  if (eleTy.isF16() || eleTy.isBF16())
+    TODO(loc, "half-precision MINVAL");
+  else if (eleTy.isF32())
     func = fir::runtime::getRuntimeFunc<mkRTKey(MinvalReal4)>(loc, builder);
   else if (eleTy.isF64())
     func = fir::runtime::getRuntimeFunc<mkRTKey(MinvalReal8)>(loc, builder);
@@ -672,23 +671,18 @@ mlir::Value fir::runtime::genMinval(fir::FirOpBuilder &builder,
     func = fir::runtime::getRuntimeFunc<ForcedMinvalReal10>(loc, builder);
   else if (eleTy.isF128())
     func = fir::runtime::getRuntimeFunc<ForcedMinvalReal16>(loc, builder);
-  else if (eleTy ==
-           builder.getIntegerType(builder.getKindMap().getIntegerBitsize(1)))
+  else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(1)))
     func = fir::runtime::getRuntimeFunc<mkRTKey(MinvalInteger1)>(loc, builder);
-  else if (eleTy ==
-           builder.getIntegerType(builder.getKindMap().getIntegerBitsize(2)))
+  else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(2)))
     func = fir::runtime::getRuntimeFunc<mkRTKey(MinvalInteger2)>(loc, builder);
-  else if (eleTy ==
-           builder.getIntegerType(builder.getKindMap().getIntegerBitsize(4)))
+  else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(4)))
     func = fir::runtime::getRuntimeFunc<mkRTKey(MinvalInteger4)>(loc, builder);
-  else if (eleTy ==
-           builder.getIntegerType(builder.getKindMap().getIntegerBitsize(8)))
+  else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(8)))
     func = fir::runtime::getRuntimeFunc<mkRTKey(MinvalInteger8)>(loc, builder);
-  else if (eleTy ==
-           builder.getIntegerType(builder.getKindMap().getIntegerBitsize(16)))
+  else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(16)))
     func = fir::runtime::getRuntimeFunc<ForcedMinvalInteger16>(loc, builder);
   else
-    fir::emitFatalError(loc, "invalid type in Minval lowering");
+    fir::emitFatalError(loc, "invalid type in MINVAL");
 
   auto fTy = func.getFunctionType();
   auto sourceFile = fir::factory::locationToFilename(builder, loc);
@@ -721,7 +715,9 @@ mlir::Value fir::runtime::genProduct(fir::FirOpBuilder &builder,
   auto eleTy = arrTy.cast<fir::SequenceType>().getEleTy();
   auto dim = builder.createIntegerConstant(loc, builder.getIndexType(), 0);
 
-  if (eleTy.isF32())
+  if (eleTy.isF16() || eleTy.isBF16())
+    TODO(loc, "half-precision PRODUCT");
+  else if (eleTy.isF32())
     func = fir::runtime::getRuntimeFunc<mkRTKey(ProductReal4)>(loc, builder);
   else if (eleTy.isF64())
     func = fir::runtime::getRuntimeFunc<mkRTKey(ProductReal8)>(loc, builder);
@@ -729,20 +725,15 @@ mlir::Value fir::runtime::genProduct(fir::FirOpBuilder &builder,
     func = fir::runtime::getRuntimeFunc<ForcedProductReal10>(loc, builder);
   else if (eleTy.isF128())
     func = fir::runtime::getRuntimeFunc<ForcedProductReal16>(loc, builder);
-  else if (eleTy ==
-           builder.getIntegerType(builder.getKindMap().getIntegerBitsize(1)))
+  else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(1)))
     func = fir::runtime::getRuntimeFunc<mkRTKey(ProductInteger1)>(loc, builder);
-  else if (eleTy ==
-           builder.getIntegerType(builder.getKindMap().getIntegerBitsize(2)))
+  else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(2)))
     func = fir::runtime::getRuntimeFunc<mkRTKey(ProductInteger2)>(loc, builder);
-  else if (eleTy ==
-           builder.getIntegerType(builder.getKindMap().getIntegerBitsize(4)))
+  else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(4)))
     func = fir::runtime::getRuntimeFunc<mkRTKey(ProductInteger4)>(loc, builder);
-  else if (eleTy ==
-           builder.getIntegerType(builder.getKindMap().getIntegerBitsize(8)))
+  else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(8)))
     func = fir::runtime::getRuntimeFunc<mkRTKey(ProductInteger8)>(loc, builder);
-  else if (eleTy ==
-           builder.getIntegerType(builder.getKindMap().getIntegerBitsize(16)))
+  else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(16)))
     func = fir::runtime::getRuntimeFunc<ForcedProductInteger16>(loc, builder);
   else if (eleTy == fir::ComplexType::get(builder.getContext(), 4))
     func =
@@ -754,8 +745,11 @@ mlir::Value fir::runtime::genProduct(fir::FirOpBuilder &builder,
     func = fir::runtime::getRuntimeFunc<ForcedProductComplex10>(loc, builder);
   else if (eleTy == fir::ComplexType::get(builder.getContext(), 16))
     func = fir::runtime::getRuntimeFunc<ForcedProductComplex16>(loc, builder);
+  else if (eleTy == fir::ComplexType::get(builder.getContext(), 2) ||
+           eleTy == fir::ComplexType::get(builder.getContext(), 3))
+    TODO(loc, "half-precision PRODUCT");
   else
-    fir::emitFatalError(loc, "invalid type in Product lowering");
+    fir::emitFatalError(loc, "invalid type in PRODUCT");
 
   auto fTy = func.getFunctionType();
   auto sourceFile = fir::factory::locationToFilename(builder, loc);
@@ -788,7 +782,9 @@ mlir::Value fir::runtime::genDotProduct(fir::FirOpBuilder &builder,
   auto arrTy = fir::dyn_cast_ptrOrBoxEleTy(ty);
   auto eleTy = arrTy.cast<fir::SequenceType>().getEleTy();
 
-  if (eleTy.isF32())
+  if (eleTy.isF16() || eleTy.isBF16())
+    TODO(loc, "half-precision DOTPRODUCT");
+  else if (eleTy.isF32())
     func = fir::runtime::getRuntimeFunc<mkRTKey(DotProductReal4)>(loc, builder);
   else if (eleTy.isF64())
     func = fir::runtime::getRuntimeFunc<mkRTKey(DotProductReal8)>(loc, builder);
@@ -808,31 +804,29 @@ mlir::Value fir::runtime::genDotProduct(fir::FirOpBuilder &builder,
   else if (eleTy == fir::ComplexType::get(builder.getContext(), 16))
     func =
         fir::runtime::getRuntimeFunc<ForcedDotProductComplex16>(loc, builder);
-  else if (eleTy ==
-           builder.getIntegerType(builder.getKindMap().getIntegerBitsize(1)))
+  else if (eleTy == fir::ComplexType::get(builder.getContext(), 2) ||
+           eleTy == fir::ComplexType::get(builder.getContext(), 3))
+    TODO(loc, "half-precision DOTPRODUCT");
+  else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(1)))
     func =
         fir::runtime::getRuntimeFunc<mkRTKey(DotProductInteger1)>(loc, builder);
-  else if (eleTy ==
-           builder.getIntegerType(builder.getKindMap().getIntegerBitsize(2)))
+  else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(2)))
     func =
         fir::runtime::getRuntimeFunc<mkRTKey(DotProductInteger2)>(loc, builder);
-  else if (eleTy ==
-           builder.getIntegerType(builder.getKindMap().getIntegerBitsize(4)))
+  else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(4)))
     func =
         fir::runtime::getRuntimeFunc<mkRTKey(DotProductInteger4)>(loc, builder);
-  else if (eleTy ==
-           builder.getIntegerType(builder.getKindMap().getIntegerBitsize(8)))
+  else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(8)))
     func =
         fir::runtime::getRuntimeFunc<mkRTKey(DotProductInteger8)>(loc, builder);
-  else if (eleTy ==
-           builder.getIntegerType(builder.getKindMap().getIntegerBitsize(16)))
+  else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(16)))
     func =
         fir::runtime::getRuntimeFunc<ForcedDotProductInteger16>(loc, builder);
   else if (eleTy.isa<fir::LogicalType>())
     func =
         fir::runtime::getRuntimeFunc<mkRTKey(DotProductLogical)>(loc, builder);
   else
-    fir::emitFatalError(loc, "invalid type in DotProduct lowering");
+    fir::emitFatalError(loc, "invalid type in DOTPRODUCT");
 
   auto fTy = func.getFunctionType();
   auto sourceFile = fir::factory::locationToFilename(builder, loc);
@@ -873,7 +867,9 @@ mlir::Value fir::runtime::genSum(fir::FirOpBuilder &builder, mlir::Location loc,
   auto eleTy = arrTy.cast<fir::SequenceType>().getEleTy();
   auto dim = builder.createIntegerConstant(loc, builder.getIndexType(), 0);
 
-  if (eleTy.isF32())
+  if (eleTy.isF16() || eleTy.isBF16())
+    TODO(loc, "half-precision SUM");
+  else if (eleTy.isF32())
     func = fir::runtime::getRuntimeFunc<mkRTKey(SumReal4)>(loc, builder);
   else if (eleTy.isF64())
     func = fir::runtime::getRuntimeFunc<mkRTKey(SumReal8)>(loc, builder);
@@ -881,20 +877,15 @@ mlir::Value fir::runtime::genSum(fir::FirOpBuilder &builder, mlir::Location loc,
     func = fir::runtime::getRuntimeFunc<ForcedSumReal10>(loc, builder);
   else if (eleTy.isF128())
     func = fir::runtime::getRuntimeFunc<ForcedSumReal16>(loc, builder);
-  else if (eleTy ==
-           builder.getIntegerType(builder.getKindMap().getIntegerBitsize(1)))
+  else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(1)))
     func = fir::runtime::getRuntimeFunc<mkRTKey(SumInteger1)>(loc, builder);
-  else if (eleTy ==
-           builder.getIntegerType(builder.getKindMap().getIntegerBitsize(2)))
+  else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(2)))
     func = fir::runtime::getRuntimeFunc<mkRTKey(SumInteger2)>(loc, builder);
-  else if (eleTy ==
-           builder.getIntegerType(builder.getKindMap().getIntegerBitsize(4)))
+  else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(4)))
     func = fir::runtime::getRuntimeFunc<mkRTKey(SumInteger4)>(loc, builder);
-  else if (eleTy ==
-           builder.getIntegerType(builder.getKindMap().getIntegerBitsize(8)))
+  else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(8)))
     func = fir::runtime::getRuntimeFunc<mkRTKey(SumInteger8)>(loc, builder);
-  else if (eleTy ==
-           builder.getIntegerType(builder.getKindMap().getIntegerBitsize(16)))
+  else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(16)))
     func = fir::runtime::getRuntimeFunc<ForcedSumInteger16>(loc, builder);
   else if (eleTy == fir::ComplexType::get(builder.getContext(), 4))
     func = fir::runtime::getRuntimeFunc<mkRTKey(CppSumComplex4)>(loc, builder);
@@ -904,8 +895,11 @@ mlir::Value fir::runtime::genSum(fir::FirOpBuilder &builder, mlir::Location loc,
     func = fir::runtime::getRuntimeFunc<ForcedSumComplex10>(loc, builder);
   else if (eleTy == fir::ComplexType::get(builder.getContext(), 16))
     func = fir::runtime::getRuntimeFunc<ForcedSumComplex16>(loc, builder);
+  else if (eleTy == fir::ComplexType::get(builder.getContext(), 2) ||
+           eleTy == fir::ComplexType::get(builder.getContext(), 3))
+    TODO(loc, "half-precision SUM");
   else
-    fir::emitFatalError(loc, "invalid type in Sum lowering");
+    fir::emitFatalError(loc, "invalid type in SUM");
 
   auto fTy = func.getFunctionType();
   auto sourceFile = fir::factory::locationToFilename(builder, loc);


        


More information about the flang-commits mailing list