[clang] 4affc44 - [Matrix] Implement * binary operator for MatrixType.

Florian Hahn via cfe-commits cfe-commits at lists.llvm.org
Sun Jun 7 03:24:27 PDT 2020


Author: Florian Hahn
Date: 2020-06-07T11:11:27+01:00
New Revision: 4affc444b499ba2a12fee7f9ef0bb6ef33789c12

URL: https://github.com/llvm/llvm-project/commit/4affc444b499ba2a12fee7f9ef0bb6ef33789c12
DIFF: https://github.com/llvm/llvm-project/commit/4affc444b499ba2a12fee7f9ef0bb6ef33789c12.diff

LOG: [Matrix] Implement * binary operator for MatrixType.

This patch implements the * binary operator for values of
MatrixType. It adds support for matrix * matrix, scalar * matrix and
matrix * scalar.

For the matrix, matrix case, the number of columns of the first operand
must match the number of rows of the second. For the scalar,matrix variants,
the element type of the matrix must match the scalar type.

Reviewers: rjmccall, anemet, Bigcheese, rsmith, martong

Reviewed By: rjmccall

Differential Revision: https://reviews.llvm.org/D76794

Added: 
    

Modified: 
    clang/include/clang/Sema/Sema.h
    clang/lib/CodeGen/CGExprScalar.cpp
    clang/lib/Sema/SemaExpr.cpp
    clang/lib/Sema/SemaOverload.cpp
    clang/test/CodeGen/matrix-type-operators.c
    clang/test/CodeGenCXX/matrix-type-operators.cpp
    clang/test/Sema/matrix-type-operators.c
    clang/test/SemaCXX/matrix-type-operators.cpp
    llvm/include/llvm/IR/MatrixBuilder.h

Removed: 
    


################################################################################
diff  --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h
index cef25fc927aa..8f914ec451f6 100644
--- a/clang/include/clang/Sema/Sema.h
+++ b/clang/include/clang/Sema/Sema.h
@@ -11218,6 +11218,8 @@ class Sema final {
   QualType CheckMatrixElementwiseOperands(ExprResult &LHS, ExprResult &RHS,
                                           SourceLocation Loc,
                                           bool IsCompAssign);
+  QualType CheckMatrixMultiplyOperands(ExprResult &LHS, ExprResult &RHS,
+                                       SourceLocation Loc, bool IsCompAssign);
 
   bool areLaxCompatibleVectorTypes(QualType srcType, QualType destType);
   bool isLaxVectorConversion(QualType srcType, QualType destType);

diff  --git a/clang/lib/CodeGen/CGExprScalar.cpp b/clang/lib/CodeGen/CGExprScalar.cpp
index b169462f535a..cabfe0cca281 100644
--- a/clang/lib/CodeGen/CGExprScalar.cpp
+++ b/clang/lib/CodeGen/CGExprScalar.cpp
@@ -741,6 +741,22 @@ class ScalarExprEmitter
       }
     }
 
+    if (Ops.Ty->isConstantMatrixType()) {
+      llvm::MatrixBuilder<CGBuilderTy> MB(Builder);
+      // We need to check the types of the operands of the operator to get the
+      // correct matrix dimensions.
+      auto *BO = cast<BinaryOperator>(Ops.E);
+      auto *LHSMatTy = dyn_cast<ConstantMatrixType>(
+          BO->getLHS()->getType().getCanonicalType());
+      auto *RHSMatTy = dyn_cast<ConstantMatrixType>(
+          BO->getRHS()->getType().getCanonicalType());
+      if (LHSMatTy && RHSMatTy)
+        return MB.CreateMatrixMultiply(Ops.LHS, Ops.RHS, LHSMatTy->getNumRows(),
+                                       LHSMatTy->getNumColumns(),
+                                       RHSMatTy->getNumColumns());
+      return MB.CreateScalarMultiply(Ops.LHS, Ops.RHS);
+    }
+
     if (Ops.Ty->isUnsignedIntegerType() &&
         CGF.SanOpts.has(SanitizerKind::UnsignedIntegerOverflow) &&
         !CanElideOverflowCheck(CGF.getContext(), Ops))

diff  --git a/clang/lib/Sema/SemaExpr.cpp b/clang/lib/Sema/SemaExpr.cpp
index d7a39cf6a6b5..45c1acecbe94 100644
--- a/clang/lib/Sema/SemaExpr.cpp
+++ b/clang/lib/Sema/SemaExpr.cpp
@@ -10058,6 +10058,9 @@ QualType Sema::CheckMultiplyDivideOperands(ExprResult &LHS, ExprResult &RHS,
     return CheckVectorOperands(LHS, RHS, Loc, IsCompAssign,
                                /*AllowBothBool*/getLangOpts().AltiVec,
                                /*AllowBoolConversions*/false);
+  if (!IsDiv && (LHS.get()->getType()->isConstantMatrixType() ||
+                 RHS.get()->getType()->isConstantMatrixType()))
+    return CheckMatrixMultiplyOperands(LHS, RHS, Loc, IsCompAssign);
 
   QualType compType = UsualArithmeticConversions(
       LHS, RHS, Loc, IsCompAssign ? ACK_CompAssign : ACK_Arithmetic);
@@ -12120,6 +12123,37 @@ QualType Sema::CheckMatrixElementwiseOperands(ExprResult &LHS, ExprResult &RHS,
   return InvalidOperands(Loc, LHS, RHS);
 }
 
+QualType Sema::CheckMatrixMultiplyOperands(ExprResult &LHS, ExprResult &RHS,
+                                           SourceLocation Loc,
+                                           bool IsCompAssign) {
+  if (!IsCompAssign) {
+    LHS = DefaultFunctionArrayLvalueConversion(LHS.get());
+    if (LHS.isInvalid())
+      return QualType();
+  }
+  RHS = DefaultFunctionArrayLvalueConversion(RHS.get());
+  if (RHS.isInvalid())
+    return QualType();
+
+  auto *LHSMatType = LHS.get()->getType()->getAs<ConstantMatrixType>();
+  auto *RHSMatType = RHS.get()->getType()->getAs<ConstantMatrixType>();
+  assert((LHSMatType || RHSMatType) && "At least one operand must be a matrix");
+
+  if (LHSMatType && RHSMatType) {
+    if (LHSMatType->getNumColumns() != RHSMatType->getNumRows())
+      return InvalidOperands(Loc, LHS, RHS);
+
+    if (!Context.hasSameType(LHSMatType->getElementType(),
+                             RHSMatType->getElementType()))
+      return InvalidOperands(Loc, LHS, RHS);
+
+    return Context.getConstantMatrixType(LHSMatType->getElementType(),
+                                         LHSMatType->getNumRows(),
+                                         RHSMatType->getNumColumns());
+  }
+  return CheckMatrixElementwiseOperands(LHS, RHS, Loc, IsCompAssign);
+}
+
 inline QualType Sema::CheckBitwiseOperands(ExprResult &LHS, ExprResult &RHS,
                                            SourceLocation Loc,
                                            BinaryOperatorKind Opc) {

diff  --git a/clang/lib/Sema/SemaOverload.cpp b/clang/lib/Sema/SemaOverload.cpp
index d7fc9664772b..5e5f53991a68 100644
--- a/clang/lib/Sema/SemaOverload.cpp
+++ b/clang/lib/Sema/SemaOverload.cpp
@@ -9167,8 +9167,10 @@ void Sema::AddBuiltinOperatorCandidates(OverloadedOperatorKind Op,
   case OO_Star: // '*' is either unary or binary
     if (Args.size() == 1)
       OpBuilder.addUnaryStarPointerOverloads();
-    else
+    else {
       OpBuilder.addGenericBinaryArithmeticOverloads();
+      OpBuilder.addMatrixBinaryArithmeticOverloads();
+    }
     break;
 
   case OO_Slash:

diff  --git a/clang/test/CodeGen/matrix-type-operators.c b/clang/test/CodeGen/matrix-type-operators.c
index c649922aa2c5..f44ecd77337c 100644
--- a/clang/test/CodeGen/matrix-type-operators.c
+++ b/clang/test/CodeGen/matrix-type-operators.c
@@ -173,6 +173,134 @@ void add_matrix_scalar_long_long_int_unsigned_long_long(ullx4x2_t b, unsigned lo
   b = vulli + b;
 }
 
+// Tests for matrix multiplication.
+
+void multiply_matrix_matrix_double(dx5x5_t b, dx5x5_t c) {
+  // CHECK-LABEL: @multiply_matrix_matrix_double(
+  // CHECK:         [[B:%.*]] = load <25 x double>, <25 x double>* {{.*}}, align 8
+  // CHECK-NEXT:    [[C:%.*]] = load <25 x double>, <25 x double>* {{.*}}, align 8
+  // CHECK-NEXT:    [[RES:%.*]] = call <25 x double> @llvm.matrix.multiply.v25f64.v25f64.v25f64(<25 x double> [[B]], <25 x double> [[C]], i32 5, i32 5, i32 5)
+  // CHECK-NEXT:    [[A_ADDR:%.*]] = bitcast [25 x double]* %a to <25 x double>*
+  // CHECK-NEXT:    store <25 x double> [[RES]], <25 x double>* [[A_ADDR]], align 8
+  // CHECK-NEXT:    ret void
+  //
+
+  dx5x5_t a;
+  a = b * c;
+}
+
+typedef int ix3x9_t __attribute__((matrix_type(3, 9)));
+typedef int ix9x9_t __attribute__((matrix_type(9, 9)));
+// CHECK-LABEL: @multiply_matrix_matrix_int(
+// CHECK:         [[B:%.*]] = load <27 x i32>, <27 x i32>* {{.*}}, align 4
+// CHECK-NEXT:    [[C:%.*]] = load <27 x i32>, <27 x i32>* {{.*}}, align 4
+// CHECK-NEXT:    [[RES:%.*]] = call <81 x i32> @llvm.matrix.multiply.v81i32.v27i32.v27i32(<27 x i32> [[B]], <27 x i32> [[C]], i32 9, i32 3, i32 9)
+// CHECK-NEXT:    [[A_ADDR:%.*]] = bitcast [81 x i32]* %a to <81 x i32>*
+// CHECK-NEXT:    store <81 x i32> [[RES]], <81 x i32>* [[A_ADDR]], align 4
+// CHECK-NEXT:    ret void
+//
+void multiply_matrix_matrix_int(ix9x3_t b, ix3x9_t c) {
+  ix9x9_t a;
+  a = b * c;
+}
+
+// CHECK-LABEL: @multiply_double_matrix_scalar_float(
+// CHECK:         [[A:%.*]] = load <25 x double>, <25 x double>* {{.*}}, align 8
+// CHECK-NEXT:    [[S:%.*]] = load float, float* %s.addr, align 4
+// CHECK-NEXT:    [[S_EXT:%.*]] = fpext float [[S]] to double
+// CHECK-NEXT:    [[VECINSERT:%.*]] = insertelement <25 x double> undef, double [[S_EXT]], i32 0
+// CHECK-NEXT:    [[VECSPLAT:%.*]] = shufflevector <25 x double> [[VECINSERT]], <25 x double> undef, <25 x i32> zeroinitializer
+// CHECK-NEXT:    [[RES:%.*]] = fmul <25 x double> [[A]], [[VECSPLAT]]
+// CHECK-NEXT:    store <25 x double> [[RES]], <25 x double>* {{.*}}, align 8
+// CHECK-NEXT:    ret void
+//
+void multiply_double_matrix_scalar_float(dx5x5_t a, float s) {
+  a = a * s;
+}
+
+// CHECK-LABEL: @multiply_double_matrix_scalar_double(
+// CHECK:         [[A:%.*]] = load <25 x double>, <25 x double>* {{.*}}, align 8
+// CHECK-NEXT:    [[S:%.*]] = load double, double* %s.addr, align 8
+// CHECK-NEXT:    [[VECINSERT:%.*]] = insertelement <25 x double> undef, double [[S]], i32 0
+// CHECK-NEXT:    [[VECSPLAT:%.*]] = shufflevector <25 x double> [[VECINSERT]], <25 x double> undef, <25 x i32> zeroinitializer
+// CHECK-NEXT:    [[RES:%.*]] = fmul <25 x double> [[A]], [[VECSPLAT]]
+// CHECK-NEXT:    store <25 x double> [[RES]], <25 x double>* {{.*}}, align 8
+// CHECK-NEXT:    ret void
+//
+void multiply_double_matrix_scalar_double(dx5x5_t a, double s) {
+  a = a * s;
+}
+
+// CHECK-LABEL: @multiply_float_matrix_scalar_double(
+// CHECK:         [[S:%.*]] = load double, double* %s.addr, align 8
+// CHECK-NEXT:    [[S_TRUNC:%.*]] = fptrunc double [[S]] to float
+// CHECK-NEXT:    [[MAT:%.*]] = load <6 x float>, <6 x float>* [[MAT_ADDR:%.*]], align 4
+// CHECK-NEXT:    [[VECINSERT:%.*]] = insertelement <6 x float> undef, float [[S_TRUNC]], i32 0
+// CHECK-NEXT:    [[VECSPLAT:%.*]] = shufflevector <6 x float> [[VECINSERT]], <6 x float> undef, <6 x i32> zeroinitializer
+// CHECK-NEXT:    [[RES:%.*]] = fmul <6 x float> [[VECSPLAT]], [[MAT]]
+// CHECK-NEXT:    store <6 x float> [[RES]], <6 x float>* [[MAT_ADDR]], align 4
+// CHECK-NEXT:    ret void
+//
+void multiply_float_matrix_scalar_double(fx2x3_t b, double s) {
+  b = s * b;
+}
+
+// CHECK-LABEL: @multiply_int_matrix_scalar_short(
+// CHECK:         [[S:%.*]] = load i16, i16* %s.addr, align 2
+// CHECK-NEXT:    [[S_EXT:%.*]] = sext i16 [[S]] to i32
+// CHECK-NEXT:    [[MAT:%.*]] = load <27 x i32>, <27 x i32>* [[MAT_ADDR:%.*]], align 4
+// CHECK-NEXT:    [[VECINSERT:%.*]] = insertelement <27 x i32> undef, i32 [[S_EXT]], i32 0
+// CHECK-NEXT:    [[VECSPLAT:%.*]] = shufflevector <27 x i32> [[VECINSERT]], <27 x i32> undef, <27 x i32> zeroinitializer
+// CHECK-NEXT:    [[RES:%.*]] = mul <27 x i32> [[VECSPLAT]], [[MAT]]
+// CHECK-NEXT:    store <27 x i32> [[RES]], <27 x i32>* [[MAT_ADDR]], align 4
+// CHECK-NEXT:    ret void
+//
+void multiply_int_matrix_scalar_short(ix9x3_t b, short s) {
+  b = s * b;
+}
+
+// CHECK-LABEL: @multiply_int_matrix_scalar_ull(
+// CHECK:         [[MAT:%.*]] = load <27 x i32>, <27 x i32>* [[MAT_ADDR:%.*]], align 4
+// CHECK-NEXT:    [[S:%.*]] = load i64, i64* %s.addr, align 8
+// CHECK-NEXT:    [[S_TRUNC:%.*]] = trunc i64 [[S]] to i32
+// CHECK-NEXT:    [[VECINSERT:%.*]] = insertelement <27 x i32> undef, i32 [[S_TRUNC]], i32 0
+// CHECK-NEXT:    [[VECSPLAT:%.*]] = shufflevector <27 x i32> [[VECINSERT]], <27 x i32> undef, <27 x i32> zeroinitializer
+// CHECK-NEXT:    [[RES:%.*]] = mul <27 x i32> [[MAT]], [[VECSPLAT]]
+// CHECK-NEXT:    store <27 x i32> [[RES]], <27 x i32>* [[MAT_ADDR]], align 4
+// CHECK-NEXT:    ret void
+//
+void multiply_int_matrix_scalar_ull(ix9x3_t b, unsigned long long s) {
+  b = b * s;
+}
+
+// CHECK-LABEL: @multiply_float_matrix_constant(
+// CHECK-NEXT:  entry:
+// CHECK-NEXT:    [[A_ADDR:%.*]] = alloca [6 x float], align 4
+// CHECK-NEXT:    [[MAT_ADDR:%.*]] = bitcast [6 x float]* [[A_ADDR]] to <6 x float>*
+// CHECK-NEXT:    store <6 x float> [[A:%.*]], <6 x float>* [[MAT_ADDR]], align 4
+// CHECK-NEXT:    [[MAT:%.*]] = load <6 x float>, <6 x float>* [[MAT_ADDR]], align 4
+// CHECK-NEXT:    [[RES:%.*]] = fmul <6 x float> [[MAT]], <float 2.500000e+00, float 2.500000e+00, float 2.500000e+00, float 2.500000e+00, float 2.500000e+00, float 2.500000e+00>
+// CHECK-NEXT:    store <6 x float> [[RES]], <6 x float>* [[MAT_ADDR]], align 4
+// CHECK-NEXT:    ret void
+//
+void multiply_float_matrix_constant(fx2x3_t a) {
+  a = a * 2.5;
+}
+
+// CHECK-LABEL: @multiply_int_matrix_constant(
+// CHECK-NEXT:  entry:
+// CHECK-NEXT:    [[A_ADDR:%.*]] = alloca [27 x i32], align 4
+// CHECK-NEXT:    [[MAT_ADDR:%.*]] = bitcast [27 x i32]* [[A_ADDR]] to <27 x i32>*
+// CHECK-NEXT:    store <27 x i32> [[A:%.*]], <27 x i32>* [[MAT_ADDR]], align 4
+// CHECK-NEXT:    [[MAT:%.*]] = load <27 x i32>, <27 x i32>* [[MAT_ADDR]], align 4
+// CHECK-NEXT:    [[RES:%.*]] = mul <27 x i32> <i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5>, [[MAT]]
+// CHECK-NEXT:    store <27 x i32> [[RES]], <27 x i32>* [[MAT_ADDR]], align 4
+// CHECK-NEXT:    ret void
+//
+void multiply_int_matrix_constant(ix9x3_t a) {
+  a = 5 * a;
+}
+
 // Tests for the matrix type operators.
 
 typedef double dx5x5_t __attribute__((matrix_type(5, 5)));

diff  --git a/clang/test/CodeGenCXX/matrix-type-operators.cpp b/clang/test/CodeGenCXX/matrix-type-operators.cpp
index a9eec96e2be1..5ca55d259756 100644
--- a/clang/test/CodeGenCXX/matrix-type-operators.cpp
+++ b/clang/test/CodeGenCXX/matrix-type-operators.cpp
@@ -157,6 +157,45 @@ void test_IntWrapper_Sub(MyMatrix<double, 10, 9> &m) {
   m.value = w3 - m.value;
 }
 
+template <typename EltTy0, unsigned R0, unsigned C0, unsigned C1>
+typename MyMatrix<EltTy0, R0, C1>::matrix_t multiply(MyMatrix<EltTy0, R0, C0> &A, MyMatrix<EltTy0, C0, C1> &B) {
+  return A.value * B.value;
+}
+
+MyMatrix<float, 2, 2> test_multiply_template(MyMatrix<float, 2, 5> Mat1,
+                                             MyMatrix<float, 5, 2> Mat2) {
+  // CHECK-LABEL: define void @_Z22test_multiply_template8MyMatrixIfLj2ELj5EES_IfLj5ELj2EE(
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    [[RES:%.*]] = call <4 x float> @_Z8multiplyIfLj2ELj5ELj2EEN8MyMatrixIT_XT0_EXT2_EE8matrix_tERS0_IS1_XT0_EXT1_EERS0_IS1_XT1_EXT2_EE(%struct.MyMatrix* nonnull align 4 dereferenceable(40) %Mat1, %struct.MyMatrix.2* nonnull align 4 dereferenceable(40) %Mat2)
+  // CHECK-NEXT:    %value = getelementptr inbounds %struct.MyMatrix.1, %struct.MyMatrix.1* %agg.result, i32 0, i32 0
+  // CHECK-NEXT:    [[VALUE_ADDR:%.*]] = bitcast [4 x float]* %value to <4 x float>*
+  // CHECK-NEXT:    store <4 x float> [[RES]], <4 x float>* [[VALUE_ADDR]], align 4
+  // CHECK-NEXT:    ret void
+  //
+  // CHECK-LABEL:  define linkonce_odr <4 x float> @_Z8multiplyIfLj2ELj5ELj2EEN8MyMatrixIT_XT0_EXT2_EE8matrix_tERS0_IS1_XT0_EXT1_EERS0_IS1_XT1_EXT2_EE(
+  // CHECK:         [[MAT1:%.*]] = load <10 x float>, <10 x float>* {{.*}}, align 4
+  // CHECK:         [[MAT2:%.*]] = load <10 x float>, <10 x float>* {{.*}}, align 4
+  // CHECK-NEXT:    [[RES:%.*]] = call <4 x float> @llvm.matrix.multiply.v4f32.v10f32.v10f32(<10 x float> [[MAT1]], <10 x float> [[MAT2]], i32 2, i32 5, i32 2)
+  // CHECK-NEXT:    ret <4 x float> [[RES]]
+
+  MyMatrix<float, 2, 2> Res;
+  Res.value = multiply(Mat1, Mat2);
+  return Res;
+}
+
+void test_IntWrapper_Multiply(MyMatrix<double, 10, 9> &m, IntWrapper &w3) {
+  // CHECK-LABEL: define void @_Z24test_IntWrapper_MultiplyR8MyMatrixIdLj10ELj9EER10IntWrapper(
+  // CHECK:       [[SCALAR:%.*]] = call i32 @_ZN10IntWrappercviEv(%struct.IntWrapper* {{.*}})
+  // CHECK-NEXT:  [[SCALAR_FP:%.*]] = sitofp i32 %call to double
+  // CHECK:       [[MATRIX:%.*]] = load <90 x double>, <90 x double>* {{.*}}, align 8
+  // CHECK-NEXT:  [[SCALAR_EMBED:%.*]] = insertelement <90 x double> undef, double [[SCALAR_FP]], i32 0
+  // CHECK-NEXT:  [[SCALAR_EMBED1:%.*]] = shufflevector <90 x double> [[SCALAR_EMBED]], <90 x double> undef, <90 x i32> zeroinitializer
+  // CHECK-NEXT:  [[RES:%.*]] = fmul <90 x double> [[SCALAR_EMBED1]], [[MATRIX]]
+  // CHECK:       store <90 x double> [[RES]], <90 x double>* {{.*}}, align 8
+  // CHECK:       ret void
+  m.value = w3 * m.value;
+}
+
 template <typename EltTy, unsigned Rows, unsigned Columns>
 void insert(MyMatrix<EltTy, Rows, Columns> &Mat, EltTy e, unsigned i, unsigned j) {
   Mat.value[i][j] = e;
@@ -164,11 +203,11 @@ void insert(MyMatrix<EltTy, Rows, Columns> &Mat, EltTy e, unsigned i, unsigned j
 
 void test_insert_template1(MyMatrix<unsigned, 2, 2> &Mat, unsigned e, unsigned i, unsigned j) {
   // CHECK-LABEL: @_Z21test_insert_template1R8MyMatrixIjLj2ELj2EEjjj(
-  // CHECK:         [[MAT_ADDR:%.*]] = load %struct.MyMatrix.1*, %struct.MyMatrix.1** %Mat.addr, align 8
+  // CHECK:         [[MAT_ADDR:%.*]] = load %struct.MyMatrix.3*, %struct.MyMatrix.3** %Mat.addr, align 8
   // CHECK-NEXT:    [[E:%.*]] = load i32, i32* %e.addr, align 4
   // CHECK-NEXT:    [[I:%.*]] = load i32, i32* %i.addr, align 4
   // CHECK-NEXT:    [[J:%.*]] = load i32, i32* %j.addr, align 4
-  // CHECK-NEXT:    call void @_Z6insertIjLj2ELj2EEvR8MyMatrixIT_XT0_EXT1_EES1_jj(%struct.MyMatrix.1* nonnull align 4 dereferenceable(16) [[MAT_ADDR]], i32 [[E]], i32 [[I]], i32 [[J]])
+  // CHECK-NEXT:    call void @_Z6insertIjLj2ELj2EEvR8MyMatrixIT_XT0_EXT1_EES1_jj(%struct.MyMatrix.3* nonnull align 4 dereferenceable(16) [[MAT_ADDR]], i32 [[E]], i32 [[I]], i32 [[J]])
   // CHECK-NEXT:    ret void
   //
   // CHECK-LABEL: define linkonce_odr void @_Z6insertIjLj2ELj2EEvR8MyMatrixIT_XT0_EXT1_EES1_jj(
@@ -190,9 +229,9 @@ void test_insert_template1(MyMatrix<unsigned, 2, 2> &Mat, unsigned e, unsigned i
 
 void test_insert_template2(MyMatrix<float, 3, 8> &Mat, float e) {
   // CHECK-LABEL: @_Z21test_insert_template2R8MyMatrixIfLj3ELj8EEf(
-  // CHECK:         [[MAT_ADDR:%.*]] = load %struct.MyMatrix.2*, %struct.MyMatrix.2** %Mat.addr, align 8
+  // CHECK:         [[MAT_ADDR:%.*]] = load %struct.MyMatrix.4*, %struct.MyMatrix.4** %Mat.addr, align 8
   // CHECK-NEXT:    [[E:%.*]] = load float, float* %e.addr, align 4
-  // CHECK-NEXT:    call void @_Z6insertIfLj3ELj8EEvR8MyMatrixIT_XT0_EXT1_EES1_jj(%struct.MyMatrix.2* nonnull align 4 dereferenceable(96) [[MAT_ADDR]], float [[E]], i32 2, i32 5)
+  // CHECK-NEXT:    call void @_Z6insertIfLj3ELj8EEvR8MyMatrixIT_XT0_EXT1_EES1_jj(%struct.MyMatrix.4* nonnull align 4 dereferenceable(96) [[MAT_ADDR]], float [[E]], i32 2, i32 5)
   // CHECK-NEXT:    ret void
   //
   // CHECK-LABEL: define linkonce_odr void @_Z6insertIfLj3ELj8EEvR8MyMatrixIT_XT0_EXT1_EES1_jj(
@@ -220,7 +259,7 @@ EltTy extract(MyMatrix<EltTy, Rows, Columns> &Mat) {
 int test_extract_template(MyMatrix<int, 2, 2> Mat1) {
   // CHECK-LABEL: @_Z21test_extract_template8MyMatrixIiLj2ELj2EE(
   // CHECK-NEXT:  entry:
-  // CHECK-NEXT:    [[CALL:%.*]] = call i32 @_Z7extractIiLj2ELj2EET_R8MyMatrixIS0_XT0_EXT1_EE(%struct.MyMatrix.3* nonnull align 4 dereferenceable(16) [[MAT1:%.*]])
+  // CHECK-NEXT:    [[CALL:%.*]] = call i32 @_Z7extractIiLj2ELj2EET_R8MyMatrixIS0_XT0_EXT1_EE(%struct.MyMatrix.5* nonnull align 4 dereferenceable(16) [[MAT1:%.*]])
   // CHECK-NEXT:    ret i32 [[CALL]]
   //
   // CHECK-LABEL: define linkonce_odr i32 @_Z7extractIiLj2ELj2EET_R8MyMatrixIS0_XT0_EXT1_EE(
@@ -301,7 +340,7 @@ struct identmatrix_t {
 constexpr identmatrix_t identmatrix;
 
 void test_constexpr1(matrix_type<float, 4, 4> &m) {
-  // CHECK-LABEL: define void @_Z15test_constexpr1RU11matrix_typeLm4ELm4Ef([16 x float]* nonnull align 4 dereferenceable(64) %m) #3 {
+  // CHECK-LABEL: define void @_Z15test_constexpr1RU11matrix_typeLm4ELm4Ef(
   // CHECK:         [[MAT:%.*]] = load <16 x float>, <16 x float>* {{.*}}, align 4
   // CHECK-NEXT:    [[IM:%.*]] = call <16 x float> @_ZNK13identmatrix_tcvU11matrix_typeXT0_EXT0_ET_IfLj4EEEv(%struct.identmatrix_t* @_ZL11identmatrix)
   // CHECK-NEXT:    [[ADD:%.*]] = fadd <16 x float> [[MAT]], [[IM]]
@@ -327,7 +366,7 @@ void test_constexpr1(matrix_type<float, 4, 4> &m) {
 }
 
 void test_constexpr2(matrix_type<int, 5, 5> &m) {
-  // CHECK-LABEL: define void @_Z15test_constexpr2RU11matrix_typeLm5ELm5Ei([25 x i32]* nonnull align 4 dereferenceable(100) %m) #4 {
+  // CHECK-LABEL: define void @_Z15test_constexpr2RU11matrix_typeLm5ELm5Ei(
   // CHECK:         [[IM:%.*]] = call <25 x i32> @_ZNK13identmatrix_tcvU11matrix_typeXT0_EXT0_ET_IiLj5EEEv(%struct.identmatrix_t* @_ZL11identmatrix)
   // CHECK:         [[MAT:%.*]] = load <25 x i32>, <25 x i32>* {{.*}}, align 4
   // CHECK-NEXT:    [[SUB:%.*]] = sub <25 x i32> [[IM]], [[MAT]]

diff  --git a/clang/test/Sema/matrix-type-operators.c b/clang/test/Sema/matrix-type-operators.c
index 0de0dda92746..397fc3513557 100644
--- a/clang/test/Sema/matrix-type-operators.c
+++ b/clang/test/Sema/matrix-type-operators.c
@@ -32,6 +32,44 @@ void sub(sx10x10_t a, sx5x10_t b, sx10x5_t c) {
   // expected-error at -2 {{casting 'sx10x5_t *' (aka 'float  __attribute__((matrix_type(10, 5)))*') to incompatible type 'float'}}
 }
 
+typedef int ix10x5_t __attribute__((matrix_type(10, 5)));
+typedef int ix10x10_t __attribute__((matrix_type(10, 10)));
+
+void matrix_matrix_multiply(sx10x10_t a, sx5x10_t b, ix10x5_t c, ix10x10_t d, float sf, char *p) {
+  // Check dimension mismatches.
+  a = a * b;
+  // expected-error at -1 {{invalid operands to binary expression ('sx10x10_t' (aka 'float __attribute__((matrix_type(10, 10)))') and 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))'))}}
+  b = a * a;
+  // expected-error at -1 {{assigning to 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') from incompatible type 'float __attribute__((matrix_type(10, 10)))'}}
+
+  // Check element type mismatches.
+  a = b * c;
+  // expected-error at -1 {{invalid operands to binary expression ('sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') and 'ix10x5_t' (aka 'int __attribute__((matrix_type(10, 5)))'))}}
+  d = a * a;
+  // expected-error at -1 {{assigning to 'ix10x10_t' (aka 'int __attribute__((matrix_type(10, 10)))') from incompatible type 'float __attribute__((matrix_type(10, 10)))'}}
+
+  p = a * a;
+  // expected-error at -1 {{assigning to 'char *' from incompatible type 'float __attribute__((matrix_type(10, 10)))'}}
+}
+
+void mat_scalar_multiply(sx10x10_t a, sx5x10_t b, float sf, char *p) {
+  // Shape of multiplication result does not match the type of b.
+  b = a * sf;
+  // expected-error at -1 {{assigning to 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') from incompatible type 'sx10x10_t' (aka 'float __attribute__((matrix_type(10, 10)))')}}
+  b = sf * a;
+  // expected-error at -1 {{assigning to 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') from incompatible type 'sx10x10_t' (aka 'float __attribute__((matrix_type(10, 10)))')}}
+
+  a = a * p;
+  // expected-error at -1 {{casting 'char *' to incompatible type 'float'}}
+  // expected-error at -2 {{invalid operands to binary expression ('sx10x10_t' (aka 'float __attribute__((matrix_type(10, 10)))') and 'char *')}}
+  a = p * a;
+  // expected-error at -1 {{casting 'char *' to incompatible type 'float'}}
+  // expected-error at -2 {{invalid operands to binary expression ('char *' and 'sx10x10_t' (aka 'float __attribute__((matrix_type(10, 10)))'))}}
+
+  sf = a * sf;
+  // expected-error at -1 {{assigning to 'float' from incompatible type 'sx10x10_t' (aka 'float __attribute__((matrix_type(10, 10)))')}}
+}
+
 sx5x10_t get_matrix();
 
 void insert(sx5x10_t a, float f) {

diff  --git a/clang/test/SemaCXX/matrix-type-operators.cpp b/clang/test/SemaCXX/matrix-type-operators.cpp
index 337d44f56e64..808e47451461 100644
--- a/clang/test/SemaCXX/matrix-type-operators.cpp
+++ b/clang/test/SemaCXX/matrix-type-operators.cpp
@@ -65,6 +65,45 @@ void test_subtract_template(unsigned *Ptr1, float *Ptr2) {
   // expected-note at -1 {{in instantiation of function template specialization 'subtract<unsigned int, 3, 3, float, 2, 2, unsigned int, 2, 2>' requested here}}
 }
 
+template <typename EltTy0, unsigned R0, unsigned C0, typename EltTy1, unsigned R1, unsigned C1, typename EltTy2, unsigned R2, unsigned C2>
+typename MyMatrix<EltTy2, R2, C2>::matrix_t multiply(MyMatrix<EltTy0, R0, C0> &A, MyMatrix<EltTy1, R1, C1> &B) {
+  char *v1 = A.value * B.value;
+  // expected-error at -1 {{cannot initialize a variable of type 'char *' with an rvalue of type 'unsigned int __attribute__((matrix_type(2, 2)))'}}
+  // expected-error at -2 {{invalid operands to binary expression ('MyMatrix<unsigned int, 3, 2>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 2)))') and 'MyMatrix<unsigned int, 3, 3>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 3)))'))}}
+  // expected-error at -3 {{invalid operands to binary expression ('MyMatrix<float, 2, 2>::matrix_t' (aka 'float __attribute__((matrix_type(2, 2)))') and 'MyMatrix<unsigned int, 2, 2>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))'))}}
+
+  MyMatrix<int, 5, 6> m;
+  B.value = m.value * A.value;
+  // expected-error at -1 {{invalid operands to binary expression ('MyMatrix<int, 5, 6>::matrix_t' (aka 'int __attribute__((matrix_type(5, 6)))') and 'MyMatrix<unsigned int, 2, 2>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))'))}}
+  // expected-error at -2 {{invalid operands to binary expression ('MyMatrix<int, 5, 6>::matrix_t' (aka 'int __attribute__((matrix_type(5, 6)))') and 'MyMatrix<unsigned int, 3, 2>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 2)))'))}}
+  // expected-error at -3 {{invalid operands to binary expression ('MyMatrix<int, 5, 6>::matrix_t' (aka 'int __attribute__((matrix_type(5, 6)))') and 'MyMatrix<float, 2, 2>::matrix_t' (aka 'float __attribute__((matrix_type(2, 2)))'))}}
+
+  return A.value * B.value;
+  // expected-error at -1 {{invalid operands to binary expression ('MyMatrix<unsigned int, 3, 2>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 2)))') and 'MyMatrix<unsigned int, 3, 3>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 3)))'))}}
+  // expected-error at -2 {{invalid operands to binary expression ('MyMatrix<float, 2, 2>::matrix_t' (aka 'float __attribute__((matrix_type(2, 2)))') and 'MyMatrix<unsigned int, 2, 2>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))'))}}
+}
+
+void test_multiply_template(unsigned *Ptr1, float *Ptr2) {
+  MyMatrix<unsigned, 2, 2> Mat1;
+  MyMatrix<unsigned, 3, 3> Mat2;
+  MyMatrix<float, 2, 2> Mat3;
+  Mat1.value = *((decltype(Mat1)::matrix_t *)Ptr1);
+  unsigned v1 = multiply<unsigned, 2, 2, unsigned, 2, 2, unsigned, 2, 2>(Mat1, Mat1);
+  // expected-note at -1 {{in instantiation of function template specialization 'multiply<unsigned int, 2, 2, unsigned int, 2, 2, unsigned int, 2, 2>' requested here}}
+  // expected-error at -2 {{cannot initialize a variable of type 'unsigned int' with an rvalue of type 'typename MyMatrix<unsigned int, 2U, 2U>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))')}}
+
+  MyMatrix<unsigned, 3, 2> Mat4;
+  Mat1.value = multiply<unsigned, 3, 2, unsigned, 3, 3, unsigned, 2, 2>(Mat4, Mat2);
+  // expected-note at -1 {{in instantiation of function template specialization 'multiply<unsigned int, 3, 2, unsigned int, 3, 3, unsigned int, 2, 2>' requested here}}
+
+  Mat1.value = multiply<float, 2, 2, unsigned, 2, 2, unsigned, 2, 2>(Mat3, Mat1);
+  // expected-note at -1 {{in instantiation of function template specialization 'multiply<float, 2, 2, unsigned int, 2, 2, unsigned int, 2, 2>' requested here}}
+
+  Mat4.value = Mat4.value * Mat1;
+  // expected-error at -1 {{no viable conversion from 'MyMatrix<unsigned int, 2, 2>' to 'unsigned int'}}
+  // expected-error at -2 {{invalid operands to binary expression ('MyMatrix<unsigned int, 3, 2>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 2)))') and 'MyMatrix<unsigned int, 2, 2>')}}
+}
+
 struct UserT {};
 
 struct StructWithC {

diff  --git a/llvm/include/llvm/IR/MatrixBuilder.h b/llvm/include/llvm/IR/MatrixBuilder.h
index 20aa47ef6013..4833ed2a9ec1 100644
--- a/llvm/include/llvm/IR/MatrixBuilder.h
+++ b/llvm/include/llvm/IR/MatrixBuilder.h
@@ -33,6 +33,21 @@ template <class IRBuilderTy> class MatrixBuilder {
   IRBuilderTy &B;
   Module *getModule() { return B.GetInsertBlock()->getParent()->getParent(); }
 
+  std::pair<Value *, Value *> splatScalarOperandIfNeeded(Value *LHS,
+                                                         Value *RHS) {
+    assert((LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy()) &&
+           "One of the operands must be a matrix (embedded in a vector)");
+    if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy())
+      RHS = B.CreateVectorSplat(
+          cast<VectorType>(LHS->getType())->getNumElements(), RHS,
+          "scalar.splat");
+    else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy())
+      LHS = B.CreateVectorSplat(
+          cast<VectorType>(RHS->getType())->getNumElements(), LHS,
+          "scalar.splat");
+    return {LHS, RHS};
+  }
+
 public:
   MatrixBuilder(IRBuilderTy &Builder) : B(Builder) {}
 
@@ -164,15 +179,13 @@ template <class IRBuilderTy> class MatrixBuilder {
                : B.CreateSub(LHS, RHS);
   }
 
-  /// Multiply matrix \p LHS with scalar \p RHS.
+  /// Multiply matrix \p LHS with scalar \p RHS or scalar \p LHS with matrix \p
+  /// RHS.
   Value *CreateScalarMultiply(Value *LHS, Value *RHS) {
-    Value *ScalarVector =
-        B.CreateVectorSplat(cast<VectorType>(LHS->getType())->getNumElements(),
-                            RHS, "scalar.splat");
-    if (RHS->getType()->isFloatingPointTy())
-      return B.CreateFMul(LHS, ScalarVector);
-
-    return B.CreateMul(LHS, ScalarVector);
+    std::tie(LHS, RHS) = splatScalarOperandIfNeeded(LHS, RHS);
+    if (LHS->getType()->getScalarType()->isFloatingPointTy())
+      return B.CreateFMul(LHS, RHS);
+    return B.CreateMul(LHS, RHS);
   }
 
   /// Extracts the element at (\p RowIdx, \p ColumnIdx) from \p Matrix.


        


More information about the cfe-commits mailing list