[llvm] 1669fdd - [Matrix] Use alignment info when lowering loads/stores.

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Thu Jun 18 05:29:49 PDT 2020


Author: Florian Hahn
Date: 2020-06-18T13:19:31+01:00
New Revision: 1669fddc9f6096a5c674ba0a459403ea2c5ffd9c

URL: https://github.com/llvm/llvm-project/commit/1669fddc9f6096a5c674ba0a459403ea2c5ffd9c
DIFF: https://github.com/llvm/llvm-project/commit/1669fddc9f6096a5c674ba0a459403ea2c5ffd9c.diff

LOG: [Matrix] Use alignment info when lowering loads/stores.

This patch updates LowerMatrixIntrinsics to preserve the alignment
specified at the original load/stores and the align attribute for the
pointer argument of the column.major.load/store intrinsics.

We can always use the specified alignment for the load of the first
column. For subsequent columns, the alignment may need to be reduced.

For ConstantInt strides, compute the offset for the start of the column in
bytes and use commonAlignment to get the largest valid alignment.

For non-ConstantInt strides, we need to take the common alignment of the
initial alignment and the element size in bytes.

Reviewers: anemet, Gerolf, hfinkel, andrew.w.kaylor, LuoYuanke, rjmccall

Reviewed By: rjmccall

Differential Revision: https://reviews.llvm.org/D81960

Added: 
    

Modified: 
    llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
    llvm/test/Transforms/LowerMatrixIntrinsics/const-gep.ll
    llvm/test/Transforms/LowerMatrixIntrinsics/load-align-volatile.ll
    llvm/test/Transforms/LowerMatrixIntrinsics/store-align-volatile.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 97a40b806aab..c663714ea66d 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -37,6 +37,7 @@
 #include "llvm/IR/PatternMatch.h"
 #include "llvm/InitializePasses.h"
 #include "llvm/Pass.h"
+#include "llvm/Support/Alignment.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Transforms/Scalar.h"
 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
@@ -732,20 +733,6 @@ class LowerMatrixIntrinsics {
     return Changed;
   }
 
-  LoadInst *createVectorLoad(Value *ColumnPtr, Type *EltType, bool IsVolatile,
-                             IRBuilder<> &Builder) {
-    return Builder.CreateAlignedLoad(ColumnPtr,
-                                     Align(DL.getABITypeAlignment(EltType)),
-                                     IsVolatile, "col.load");
-  }
-
-  StoreInst *createVectorStore(Value *ColumnValue, Value *ColumnPtr,
-                               Type *EltType, bool IsVolatile,
-                               IRBuilder<> &Builder) {
-    return Builder.CreateAlignedStore(ColumnValue, ColumnPtr,
-                                      DL.getABITypeAlign(EltType), IsVolatile);
-  }
-
   /// Turns \p BasePtr into an elementwise pointer to \p EltType.
   Value *createElementPtr(Value *BasePtr, Type *EltType, IRBuilder<> &Builder) {
     unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace();
@@ -777,10 +764,30 @@ class LowerMatrixIntrinsics {
     return true;
   }
 
+  /// Compute the alignment for a column/row \p Idx with \p Stride between them.
+  /// The address at \p Idx == 0 has alignment \p A. If \p Stride is a
+  /// ConstantInt, reduce the initial alignment based on the byte offset. For
+  /// non-ConstantInt strides, return the common alignment of the initial
+  /// alignment and the element size in bytes.
+  Align getAlignForIndex(unsigned Idx, Value *Stride, Type *ElementTy,
+                         MaybeAlign A) const {
+    Align InitialAlign = DL.getValueOrABITypeAlignment(A, ElementTy);
+    if (Idx == 0)
+      return InitialAlign;
+
+    TypeSize ElementSizeInBits = DL.getTypeSizeInBits(ElementTy);
+    if (auto *ConstStride = dyn_cast<ConstantInt>(Stride)) {
+      uint64_t StrideInBytes =
+          ConstStride->getZExtValue() * ElementSizeInBits / 8;
+      return commonAlignment(InitialAlign, Idx * StrideInBytes);
+    }
+    return commonAlignment(InitialAlign, ElementSizeInBits / 8);
+  }
+
   /// Load a matrix with \p Shape starting at \p Ptr and using \p Stride between
   /// vectors.
-  MatrixTy loadMatrix(Type *Ty, Value *Ptr, Value *Stride, bool IsVolatile,
-                      ShapeInfo Shape, IRBuilder<> &Builder) {
+  MatrixTy loadMatrix(Type *Ty, Value *Ptr, MaybeAlign MAlign, Value *Stride,
+                      bool IsVolatile, ShapeInfo Shape, IRBuilder<> &Builder) {
     auto VType = cast<VectorType>(Ty);
     Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder);
     MatrixTy Result;
@@ -788,8 +795,10 @@ class LowerMatrixIntrinsics {
       Value *GEP = computeVectorAddr(EltPtr, Builder.getInt64(I), Stride,
                                      Shape.getStride(), VType->getElementType(),
                                      Builder);
-      Value *Vector =
-          createVectorLoad(GEP, VType->getElementType(), IsVolatile, Builder);
+      Value *Vector = Builder.CreateAlignedLoad(
+          GEP, getAlignForIndex(I, Stride, VType->getElementType(), MAlign),
+          IsVolatile, "col.load");
+
       Result.addVector(Vector);
     }
     return Result.addNumLoads(getNumOps(Result.getVectorTy()) *
@@ -798,8 +807,9 @@ class LowerMatrixIntrinsics {
 
   /// Loads a sub-matrix with shape \p ResultShape from a \p R x \p C matrix,
   /// starting at \p MatrixPtr[I][J].
-  MatrixTy loadMatrix(Value *MatrixPtr, bool IsVolatile, ShapeInfo MatrixShape,
-                      Value *I, Value *J, ShapeInfo ResultShape, Type *EltTy,
+  MatrixTy loadMatrix(Value *MatrixPtr, MaybeAlign Align, bool IsVolatile,
+                      ShapeInfo MatrixShape, Value *I, Value *J,
+                      ShapeInfo ResultShape, Type *EltTy,
                       IRBuilder<> &Builder) {
 
     Value *Offset = Builder.CreateAdd(
@@ -815,19 +825,19 @@ class LowerMatrixIntrinsics {
     Value *TilePtr =
         Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast");
 
-    return loadMatrix(TileTy, TilePtr,
+    return loadMatrix(TileTy, TilePtr, Align,
                       Builder.getInt64(MatrixShape.getStride()), IsVolatile,
                       ResultShape, Builder);
   }
 
   /// Lower a load instruction with shape information.
-  void LowerLoad(Instruction *Inst, Value *Ptr, Value *Stride, bool IsVolatile,
-                 ShapeInfo Shape) {
+  void LowerLoad(Instruction *Inst, Value *Ptr, MaybeAlign Align, Value *Stride,
+                 bool IsVolatile, ShapeInfo Shape) {
     IRBuilder<> Builder(Inst);
-    finalizeLowering(
-        Inst,
-        loadMatrix(Inst->getType(), Ptr, Stride, IsVolatile, Shape, Builder),
-        Builder);
+    finalizeLowering(Inst,
+                     loadMatrix(Inst->getType(), Ptr, Align, Stride, IsVolatile,
+                                Shape, Builder),
+                     Builder);
   }
 
   /// Lowers llvm.matrix.column.major.load.
@@ -838,16 +848,16 @@ class LowerMatrixIntrinsics {
            "Intrinsic only supports column-major layout!");
     Value *Ptr = Inst->getArgOperand(0);
     Value *Stride = Inst->getArgOperand(1);
-    LowerLoad(Inst, Ptr, Stride,
+    LowerLoad(Inst, Ptr, Inst->getParamAlign(0), Stride,
               cast<ConstantInt>(Inst->getArgOperand(2))->isOne(),
               {Inst->getArgOperand(3), Inst->getArgOperand(4)});
   }
 
   /// Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p
   /// MatrixPtr[I][J].
-  void storeMatrix(const MatrixTy &StoreVal, Value *MatrixPtr, bool IsVolatile,
-                   ShapeInfo MatrixShape, Value *I, Value *J, Type *EltTy,
-                   IRBuilder<> &Builder) {
+  void storeMatrix(const MatrixTy &StoreVal, Value *MatrixPtr,
+                   MaybeAlign MAlign, bool IsVolatile, ShapeInfo MatrixShape,
+                   Value *I, Value *J, Type *EltTy, IRBuilder<> &Builder) {
     Value *Offset = Builder.CreateAdd(
         Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I);
 
@@ -861,34 +871,38 @@ class LowerMatrixIntrinsics {
     Value *TilePtr =
         Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast");
 
-    storeMatrix(TileTy, StoreVal, TilePtr,
+    storeMatrix(TileTy, StoreVal, TilePtr, MAlign,
                 Builder.getInt64(MatrixShape.getStride()), IsVolatile, Builder);
   }
 
   /// Store matrix \p StoreVal starting at \p Ptr and using \p Stride between
   /// vectors.
-  MatrixTy storeMatrix(Type *Ty, MatrixTy StoreVal, Value *Ptr, Value *Stride,
-                       bool IsVolatile, IRBuilder<> &Builder) {
+  MatrixTy storeMatrix(Type *Ty, MatrixTy StoreVal, Value *Ptr,
+                       MaybeAlign MAlign, Value *Stride, bool IsVolatile,
+                       IRBuilder<> &Builder) {
     auto VType = cast<VectorType>(Ty);
     Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder);
     for (auto Vec : enumerate(StoreVal.vectors())) {
       Value *GEP = computeVectorAddr(EltPtr, Builder.getInt64(Vec.index()),
                                      Stride, StoreVal.getStride(),
                                      VType->getElementType(), Builder);
-      createVectorStore(Vec.value(), GEP, VType->getElementType(), IsVolatile,
-                        Builder);
+      Builder.CreateAlignedStore(Vec.value(), GEP,
+                                 getAlignForIndex(Vec.index(), Stride,
+                                                  VType->getElementType(),
+                                                  MAlign),
+                                 IsVolatile);
     }
     return MatrixTy().addNumStores(getNumOps(StoreVal.getVectorTy()) *
                                    StoreVal.getNumVectors());
   }
 
   /// Lower a store instruction with shape information.
-  void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, Value *Stride,
-                  bool IsVolatile, ShapeInfo Shape) {
+  void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, MaybeAlign A,
+                  Value *Stride, bool IsVolatile, ShapeInfo Shape) {
     IRBuilder<> Builder(Inst);
     auto StoreVal = getMatrix(Matrix, Shape, Builder);
     finalizeLowering(Inst,
-                     storeMatrix(Matrix->getType(), StoreVal, Ptr, Stride,
+                     storeMatrix(Matrix->getType(), StoreVal, Ptr, A, Stride,
                                  IsVolatile, Builder),
                      Builder);
   }
@@ -902,7 +916,7 @@ class LowerMatrixIntrinsics {
     Value *Matrix = Inst->getArgOperand(0);
     Value *Ptr = Inst->getArgOperand(1);
     Value *Stride = Inst->getArgOperand(2);
-    LowerStore(Inst, Matrix, Ptr, Stride,
+    LowerStore(Inst, Matrix, Ptr, Inst->getParamAlign(1), Stride,
                cast<ConstantInt>(Inst->getArgOperand(3))->isOne(),
                {Inst->getArgOperand(4), Inst->getArgOperand(5)});
   }
@@ -1215,16 +1229,18 @@ class LowerMatrixIntrinsics {
 
         for (unsigned K = 0; K < M; K += TileSize) {
           const unsigned TileM = std::min(M - K, unsigned(TileSize));
-          MatrixTy A = loadMatrix(APtr, LoadOp0->isVolatile(), LShape,
-                                  Builder.getInt64(I), Builder.getInt64(K),
-                                  {TileR, TileM}, EltType, Builder);
-          MatrixTy B = loadMatrix(BPtr, LoadOp1->isVolatile(), RShape,
-                                  Builder.getInt64(K), Builder.getInt64(J),
-                                  {TileM, TileC}, EltType, Builder);
+          MatrixTy A =
+              loadMatrix(APtr, LoadOp0->getAlign(), LoadOp0->isVolatile(),
+                         LShape, Builder.getInt64(I), Builder.getInt64(K),
+                         {TileR, TileM}, EltType, Builder);
+          MatrixTy B =
+              loadMatrix(BPtr, LoadOp1->getAlign(), LoadOp1->isVolatile(),
+                         RShape, Builder.getInt64(K), Builder.getInt64(J),
+                         {TileM, TileC}, EltType, Builder);
           emitMatrixMultiply(Res, A, B, AllowContract, Builder, true);
         }
-        storeMatrix(Res, CPtr, Store->isVolatile(), {R, M}, Builder.getInt64(I),
-                    Builder.getInt64(J), EltType, Builder);
+        storeMatrix(Res, CPtr, Store->getAlign(), Store->isVolatile(), {R, M},
+                    Builder.getInt64(I), Builder.getInt64(J), EltType, Builder);
       }
 
     // Mark eliminated instructions as fused and remove them.
@@ -1337,8 +1353,9 @@ class LowerMatrixIntrinsics {
     if (I == ShapeMap.end())
       return false;
 
-    LowerLoad(Inst, Ptr, Builder.getInt64(I->second.getStride()),
-              Inst->isVolatile(), I->second);
+    LowerLoad(Inst, Ptr, Inst->getAlign(),
+              Builder.getInt64(I->second.getStride()), Inst->isVolatile(),
+              I->second);
     return true;
   }
 
@@ -1348,8 +1365,9 @@ class LowerMatrixIntrinsics {
     if (I == ShapeMap.end())
       return false;
 
-    LowerStore(Inst, StoredVal, Ptr, Builder.getInt64(I->second.getStride()),
-               Inst->isVolatile(), I->second);
+    LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(),
+               Builder.getInt64(I->second.getStride()), Inst->isVolatile(),
+               I->second);
     return true;
   }
 

diff  --git a/llvm/test/Transforms/LowerMatrixIntrinsics/const-gep.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/const-gep.ll
index 8caddb0b91a1..69bc882caaa8 100644
--- a/llvm/test/Transforms/LowerMatrixIntrinsics/const-gep.ll
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/const-gep.ll
@@ -76,9 +76,9 @@ entry:
   %c.addr = alloca i32, align 4
   store i32 %r, i32* %r.addr, align 4
   store i32 %c, i32* %c.addr, align 4
-  %0 = load <4 x double>, <4 x double>* getelementptr inbounds ([5 x <4 x double>], [5 x <4 x double>]* @foo, i64 0, i64 0), align 16
+  %0 = load <4 x double>, <4 x double>* getelementptr inbounds ([5 x <4 x double>], [5 x <4 x double>]* @foo, i64 0, i64 0), align 8
   %mul = call <4 x double> @llvm.matrix.multiply(<4 x double> %0, <4 x double> %0, i32 2, i32 2, i32 2)
-  store <4 x double> %0, <4 x double>* getelementptr inbounds ([5 x <4 x double>], [5 x <4 x double>]* @foo, i64 0, i64 2), align 16
+  store <4 x double> %0, <4 x double>* getelementptr inbounds ([5 x <4 x double>], [5 x <4 x double>]* @foo, i64 0, i64 2), align 8
   ret void
 }
 

diff  --git a/llvm/test/Transforms/LowerMatrixIntrinsics/load-align-volatile.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/load-align-volatile.ll
index 9da5c20c3000..14b81a1d8d9b 100644
--- a/llvm/test/Transforms/LowerMatrixIntrinsics/load-align-volatile.ll
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/load-align-volatile.ll
@@ -51,7 +51,7 @@ define <9 x double> @strided_load_3x3_align32(<9 x double>* %in, i64 %stride) {
 ; CHECK-NEXT:    [[VEC_START:%.*]] = mul i64 0, [[STRIDE:%.*]]
 ; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr double, double* [[TMP0]], i64 [[VEC_START]]
 ; CHECK-NEXT:    [[VEC_CAST:%.*]] = bitcast double* [[VEC_GEP]] to <3 x double>*
-; CHECK-NEXT:    load <3 x double>, <3 x double>* [[VEC_CAST]], align 8
+; CHECK-NEXT:    load <3 x double>, <3 x double>* [[VEC_CAST]], align 32
 ; CHECK-NEXT:    [[VEC_START1:%.*]] = mul i64 1, [[STRIDE]]
 ; CHECK-NEXT:    [[VEC_GEP2:%.*]] = getelementptr double, double* [[TMP0]], i64 [[VEC_START1]]
 ; CHECK-NEXT:    [[VEC_CAST3:%.*]] = bitcast double* [[VEC_GEP2]] to <3 x double>*
@@ -74,15 +74,15 @@ define <9 x double> @strided_load_3x3_align2(<9 x double>* %in, i64 %stride) {
 ; CHECK-NEXT:    [[VEC_START:%.*]] = mul i64 0, [[STRIDE:%.*]]
 ; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr double, double* [[TMP0]], i64 [[VEC_START]]
 ; CHECK-NEXT:    [[VEC_CAST:%.*]] = bitcast double* [[VEC_GEP]] to <3 x double>*
-; CHECK-NEXT:    load <3 x double>, <3 x double>* [[VEC_CAST]], align 8
+; CHECK-NEXT:    load <3 x double>, <3 x double>* [[VEC_CAST]], align 2
 ; CHECK-NEXT:    [[VEC_START1:%.*]] = mul i64 1, [[STRIDE]]
 ; CHECK-NEXT:    [[VEC_GEP2:%.*]] = getelementptr double, double* [[TMP0]], i64 [[VEC_START1]]
 ; CHECK-NEXT:    [[VEC_CAST3:%.*]] = bitcast double* [[VEC_GEP2]] to <3 x double>*
-; CHECK-NEXT:    load <3 x double>, <3 x double>* [[VEC_CAST3]], align 8
+; CHECK-NEXT:    load <3 x double>, <3 x double>* [[VEC_CAST3]], align 2
 ; CHECK-NEXT:    [[VEC_START5:%.*]] = mul i64 2, [[STRIDE]]
 ; CHECK-NEXT:    [[VEC_GEP6:%.*]] = getelementptr double, double* [[TMP0]], i64 [[VEC_START5]]
 ; CHECK-NEXT:    [[VEC_CAST7:%.*]] = bitcast double* [[VEC_GEP6]] to <3 x double>*
-; CHECK-NEXT:    load <3 x double>, <3 x double>* [[VEC_CAST7]], align 8
+; CHECK-NEXT:    load <3 x double>, <3 x double>* [[VEC_CAST7]], align 2
 ; CHECK-NOT:     = load
 ;
 entry:
@@ -95,10 +95,10 @@ define <4 x double> @load_align2_multiply(<4 x double>* %in) {
 ; CHECK-LABEL: @load_align2_multiply(
 ; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <4 x double>* [[IN:%.*]] to double*
 ; CHECK-NEXT:    [[VEC_CAST:%.*]] = bitcast double* [[TMP1]] to <2 x double>*
-; CHECK-NEXT:    load <2 x double>, <2 x double>* [[VEC_CAST]], align 8
+; CHECK-NEXT:    load <2 x double>, <2 x double>* [[VEC_CAST]], align 2
 ; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr double, double* [[TMP1]], i64 2
 ; CHECK-NEXT:    [[VEC_CAST1:%.*]] = bitcast double* [[VEC_GEP]] to <2 x double>*
-; CHECK-NEXT:    load <2 x double>, <2 x double>* [[VEC_CAST1]], align 8
+; CHECK-NEXT:    load <2 x double>, <2 x double>* [[VEC_CAST1]], align 2
 ; CHECK-NOT:     = load
 ;
   %in.m = load <4 x double>, <4 x double>* %in, align 2
@@ -111,13 +111,13 @@ define <6 x float> @strided_load_2x3_align16_stride2(<6 x float>* %in) {
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[TMP0:%.*]] = bitcast <6 x float>* [[IN:%.*]] to float*
 ; CHECK-NEXT:    [[VEC_CAST:%.*]] = bitcast float* [[TMP0]] to <2 x float>*
-; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x float>, <2 x float>* [[VEC_CAST]], align 4
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x float>, <2 x float>* [[VEC_CAST]], align 16
 ; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, float* [[TMP0]], i64 2
 ; CHECK-NEXT:    [[VEC_CAST1:%.*]] = bitcast float* [[VEC_GEP]] to <2 x float>*
-; CHECK-NEXT:    [[COL_LOAD2:%.*]] = load <2 x float>, <2 x float>* [[VEC_CAST1]], align 4
+; CHECK-NEXT:    [[COL_LOAD2:%.*]] = load <2 x float>, <2 x float>* [[VEC_CAST1]], align 8
 ; CHECK-NEXT:    [[VEC_GEP3:%.*]] = getelementptr float, float* [[TMP0]], i64 4
 ; CHECK-NEXT:    [[VEC_CAST4:%.*]] = bitcast float* [[VEC_GEP3]] to <2 x float>*
-; CHECK-NEXT:    [[COL_LOAD5:%.*]] = load <2 x float>, <2 x float>* [[VEC_CAST4]], align 4
+; CHECK-NEXT:    [[COL_LOAD5:%.*]] = load <2 x float>, <2 x float>* [[VEC_CAST4]], align 16
 ; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <2 x float> [[COL_LOAD]], <2 x float> [[COL_LOAD2]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
 ; CHECK-NEXT:    [[TMP2:%.*]] = shufflevector <2 x float> [[COL_LOAD5]], <2 x float> undef, <4 x i32> <i32 0, i32 1, i32 undef, i32 undef>
 ; CHECK-NEXT:    [[TMP3:%.*]] = shufflevector <4 x float> [[TMP1]], <4 x float> [[TMP2]], <6 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5>

diff  --git a/llvm/test/Transforms/LowerMatrixIntrinsics/store-align-volatile.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/store-align-volatile.ll
index 7b60d69b7d9e..6688dadbac0a 100644
--- a/llvm/test/Transforms/LowerMatrixIntrinsics/store-align-volatile.ll
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/store-align-volatile.ll
@@ -43,7 +43,7 @@ define void @strided_store_align32(<6 x i32> %in, i64 %stride, i32* %out) {
 ; CHECK-NEXT:    [[VEC_START:%.*]] = mul i64 0, [[STRIDE:%.*]]
 ; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr i32, i32* [[OUT:%.*]], i64 [[VEC_START]]
 ; CHECK-NEXT:    [[VEC_CAST:%.*]] = bitcast i32* [[VEC_GEP]] to <3 x i32>*
-; CHECK-NEXT:    store volatile <3 x i32> [[SPLIT]], <3 x i32>* [[VEC_CAST]], align 4
+; CHECK-NEXT:    store volatile <3 x i32> [[SPLIT]], <3 x i32>* [[VEC_CAST]], align 32
 ; CHECK-NEXT:    [[VEC_START2:%.*]] = mul i64 1, [[STRIDE]]
 ; CHECK-NEXT:    [[VEC_GEP3:%.*]] = getelementptr i32, i32* [[OUT]], i64 [[VEC_START2]]
 ; CHECK-NEXT:    [[VEC_CAST4:%.*]] = bitcast i32* [[VEC_GEP3]] to <3 x i32>*
@@ -61,11 +61,11 @@ define void @strided_store_align2(<6 x i32> %in, i64 %stride, i32* %out) {
 ; CHECK-NEXT:    [[VEC_START:%.*]] = mul i64 0, [[STRIDE:%.*]]
 ; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr i32, i32* [[OUT:%.*]], i64 [[VEC_START]]
 ; CHECK-NEXT:    [[VEC_CAST:%.*]] = bitcast i32* [[VEC_GEP]] to <3 x i32>*
-; CHECK-NEXT:    store volatile <3 x i32> [[SPLIT]], <3 x i32>* [[VEC_CAST]], align 4
+; CHECK-NEXT:    store volatile <3 x i32> [[SPLIT]], <3 x i32>* [[VEC_CAST]], align 2
 ; CHECK-NEXT:    [[VEC_START2:%.*]] = mul i64 1, [[STRIDE]]
 ; CHECK-NEXT:    [[VEC_GEP3:%.*]] = getelementptr i32, i32* [[OUT]], i64 [[VEC_START2]]
 ; CHECK-NEXT:    [[VEC_CAST4:%.*]] = bitcast i32* [[VEC_GEP3]] to <3 x i32>*
-; CHECK-NEXT:    store volatile <3 x i32> [[SPLIT1]], <3 x i32>* [[VEC_CAST4]], align 4
+; CHECK-NEXT:    store volatile <3 x i32> [[SPLIT1]], <3 x i32>* [[VEC_CAST4]], align 2
 ; CHECK-NEXT:    ret void
 ;
   call void @llvm.matrix.column.major.store(<6 x i32> %in, i32* align 2 %out, i64 %stride, i1 true, i32 3, i32 2)
@@ -76,10 +76,10 @@ define void @multiply_store_align16_stride8(<4 x i32> %in, <4 x i32>* %out) {
 ; CHECK-LABEL: @multiply_store_align16_stride8(
 ; CHECK:         [[TMP29:%.*]] = bitcast <4 x i32>* %out to i32*
 ; CHECK-NEXT:    [[VEC_CAST:%.*]] = bitcast i32* [[TMP29]] to <2 x i32>*
-; CHECK-NEXT:    store <2 x i32> {{.*}}, <2 x i32>* [[VEC_CAST]], align 4
+; CHECK-NEXT:    store <2 x i32> {{.*}}, <2 x i32>* [[VEC_CAST]], align 16
 ; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr i32, i32* [[TMP29]], i64 2
 ; CHECK-NEXT:    [[VEC_CAST25:%.*]] = bitcast i32* [[VEC_GEP]] to <2 x i32>*
-; CHECK-NEXT:    store <2 x i32> {{.*}}, <2 x i32>* [[VEC_CAST25]], align 4
+; CHECK-NEXT:    store <2 x i32> {{.*}}, <2 x i32>* [[VEC_CAST25]], align 8
 ; CHECK-NEXT:    ret void
 ;
   %res = call <4 x i32> @llvm.matrix.multiply(<4 x i32> %in, <4 x i32> %in, i32 2, i32 2, i32 2)
@@ -93,13 +93,13 @@ define void @strided_store_align8_stride12(<6 x i32> %in, i32* %out) {
 ; CHECK-NEXT:    [[SPLIT1:%.*]] = shufflevector <6 x i32> [[IN]], <6 x i32> undef, <2 x i32> <i32 2, i32 3>
 ; CHECK-NEXT:    [[SPLIT2:%.*]] = shufflevector <6 x i32> [[IN]], <6 x i32> undef, <2 x i32> <i32 4, i32 5>
 ; CHECK-NEXT:    [[VEC_CAST:%.*]] = bitcast i32* [[OUT:%.*]] to <2 x i32>*
-; CHECK-NEXT:    store <2 x i32> [[SPLIT]], <2 x i32>* [[VEC_CAST]], align 4
+; CHECK-NEXT:    store <2 x i32> [[SPLIT]], <2 x i32>* [[VEC_CAST]], align 8
 ; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr i32, i32* [[OUT]], i64 3
 ; CHECK-NEXT:    [[VEC_CAST3:%.*]] = bitcast i32* [[VEC_GEP]] to <2 x i32>*
 ; CHECK-NEXT:    store <2 x i32> [[SPLIT1]], <2 x i32>* [[VEC_CAST3]], align 4
 ; CHECK-NEXT:    [[VEC_GEP4:%.*]] = getelementptr i32, i32* [[OUT]], i64 6
 ; CHECK-NEXT:    [[VEC_CAST5:%.*]] = bitcast i32* [[VEC_GEP4]] to <2 x i32>*
-; CHECK-NEXT:    store <2 x i32> [[SPLIT2]], <2 x i32>* [[VEC_CAST5]], align 4
+; CHECK-NEXT:    store <2 x i32> [[SPLIT2]], <2 x i32>* [[VEC_CAST5]], align 8
 ; CHECK-NEXT:    ret void
 ;
   call void @llvm.matrix.column.major.store(<6 x i32> %in, i32* align 8 %out, i64 3, i1 false, i32 2, i32 3)


        


More information about the llvm-commits mailing list