[llvm] 2da4b6e - [IR] Allow fast math flags on calls with floating point array type.

Jay Foad via llvm-commits llvm-commits at lists.llvm.org
Wed Oct 30 07:00:43 PDT 2019


Author: Jay Foad
Date: 2019-10-30T14:00:33Z
New Revision: 2da4b6e51450e8a6a40755cc5a40ebb6289766a5

URL: https://github.com/llvm/llvm-project/commit/2da4b6e51450e8a6a40755cc5a40ebb6289766a5
DIFF: https://github.com/llvm/llvm-project/commit/2da4b6e51450e8a6a40755cc5a40ebb6289766a5.diff

LOG: [IR] Allow fast math flags on calls with floating point array type.

Summary:
This extends the rules for when a call instruction is deemed to be an
FPMathOperator, which is based on the type of the call (i.e. the return
type of the function being called). Previously we only allowed
floating-point and vector-of-floating-point types. Now we also allow
arrays (nested to any depth) of floating-point and
vector-of-floating-point types.

This was motivated by llpc, the pipeline compiler for AMD GPUs
(https://github.com/GPUOpen-Drivers/llpc). llpc has many math library
functions that operate on vectors, typically represented as <4 x float>,
and some that operate on matrices, typically represented as
[4 x <4 x float>], and it's useful to be able to decorate calls to all
of them with fast math flags.

Reviewers: spatel, wristow, arsenm, hfinkel, aemerson, efriedma, cameron.mcinally, mcberg2017, jmolloy

Subscribers: wdng, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D69161

Added: 
    

Modified: 
    llvm/docs/LangRef.rst
    llvm/include/llvm/IR/Operator.h
    llvm/lib/AsmParser/LLParser.cpp
    llvm/lib/Bitcode/Reader/BitcodeReader.cpp
    llvm/test/Bitcode/compatibility.ll
    llvm/unittests/IR/InstructionsTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst
index 6c86c27ea466..c6c2cdff257b 100644
--- a/llvm/docs/LangRef.rst
+++ b/llvm/docs/LangRef.rst
@@ -10134,7 +10134,8 @@ The optional ``fast-math-flags`` marker indicates that the phi has one
 or more :ref:`fast-math-flags <fastmath>`. These are optimization hints
 to enable otherwise unsafe floating-point optimizations. Fast-math-flags
 are only valid for phis that return a floating-point scalar or vector
-type.
+type, or an array (nested to any depth) of floating-point scalar or vector
+types.
 
 Semantics:
 """"""""""
@@ -10183,7 +10184,8 @@ class <t_firstclass>` type.
 #. The optional ``fast-math flags`` marker indicates that the select has one or more
    :ref:`fast-math flags <fastmath>`. These are optimization hints to enable
    otherwise unsafe floating-point optimizations. Fast-math flags are only valid
-   for selects that return a floating-point scalar or vector type.
+   for selects that return a floating-point scalar or vector type, or an array
+   (nested to any depth) of floating-point scalar or vector types.
 
 Semantics:
 """"""""""
@@ -10282,7 +10284,8 @@ This instruction requires several arguments:
 #. The optional ``fast-math flags`` marker indicates that the call has one or more
    :ref:`fast-math flags <fastmath>`, which are optimization hints to enable
    otherwise unsafe floating-point optimizations. Fast-math flags are only valid
-   for calls that return a floating-point scalar or vector type.
+   for calls that return a floating-point scalar or vector type, or an array
+   (nested to any depth) of floating-point scalar or vector types.
 
 #. The optional "cconv" marker indicates which :ref:`calling
    convention <callingconv>` the call should use. If none is

diff  --git a/llvm/include/llvm/IR/Operator.h b/llvm/include/llvm/IR/Operator.h
index 037f5aed03ee..c8ca7e9a00e8 100644
--- a/llvm/include/llvm/IR/Operator.h
+++ b/llvm/include/llvm/IR/Operator.h
@@ -394,8 +394,12 @@ class FPMathOperator : public Operator {
       return true;
     case Instruction::PHI:
     case Instruction::Select:
-    case Instruction::Call:
-      return V->getType()->isFPOrFPVectorTy();
+    case Instruction::Call: {
+      Type *Ty = V->getType();
+      while (ArrayType *ArrTy = dyn_cast<ArrayType>(Ty))
+        Ty = ArrTy->getElementType();
+      return Ty->isFPOrFPVectorTy();
+    }
     default:
       return false;
     }

diff  --git a/llvm/lib/AsmParser/LLParser.cpp b/llvm/lib/AsmParser/LLParser.cpp
index 41e1d0bd889a..db78fa383682 100644
--- a/llvm/lib/AsmParser/LLParser.cpp
+++ b/llvm/lib/AsmParser/LLParser.cpp
@@ -5799,7 +5799,7 @@ int LLParser::ParseInstruction(Instruction *&Inst, BasicBlock *BB,
     if (Res != 0)
       return Res;
     if (FMF.any()) {
-      if (!Inst->getType()->isFPOrFPVectorTy())
+      if (!isa<FPMathOperator>(Inst))
         return Error(Loc, "fast-math-flags specified for select without "
                           "floating-point scalar or vector return type");
       Inst->setFastMathFlags(FMF);
@@ -5816,7 +5816,7 @@ int LLParser::ParseInstruction(Instruction *&Inst, BasicBlock *BB,
     if (Res != 0)
       return Res;
     if (FMF.any()) {
-      if (!Inst->getType()->isFPOrFPVectorTy())
+      if (!isa<FPMathOperator>(Inst))
         return Error(Loc, "fast-math-flags specified for phi without "
                           "floating-point scalar or vector return type");
       Inst->setFastMathFlags(FMF);
@@ -6787,10 +6787,6 @@ bool LLParser::ParseCall(Instruction *&Inst, PerFunctionState &PFS,
       ParseOptionalOperandBundles(BundleList, PFS))
     return true;
 
-  if (FMF.any() && !RetType->isFPOrFPVectorTy())
-    return Error(CallLoc, "fast-math-flags specified for call without "
-                          "floating-point scalar or vector return type");
-
   // If RetType is a non-function pointer type, then this is the short syntax
   // for the call, which means that RetType is just the return type.  Infer the
   // rest of the function argument types from the arguments that are present.
@@ -6853,8 +6849,12 @@ bool LLParser::ParseCall(Instruction *&Inst, PerFunctionState &PFS,
   CallInst *CI = CallInst::Create(Ty, Callee, Args, BundleList);
   CI->setTailCallKind(TCK);
   CI->setCallingConv(CC);
-  if (FMF.any())
+  if (FMF.any()) {
+    if (!isa<FPMathOperator>(CI))
+      return Error(CallLoc, "fast-math-flags specified for call without "
+                   "floating-point scalar or vector return type");
     CI->setFastMathFlags(FMF);
+  }
   CI->setAttributes(PAL);
   ForwardRefAttrGroups[CI] = FwdRefAttrGrps;
   Inst = CI;

diff  --git a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
index 3dac550b45ce..be375549d990 100644
--- a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
+++ b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
@@ -4641,10 +4641,9 @@ Error BitcodeReader::parseFunctionBody(Function *F) {
       // There is an optional final record for fast-math-flags if this phi has a
       // floating-point type.
       size_t NumArgs = (Record.size() - 1) / 2;
-      if ((Record.size() - 1) % 2 == 1 && !Ty->isFPOrFPVectorTy())
-        return error("Invalid record");
-
       PHINode *PN = PHINode::Create(Ty, NumArgs);
+      if ((Record.size() - 1) % 2 == 1 && !isa<FPMathOperator>(PN))
+        return error("Invalid record");
       InstructionList.push_back(PN);
 
       for (unsigned i = 0; i != NumArgs; i++) {

diff  --git a/llvm/test/Bitcode/compatibility.ll b/llvm/test/Bitcode/compatibility.ll
index 617a3fe059da..8269a90e1460 100644
--- a/llvm/test/Bitcode/compatibility.ll
+++ b/llvm/test/Bitcode/compatibility.ll
@@ -861,6 +861,14 @@ define void @fastmathflags_vector_select(<2 x i1> %cond, <2 x double> %op1, <2 x
   ret void
 }
 
+define void @fastmathflags_array_select(i1 %cond, [2 x double] %op1, [2 x double] %op2) {
+  %f.nnan.nsz = select nnan nsz i1 %cond, [2 x double] %op1, [2 x double] %op2
+  ; CHECK: %f.nnan.nsz = select nnan nsz i1 %cond, [2 x double] %op1, [2 x double] %op2
+  %f.fast = select fast i1 %cond, [2 x double] %op1, [2 x double] %op2
+  ; CHECK: %f.fast = select fast i1 %cond, [2 x double] %op1, [2 x double] %op2
+  ret void
+}
+
 define void @fastmathflags_phi(i1 %cond, float %f1, float %f2, double %d1, double %d2, half %h1, half %h2) {
 entry:
   br i1 %cond, label %L1, label %L2
@@ -903,24 +911,65 @@ exit:
   ret void
 }
 
+define void @fastmathflags_array_phi(i1 %cond, [4 x float] %f1, [4 x float] %f2, [2 x double] %d1, [2 x double] %d2, [8 x half] %h1, [8 x half] %h2) {
+entry:
+  br i1 %cond, label %L1, label %L2
+L1:
+  br label %exit
+L2:
+  br label %exit
+exit:
+  %p.nnan = phi nnan [4 x float] [ %f1, %L1 ], [ %f2, %L2 ]
+  ; CHECK: %p.nnan = phi nnan [4 x float] [ %f1, %L1 ], [ %f2, %L2 ]
+  %p.ninf = phi ninf [2 x double] [ %d1, %L1 ], [ %d2, %L2 ]
+  ; CHECK: %p.ninf = phi ninf [2 x double] [ %d1, %L1 ], [ %d2, %L2 ]
+  %p.contract = phi contract [8 x half] [ %h1, %L1 ], [ %h2, %L2 ]
+  ; CHECK: %p.contract = phi contract [8 x half] [ %h1, %L1 ], [ %h2, %L2 ]
+  %p.nsz.reassoc = phi reassoc nsz [4 x float] [ %f1, %L1 ], [ %f2, %L2 ]
+  ; CHECK: %p.nsz.reassoc = phi reassoc nsz [4 x float] [ %f1, %L1 ], [ %f2, %L2 ]
+  %p.fast = phi fast [8 x half] [ %h2, %L1 ], [ %h1, %L2 ]
+  ; CHECK: %p.fast = phi fast [8 x half] [ %h2, %L1 ], [ %h1, %L2 ]
+  ret void
+}
+
 ; Check various fast math flags and floating-point types on calls.
 
-declare float @fmf1()
-declare double @fmf2()
-declare <4 x double> @fmf3()
+declare float @fmf_f32()
+declare double @fmf_f64()
+declare <4 x double> @fmf_v4f64()
 
 ; CHECK-LABEL: fastMathFlagsForCalls(
 define void @fastMathFlagsForCalls(float %f, double %d1, <4 x double> %d2) {
-  %call.fast = call fast float @fmf1()
-  ; CHECK: %call.fast = call fast float @fmf1()
+  %call.fast = call fast float @fmf_f32()
+  ; CHECK: %call.fast = call fast float @fmf_f32()
+
+  ; Throw in some other attributes to make sure those stay in the right places.
+
+  %call.nsz.arcp = notail call nsz arcp double @fmf_f64()
+  ; CHECK: %call.nsz.arcp = notail call nsz arcp double @fmf_f64()
+
+  %call.nnan.ninf = tail call nnan ninf fastcc <4 x double> @fmf_v4f64()
+  ; CHECK: %call.nnan.ninf = tail call nnan ninf fastcc <4 x double> @fmf_v4f64()
+
+  ret void
+}
+
+declare [2 x float] @fmf_a2f32()
+declare [2 x double] @fmf_a2f64()
+declare [2 x <4 x double>] @fmf_a2v4f64()
+
+; CHECK-LABEL: fastMathFlagsForArrayCalls(
+define void @fastMathFlagsForArrayCalls([2 x float] %f, [2 x double] %d1, [2 x <4 x double>] %d2) {
+  %call.fast = call fast [2 x float] @fmf_a2f32()
+  ; CHECK: %call.fast = call fast [2 x float] @fmf_a2f32()
 
   ; Throw in some other attributes to make sure those stay in the right places.
 
-  %call.nsz.arcp = notail call nsz arcp double @fmf2()
-  ; CHECK: %call.nsz.arcp = notail call nsz arcp double @fmf2()
+  %call.nsz.arcp = notail call nsz arcp [2 x double] @fmf_a2f64()
+  ; CHECK: %call.nsz.arcp = notail call nsz arcp [2 x double] @fmf_a2f64()
 
-  %call.nnan.ninf = tail call nnan ninf fastcc <4 x double> @fmf3()
-  ; CHECK: %call.nnan.ninf = tail call nnan ninf fastcc <4 x double> @fmf3()
+  %call.nnan.ninf = tail call nnan ninf fastcc [2 x <4 x double>] @fmf_a2v4f64()
+  ; CHECK: %call.nnan.ninf = tail call nnan ninf fastcc [2 x <4 x double>] @fmf_a2v4f64()
 
   ret void
 }

diff  --git a/llvm/unittests/IR/InstructionsTest.cpp b/llvm/unittests/IR/InstructionsTest.cpp
index ea2265655cb3..556c41058e7d 100644
--- a/llvm/unittests/IR/InstructionsTest.cpp
+++ b/llvm/unittests/IR/InstructionsTest.cpp
@@ -1046,6 +1046,60 @@ TEST(InstructionsTest, PhiMightNotBeFPMathOperator) {
   FP->deleteValue();
 }
 
+TEST(InstructionsTest, FPCallIsFPMathOperator) {
+  LLVMContext C;
+
+  Type *ITy = Type::getInt32Ty(C);
+  FunctionType *IFnTy = FunctionType::get(ITy, {});
+  Value *ICallee = Constant::getNullValue(IFnTy->getPointerTo());
+  std::unique_ptr<CallInst> ICall(CallInst::Create(IFnTy, ICallee, {}, ""));
+  EXPECT_FALSE(isa<FPMathOperator>(ICall));
+
+  Type *VITy = VectorType::get(ITy, 2);
+  FunctionType *VIFnTy = FunctionType::get(VITy, {});
+  Value *VICallee = Constant::getNullValue(VIFnTy->getPointerTo());
+  std::unique_ptr<CallInst> VICall(CallInst::Create(VIFnTy, VICallee, {}, ""));
+  EXPECT_FALSE(isa<FPMathOperator>(VICall));
+
+  Type *AITy = ArrayType::get(ITy, 2);
+  FunctionType *AIFnTy = FunctionType::get(AITy, {});
+  Value *AICallee = Constant::getNullValue(AIFnTy->getPointerTo());
+  std::unique_ptr<CallInst> AICall(CallInst::Create(AIFnTy, AICallee, {}, ""));
+  EXPECT_FALSE(isa<FPMathOperator>(AICall));
+
+  Type *FTy = Type::getFloatTy(C);
+  FunctionType *FFnTy = FunctionType::get(FTy, {});
+  Value *FCallee = Constant::getNullValue(FFnTy->getPointerTo());
+  std::unique_ptr<CallInst> FCall(CallInst::Create(FFnTy, FCallee, {}, ""));
+  EXPECT_TRUE(isa<FPMathOperator>(FCall));
+
+  Type *VFTy = VectorType::get(FTy, 2);
+  FunctionType *VFFnTy = FunctionType::get(VFTy, {});
+  Value *VFCallee = Constant::getNullValue(VFFnTy->getPointerTo());
+  std::unique_ptr<CallInst> VFCall(CallInst::Create(VFFnTy, VFCallee, {}, ""));
+  EXPECT_TRUE(isa<FPMathOperator>(VFCall));
+
+  Type *AFTy = ArrayType::get(FTy, 2);
+  FunctionType *AFFnTy = FunctionType::get(AFTy, {});
+  Value *AFCallee = Constant::getNullValue(AFFnTy->getPointerTo());
+  std::unique_ptr<CallInst> AFCall(CallInst::Create(AFFnTy, AFCallee, {}, ""));
+  EXPECT_TRUE(isa<FPMathOperator>(AFCall));
+
+  Type *AVFTy = ArrayType::get(VFTy, 2);
+  FunctionType *AVFFnTy = FunctionType::get(AVFTy, {});
+  Value *AVFCallee = Constant::getNullValue(AVFFnTy->getPointerTo());
+  std::unique_ptr<CallInst> AVFCall(
+      CallInst::Create(AVFFnTy, AVFCallee, {}, ""));
+  EXPECT_TRUE(isa<FPMathOperator>(AVFCall));
+
+  Type *AAVFTy = ArrayType::get(AVFTy, 2);
+  FunctionType *AAVFFnTy = FunctionType::get(AAVFTy, {});
+  Value *AAVFCallee = Constant::getNullValue(AAVFFnTy->getPointerTo());
+  std::unique_ptr<CallInst> AAVFCall(
+      CallInst::Create(AAVFFnTy, AAVFCallee, {}, ""));
+  EXPECT_TRUE(isa<FPMathOperator>(AAVFCall));
+}
+
 TEST(InstructionsTest, FNegInstruction) {
   LLVMContext Context;
   Type *FltTy = Type::getFloatTy(Context);


        


More information about the llvm-commits mailing list