[PATCH] R600/SI: Use mad for fsub + fmul

Tom Stellard tom at stellard.net
Fri Aug 22 12:48:52 PDT 2014


On Sun, Aug 17, 2014 at 02:17:56AM +0000, Matt Arsenault wrote:
> We can use a negate source modifier to match this for fsub.
> 
> http://reviews.llvm.org/D4942
> 
> Files:
>   lib/Target/R600/AMDGPUISelLowering.cpp
>   lib/Target/R600/AMDGPUISelLowering.h
>   lib/Target/R600/AMDGPUInstrInfo.td
>   lib/Target/R600/SIISelLowering.cpp
>   lib/Target/R600/SIInstrInfo.td
>   lib/Target/R600/SIInstructions.td
>   test/CodeGen/R600/mad-sub.ll

> Index: lib/Target/R600/AMDGPUISelLowering.cpp
> ===================================================================
> --- lib/Target/R600/AMDGPUISelLowering.cpp
> +++ lib/Target/R600/AMDGPUISelLowering.cpp
> @@ -2157,6 +2157,7 @@
>    NODE_NAME_CASE(DWORDADDR)
>    NODE_NAME_CASE(FRACT)
>    NODE_NAME_CASE(CLAMP)
> +  NODE_NAME_CASE(MAD)
>    NODE_NAME_CASE(FMAX)
>    NODE_NAME_CASE(SMAX)
>    NODE_NAME_CASE(UMAX)
> Index: lib/Target/R600/AMDGPUISelLowering.h
> ===================================================================
> --- lib/Target/R600/AMDGPUISelLowering.h
> +++ lib/Target/R600/AMDGPUISelLowering.h
> @@ -180,6 +180,7 @@
>    DWORDADDR,
>    FRACT,
>    CLAMP,
> +  MAD, // Multiply + add with same result as the separate operations.
>  
>    // SIN_HW, COS_HW - f32 for SI, 1 ULP max error, valid from -100 pi to 100 pi.
>    // Denormals handled on some parts.
> Index: lib/Target/R600/AMDGPUInstrInfo.td
> ===================================================================
> --- lib/Target/R600/AMDGPUInstrInfo.td
> +++ lib/Target/R600/AMDGPUInstrInfo.td
> @@ -64,6 +64,7 @@
>  >;
>  
>  def AMDGPUclamp : SDNode<"AMDGPUISD::CLAMP", SDTFPTernaryOp, []>;
> +def AMDGPUmad : SDNode<"AMDGPUISD::MAD", SDTFPTernaryOp, []>;
>  
>  // out = max(a, b) a and b are signed ints
>  def AMDGPUsmax : SDNode<"AMDGPUISD::SMAX", SDTIntBinOp,
> Index: lib/Target/R600/SIISelLowering.cpp
> ===================================================================
> --- lib/Target/R600/SIISelLowering.cpp
> +++ lib/Target/R600/SIISelLowering.cpp
> @@ -225,6 +225,7 @@
>  
>    setOperationAction(ISD::FDIV, MVT::f32, Custom);
>  
> +  setTargetDAGCombine(ISD::FSUB);
>    setTargetDAGCombine(ISD::SELECT_CC);
>    setTargetDAGCombine(ISD::SETCC);
>  
> @@ -1380,6 +1381,42 @@
>  
>    case ISD::UINT_TO_FP: {
>      return performUCharToFloatCombine(N, DCI);
> +
> +  case ISD::FSUB: {
> +    if (DCI.getDAGCombineLevel() < AfterLegalizeDAG)
> +      break;
> +
> +    EVT VT = N->getValueType(0);
> +
> +    // Try to get the fneg to fold into the source modifier. This undoes generic
> +    // DAG combines and folds them into the mad.
> +    if (VT == MVT::f32) {
> +      SDValue LHS = N->getOperand(0);
> +      SDValue RHS = N->getOperand(1);
> +
> +      if (LHS.getOpcode() == ISD::FMUL) {
> +        // (fsub (fmul a, b), c) -> mad a, b, (fneg c)
> +
> +        SDValue A = LHS.getOperand(0);
> +        SDValue B = LHS.getOperand(1);
> +        SDValue C = DAG.getNode(ISD::FNEG, DL, VT, RHS);
> +
> +        return DAG.getNode(AMDGPUISD::MAD, DL, VT, A, B, C);
> +      }
> +
> +      if (RHS.getOpcode() == ISD::FMUL) {
> +        // (fsub c, (fmul a, b)) -> mad (fneg a), b, c
> +
> +        SDValue A = DAG.getNode(ISD::FNEG, DL, VT, RHS.getOperand(0));
> +        SDValue B = RHS.getOperand(1);
> +        SDValue C = LHS;
> +
> +        return DAG.getNode(AMDGPUISD::MAD, DL, VT, A, B, C);
> +      }
> +    }
> +
> +    break;
> +  }
>    }
>    case ISD::LOAD:
>    case ISD::STORE:
> Index: lib/Target/R600/SIInstrInfo.td
> ===================================================================
> --- lib/Target/R600/SIInstrInfo.td
> +++ lib/Target/R600/SIInstrInfo.td
> @@ -793,6 +793,17 @@
>  multiclass VOP3b_32 <bits<9> op, string opName, list<dag> pattern> :
>    VOP3b_Helper <op, VReg_32, VSrc_32, opName, pattern>;
>  
> +
> +class Vop3ModPat<Instruction Inst, VOPProfile P, SDPatternOperator node> : Pat<
> +  (node (P.Src0VT (VOP3Mods0 P.Src0VT:$src0, i32:$src0_modifiers, i32:$clamp, i32:$omod)),
> +        (P.Src1VT (VOP3Mods P.Src1VT:$src1, i32:$src1_modifiers)),
> +        (P.Src2VT (VOP3Mods P.Src2VT:$src2, i32:$src2_modifiers))),
> +  (Inst i32:$src0_modifiers, P.Src0VT:$src0,
> +        i32:$src1_modifiers, P.Src1VT:$src1,
> +        i32:$src2_modifiers, P.Src2VT:$src2,
> +        i32:$clamp,
> +        i32:$omod)>;
> +
>  //===----------------------------------------------------------------------===//
>  // Vector I/O classes
>  //===----------------------------------------------------------------------===//
> Index: lib/Target/R600/SIInstructions.td
> ===================================================================
> --- lib/Target/R600/SIInstructions.td
> +++ lib/Target/R600/SIInstructions.td
> @@ -2513,6 +2513,9 @@
>    (V_MUL_HI_I32 $src0, $src1)
>  >;
>  
> +def : Vop3ModPat<V_MAD_F32, VOP_F32_F32_F32_F32, AMDGPUmad>;
> +
> +

It would be nice if there were some way to do this without adding an extra
pattern, but I'm not sure how.

LGTM.
-Tom

>  defm : BFIPatterns <V_BFI_B32, S_MOV_B32>;
>  def : ROTRPattern <V_ALIGNBIT_B32>;
>  
> Index: test/CodeGen/R600/mad-sub.ll
> ===================================================================
> --- /dev/null
> +++ test/CodeGen/R600/mad-sub.ll
> @@ -0,0 +1,173 @@
> +; RUN: llc -march=r600 -mcpu=SI -verify-machineinstrs < %s | FileCheck -check-prefix=SI -check-prefix=FUNC %s
> +
> +declare i32 @llvm.r600.read.tidig.x() #0
> +declare float @llvm.fabs.f32(float) #0
> +
> +; FUNC-LABEL: @mad_sub_f32
> +; SI: BUFFER_LOAD_DWORD [[REGA:v[0-9]+]]
> +; SI: BUFFER_LOAD_DWORD [[REGB:v[0-9]+]]
> +; SI: BUFFER_LOAD_DWORD [[REGC:v[0-9]+]]
> +; SI: V_MAD_F32 [[RESULT:v[0-9]+]], [[REGA]], [[REGB]], -[[REGC]]
> +; SI: BUFFER_STORE_DWORD [[RESULT]]
> +define void @mad_sub_f32(float addrspace(1)* noalias nocapture %out, float addrspace(1)* noalias nocapture readonly %ptr) #1 {
> +  %tid = tail call i32 @llvm.r600.read.tidig.x() #0
> +  %tid.ext = sext i32 %tid to i64
> +  %gep0 = getelementptr float addrspace(1)* %ptr, i64 %tid.ext
> +  %add1 = add i64 %tid.ext, 1
> +  %gep1 = getelementptr float addrspace(1)* %ptr, i64 %add1
> +  %add2 = add i64 %tid.ext, 2
> +  %gep2 = getelementptr float addrspace(1)* %ptr, i64 %add2
> +  %outgep = getelementptr float addrspace(1)* %out, i64 %tid.ext
> +  %a = load float addrspace(1)* %gep0, align 4
> +  %b = load float addrspace(1)* %gep1, align 4
> +  %c = load float addrspace(1)* %gep2, align 4
> +  %mul = fmul float %a, %b
> +  %sub = fsub float %mul, %c
> +  store float %sub, float addrspace(1)* %outgep, align 4
> +  ret void
> +}
> +
> +; FUNC-LABEL: @mad_sub_inv_f32
> +; SI: BUFFER_LOAD_DWORD [[REGA:v[0-9]+]]
> +; SI: BUFFER_LOAD_DWORD [[REGB:v[0-9]+]]
> +; SI: BUFFER_LOAD_DWORD [[REGC:v[0-9]+]]
> +; SI: V_MAD_F32 [[RESULT:v[0-9]+]], -[[REGA]], [[REGB]], [[REGC]]
> +; SI: BUFFER_STORE_DWORD [[RESULT]]
> +define void @mad_sub_inv_f32(float addrspace(1)* noalias nocapture %out, float addrspace(1)* noalias nocapture readonly %ptr) #1 {
> +  %tid = tail call i32 @llvm.r600.read.tidig.x() #0
> +  %tid.ext = sext i32 %tid to i64
> +  %gep0 = getelementptr float addrspace(1)* %ptr, i64 %tid.ext
> +  %add1 = add i64 %tid.ext, 1
> +  %gep1 = getelementptr float addrspace(1)* %ptr, i64 %add1
> +  %add2 = add i64 %tid.ext, 2
> +  %gep2 = getelementptr float addrspace(1)* %ptr, i64 %add2
> +  %outgep = getelementptr float addrspace(1)* %out, i64 %tid.ext
> +  %a = load float addrspace(1)* %gep0, align 4
> +  %b = load float addrspace(1)* %gep1, align 4
> +  %c = load float addrspace(1)* %gep2, align 4
> +  %mul = fmul float %a, %b
> +  %sub = fsub float %c, %mul
> +  store float %sub, float addrspace(1)* %outgep, align 4
> +  ret void
> +}
> +
> +; FUNC-LABEL: @mad_sub_f64
> +; SI: V_MUL_F64
> +; SI: V_ADD_F64
> +define void @mad_sub_f64(double addrspace(1)* noalias nocapture %out, double addrspace(1)* noalias nocapture readonly %ptr) #1 {
> +  %tid = tail call i32 @llvm.r600.read.tidig.x() #0
> +  %tid.ext = sext i32 %tid to i64
> +  %gep0 = getelementptr double addrspace(1)* %ptr, i64 %tid.ext
> +  %add1 = add i64 %tid.ext, 1
> +  %gep1 = getelementptr double addrspace(1)* %ptr, i64 %add1
> +  %add2 = add i64 %tid.ext, 2
> +  %gep2 = getelementptr double addrspace(1)* %ptr, i64 %add2
> +  %outgep = getelementptr double addrspace(1)* %out, i64 %tid.ext
> +  %a = load double addrspace(1)* %gep0, align 8
> +  %b = load double addrspace(1)* %gep1, align 8
> +  %c = load double addrspace(1)* %gep2, align 8
> +  %mul = fmul double %a, %b
> +  %sub = fsub double %mul, %c
> +  store double %sub, double addrspace(1)* %outgep, align 8
> +  ret void
> +}
> +
> +; FUNC-LABEL: @mad_sub_fabs_f32
> +; SI: BUFFER_LOAD_DWORD [[REGA:v[0-9]+]]
> +; SI: BUFFER_LOAD_DWORD [[REGB:v[0-9]+]]
> +; SI: BUFFER_LOAD_DWORD [[REGC:v[0-9]+]]
> +; SI: V_MAD_F32 [[RESULT:v[0-9]+]], [[REGA]], [[REGB]], -|[[REGC]]|
> +; SI: BUFFER_STORE_DWORD [[RESULT]]
> +define void @mad_sub_fabs_f32(float addrspace(1)* noalias nocapture %out, float addrspace(1)* noalias nocapture readonly %ptr) #1 {
> +  %tid = tail call i32 @llvm.r600.read.tidig.x() #0
> +  %tid.ext = sext i32 %tid to i64
> +  %gep0 = getelementptr float addrspace(1)* %ptr, i64 %tid.ext
> +  %add1 = add i64 %tid.ext, 1
> +  %gep1 = getelementptr float addrspace(1)* %ptr, i64 %add1
> +  %add2 = add i64 %tid.ext, 2
> +  %gep2 = getelementptr float addrspace(1)* %ptr, i64 %add2
> +  %outgep = getelementptr float addrspace(1)* %out, i64 %tid.ext
> +  %a = load float addrspace(1)* %gep0, align 4
> +  %b = load float addrspace(1)* %gep1, align 4
> +  %c = load float addrspace(1)* %gep2, align 4
> +  %c.abs = call float @llvm.fabs.f32(float %c) #0
> +  %mul = fmul float %a, %b
> +  %sub = fsub float %mul, %c.abs
> +  store float %sub, float addrspace(1)* %outgep, align 4
> +  ret void
> +}
> +
> +; FUNC-LABEL: @mad_sub_fabs_inv_f32
> +; SI: BUFFER_LOAD_DWORD [[REGA:v[0-9]+]]
> +; SI: BUFFER_LOAD_DWORD [[REGB:v[0-9]+]]
> +; SI: BUFFER_LOAD_DWORD [[REGC:v[0-9]+]]
> +; SI: V_MAD_F32 [[RESULT:v[0-9]+]], -[[REGA]], [[REGB]], |[[REGC]]|
> +; SI: BUFFER_STORE_DWORD [[RESULT]]
> +define void @mad_sub_fabs_inv_f32(float addrspace(1)* noalias nocapture %out, float addrspace(1)* noalias nocapture readonly %ptr) #1 {
> +  %tid = tail call i32 @llvm.r600.read.tidig.x() #0
> +  %tid.ext = sext i32 %tid to i64
> +  %gep0 = getelementptr float addrspace(1)* %ptr, i64 %tid.ext
> +  %add1 = add i64 %tid.ext, 1
> +  %gep1 = getelementptr float addrspace(1)* %ptr, i64 %add1
> +  %add2 = add i64 %tid.ext, 2
> +  %gep2 = getelementptr float addrspace(1)* %ptr, i64 %add2
> +  %outgep = getelementptr float addrspace(1)* %out, i64 %tid.ext
> +  %a = load float addrspace(1)* %gep0, align 4
> +  %b = load float addrspace(1)* %gep1, align 4
> +  %c = load float addrspace(1)* %gep2, align 4
> +  %c.abs = call float @llvm.fabs.f32(float %c) #0
> +  %mul = fmul float %a, %b
> +  %sub = fsub float %c.abs, %mul
> +  store float %sub, float addrspace(1)* %outgep, align 4
> +  ret void
> +}
> +
> +; FUNC-LABEL: @neg_neg_mad_f32
> +; SI: V_MAD_F32 {{v[0-9]+}}, {{v[0-9]+}}, {{v[0-9]+}}, {{v[0-9]+}}
> +define void @neg_neg_mad_f32(float addrspace(1)* noalias nocapture %out, float addrspace(1)* noalias nocapture readonly %ptr) #1 {
> +  %tid = tail call i32 @llvm.r600.read.tidig.x() #0
> +  %tid.ext = sext i32 %tid to i64
> +  %gep0 = getelementptr float addrspace(1)* %ptr, i64 %tid.ext
> +  %add1 = add i64 %tid.ext, 1
> +  %gep1 = getelementptr float addrspace(1)* %ptr, i64 %add1
> +  %add2 = add i64 %tid.ext, 2
> +  %gep2 = getelementptr float addrspace(1)* %ptr, i64 %add2
> +  %outgep = getelementptr float addrspace(1)* %out, i64 %tid.ext
> +  %a = load float addrspace(1)* %gep0, align 4
> +  %b = load float addrspace(1)* %gep1, align 4
> +  %c = load float addrspace(1)* %gep2, align 4
> +  %nega = fsub float -0.000000e+00, %a
> +  %negb = fsub float -0.000000e+00, %b
> +  %mul = fmul float %nega, %negb
> +  %sub = fadd float %mul, %c
> +  store float %sub, float addrspace(1)* %outgep, align 4
> +  ret void
> +}
> +
> +; FUNC-LABEL: @mad_fabs_sub_f32
> +; SI: BUFFER_LOAD_DWORD [[REGA:v[0-9]+]]
> +; SI: BUFFER_LOAD_DWORD [[REGB:v[0-9]+]]
> +; SI: BUFFER_LOAD_DWORD [[REGC:v[0-9]+]]
> +; SI: V_MAD_F32 [[RESULT:v[0-9]+]], [[REGA]], |[[REGB]]|, -[[REGC]]
> +; SI: BUFFER_STORE_DWORD [[RESULT]]
> +define void @mad_fabs_sub_f32(float addrspace(1)* noalias nocapture %out, float addrspace(1)* noalias nocapture readonly %ptr) #1 {
> +  %tid = tail call i32 @llvm.r600.read.tidig.x() #0
> +  %tid.ext = sext i32 %tid to i64
> +  %gep0 = getelementptr float addrspace(1)* %ptr, i64 %tid.ext
> +  %add1 = add i64 %tid.ext, 1
> +  %gep1 = getelementptr float addrspace(1)* %ptr, i64 %add1
> +  %add2 = add i64 %tid.ext, 2
> +  %gep2 = getelementptr float addrspace(1)* %ptr, i64 %add2
> +  %outgep = getelementptr float addrspace(1)* %out, i64 %tid.ext
> +  %a = load float addrspace(1)* %gep0, align 4
> +  %b = load float addrspace(1)* %gep1, align 4
> +  %c = load float addrspace(1)* %gep2, align 4
> +  %b.abs = call float @llvm.fabs.f32(float %b) #0
> +  %mul = fmul float %a, %b.abs
> +  %sub = fsub float %mul, %c
> +  store float %sub, float addrspace(1)* %outgep, align 4
> +  ret void
> +}
> +
> +attributes #0 = { nounwind readnone }
> +attributes #1 = { nounwind }

> _______________________________________________
> llvm-commits mailing list
> llvm-commits at cs.uiuc.edu
> http://lists.cs.uiuc.edu/mailman/listinfo/llvm-commits




More information about the llvm-commits mailing list