[llvm] [SPIRV] Added support for the constrained arithmetic intrinsic (PR #157441)

via llvm-commits llvm-commits at lists.llvm.org
Mon Sep 8 05:39:36 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-support

Author: Subash B (SubashBoopathi)

<details>
<summary>Changes</summary>

Added SPIR-V support for constrained arithmetic intrinsic fmuladd, with legalization, instruction selection, and tests; lowered as a sequence of OpFMul and OpFAdd, consistent with the SPIR-V translator.

---
Full diff: https://github.com/llvm/llvm-project/pull/157441.diff


6 Files Affected:

- (modified) llvm/include/llvm/Support/TargetOpcodes.def (+1) 
- (modified) llvm/include/llvm/Target/GenericOpcodes.td (+1) 
- (modified) llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp (+2) 
- (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+48) 
- (modified) llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp (+1-1) 
- (modified) llvm/test/CodeGen/SPIRV/llvm-intrinsics/constrained-arithmetic.ll (+4) 


``````````diff
diff --git a/llvm/include/llvm/Support/TargetOpcodes.def b/llvm/include/llvm/Support/TargetOpcodes.def
index b905576b61791..dcf80c56c47b3 100644
--- a/llvm/include/llvm/Support/TargetOpcodes.def
+++ b/llvm/include/llvm/Support/TargetOpcodes.def
@@ -643,6 +643,7 @@ HANDLE_TARGET_OPCODE(G_FMA)
 
 /// Generic FP multiply and add. Behaves as separate fmul and fadd.
 HANDLE_TARGET_OPCODE(G_FMAD)
+HANDLE_TARGET_OPCODE(G_STRICT_FMULADD)
 
 /// Generic FP division.
 HANDLE_TARGET_OPCODE(G_FDIV)
diff --git a/llvm/include/llvm/Target/GenericOpcodes.td b/llvm/include/llvm/Target/GenericOpcodes.td
index ce4750db88c9a..08fbd2253edf2 100644
--- a/llvm/include/llvm/Target/GenericOpcodes.td
+++ b/llvm/include/llvm/Target/GenericOpcodes.td
@@ -1716,6 +1716,7 @@ def G_STRICT_FREM : ConstrainedInstruction<G_FREM>;
 def G_STRICT_FMA : ConstrainedInstruction<G_FMA>;
 def G_STRICT_FSQRT : ConstrainedInstruction<G_FSQRT>;
 def G_STRICT_FLDEXP : ConstrainedInstruction<G_FLDEXP>;
+def G_STRICT_FMULADD : ConstrainedInstruction<G_FMAD>;
 
 //------------------------------------------------------------------------------
 // Memory intrinsics
diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
index d7280eaba2440..3742df2f1732b 100644
--- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
@@ -2061,6 +2061,8 @@ static unsigned getConstrainedOpcode(Intrinsic::ID ID) {
     return TargetOpcode::G_STRICT_FSQRT;
   case Intrinsic::experimental_constrained_ldexp:
     return TargetOpcode::G_STRICT_FLDEXP;
+  case Intrinsic::experimental_constrained_fmuladd:
+    return TargetOpcode::G_STRICT_FMULADD;
   default:
     return 0;
   }
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 6608b3f2cbefd..9d56c18f1c096 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -227,6 +227,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
   bool selectExt(Register ResVReg, const SPIRVType *ResType, MachineInstr &I,
                  bool IsSigned) const;
 
+  bool selectStrictFMulAdd(Register ResVReg, const SPIRVType *ResType,
+                           MachineInstr &I) const;
+
   bool selectTrunc(Register ResVReg, const SPIRVType *ResType,
                    MachineInstr &I) const;
 
@@ -689,6 +692,9 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg,
   case TargetOpcode::G_FMA:
     return selectExtInst(ResVReg, ResType, I, CL::fma, GL::Fma);
 
+  case TargetOpcode::G_STRICT_FMULADD:
+    return selectStrictFMulAdd(ResVReg, ResType, I);
+
   case TargetOpcode::G_STRICT_FLDEXP:
     return selectExtInst(ResVReg, ResType, I, CL::ldexp);
 
@@ -1038,6 +1044,48 @@ bool SPIRVInstructionSelector::selectOpWithSrcs(Register ResVReg,
   return MIB.constrainAllUses(TII, TRI, RBI);
 }
 
+bool SPIRVInstructionSelector::selectStrictFMulAdd(Register ResVReg,
+                                                   const SPIRVType *ResType,
+                                                   MachineInstr &I) const {
+  assert(I.getNumOperands() == 4 &&
+         "FMulAdd should have 3 operands and result");
+  assert(I.getOperand(1).isReg() && I.getOperand(2).isReg() &&
+         I.getOperand(3).isReg() && "Operands should be registers");
+  MachineBasicBlock &BB = *I.getParent();
+  Register MulLHS = I.getOperand(1).getReg();
+  Register MulRHS = I.getOperand(2).getReg();
+  Register AddRHS = I.getOperand(3).getReg();
+  SPIRVType *MulLHSType = GR.getSPIRVTypeForVReg(MulLHS);
+  SPIRVType *MulRHSType = GR.getSPIRVTypeForVReg(MulRHS);
+  SPIRVType *AddRHSType = GR.getSPIRVTypeForVReg(AddRHS);
+  if (!MulLHSType || !MulRHSType || !AddRHSType)
+    report_fatal_error("Input Type could not be determined.");
+  if (!GR.isScalarOrVectorOfType(MulLHS, SPIRV::OpTypeFloat) ||
+      !GR.isScalarOrVectorOfType(MulRHS, SPIRV::OpTypeFloat) ||
+      !GR.isScalarOrVectorOfType(AddRHS, SPIRV::OpTypeFloat)) {
+    report_fatal_error("FMulAdd requires floating-point operands");
+  }
+  bool IsScalar = (MulLHSType->getOpcode() == SPIRV::OpTypeFloat);
+  bool IsVector = (MulLHSType->getOpcode() == SPIRV::OpTypeVector);
+  if (!IsScalar && !IsVector)
+    report_fatal_error("Unsupported type for FMulAdd operation");
+  unsigned MulOpcode = IsScalar ? SPIRV::OpFMulS : SPIRV::OpFMulV;
+  unsigned AddOpcode = IsScalar ? SPIRV::OpFAddS : SPIRV::OpFAddV;
+  Register MulTemp = MRI->createVirtualRegister(MRI->getRegClass(MulLHS));
+  BuildMI(BB, I, I.getDebugLoc(), TII.get(MulOpcode))
+      .addDef(MulTemp)
+      .addUse(GR.getSPIRVTypeID(ResType))
+      .addUse(MulLHS)
+      .addUse(MulRHS)
+      .constrainAllUses(TII, TRI, RBI);
+  return BuildMI(BB, I, I.getDebugLoc(), TII.get(AddOpcode))
+      .addDef(ResVReg)
+      .addUse(GR.getSPIRVTypeID(ResType))
+      .addUse(MulTemp)
+      .addUse(AddRHS)
+      .constrainAllUses(TII, TRI, RBI);
+}
+
 bool SPIRVInstructionSelector::selectUnOp(Register ResVReg,
                                           const SPIRVType *ResType,
                                           MachineInstr &I,
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
index 721f64a329d31..8041ef67cfa56 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
@@ -193,7 +193,7 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
       .legalFor(allIntScalarsAndVectors)
       .legalIf(extendedScalarsAndVectors);
 
-  getActionDefinitionsBuilder({G_FMA, G_STRICT_FMA})
+  getActionDefinitionsBuilder({G_FMA, G_STRICT_FMA, G_STRICT_FMULADD})
       .legalFor(allFloatScalarsAndVectors);
 
   getActionDefinitionsBuilder(G_STRICT_FLDEXP)
diff --git a/llvm/test/CodeGen/SPIRV/llvm-intrinsics/constrained-arithmetic.ll b/llvm/test/CodeGen/SPIRV/llvm-intrinsics/constrained-arithmetic.ll
index 11bedfa605f9b..b2b8bbfce7321 100644
--- a/llvm/test/CodeGen/SPIRV/llvm-intrinsics/constrained-arithmetic.ll
+++ b/llvm/test/CodeGen/SPIRV/llvm-intrinsics/constrained-arithmetic.ll
@@ -7,6 +7,7 @@
 ; CHECK-DAG: OpName %[[#r4:]] "r4"
 ; CHECK-DAG: OpName %[[#r5:]] "r5"
 ; CHECK-DAG: OpName %[[#r6:]] "r6"
+; CHECK-DAG: OpName %[[#r7:]] "r7"
 
 ; CHECK-NOT: OpDecorate %[[#r5]] FPRoundingMode
 ; CHECK-NOT: OpDecorate %[[#r6]] FPRoundingMode
@@ -22,6 +23,8 @@
 ; CHECK: OpFMul %[[#]] %[[#]]
 ; CHECK: OpExtInst %[[#]] %[[#]] fma %[[#]] %[[#]] %[[#]]
 ; CHECK: OpFRem
+; CHECK: OpFMul %[[#]] %[[#]]
+; CHECK: OpFAdd %[[#]] %[[#]]
 
 ; Function Attrs: norecurse nounwind strictfp
 define dso_local spir_kernel void @test(float %a, i32 %in, i32 %ui) {
@@ -32,6 +35,7 @@ entry:
   %r4 = tail call float @llvm.experimental.constrained.fmul.f32(float %a, float %a, metadata !"round.downward", metadata !"fpexcept.strict")
   %r5 = tail call float @llvm.experimental.constrained.fma.f32(float %a, float %a, float %a, metadata !"round.dynamic", metadata !"fpexcept.strict")
   %r6 = tail call float @llvm.experimental.constrained.frem.f32(float %a, float %a, metadata !"round.dynamic", metadata !"fpexcept.strict")
+  %r7 = tail call float @llvm.experimental.constrained.fmuladd.f32(float %a, float %a, float %a, metadata !"round.dynamic", metadata !"fpexcept.strict")
   ret void
 }
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/157441


More information about the llvm-commits mailing list