[llvm] 3f08ad6 - [SVE][CodeGen] Scalable vector MVT size queries

Graham Hunter via llvm-commits llvm-commits at lists.llvm.org
Mon Nov 18 04:31:35 PST 2019


Author: Graham Hunter
Date: 2019-11-18T12:30:59Z
New Revision: 3f08ad611aa26db2e719705b8fb60f4661d97b98

URL: https://github.com/llvm/llvm-project/commit/3f08ad611aa26db2e719705b8fb60f4661d97b98
DIFF: https://github.com/llvm/llvm-project/commit/3f08ad611aa26db2e719705b8fb60f4661d97b98.diff

LOG: [SVE][CodeGen] Scalable vector MVT size queries

* Implements scalable size queries for MVTs, split out from D53137.

* Contains a fix for FindMemType to avoid using scalable vector type
  to contain non-scalable types.

* Explicit casts for several places where implicit integer sign
  changes or promotion from 32 to 64 bits caused problems.

* CodeGenDAGPatterns will treat scalable and non-scalable vector types
  as different.

Reviewers: greened, cameron.mcinally, sdesmalen, rovka

Reviewed By: rovka

Differential Revision: https://reviews.llvm.org/D66871

Added: 
    

Modified: 
    llvm/include/llvm/CodeGen/SelectionDAGNodes.h
    llvm/include/llvm/CodeGen/ValueTypes.h
    llvm/include/llvm/Support/MachineValueType.h
    llvm/include/llvm/Support/TypeSize.h
    llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
    llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
    llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
    llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
    llvm/lib/CodeGen/SelectionDAG/StatepointLowering.cpp
    llvm/lib/CodeGen/ValueTypes.cpp
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/lib/Target/AArch64/AArch64StackOffset.h
    llvm/lib/Target/ARM/ARMISelLowering.cpp
    llvm/lib/Target/Hexagon/HexagonISelLowering.cpp
    llvm/lib/Target/Mips/MipsISelLowering.cpp
    llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
    llvm/lib/Target/PowerPC/PPCISelDAGToDAG.cpp
    llvm/lib/Target/X86/X86ISelLowering.cpp
    llvm/unittests/CodeGen/ScalableVectorMVTsTest.cpp
    llvm/utils/TableGen/CodeGenDAGPatterns.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
index ceb8b72635a2..abcd3fb17333 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
@@ -42,6 +42,7 @@
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/MachineValueType.h"
+#include "llvm/Support/TypeSize.h"
 #include <algorithm>
 #include <cassert>
 #include <climits>
@@ -170,11 +171,15 @@ class SDValue {
   }
 
   /// Returns the size of the value in bits.
-  unsigned getValueSizeInBits() const {
+  ///
+  /// If the value type is a scalable vector type, the scalable property will
+  /// be set and the runtime size will be a positive integer multiple of the
+  /// base size.
+  TypeSize getValueSizeInBits() const {
     return getValueType().getSizeInBits();
   }
 
-  unsigned getScalarValueSizeInBits() const {
+  TypeSize getScalarValueSizeInBits() const {
     return getValueType().getScalarType().getSizeInBits();
   }
 
@@ -1022,7 +1027,11 @@ END_TWO_BYTE_PACK()
   }
 
   /// Returns MVT::getSizeInBits(getValueType(ResNo)).
-  unsigned getValueSizeInBits(unsigned ResNo) const {
+  ///
+  /// If the value type is a scalable vector type, the scalable property will
+  /// be set and the runtime size will be a positive integer multiple of the
+  /// base size.
+  TypeSize getValueSizeInBits(unsigned ResNo) const {
     return getValueType(ResNo).getSizeInBits();
   }
 

diff  --git a/llvm/include/llvm/CodeGen/ValueTypes.h b/llvm/include/llvm/CodeGen/ValueTypes.h
index cd4c4ca64081..bcf417762920 100644
--- a/llvm/include/llvm/CodeGen/ValueTypes.h
+++ b/llvm/include/llvm/CodeGen/ValueTypes.h
@@ -18,6 +18,7 @@
 #include "llvm/Support/Compiler.h"
 #include "llvm/Support/MachineValueType.h"
 #include "llvm/Support/MathExtras.h"
+#include "llvm/Support/TypeSize.h"
 #include <cassert>
 #include <cstdint>
 #include <string>
@@ -209,11 +210,13 @@ namespace llvm {
 
     /// Return true if the bit size is a multiple of 8.
     bool isByteSized() const {
-      return (getSizeInBits() & 7) == 0;
+      return getSizeInBits().isByteSized();
     }
 
     /// Return true if the size is a power-of-two number of bytes.
     bool isRound() const {
+      if (isScalableVector())
+        return false;
       unsigned BitSize = getSizeInBits();
       return BitSize >= 8 && !(BitSize & (BitSize - 1));
     }
@@ -288,25 +291,38 @@ namespace llvm {
     }
 
     /// Return the size of the specified value type in bits.
-    unsigned getSizeInBits() const {
+    ///
+    /// If the value type is a scalable vector type, the scalable property will
+    /// be set and the runtime size will be a positive integer multiple of the
+    /// base size.
+    TypeSize getSizeInBits() const {
       if (isSimple())
         return V.getSizeInBits();
       return getExtendedSizeInBits();
     }
 
-    unsigned getScalarSizeInBits() const {
+    TypeSize getScalarSizeInBits() const {
       return getScalarType().getSizeInBits();
     }
 
     /// Return the number of bytes overwritten by a store of the specified value
     /// type.
-    unsigned getStoreSize() const {
-      return (getSizeInBits() + 7) / 8;
+    ///
+    /// If the value type is a scalable vector type, the scalable property will
+    /// be set and the runtime size will be a positive integer multiple of the
+    /// base size.
+    TypeSize getStoreSize() const {
+      TypeSize BaseSize = getSizeInBits();
+      return {(BaseSize.getKnownMinSize() + 7) / 8, BaseSize.isScalable()};
     }
 
     /// Return the number of bits overwritten by a store of the specified value
     /// type.
-    unsigned getStoreSizeInBits() const {
+    ///
+    /// If the value type is a scalable vector type, the scalable property will
+    /// be set and the runtime size will be a positive integer multiple of the
+    /// base size.
+    TypeSize getStoreSizeInBits() const {
       return getStoreSize() * 8;
     }
 
@@ -428,7 +444,7 @@ namespace llvm {
     bool isExtended2048BitVector() const LLVM_READONLY;
     EVT getExtendedVectorElementType() const;
     unsigned getExtendedVectorNumElements() const LLVM_READONLY;
-    unsigned getExtendedSizeInBits() const LLVM_READONLY;
+    TypeSize getExtendedSizeInBits() const LLVM_READONLY;
   };
 
 } // end namespace llvm

diff  --git a/llvm/include/llvm/Support/MachineValueType.h b/llvm/include/llvm/Support/MachineValueType.h
index 7f9f0b85c55e..26b45a602763 100644
--- a/llvm/include/llvm/Support/MachineValueType.h
+++ b/llvm/include/llvm/Support/MachineValueType.h
@@ -671,7 +671,12 @@ namespace llvm {
       return { getVectorNumElements(), isScalableVector() };
     }
 
-    unsigned getSizeInBits() const {
+    /// Returns the size of the specified MVT in bits.
+    ///
+    /// If the value type is a scalable vector type, the scalable property will
+    /// be set and the runtime size will be a positive integer multiple of the
+    /// base size.
+    TypeSize getSizeInBits() const {
       switch (SimpleTy) {
       default:
         llvm_unreachable("getSizeInBits called on extended MVT.");
@@ -691,25 +696,25 @@ namespace llvm {
       case Metadata:
         llvm_unreachable("Value type is metadata.");
       case i1:
-      case v1i1:
-      case nxv1i1: return 1;
-      case v2i1:
-      case nxv2i1: return 2;
-      case v4i1:
-      case nxv4i1: return 4;
+      case v1i1: return TypeSize::Fixed(1);
+      case nxv1i1: return TypeSize::Scalable(1);
+      case v2i1: return TypeSize::Fixed(2);
+      case nxv2i1: return TypeSize::Scalable(2);
+      case v4i1: return TypeSize::Fixed(4);
+      case nxv4i1: return TypeSize::Scalable(4);
       case i8  :
       case v1i8:
-      case v8i1:
+      case v8i1: return TypeSize::Fixed(8);
       case nxv1i8:
-      case nxv8i1: return 8;
+      case nxv8i1: return TypeSize::Scalable(8);
       case i16 :
       case f16:
       case v16i1:
       case v2i8:
-      case v1i16:
+      case v1i16: return TypeSize::Fixed(16);
       case nxv16i1:
       case nxv2i8:
-      case nxv1i16: return 16;
+      case nxv1i16: return TypeSize::Scalable(16);
       case f32 :
       case i32 :
       case v32i1:
@@ -717,15 +722,15 @@ namespace llvm {
       case v2i16:
       case v2f16:
       case v1f32:
-      case v1i32:
+      case v1i32: return TypeSize::Fixed(32);
       case nxv32i1:
       case nxv4i8:
       case nxv2i16:
       case nxv1i32:
       case nxv2f16:
-      case nxv1f32: return 32;
+      case nxv1f32: return TypeSize::Scalable(32);
       case v3i16:
-      case v3f16: return 48;
+      case v3f16: return TypeSize::Fixed(48);
       case x86mmx:
       case f64 :
       case i64 :
@@ -736,17 +741,17 @@ namespace llvm {
       case v1i64:
       case v4f16:
       case v2f32:
-      case v1f64:
+      case v1f64: return TypeSize::Fixed(64);
       case nxv8i8:
       case nxv4i16:
       case nxv2i32:
       case nxv1i64:
       case nxv4f16:
       case nxv2f32:
-      case nxv1f64: return 64;
-      case f80 :  return 80;
+      case nxv1f64: return TypeSize::Scalable(64);
+      case f80 :  return TypeSize::Fixed(80);
       case v3i32:
-      case v3f32: return 96;
+      case v3f32: return TypeSize::Fixed(96);
       case f128:
       case ppcf128:
       case i128:
@@ -758,16 +763,16 @@ namespace llvm {
       case v1i128:
       case v8f16:
       case v4f32:
-      case v2f64:
+      case v2f64: return TypeSize::Fixed(128);
       case nxv16i8:
       case nxv8i16:
       case nxv4i32:
       case nxv2i64:
       case nxv8f16:
       case nxv4f32:
-      case nxv2f64: return 128;
+      case nxv2f64: return TypeSize::Scalable(128);
       case v5i32:
-      case v5f32: return 160;
+      case v5f32: return TypeSize::Fixed(160);
       case v256i1:
       case v32i8:
       case v16i16:
@@ -775,13 +780,13 @@ namespace llvm {
       case v4i64:
       case v16f16:
       case v8f32:
-      case v4f64:
+      case v4f64: return TypeSize::Fixed(256);
       case nxv32i8:
       case nxv16i16:
       case nxv8i32:
       case nxv4i64:
       case nxv8f32:
-      case nxv4f64: return 256;
+      case nxv4f64: return TypeSize::Scalable(256);
       case v512i1:
       case v64i8:
       case v32i16:
@@ -789,56 +794,71 @@ namespace llvm {
       case v8i64:
       case v32f16:
       case v16f32:
-      case v8f64:
+      case v8f64: return TypeSize::Fixed(512);
       case nxv32i16:
       case nxv16i32:
       case nxv8i64:
       case nxv16f32:
-      case nxv8f64: return 512;
+      case nxv8f64: return TypeSize::Scalable(512);
       case v1024i1:
       case v128i8:
       case v64i16:
       case v32i32:
       case v16i64:
-      case v32f32:
+      case v32f32: return TypeSize::Fixed(1024);
       case nxv32i32:
-      case nxv16i64: return 1024;
+      case nxv16i64: return TypeSize::Scalable(1024);
       case v256i8:
       case v128i16:
       case v64i32:
       case v32i64:
-      case v64f32:
-      case nxv32i64: return 2048;
+      case v64f32: return TypeSize::Fixed(2048);
+      case nxv32i64: return TypeSize::Scalable(2048);
       case v128i32:
-      case v128f32:  return 4096;
+      case v128f32:  return TypeSize::Fixed(4096);
       case v256i32:
-      case v256f32:  return 8192;
+      case v256f32:  return TypeSize::Fixed(8192);
       case v512i32:
-      case v512f32:  return 16384;
+      case v512f32:  return TypeSize::Fixed(16384);
       case v1024i32:
-      case v1024f32:  return 32768;
+      case v1024f32:  return TypeSize::Fixed(32768);
       case v2048i32:
-      case v2048f32:  return 65536;
-      case exnref: return 0; // opaque type
+      case v2048f32:  return TypeSize::Fixed(65536);
+      case exnref: return TypeSize::Fixed(0); // opaque type
       }
     }
 
-    unsigned getScalarSizeInBits() const {
+    TypeSize getScalarSizeInBits() const {
       return getScalarType().getSizeInBits();
     }
 
     /// Return the number of bytes overwritten by a store of the specified value
     /// type.
-    unsigned getStoreSize() const {
-      return (getSizeInBits() + 7) / 8;
+    ///
+    /// If the value type is a scalable vector type, the scalable property will
+    /// be set and the runtime size will be a positive integer multiple of the
+    /// base size.
+    TypeSize getStoreSize() const {
+      TypeSize BaseSize = getSizeInBits();
+      return {(BaseSize.getKnownMinSize() + 7) / 8, BaseSize.isScalable()};
     }
 
     /// Return the number of bits overwritten by a store of the specified value
     /// type.
-    unsigned getStoreSizeInBits() const {
+    ///
+    /// If the value type is a scalable vector type, the scalable property will
+    /// be set and the runtime size will be a positive integer multiple of the
+    /// base size.
+    TypeSize getStoreSizeInBits() const {
       return getStoreSize() * 8;
     }
 
+    /// Returns true if the number of bits for the type is a multiple of an
+    /// 8-bit byte.
+    bool isByteSized() const {
+      return getSizeInBits().isByteSized();
+    }
+
     /// Return true if this has more bits than VT.
     bool bitsGT(MVT VT) const {
       return getSizeInBits() > VT.getSizeInBits();

diff  --git a/llvm/include/llvm/Support/TypeSize.h b/llvm/include/llvm/Support/TypeSize.h
index 711679cdcacb..7ea651f0f22c 100644
--- a/llvm/include/llvm/Support/TypeSize.h
+++ b/llvm/include/llvm/Support/TypeSize.h
@@ -138,6 +138,11 @@ class TypeSize {
     return IsScalable;
   }
 
+  // Returns true if the number of bits is a multiple of an 8-bit byte.
+  bool isByteSized() const {
+    return (MinSize & 7) == 0;
+  }
+
   // Casts to a uint64_t if this is a fixed-width size.
   //
   // NOTE: This interface is obsolete and will be removed in a future version

diff  --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 6f68313c71cf..9780b6992fbb 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -220,11 +220,13 @@ namespace {
       ForCodeSize = DAG.getMachineFunction().getFunction().hasOptSize();
 
       MaximumLegalStoreInBits = 0;
+      // We use the minimum store size here, since that's all we can guarantee
+      // for the scalable vector types.
       for (MVT VT : MVT::all_valuetypes())
         if (EVT(VT).isSimple() && VT != MVT::Other &&
             TLI.isTypeLegal(EVT(VT)) &&
-            VT.getSizeInBits() >= MaximumLegalStoreInBits)
-          MaximumLegalStoreInBits = VT.getSizeInBits();
+            VT.getSizeInBits().getKnownMinSize() >= MaximumLegalStoreInBits)
+          MaximumLegalStoreInBits = VT.getSizeInBits().getKnownMinSize();
     }
 
     void ConsiderForPruning(SDNode *N) {
@@ -13969,8 +13971,8 @@ SDValue DAGCombiner::ForwardStoreValueToDirectLoad(LoadSDNode *LD) {
   // the stored value). With Offset=n (for n > 0) the loaded value starts at the
   // n:th least significant byte of the stored value.
   if (DAG.getDataLayout().isBigEndian())
-    Offset = (STMemType.getStoreSizeInBits() -
-              LDMemType.getStoreSizeInBits()) / 8 - Offset;
+    Offset = ((int64_t)STMemType.getStoreSizeInBits() -
+              (int64_t)LDMemType.getStoreSizeInBits()) / 8 - Offset;
 
   // Check that the stored value cover all bits that are loaded.
   bool STCoversLD =
@@ -15127,7 +15129,7 @@ bool DAGCombiner::MergeStoresOfConstantsOrVecElts(
   // The latest Node in the DAG.
   SDLoc DL(StoreNodes[0].MemNode);
 
-  int64_t ElementSizeBits = MemVT.getStoreSizeInBits();
+  TypeSize ElementSizeBits = MemVT.getStoreSizeInBits();
   unsigned SizeInBits = NumStores * ElementSizeBits;
   unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1;
 
@@ -15512,7 +15514,7 @@ bool DAGCombiner::MergeConsecutiveStores(StoreSDNode *St) {
       Attribute::NoImplicitFloat);
 
   // This function cannot currently deal with non-byte-sized memory sizes.
-  if (ElementSizeBytes * 8 != MemVT.getSizeInBits())
+  if (ElementSizeBytes * 8 != (int64_t)MemVT.getSizeInBits())
     return false;
 
   if (!MemVT.isSimple())

diff  --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index 6c47c5b60ad2..70c0951bfd86 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -23,6 +23,7 @@
 #include "llvm/IR/DataLayout.h"
 #include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/raw_ostream.h"
+#include "llvm/Support/TypeSize.h"
 using namespace llvm;
 
 #define DEBUG_TYPE "legalize-types"
@@ -4680,7 +4681,8 @@ static EVT FindMemType(SelectionDAG& DAG, const TargetLowering &TLI,
                        unsigned Width, EVT WidenVT,
                        unsigned Align = 0, unsigned WidenEx = 0) {
   EVT WidenEltVT = WidenVT.getVectorElementType();
-  unsigned WidenWidth = WidenVT.getSizeInBits();
+  const bool Scalable = WidenVT.isScalableVector();
+  unsigned WidenWidth = WidenVT.getSizeInBits().getKnownMinSize();
   unsigned WidenEltWidth = WidenEltVT.getSizeInBits();
   unsigned AlignInBits = Align*8;
 
@@ -4691,23 +4693,27 @@ static EVT FindMemType(SelectionDAG& DAG, const TargetLowering &TLI,
 
   // See if there is larger legal integer than the element type to load/store.
   unsigned VT;
-  for (VT = (unsigned)MVT::LAST_INTEGER_VALUETYPE;
-       VT >= (unsigned)MVT::FIRST_INTEGER_VALUETYPE; --VT) {
-    EVT MemVT((MVT::SimpleValueType) VT);
-    unsigned MemVTWidth = MemVT.getSizeInBits();
-    if (MemVT.getSizeInBits() <= WidenEltWidth)
-      break;
-    auto Action = TLI.getTypeAction(*DAG.getContext(), MemVT);
-    if ((Action == TargetLowering::TypeLegal ||
-         Action == TargetLowering::TypePromoteInteger) &&
-        (WidenWidth % MemVTWidth) == 0 &&
-        isPowerOf2_32(WidenWidth / MemVTWidth) &&
-        (MemVTWidth <= Width ||
-         (Align!=0 && MemVTWidth<=AlignInBits && MemVTWidth<=Width+WidenEx))) {
-      if (MemVTWidth == WidenWidth)
-        return MemVT;
-      RetVT = MemVT;
-      break;
+  // Don't bother looking for an integer type if the vector is scalable, skip
+  // to vector types.
+  if (!Scalable) {
+    for (VT = (unsigned)MVT::LAST_INTEGER_VALUETYPE;
+         VT >= (unsigned)MVT::FIRST_INTEGER_VALUETYPE; --VT) {
+      EVT MemVT((MVT::SimpleValueType) VT);
+      unsigned MemVTWidth = MemVT.getSizeInBits();
+      if (MemVT.getSizeInBits() <= WidenEltWidth)
+        break;
+      auto Action = TLI.getTypeAction(*DAG.getContext(), MemVT);
+      if ((Action == TargetLowering::TypeLegal ||
+           Action == TargetLowering::TypePromoteInteger) &&
+          (WidenWidth % MemVTWidth) == 0 &&
+          isPowerOf2_32(WidenWidth / MemVTWidth) &&
+          (MemVTWidth <= Width ||
+           (Align!=0 && MemVTWidth<=AlignInBits && MemVTWidth<=Width+WidenEx))) {
+        if (MemVTWidth == WidenWidth)
+          return MemVT;
+        RetVT = MemVT;
+        break;
+      }
     }
   }
 
@@ -4716,7 +4722,10 @@ static EVT FindMemType(SelectionDAG& DAG, const TargetLowering &TLI,
   for (VT = (unsigned)MVT::LAST_VECTOR_VALUETYPE;
        VT >= (unsigned)MVT::FIRST_VECTOR_VALUETYPE; --VT) {
     EVT MemVT = (MVT::SimpleValueType) VT;
-    unsigned MemVTWidth = MemVT.getSizeInBits();
+    // Skip vector MVTs which don't match the scalable property of WidenVT.
+    if (Scalable != MemVT.isScalableVector())
+      continue;
+    unsigned MemVTWidth = MemVT.getSizeInBits().getKnownMinSize();
     auto Action = TLI.getTypeAction(*DAG.getContext(), MemVT);
     if ((Action == TargetLowering::TypeLegal ||
          Action == TargetLowering::TypePromoteInteger) &&

diff  --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 1b02f96cf279..abd046530ed9 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -8842,7 +8842,9 @@ MemSDNode::MemSDNode(unsigned Opc, unsigned Order, const DebugLoc &dl,
   // We check here that the size of the memory operand fits within the size of
   // the MMO. This is because the MMO might indicate only a possible address
   // range instead of specifying the affected memory addresses precisely.
-  assert(memvt.getStoreSize() <= MMO->getSize() && "Size mismatch!");
+  // TODO: Make MachineMemOperands aware of scalable vectors.
+  assert(memvt.getStoreSize().getKnownMinSize() <= MMO->getSize() &&
+         "Size mismatch!");
 }
 
 /// Profile - Gather unique data for the node.

diff  --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 4bb5d1e96c96..54a31424b202 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -4304,7 +4304,10 @@ void SelectionDAGBuilder::visitMaskedStore(const CallInst &I,
   MachineMemOperand *MMO =
     DAG.getMachineFunction().
     getMachineMemOperand(MachinePointerInfo(PtrOperand),
-                          MachineMemOperand::MOStore,  VT.getStoreSize(),
+                          MachineMemOperand::MOStore,
+                          // TODO: Make MachineMemOperands aware of scalable
+                          // vectors.
+                          VT.getStoreSize().getKnownMinSize(),
                           Alignment, AAInfo);
   SDValue StoreNode = DAG.getMaskedStore(getRoot(), sdl, Src0, Ptr, Mask, VT,
                                          MMO, false /* Truncating */,
@@ -4408,7 +4411,10 @@ void SelectionDAGBuilder::visitMaskedScatter(const CallInst &I) {
   const Value *MemOpBasePtr = UniformBase ? BasePtr : nullptr;
   MachineMemOperand *MMO = DAG.getMachineFunction().
     getMachineMemOperand(MachinePointerInfo(MemOpBasePtr),
-                         MachineMemOperand::MOStore,  VT.getStoreSize(),
+                         MachineMemOperand::MOStore,
+                         // TODO: Make MachineMemOperands aware of scalable
+                         // vectors.
+                         VT.getStoreSize().getKnownMinSize(),
                          Alignment, AAInfo);
   if (!UniformBase) {
     Base = DAG.getConstant(0, sdl, TLI.getPointerTy(DAG.getDataLayout()));
@@ -4477,7 +4483,10 @@ void SelectionDAGBuilder::visitMaskedLoad(const CallInst &I, bool IsExpanding) {
   MachineMemOperand *MMO =
     DAG.getMachineFunction().
     getMachineMemOperand(MachinePointerInfo(PtrOperand),
-                          MachineMemOperand::MOLoad,  VT.getStoreSize(),
+                          MachineMemOperand::MOLoad,
+                          // TODO: Make MachineMemOperands aware of scalable
+                          // vectors.
+                          VT.getStoreSize().getKnownMinSize(),
                           Alignment, AAInfo, Ranges);
 
   SDValue Load = DAG.getMaskedLoad(VT, sdl, InChain, Ptr, Mask, Src0, VT, MMO,
@@ -4528,7 +4537,10 @@ void SelectionDAGBuilder::visitMaskedGather(const CallInst &I) {
   MachineMemOperand *MMO =
     DAG.getMachineFunction().
     getMachineMemOperand(MachinePointerInfo(UniformBase ? BasePtr : nullptr),
-                         MachineMemOperand::MOLoad,  VT.getStoreSize(),
+                         MachineMemOperand::MOLoad,
+                         // TODO: Make MachineMemOperands aware of scalable
+                         // vectors.
+                         VT.getStoreSize().getKnownMinSize(),
                          Alignment, AAInfo, Ranges);
 
   if (!UniformBase) {
@@ -9248,9 +9260,11 @@ TargetLowering::LowerCallTo(TargetLowering::CallLoweringInfo &CLI) const {
 
       for (unsigned j = 0; j != NumParts; ++j) {
         // if it isn't first piece, alignment must be 1
+        // For scalable vectors the scalable part is currently handled
+        // by individual targets, so we just use the known minimum size here.
         ISD::OutputArg MyFlags(Flags, Parts[j].getValueType(), VT,
-                               i < CLI.NumFixedArgs,
-                               i, j*Parts[j].getValueType().getStoreSize());
+                    i < CLI.NumFixedArgs, i,
+                    j*Parts[j].getValueType().getStoreSize().getKnownMinSize());
         if (NumParts > 1 && j == 0)
           MyFlags.Flags.setSplit();
         else if (j != 0) {
@@ -9719,8 +9733,11 @@ void SelectionDAGISel::LowerArguments(const Function &F) {
       unsigned NumRegs = TLI->getNumRegistersForCallingConv(
           *CurDAG->getContext(), F.getCallingConv(), VT);
       for (unsigned i = 0; i != NumRegs; ++i) {
+        // For scalable vectors, use the minimum size; individual targets
+        // are responsible for handling scalable vector arguments and
+        // return values.
         ISD::InputArg MyFlags(Flags, RegisterVT, VT, isArgValueUsed,
-                              ArgNo, PartBase+i*RegisterVT.getStoreSize());
+                 ArgNo, PartBase+i*RegisterVT.getStoreSize().getKnownMinSize());
         if (NumRegs > 1 && i == 0)
           MyFlags.Flags.setSplit();
         // if it isn't first piece, alignment must be 1
@@ -9733,7 +9750,7 @@ void SelectionDAGISel::LowerArguments(const Function &F) {
       }
       if (NeedsRegBlock && Value == NumValues - 1)
         Ins[Ins.size() - 1].Flags.setInConsecutiveRegsLast();
-      PartBase += VT.getStoreSize();
+      PartBase += VT.getStoreSize().getKnownMinSize();
     }
   }
 

diff  --git a/llvm/lib/CodeGen/SelectionDAG/StatepointLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/StatepointLowering.cpp
index fad98b6f50dc..c628f379e415 100644
--- a/llvm/lib/CodeGen/SelectionDAG/StatepointLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/StatepointLowering.cpp
@@ -384,7 +384,8 @@ spillIncomingStatepointValue(SDValue Incoming, SDValue Chain,
     // can consider allowing spills of smaller values to larger slots
     // (i.e. change the '==' in the assert below to a '>=').
     MachineFrameInfo &MFI = Builder.DAG.getMachineFunction().getFrameInfo();
-    assert((MFI.getObjectSize(Index) * 8) == Incoming.getValueSizeInBits() &&
+    assert((MFI.getObjectSize(Index) * 8) ==
+           (int64_t)Incoming.getValueSizeInBits() &&
            "Bad spill:  stack slot does not match!");
 
     // Note: Using the alignment of the spill slot (rather than the abi or

diff  --git a/llvm/lib/CodeGen/ValueTypes.cpp b/llvm/lib/CodeGen/ValueTypes.cpp
index 73b862d51c0f..b868abf69582 100644
--- a/llvm/lib/CodeGen/ValueTypes.cpp
+++ b/llvm/lib/CodeGen/ValueTypes.cpp
@@ -11,6 +11,7 @@
 #include "llvm/IR/DerivedTypes.h"
 #include "llvm/IR/Type.h"
 #include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/TypeSize.h"
 using namespace llvm;
 
 EVT EVT::changeExtendedTypeToInteger() const {
@@ -101,12 +102,12 @@ unsigned EVT::getExtendedVectorNumElements() const {
   return cast<VectorType>(LLVMTy)->getNumElements();
 }
 
-unsigned EVT::getExtendedSizeInBits() const {
+TypeSize EVT::getExtendedSizeInBits() const {
   assert(isExtended() && "Type is not extended!");
   if (IntegerType *ITy = dyn_cast<IntegerType>(LLVMTy))
-    return ITy->getBitWidth();
+    return TypeSize::Fixed(ITy->getBitWidth());
   if (VectorType *VTy = dyn_cast<VectorType>(LLVMTy))
-    return VTy->getBitWidth();
+    return VTy->getPrimitiveSizeInBits();
   llvm_unreachable("Unrecognized extended type!");
 }
 

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 2a3b3a3ac2f8..a9471a7acaf7 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -9937,7 +9937,7 @@ static SDValue performBitcastCombine(SDNode *N,
 
   // Only interested in 64-bit vectors as the ultimate result.
   EVT VT = N->getValueType(0);
-  if (!VT.isVector())
+  if (!VT.isVector() || VT.isScalableVector())
     return SDValue();
   if (VT.getSimpleVT().getSizeInBits() != 64)
     return SDValue();

diff  --git a/llvm/lib/Target/AArch64/AArch64StackOffset.h b/llvm/lib/Target/AArch64/AArch64StackOffset.h
index 13f12a6c9c30..f95b5dc5246e 100644
--- a/llvm/lib/Target/AArch64/AArch64StackOffset.h
+++ b/llvm/lib/Target/AArch64/AArch64StackOffset.h
@@ -15,6 +15,7 @@
 #define LLVM_LIB_TARGET_AARCH64_AARCH64STACKOFFSET_H
 
 #include "llvm/Support/MachineValueType.h"
+#include "llvm/Support/TypeSize.h"
 
 namespace llvm {
 
@@ -45,8 +46,7 @@ class StackOffset {
   StackOffset() : Bytes(0), ScalableBytes(0) {}
 
   StackOffset(int64_t Offset, MVT::SimpleValueType T) : StackOffset() {
-    assert(MVT(T).getSizeInBits() % 8 == 0 &&
-           "Offset type is not a multiple of bytes");
+    assert(MVT(T).isByteSized() && "Offset type is not a multiple of bytes");
     *this += Part(Offset, T);
   }
 
@@ -56,11 +56,11 @@ class StackOffset {
   StackOffset &operator=(const StackOffset &) = default;
 
   StackOffset &operator+=(const StackOffset::Part &Other) {
-    int64_t OffsetInBytes = Other.first * (Other.second.getSizeInBits() / 8);
-    if (Other.second.isScalableVector())
-      ScalableBytes += OffsetInBytes;
+    const TypeSize Size = Other.second.getSizeInBits();
+    if (Size.isScalable())
+      ScalableBytes += Other.first * ((int64_t)Size.getKnownMinSize() / 8);
     else
-      Bytes += OffsetInBytes;
+      Bytes += Other.first * ((int64_t)Size.getFixedSize() / 8);
     return *this;
   }
 

diff  --git a/llvm/lib/Target/ARM/ARMISelLowering.cpp b/llvm/lib/Target/ARM/ARMISelLowering.cpp
index b1d1d4fd5fc9..c9314007c0a7 100644
--- a/llvm/lib/Target/ARM/ARMISelLowering.cpp
+++ b/llvm/lib/Target/ARM/ARMISelLowering.cpp
@@ -14886,7 +14886,7 @@ static bool isLegalT2AddressImmediate(int64_t V, EVT VT,
     V = -V;
   }
 
-  unsigned NumBytes = std::max(VT.getSizeInBits() / 8, 1U);
+  unsigned NumBytes = std::max((unsigned)VT.getSizeInBits() / 8, 1U);
 
   // MVE: size * imm7
   if (VT.isVector() && Subtarget->hasMVEIntegerOps()) {

diff  --git a/llvm/lib/Target/Hexagon/HexagonISelLowering.cpp b/llvm/lib/Target/Hexagon/HexagonISelLowering.cpp
index 09f5fd82cade..7345100f178f 100644
--- a/llvm/lib/Target/Hexagon/HexagonISelLowering.cpp
+++ b/llvm/lib/Target/Hexagon/HexagonISelLowering.cpp
@@ -475,7 +475,7 @@ HexagonTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
       MemAddr = DAG.getNode(ISD::ADD, dl, MVT::i32, StackPtr, MemAddr);
       if (ArgAlign)
         LargestAlignSeen = std::max(LargestAlignSeen,
-                                    VA.getLocVT().getStoreSizeInBits() >> 3);
+                             (unsigned)VA.getLocVT().getStoreSizeInBits() >> 3);
       if (Flags.isByVal()) {
         // The argument is a struct passed by value. According to LLVM, "Arg"
         // is a pointer.

diff  --git a/llvm/lib/Target/Mips/MipsISelLowering.cpp b/llvm/lib/Target/Mips/MipsISelLowering.cpp
index 56db378ae6d4..f34100c66469 100644
--- a/llvm/lib/Target/Mips/MipsISelLowering.cpp
+++ b/llvm/lib/Target/Mips/MipsISelLowering.cpp
@@ -122,7 +122,8 @@ unsigned MipsTargetLowering::getNumRegistersForCallingConv(LLVMContext &Context,
                                                            CallingConv::ID CC,
                                                            EVT VT) const {
   if (VT.isVector())
-    return std::max((VT.getSizeInBits() / (Subtarget.isABI_O32() ? 32 : 64)),
+    return std::max(((unsigned)VT.getSizeInBits() /
+                     (Subtarget.isABI_O32() ? 32 : 64)),
                     1U);
   return MipsTargetLowering::getNumRegisters(Context, VT);
 }

diff  --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index 3d2447d75c77..eb9b99610659 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -885,7 +885,7 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
   MVT SimpleVT = LoadedVT.getSimpleVT();
   MVT ScalarVT = SimpleVT.getScalarType();
   // Read at least 8 bits (predicates are stored as 8-bit values)
-  unsigned fromTypeWidth = std::max(8U, ScalarVT.getSizeInBits());
+  unsigned fromTypeWidth = std::max(8U, (unsigned)ScalarVT.getSizeInBits());
   unsigned int fromType;
 
   // Vector Setting
@@ -1030,7 +1030,7 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
   // Float  : ISD::NON_EXTLOAD or ISD::EXTLOAD and the type is float
   MVT ScalarVT = SimpleVT.getScalarType();
   // Read at least 8 bits (predicates are stored as 8-bit values)
-  unsigned FromTypeWidth = std::max(8U, ScalarVT.getSizeInBits());
+  unsigned FromTypeWidth = std::max(8U, (unsigned)ScalarVT.getSizeInBits());
   unsigned int FromType;
   // The last operand holds the original LoadSDNode::getExtensionType() value
   unsigned ExtensionType = cast<ConstantSDNode>(

diff  --git a/llvm/lib/Target/PowerPC/PPCISelDAGToDAG.cpp b/llvm/lib/Target/PowerPC/PPCISelDAGToDAG.cpp
index 9e8906fa2f76..1d1d6b8baf7d 100644
--- a/llvm/lib/Target/PowerPC/PPCISelDAGToDAG.cpp
+++ b/llvm/lib/Target/PowerPC/PPCISelDAGToDAG.cpp
@@ -1044,7 +1044,7 @@ static unsigned allUsesTruncate(SelectionDAG *CurDAG, SDNode *N) {
       if (Use->isMachineOpcode())
         return 0;
       MaxTruncation =
-        std::max(MaxTruncation, Use->getValueType(0).getSizeInBits());
+        std::max(MaxTruncation, (unsigned)Use->getValueType(0).getSizeInBits());
       continue;
     case ISD::STORE: {
       if (Use->isMachineOpcode())

diff  --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index fff0d7d8b3f7..41d2899c7020 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -5835,7 +5835,7 @@ static SDValue getExtendInVec(unsigned Opcode, const SDLoc &DL, EVT VT,
            "Expected VTs to be the same size!");
     unsigned Scale = VT.getScalarSizeInBits() / InVT.getScalarSizeInBits();
     In = extractSubVector(In, 0, DAG, DL,
-                          std::max(128U, VT.getSizeInBits() / Scale));
+                          std::max(128U, (unsigned)VT.getSizeInBits() / Scale));
     InVT = In.getValueType();
   }
 
@@ -8626,7 +8626,7 @@ static SDValue LowerBUILD_VECTORvXi1(SDValue Op, SelectionDAG &DAG,
       ImmH = DAG.getBitcast(MVT::v32i1, ImmH);
       DstVec = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v64i1, ImmL, ImmH);
     } else {
-      MVT ImmVT = MVT::getIntegerVT(std::max(VT.getSizeInBits(), 8U));
+      MVT ImmVT = MVT::getIntegerVT(std::max((unsigned)VT.getSizeInBits(), 8U));
       SDValue Imm = DAG.getConstant(Immediate, dl, ImmVT);
       MVT VecVT = VT.getSizeInBits() >= 8 ? VT : MVT::v8i1;
       DstVec = DAG.getBitcast(VecVT, Imm);
@@ -32849,7 +32849,8 @@ static SDValue combineX86ShuffleChainWithExtract(
       Offset += Src.getConstantOperandVal(1);
       Src = Src.getOperand(0);
     }
-    WideSizeInBits = std::max(WideSizeInBits, Src.getValueSizeInBits());
+    WideSizeInBits = std::max(WideSizeInBits,
+                              (unsigned)Src.getValueSizeInBits());
     assert((Offset % BaseVT.getVectorNumElements()) == 0 &&
            "Unexpected subvector extraction");
     Offset /= BaseVT.getVectorNumElements();
@@ -35786,7 +35787,7 @@ static SDValue createPSADBW(SelectionDAG &DAG, const SDValue &Zext0,
                             const X86Subtarget &Subtarget) {
   // Find the appropriate width for the PSADBW.
   EVT InVT = Zext0.getOperand(0).getValueType();
-  unsigned RegSize = std::max(128u, InVT.getSizeInBits());
+  unsigned RegSize = std::max(128u, (unsigned)InVT.getSizeInBits());
 
   // "Zero-extend" the i8 vectors. This is not a per-element zext, rather we
   // fill in the missing vector elements with 0.

diff  --git a/llvm/unittests/CodeGen/ScalableVectorMVTsTest.cpp b/llvm/unittests/CodeGen/ScalableVectorMVTsTest.cpp
index 14a619653744..9a4e049f635a 100644
--- a/llvm/unittests/CodeGen/ScalableVectorMVTsTest.cpp
+++ b/llvm/unittests/CodeGen/ScalableVectorMVTsTest.cpp
@@ -120,4 +120,61 @@ TEST(ScalableVectorMVTsTest, VTToIRTranslation) {
             ScV4Float64Ty->getElementType());
 }
 
+TEST(ScalableVectorMVTsTest, SizeQueries) {
+  LLVMContext Ctx;
+
+  EVT nxv4i32 = EVT::getVectorVT(Ctx, MVT::i32, 4, /*Scalable=*/ true);
+  EVT nxv2i32 = EVT::getVectorVT(Ctx, MVT::i32, 2, /*Scalable=*/ true);
+  EVT nxv2i64 = EVT::getVectorVT(Ctx, MVT::i64, 2, /*Scalable=*/ true);
+  EVT nxv2f64 = EVT::getVectorVT(Ctx, MVT::f64, 2, /*Scalable=*/ true);
+
+  EVT v4i32 = EVT::getVectorVT(Ctx, MVT::i32, 4);
+  EVT v2i32 = EVT::getVectorVT(Ctx, MVT::i32, 2);
+  EVT v2i64 = EVT::getVectorVT(Ctx, MVT::i64, 2);
+  EVT v2f64 = EVT::getVectorVT(Ctx, MVT::f64, 2);
+
+  // Check equivalence and ordering on scalable types.
+  EXPECT_EQ(nxv4i32.getSizeInBits(), nxv2i64.getSizeInBits());
+  EXPECT_EQ(nxv2f64.getSizeInBits(), nxv2i64.getSizeInBits());
+  EXPECT_NE(nxv2i32.getSizeInBits(), nxv4i32.getSizeInBits());
+  EXPECT_LT(nxv2i32.getSizeInBits(), nxv2i64.getSizeInBits());
+  EXPECT_LE(nxv4i32.getSizeInBits(), nxv2i64.getSizeInBits());
+  EXPECT_GT(nxv4i32.getSizeInBits(), nxv2i32.getSizeInBits());
+  EXPECT_GE(nxv2i64.getSizeInBits(), nxv4i32.getSizeInBits());
+
+  // Check equivalence and ordering on fixed types.
+  EXPECT_EQ(v4i32.getSizeInBits(), v2i64.getSizeInBits());
+  EXPECT_EQ(v2f64.getSizeInBits(), v2i64.getSizeInBits());
+  EXPECT_NE(v2i32.getSizeInBits(), v4i32.getSizeInBits());
+  EXPECT_LT(v2i32.getSizeInBits(), v2i64.getSizeInBits());
+  EXPECT_LE(v4i32.getSizeInBits(), v2i64.getSizeInBits());
+  EXPECT_GT(v4i32.getSizeInBits(), v2i32.getSizeInBits());
+  EXPECT_GE(v2i64.getSizeInBits(), v4i32.getSizeInBits());
+
+  // Check that scalable and non-scalable types with the same minimum size
+  // are not considered equal.
+  ASSERT_TRUE(v4i32.getSizeInBits() != nxv4i32.getSizeInBits());
+  ASSERT_FALSE(v2i64.getSizeInBits() == nxv2f64.getSizeInBits());
+
+  // Check that we can obtain a known-exact size from a non-scalable type.
+  EXPECT_EQ(v4i32.getSizeInBits(), 128U);
+  EXPECT_EQ(v2i64.getSizeInBits().getFixedSize(), 128U);
+
+  // Check that we can query the known minimum size for both scalable and
+  // fixed length types.
+  EXPECT_EQ(nxv2i32.getSizeInBits().getKnownMinSize(), 64U);
+  EXPECT_EQ(nxv2f64.getSizeInBits().getKnownMinSize(), 128U);
+  EXPECT_EQ(v2i32.getSizeInBits().getKnownMinSize(),
+            nxv2i32.getSizeInBits().getKnownMinSize());
+
+  // Check scalable property.
+  ASSERT_FALSE(v4i32.getSizeInBits().isScalable());
+  ASSERT_TRUE(nxv4i32.getSizeInBits().isScalable());
+
+  // Check convenience size scaling methods.
+  EXPECT_EQ(v2i32.getSizeInBits() * 2, v4i32.getSizeInBits());
+  EXPECT_EQ(2 * nxv2i32.getSizeInBits(), nxv4i32.getSizeInBits());
+  EXPECT_EQ(nxv2f64.getSizeInBits() / 2, nxv2i32.getSizeInBits());
+}
+
 } // end anonymous namespace

diff  --git a/llvm/utils/TableGen/CodeGenDAGPatterns.cpp b/llvm/utils/TableGen/CodeGenDAGPatterns.cpp
index 003662cb94ea..0424c43b9822 100644
--- a/llvm/utils/TableGen/CodeGenDAGPatterns.cpp
+++ b/llvm/utils/TableGen/CodeGenDAGPatterns.cpp
@@ -23,6 +23,7 @@
 #include "llvm/ADT/Twine.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/TypeSize.h"
 #include "llvm/TableGen/Error.h"
 #include "llvm/TableGen/Record.h"
 #include <algorithm>
@@ -503,18 +504,24 @@ bool TypeInfer::EnforceSmallerThan(TypeSetByHwMode &Small,
   }
 
   auto LT = [](MVT A, MVT B) -> bool {
-    return A.getScalarSizeInBits() < B.getScalarSizeInBits() ||
-           (A.getScalarSizeInBits() == B.getScalarSizeInBits() &&
-            A.getSizeInBits() < B.getSizeInBits());
+    // Always treat non-scalable MVTs as smaller than scalable MVTs for the
+    // purposes of ordering.
+    auto ASize = std::make_tuple(A.isScalableVector(), A.getScalarSizeInBits(),
+                                 A.getSizeInBits());
+    auto BSize = std::make_tuple(B.isScalableVector(), B.getScalarSizeInBits(),
+                                 B.getSizeInBits());
+    return ASize < BSize;
   };
-  auto LE = [&LT](MVT A, MVT B) -> bool {
+  auto SameKindLE = [](MVT A, MVT B) -> bool {
     // This function is used when removing elements: when a vector is compared
-    // to a non-vector, it should return false (to avoid removal).
-    if (A.isVector() != B.isVector())
+    // to a non-vector or a scalable vector to any non-scalable MVT, it should
+    // return false (to avoid removal).
+    if (std::make_tuple(A.isVector(), A.isScalableVector()) !=
+        std::make_tuple(B.isVector(), B.isScalableVector()))
       return false;
 
-    return LT(A, B) || (A.getScalarSizeInBits() == B.getScalarSizeInBits() &&
-                        A.getSizeInBits() == B.getSizeInBits());
+    return std::make_tuple(A.getScalarSizeInBits(), A.getSizeInBits()) <=
+           std::make_tuple(B.getScalarSizeInBits(), B.getSizeInBits());
   };
 
   for (unsigned M : Modes) {
@@ -524,25 +531,29 @@ bool TypeInfer::EnforceSmallerThan(TypeSetByHwMode &Small,
     // smaller-or-equal than MinS.
     auto MinS = min_if(S.begin(), S.end(), isScalar, LT);
     if (MinS != S.end())
-      Changed |= berase_if(B, std::bind(LE, std::placeholders::_1, *MinS));
+      Changed |= berase_if(B, std::bind(SameKindLE,
+                                        std::placeholders::_1, *MinS));
 
     // MaxS = max scalar in Big, remove all scalars from Small that are
     // larger than MaxS.
     auto MaxS = max_if(B.begin(), B.end(), isScalar, LT);
     if (MaxS != B.end())
-      Changed |= berase_if(S, std::bind(LE, *MaxS, std::placeholders::_1));
+      Changed |= berase_if(S, std::bind(SameKindLE,
+                                        *MaxS, std::placeholders::_1));
 
     // MinV = min vector in Small, remove all vectors from Big that are
     // smaller-or-equal than MinV.
     auto MinV = min_if(S.begin(), S.end(), isVector, LT);
     if (MinV != S.end())
-      Changed |= berase_if(B, std::bind(LE, std::placeholders::_1, *MinV));
+      Changed |= berase_if(B, std::bind(SameKindLE,
+                                        std::placeholders::_1, *MinV));
 
     // MaxV = max vector in Big, remove all vectors from Small that are
     // larger than MaxV.
     auto MaxV = max_if(B.begin(), B.end(), isVector, LT);
     if (MaxV != B.end())
-      Changed |= berase_if(S, std::bind(LE, *MaxV, std::placeholders::_1));
+      Changed |= berase_if(S, std::bind(SameKindLE,
+                                        *MaxV, std::placeholders::_1));
   }
 
   return Changed;


        


More information about the llvm-commits mailing list