[PATCH] D144911: adding bf16 support to NVPTX

Kushan Ahmadian via Phabricator via cfe-commits cfe-commits at lists.llvm.org
Mon Jun 5 11:21:27 PDT 2023


kushanam updated this revision to Diff 528526.
kushanam added a comment.

adding min and max for bf16 and refactoring the code


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D144911/new/

https://reviews.llvm.org/D144911

Files:
  llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
  llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp


Index: llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
===================================================================
--- llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -150,23 +150,11 @@
 }
 
 static bool Isv2f16Orv2bf16Type(MVT VT) {
-  switch (VT.SimpleTy) {
-  default:
-    return false;
-  case MVT::v2f16:
-  case MVT::v2bf16:
-    return true;
-  }
+  return (VT.SimpleTy == MVT::v2f16 || VT.SimpleTy == MVT::v2bf16);
 }
 
 static bool Isf16Orbf16Type(MVT VT) {
-  switch (VT.SimpleTy) {
-  default:
-    return false;
-  case MVT::f16:
-  case MVT::bf16:
-    return true;
-  }
+  return (VT.SimpleTy == MVT::f16 || VT.SimpleTy == MVT::bf16);
 }
 
 /// ComputePTXValueVTs - For the given Type \p Ty, returns the set of primitive
@@ -624,9 +612,6 @@
   for (const auto &Op : {ISD::FADD, ISD::FMUL, ISD::FSUB, ISD::FMA}) {
     setFP16OperationAction(Op, MVT::f16, Legal, Promote);
     setFP16OperationAction(Op, MVT::v2f16, Legal, Expand);
-  }
-
-  for (const auto &Op : {ISD::FADD, ISD::FMUL, ISD::FSUB, ISD::FMA}) {
     setBF16OperationAction(Op, MVT::bf16, Legal, Promote);
     setBF16OperationAction(Op, MVT::v2bf16, Legal, Expand);
   }
@@ -693,20 +678,18 @@
   };
   for (const auto &Op : {ISD::FMINNUM, ISD::FMAXNUM}) {
     setFP16OperationAction(Op, MVT::f16, GetMinMaxAction(Promote), Promote);
+    setBF16OperationAction(Op, MVT::bf16, GetMinMaxAction(Promote), Promote);
     setOperationAction(Op, MVT::f32, Legal);
     setOperationAction(Op, MVT::f64, Legal);
     setFP16OperationAction(Op, MVT::v2f16, GetMinMaxAction(Expand), Expand);
+    setBF16OperationAction(Op, MVT::v2bf16, GetMinMaxAction(Expand), Expand);
   }
   for (const auto &Op : {ISD::FMINIMUM, ISD::FMAXIMUM}) {
     setFP16OperationAction(Op, MVT::f16, GetMinMaxAction(Expand), Expand);
+    setFP16OperationAction(Op, MVT::bf16, GetMinMaxAction(Expand), Expand);
     setOperationAction(Op, MVT::f32, GetMinMaxAction(Expand));
     setFP16OperationAction(Op, MVT::v2f16, GetMinMaxAction(Expand), Expand);
-  }
-  for (const auto &Op : {ISD::FMINNUM, ISD::FMAXNUM}) {
-    setBF16OperationAction(Op, MVT::bf16, GetMinMaxAction(Promote), Promote);
-    setBF16OperationAction(Op, MVT::v2bf16, GetMinMaxAction(Expand), Expand);
-    setBF16OperationAction(Op, MVT::bf16, GetMinMaxAction(Expand), Expand);
-    setBF16OperationAction(Op, MVT::v2bf16, GetMinMaxAction(Expand), Expand);
+    setFP16OperationAction(Op, MVT::v2bf16, GetMinMaxAction(Expand), Expand);
   }
 
   // No FEXP2, FLOG2.  The PTX ex2 and log2 functions are always approximate.
Index: llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
===================================================================
--- llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -1294,14 +1294,11 @@
     NumElts = EltVT.getVectorNumElements();
     EltVT = EltVT.getVectorElementType();
     // vectors of f16 are loaded/stored as multiples of v2f16 elements.
-    if (EltVT == MVT::f16 && N->getValueType(0) == MVT::v2f16) {
-      assert(NumElts % 2 == 0 && "Vector must have even number of elements");
-      EltVT = MVT::v2f16;
-      NumElts /= 2;
-    } else if (EltVT == MVT::bf16 && N->getValueType(0) == MVT::v2bf16) {
-      assert(NumElts % 2 == 0 && "Vector must have even number of elements");
-      EltVT = MVT::v2bf16;
-      NumElts /= 2;
+    if ((EltVT == MVT::f16 && N->getValueType(0) == MVT::v2f16) || 
+        (EltVT == MVT::bf16 && N->getValueType(0) == MVT::v2bf16)) {
+          assert(NumElts % 2 == 0 && "Vector must have even number of elements");
+          EltVT = N->getValueType(0);
+          NumElts /= 2;
     }
   }
 


-------------- next part --------------
A non-text attachment was scrubbed...
Name: D144911.528526.patch
Type: text/x-patch
Size: 3690 bytes
Desc: not available
URL: <http://lists.llvm.org/pipermail/cfe-commits/attachments/20230605/6bc05a19/attachment.bin>


More information about the cfe-commits mailing list