[llvm] 39c55ce - GlobalISel: Introduce bitcast legalize action

Matt Arsenault via llvm-commits llvm-commits at lists.llvm.org
Tue Mar 24 16:48:07 PDT 2020


Author: Matt Arsenault
Date: 2020-03-24T19:33:33-04:00
New Revision: 39c55cef21b4ca9863edf82b6d9995fb1e28c37d

URL: https://github.com/llvm/llvm-project/commit/39c55cef21b4ca9863edf82b6d9995fb1e28c37d
DIFF: https://github.com/llvm/llvm-project/commit/39c55cef21b4ca9863edf82b6d9995fb1e28c37d.diff

LOG: GlobalISel: Introduce bitcast legalize action

For some operations, the type is unimportant and only the number of
bits matters. For example I don't want to treat <4 x s8> as a legal
type, but I also don't want to decompose loads of this into smaller
pieces to get legal register types.

On AMDGPU in SelectionDAG, we legalize a number of operations (most
notably load and store) by coercing all types to vectors of i32. For
GlobalISel, I'm trying very hard to avoid doing this for every type,
but I don't think this strategy can be completely avoided. I'm trying
to avoid bitcasts for any legitimately legal type we can operate on,
since the intervening bitcasts have proven to be a hassle.

For loads, I think I can get away without ever casting the result
type, and handling any arbitrary bitwidth during selection (I will
eventually want new tablegen support to help with this, rather than
having to add every possible type as legal). The unmerge required to
do anything with the value should expand to the expected shifts. This
is trickier for stores, since it would now require handling a wide
array of truncates during selection which I don't want.

Future potentially interesting case are for vector indexing, where
sub-dword type should be indexed in s32 pieces.

Added: 
    

Modified: 
    llvm/include/llvm/CodeGen/GlobalISel/LegalizerHelper.h
    llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h
    llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
    llvm/lib/CodeGen/GlobalISel/LegalizerInfo.cpp
    llvm/unittests/CodeGen/GlobalISel/LegalizerHelperTest.cpp
    llvm/unittests/CodeGen/GlobalISel/LegalizerInfoTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/CodeGen/GlobalISel/LegalizerHelper.h b/llvm/include/llvm/CodeGen/GlobalISel/LegalizerHelper.h
index 08d12a57adff..60c0a71762e5 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/LegalizerHelper.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/LegalizerHelper.h
@@ -74,6 +74,9 @@ class LegalizerHelper {
   /// precision, ignoring the unused bits).
   LegalizeResult widenScalar(MachineInstr &MI, unsigned TypeIdx, LLT WideTy);
 
+  /// Legalize an instruction by replacing the value type
+  LegalizeResult bitcast(MachineInstr &MI, unsigned TypeIdx, LLT Ty);
+
   /// Legalize an instruction by splitting it into simpler parts, hopefully
   /// understood by the target.
   LegalizeResult lower(MachineInstr &MI, unsigned TypeIdx, LLT Ty);
@@ -128,6 +131,14 @@ class LegalizerHelper {
   /// original vector type, and replacing the vreg of the operand in place.
   void moreElementsVectorSrc(MachineInstr &MI, LLT MoreTy, unsigned OpIdx);
 
+  /// Legalize a single operand \p OpIdx of the machine instruction \p MI as a
+  /// use by inserting a G_BITCAST to \p CastTy
+  void bitcastSrc(MachineInstr &MI, LLT CastTy, unsigned OpIdx);
+
+  /// Legalize a single operand \p OpIdx of the machine instruction \p MI as a
+  /// def by inserting a G_BITCAST from \p CastTy
+  void bitcastDst(MachineInstr &MI, LLT CastTy, unsigned OpIdx);
+
 private:
   LegalizeResult
   widenScalarMergeValues(MachineInstr &MI, unsigned TypeIdx, LLT WideTy);

diff  --git a/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h b/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h
index e2ed44e283d8..624fa70f1aa6 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h
@@ -68,6 +68,9 @@ enum LegalizeAction : std::uint8_t {
   /// the first two results.
   MoreElements,
 
+  /// Perform the operation on a 
diff erent, but equivalently sized type.
+  Bitcast,
+
   /// The operation itself must be expressed in terms of simpler actions on
   /// this target. E.g. a SREM replaced by an SDIV and subtraction.
   Lower,

diff  --git a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
index 63dff5c96fd6..ac9b35aeb18c 100644
--- a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
@@ -120,6 +120,9 @@ LegalizerHelper::legalizeInstrStep(MachineInstr &MI) {
   case WidenScalar:
     LLVM_DEBUG(dbgs() << ".. Widen scalar\n");
     return widenScalar(MI, Step.TypeIdx, Step.NewType);
+  case Bitcast:
+    LLVM_DEBUG(dbgs() << ".. Bitcast type\n");
+    return bitcast(MI, Step.TypeIdx, Step.NewType);
   case Lower:
     LLVM_DEBUG(dbgs() << ".. Lower\n");
     return lower(MI, Step.TypeIdx, Step.NewType);
@@ -1251,6 +1254,19 @@ void LegalizerHelper::moreElementsVectorSrc(MachineInstr &MI, LLT MoreTy,
   MO.setReg(MoreReg);
 }
 
+void LegalizerHelper::bitcastSrc(MachineInstr &MI, LLT CastTy, unsigned OpIdx) {
+  MachineOperand &Op = MI.getOperand(OpIdx);
+  Op.setReg(MIRBuilder.buildBitcast(CastTy, Op).getReg(0));
+}
+
+void LegalizerHelper::bitcastDst(MachineInstr &MI, LLT CastTy, unsigned OpIdx) {
+  MachineOperand &MO = MI.getOperand(OpIdx);
+  Register CastDst = MRI.createGenericVirtualRegister(CastTy);
+  MIRBuilder.setInsertPt(MIRBuilder.getMBB(), ++MIRBuilder.getInsertPt());
+  MIRBuilder.buildBitcast(MO, CastDst);
+  MO.setReg(CastDst);
+}
+
 LegalizerHelper::LegalizeResult
 LegalizerHelper::widenScalarMergeValues(MachineInstr &MI, unsigned TypeIdx,
                                         LLT WideTy) {
@@ -2104,6 +2120,61 @@ LegalizerHelper::lowerBitcast(MachineInstr &MI) {
   return UnableToLegalize;
 }
 
+LegalizerHelper::LegalizeResult
+LegalizerHelper::bitcast(MachineInstr &MI, unsigned TypeIdx, LLT CastTy) {
+  MIRBuilder.setInstr(MI);
+
+  switch (MI.getOpcode()) {
+  case TargetOpcode::G_LOAD: {
+    if (TypeIdx != 0)
+      return UnableToLegalize;
+
+    Observer.changingInstr(MI);
+    bitcastDst(MI, CastTy, 0);
+    Observer.changedInstr(MI);
+    return Legalized;
+  }
+  case TargetOpcode::G_STORE: {
+    if (TypeIdx != 0)
+      return UnableToLegalize;
+
+    Observer.changingInstr(MI);
+    bitcastSrc(MI, CastTy, 0);
+    Observer.changedInstr(MI);
+    return Legalized;
+  }
+  case TargetOpcode::G_SELECT: {
+    if (TypeIdx != 0)
+      return UnableToLegalize;
+
+    if (MRI.getType(MI.getOperand(1).getReg()).isVector()) {
+      LLVM_DEBUG(
+          dbgs() << "bitcast action not implemented for vector select\n");
+      return UnableToLegalize;
+    }
+
+    Observer.changingInstr(MI);
+    bitcastSrc(MI, CastTy, 2);
+    bitcastSrc(MI, CastTy, 3);
+    bitcastDst(MI, CastTy, 0);
+    Observer.changedInstr(MI);
+    return Legalized;
+  }
+  case TargetOpcode::G_AND:
+  case TargetOpcode::G_OR:
+  case TargetOpcode::G_XOR: {
+    Observer.changingInstr(MI);
+    bitcastSrc(MI, CastTy, 1);
+    bitcastSrc(MI, CastTy, 2);
+    bitcastDst(MI, CastTy, 0);
+    Observer.changedInstr(MI);
+    return Legalized;
+  }
+  default:
+    return UnableToLegalize;
+  }
+}
+
 LegalizerHelper::LegalizeResult
 LegalizerHelper::lower(MachineInstr &MI, unsigned TypeIdx, LLT Ty) {
   using namespace TargetOpcode;

diff  --git a/llvm/lib/CodeGen/GlobalISel/LegalizerInfo.cpp b/llvm/lib/CodeGen/GlobalISel/LegalizerInfo.cpp
index cad08a395936..2658434fe949 100644
--- a/llvm/lib/CodeGen/GlobalISel/LegalizerInfo.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/LegalizerInfo.cpp
@@ -59,6 +59,9 @@ raw_ostream &llvm::operator<<(raw_ostream &OS, LegalizeAction Action) {
   case MoreElements:
     OS << "MoreElements";
     break;
+  case Bitcast:
+    OS << "Bitcast";
+    break;
   case Lower:
     OS << "Lower";
     break;
@@ -173,6 +176,9 @@ static bool mutationIsSane(const LegalizeRule &Rule,
 
     return true;
   }
+  case Bitcast: {
+    return OldTy != NewTy && OldTy.getSizeInBits() == NewTy.getSizeInBits();
+  }
   default:
     return true;
   }
@@ -575,6 +581,7 @@ LegalizerInfo::findAction(const SizeAndActionsVec &Vec, const uint32_t Size) {
   LegalizeAction Action = Vec[VecIdx].second;
   switch (Action) {
   case Legal:
+  case Bitcast:
   case Lower:
   case Libcall:
   case Custom:

diff  --git a/llvm/unittests/CodeGen/GlobalISel/LegalizerHelperTest.cpp b/llvm/unittests/CodeGen/GlobalISel/LegalizerHelperTest.cpp
index 3bc8f1f8f240..244cb75c3e29 100644
--- a/llvm/unittests/CodeGen/GlobalISel/LegalizerHelperTest.cpp
+++ b/llvm/unittests/CodeGen/GlobalISel/LegalizerHelperTest.cpp
@@ -2565,4 +2565,161 @@ TEST_F(AArch64GISelMITest, WidenUnmerge) {
   EXPECT_TRUE(CheckMachineFunction(*MF, CheckStr)) << *MF;
 }
 
+TEST_F(AArch64GISelMITest, BitcastLoad) {
+  setUp();
+  if (!TM)
+    return;
+
+  LLT P0 = LLT::pointer(0, 64);
+  LLT S32 = LLT::scalar(32);
+  LLT V4S8 = LLT::vector(4, 8);
+  auto Ptr = B.buildUndef(P0);
+
+  DefineLegalizerInfo(A, {});
+
+  MachineMemOperand *MMO = B.getMF().getMachineMemOperand(
+    MachinePointerInfo(), MachineMemOperand::MOLoad, 4, 4);
+  auto Load = B.buildLoad(V4S8, Ptr, *MMO);
+
+  AInfo Info(MF->getSubtarget());
+  DummyGISelObserver Observer;
+  LegalizerHelper Helper(*MF, Info, Observer, B);
+  EXPECT_EQ(LegalizerHelper::LegalizeResult::Legalized,
+            Helper.bitcast(*Load, 0, S32));
+
+  auto CheckStr = R"(
+  CHECK: [[PTR:%[0-9]+]]:_(p0) = G_IMPLICIT_DEF
+  CHECK: [[LOAD:%[0-9]+]]:_(s32) = G_LOAD
+  CHECK: [[CAST:%[0-9]+]]:_(<4 x s8>) = G_BITCAST [[LOAD]]
+
+  )";
+
+  // Check
+  EXPECT_TRUE(CheckMachineFunction(*MF, CheckStr)) << *MF;
+}
+
+TEST_F(AArch64GISelMITest, BitcastStore) {
+  setUp();
+  if (!TM)
+    return;
+
+  LLT P0 = LLT::pointer(0, 64);
+  LLT S32 = LLT::scalar(32);
+  LLT V4S8 = LLT::vector(4, 8);
+  auto Ptr = B.buildUndef(P0);
+
+  DefineLegalizerInfo(A, {});
+
+  MachineMemOperand *MMO = B.getMF().getMachineMemOperand(
+    MachinePointerInfo(), MachineMemOperand::MOStore, 4, 4);
+  auto Val = B.buildUndef(V4S8);
+  auto Store = B.buildStore(Val, Ptr, *MMO);
+
+  AInfo Info(MF->getSubtarget());
+  DummyGISelObserver Observer;
+  LegalizerHelper Helper(*MF, Info, Observer, B);
+  EXPECT_EQ(LegalizerHelper::LegalizeResult::Legalized,
+            Helper.bitcast(*Store, 0, S32));
+
+  auto CheckStr = R"(
+  CHECK: [[VAL:%[0-9]+]]:_(<4 x s8>) = G_IMPLICIT_DEF
+  CHECK: [[CAST:%[0-9]+]]:_(s32) = G_BITCAST [[VAL]]
+  CHECK: G_STORE [[CAST]]
+  )";
+
+  // Check
+  EXPECT_TRUE(CheckMachineFunction(*MF, CheckStr)) << *MF;
+}
+
+TEST_F(AArch64GISelMITest, BitcastSelect) {
+  setUp();
+  if (!TM)
+    return;
+
+  LLT S1 = LLT::scalar(1);
+  LLT S32 = LLT::scalar(32);
+  LLT V4S8 = LLT::vector(4, 8);
+
+  DefineLegalizerInfo(A, {});
+
+  auto Cond = B.buildUndef(S1);
+  auto Val0 = B.buildConstant(V4S8, 123);
+  auto Val1 = B.buildConstant(V4S8, 99);
+
+  auto Select = B.buildSelect(V4S8, Cond, Val0, Val1);
+
+  AInfo Info(MF->getSubtarget());
+  DummyGISelObserver Observer;
+  LegalizerHelper Helper(*MF, Info, Observer, B);
+  EXPECT_EQ(LegalizerHelper::LegalizeResult::Legalized,
+            Helper.bitcast(*Select, 0, S32));
+
+  auto CheckStr = R"(
+  CHECK: [[VAL0:%[0-9]+]]:_(<4 x s8>) = G_BUILD_VECTOR
+  CHECK: [[VAL1:%[0-9]+]]:_(<4 x s8>) = G_BUILD_VECTOR
+  CHECK: [[CAST0:%[0-9]+]]:_(s32) = G_BITCAST [[VAL0]]
+  CHECK: [[CAST1:%[0-9]+]]:_(s32) = G_BITCAST [[VAL1]]
+  CHECK: [[SELECT:%[0-9]+]]:_(s32) = G_SELECT %{{[0-9]+}}:_(s1), [[CAST0]]:_, [[CAST1]]:_
+  CHECK: [[CAST2:%[0-9]+]]:_(<4 x s8>) = G_BITCAST [[SELECT]]
+  )";
+
+  // Check
+  EXPECT_TRUE(CheckMachineFunction(*MF, CheckStr)) << *MF;
+
+  // Doesn't make sense
+  auto VCond = B.buildUndef(LLT::vector(4, 1));
+  auto VSelect = B.buildSelect(V4S8, VCond, Val0, Val1);
+  EXPECT_EQ(LegalizerHelper::LegalizeResult::UnableToLegalize,
+            Helper.bitcast(*VSelect, 0, S32));
+  EXPECT_EQ(LegalizerHelper::LegalizeResult::UnableToLegalize,
+            Helper.bitcast(*VSelect, 1, LLT::scalar(4)));
+}
+
+TEST_F(AArch64GISelMITest, BitcastBitOps) {
+  setUp();
+  if (!TM)
+    return;
+
+  LLT S32 = LLT::scalar(32);
+  LLT V4S8 = LLT::vector(4, 8);
+
+  DefineLegalizerInfo(A, {});
+
+  auto Val0 = B.buildConstant(V4S8, 123);
+  auto Val1 = B.buildConstant(V4S8, 99);
+  auto And = B.buildAnd(V4S8, Val0, Val1);
+  auto Or = B.buildOr(V4S8, Val0, Val1);
+  auto Xor = B.buildXor(V4S8, Val0, Val1);
+
+  AInfo Info(MF->getSubtarget());
+  DummyGISelObserver Observer;
+  LegalizerHelper Helper(*MF, Info, Observer, B);
+  EXPECT_EQ(LegalizerHelper::LegalizeResult::Legalized,
+            Helper.bitcast(*And, 0, S32));
+  EXPECT_EQ(LegalizerHelper::LegalizeResult::Legalized,
+            Helper.bitcast(*Or, 0, S32));
+  EXPECT_EQ(LegalizerHelper::LegalizeResult::Legalized,
+            Helper.bitcast(*Xor, 0, S32));
+
+  auto CheckStr = R"(
+  CHECK: [[VAL0:%[0-9]+]]:_(<4 x s8>) = G_BUILD_VECTOR
+  CHECK: [[VAL1:%[0-9]+]]:_(<4 x s8>) = G_BUILD_VECTOR
+  CHECK: [[CAST0:%[0-9]+]]:_(s32) = G_BITCAST [[VAL0]]
+  CHECK: [[CAST1:%[0-9]+]]:_(s32) = G_BITCAST [[VAL1]]
+  CHECK: [[AND:%[0-9]+]]:_(s32) = G_AND [[CAST0]]:_, [[CAST1]]:_
+  CHECK: [[CAST_AND:%[0-9]+]]:_(<4 x s8>) = G_BITCAST [[AND]]
+  CHECK: [[CAST2:%[0-9]+]]:_(s32) = G_BITCAST [[VAL0]]
+  CHECK: [[CAST3:%[0-9]+]]:_(s32) = G_BITCAST [[VAL1]]
+  CHECK: [[OR:%[0-9]+]]:_(s32) = G_OR [[CAST2]]:_, [[CAST3]]:_
+  CHECK: [[CAST_OR:%[0-9]+]]:_(<4 x s8>) = G_BITCAST [[OR]]
+  CHECK: [[CAST4:%[0-9]+]]:_(s32) = G_BITCAST [[VAL0]]
+  CHECK: [[CAST5:%[0-9]+]]:_(s32) = G_BITCAST [[VAL1]]
+  CHECK: [[XOR:%[0-9]+]]:_(s32) = G_XOR [[CAST4]]:_, [[CAST5]]:_
+  CHECK: [[CAST_XOR:%[0-9]+]]:_(<4 x s8>) = G_BITCAST [[XOR]]
+  )";
+
+  // Check
+  EXPECT_TRUE(CheckMachineFunction(*MF, CheckStr)) << *MF;
+}
+
 } // namespace

diff  --git a/llvm/unittests/CodeGen/GlobalISel/LegalizerInfoTest.cpp b/llvm/unittests/CodeGen/GlobalISel/LegalizerInfoTest.cpp
index b342143e1394..7fd2ea453a2a 100644
--- a/llvm/unittests/CodeGen/GlobalISel/LegalizerInfoTest.cpp
+++ b/llvm/unittests/CodeGen/GlobalISel/LegalizerInfoTest.cpp
@@ -27,6 +27,7 @@ operator<<(std::ostream &OS, const LegalizeAction Act) {
   case MoreElements:  OS << "MoreElements"; break;
   case Libcall: OS << "Libcall"; break;
   case Custom: OS << "Custom"; break;
+  case Bitcast: OS << "Bitcast"; break;
   case Unsupported: OS << "Unsupported"; break;
   case NotFound: OS << "NotFound"; break;
   case UseLegacyRules: OS << "UseLegacyRules"; break;


        


More information about the llvm-commits mailing list