[clang] 1065869 - [Matrix] Add matrix type to Clang.

Florian Hahn via cfe-commits cfe-commits at lists.llvm.org
Mon May 11 11:07:06 PDT 2020


Author: Florian Hahn
Date: 2020-05-11T18:55:45+01:00
New Revision: 10658691951f7e3ffd257f24e29e81a101daa204

URL: https://github.com/llvm/llvm-project/commit/10658691951f7e3ffd257f24e29e81a101daa204
DIFF: https://github.com/llvm/llvm-project/commit/10658691951f7e3ffd257f24e29e81a101daa204.diff

LOG: [Matrix] Add matrix type to Clang.

This patch adds a matrix type to Clang as described in the draft
specification in clang/docs/MatrixSupport.rst. It introduces a new option
-fenable-matrix, which can be used to enable the matrix support.

The patch adds new MatrixType and DependentSizedMatrixType types along
with the plumbing required. Loads of and stores to pointers to matrix
values are lowered to memory operations on 1-D IR arrays. After loading,
the loaded values are cast to a vector. This ensures matrix values use
the alignment of the element type, instead of LLVM's large vector
alignment.

The operators and builtins described in the draft spec will will be added in
follow-up patches.

Reviewers: martong, rsmith, Bigcheese, anemet, dexonsmith, rjmccall, aaron.ballman

Reviewed By: rjmccall

Differential Revision: https://reviews.llvm.org/D72281

Added: 
    clang/test/CodeGen/debug-info-matrix-types.c
    clang/test/CodeGen/matrix-type.c
    clang/test/CodeGenCXX/matrix-type.cpp
    clang/test/Parser/matrix-type-disabled.c
    clang/test/SemaCXX/matrix-type.cpp

Modified: 
    clang/include/clang/AST/ASTContext.h
    clang/include/clang/AST/RecursiveASTVisitor.h
    clang/include/clang/AST/Type.h
    clang/include/clang/AST/TypeLoc.h
    clang/include/clang/AST/TypeProperties.td
    clang/include/clang/Basic/Attr.td
    clang/include/clang/Basic/DiagnosticSemaKinds.td
    clang/include/clang/Basic/Features.def
    clang/include/clang/Basic/LangOptions.def
    clang/include/clang/Basic/TypeNodes.td
    clang/include/clang/Driver/Options.td
    clang/include/clang/Sema/Sema.h
    clang/include/clang/Serialization/TypeBitCodes.def
    clang/lib/AST/ASTContext.cpp
    clang/lib/AST/ASTStructuralEquivalence.cpp
    clang/lib/AST/ExprConstant.cpp
    clang/lib/AST/ItaniumMangle.cpp
    clang/lib/AST/MicrosoftMangle.cpp
    clang/lib/AST/Type.cpp
    clang/lib/AST/TypePrinter.cpp
    clang/lib/CodeGen/CGDebugInfo.cpp
    clang/lib/CodeGen/CGDebugInfo.h
    clang/lib/CodeGen/CGExpr.cpp
    clang/lib/CodeGen/CodeGenFunction.cpp
    clang/lib/CodeGen/CodeGenTypes.cpp
    clang/lib/CodeGen/ItaniumCXXABI.cpp
    clang/lib/Driver/ToolChains/Clang.cpp
    clang/lib/Frontend/CompilerInvocation.cpp
    clang/lib/Sema/SemaExpr.cpp
    clang/lib/Sema/SemaLookup.cpp
    clang/lib/Sema/SemaTemplate.cpp
    clang/lib/Sema/SemaTemplateDeduction.cpp
    clang/lib/Sema/SemaType.cpp
    clang/lib/Sema/TreeTransform.h
    clang/lib/Serialization/ASTReader.cpp
    clang/lib/Serialization/ASTWriter.cpp
    clang/tools/libclang/CIndex.cpp

Removed: 
    


################################################################################
diff  --git a/clang/include/clang/AST/ASTContext.h b/clang/include/clang/AST/ASTContext.h
index 8eb5aa0230d9..509ada3c9696 100644
--- a/clang/include/clang/AST/ASTContext.h
+++ b/clang/include/clang/AST/ASTContext.h
@@ -194,6 +194,8 @@ class ASTContext : public RefCountedBase<ASTContext> {
       DependentAddressSpaceTypes;
   mutable llvm::FoldingSet<VectorType> VectorTypes;
   mutable llvm::FoldingSet<DependentVectorType> DependentVectorTypes;
+  mutable llvm::FoldingSet<ConstantMatrixType> MatrixTypes;
+  mutable llvm::FoldingSet<DependentSizedMatrixType> DependentSizedMatrixTypes;
   mutable llvm::FoldingSet<FunctionNoProtoType> FunctionNoProtoTypes;
   mutable llvm::ContextualFoldingSet<FunctionProtoType, ASTContext&>
     FunctionProtoTypes;
@@ -1326,6 +1328,20 @@ class ASTContext : public RefCountedBase<ASTContext> {
                                           Expr *SizeExpr,
                                           SourceLocation AttrLoc) const;
 
+  /// Return the unique reference to the matrix type of the specified element
+  /// type and size
+  ///
+  /// \pre \p ElementType must be a valid matrix element type (see
+  /// MatrixType::isValidElementType).
+  QualType getConstantMatrixType(QualType ElementType, unsigned NumRows,
+                                 unsigned NumColumns) const;
+
+  /// Return the unique reference to the matrix type of the specified element
+  /// type and size
+  QualType getDependentSizedMatrixType(QualType ElementType, Expr *RowExpr,
+                                       Expr *ColumnExpr,
+                                       SourceLocation AttrLoc) const;
+
   QualType getDependentAddressSpaceType(QualType PointeeType,
                                         Expr *AddrSpaceExpr,
                                         SourceLocation AttrLoc) const;

diff  --git a/clang/include/clang/AST/RecursiveASTVisitor.h b/clang/include/clang/AST/RecursiveASTVisitor.h
index 0680a8f70212..dbd6a04c3660 100644
--- a/clang/include/clang/AST/RecursiveASTVisitor.h
+++ b/clang/include/clang/AST/RecursiveASTVisitor.h
@@ -1006,6 +1006,17 @@ DEF_TRAVERSE_TYPE(VectorType, { TRY_TO(TraverseType(T->getElementType())); })
 
 DEF_TRAVERSE_TYPE(ExtVectorType, { TRY_TO(TraverseType(T->getElementType())); })
 
+DEF_TRAVERSE_TYPE(ConstantMatrixType,
+                  { TRY_TO(TraverseType(T->getElementType())); })
+
+DEF_TRAVERSE_TYPE(DependentSizedMatrixType, {
+  if (T->getRowExpr())
+    TRY_TO(TraverseStmt(T->getRowExpr()));
+  if (T->getColumnExpr())
+    TRY_TO(TraverseStmt(T->getColumnExpr()));
+  TRY_TO(TraverseType(T->getElementType()));
+})
+
 DEF_TRAVERSE_TYPE(FunctionNoProtoType,
                   { TRY_TO(TraverseType(T->getReturnType())); })
 
@@ -1258,6 +1269,18 @@ DEF_TRAVERSE_TYPELOC(ExtVectorType, {
   TRY_TO(TraverseType(TL.getTypePtr()->getElementType()));
 })
 
+DEF_TRAVERSE_TYPELOC(ConstantMatrixType, {
+  TRY_TO(TraverseStmt(TL.getAttrRowOperand()));
+  TRY_TO(TraverseStmt(TL.getAttrColumnOperand()));
+  TRY_TO(TraverseType(TL.getTypePtr()->getElementType()));
+})
+
+DEF_TRAVERSE_TYPELOC(DependentSizedMatrixType, {
+  TRY_TO(TraverseStmt(TL.getAttrRowOperand()));
+  TRY_TO(TraverseStmt(TL.getAttrColumnOperand()));
+  TRY_TO(TraverseType(TL.getTypePtr()->getElementType()));
+})
+
 DEF_TRAVERSE_TYPELOC(FunctionNoProtoType,
                      { TRY_TO(TraverseTypeLoc(TL.getReturnLoc())); })
 

diff  --git a/clang/include/clang/AST/Type.h b/clang/include/clang/AST/Type.h
index 08522a670c99..21b14de997cb 100644
--- a/clang/include/clang/AST/Type.h
+++ b/clang/include/clang/AST/Type.h
@@ -1654,6 +1654,19 @@ class alignas(8) Type : public ExtQualsTypeCommonBase {
     uint32_t NumElements;
   };
 
+  class ConstantMatrixTypeBitfields {
+    friend class ConstantMatrixType;
+
+    unsigned : NumTypeBits;
+
+    /// Number of rows and columns. Using 20 bits allows supporting very large
+    /// matrixes, while keeping 24 bits to accommodate NumTypeBits.
+    unsigned NumRows : 20;
+    unsigned NumColumns : 20;
+
+    static constexpr uint32_t MaxElementsPerDimension = (1 << 20) - 1;
+  };
+
   class AttributedTypeBitfields {
     friend class AttributedType;
 
@@ -1763,6 +1776,7 @@ class alignas(8) Type : public ExtQualsTypeCommonBase {
     TypeWithKeywordBitfields TypeWithKeywordBits;
     ElaboratedTypeBitfields ElaboratedTypeBits;
     VectorTypeBitfields VectorTypeBits;
+    ConstantMatrixTypeBitfields ConstantMatrixTypeBits;
     SubstTemplateTypeParmPackTypeBitfields SubstTemplateTypeParmPackTypeBits;
     TemplateSpecializationTypeBitfields TemplateSpecializationTypeBits;
     DependentTemplateSpecializationTypeBitfields
@@ -2021,6 +2035,7 @@ class alignas(8) Type : public ExtQualsTypeCommonBase {
   bool isComplexIntegerType() const;            // GCC _Complex integer type.
   bool isVectorType() const;                    // GCC vector type.
   bool isExtVectorType() const;                 // Extended vector type.
+  bool isConstantMatrixType() const;            // Matrix type.
   bool isDependentAddressSpaceType() const;     // value-dependent address space qualifier
   bool isObjCObjectPointerType() const;         // pointer to ObjC object
   bool isObjCRetainableType() const;            // ObjC object or block pointer
@@ -3390,6 +3405,131 @@ class ExtVectorType : public VectorType {
   }
 };
 
+/// Represents a matrix type, as defined in the Matrix Types clang extensions.
+/// __attribute__((matrix_type(rows, columns))), where "rows" specifies
+/// number of rows and "columns" specifies the number of columns.
+class MatrixType : public Type, public llvm::FoldingSetNode {
+protected:
+  friend class ASTContext;
+
+  /// The element type of the matrix.
+  QualType ElementType;
+
+  MatrixType(QualType ElementTy, QualType CanonElementTy);
+
+  MatrixType(TypeClass TypeClass, QualType ElementTy, QualType CanonElementTy,
+             const Expr *RowExpr = nullptr, const Expr *ColumnExpr = nullptr);
+
+public:
+  /// Returns type of the elements being stored in the matrix
+  QualType getElementType() const { return ElementType; }
+
+  /// Valid elements types are the following:
+  /// * an integer type (as in C2x 6.2.5p19), but excluding enumerated types
+  ///   and _Bool
+  /// * the standard floating types float or double
+  /// * a half-precision floating point type, if one is supported on the target
+  static bool isValidElementType(QualType T) {
+    return T->isDependentType() ||
+           (T->isRealType() && !T->isBooleanType() && !T->isEnumeralType());
+  }
+
+  bool isSugared() const { return false; }
+  QualType desugar() const { return QualType(this, 0); }
+
+  static bool classof(const Type *T) {
+    return T->getTypeClass() == ConstantMatrix ||
+           T->getTypeClass() == DependentSizedMatrix;
+  }
+};
+
+/// Represents a concrete matrix type with constant number of rows and columns
+class ConstantMatrixType final : public MatrixType {
+protected:
+  friend class ASTContext;
+
+  /// The element type of the matrix.
+  QualType ElementType;
+
+  ConstantMatrixType(QualType MatrixElementType, unsigned NRows,
+                     unsigned NColumns, QualType CanonElementType);
+
+  ConstantMatrixType(TypeClass typeClass, QualType MatrixType, unsigned NRows,
+                     unsigned NColumns, QualType CanonElementType);
+
+public:
+  /// Returns the number of rows in the matrix.
+  unsigned getNumRows() const { return ConstantMatrixTypeBits.NumRows; }
+
+  /// Returns the number of columns in the matrix.
+  unsigned getNumColumns() const { return ConstantMatrixTypeBits.NumColumns; }
+
+  /// Returns the number of elements required to embed the matrix into a vector.
+  unsigned getNumElementsFlattened() const {
+    return ConstantMatrixTypeBits.NumRows * ConstantMatrixTypeBits.NumColumns;
+  }
+
+  /// Returns true if \p NumElements is a valid matrix dimension.
+  static bool isDimensionValid(uint64_t NumElements) {
+    return NumElements > 0 &&
+           NumElements <= ConstantMatrixTypeBitfields::MaxElementsPerDimension;
+  }
+
+  void Profile(llvm::FoldingSetNodeID &ID) {
+    Profile(ID, getElementType(), getNumRows(), getNumColumns(),
+            getTypeClass());
+  }
+
+  static void Profile(llvm::FoldingSetNodeID &ID, QualType ElementType,
+                      unsigned NumRows, unsigned NumColumns,
+                      TypeClass TypeClass) {
+    ID.AddPointer(ElementType.getAsOpaquePtr());
+    ID.AddInteger(NumRows);
+    ID.AddInteger(NumColumns);
+    ID.AddInteger(TypeClass);
+  }
+
+  static bool classof(const Type *T) {
+    return T->getTypeClass() == ConstantMatrix;
+  }
+};
+
+/// Represents a matrix type where the type and the number of rows and columns
+/// is dependent on a template.
+class DependentSizedMatrixType final : public MatrixType {
+  friend class ASTContext;
+
+  const ASTContext &Context;
+  Expr *RowExpr;
+  Expr *ColumnExpr;
+
+  SourceLocation loc;
+
+  DependentSizedMatrixType(const ASTContext &Context, QualType ElementType,
+                           QualType CanonicalType, Expr *RowExpr,
+                           Expr *ColumnExpr, SourceLocation loc);
+
+public:
+  QualType getElementType() const { return ElementType; }
+  Expr *getRowExpr() const { return RowExpr; }
+  Expr *getColumnExpr() const { return ColumnExpr; }
+  SourceLocation getAttributeLoc() const { return loc; }
+
+  bool isSugared() const { return false; }
+  QualType desugar() const { return QualType(this, 0); }
+
+  static bool classof(const Type *T) {
+    return T->getTypeClass() == DependentSizedMatrix;
+  }
+
+  void Profile(llvm::FoldingSetNodeID &ID) {
+    Profile(ID, Context, getElementType(), getRowExpr(), getColumnExpr());
+  }
+
+  static void Profile(llvm::FoldingSetNodeID &ID, const ASTContext &Context,
+                      QualType ElementType, Expr *RowExpr, Expr *ColumnExpr);
+};
+
 /// FunctionType - C99 6.7.5.3 - Function Declarators.  This is the common base
 /// class of FunctionNoProtoType and FunctionProtoType.
 class FunctionType : public Type {
@@ -6605,6 +6745,10 @@ inline bool Type::isExtVectorType() const {
   return isa<ExtVectorType>(CanonicalType);
 }
 
+inline bool Type::isConstantMatrixType() const {
+  return isa<ConstantMatrixType>(CanonicalType);
+}
+
 inline bool Type::isDependentAddressSpaceType() const {
   return isa<DependentAddressSpaceType>(CanonicalType);
 }

diff  --git a/clang/include/clang/AST/TypeLoc.h b/clang/include/clang/AST/TypeLoc.h
index 2221485983b2..72cc8ef098e7 100644
--- a/clang/include/clang/AST/TypeLoc.h
+++ b/clang/include/clang/AST/TypeLoc.h
@@ -1735,6 +1735,7 @@ class DependentAddressSpaceTypeLoc
 
   void initializeLocal(ASTContext &Context, SourceLocation loc) {
     setAttrNameLoc(loc);
+    setAttrOperandParensRange(loc);
     setAttrOperandParensRange(SourceRange(loc));
     setAttrExprOperand(getTypePtr()->getAddrSpaceExpr());
   }
@@ -1774,6 +1775,68 @@ class DependentSizedExtVectorTypeLoc :
                                      DependentSizedExtVectorType> {
 };
 
+struct MatrixTypeLocInfo {
+  SourceLocation AttrLoc;
+  SourceRange OperandParens;
+  Expr *RowOperand;
+  Expr *ColumnOperand;
+};
+
+class MatrixTypeLoc : public ConcreteTypeLoc<UnqualTypeLoc, MatrixTypeLoc,
+                                             MatrixType, MatrixTypeLocInfo> {
+public:
+  /// The location of the attribute name, i.e.
+  ///    float __attribute__((matrix_type(4, 2)))
+  ///                         ^~~~~~~~~~~~~~~~~
+  SourceLocation getAttrNameLoc() const { return getLocalData()->AttrLoc; }
+  void setAttrNameLoc(SourceLocation loc) { getLocalData()->AttrLoc = loc; }
+
+  /// The attribute's row operand, if it has one.
+  ///    float __attribute__((matrix_type(4, 2)))
+  ///                                     ^
+  Expr *getAttrRowOperand() const { return getLocalData()->RowOperand; }
+  void setAttrRowOperand(Expr *e) { getLocalData()->RowOperand = e; }
+
+  /// The attribute's column operand, if it has one.
+  ///    float __attribute__((matrix_type(4, 2)))
+  ///                                        ^
+  Expr *getAttrColumnOperand() const { return getLocalData()->ColumnOperand; }
+  void setAttrColumnOperand(Expr *e) { getLocalData()->ColumnOperand = e; }
+
+  /// The location of the parentheses around the operand, if there is
+  /// an operand.
+  ///    float __attribute__((matrix_type(4, 2)))
+  ///                                    ^    ^
+  SourceRange getAttrOperandParensRange() const {
+    return getLocalData()->OperandParens;
+  }
+  void setAttrOperandParensRange(SourceRange range) {
+    getLocalData()->OperandParens = range;
+  }
+
+  SourceRange getLocalSourceRange() const {
+    SourceRange range(getAttrNameLoc());
+    range.setEnd(getAttrOperandParensRange().getEnd());
+    return range;
+  }
+
+  void initializeLocal(ASTContext &Context, SourceLocation loc) {
+    setAttrNameLoc(loc);
+    setAttrOperandParensRange(loc);
+    setAttrRowOperand(nullptr);
+    setAttrColumnOperand(nullptr);
+  }
+};
+
+class ConstantMatrixTypeLoc
+    : public InheritingConcreteTypeLoc<MatrixTypeLoc, ConstantMatrixTypeLoc,
+                                       ConstantMatrixType> {};
+
+class DependentSizedMatrixTypeLoc
+    : public InheritingConcreteTypeLoc<MatrixTypeLoc,
+                                       DependentSizedMatrixTypeLoc,
+                                       DependentSizedMatrixType> {};
+
 // FIXME: location of the '_Complex' keyword.
 class ComplexTypeLoc : public InheritingConcreteTypeLoc<TypeSpecTypeLoc,
                                                         ComplexTypeLoc,

diff  --git a/clang/include/clang/AST/TypeProperties.td b/clang/include/clang/AST/TypeProperties.td
index 12bc5a4ee8a3..4540ea0e1952 100644
--- a/clang/include/clang/AST/TypeProperties.td
+++ b/clang/include/clang/AST/TypeProperties.td
@@ -224,6 +224,41 @@ let Class = DependentSizedExtVectorType in {
   }]>;
 }
 
+let Class = MatrixType in {
+  def : Property<"elementType", QualType> {
+    let Read = [{ node->getElementType() }];
+  }
+}
+
+let Class = ConstantMatrixType in {
+  def : Property<"numRows", UInt32> {
+    let Read = [{ node->getNumRows() }];
+  }
+  def : Property<"numColumns", UInt32> {
+    let Read = [{ node->getNumColumns() }];
+  }
+
+  def : Creator<[{
+    return ctx.getConstantMatrixType(elementType, numRows, numColumns);
+  }]>;
+}
+
+let Class = DependentSizedMatrixType in {
+  def : Property<"rows", ExprRef> {
+    let Read = [{ node->getRowExpr() }];
+  }
+  def : Property<"columns", ExprRef> {
+    let Read = [{ node->getColumnExpr() }];
+  }
+  def : Property<"attributeLoc", SourceLocation> {
+    let Read = [{ node->getAttributeLoc() }];
+  }
+
+  def : Creator<[{
+    return ctx.getDependentSizedMatrixType(elementType, rows, columns, attributeLoc);
+  }]>;
+}
+
 let Class = FunctionType in {
   def : Property<"returnType", QualType> {
     let Read = [{ node->getReturnType() }];

diff  --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td
index fca8366566f4..833b862cfb5f 100644
--- a/clang/include/clang/Basic/Attr.td
+++ b/clang/include/clang/Basic/Attr.td
@@ -2460,6 +2460,15 @@ def VecTypeHint : InheritableAttr {
   let Documentation = [Undocumented];
 }
 
+def MatrixType : TypeAttr {
+  let Spellings = [Clang<"matrix_type">];
+  let Subjects = SubjectList<[TypedefName], ErrorDiag>;
+  let Args = [ExprArgument<"NumRows">, ExprArgument<"NumColumns">];
+  let Documentation = [Undocumented];
+  let ASTNode = 0;
+  let PragmaAttributeSupport = 0;
+}
+
 def Visibility : InheritableAttr {
   let Clone = 0;
   let Spellings = [GCC<"visibility">];

diff  --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index fb2d8e48fa70..44a34080b239 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -2774,6 +2774,7 @@ def err_attribute_too_many_arguments : Error<
 def err_attribute_too_few_arguments : Error<
   "%0 attribute takes at least %1 argument%s1">;
 def err_attribute_invalid_vector_type : Error<"invalid vector element type %0">;
+def err_attribute_invalid_matrix_type : Error<"invalid matrix element type %0">;
 def err_attribute_bad_neon_vector_size : Error<
   "Neon vector size must be 64 or 128 bits">;
 def err_attribute_requires_positive_integer : Error<
@@ -2877,8 +2878,8 @@ def err_init_method_bad_return_type : Error<
   "init methods must return an object pointer type, not %0">;
 def err_attribute_invalid_size : Error<
   "vector size not an integral multiple of component size">;
-def err_attribute_zero_size : Error<"zero vector size">;
-def err_attribute_size_too_large : Error<"vector size too large">;
+def err_attribute_zero_size : Error<"zero %0 size">;
+def err_attribute_size_too_large : Error<"%0 size too large">;
 def err_typecheck_vector_not_convertable_implict_truncation : Error<
    "cannot convert between %select{scalar|vector}0 type %1 and vector type"
    " %2 as implicit conversion would cause truncation">;
@@ -10741,6 +10742,9 @@ def err_builtin_launder_invalid_arg : Error<
   "%select{non-pointer|function pointer|void pointer}0 argument to "
   "'__builtin_launder' is not allowed">;
 
+def err_builtin_matrix_disabled: Error<
+  "matrix types extension is disabled. Pass -fenable-matrix to enable it">;
+
 def err_preserve_field_info_not_field : Error<
   "__builtin_preserve_field_info argument %0 not a field access">;
 def err_preserve_field_info_not_const: Error<

diff  --git a/clang/include/clang/Basic/Features.def b/clang/include/clang/Basic/Features.def
index 20e1b141a3ef..eb69734571cc 100644
--- a/clang/include/clang/Basic/Features.def
+++ b/clang/include/clang/Basic/Features.def
@@ -253,6 +253,7 @@ EXTENSION(pragma_clang_attribute_namespaces, true)
 EXTENSION(pragma_clang_attribute_external_declaration, true)
 EXTENSION(gnu_asm, LangOpts.GNUAsm)
 EXTENSION(gnu_asm_goto_with_outputs, LangOpts.GNUAsm)
+EXTENSION(matrix_types, LangOpts.MatrixTypes)
 
 #undef EXTENSION
 #undef FEATURE

diff  --git a/clang/include/clang/Basic/LangOptions.def b/clang/include/clang/Basic/LangOptions.def
index 8073031e5975..55784c3911dd 100644
--- a/clang/include/clang/Basic/LangOptions.def
+++ b/clang/include/clang/Basic/LangOptions.def
@@ -358,6 +358,8 @@ LANGOPT(PaddingOnUnsignedFixedPoint, 1, 0,
 
 LANGOPT(RegisterStaticDestructors, 1, 1, "Register C++ static destructors")
 
+LANGOPT(MatrixTypes, 1, 0, "Enable or disable the builtin matrix type")
+
 COMPATIBLE_VALUE_LANGOPT(MaxTokens, 32, 0, "Max number of tokens per TU or 0")
 
 ENUM_LANGOPT(SignReturnAddressScope, SignReturnAddressScopeKind, 2, SignReturnAddressScopeKind::None,

diff  --git a/clang/include/clang/Basic/TypeNodes.td b/clang/include/clang/Basic/TypeNodes.td
index cd15a498642f..a4e3002b9075 100644
--- a/clang/include/clang/Basic/TypeNodes.td
+++ b/clang/include/clang/Basic/TypeNodes.td
@@ -69,6 +69,9 @@ def DependentAddressSpaceType : TypeNode<Type>, AlwaysDependent;
 def VectorType : TypeNode<Type>;
 def DependentVectorType : TypeNode<Type>, AlwaysDependent;
 def ExtVectorType : TypeNode<VectorType>;
+def MatrixType : TypeNode<Type, 1>;
+def ConstantMatrixType : TypeNode<MatrixType>;
+def DependentSizedMatrixType : TypeNode<MatrixType>, AlwaysDependent;
 def FunctionType : TypeNode<Type, 1>;
 def FunctionProtoType : TypeNode<FunctionType>;
 def FunctionNoProtoType : TypeNode<FunctionType>;

diff  --git a/clang/include/clang/Driver/Options.td b/clang/include/clang/Driver/Options.td
index 0d00e34040bc..813016d1e8a5 100644
--- a/clang/include/clang/Driver/Options.td
+++ b/clang/include/clang/Driver/Options.td
@@ -2014,6 +2014,10 @@ def fstrict_return : Flag<["-"], "fstrict-return">, Group<f_Group>,
 def fno_strict_return : Flag<["-"], "fno-strict-return">, Group<f_Group>,
   Flags<[CC1Option]>;
 
+def fenable_matrix : Flag<["-"], "fenable-matrix">, Group<f_Group>,
+    Flags<[CC1Option]>,
+    HelpText<"Enable matrix data type and related builtin functions">;
+
 def fallow_editor_placeholders : Flag<["-"], "fallow-editor-placeholders">,
   Group<f_Group>, Flags<[CC1Option]>,
   HelpText<"Treat editor placeholders as valid source code">;

diff  --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h
index d4c8e483a961..5aeb410e7288 100644
--- a/clang/include/clang/Sema/Sema.h
+++ b/clang/include/clang/Sema/Sema.h
@@ -1627,6 +1627,9 @@ class Sema final {
   QualType BuildVectorType(QualType T, Expr *VecSize, SourceLocation AttrLoc);
   QualType BuildExtVectorType(QualType T, Expr *ArraySize,
                               SourceLocation AttrLoc);
+  QualType BuildMatrixType(QualType T, Expr *NumRows, Expr *NumColumns,
+                           SourceLocation AttrLoc);
+
   QualType BuildAddressSpaceAttr(QualType &T, LangAS ASIdx, Expr *AddrSpace,
                                  SourceLocation AttrLoc);
 

diff  --git a/clang/include/clang/Serialization/TypeBitCodes.def b/clang/include/clang/Serialization/TypeBitCodes.def
index 561c8869ead6..e92e05810648 100644
--- a/clang/include/clang/Serialization/TypeBitCodes.def
+++ b/clang/include/clang/Serialization/TypeBitCodes.def
@@ -60,5 +60,7 @@ TYPE_BIT_CODE(DependentVector, DEPENDENT_SIZED_VECTOR, 48)
 TYPE_BIT_CODE(MacroQualified, MACRO_QUALIFIED, 49)
 TYPE_BIT_CODE(ExtInt, EXT_INT, 50)
 TYPE_BIT_CODE(DependentExtInt, DEPENDENT_EXT_INT, 51)
+TYPE_BIT_CODE(ConstantMatrix, CONSTANT_MATRIX, 52)
+TYPE_BIT_CODE(DependentSizedMatrix, DEPENDENT_SIZE_MATRIX, 53)
 
 #undef TYPE_BIT_CODE

diff  --git a/clang/lib/AST/ASTContext.cpp b/clang/lib/AST/ASTContext.cpp
index 4ed0073eba08..8f3858126253 100644
--- a/clang/lib/AST/ASTContext.cpp
+++ b/clang/lib/AST/ASTContext.cpp
@@ -1932,6 +1932,17 @@ TypeInfo ASTContext::getTypeInfoImpl(const Type *T) const {
     break;
   }
 
+  case Type::ConstantMatrix: {
+    const auto *MT = cast<ConstantMatrixType>(T);
+    TypeInfo ElementInfo = getTypeInfo(MT->getElementType());
+    // The internal layout of a matrix value is implementation defined.
+    // Initially be ABI compatible with arrays with respect to alignment and
+    // size.
+    Width = ElementInfo.Width * MT->getNumRows() * MT->getNumColumns();
+    Align = ElementInfo.Align;
+    break;
+  }
+
   case Type::Builtin:
     switch (cast<BuiltinType>(T)->getKind()) {
     default: llvm_unreachable("Unknown builtin type!");
@@ -3362,6 +3373,8 @@ QualType ASTContext::getVariableArrayDecayedType(QualType type) const {
   case Type::DependentVector:
   case Type::ExtVector:
   case Type::DependentSizedExtVector:
+  case Type::ConstantMatrix:
+  case Type::DependentSizedMatrix:
   case Type::DependentAddressSpace:
   case Type::ObjCObject:
   case Type::ObjCInterface:
@@ -3775,6 +3788,78 @@ ASTContext::getDependentSizedExtVectorType(QualType vecType,
   return QualType(New, 0);
 }
 
+QualType ASTContext::getConstantMatrixType(QualType ElementTy, unsigned NumRows,
+                                           unsigned NumColumns) const {
+  llvm::FoldingSetNodeID ID;
+  ConstantMatrixType::Profile(ID, ElementTy, NumRows, NumColumns,
+                              Type::ConstantMatrix);
+
+  assert(MatrixType::isValidElementType(ElementTy) &&
+         "need a valid element type");
+  assert(ConstantMatrixType::isDimensionValid(NumRows) &&
+         ConstantMatrixType::isDimensionValid(NumColumns) &&
+         "need valid matrix dimensions");
+  void *InsertPos = nullptr;
+  if (ConstantMatrixType *MTP = MatrixTypes.FindNodeOrInsertPos(ID, InsertPos))
+    return QualType(MTP, 0);
+
+  QualType Canonical;
+  if (!ElementTy.isCanonical()) {
+    Canonical =
+        getConstantMatrixType(getCanonicalType(ElementTy), NumRows, NumColumns);
+
+    ConstantMatrixType *NewIP = MatrixTypes.FindNodeOrInsertPos(ID, InsertPos);
+    assert(!NewIP && "Matrix type shouldn't already exist in the map");
+    (void)NewIP;
+  }
+
+  auto *New = new (*this, TypeAlignment)
+      ConstantMatrixType(ElementTy, NumRows, NumColumns, Canonical);
+  MatrixTypes.InsertNode(New, InsertPos);
+  Types.push_back(New);
+  return QualType(New, 0);
+}
+
+QualType ASTContext::getDependentSizedMatrixType(QualType ElementTy,
+                                                 Expr *RowExpr,
+                                                 Expr *ColumnExpr,
+                                                 SourceLocation AttrLoc) const {
+  QualType CanonElementTy = getCanonicalType(ElementTy);
+  llvm::FoldingSetNodeID ID;
+  DependentSizedMatrixType::Profile(ID, *this, CanonElementTy, RowExpr,
+                                    ColumnExpr);
+
+  void *InsertPos = nullptr;
+  DependentSizedMatrixType *Canon =
+      DependentSizedMatrixTypes.FindNodeOrInsertPos(ID, InsertPos);
+
+  if (!Canon) {
+    Canon = new (*this, TypeAlignment) DependentSizedMatrixType(
+        *this, CanonElementTy, QualType(), RowExpr, ColumnExpr, AttrLoc);
+#ifndef NDEBUG
+    DependentSizedMatrixType *CanonCheck =
+        DependentSizedMatrixTypes.FindNodeOrInsertPos(ID, InsertPos);
+    assert(!CanonCheck && "Dependent-sized matrix canonical type broken");
+#endif
+    DependentSizedMatrixTypes.InsertNode(Canon, InsertPos);
+    Types.push_back(Canon);
+  }
+
+  // Already have a canonical version of the matrix type
+  //
+  // If it exactly matches the requested type, use it directly.
+  if (Canon->getElementType() == ElementTy && Canon->getRowExpr() == RowExpr &&
+      Canon->getRowExpr() == ColumnExpr)
+    return QualType(Canon, 0);
+
+  // Use Canon as the canonical type for newly-built type.
+  DependentSizedMatrixType *New = new (*this, TypeAlignment)
+      DependentSizedMatrixType(*this, ElementTy, QualType(Canon, 0), RowExpr,
+                               ColumnExpr, AttrLoc);
+  Types.push_back(New);
+  return QualType(New, 0);
+}
+
 QualType ASTContext::getDependentAddressSpaceType(QualType PointeeType,
                                                   Expr *AddrSpaceExpr,
                                                   SourceLocation AttrLoc) const {
@@ -7338,6 +7423,11 @@ void ASTContext::getObjCEncodingForTypeImpl(QualType T, std::string &S,
       *NotEncodedT = T;
     return;
 
+  case Type::ConstantMatrix:
+    if (NotEncodedT)
+      *NotEncodedT = T;
+    return;
+
   // We could see an undeduced auto type here during error recovery.
   // Just ignore it.
   case Type::Auto:
@@ -8217,6 +8307,16 @@ static bool areCompatVectorTypes(const VectorType *LHS,
          LHS->getNumElements() == RHS->getNumElements();
 }
 
+/// areCompatMatrixTypes - Return true if the two specified matrix types are
+/// compatible.
+static bool areCompatMatrixTypes(const ConstantMatrixType *LHS,
+                                 const ConstantMatrixType *RHS) {
+  assert(LHS->isCanonicalUnqualified() && RHS->isCanonicalUnqualified());
+  return LHS->getElementType() == RHS->getElementType() &&
+         LHS->getNumRows() == RHS->getNumRows() &&
+         LHS->getNumColumns() == RHS->getNumColumns();
+}
+
 bool ASTContext::areCompatibleVectorTypes(QualType FirstVec,
                                           QualType SecondVec) {
   assert(FirstVec->isVectorType() && "FirstVec should be a vector type");
@@ -9414,6 +9514,11 @@ QualType ASTContext::mergeTypes(QualType LHS, QualType RHS,
                              RHSCan->castAs<VectorType>()))
       return LHS;
     return {};
+  case Type::ConstantMatrix:
+    if (areCompatMatrixTypes(LHSCan->castAs<ConstantMatrixType>(),
+                             RHSCan->castAs<ConstantMatrixType>()))
+      return LHS;
+    return {};
   case Type::ObjCObject: {
     // Check if the types are assignment compatible.
     // FIXME: This should be type compatibility, e.g. whether

diff  --git a/clang/lib/AST/ASTStructuralEquivalence.cpp b/clang/lib/AST/ASTStructuralEquivalence.cpp
index c562830c41e1..8b5b2444f1e2 100644
--- a/clang/lib/AST/ASTStructuralEquivalence.cpp
+++ b/clang/lib/AST/ASTStructuralEquivalence.cpp
@@ -617,6 +617,34 @@ static bool IsStructurallyEquivalent(StructuralEquivalenceContext &Context,
     break;
   }
 
+  case Type::DependentSizedMatrix: {
+    const DependentSizedMatrixType *Mat1 = cast<DependentSizedMatrixType>(T1);
+    const DependentSizedMatrixType *Mat2 = cast<DependentSizedMatrixType>(T2);
+    // The element types, row and column expressions must be structurally
+    // equivalent.
+    if (!IsStructurallyEquivalent(Context, Mat1->getRowExpr(),
+                                  Mat2->getRowExpr()) ||
+        !IsStructurallyEquivalent(Context, Mat1->getColumnExpr(),
+                                  Mat2->getColumnExpr()) ||
+        !IsStructurallyEquivalent(Context, Mat1->getElementType(),
+                                  Mat2->getElementType()))
+      return false;
+    break;
+  }
+
+  case Type::ConstantMatrix: {
+    const ConstantMatrixType *Mat1 = cast<ConstantMatrixType>(T1);
+    const ConstantMatrixType *Mat2 = cast<ConstantMatrixType>(T2);
+    // The element types must be structurally equivalent and the number of rows
+    // and columns must match.
+    if (!IsStructurallyEquivalent(Context, Mat1->getElementType(),
+                                  Mat2->getElementType()) ||
+        Mat1->getNumRows() != Mat2->getNumRows() ||
+        Mat1->getNumColumns() != Mat2->getNumColumns())
+      return false;
+    break;
+  }
+
   case Type::FunctionProto: {
     const auto *Proto1 = cast<FunctionProtoType>(T1);
     const auto *Proto2 = cast<FunctionProtoType>(T2);

diff  --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp
index 1c738ff51175..40c60f1e45cc 100644
--- a/clang/lib/AST/ExprConstant.cpp
+++ b/clang/lib/AST/ExprConstant.cpp
@@ -10350,6 +10350,7 @@ EvaluateBuiltinClassifyType(QualType T, const LangOptions &LangOpts) {
   case Type::BlockPointer:
   case Type::Vector:
   case Type::ExtVector:
+  case Type::ConstantMatrix:
   case Type::ObjCObject:
   case Type::ObjCInterface:
   case Type::ObjCObjectPointer:

diff  --git a/clang/lib/AST/ItaniumMangle.cpp b/clang/lib/AST/ItaniumMangle.cpp
index d60cacf07534..dbf004b6e2be 100644
--- a/clang/lib/AST/ItaniumMangle.cpp
+++ b/clang/lib/AST/ItaniumMangle.cpp
@@ -2079,6 +2079,8 @@ bool CXXNameMangler::mangleUnresolvedTypeOrSimpleId(QualType Ty,
   case Type::DependentSizedExtVector:
   case Type::Vector:
   case Type::ExtVector:
+  case Type::ConstantMatrix:
+  case Type::DependentSizedMatrix:
   case Type::FunctionProto:
   case Type::FunctionNoProto:
   case Type::Paren:
@@ -3343,6 +3345,31 @@ void CXXNameMangler::mangleType(const DependentSizedExtVectorType *T) {
   mangleType(T->getElementType());
 }
 
+void CXXNameMangler::mangleType(const ConstantMatrixType *T) {
+  // Mangle matrix types using a vendor extended type qualifier:
+  // U<Len>matrix_type<Rows><Columns><element type>
+  StringRef VendorQualifier = "matrix_type";
+  Out << "U" << VendorQualifier.size() << VendorQualifier;
+  auto &ASTCtx = getASTContext();
+  unsigned BitWidth = ASTCtx.getTypeSize(ASTCtx.getSizeType());
+  llvm::APSInt Rows(BitWidth);
+  Rows = T->getNumRows();
+  mangleIntegerLiteral(ASTCtx.getSizeType(), Rows);
+  llvm::APSInt Columns(BitWidth);
+  Columns = T->getNumColumns();
+  mangleIntegerLiteral(ASTCtx.getSizeType(), Columns);
+  mangleType(T->getElementType());
+}
+
+void CXXNameMangler::mangleType(const DependentSizedMatrixType *T) {
+  // U<Len>matrix_type<row expr><column expr><element type>
+  StringRef VendorQualifier = "matrix_type";
+  Out << "U" << VendorQualifier.size() << VendorQualifier;
+  mangleTemplateArg(T->getRowExpr());
+  mangleTemplateArg(T->getColumnExpr());
+  mangleType(T->getElementType());
+}
+
 void CXXNameMangler::mangleType(const DependentAddressSpaceType *T) {
   SplitQualType split = T->getPointeeType().split();
   mangleQualifiers(split.Quals, T);

diff  --git a/clang/lib/AST/MicrosoftMangle.cpp b/clang/lib/AST/MicrosoftMangle.cpp
index dc5c15fbef68..e3796accb26e 100644
--- a/clang/lib/AST/MicrosoftMangle.cpp
+++ b/clang/lib/AST/MicrosoftMangle.cpp
@@ -2730,6 +2730,23 @@ void MicrosoftCXXNameMangler::mangleType(const DependentSizedExtVectorType *T,
     << Range;
 }
 
+void MicrosoftCXXNameMangler::mangleType(const ConstantMatrixType *T,
+                                         Qualifiers quals, SourceRange Range) {
+  DiagnosticsEngine &Diags = Context.getDiags();
+  unsigned DiagID = Diags.getCustomDiagID(DiagnosticsEngine::Error,
+                                          "Cannot mangle this matrix type yet");
+  Diags.Report(Range.getBegin(), DiagID) << Range;
+}
+
+void MicrosoftCXXNameMangler::mangleType(const DependentSizedMatrixType *T,
+                                         Qualifiers quals, SourceRange Range) {
+  DiagnosticsEngine &Diags = Context.getDiags();
+  unsigned DiagID = Diags.getCustomDiagID(
+      DiagnosticsEngine::Error,
+      "Cannot mangle this dependent-sized matrix type yet");
+  Diags.Report(Range.getBegin(), DiagID) << Range;
+}
+
 void MicrosoftCXXNameMangler::mangleType(const DependentAddressSpaceType *T,
                                          Qualifiers, SourceRange Range) {
   DiagnosticsEngine &Diags = Context.getDiags();

diff  --git a/clang/lib/AST/Type.cpp b/clang/lib/AST/Type.cpp
index 3408149bd760..e4d8af9c70f5 100644
--- a/clang/lib/AST/Type.cpp
+++ b/clang/lib/AST/Type.cpp
@@ -282,6 +282,53 @@ void DependentAddressSpaceType::Profile(llvm::FoldingSetNodeID &ID,
   AddrSpaceExpr->Profile(ID, Context, true);
 }
 
+MatrixType::MatrixType(TypeClass tc, QualType matrixType, QualType canonType,
+                       const Expr *RowExpr, const Expr *ColumnExpr)
+    : Type(tc, canonType,
+           (RowExpr ? (matrixType->getDependence() | TypeDependence::Dependent |
+                       TypeDependence::Instantiation |
+                       (matrixType->isVariablyModifiedType()
+                            ? TypeDependence::VariablyModified
+                            : TypeDependence::None) |
+                       (matrixType->containsUnexpandedParameterPack() ||
+                                (RowExpr &&
+                                 RowExpr->containsUnexpandedParameterPack()) ||
+                                (ColumnExpr &&
+                                 ColumnExpr->containsUnexpandedParameterPack())
+                            ? TypeDependence::UnexpandedPack
+                            : TypeDependence::None))
+                    : matrixType->getDependence())),
+      ElementType(matrixType) {}
+
+ConstantMatrixType::ConstantMatrixType(QualType matrixType, unsigned nRows,
+                                       unsigned nColumns, QualType canonType)
+    : ConstantMatrixType(ConstantMatrix, matrixType, nRows, nColumns,
+                         canonType) {}
+
+ConstantMatrixType::ConstantMatrixType(TypeClass tc, QualType matrixType,
+                                       unsigned nRows, unsigned nColumns,
+                                       QualType canonType)
+    : MatrixType(tc, matrixType, canonType) {
+  ConstantMatrixTypeBits.NumRows = nRows;
+  ConstantMatrixTypeBits.NumColumns = nColumns;
+}
+
+DependentSizedMatrixType::DependentSizedMatrixType(
+    const ASTContext &CTX, QualType ElementType, QualType CanonicalType,
+    Expr *RowExpr, Expr *ColumnExpr, SourceLocation loc)
+    : MatrixType(DependentSizedMatrix, ElementType, CanonicalType, RowExpr,
+                 ColumnExpr),
+      Context(CTX), RowExpr(RowExpr), ColumnExpr(ColumnExpr), loc(loc) {}
+
+void DependentSizedMatrixType::Profile(llvm::FoldingSetNodeID &ID,
+                                       const ASTContext &CTX,
+                                       QualType ElementType, Expr *RowExpr,
+                                       Expr *ColumnExpr) {
+  ID.AddPointer(ElementType.getAsOpaquePtr());
+  RowExpr->Profile(ID, CTX, true);
+  ColumnExpr->Profile(ID, CTX, true);
+}
+
 VectorType::VectorType(QualType vecType, unsigned nElements, QualType canonType,
                        VectorKind vecKind)
     : VectorType(Vector, vecType, nElements, canonType, vecKind) {}
@@ -971,6 +1018,17 @@ struct SimpleTransformVisitor : public TypeVisitor<Derived, QualType> {
     return Ctx.getExtVectorType(elementType, T->getNumElements());
   }
 
+  QualType VisitConstantMatrixType(const ConstantMatrixType *T) {
+    QualType elementType = recurse(T->getElementType());
+    if (elementType.isNull())
+      return {};
+    if (elementType.getAsOpaquePtr() == T->getElementType().getAsOpaquePtr())
+      return QualType(T, 0);
+
+    return Ctx.getConstantMatrixType(elementType, T->getNumRows(),
+                                     T->getNumColumns());
+  }
+
   QualType VisitFunctionNoProtoType(const FunctionNoProtoType *T) {
     QualType returnType = recurse(T->getReturnType());
     if (returnType.isNull())
@@ -1790,6 +1848,14 @@ namespace {
       return Visit(T->getElementType());
     }
 
+    Type *VisitDependentSizedMatrixType(const DependentSizedMatrixType *T) {
+      return Visit(T->getElementType());
+    }
+
+    Type *VisitConstantMatrixType(const ConstantMatrixType *T) {
+      return Visit(T->getElementType());
+    }
+
     Type *VisitFunctionProtoType(const FunctionProtoType *T) {
       if (Syntactic && T->hasTrailingReturn())
         return const_cast<FunctionProtoType*>(T);
@@ -3744,6 +3810,8 @@ static CachedProperties computeCachedProperties(const Type *T) {
   case Type::Vector:
   case Type::ExtVector:
     return Cache::get(cast<VectorType>(T)->getElementType());
+  case Type::ConstantMatrix:
+    return Cache::get(cast<ConstantMatrixType>(T)->getElementType());
   case Type::FunctionNoProto:
     return Cache::get(cast<FunctionType>(T)->getReturnType());
   case Type::FunctionProto: {
@@ -3830,6 +3898,9 @@ LinkageInfo LinkageComputer::computeTypeLinkageInfo(const Type *T) {
   case Type::Vector:
   case Type::ExtVector:
     return computeTypeLinkageInfo(cast<VectorType>(T)->getElementType());
+  case Type::ConstantMatrix:
+    return computeTypeLinkageInfo(
+        cast<ConstantMatrixType>(T)->getElementType());
   case Type::FunctionNoProto:
     return computeTypeLinkageInfo(cast<FunctionType>(T)->getReturnType());
   case Type::FunctionProto: {
@@ -3993,6 +4064,8 @@ bool Type::canHaveNullability(bool ResultIfUnknown) const {
   case Type::DependentSizedExtVector:
   case Type::Vector:
   case Type::ExtVector:
+  case Type::ConstantMatrix:
+  case Type::DependentSizedMatrix:
   case Type::DependentAddressSpace:
   case Type::FunctionProto:
   case Type::FunctionNoProto:

diff  --git a/clang/lib/AST/TypePrinter.cpp b/clang/lib/AST/TypePrinter.cpp
index cf82d1a26156..6f6932e65214 100644
--- a/clang/lib/AST/TypePrinter.cpp
+++ b/clang/lib/AST/TypePrinter.cpp
@@ -256,6 +256,8 @@ bool TypePrinter::canPrefixQualifiers(const Type *T,
     case Type::DependentSizedExtVector:
     case Type::Vector:
     case Type::ExtVector:
+    case Type::ConstantMatrix:
+    case Type::DependentSizedMatrix:
     case Type::FunctionProto:
     case Type::FunctionNoProto:
     case Type::Paren:
@@ -720,6 +722,38 @@ void TypePrinter::printExtVectorAfter(const ExtVectorType *T, raw_ostream &OS) {
   OS << ")))";
 }
 
+void TypePrinter::printConstantMatrixBefore(const ConstantMatrixType *T,
+                                            raw_ostream &OS) {
+  printBefore(T->getElementType(), OS);
+  OS << " __attribute__((matrix_type(";
+  OS << T->getNumRows() << ", " << T->getNumColumns();
+  OS << ")))";
+}
+
+void TypePrinter::printConstantMatrixAfter(const ConstantMatrixType *T,
+                                           raw_ostream &OS) {
+  printAfter(T->getElementType(), OS);
+}
+
+void TypePrinter::printDependentSizedMatrixBefore(
+    const DependentSizedMatrixType *T, raw_ostream &OS) {
+  printBefore(T->getElementType(), OS);
+  OS << " __attribute__((matrix_type(";
+  if (T->getRowExpr()) {
+    T->getRowExpr()->printPretty(OS, nullptr, Policy);
+  }
+  OS << ", ";
+  if (T->getColumnExpr()) {
+    T->getColumnExpr()->printPretty(OS, nullptr, Policy);
+  }
+  OS << ")))";
+}
+
+void TypePrinter::printDependentSizedMatrixAfter(
+    const DependentSizedMatrixType *T, raw_ostream &OS) {
+  printAfter(T->getElementType(), OS);
+}
+
 void
 FunctionProtoType::printExceptionSpecification(raw_ostream &OS,
                                                const PrintingPolicy &Policy)

diff  --git a/clang/lib/CodeGen/CGDebugInfo.cpp b/clang/lib/CodeGen/CGDebugInfo.cpp
index e6422a7ff1c3..0c23b16a78d8 100644
--- a/clang/lib/CodeGen/CGDebugInfo.cpp
+++ b/clang/lib/CodeGen/CGDebugInfo.cpp
@@ -2736,6 +2736,23 @@ llvm::DIType *CGDebugInfo::CreateType(const VectorType *Ty,
   return DBuilder.createVectorType(Size, Align, ElementTy, SubscriptArray);
 }
 
+llvm::DIType *CGDebugInfo::CreateType(const ConstantMatrixType *Ty,
+                                      llvm::DIFile *Unit) {
+  // FIXME: Create another debug type for matrices
+  // For the time being, it treats it like a nested ArrayType.
+
+  llvm::DIType *ElementTy = getOrCreateType(Ty->getElementType(), Unit);
+  uint64_t Size = CGM.getContext().getTypeSize(Ty);
+  uint32_t Align = getTypeAlignIfRequired(Ty, CGM.getContext());
+
+  // Create ranges for both dimensions.
+  llvm::SmallVector<llvm::Metadata *, 2> Subscripts;
+  Subscripts.push_back(DBuilder.getOrCreateSubrange(0, Ty->getNumColumns()));
+  Subscripts.push_back(DBuilder.getOrCreateSubrange(0, Ty->getNumRows()));
+  llvm::DINodeArray SubscriptArray = DBuilder.getOrCreateArray(Subscripts);
+  return DBuilder.createArrayType(Size, Align, ElementTy, SubscriptArray);
+}
+
 llvm::DIType *CGDebugInfo::CreateType(const ArrayType *Ty, llvm::DIFile *Unit) {
   uint64_t Size;
   uint32_t Align;
@@ -3129,6 +3146,8 @@ llvm::DIType *CGDebugInfo::CreateTypeNode(QualType Ty, llvm::DIFile *Unit) {
   case Type::ExtVector:
   case Type::Vector:
     return CreateType(cast<VectorType>(Ty), Unit);
+  case Type::ConstantMatrix:
+    return CreateType(cast<ConstantMatrixType>(Ty), Unit);
   case Type::ObjCObjectPointer:
     return CreateType(cast<ObjCObjectPointerType>(Ty), Unit);
   case Type::ObjCObject:

diff  --git a/clang/lib/CodeGen/CGDebugInfo.h b/clang/lib/CodeGen/CGDebugInfo.h
index 34164fbec90e..367047e79dc9 100644
--- a/clang/lib/CodeGen/CGDebugInfo.h
+++ b/clang/lib/CodeGen/CGDebugInfo.h
@@ -192,6 +192,7 @@ class CGDebugInfo {
   llvm::DIType *CreateType(const ObjCTypeParamType *Ty, llvm::DIFile *Unit);
 
   llvm::DIType *CreateType(const VectorType *Ty, llvm::DIFile *F);
+  llvm::DIType *CreateType(const ConstantMatrixType *Ty, llvm::DIFile *F);
   llvm::DIType *CreateType(const ArrayType *Ty, llvm::DIFile *F);
   llvm::DIType *CreateType(const LValueReferenceType *Ty, llvm::DIFile *F);
   llvm::DIType *CreateType(const RValueReferenceType *Ty, llvm::DIFile *Unit);

diff  --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index 85c2d318c196..1701c906e193 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -145,8 +145,19 @@ Address CodeGenFunction::CreateMemTemp(QualType Ty, const Twine &Name,
 
 Address CodeGenFunction::CreateMemTemp(QualType Ty, CharUnits Align,
                                        const Twine &Name, Address *Alloca) {
-  return CreateTempAlloca(ConvertTypeForMem(Ty), Align, Name,
-                          /*ArraySize=*/nullptr, Alloca);
+  Address Result = CreateTempAlloca(ConvertTypeForMem(Ty), Align, Name,
+                                    /*ArraySize=*/nullptr, Alloca);
+
+  if (Ty->isConstantMatrixType()) {
+    auto *ArrayTy = cast<llvm::ArrayType>(Result.getType()->getElementType());
+    auto *VectorTy = llvm::VectorType::get(ArrayTy->getElementType(),
+                                           ArrayTy->getNumElements());
+
+    Result = Address(
+        Builder.CreateBitCast(Result.getPointer(), VectorTy->getPointerTo()),
+        Result.getAlignment());
+  }
+  return Result;
 }
 
 Address CodeGenFunction::CreateMemTempWithoutCast(QualType Ty, CharUnits Align,
@@ -1732,6 +1743,42 @@ llvm::Value *CodeGenFunction::EmitFromMemory(llvm::Value *Value, QualType Ty) {
   return Value;
 }
 
+// Convert the pointer of \p Addr to a pointer to a vector (the value type of
+// MatrixType), if it points to a array (the memory type of MatrixType).
+static Address MaybeConvertMatrixAddress(Address Addr, CodeGenFunction &CGF,
+                                         bool IsVector = true) {
+  auto *ArrayTy = dyn_cast<llvm::ArrayType>(
+      cast<llvm::PointerType>(Addr.getPointer()->getType())->getElementType());
+  if (ArrayTy && IsVector) {
+    auto *VectorTy = llvm::VectorType::get(ArrayTy->getElementType(),
+                                           ArrayTy->getNumElements());
+
+    return Address(CGF.Builder.CreateElementBitCast(Addr, VectorTy));
+  }
+  auto *VectorTy = dyn_cast<llvm::VectorType>(
+      cast<llvm::PointerType>(Addr.getPointer()->getType())->getElementType());
+  if (VectorTy && !IsVector) {
+    auto *ArrayTy = llvm::ArrayType::get(VectorTy->getElementType(),
+                                         VectorTy->getNumElements());
+
+    return Address(CGF.Builder.CreateElementBitCast(Addr, ArrayTy));
+  }
+
+  return Addr;
+}
+
+// Emit a store of a matrix LValue. This may require casting the original
+// pointer to memory address (ArrayType) to a pointer to the value type
+// (VectorType).
+static void EmitStoreOfMatrixScalar(llvm::Value *value, LValue lvalue,
+                                    bool isInit, CodeGenFunction &CGF) {
+  Address Addr = MaybeConvertMatrixAddress(lvalue.getAddress(CGF), CGF,
+                                           value->getType()->isVectorTy());
+  CGF.EmitStoreOfScalar(value, Addr, lvalue.isVolatile(), lvalue.getType(),
+                        lvalue.getBaseInfo(), lvalue.getTBAAInfo(), isInit,
+                        lvalue.isNontemporal());
+}
+
 void CodeGenFunction::EmitStoreOfScalar(llvm::Value *Value, Address Addr,
                                         bool Volatile, QualType Ty,
                                         LValueBaseInfo BaseInfo,
@@ -1779,11 +1826,26 @@ void CodeGenFunction::EmitStoreOfScalar(llvm::Value *Value, Address Addr,
 
 void CodeGenFunction::EmitStoreOfScalar(llvm::Value *value, LValue lvalue,
                                         bool isInit) {
+  if (lvalue.getType()->isConstantMatrixType()) {
+    EmitStoreOfMatrixScalar(value, lvalue, isInit, *this);
+    return;
+  }
+
   EmitStoreOfScalar(value, lvalue.getAddress(*this), lvalue.isVolatile(),
                     lvalue.getType(), lvalue.getBaseInfo(),
                     lvalue.getTBAAInfo(), isInit, lvalue.isNontemporal());
 }
 
+// Emit a load of a LValue of matrix type. This may require casting the pointer
+// to memory address (ArrayType) to a pointer to the value type (VectorType).
+static RValue EmitLoadOfMatrixLValue(LValue LV, SourceLocation Loc,
+                                     CodeGenFunction &CGF) {
+  assert(LV.getType()->isConstantMatrixType());
+  Address Addr = MaybeConvertMatrixAddress(LV.getAddress(CGF), CGF);
+  LV.setAddress(Addr);
+  return RValue::get(CGF.EmitLoadOfScalar(LV, Loc));
+}
+
 /// EmitLoadOfLValue - Given an expression that represents a value lvalue, this
 /// method emits the address of the lvalue, then loads the result as an rvalue,
 /// returning the rvalue.
@@ -1809,6 +1871,9 @@ RValue CodeGenFunction::EmitLoadOfLValue(LValue LV, SourceLocation Loc) {
   if (LV.isSimple()) {
     assert(!LV.getType()->isFunctionType());
 
+    if (LV.getType()->isConstantMatrixType())
+      return EmitLoadOfMatrixLValue(LV, Loc, *this);
+
     // Everything needs a load.
     return RValue::get(EmitLoadOfScalar(LV, Loc));
   }

diff  --git a/clang/lib/CodeGen/CodeGenFunction.cpp b/clang/lib/CodeGen/CodeGenFunction.cpp
index 4fcf31a5f1aa..cbe4823aca58 100644
--- a/clang/lib/CodeGen/CodeGenFunction.cpp
+++ b/clang/lib/CodeGen/CodeGenFunction.cpp
@@ -247,6 +247,7 @@ TypeEvaluationKind CodeGenFunction::getEvaluationKind(QualType type) {
     case Type::MemberPointer:
     case Type::Vector:
     case Type::ExtVector:
+    case Type::ConstantMatrix:
     case Type::FunctionProto:
     case Type::FunctionNoProto:
     case Type::Enum:
@@ -2000,6 +2001,7 @@ void CodeGenFunction::EmitVariablyModifiedType(QualType type) {
     case Type::Complex:
     case Type::Vector:
     case Type::ExtVector:
+    case Type::ConstantMatrix:
     case Type::Record:
     case Type::Enum:
     case Type::Elaborated:

diff  --git a/clang/lib/CodeGen/CodeGenTypes.cpp b/clang/lib/CodeGen/CodeGenTypes.cpp
index 7ae9e91c8a3b..19b8ff3e8b3f 100644
--- a/clang/lib/CodeGen/CodeGenTypes.cpp
+++ b/clang/lib/CodeGen/CodeGenTypes.cpp
@@ -82,6 +82,13 @@ void CodeGenTypes::addRecordTypeName(const RecordDecl *RD,
 /// a type.  For example, the scalar representation for _Bool is i1, but the
 /// memory representation is usually i8 or i32, depending on the target.
 llvm::Type *CodeGenTypes::ConvertTypeForMem(QualType T, bool ForBitField) {
+  if (T->isConstantMatrixType()) {
+    const Type *Ty = Context.getCanonicalType(T).getTypePtr();
+    const ConstantMatrixType *MT = cast<ConstantMatrixType>(Ty);
+    return llvm::ArrayType::get(ConvertType(MT->getElementType()),
+                                MT->getNumRows() * MT->getNumColumns());
+  }
+
   llvm::Type *R = ConvertType(T);
 
   // If this is a bool type, or an ExtIntType in a bitfield representation,
@@ -646,6 +653,12 @@ llvm::Type *CodeGenTypes::ConvertType(QualType T) {
                                        VT->getNumElements());
     break;
   }
+  case Type::ConstantMatrix: {
+    const ConstantMatrixType *MT = cast<ConstantMatrixType>(Ty);
+    ResultType = llvm::VectorType::get(ConvertType(MT->getElementType()),
+                                       MT->getNumRows() * MT->getNumColumns());
+    break;
+  }
   case Type::FunctionNoProto:
   case Type::FunctionProto:
     ResultType = ConvertFunctionTypeInternal(T);

diff  --git a/clang/lib/CodeGen/ItaniumCXXABI.cpp b/clang/lib/CodeGen/ItaniumCXXABI.cpp
index 4a591cf7aac5..ceef143e9017 100644
--- a/clang/lib/CodeGen/ItaniumCXXABI.cpp
+++ b/clang/lib/CodeGen/ItaniumCXXABI.cpp
@@ -3223,6 +3223,7 @@ void ItaniumRTTIBuilder::BuildVTablePointer(const Type *Ty) {
   // GCC treats vector and complex types as fundamental types.
   case Type::Vector:
   case Type::ExtVector:
+  case Type::ConstantMatrix:
   case Type::Complex:
   case Type::Atomic:
   // FIXME: GCC treats block pointers as fundamental types?!
@@ -3458,6 +3459,7 @@ llvm::Constant *ItaniumRTTIBuilder::BuildTypeInfo(
   case Type::Builtin:
   case Type::Vector:
   case Type::ExtVector:
+  case Type::ConstantMatrix:
   case Type::Complex:
   case Type::BlockPointer:
     // Itanium C++ ABI 2.9.5p4:

diff  --git a/clang/lib/Driver/ToolChains/Clang.cpp b/clang/lib/Driver/ToolChains/Clang.cpp
index 42d5af71c23c..d6f05bea4a29 100644
--- a/clang/lib/Driver/ToolChains/Clang.cpp
+++ b/clang/lib/Driver/ToolChains/Clang.cpp
@@ -4566,6 +4566,13 @@ void Clang::ConstructJob(Compilation &C, const JobAction &JA,
   if (Args.hasFlag(options::OPT_mrtd, options::OPT_mno_rtd, false))
     CmdArgs.push_back("-fdefault-calling-conv=stdcall");
 
+  if (Args.hasArg(options::OPT_fenable_matrix)) {
+    // enable-matrix is needed by both the LangOpts and by LLVM.
+    CmdArgs.push_back("-fenable-matrix");
+    CmdArgs.push_back("-mllvm");
+    CmdArgs.push_back("-enable-matrix");
+  }
+
   CodeGenOptions::FramePointerKind FPKeepKind =
                   getFramePointerKind(Args, RawTriple);
   const char *FPKeepKindStr = nullptr;

diff  --git a/clang/lib/Frontend/CompilerInvocation.cpp b/clang/lib/Frontend/CompilerInvocation.cpp
index c5166ad41160..ec415472ea3c 100644
--- a/clang/lib/Frontend/CompilerInvocation.cpp
+++ b/clang/lib/Frontend/CompilerInvocation.cpp
@@ -3337,6 +3337,8 @@ static void ParseLangArgs(LangOptions &Opts, ArgList &Args, InputKind IK,
   Opts.CompleteMemberPointers = Args.hasArg(OPT_fcomplete_member_pointers);
   Opts.BuildingPCHWithObjectFile = Args.hasArg(OPT_building_pch_with_obj);
 
+  Opts.MatrixTypes = Args.hasArg(OPT_fenable_matrix);
+
   Opts.MaxTokens = getLastArgIntValue(Args, OPT_fmax_tokens_EQ, 0, Diags);
 
   if (Arg *A = Args.getLastArg(OPT_msign_return_address_EQ)) {

diff  --git a/clang/lib/Sema/SemaExpr.cpp b/clang/lib/Sema/SemaExpr.cpp
index 948880e1c136..6f25f71c77dd 100644
--- a/clang/lib/Sema/SemaExpr.cpp
+++ b/clang/lib/Sema/SemaExpr.cpp
@@ -4257,6 +4257,7 @@ static void captureVariablyModifiedType(ASTContext &Context, QualType T,
     case Type::Complex:
     case Type::Vector:
     case Type::ExtVector:
+    case Type::ConstantMatrix:
     case Type::Record:
     case Type::Enum:
     case Type::Elaborated:

diff  --git a/clang/lib/Sema/SemaLookup.cpp b/clang/lib/Sema/SemaLookup.cpp
index aa83f82a1ce8..dbbd190caea9 100644
--- a/clang/lib/Sema/SemaLookup.cpp
+++ b/clang/lib/Sema/SemaLookup.cpp
@@ -2966,6 +2966,7 @@ addAssociatedClassesAndNamespaces(AssociatedLookup &Result, QualType Ty) {
     // These are fundamental types.
     case Type::Vector:
     case Type::ExtVector:
+    case Type::ConstantMatrix:
     case Type::Complex:
     case Type::ExtInt:
       break;

diff  --git a/clang/lib/Sema/SemaTemplate.cpp b/clang/lib/Sema/SemaTemplate.cpp
old mode 100755
new mode 100644
index 66b8e4d97c07..c2324f976eba
--- a/clang/lib/Sema/SemaTemplate.cpp
+++ b/clang/lib/Sema/SemaTemplate.cpp
@@ -5867,6 +5867,11 @@ bool UnnamedLocalNoLinkageFinder::VisitDependentSizedExtVectorType(
   return Visit(T->getElementType());
 }
 
+bool UnnamedLocalNoLinkageFinder::VisitDependentSizedMatrixType(
+    const DependentSizedMatrixType *T) {
+  return Visit(T->getElementType());
+}
+
 bool UnnamedLocalNoLinkageFinder::VisitDependentAddressSpaceType(
     const DependentAddressSpaceType *T) {
   return Visit(T->getPointeeType());
@@ -5885,6 +5890,11 @@ bool UnnamedLocalNoLinkageFinder::VisitExtVectorType(const ExtVectorType* T) {
   return Visit(T->getElementType());
 }
 
+bool UnnamedLocalNoLinkageFinder::VisitConstantMatrixType(
+    const ConstantMatrixType *T) {
+  return Visit(T->getElementType());
+}
+
 bool UnnamedLocalNoLinkageFinder::VisitFunctionProtoType(
                                                   const FunctionProtoType* T) {
   for (const auto &A : T->param_types()) {

diff  --git a/clang/lib/Sema/SemaTemplateDeduction.cpp b/clang/lib/Sema/SemaTemplateDeduction.cpp
index e1d438fcb724..19f8248db6bf 100644
--- a/clang/lib/Sema/SemaTemplateDeduction.cpp
+++ b/clang/lib/Sema/SemaTemplateDeduction.cpp
@@ -2055,6 +2055,101 @@ DeduceTemplateArgumentsByTypeMatch(Sema &S,
       return Sema::TDK_NonDeducedMismatch;
     }
 
+    //     (clang extension)
+    //
+    //     T __attribute__((matrix_type(<integral constant>,
+    //                                  <integral constant>)))
+    case Type::ConstantMatrix: {
+      const ConstantMatrixType *MatrixArg = dyn_cast<ConstantMatrixType>(Arg);
+      if (!MatrixArg)
+        return Sema::TDK_NonDeducedMismatch;
+
+      const ConstantMatrixType *MatrixParam = cast<ConstantMatrixType>(Param);
+      // Check that the dimensions are the same
+      if (MatrixParam->getNumRows() != MatrixArg->getNumRows() ||
+          MatrixParam->getNumColumns() != MatrixArg->getNumColumns()) {
+        return Sema::TDK_NonDeducedMismatch;
+      }
+      // Perform deduction on element types.
+      return DeduceTemplateArgumentsByTypeMatch(
+          S, TemplateParams, MatrixParam->getElementType(),
+          MatrixArg->getElementType(), Info, Deduced, TDF);
+    }
+
+    case Type::DependentSizedMatrix: {
+      const MatrixType *MatrixArg = dyn_cast<MatrixType>(Arg);
+      if (!MatrixArg)
+        return Sema::TDK_NonDeducedMismatch;
+
+      // Check the element type of the matrixes.
+      const DependentSizedMatrixType *MatrixParam =
+          cast<DependentSizedMatrixType>(Param);
+      if (Sema::TemplateDeductionResult Result =
+              DeduceTemplateArgumentsByTypeMatch(
+                  S, TemplateParams, MatrixParam->getElementType(),
+                  MatrixArg->getElementType(), Info, Deduced, TDF))
+        return Result;
+
+      // Try to deduce a matrix dimension.
+      auto DeduceMatrixArg =
+          [&S, &Info, &Deduced, &TemplateParams](
+              Expr *ParamExpr, const MatrixType *Arg,
+              unsigned (ConstantMatrixType::*GetArgDimension)() const,
+              Expr *(DependentSizedMatrixType::*GetArgDimensionExpr)() const) {
+            const auto *ArgConstMatrix = dyn_cast<ConstantMatrixType>(Arg);
+            const auto *ArgDepMatrix = dyn_cast<DependentSizedMatrixType>(Arg);
+            if (!ParamExpr->isValueDependent()) {
+              llvm::APSInt ParamConst(
+                  S.Context.getTypeSize(S.Context.getSizeType()));
+              if (!ParamExpr->isIntegerConstantExpr(ParamConst, S.Context))
+                return Sema::TDK_NonDeducedMismatch;
+
+              if (ArgConstMatrix) {
+                if ((ArgConstMatrix->*GetArgDimension)() == ParamConst)
+                  return Sema::TDK_Success;
+                return Sema::TDK_NonDeducedMismatch;
+              }
+
+              Expr *ArgExpr = (ArgDepMatrix->*GetArgDimensionExpr)();
+              llvm::APSInt ArgConst(
+                  S.Context.getTypeSize(S.Context.getSizeType()));
+              if (!ArgExpr->isValueDependent() &&
+                  ArgExpr->isIntegerConstantExpr(ArgConst, S.Context) &&
+                  ArgConst == ParamConst)
+                return Sema::TDK_Success;
+              return Sema::TDK_NonDeducedMismatch;
+            }
+
+            NonTypeTemplateParmDecl *NTTP =
+                getDeducedParameterFromExpr(Info, ParamExpr);
+            if (!NTTP)
+              return Sema::TDK_Success;
+
+            if (ArgConstMatrix) {
+              llvm::APSInt ArgConst(
+                  S.Context.getTypeSize(S.Context.getSizeType()));
+              ArgConst = (ArgConstMatrix->*GetArgDimension)();
+              return DeduceNonTypeTemplateArgument(
+                  S, TemplateParams, NTTP, ArgConst, S.Context.getSizeType(),
+                  /*ArrayBound=*/true, Info, Deduced);
+            }
+
+            return DeduceNonTypeTemplateArgument(
+                S, TemplateParams, NTTP, (ArgDepMatrix->*GetArgDimensionExpr)(),
+                Info, Deduced);
+          };
+
+      auto Result = DeduceMatrixArg(MatrixParam->getRowExpr(), MatrixArg,
+                                    &ConstantMatrixType::getNumRows,
+                                    &DependentSizedMatrixType::getRowExpr);
+      if (Result)
+        return Result;
+
+      return DeduceMatrixArg(MatrixParam->getColumnExpr(), MatrixArg,
+                             &ConstantMatrixType::getNumColumns,
+                             &DependentSizedMatrixType::getColumnExpr);
+    }
+
     //     (clang extension)
     //
     //     T __attribute__(((address_space(N))))
@@ -5723,6 +5818,24 @@ MarkUsedTemplateParameters(ASTContext &Ctx, QualType T,
     break;
   }
 
+  case Type::ConstantMatrix: {
+    const ConstantMatrixType *MatType = cast<ConstantMatrixType>(T);
+    MarkUsedTemplateParameters(Ctx, MatType->getElementType(), OnlyDeduced,
+                               Depth, Used);
+    break;
+  }
+
+  case Type::DependentSizedMatrix: {
+    const DependentSizedMatrixType *MatType = cast<DependentSizedMatrixType>(T);
+    MarkUsedTemplateParameters(Ctx, MatType->getElementType(), OnlyDeduced,
+                               Depth, Used);
+    MarkUsedTemplateParameters(Ctx, MatType->getRowExpr(), OnlyDeduced, Depth,
+                               Used);
+    MarkUsedTemplateParameters(Ctx, MatType->getColumnExpr(), OnlyDeduced,
+                               Depth, Used);
+    break;
+  }
+
   case Type::FunctionProto: {
     const FunctionProtoType *Proto = cast<FunctionProtoType>(T);
     MarkUsedTemplateParameters(Ctx, Proto->getReturnType(), OnlyDeduced, Depth,

diff  --git a/clang/lib/Sema/SemaType.cpp b/clang/lib/Sema/SemaType.cpp
index 338e33589b1d..df8ad7ad78b9 100644
--- a/clang/lib/Sema/SemaType.cpp
+++ b/clang/lib/Sema/SemaType.cpp
@@ -2492,14 +2492,15 @@ QualType Sema::BuildVectorType(QualType CurType, Expr *SizeExpr,
   if (!VecSize.isIntN(61)) {
     // Bit size will overflow uint64.
     Diag(AttrLoc, diag::err_attribute_size_too_large)
-        << SizeExpr->getSourceRange();
+        << SizeExpr->getSourceRange() << "vector";
     return QualType();
   }
   uint64_t VectorSizeBits = VecSize.getZExtValue() * 8;
   unsigned TypeSize = static_cast<unsigned>(Context.getTypeSize(CurType));
 
   if (VectorSizeBits == 0) {
-    Diag(AttrLoc, diag::err_attribute_zero_size) << SizeExpr->getSourceRange();
+    Diag(AttrLoc, diag::err_attribute_zero_size)
+        << SizeExpr->getSourceRange() << "vector";
     return QualType();
   }
 
@@ -2511,7 +2512,7 @@ QualType Sema::BuildVectorType(QualType CurType, Expr *SizeExpr,
 
   if (VectorSizeBits / TypeSize > std::numeric_limits<uint32_t>::max()) {
     Diag(AttrLoc, diag::err_attribute_size_too_large)
-        << SizeExpr->getSourceRange();
+        << SizeExpr->getSourceRange() << "vector";
     return QualType();
   }
 
@@ -2549,7 +2550,7 @@ QualType Sema::BuildExtVectorType(QualType T, Expr *ArraySize,
 
     if (!vecSize.isIntN(32)) {
       Diag(AttrLoc, diag::err_attribute_size_too_large)
-          << ArraySize->getSourceRange();
+          << ArraySize->getSourceRange() << "vector";
       return QualType();
     }
     // Unlike gcc's vector_size attribute, the size is specified as the
@@ -2558,7 +2559,7 @@ QualType Sema::BuildExtVectorType(QualType T, Expr *ArraySize,
 
     if (vectorSize == 0) {
       Diag(AttrLoc, diag::err_attribute_zero_size)
-      << ArraySize->getSourceRange();
+          << ArraySize->getSourceRange() << "vector";
       return QualType();
     }
 
@@ -2568,6 +2569,84 @@ QualType Sema::BuildExtVectorType(QualType T, Expr *ArraySize,
   return Context.getDependentSizedExtVectorType(T, ArraySize, AttrLoc);
 }
 
+QualType Sema::BuildMatrixType(QualType ElementTy, Expr *NumRows, Expr *NumCols,
+                               SourceLocation AttrLoc) {
+  assert(Context.getLangOpts().MatrixTypes &&
+         "Should never build a matrix type when it is disabled");
+
+  if (NumRows->isTypeDependent() || NumCols->isTypeDependent() ||
+      NumRows->isValueDependent() || NumCols->isValueDependent())
+    return Context.getDependentSizedMatrixType(ElementTy, NumRows, NumCols,
+                                               AttrLoc);
+
+  // Check element type, if it is not dependent.
+  if (!ElementTy->isDependentType() &&
+      !MatrixType::isValidElementType(ElementTy)) {
+    Diag(AttrLoc, diag::err_attribute_invalid_matrix_type) << ElementTy;
+    return QualType();
+  }
+
+  // Both row and column values can only be 20 bit wide currently.
+  llvm::APSInt ValueRows(32), ValueColumns(32);
+
+  bool const RowsIsInteger = NumRows->isIntegerConstantExpr(ValueRows, Context);
+  bool const ColumnsIsInteger =
+      NumCols->isIntegerConstantExpr(ValueColumns, Context);
+
+  auto const RowRange = NumRows->getSourceRange();
+  auto const ColRange = NumCols->getSourceRange();
+
+  // Both are row and column expressions are invalid.
+  if (!RowsIsInteger && !ColumnsIsInteger) {
+    Diag(AttrLoc, diag::err_attribute_argument_type)
+        << "matrix_type" << AANT_ArgumentIntegerConstant << RowRange
+        << ColRange;
+    return QualType();
+  }
+
+  // Only the row expression is invalid.
+  if (!RowsIsInteger) {
+    Diag(AttrLoc, diag::err_attribute_argument_type)
+        << "matrix_type" << AANT_ArgumentIntegerConstant << RowRange;
+    return QualType();
+  }
+
+  // Only the column expression is invalid.
+  if (!ColumnsIsInteger) {
+    Diag(AttrLoc, diag::err_attribute_argument_type)
+        << "matrix_type" << AANT_ArgumentIntegerConstant << ColRange;
+    return QualType();
+  }
+
+  // Check the matrix dimensions.
+  unsigned MatrixRows = static_cast<unsigned>(ValueRows.getZExtValue());
+  unsigned MatrixColumns = static_cast<unsigned>(ValueColumns.getZExtValue());
+  if (MatrixRows == 0 && MatrixColumns == 0) {
+    Diag(AttrLoc, diag::err_attribute_zero_size)
+        << "matrix" << RowRange << ColRange;
+    return QualType();
+  }
+  if (MatrixRows == 0) {
+    Diag(AttrLoc, diag::err_attribute_zero_size) << "matrix" << RowRange;
+    return QualType();
+  }
+  if (MatrixColumns == 0) {
+    Diag(AttrLoc, diag::err_attribute_zero_size) << "matrix" << ColRange;
+    return QualType();
+  }
+  if (!ConstantMatrixType::isDimensionValid(MatrixRows)) {
+    Diag(AttrLoc, diag::err_attribute_size_too_large)
+        << RowRange << "matrix row";
+    return QualType();
+  }
+  if (!ConstantMatrixType::isDimensionValid(MatrixColumns)) {
+    Diag(AttrLoc, diag::err_attribute_size_too_large)
+        << ColRange << "matrix column";
+    return QualType();
+  }
+  return Context.getConstantMatrixType(ElementTy, MatrixRows, MatrixColumns);
+}
+
 bool Sema::CheckFunctionReturnType(QualType T, SourceLocation Loc) {
   if (T->isArrayType() || T->isFunctionType()) {
     Diag(Loc, diag::err_func_returning_array_function)
@@ -6013,6 +6092,21 @@ fillDependentAddressSpaceTypeLoc(DependentAddressSpaceTypeLoc DASTL,
       "no address_space attribute found at the expected location!");
 }
 
+static void fillMatrixTypeLoc(MatrixTypeLoc MTL,
+                              const ParsedAttributesView &Attrs) {
+  for (const ParsedAttr &AL : Attrs) {
+    if (AL.getKind() == ParsedAttr::AT_MatrixType) {
+      MTL.setAttrNameLoc(AL.getLoc());
+      MTL.setAttrRowOperand(AL.getArgAsExpr(0));
+      MTL.setAttrColumnOperand(AL.getArgAsExpr(1));
+      MTL.setAttrOperandParensRange(SourceRange());
+      return;
+    }
+  }
+
+  llvm_unreachable("no matrix_type attribute found at the expected location!");
+}
+
 /// Create and instantiate a TypeSourceInfo with type source information.
 ///
 /// \param T QualType referring to the type as written in source code.
@@ -6061,6 +6155,9 @@ GetTypeSourceInfoForDeclarator(TypeProcessingState &State,
       CurrTL = TL.getPointeeTypeLoc().getUnqualifiedLoc();
     }
 
+    if (MatrixTypeLoc TL = CurrTL.getAs<MatrixTypeLoc>())
+      fillMatrixTypeLoc(TL, D.getTypeObject(i).getAttrs());
+
     // FIXME: Ordering here?
     while (AdjustedTypeLoc TL = CurrTL.getAs<AdjustedTypeLoc>())
       CurrTL = TL.getNextTypeLoc().getUnqualifiedLoc();
@@ -7706,6 +7803,68 @@ static void HandleOpenCLAccessAttr(QualType &CurType, const ParsedAttr &Attr,
   }
 }
 
+/// HandleMatrixTypeAttr - "matrix_type" attribute, like ext_vector_type
+static void HandleMatrixTypeAttr(QualType &CurType, const ParsedAttr &Attr,
+                                 Sema &S) {
+  if (!S.getLangOpts().MatrixTypes) {
+    S.Diag(Attr.getLoc(), diag::err_builtin_matrix_disabled);
+    return;
+  }
+
+  if (Attr.getNumArgs() != 2) {
+    S.Diag(Attr.getLoc(), diag::err_attribute_wrong_number_arguments)
+        << Attr << 2;
+    return;
+  }
+
+  Expr *RowsExpr = nullptr;
+  Expr *ColsExpr = nullptr;
+
+  // TODO: Refactor parameter extraction into separate function
+  // Get the number of rows
+  if (Attr.isArgIdent(0)) {
+    CXXScopeSpec SS;
+    SourceLocation TemplateKeywordLoc;
+    UnqualifiedId id;
+    id.setIdentifier(Attr.getArgAsIdent(0)->Ident, Attr.getLoc());
+    ExprResult Rows = S.ActOnIdExpression(S.getCurScope(), SS,
+                                          TemplateKeywordLoc, id, false, false);
+
+    if (Rows.isInvalid())
+      // TODO: maybe a good error message would be nice here
+      return;
+    RowsExpr = Rows.get();
+  } else {
+    assert(Attr.isArgExpr(0) &&
+           "Argument to should either be an identity or expression");
+    RowsExpr = Attr.getArgAsExpr(0);
+  }
+
+  // Get the number of columns
+  if (Attr.isArgIdent(1)) {
+    CXXScopeSpec SS;
+    SourceLocation TemplateKeywordLoc;
+    UnqualifiedId id;
+    id.setIdentifier(Attr.getArgAsIdent(1)->Ident, Attr.getLoc());
+    ExprResult Columns = S.ActOnIdExpression(
+        S.getCurScope(), SS, TemplateKeywordLoc, id, false, false);
+
+    if (Columns.isInvalid())
+      // TODO: a good error message would be nice here
+      return;
+    RowsExpr = Columns.get();
+  } else {
+    assert(Attr.isArgExpr(1) &&
+           "Argument to should either be an identity or expression");
+    ColsExpr = Attr.getArgAsExpr(1);
+  }
+
+  // Create the matrix type.
+  QualType T = S.BuildMatrixType(CurType, RowsExpr, ColsExpr, Attr.getLoc());
+  if (!T.isNull())
+    CurType = T;
+}
+
 static void HandleLifetimeBoundAttr(TypeProcessingState &State,
                                     QualType &CurType,
                                     ParsedAttr &Attr) {
@@ -7857,6 +8016,11 @@ static void processTypeAttrs(TypeProcessingState &state, QualType &type,
       break;
     }
 
+    case ParsedAttr::AT_MatrixType:
+      HandleMatrixTypeAttr(type, attr, state.getSema());
+      attr.setUsedAsTypeAttr();
+      break;
+
     MS_TYPE_ATTRS_CASELIST:
       if (!handleMSPointerTypeQualifierAttr(state, attr, type))
         attr.setUsedAsTypeAttr();

diff  --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h
index 0987fee29182..99f1784e4625 100644
--- a/clang/lib/Sema/TreeTransform.h
+++ b/clang/lib/Sema/TreeTransform.h
@@ -894,6 +894,16 @@ class TreeTransform {
                                               Expr *SizeExpr,
                                               SourceLocation AttributeLoc);
 
+  /// Build a new matrix type given the element type and dimensions.
+  QualType RebuildConstantMatrixType(QualType ElementType, unsigned NumRows,
+                                     unsigned NumColumns);
+
+  /// Build a new matrix type given the type and dependently-defined
+  /// dimensions.
+  QualType RebuildDependentSizedMatrixType(QualType ElementType, Expr *RowExpr,
+                                           Expr *ColumnExpr,
+                                           SourceLocation AttributeLoc);
+
   /// Build a new DependentAddressSpaceType or return the pointee
   /// type variable with the correct address space (retrieved from
   /// AddrSpaceExpr) applied to it. The former will be returned in cases
@@ -5179,6 +5189,86 @@ QualType TreeTransform<Derived>::TransformDependentSizedExtVectorType(
   return Result;
 }
 
+template <typename Derived>
+QualType
+TreeTransform<Derived>::TransformConstantMatrixType(TypeLocBuilder &TLB,
+                                                    ConstantMatrixTypeLoc TL) {
+  const ConstantMatrixType *T = TL.getTypePtr();
+  QualType ElementType = getDerived().TransformType(T->getElementType());
+  if (ElementType.isNull())
+    return QualType();
+
+  QualType Result = TL.getType();
+  if (getDerived().AlwaysRebuild() || ElementType != T->getElementType()) {
+    Result = getDerived().RebuildConstantMatrixType(
+        ElementType, T->getNumRows(), T->getNumColumns());
+    if (Result.isNull())
+      return QualType();
+  }
+
+  ConstantMatrixTypeLoc NewTL = TLB.push<ConstantMatrixTypeLoc>(Result);
+  NewTL.setAttrNameLoc(TL.getAttrNameLoc());
+  NewTL.setAttrOperandParensRange(TL.getAttrOperandParensRange());
+  NewTL.setAttrRowOperand(TL.getAttrRowOperand());
+  NewTL.setAttrColumnOperand(TL.getAttrColumnOperand());
+
+  return Result;
+}
+
+template <typename Derived>
+QualType TreeTransform<Derived>::TransformDependentSizedMatrixType(
+    TypeLocBuilder &TLB, DependentSizedMatrixTypeLoc TL) {
+  const DependentSizedMatrixType *T = TL.getTypePtr();
+
+  QualType ElementType = getDerived().TransformType(T->getElementType());
+  if (ElementType.isNull()) {
+    return QualType();
+  }
+
+  // Matrix dimensions are constant expressions.
+  EnterExpressionEvaluationContext Unevaluated(
+      SemaRef, Sema::ExpressionEvaluationContext::ConstantEvaluated);
+
+  Expr *origRows = TL.getAttrRowOperand();
+  if (!origRows)
+    origRows = T->getRowExpr();
+  Expr *origColumns = TL.getAttrColumnOperand();
+  if (!origColumns)
+    origColumns = T->getColumnExpr();
+
+  ExprResult rowResult = getDerived().TransformExpr(origRows);
+  rowResult = SemaRef.ActOnConstantExpression(rowResult);
+  if (rowResult.isInvalid())
+    return QualType();
+
+  ExprResult columnResult = getDerived().TransformExpr(origColumns);
+  columnResult = SemaRef.ActOnConstantExpression(columnResult);
+  if (columnResult.isInvalid())
+    return QualType();
+
+  Expr *rows = rowResult.get();
+  Expr *columns = columnResult.get();
+
+  QualType Result = TL.getType();
+  if (getDerived().AlwaysRebuild() || ElementType != T->getElementType() ||
+      rows != origRows || columns != origColumns) {
+    Result = getDerived().RebuildDependentSizedMatrixType(
+        ElementType, rows, columns, T->getAttributeLoc());
+
+    if (Result.isNull())
+      return QualType();
+  }
+
+  // We might have any sort of matrix type now, but fortunately they
+  // all have the same location layout.
+  MatrixTypeLoc NewTL = TLB.push<MatrixTypeLoc>(Result);
+  NewTL.setAttrNameLoc(TL.getAttrNameLoc());
+  NewTL.setAttrOperandParensRange(TL.getAttrOperandParensRange());
+  NewTL.setAttrRowOperand(rows);
+  NewTL.setAttrColumnOperand(columns);
+  return Result;
+}
+
 template <typename Derived>
 QualType TreeTransform<Derived>::TransformDependentAddressSpaceType(
     TypeLocBuilder &TLB, DependentAddressSpaceTypeLoc TL) {
@@ -13750,6 +13840,21 @@ TreeTransform<Derived>::RebuildDependentSizedExtVectorType(QualType ElementType,
   return SemaRef.BuildExtVectorType(ElementType, SizeExpr, AttributeLoc);
 }
 
+template <typename Derived>
+QualType TreeTransform<Derived>::RebuildConstantMatrixType(
+    QualType ElementType, unsigned NumRows, unsigned NumColumns) {
+  return SemaRef.Context.getConstantMatrixType(ElementType, NumRows,
+                                               NumColumns);
+}
+
+template <typename Derived>
+QualType TreeTransform<Derived>::RebuildDependentSizedMatrixType(
+    QualType ElementType, Expr *RowExpr, Expr *ColumnExpr,
+    SourceLocation AttributeLoc) {
+  return SemaRef.BuildMatrixType(ElementType, RowExpr, ColumnExpr,
+                                 AttributeLoc);
+}
+
 template<typename Derived>
 QualType TreeTransform<Derived>::RebuildFunctionProtoType(
     QualType T,

diff  --git a/clang/lib/Serialization/ASTReader.cpp b/clang/lib/Serialization/ASTReader.cpp
index 187665be5255..3f4164690c78 100644
--- a/clang/lib/Serialization/ASTReader.cpp
+++ b/clang/lib/Serialization/ASTReader.cpp
@@ -6554,6 +6554,21 @@ void TypeLocReader::VisitExtVectorTypeLoc(ExtVectorTypeLoc TL) {
   TL.setNameLoc(readSourceLocation());
 }
 
+void TypeLocReader::VisitConstantMatrixTypeLoc(ConstantMatrixTypeLoc TL) {
+  TL.setAttrNameLoc(readSourceLocation());
+  TL.setAttrOperandParensRange(Reader.readSourceRange());
+  TL.setAttrRowOperand(Reader.readExpr());
+  TL.setAttrColumnOperand(Reader.readExpr());
+}
+
+void TypeLocReader::VisitDependentSizedMatrixTypeLoc(
+    DependentSizedMatrixTypeLoc TL) {
+  TL.setAttrNameLoc(readSourceLocation());
+  TL.setAttrOperandParensRange(Reader.readSourceRange());
+  TL.setAttrRowOperand(Reader.readExpr());
+  TL.setAttrColumnOperand(Reader.readExpr());
+}
+
 void TypeLocReader::VisitFunctionTypeLoc(FunctionTypeLoc TL) {
   TL.setLocalRangeBegin(readSourceLocation());
   TL.setLParenLoc(readSourceLocation());

diff  --git a/clang/lib/Serialization/ASTWriter.cpp b/clang/lib/Serialization/ASTWriter.cpp
index b2281abe5ba5..f4e54e4358f9 100644
--- a/clang/lib/Serialization/ASTWriter.cpp
+++ b/clang/lib/Serialization/ASTWriter.cpp
@@ -288,6 +288,25 @@ void TypeLocWriter::VisitExtVectorTypeLoc(ExtVectorTypeLoc TL) {
   Record.AddSourceLocation(TL.getNameLoc());
 }
 
+void TypeLocWriter::VisitConstantMatrixTypeLoc(ConstantMatrixTypeLoc TL) {
+  Record.AddSourceLocation(TL.getAttrNameLoc());
+  SourceRange range = TL.getAttrOperandParensRange();
+  Record.AddSourceLocation(range.getBegin());
+  Record.AddSourceLocation(range.getEnd());
+  Record.AddStmt(TL.getAttrRowOperand());
+  Record.AddStmt(TL.getAttrColumnOperand());
+}
+
+void TypeLocWriter::VisitDependentSizedMatrixTypeLoc(
+    DependentSizedMatrixTypeLoc TL) {
+  Record.AddSourceLocation(TL.getAttrNameLoc());
+  SourceRange range = TL.getAttrOperandParensRange();
+  Record.AddSourceLocation(range.getBegin());
+  Record.AddSourceLocation(range.getEnd());
+  Record.AddStmt(TL.getAttrRowOperand());
+  Record.AddStmt(TL.getAttrColumnOperand());
+}
+
 void TypeLocWriter::VisitFunctionTypeLoc(FunctionTypeLoc TL) {
   Record.AddSourceLocation(TL.getLocalRangeBegin());
   Record.AddSourceLocation(TL.getLParenLoc());

diff  --git a/clang/test/CodeGen/debug-info-matrix-types.c b/clang/test/CodeGen/debug-info-matrix-types.c
new file mode 100644
index 000000000000..4c0bbfc516e1
--- /dev/null
+++ b/clang/test/CodeGen/debug-info-matrix-types.c
@@ -0,0 +1,19 @@
+// RUN: %clang_cc1 -fenable-matrix -triple x86_64-apple-darwin %s -debug-info-kind=limited -emit-llvm -disable-llvm-passes -o - | FileCheck %s
+
+typedef double dx2x3_t __attribute__((matrix_type(2, 3)));
+
+void load_store_double(dx2x3_t *a, dx2x3_t *b) {
+  // CHECK-DAG:  @llvm.dbg.declare(metadata [6 x double]** %a.addr, metadata [[EXPR_A:![0-9]+]]
+  // CHECK-DAG:  @llvm.dbg.declare(metadata [6 x double]** %b.addr, metadata [[EXPR_B:![0-9]+]]
+  // CHECK: [[PTR_TY:![0-9]+]] = !DIDerivedType(tag: DW_TAG_pointer_type, baseType: [[TYPEDEF:![0-9]+]], size: 64)
+  // CHECK: [[TYPEDEF]] = !DIDerivedType(tag: DW_TAG_typedef, name: "dx2x3_t", {{.+}} baseType: [[MATRIX_TY:![0-9]+]])
+  // CHECK: [[MATRIX_TY]] = !DICompositeType(tag: DW_TAG_array_type, baseType: [[ELT_TY:![0-9]+]], size: 384, elements: [[ELEMENTS:![0-9]+]])
+  // CHECK: [[ELT_TY]] = !DIBasicType(name: "double", size: 64, encoding: DW_ATE_float)
+  // CHECK: [[ELEMENTS]] = !{[[COLS:![0-9]+]], [[ROWS:![0-9]+]]}
+  // CHECK: [[COLS]] = !DISubrange(count: 3)
+  // CHECK: [[ROWS]] = !DISubrange(count: 2)
+  // CHECK: [[EXPR_A]] = !DILocalVariable(name: "a", arg: 1, {{.+}} type: [[PTR_TY]])
+  // CHECK: [[EXPR_B]] = !DILocalVariable(name: "b", arg: 2, {{.+}} type: [[PTR_TY]])
+
+  *a = *b;
+}

diff  --git a/clang/test/CodeGen/matrix-type.c b/clang/test/CodeGen/matrix-type.c
new file mode 100644
index 000000000000..31d2497b7986
--- /dev/null
+++ b/clang/test/CodeGen/matrix-type.c
@@ -0,0 +1,158 @@
+// RUN: %clang_cc1 -fenable-matrix -triple x86_64-apple-darwin %s -emit-llvm -disable-llvm-passes -o - | FileCheck %s
+
+#if !__has_extension(matrix_types)
+#error Expected extension 'matrix_types' to be enabled
+#endif
+
+typedef double dx5x5_t __attribute__((matrix_type(5, 5)));
+
+// CHECK: %struct.Matrix = type { i8, [12 x float], float }
+
+void load_store_double(dx5x5_t *a, dx5x5_t *b) {
+  // CHECK-LABEL:  define void @load_store_double(
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %a.addr = alloca [25 x double]*, align 8
+  // CHECK-NEXT:    %b.addr = alloca [25 x double]*, align 8
+  // CHECK-NEXT:    store [25 x double]* %a, [25 x double]** %a.addr, align 8
+  // CHECK-NEXT:    store [25 x double]* %b, [25 x double]** %b.addr, align 8
+  // CHECK-NEXT:    %0 = load [25 x double]*, [25 x double]** %b.addr, align 8
+  // CHECK-NEXT:    %1 = bitcast [25 x double]* %0 to <25 x double>*
+  // CHECK-NEXT:    %2 = load <25 x double>, <25 x double>* %1, align 8
+  // CHECK-NEXT:    %3 = load [25 x double]*, [25 x double]** %a.addr, align 8
+  // CHECK-NEXT:    %4 = bitcast [25 x double]* %3 to <25 x double>*
+  // CHECK-NEXT:    store <25 x double> %2, <25 x double>* %4, align 8
+  // CHECK-NEXT:   ret void
+
+  *a = *b;
+}
+
+typedef float fx3x4_t __attribute__((matrix_type(3, 4)));
+void load_store_float(fx3x4_t *a, fx3x4_t *b) {
+  // CHECK-LABEL:  define void @load_store_float(
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %a.addr = alloca [12 x float]*, align 8
+  // CHECK-NEXT:    %b.addr = alloca [12 x float]*, align 8
+  // CHECK-NEXT:    store [12 x float]* %a, [12 x float]** %a.addr, align 8
+  // CHECK-NEXT:    store [12 x float]* %b, [12 x float]** %b.addr, align 8
+  // CHECK-NEXT:    %0 = load [12 x float]*, [12 x float]** %b.addr, align 8
+  // CHECK-NEXT:    %1 = bitcast [12 x float]* %0 to <12 x float>*
+  // CHECK-NEXT:    %2 = load <12 x float>, <12 x float>* %1, align 4
+  // CHECK-NEXT:    %3 = load [12 x float]*, [12 x float]** %a.addr, align 8
+  // CHECK-NEXT:    %4 = bitcast [12 x float]* %3 to <12 x float>*
+  // CHECK-NEXT:    store <12 x float> %2, <12 x float>* %4, align 4
+  // CHECK-NEXT:   ret void
+
+  *a = *b;
+}
+
+typedef int ix3x4_t __attribute__((matrix_type(4, 3)));
+void load_store_int(ix3x4_t *a, ix3x4_t *b) {
+  // CHECK-LABEL:  define void @load_store_int(
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %a.addr = alloca [12 x i32]*, align 8
+  // CHECK-NEXT:    %b.addr = alloca [12 x i32]*, align 8
+  // CHECK-NEXT:    store [12 x i32]* %a, [12 x i32]** %a.addr, align 8
+  // CHECK-NEXT:    store [12 x i32]* %b, [12 x i32]** %b.addr, align 8
+  // CHECK-NEXT:    %0 = load [12 x i32]*, [12 x i32]** %b.addr, align 8
+  // CHECK-NEXT:    %1 = bitcast [12 x i32]* %0 to <12 x i32>*
+  // CHECK-NEXT:    %2 = load <12 x i32>, <12 x i32>* %1, align 4
+  // CHECK-NEXT:    %3 = load [12 x i32]*, [12 x i32]** %a.addr, align 8
+  // CHECK-NEXT:    %4 = bitcast [12 x i32]* %3 to <12 x i32>*
+  // CHECK-NEXT:    store <12 x i32> %2, <12 x i32>* %4, align 4
+  // CHECK-NEXT:   ret void
+
+  *a = *b;
+}
+
+typedef unsigned long long ullx3x4_t __attribute__((matrix_type(4, 3)));
+void load_store_ull(ullx3x4_t *a, ullx3x4_t *b) {
+  // CHECK-LABEL:  define void @load_store_ull(
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %a.addr = alloca [12 x i64]*, align 8
+  // CHECK-NEXT:    %b.addr = alloca [12 x i64]*, align 8
+  // CHECK-NEXT:    store [12 x i64]* %a, [12 x i64]** %a.addr, align 8
+  // CHECK-NEXT:    store [12 x i64]* %b, [12 x i64]** %b.addr, align 8
+  // CHECK-NEXT:    %0 = load [12 x i64]*, [12 x i64]** %b.addr, align 8
+  // CHECK-NEXT:    %1 = bitcast [12 x i64]* %0 to <12 x i64>*
+  // CHECK-NEXT:    %2 = load <12 x i64>, <12 x i64>* %1, align 8
+  // CHECK-NEXT:    %3 = load [12 x i64]*, [12 x i64]** %a.addr, align 8
+  // CHECK-NEXT:    %4 = bitcast [12 x i64]* %3 to <12 x i64>*
+  // CHECK-NEXT:    store <12 x i64> %2, <12 x i64>* %4, align 8
+  // CHECK-NEXT:   ret void
+
+  *a = *b;
+}
+
+typedef __fp16 fp16x3x4_t __attribute__((matrix_type(4, 3)));
+void load_store_fp16(fp16x3x4_t *a, fp16x3x4_t *b) {
+  // CHECK-LABEL:  define void @load_store_fp16(
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %a.addr = alloca [12 x half]*, align 8
+  // CHECK-NEXT:    %b.addr = alloca [12 x half]*, align 8
+  // CHECK-NEXT:    store [12 x half]* %a, [12 x half]** %a.addr, align 8
+  // CHECK-NEXT:    store [12 x half]* %b, [12 x half]** %b.addr, align 8
+  // CHECK-NEXT:    %0 = load [12 x half]*, [12 x half]** %b.addr, align 8
+  // CHECK-NEXT:    %1 = bitcast [12 x half]* %0 to <12 x half>*
+  // CHECK-NEXT:    %2 = load <12 x half>, <12 x half>* %1, align 2
+  // CHECK-NEXT:    %3 = load [12 x half]*, [12 x half]** %a.addr, align 8
+  // CHECK-NEXT:    %4 = bitcast [12 x half]* %3 to <12 x half>*
+  // CHECK-NEXT:    store <12 x half> %2, <12 x half>* %4, align 2
+  // CHECK-NEXT:   ret void
+
+  *a = *b;
+}
+
+typedef float fx3x3_t __attribute__((matrix_type(3, 3)));
+
+void parameter_passing(fx3x3_t a, fx3x3_t *b) {
+  // CHECK-LABEL: define void @parameter_passing(
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %a.addr = alloca [9 x float], align 4
+  // CHECK-NEXT:    %b.addr = alloca [9 x float]*, align 8
+  // CHECK-NEXT:    %0 = bitcast [9 x float]* %a.addr to <9 x float>*
+  // CHECK-NEXT:    store <9 x float> %a, <9 x float>* %0, align 4
+  // CHECK-NEXT:    store [9 x float]* %b, [9 x float]** %b.addr, align 8
+  // CHECK-NEXT:    %1 = load <9 x float>, <9 x float>* %0, align 4
+  // CHECK-NEXT:    %2 = load [9 x float]*, [9 x float]** %b.addr, align 8
+  // CHECK-NEXT:    %3 = bitcast [9 x float]* %2 to <9 x float>*
+  // CHECK-NEXT:    store <9 x float> %1, <9 x float>* %3, align 4
+  // CHECK-NEXT:    ret void
+  *b = a;
+}
+
+fx3x3_t return_matrix(fx3x3_t *a) {
+  // CHECK-LABEL: define <9 x float> @return_matrix
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %a.addr = alloca [9 x float]*, align 8
+  // CHECK-NEXT:    store [9 x float]* %a, [9 x float]** %a.addr, align 8
+  // CHECK-NEXT:    %0 = load [9 x float]*, [9 x float]** %a.addr, align 8
+  // CHECK-NEXT:    %1 = bitcast [9 x float]* %0 to <9 x float>*
+  // CHECK-NEXT:    %2 = load <9 x float>, <9 x float>* %1, align 4
+  // CHECK-NEXT:    ret <9 x float> %2
+  return *a;
+}
+
+typedef struct {
+  char Tmp1;
+  fx3x4_t Data;
+  float Tmp2;
+} Matrix;
+
+void matrix_struct(Matrix *a, Matrix *b) {
+  // CHECK-LABEL: define void @matrix_struct(
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %a.addr = alloca %struct.Matrix*, align 8
+  // CHECK-NEXT:    %b.addr = alloca %struct.Matrix*, align 8
+  // CHECK-NEXT:    store %struct.Matrix* %a, %struct.Matrix** %a.addr, align 8
+  // CHECK-NEXT:    store %struct.Matrix* %b, %struct.Matrix** %b.addr, align 8
+  // CHECK-NEXT:    %0 = load %struct.Matrix*, %struct.Matrix** %a.addr, align 8
+  // CHECK-NEXT:    %Data = getelementptr inbounds %struct.Matrix, %struct.Matrix* %0, i32 0, i32 1
+  // CHECK-NEXT:    %1 = bitcast [12 x float]* %Data to <12 x float>*
+  // CHECK-NEXT:    %2 = load <12 x float>, <12 x float>* %1, align 4
+  // CHECK-NEXT:    %3 = load %struct.Matrix*, %struct.Matrix** %b.addr, align 8
+  // CHECK-NEXT:    %Data1 = getelementptr inbounds %struct.Matrix, %struct.Matrix* %3, i32 0, i32 1
+  // CHECK-NEXT:    %4 = bitcast [12 x float]* %Data1 to <12 x float>*
+  // CHECK-NEXT:    store <12 x float> %2, <12 x float>* %4, align 4
+  // CHECK-NEXT:    ret void
+  b->Data = a->Data;
+}

diff  --git a/clang/test/CodeGenCXX/matrix-type.cpp b/clang/test/CodeGenCXX/matrix-type.cpp
new file mode 100644
index 000000000000..c243574177a0
--- /dev/null
+++ b/clang/test/CodeGenCXX/matrix-type.cpp
@@ -0,0 +1,388 @@
+// RUN: %clang_cc1 -fenable-matrix -triple x86_64-apple-darwin %s -emit-llvm -disable-llvm-passes -o - -std=c++17 | FileCheck %s
+
+typedef double dx5x5_t __attribute__((matrix_type(5, 5)));
+typedef float fx3x4_t __attribute__((matrix_type(3, 4)));
+
+// CHECK: %struct.Matrix = type { i8, [12 x float], float }
+
+void load_store(dx5x5_t *a, dx5x5_t *b) {
+  // CHECK-LABEL:  define void @_Z10load_storePU11matrix_typeLm5ELm5EdS0_(
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %a.addr = alloca [25 x double]*, align 8
+  // CHECK-NEXT:    %b.addr = alloca [25 x double]*, align 8
+  // CHECK-NEXT:    store [25 x double]* %a, [25 x double]** %a.addr, align 8
+  // CHECK-NEXT:    store [25 x double]* %b, [25 x double]** %b.addr, align 8
+  // CHECK-NEXT:    %0 = load [25 x double]*, [25 x double]** %b.addr, align 8
+  // CHECK-NEXT:    %1 = bitcast [25 x double]* %0 to <25 x double>*
+  // CHECK-NEXT:    %2 = load <25 x double>, <25 x double>* %1, align 8
+  // CHECK-NEXT:    %3 = load [25 x double]*, [25 x double]** %a.addr, align 8
+  // CHECK-NEXT:    %4 = bitcast [25 x double]* %3 to <25 x double>*
+  // CHECK-NEXT:    store <25 x double> %2, <25 x double>* %4, align 8
+  // CHECK-NEXT:   ret void
+
+  *a = *b;
+}
+
+typedef float fx3x3_t __attribute__((matrix_type(3, 3)));
+
+void parameter_passing(fx3x3_t a, fx3x3_t *b) {
+  // CHECK-LABEL: define void @_Z17parameter_passingU11matrix_typeLm3ELm3EfPS_(
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %a.addr = alloca [9 x float], align 4
+  // CHECK-NEXT:    %b.addr = alloca [9 x float]*, align 8
+  // CHECK-NEXT:    %0 = bitcast [9 x float]* %a.addr to <9 x float>*
+  // CHECK-NEXT:    store <9 x float> %a, <9 x float>* %0, align 4
+  // CHECK-NEXT:    store [9 x float]* %b, [9 x float]** %b.addr, align 8
+  // CHECK-NEXT:    %1 = load <9 x float>, <9 x float>* %0, align 4
+  // CHECK-NEXT:    %2 = load [9 x float]*, [9 x float]** %b.addr, align 8
+  // CHECK-NEXT:    %3 = bitcast [9 x float]* %2 to <9 x float>*
+  // CHECK-NEXT:    store <9 x float> %1, <9 x float>* %3, align 4
+  // CHECK-NEXT:    ret void
+  *b = a;
+}
+
+fx3x3_t return_matrix(fx3x3_t *a) {
+  // CHECK-LABEL: define <9 x float> @_Z13return_matrixPU11matrix_typeLm3ELm3Ef(
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %a.addr = alloca [9 x float]*, align 8
+  // CHECK-NEXT:    store [9 x float]* %a, [9 x float]** %a.addr, align 8
+  // CHECK-NEXT:    %0 = load [9 x float]*, [9 x float]** %a.addr, align 8
+  // CHECK-NEXT:    %1 = bitcast [9 x float]* %0 to <9 x float>*
+  // CHECK-NEXT:    %2 = load <9 x float>, <9 x float>* %1, align 4
+  // CHECK-NEXT:    ret <9 x float> %2
+  return *a;
+}
+
+struct Matrix {
+  char Tmp1;
+  fx3x4_t Data;
+  float Tmp2;
+};
+
+void matrix_struct_pointers(Matrix *a, Matrix *b) {
+  // CHECK-LABEL: define void @_Z22matrix_struct_pointersP6MatrixS0_(
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %a.addr = alloca %struct.Matrix*, align 8
+  // CHECK-NEXT:    %b.addr = alloca %struct.Matrix*, align 8
+  // CHECK-NEXT:    store %struct.Matrix* %a, %struct.Matrix** %a.addr, align 8
+  // CHECK-NEXT:    store %struct.Matrix* %b, %struct.Matrix** %b.addr, align 8
+  // CHECK-NEXT:    %0 = load %struct.Matrix*, %struct.Matrix** %a.addr, align 8
+  // CHECK-NEXT:    %Data = getelementptr inbounds %struct.Matrix, %struct.Matrix* %0, i32 0, i32 1
+  // CHECK-NEXT:    %1 = bitcast [12 x float]* %Data to <12 x float>*
+  // CHECK-NEXT:    %2 = load <12 x float>, <12 x float>* %1, align 4
+  // CHECK-NEXT:    %3 = load %struct.Matrix*, %struct.Matrix** %b.addr, align 8
+  // CHECK-NEXT:    %Data1 = getelementptr inbounds %struct.Matrix, %struct.Matrix* %3, i32 0, i32 1
+  // CHECK-NEXT:    %4 = bitcast [12 x float]* %Data1 to <12 x float>*
+  // CHECK-NEXT:    store <12 x float> %2, <12 x float>* %4, align 4
+  // CHECK-NEXT:    ret void
+  b->Data = a->Data;
+}
+
+void matrix_struct_reference(Matrix &a, Matrix &b) {
+  // CHECK-LABEL: define void @_Z23matrix_struct_referenceR6MatrixS0_(
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %a.addr = alloca %struct.Matrix*, align 8
+  // CHECK-NEXT:    %b.addr = alloca %struct.Matrix*, align 8
+  // CHECK-NEXT:    store %struct.Matrix* %a, %struct.Matrix** %a.addr, align 8
+  // CHECK-NEXT:    store %struct.Matrix* %b, %struct.Matrix** %b.addr, align 8
+  // CHECK-NEXT:    %0 = load %struct.Matrix*, %struct.Matrix** %a.addr, align 8
+  // CHECK-NEXT:    %Data = getelementptr inbounds %struct.Matrix, %struct.Matrix* %0, i32 0, i32 1
+  // CHECK-NEXT:    %1 = bitcast [12 x float]* %Data to <12 x float>*
+  // CHECK-NEXT:    %2 = load <12 x float>, <12 x float>* %1, align 4
+  // CHECK-NEXT:    %3 = load %struct.Matrix*, %struct.Matrix** %b.addr, align 8
+  // CHECK-NEXT:    %Data1 = getelementptr inbounds %struct.Matrix, %struct.Matrix* %3, i32 0, i32 1
+  // CHECK-NEXT:    %4 = bitcast [12 x float]* %Data1 to <12 x float>*
+  // CHECK-NEXT:    store <12 x float> %2, <12 x float>* %4, align 4
+  // CHECK-NEXT:    ret void
+  b.Data = a.Data;
+}
+
+class MatrixClass {
+public:
+  int Tmp1;
+  fx3x4_t Data;
+  long Tmp2;
+};
+
+void matrix_class_reference(MatrixClass &a, MatrixClass &b) {
+  // CHECK-LABEL: define void @_Z22matrix_class_referenceR11MatrixClassS0_(
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %a.addr = alloca %class.MatrixClass*, align 8
+  // CHECK-NEXT:    %b.addr = alloca %class.MatrixClass*, align 8
+  // CHECK-NEXT:    store %class.MatrixClass* %a, %class.MatrixClass** %a.addr, align 8
+  // CHECK-NEXT:    store %class.MatrixClass* %b, %class.MatrixClass** %b.addr, align 8
+  // CHECK-NEXT:    %0 = load %class.MatrixClass*, %class.MatrixClass** %a.addr, align 8
+  // CHECK-NEXT:    %Data = getelementptr inbounds %class.MatrixClass, %class.MatrixClass* %0, i32 0, i32 1
+  // CHECK-NEXT:    %1 = bitcast [12 x float]* %Data to <12 x float>*
+  // CHECK-NEXT:    %2 = load <12 x float>, <12 x float>* %1, align 4
+  // CHECK-NEXT:    %3 = load %class.MatrixClass*, %class.MatrixClass** %b.addr, align 8
+  // CHECK-NEXT:    %Data1 = getelementptr inbounds %class.MatrixClass, %class.MatrixClass* %3, i32 0, i32 1
+  // CHECK-NEXT:    %4 = bitcast [12 x float]* %Data1 to <12 x float>*
+  // CHECK-NEXT:    store <12 x float> %2, <12 x float>* %4, align 4
+  // CHECK-NEXT:    ret void
+  b.Data = a.Data;
+}
+
+template <typename Ty, unsigned Rows, unsigned Cols>
+class MatrixClassTemplate {
+public:
+  using MatrixTy = Ty __attribute__((matrix_type(Rows, Cols)));
+  int Tmp1;
+  MatrixTy Data;
+  long Tmp2;
+};
+
+template <typename Ty, unsigned Rows, unsigned Cols>
+void matrix_template_reference(MatrixClassTemplate<Ty, Rows, Cols> &a, MatrixClassTemplate<Ty, Rows, Cols> &b) {
+  b.Data = a.Data;
+}
+
+MatrixClassTemplate<float, 10, 15> matrix_template_reference_caller(float *Data) {
+  // CHECK-LABEL: define void @_Z32matrix_template_reference_callerPf(%class.MatrixClassTemplate* noalias sret align 8 %agg.result, float* %Data
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %Data.addr = alloca float*, align 8
+  // CHECK-NEXT:    %Arg = alloca %class.MatrixClassTemplate, align 8
+  // CHECK-NEXT:    store float* %Data, float** %Data.addr, align 8
+  // CHECK-NEXT:    %0 = load float*, float** %Data.addr, align 8
+  // CHECK-NEXT:    %1 = bitcast float* %0 to [150 x float]*
+  // CHECK-NEXT:    %2 = bitcast [150 x float]* %1 to <150 x float>*
+  // CHECK-NEXT:    %3 = load <150 x float>, <150 x float>* %2, align 4
+  // CHECK-NEXT:    %Data1 = getelementptr inbounds %class.MatrixClassTemplate, %class.MatrixClassTemplate* %Arg, i32 0, i32 1
+  // CHECK-NEXT:    %4 = bitcast [150 x float]* %Data1 to <150 x float>*
+  // CHECK-NEXT:    store <150 x float> %3, <150 x float>* %4, align 4
+  // CHECK-NEXT:    call void @_Z25matrix_template_referenceIfLj10ELj15EEvR19MatrixClassTemplateIT_XT0_EXT1_EES3_(%class.MatrixClassTemplate* dereferenceable(616) %Arg, %class.MatrixClassTemplate* dereferenceable(616) %agg.result)
+  // CHECK-NEXT:    ret void
+
+  // CHECK-LABEL: define linkonce_odr void @_Z25matrix_template_referenceIfLj10ELj15EEvR19MatrixClassTemplateIT_XT0_EXT1_EES3_(%class.MatrixClassTemplate* dereferenceable(616) %a, %class.MatrixClassTemplate* dereferenceable(616) %b)
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %a.addr = alloca %class.MatrixClassTemplate*, align 8
+  // CHECK-NEXT:    %b.addr = alloca %class.MatrixClassTemplate*, align 8
+  // CHECK-NEXT:    store %class.MatrixClassTemplate* %a, %class.MatrixClassTemplate** %a.addr, align 8
+  // CHECK-NEXT:    store %class.MatrixClassTemplate* %b, %class.MatrixClassTemplate** %b.addr, align 8
+  // CHECK-NEXT:    %0 = load %class.MatrixClassTemplate*, %class.MatrixClassTemplate** %a.addr, align 8
+  // CHECK-NEXT:    %Data = getelementptr inbounds %class.MatrixClassTemplate, %class.MatrixClassTemplate* %0, i32 0, i32 1
+  // CHECK-NEXT:    %1 = bitcast [150 x float]* %Data to <150 x float>*
+  // CHECK-NEXT:    %2 = load <150 x float>, <150 x float>* %1, align 4
+  // CHECK-NEXT:    %3 = load %class.MatrixClassTemplate*, %class.MatrixClassTemplate** %b.addr, align 8
+  // CHECK-NEXT:    %Data1 = getelementptr inbounds %class.MatrixClassTemplate, %class.MatrixClassTemplate* %3, i32 0, i32 1
+  // CHECK-NEXT:    %4 = bitcast [150 x float]* %Data1 to <150 x float>*
+  // CHECK-NEXT:    store <150 x float> %2, <150 x float>* %4, align 4
+  // CHECK-NEXT:    ret void
+
+  MatrixClassTemplate<float, 10, 15> Result, Arg;
+  Arg.Data = *((MatrixClassTemplate<float, 10, 15>::MatrixTy *)Data);
+  matrix_template_reference(Arg, Result);
+  return Result;
+}
+
+template <class T, unsigned long R, unsigned long C>
+using matrix = T __attribute__((matrix_type(R, C)));
+
+template <int N>
+struct selector {};
+
+template <class T, unsigned long R, unsigned long C>
+selector<0> use_matrix(matrix<T, R, C> &m) {}
+
+template <class T, unsigned long R>
+selector<1> use_matrix(matrix<T, R, 10> &m) {}
+
+template <class T>
+selector<2> use_matrix(matrix<T, 10, 10> &m) {}
+
+template <class T, unsigned long C>
+selector<3> use_matrix(matrix<T, 10, C> &m) {}
+
+template <unsigned long R, unsigned long C>
+selector<4> use_matrix(matrix<float, R, C> &m) {}
+
+void test_template_deduction() {
+
+  // CHECK-LABEL: define void @_Z23test_template_deductionv()
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %m0 = alloca [120 x i32], align 4
+  // CHECK-NEXT:    %w = alloca %struct.selector, align 1
+  // CHECK-NEXT:    %undef.agg.tmp = alloca %struct.selector, align 1
+  // CHECK-NEXT:    %m1 = alloca [100 x i32], align 4
+  // CHECK-NEXT:    %x = alloca %struct.selector.0, align 1
+  // CHECK-NEXT:    %undef.agg.tmp1 = alloca %struct.selector.0, align 1
+  // CHECK-NEXT:    %m2 = alloca [120 x i32], align 4
+  // CHECK-NEXT:    %y = alloca %struct.selector.1, align 1
+  // CHECK-NEXT:    %undef.agg.tmp2 = alloca %struct.selector.1, align 1
+  // CHECK-NEXT:    %m3 = alloca [144 x i32], align 4
+  // CHECK-NEXT:    %z = alloca %struct.selector.2, align 1
+  // CHECK-NEXT:    %undef.agg.tmp3 = alloca %struct.selector.2, align 1
+  // CHECK-NEXT:    %m4 = alloca [144 x float], align 4
+  // CHECK-NEXT:    %v = alloca %struct.selector.3, align 1
+  // CHECK-NEXT:    %undef.agg.tmp4 = alloca %struct.selector.3, align 1
+  // CHECK-NEXT:    call void @_Z10use_matrixIiLm12EE8selectorILi3EERU11matrix_typeXLm10EEXT0_ET_([120 x i32]* dereferenceable(480) %m0)
+  // CHECK-NEXT:    call void @_Z10use_matrixIiE8selectorILi2EERU11matrix_typeLm10ELm10ET_([100 x i32]* dereferenceable(400) %m1)
+  // CHECK-NEXT:    call void @_Z10use_matrixIiLm12EE8selectorILi1EERU11matrix_typeXT0_EXLm10EET_([120 x i32]* dereferenceable(480) %m2)
+  // CHECK-NEXT:    call void @_Z10use_matrixIiLm12ELm12EE8selectorILi0EERU11matrix_typeXT0_EXT1_ET_([144 x i32]* dereferenceable(576) %m3)
+  // CHECK-NEXT:    call void @_Z10use_matrixILm12ELm12EE8selectorILi4EERU11matrix_typeXT_EXT0_Ef([144 x float]* dereferenceable(576) %m4)
+  // CHECK-NEXT:    ret void
+
+  // CHECK-LABEL: define linkonce_odr void @_Z10use_matrixIiLm12EE8selectorILi3EERU11matrix_typeXLm10EEXT0_ET_([120 x i32]* dereferenceable(480) %m)
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %m.addr = alloca [120 x i32]*, align 8
+  // CHECK-NEXT:    store [120 x i32]* %m, [120 x i32]** %m.addr, align 8
+  // CHECK-NEXT:    call void @llvm.trap()
+  // CHECK-NEXT:    unreachable
+
+  // CHECK-LABEL: define linkonce_odr void @_Z10use_matrixIiE8selectorILi2EERU11matrix_typeLm10ELm10ET_([100 x i32]* dereferenceable(400) %m)
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %m.addr = alloca [100 x i32]*, align 8
+  // CHECK-NEXT:    store [100 x i32]* %m, [100 x i32]** %m.addr, align 8
+  // CHECK-NEXT:    call void @llvm.trap()
+  // CHECK-NEXT:    unreachable
+
+  // CHECK-LABEL: define linkonce_odr void @_Z10use_matrixIiLm12EE8selectorILi1EERU11matrix_typeXT0_EXLm10EET_([120 x i32]* dereferenceable(480) %m)
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %m.addr = alloca [120 x i32]*, align 8
+  // CHECK-NEXT:    store [120 x i32]* %m, [120 x i32]** %m.addr, align 8
+  // CHECK-NEXT:    call void @llvm.trap()
+  // CHECK-NEXT:    unreachable
+
+  // CHECK-LABEL: define linkonce_odr void @_Z10use_matrixIiLm12ELm12EE8selectorILi0EERU11matrix_typeXT0_EXT1_ET_([144 x i32]* dereferenceable(576) %m)
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %m.addr = alloca [144 x i32]*, align 8
+  // CHECK-NEXT:    store [144 x i32]* %m, [144 x i32]** %m.addr, align 8
+  // CHECK-NEXT:    call void @llvm.trap()
+  // CHECK-NEXT:    unreachable
+
+  // CHECK-LABEL: define linkonce_odr void @_Z10use_matrixILm12ELm12EE8selectorILi4EERU11matrix_typeXT_EXT0_Ef([144 x float]* dereferenceable(576)
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %m.addr = alloca [144 x float]*, align 8
+  // CHECK-NEXT:    store [144 x float]* %m, [144 x float]** %m.addr, align 8
+  // CHECK-NEXT:    call void @llvm.trap()
+  // CHECK-NEXT:    unreachable
+
+  matrix<int, 10, 12> m0;
+  selector<3> w = use_matrix(m0);
+  matrix<int, 10, 10> m1;
+  selector<2> x = use_matrix(m1);
+  matrix<int, 12, 10> m2;
+  selector<1> y = use_matrix(m2);
+  matrix<int, 12, 12> m3;
+  selector<0> z = use_matrix(m3);
+  matrix<float, 12, 12> m4;
+  selector<4> v = use_matrix(m4);
+}
+
+template <auto R>
+void foo(matrix<int, R, 10> &m) {
+}
+
+void test_auto_t() {
+  // CHECK-LABEL: define void @_Z11test_auto_tv()
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %m = alloca [130 x i32], align 4
+  // CHECK-NEXT:    call void @_Z3fooILm13EEvRU11matrix_typeXT_EXLm10EEi([130 x i32]* dereferenceable(520) %m)
+  // CHECK-NEXT:    ret void
+
+  // CHECK-LABEL: define linkonce_odr void @_Z3fooILm13EEvRU11matrix_typeXT_EXLm10EEi([130 x i32]* dereferenceable(520) %m)
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %m.addr = alloca [130 x i32]*, align 8
+  // CHECK-NEXT:    store [130 x i32]* %m, [130 x i32]** %m.addr, align 8
+  // CHECK-NEXT:    ret void
+
+  matrix<int, 13, 10> m;
+  foo(m);
+}
+
+template <unsigned long R, unsigned long C>
+matrix<float, R + 1, C + 2> use_matrix_2(matrix<int, R, C> &m) {}
+
+template <unsigned long R, unsigned long C>
+selector<0> use_matrix_2(matrix<int, R + 2, C / 2> &m1, matrix<float, R, C> &m2) {}
+
+template <unsigned long R, unsigned long C>
+selector<1> use_matrix_2(matrix<int, R + C, C> &m1, matrix<float, R, C - R> &m2) {}
+
+template <unsigned long R>
+matrix<float, R + R, R - 3> use_matrix_2(matrix<int, R, 10> &m1) {}
+
+template <unsigned long R>
+selector<2> use_matrix_3(matrix<int, R - 2, R> &m) {}
+
+void test_use_matrix_2() {
+  // CHECK-LABEL: define void @_Z17test_use_matrix_2v()
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %m1 = alloca [24 x i32], align 4
+  // CHECK-NEXT:    %r1 = alloca [40 x float], align 4
+  // CHECK-NEXT:    %m2 = alloca [24 x float], align 4
+  // CHECK-NEXT:    %r2 = alloca %struct.selector.2, align 1
+  // CHECK-NEXT:    %undef.agg.tmp = alloca %struct.selector.2, align 1
+  // CHECK-NEXT:    %m3 = alloca [104 x i32], align 4
+  // CHECK-NEXT:    %m4 = alloca [15 x float], align 4
+  // CHECK-NEXT:    %r3 = alloca %struct.selector.1, align 1
+  // CHECK-NEXT:    %undef.agg.tmp1 = alloca %struct.selector.1, align 1
+  // CHECK-NEXT:    %m5 = alloca [50 x i32], align 4
+  // CHECK-NEXT:    %r4 = alloca [20 x float], align 4
+  // CHECK-NEXT:    %r5 = alloca %struct.selector.0, align 1
+  // CHECK-NEXT:    %undef.agg.tmp3 = alloca %struct.selector.0, align 1
+  // CHECK-NEXT:    %call = call <40 x float> @_Z12use_matrix_2ILm4ELm6EEU11matrix_typeXplT_Li1EEXplT0_Li2EEfRU11matrix_typeXT_EXT0_Ei([24 x i32]* dereferenceable(96) %m1)
+  // CHECK-NEXT:    %0 = bitcast [40 x float]* %r1 to <40 x float>*
+  // CHECK-NEXT:    store <40 x float> %call, <40 x float>* %0, align 4
+  // CHECK-NEXT:    call void @_Z12use_matrix_2ILm2ELm12EE8selectorILi0EERU11matrix_typeXplT_Li2EEXdvT0_Li2EEiRU11matrix_typeXT_EXT0_Ef([24 x i32]* dereferenceable(96) %m1, [24 x float]* dereferenceable(96) %m2)
+  // CHECK-NEXT:    call void @_Z12use_matrix_2ILm5ELm8EE8selectorILi1EERU11matrix_typeXplT_T0_EXT0_EiRU11matrix_typeXT_EXmiT0_T_Ef([104 x i32]* dereferenceable(416) %m3, [15 x float]* dereferenceable(60) %m4)
+  // CHECK-NEXT:    %call2 = call <20 x float> @_Z12use_matrix_2ILm5EEU11matrix_typeXplT_T_EXmiT_Li3EEfRU11matrix_typeXT_EXLm10EEi([50 x i32]* dereferenceable(200) %m5)
+  // CHECK-NEXT:    %1 = bitcast [20 x float]* %r4 to <20 x float>*
+  // CHECK-NEXT:    store <20 x float> %call2, <20 x float>* %1, align 4
+  // CHECK-NEXT:    call void @_Z12use_matrix_3ILm6EE8selectorILi2EERU11matrix_typeXmiT_Li2EEXT_Ei([24 x i32]* dereferenceable(96) %m1)
+  // CHECK-NEXT:    ret void
+
+  // CHECK-LABEL: define linkonce_odr <40 x float> @_Z12use_matrix_2ILm4ELm6EEU11matrix_typeXplT_Li1EEXplT0_Li2EEfRU11matrix_typeXT_EXT0_Ei([24 x i32]* dereferenceable(96) %m)
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %m.addr = alloca [24 x i32]*, align 8
+  // CHECK-NEXT:    store [24 x i32]* %m, [24 x i32]** %m.addr, align 8
+  // CHECK-NEXT:    call void @llvm.trap()
+  // CHECK-NEXT:    unreachable
+
+  // CHECK-LABEL: define linkonce_odr void @_Z12use_matrix_2ILm2ELm12EE8selectorILi0EERU11matrix_typeXplT_Li2EEXdvT0_Li2EEiRU11matrix_typeXT_EXT0_Ef([24 x i32]* dereferenceable(96) %m1, [24 x float]* dereferenceable(96) %m2)
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %m1.addr = alloca [24 x i32]*, align 8
+  // CHECK-NEXT:    %m2.addr = alloca [24 x float]*, align 8
+  // CHECK-NEXT:    store [24 x i32]* %m1, [24 x i32]** %m1.addr, align 8
+  // CHECK-NEXT:    store [24 x float]* %m2, [24 x float]** %m2.addr, align 8
+  // CHECK-NEXT:    call void @llvm.trap()
+  // CHECK-NEXT:    unreachable
+
+  // CHECK-LABEL: define linkonce_odr void @_Z12use_matrix_2ILm5ELm8EE8selectorILi1EERU11matrix_typeXplT_T0_EXT0_EiRU11matrix_typeXT_EXmiT0_T_Ef([104 x i32]* dereferenceable(416) %m1, [15 x float]* dereferenceable(60) %m2)
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %m1.addr = alloca [104 x i32]*, align 8
+  // CHECK-NEXT:    %m2.addr = alloca [15 x float]*, align 8
+  // CHECK-NEXT:    store [104 x i32]* %m1, [104 x i32]** %m1.addr, align 8
+  // CHECK-NEXT:    store [15 x float]* %m2, [15 x float]** %m2.addr, align 8
+  // CHECK-NEXT:    call void @llvm.trap()
+  // CHECK-NEXT:    unreachable
+
+  // CHECK-LABEL: define linkonce_odr <20 x float> @_Z12use_matrix_2ILm5EEU11matrix_typeXplT_T_EXmiT_Li3EEfRU11matrix_typeXT_EXLm10EEi([50 x i32]* dereferenceable(200) %m1)
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %m1.addr = alloca [50 x i32]*, align 8
+  // CHECK-NEXT:    store [50 x i32]* %m1, [50 x i32]** %m1.addr, align 8
+  // CHECK-NEXT:    call void @llvm.trap()
+  // CHECK-NEXT:    unreachable
+
+  // CHECK-LABEL: define linkonce_odr void @_Z12use_matrix_3ILm6EE8selectorILi2EERU11matrix_typeXmiT_Li2EEXT_Ei([24 x i32]* dereferenceable(96) %m)
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %m.addr = alloca [24 x i32]*, align 8
+  // CHECK-NEXT:    store [24 x i32]* %m, [24 x i32]** %m.addr, align 8
+  // CHECK-NEXT:    call void @llvm.trap()
+  // CHECK-NEXT:    unreachable
+
+  matrix<int, 4, 6> m1;
+  matrix<float, 5, 8> r1 = use_matrix_2(m1);
+
+  matrix<float, 2, 12> m2;
+  selector<0> r2 = use_matrix_2(m1, m2);
+
+  matrix<int, 13, 8> m3;
+  matrix<float, 5, 3> m4;
+  selector<1> r3 = use_matrix_2(m3, m4);
+
+  matrix<int, 5, 10> m5;
+  matrix<float, 10, 2> r4 = use_matrix_2(m5);
+
+  selector<2> r5 = use_matrix_3(m1);
+}

diff  --git a/clang/test/Parser/matrix-type-disabled.c b/clang/test/Parser/matrix-type-disabled.c
new file mode 100644
index 000000000000..ccecd7eda20c
--- /dev/null
+++ b/clang/test/Parser/matrix-type-disabled.c
@@ -0,0 +1,14 @@
+// RUN: %clang_cc1 %s -triple i686-apple-darwin -verify -fsyntax-only
+
+// Matrix types are disabled by default.
+
+#if __has_extension(matrix_types)
+#error Expected extension 'matrix_types' to be disabled
+#endif
+
+typedef double dx5x5_t __attribute__((matrix_type(5, 5)));
+// expected-error at -1 {{matrix types extension is disabled. Pass -fenable-matrix to enable it}}
+
+void load_store_double(dx5x5_t *a, dx5x5_t *b) {
+  *a = *b;
+}

diff  --git a/clang/test/SemaCXX/matrix-type.cpp b/clang/test/SemaCXX/matrix-type.cpp
new file mode 100644
index 000000000000..76f288d83e6e
--- /dev/null
+++ b/clang/test/SemaCXX/matrix-type.cpp
@@ -0,0 +1,129 @@
+// RUN: %clang_cc1 -fsyntax-only -pedantic -fenable-matrix -std=c++11 -verify -triple x86_64-apple-darwin %s
+
+using matrix_double_t = double __attribute__((matrix_type(6, 6)));
+using matrix_float_t = float __attribute__((matrix_type(6, 6)));
+using matrix_int_t = int __attribute__((matrix_type(6, 6)));
+
+void matrix_var_dimensions(int Rows, unsigned Columns, char C) {
+  using matrix1_t = int __attribute__((matrix_type(Rows, 1)));    // expected-error{{matrix_type attribute requires an integer constant}}
+  using matrix2_t = int __attribute__((matrix_type(1, Columns))); // expected-error{{matrix_type attribute requires an integer constant}}
+  using matrix3_t = int __attribute__((matrix_type(C, C)));       // expected-error{{matrix_type attribute requires an integer constant}}
+  using matrix4_t = int __attribute__((matrix_type(-1, 1)));      // expected-error{{matrix row size too large}}
+  using matrix5_t = int __attribute__((matrix_type(1, -1)));      // expected-error{{matrix column size too large}}
+  using matrix6_t = int __attribute__((matrix_type(0, 1)));       // expected-error{{zero matrix size}}
+  using matrix7_t = int __attribute__((matrix_type(1, 0)));       // expected-error{{zero matrix size}}
+  using matrix7_t = int __attribute__((matrix_type(char, 0)));    // expected-error{{expected '(' for function-style cast or type construction}}
+  using matrix8_t = int __attribute__((matrix_type(1048576, 1))); // expected-error{{matrix row size too large}}
+}
+
+struct S1 {};
+
+enum TestEnum {
+  A,
+  B
+};
+
+void matrix_unsupported_element_type() {
+  using matrix1_t = char *__attribute__((matrix_type(1, 1)));    // expected-error{{invalid matrix element type 'char *'}}
+  using matrix2_t = S1 __attribute__((matrix_type(1, 1)));       // expected-error{{invalid matrix element type 'S1'}}
+  using matrix3_t = bool __attribute__((matrix_type(1, 1)));     // expected-error{{invalid matrix element type 'bool'}}
+  using matrix4_t = TestEnum __attribute__((matrix_type(1, 1))); // expected-error{{invalid matrix element type 'TestEnum'}}
+}
+
+template <typename T> // expected-note{{declared here}}
+void matrix_template_1() {
+  using matrix1_t = float __attribute__((matrix_type(T, T))); // expected-error{{'T' does not refer to a value}}
+}
+
+template <class C> // expected-note{{declared here}}
+void matrix_template_2() {
+  using matrix1_t = float __attribute__((matrix_type(C, C))); // expected-error{{'C' does not refer to a value}}
+}
+
+template <unsigned Rows, unsigned Cols>
+void matrix_template_3() {
+  using matrix1_t = float __attribute__((matrix_type(Rows, Cols))); // expected-error{{zero matrix size}}
+}
+
+void instantiate_template_3() {
+  matrix_template_3<1, 10>();
+  matrix_template_3<0, 10>(); // expected-note{{in instantiation of function template specialization 'matrix_template_3<0, 10>' requested here}}
+}
+
+template <int Rows, unsigned Cols>
+void matrix_template_4() {
+  using matrix1_t = float __attribute__((matrix_type(Rows, Cols))); // expected-error{{matrix row size too large}}
+}
+
+void instantiate_template_4() {
+  matrix_template_4<2, 10>();
+  matrix_template_4<-3, 10>(); // expected-note{{in instantiation of function template specialization 'matrix_template_4<-3, 10>' requested here}}
+}
+
+template <class T, unsigned long R, unsigned long C>
+using matrix = T __attribute__((matrix_type(R, C)));
+
+template <class T, unsigned long R>
+void use_matrix(matrix<T, R, 10> &m) {}
+// expected-note at -1 {{candidate function [with T = float, R = 10]}}
+
+template <class T, unsigned long C>
+void use_matrix(matrix<T, 10, C> &m) {}
+// expected-note at -1 {{candidate function [with T = float, C = 10]}}
+
+void test_ambigous_deduction1() {
+  matrix<float, 10, 10> m;
+  use_matrix(m);
+  // expected-error at -1 {{call to 'use_matrix' is ambiguous}}
+}
+
+template <class T, long R>
+void type_conflict(matrix<T, R, 10> &m, T x) {}
+// expected-note at -1 {{candidate template ignored: deduced conflicting types for parameter 'T' ('float' vs. 'char *')}}
+
+void test_type_conflict(char *p) {
+  matrix<float, 10, 10> m;
+  type_conflict(m, p);
+  // expected-error at -1 {{no matching function for call to 'type_conflict'}}
+}
+
+template <unsigned long R, unsigned long C>
+matrix<float, R + 1, C + 2> use_matrix_2(matrix<int, R, C> &m) {}
+// expected-note at -1 {{candidate function template not viable: requires single argument 'm', but 2 arguments were provided}}
+// expected-note at -2 {{candidate function template not viable: requires single argument 'm', but 2 arguments were provided}}
+
+template <unsigned long R, unsigned long C>
+void use_matrix_2(matrix<int, R + 2, C / 2> &m1, matrix<float, R, C> &m2) {}
+// expected-note at -1 {{candidate function [with R = 3, C = 11] not viable: no known conversion from 'matrix<int, 5, 6>' (aka 'int __attribute__((matrix_type(5, 6)))') to 'matrix<int, 3UL + 2, 11UL / 2> &' (aka 'int  __attribute__((matrix_type(5, 5)))&') for 1st argument}}
+// expected-note at -2 {{candidate template ignored: deduced type 'matrix<float, 3UL, 4UL>' of 2nd parameter does not match adjusted type 'matrix<int, 3, 4>' of argument [with R = 3, C = 4]}}
+
+template <typename T, unsigned long R, unsigned long C>
+void use_matrix_2(matrix<T, R + C, C> &m1, matrix<T, R, C - R> &m2) {}
+// expected-note at -1 {{candidate template ignored: deduced conflicting types for parameter 'T' ('int' vs. 'float')}}
+// expected-note at -2 {{candidate template ignored: deduced type 'matrix<[...], 3UL + 4UL, 4UL>' of 1st parameter does not match adjusted type 'matrix<[...], 3, 4>' of argument [with T = int, R = 3, C = 4]}}
+
+template <typename T, unsigned long R>
+void use_matrix_3(matrix<T, R - 2, R> &m) {}
+// expected-note at -1 {{candidate template ignored: deduced type 'matrix<[...], 5UL - 2, 5UL>' of 1st parameter does not match adjusted type 'matrix<[...], 5, 5>' of argument [with T = unsigned int, R = 5]}}
+
+void test_use_matrix_2() {
+  matrix<int, 5, 6> m1;
+  matrix<float, 5, 8> r1 = use_matrix_2(m1);
+  // expected-error at -1 {{cannot initialize a variable of type 'matrix<[...], 5, 8>' with an rvalue of type 'matrix<[...], 5UL + 1, 6UL + 2>'}}
+
+  matrix<int, 4, 5> m2;
+  matrix<float, 5, 8> r2 = use_matrix_2(m2);
+  // expected-error at -1 {{cannot initialize a variable of type 'matrix<[...], 5, 8>' with an rvalue of type 'matrix<[...], 4UL + 1, 5UL + 2>'}}
+
+  matrix<float, 3, 11> m3;
+  use_matrix_2(m1, m3);
+  // expected-error at -1 {{no matching function for call to 'use_matrix_2'}}
+
+  matrix<int, 3, 4> m4;
+  use_matrix_2(m4, m4);
+  // expected-error at -1 {{no matching function for call to 'use_matrix_2'}}
+
+  matrix<unsigned, 5, 5> m5;
+  use_matrix_3(m5);
+  // expected-error at -1 {{no matching function for call to 'use_matrix_3'}}
+}

diff  --git a/clang/tools/libclang/CIndex.cpp b/clang/tools/libclang/CIndex.cpp
index 2afc5a4eb842..583950c4620b 100644
--- a/clang/tools/libclang/CIndex.cpp
+++ b/clang/tools/libclang/CIndex.cpp
@@ -1795,6 +1795,8 @@ DEFAULT_TYPELOC_IMPL(DependentVector, Type)
 DEFAULT_TYPELOC_IMPL(DependentSizedExtVector, Type)
 DEFAULT_TYPELOC_IMPL(Vector, Type)
 DEFAULT_TYPELOC_IMPL(ExtVector, VectorType)
+DEFAULT_TYPELOC_IMPL(ConstantMatrix, MatrixType)
+DEFAULT_TYPELOC_IMPL(DependentSizedMatrix, MatrixType)
 DEFAULT_TYPELOC_IMPL(FunctionProto, FunctionType)
 DEFAULT_TYPELOC_IMPL(FunctionNoProto, FunctionType)
 DEFAULT_TYPELOC_IMPL(Record, TagType)


        


More information about the cfe-commits mailing list