[llvm] [Matrix] Use data layout index type for lower matrix intrinsics (PR #162646)

via llvm-commits llvm-commits at lists.llvm.org
Thu Oct 9 05:33:06 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Nathan Corbyn (cofibrant)

<details>
<summary>Changes</summary>

I've also included a commit that slightly refactors how shape information is propagated.

CC @<!-- -->fhahn 

---

Patch is 70.28 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/162646.diff


7 Files Affected:

- (modified) llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp (+63-42) 
- (modified) llvm/test/Transforms/LowerMatrixIntrinsics/strided-load-double.ll (+92-42) 
- (modified) llvm/test/Transforms/LowerMatrixIntrinsics/strided-load-float.ll (+70-32) 
- (modified) llvm/test/Transforms/LowerMatrixIntrinsics/strided-load-i32.ll (+70-32) 
- (modified) llvm/test/Transforms/LowerMatrixIntrinsics/strided-store-double.ll (+94-44) 
- (modified) llvm/test/Transforms/LowerMatrixIntrinsics/strided-store-float.ll (+72-34) 
- (modified) llvm/test/Transforms/LowerMatrixIntrinsics/strided-store-i32.ll (+72-34) 


``````````diff
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 56e0569831e83..408372efdb93b 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -44,7 +44,7 @@
 #include "llvm/Support/Alignment.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Compiler.h"
-#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
 #include "llvm/Transforms/Utils/LoopUtils.h"
 #include "llvm/Transforms/Utils/MatrixUtils.h"
@@ -241,11 +241,16 @@ raw_ostream &operator<<(raw_ostream &OS, ShapeInfo SI) {
 
 } // namespace
 
-static bool isUniformShape(Value *V) {
+/// Returns true if \p V is an instruction whose result is of the same shape
+/// as its operands (or if \p V is a non-instruction value).
+static bool isShapePreserving(Value *V) {
   Instruction *I = dyn_cast<Instruction>(V);
   if (!I)
     return true;
 
+  if (isa<SelectInst>(I))
+    return true;
+
   if (I->isBinaryOp())
     return true;
 
@@ -296,6 +301,13 @@ static bool isUniformShape(Value *V) {
   }
 }
 
+static iterator_range<Use *> getShapedOperands(Instruction *I) {
+  auto Ops = I->operands();
+  // Ignore shape information for the predicate operand of a `select`
+  // instruction
+  return isa<SelectInst>(I) ? drop_begin(Ops) : Ops;
+}
+
 /// Return the ShapeInfo for the result of \p I, it it can be determined.
 static std::optional<ShapeInfo>
 computeShapeInfoForInst(Instruction *I,
@@ -325,9 +337,8 @@ computeShapeInfoForInst(Instruction *I,
       return OpShape->second;
   }
 
-  if (isUniformShape(I) || isa<SelectInst>(I)) {
-    auto Ops = I->operands();
-    auto ShapedOps = isa<SelectInst>(I) ? drop_begin(Ops) : Ops;
+  if (isShapePreserving(I)) {
+    auto ShapedOps = getShapedOperands(I);
     // Find the first operand that has a known shape and use that.
     for (auto &Op : ShapedOps) {
       auto OpShape = ShapeMap.find(Op.get());
@@ -633,18 +644,16 @@ class LowerMatrixIntrinsics {
       if (Found != Inst2ColumnMatrix.end()) {
         // FIXME: re: "at least": SplitVecs.size() doesn't count the shuffles
         // that embedInVector created.
-        LLVM_DEBUG(dbgs() << "matrix reshape from " << Found->second.shape()
-                          << " to " << SI << " using at least "
-                          << SplitVecs.size() << " shuffles on behalf of:\n"
-                          << *Inst << '\n');
+        LDBG() << "matrix reshape from " << Found->second.shape() << " to "
+               << SI << " using at least " << SplitVecs.size()
+               << " shuffles on behalf of:\n"
+               << *Inst << '\n';
         ReshapedMatrices++;
       } else if (!ShapeMap.contains(MatrixVal)) {
-        LLVM_DEBUG(
-            dbgs()
-            << "splitting a " << SI << " matrix with " << SplitVecs.size()
-            << " shuffles beacuse we do not have a shape-aware lowering for "
-               "its def:\n"
-            << *Inst << '\n');
+        LDBG() << "splitting a " << SI << " matrix with " << SplitVecs.size()
+               << " shuffles beacuse we do not have a shape-aware lowering for "
+                  "its def:\n"
+               << *Inst << '\n';
         (void)Inst;
         SplitMatrices++;
       } else {
@@ -675,15 +684,14 @@ class LowerMatrixIntrinsics {
             "Matrix shape verification failed, compilation aborted!");
       }
 
-      LLVM_DEBUG(dbgs() << "  not overriding existing shape: "
-                        << SIter->second.NumRows << " "
-                        << SIter->second.NumColumns << " for " << *V << "\n");
+      LDBG() << "  not overriding existing shape: " << SIter->second.NumRows
+             << " " << SIter->second.NumColumns << " for " << *V << "\n";
       return false;
     }
 
     ShapeMap.insert({V, Shape});
-    LLVM_DEBUG(dbgs() << "  " << Shape.NumRows << " x " << Shape.NumColumns
-                      << " for " << *V << "\n");
+    LDBG() << "  " << Shape.NumRows << " x " << Shape.NumColumns << " for "
+           << *V << "\n";
     return true;
   }
 
@@ -703,10 +711,9 @@ class LowerMatrixIntrinsics {
       case Intrinsic::matrix_column_major_store:
         return true;
       default:
-        return isUniformShape(II);
+        break;
       }
-    return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V) ||
-           isa<SelectInst>(V);
+    return isShapePreserving(V) || isa<StoreInst>(V) || isa<LoadInst>(V);
   }
 
   /// Propagate the shape information of instructions to their users.
@@ -719,7 +726,7 @@ class LowerMatrixIntrinsics {
     // Pop an element for which we guaranteed to have at least one of the
     // operand shapes.  Add the shape for this and then add users to the work
     // list.
-    LLVM_DEBUG(dbgs() << "Forward-propagate shapes:\n");
+    LDBG() << "Forward-propagate shapes:\n";
     while (!WorkList.empty()) {
       Instruction *Inst = WorkList.pop_back_val();
 
@@ -754,7 +761,7 @@ class LowerMatrixIntrinsics {
     // Pop an element with known shape.  Traverse the operands, if their shape
     // derives from the result shape and is unknown, add it and add them to the
     // worklist.
-    LLVM_DEBUG(dbgs() << "Backward-propagate shapes:\n");
+    LDBG() << "Backward-propagate shapes:\n";
     while (!WorkList.empty()) {
       Value *V = WorkList.pop_back_val();
 
@@ -778,7 +785,8 @@ class LowerMatrixIntrinsics {
 
       } else if (match(V, m_Intrinsic<Intrinsic::matrix_transpose>(
                               m_Value(MatrixA), m_Value(M), m_Value(N)))) {
-        // Flip dimensions.
+        // We're told MatrixA is M x N so propagate this information directly.
+        // Compare \f computeSahpeInfoForInst where the dimensions are flipped.
         if (setShapeInfo(MatrixA, {M, N}))
           pushInstruction(MatrixA, WorkList);
       } else if (match(V, m_Intrinsic<Intrinsic::matrix_column_major_store>(
@@ -793,10 +801,9 @@ class LowerMatrixIntrinsics {
       } else if (isa<StoreInst>(V)) {
         // Nothing to do.  We forward-propagated to this so we would just
         // backward propagate to an instruction with an already known shape.
-      } else if (isUniformShape(V) || isa<SelectInst>(V)) {
-        auto Ops = cast<Instruction>(V)->operands();
-        auto ShapedOps = isa<SelectInst>(V) ? drop_begin(Ops) : Ops;
-        // Propagate to all operands.
+      } else if (isShapePreserving(V)) {
+        auto ShapedOps = getShapedOperands(cast<Instruction>(V));
+        // Propagate to all shaped operands.
         ShapeInfo Shape = ShapeMap[V];
         for (Use &U : ShapedOps) {
           if (setShapeInfo(U.get(), Shape))
@@ -1295,6 +1302,19 @@ class LowerMatrixIntrinsics {
     return commonAlignment(InitialAlign, ElementSizeInBits / 8);
   }
 
+  IntegerType *getIndexType(Value *Ptr) const {
+    return cast<IntegerType>(DL.getIndexType(Ptr->getType()));
+  }
+
+  Value *getIndex(Value *Ptr, uint64_t V) const {
+    return ConstantInt::get(getIndexType(Ptr), V);
+  }
+
+  Value *truncateToIndexType(Value *Ptr, Value *V, IRBuilder<> &Builder) const {
+    assert(isa<IntegerType>(V->getType()));
+    return Builder.CreateZExtOrTrunc(V, getIndexType(Ptr), V->getName() + ".trunc");
+  }
+
   /// Load a matrix with \p Shape starting at \p Ptr and using \p Stride between
   /// vectors.
   MatrixTy loadMatrix(Type *Ty, Value *Ptr, MaybeAlign MAlign, Value *Stride,
@@ -1304,6 +1324,7 @@ class LowerMatrixIntrinsics {
     Type *VecTy = FixedVectorType::get(EltTy, Shape.getStride());
     Value *EltPtr = Ptr;
     MatrixTy Result;
+    Stride = truncateToIndexType(Ptr, Stride, Builder);
     for (unsigned I = 0, E = Shape.getNumVectors(); I < E; ++I) {
       Value *GEP = computeVectorAddr(
           EltPtr, Builder.getIntN(Stride->getType()->getScalarSizeInBits(), I),
@@ -1325,14 +1346,14 @@ class LowerMatrixIntrinsics {
                       ShapeInfo ResultShape, Type *EltTy,
                       IRBuilder<> &Builder) {
     Value *Offset = Builder.CreateAdd(
-        Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I);
+        Builder.CreateMul(J, getIndex(MatrixPtr, MatrixShape.getStride())), I);
 
     Value *TileStart = Builder.CreateGEP(EltTy, MatrixPtr, Offset);
     auto *TileTy = FixedVectorType::get(EltTy, ResultShape.NumRows *
                                                    ResultShape.NumColumns);
 
     return loadMatrix(TileTy, TileStart, Align,
-                      Builder.getInt64(MatrixShape.getStride()), IsVolatile,
+                      getIndex(MatrixPtr, MatrixShape.getStride()), IsVolatile,
                       ResultShape, Builder);
   }
 
@@ -1363,14 +1384,15 @@ class LowerMatrixIntrinsics {
                    MaybeAlign MAlign, bool IsVolatile, ShapeInfo MatrixShape,
                    Value *I, Value *J, Type *EltTy, IRBuilder<> &Builder) {
     Value *Offset = Builder.CreateAdd(
-        Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I);
+        Builder.CreateMul(J, getIndex(MatrixPtr, MatrixShape.getStride())), I);
 
     Value *TileStart = Builder.CreateGEP(EltTy, MatrixPtr, Offset);
     auto *TileTy = FixedVectorType::get(EltTy, StoreVal.getNumRows() *
                                                    StoreVal.getNumColumns());
 
     storeMatrix(TileTy, StoreVal, TileStart, MAlign,
-                Builder.getInt64(MatrixShape.getStride()), IsVolatile, Builder);
+                getIndex(MatrixPtr, MatrixShape.getStride()), IsVolatile,
+                Builder);
   }
 
   /// Store matrix \p StoreVal starting at \p Ptr and using \p Stride between
@@ -1380,6 +1402,7 @@ class LowerMatrixIntrinsics {
                        IRBuilder<> &Builder) {
     auto *VType = cast<FixedVectorType>(Ty);
     Value *EltPtr = Ptr;
+    Stride = truncateToIndexType(Ptr, Stride, Builder);
     for (auto Vec : enumerate(StoreVal.vectors())) {
       Value *GEP = computeVectorAddr(
           EltPtr,
@@ -2011,18 +2034,17 @@ class LowerMatrixIntrinsics {
             const unsigned TileM = std::min(M - K, unsigned(TileSize));
             MatrixTy A =
                 loadMatrix(APtr, LoadOp0->getAlign(), LoadOp0->isVolatile(),
-                           LShape, Builder.getInt64(I), Builder.getInt64(K),
+                           LShape, getIndex(APtr, I), getIndex(APtr, K),
                            {TileR, TileM}, EltType, Builder);
             MatrixTy B =
                 loadMatrix(BPtr, LoadOp1->getAlign(), LoadOp1->isVolatile(),
-                           RShape, Builder.getInt64(K), Builder.getInt64(J),
+                           RShape, getIndex(BPtr, K), getIndex(BPtr, J),
                            {TileM, TileC}, EltType, Builder);
             emitMatrixMultiply(Res, A, B, Builder, true, false,
                                getFastMathFlags(MatMul));
           }
           storeMatrix(Res, CPtr, Store->getAlign(), Store->isVolatile(), {R, M},
-                      Builder.getInt64(I), Builder.getInt64(J), EltType,
-                      Builder);
+                      getIndex(CPtr, I), getIndex(CPtr, J), EltType, Builder);
         }
     }
 
@@ -2254,15 +2276,14 @@ class LowerMatrixIntrinsics {
   /// Lower load instructions.
   MatrixTy VisitLoad(LoadInst *Inst, const ShapeInfo &SI, Value *Ptr,
                      IRBuilder<> &Builder) {
-    return LowerLoad(Inst, Ptr, Inst->getAlign(),
-                     Builder.getInt64(SI.getStride()), Inst->isVolatile(), SI,
-                     Builder);
+    return LowerLoad(Inst, Ptr, Inst->getAlign(), getIndex(Ptr, SI.getStride()),
+                     Inst->isVolatile(), SI, Builder);
   }
 
   MatrixTy VisitStore(StoreInst *Inst, const ShapeInfo &SI, Value *StoredVal,
                       Value *Ptr, IRBuilder<> &Builder) {
     return LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(),
-                      Builder.getInt64(SI.getStride()), Inst->isVolatile(), SI,
+                      getIndex(Ptr, SI.getStride()), Inst->isVolatile(), SI,
                       Builder);
   }
 
diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/strided-load-double.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/strided-load-double.ll
index ae7da19e1641e..72a12dc1e7c4c 100644
--- a/llvm/test/Transforms/LowerMatrixIntrinsics/strided-load-double.ll
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/strided-load-double.ll
@@ -1,22 +1,40 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
-; RUN: opt -passes='lower-matrix-intrinsics' -S < %s | FileCheck %s
+; RUN: opt -passes='lower-matrix-intrinsics' -data-layout='p:64:64' -S < %s | FileCheck %s --check-prefix=PTR64
+; RUN: opt -passes='lower-matrix-intrinsics' -data-layout='p:32:32' -S < %s | FileCheck %s --check-prefix=PTR32
 
 define <9 x double> @strided_load_3x3(ptr %in, i64 %stride) {
-; CHECK-LABEL: @strided_load_3x3(
-; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[VEC_START:%.*]] = mul i64 0, [[STRIDE:%.*]]
-; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr double, ptr [[IN:%.*]], i64 [[VEC_START]]
-; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <3 x double>, ptr [[VEC_GEP]], align 8
-; CHECK-NEXT:    [[VEC_START1:%.*]] = mul i64 1, [[STRIDE]]
-; CHECK-NEXT:    [[VEC_GEP2:%.*]] = getelementptr double, ptr [[IN]], i64 [[VEC_START1]]
-; CHECK-NEXT:    [[COL_LOAD4:%.*]] = load <3 x double>, ptr [[VEC_GEP2]], align 8
-; CHECK-NEXT:    [[VEC_START5:%.*]] = mul i64 2, [[STRIDE]]
-; CHECK-NEXT:    [[VEC_GEP6:%.*]] = getelementptr double, ptr [[IN]], i64 [[VEC_START5]]
-; CHECK-NEXT:    [[COL_LOAD8:%.*]] = load <3 x double>, ptr [[VEC_GEP6]], align 8
-; CHECK-NEXT:    [[TMP0:%.*]] = shufflevector <3 x double> [[COL_LOAD]], <3 x double> [[COL_LOAD4]], <6 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5>
-; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <3 x double> [[COL_LOAD8]], <3 x double> poison, <6 x i32> <i32 0, i32 1, i32 2, i32 poison, i32 poison, i32 poison>
-; CHECK-NEXT:    [[TMP2:%.*]] = shufflevector <6 x double> [[TMP0]], <6 x double> [[TMP1]], <9 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8>
-; CHECK-NEXT:    ret <9 x double> [[TMP2]]
+; PTR64-LABEL: @strided_load_3x3(
+; PTR64-NEXT:  entry:
+; PTR64-NEXT:    [[VEC_START:%.*]] = mul i64 0, [[STRIDE:%.*]]
+; PTR64-NEXT:    [[VEC_GEP:%.*]] = getelementptr double, ptr [[IN:%.*]], i64 [[VEC_START]]
+; PTR64-NEXT:    [[COL_LOAD:%.*]] = load <3 x double>, ptr [[VEC_GEP]], align 8
+; PTR64-NEXT:    [[VEC_START1:%.*]] = mul i64 1, [[STRIDE]]
+; PTR64-NEXT:    [[VEC_GEP2:%.*]] = getelementptr double, ptr [[IN]], i64 [[VEC_START1]]
+; PTR64-NEXT:    [[COL_LOAD3:%.*]] = load <3 x double>, ptr [[VEC_GEP2]], align 8
+; PTR64-NEXT:    [[VEC_START4:%.*]] = mul i64 2, [[STRIDE]]
+; PTR64-NEXT:    [[VEC_GEP5:%.*]] = getelementptr double, ptr [[IN]], i64 [[VEC_START4]]
+; PTR64-NEXT:    [[COL_LOAD6:%.*]] = load <3 x double>, ptr [[VEC_GEP5]], align 8
+; PTR64-NEXT:    [[TMP0:%.*]] = shufflevector <3 x double> [[COL_LOAD]], <3 x double> [[COL_LOAD3]], <6 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5>
+; PTR64-NEXT:    [[TMP1:%.*]] = shufflevector <3 x double> [[COL_LOAD6]], <3 x double> poison, <6 x i32> <i32 0, i32 1, i32 2, i32 poison, i32 poison, i32 poison>
+; PTR64-NEXT:    [[TMP2:%.*]] = shufflevector <6 x double> [[TMP0]], <6 x double> [[TMP1]], <9 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8>
+; PTR64-NEXT:    ret <9 x double> [[TMP2]]
+;
+; PTR32-LABEL: @strided_load_3x3(
+; PTR32-NEXT:  entry:
+; PTR32-NEXT:    [[STRIDE:%.*]] = trunc i64 [[STRIDE1:%.*]] to i32
+; PTR32-NEXT:    [[VEC_START:%.*]] = mul i32 0, [[STRIDE]]
+; PTR32-NEXT:    [[VEC_GEP:%.*]] = getelementptr double, ptr [[IN:%.*]], i32 [[VEC_START]]
+; PTR32-NEXT:    [[COL_LOAD:%.*]] = load <3 x double>, ptr [[VEC_GEP]], align 8
+; PTR32-NEXT:    [[VEC_START1:%.*]] = mul i32 1, [[STRIDE]]
+; PTR32-NEXT:    [[VEC_GEP2:%.*]] = getelementptr double, ptr [[IN]], i32 [[VEC_START1]]
+; PTR32-NEXT:    [[COL_LOAD3:%.*]] = load <3 x double>, ptr [[VEC_GEP2]], align 8
+; PTR32-NEXT:    [[VEC_START4:%.*]] = mul i32 2, [[STRIDE]]
+; PTR32-NEXT:    [[VEC_GEP5:%.*]] = getelementptr double, ptr [[IN]], i32 [[VEC_START4]]
+; PTR32-NEXT:    [[COL_LOAD6:%.*]] = load <3 x double>, ptr [[VEC_GEP5]], align 8
+; PTR32-NEXT:    [[TMP0:%.*]] = shufflevector <3 x double> [[COL_LOAD]], <3 x double> [[COL_LOAD3]], <6 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5>
+; PTR32-NEXT:    [[TMP1:%.*]] = shufflevector <3 x double> [[COL_LOAD6]], <3 x double> poison, <6 x i32> <i32 0, i32 1, i32 2, i32 poison, i32 poison, i32 poison>
+; PTR32-NEXT:    [[TMP2:%.*]] = shufflevector <6 x double> [[TMP0]], <6 x double> [[TMP1]], <9 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8>
+; PTR32-NEXT:    ret <9 x double> [[TMP2]]
 ;
 entry:
   %load = call <9 x double> @llvm.matrix.column.major.load.v9f64.i64(ptr %in, i64 %stride, i1 false, i32 3, i32 3)
@@ -26,12 +44,20 @@ entry:
 declare <9 x double> @llvm.matrix.column.major.load.v9f64.i64(ptr, i64, i1, i32, i32)
 
 define <9 x double> @strided_load_9x1(ptr %in, i64 %stride) {
-; CHECK-LABEL: @strided_load_9x1(
-; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[VEC_START:%.*]] = mul i64 0, [[STRIDE:%.*]]
-; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr double, ptr [[IN:%.*]], i64 [[VEC_START]]
-; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <9 x double>, ptr [[VEC_GEP]], align 8
-; CHECK-NEXT:    ret <9 x double> [[COL_LOAD]]
+; PTR64-LABEL: @strided_load_9x1(
+; PTR64-NEXT:  entry:
+; PTR64-NEXT:    [[VEC_START:%.*]] = mul i64 0, [[STRIDE:%.*]]
+; PTR64-NEXT:    [[VEC_GEP:%.*]] = getelementptr double, ptr [[IN:%.*]], i64 [[VEC_START]]
+; PTR64-NEXT:    [[COL_LOAD:%.*]] = load <9 x double>, ptr [[VEC_GEP]], align 8
+; PTR64-NEXT:    ret <9 x double> [[COL_LOAD]]
+;
+; PTR32-LABEL: @strided_load_9x1(
+; PTR32-NEXT:  entry:
+; PTR32-NEXT:    [[STRIDE_TRUNC:%.*]] = trunc i64 [[STRIDE:%.*]] to i32
+; PTR32-NEXT:    [[VEC_START:%.*]] = mul i32 0, [[STRIDE_TRUNC]]
+; PTR32-NEXT:    [[VEC_GEP:%.*]] = getelementptr double, ptr [[IN:%.*]], i32 [[VEC_START]]
+; PTR32-NEXT:    [[COL_LOAD:%.*]] = load <9 x double>, ptr [[VEC_GEP]], align 8
+; PTR32-NEXT:    ret <9 x double> [[COL_LOAD]]
 ;
 entry:
   %load = call <9 x double> @llvm.matrix.column.major.load.v9f64.i64(ptr %in, i64 %stride, i1 false, i32 9, i32 1)
@@ -41,16 +67,28 @@ entry:
 declare <8 x double> @llvm.matrix.column.major.load.v8f64.i64(ptr, i64, i1, i32, i32)
 
 define <8 x double> @strided_load_4x2(ptr %in, i64 %stride) {
-; CHECK-LABEL: @strided_load_4x2(
-; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[VEC_START:%.*]] = mul i64 0, [[STRIDE:%.*]]
-; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr double, ptr [[IN:%.*]], i64 [[VEC_START]]
-; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <4 x double>, ptr [[VEC_GEP]], align 8
-; CHECK-NEXT:    [[VEC_START1:%.*]] = mul i64 1, [[STRIDE]]
-; CHECK-NEXT:    [[VEC_GEP2:%.*]] = getelementptr double, ptr [[IN]], i64 [[VEC_START1]]
-; CHECK-NEXT:    [[COL_LOAD4:%.*]] = load <4 x double>, ptr [[VEC_GEP2]], align 8
-; CHECK-NEXT:    [[TMP0:%.*]] = shufflevector <4 x double> [[COL_LOAD]], <4 x double> [[COL_LOAD4]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
-; CHECK-NEXT:    ret <8 x double> [[TMP0]]
+; PTR64-LABEL: @strided_load_4x2(
+; PTR64-NEXT:  entry:
+; PTR64-NEXT:    [[VEC_START:%.*]] = mul i64 0, [[STRIDE:%.*]]
+; PTR64-NEXT:    [[VEC_GEP:%.*]] = getelementptr double, ptr [[IN:%.*]], i64 [[VEC_START]]
+; PTR64-NEXT:    [[COL_LOAD:%.*]] = load <4 x double>, ptr [[VEC_GEP]], align 8
+; PTR64-NEXT:    [[VEC_START1:%.*]] = mul i64 1, [[STRIDE]]
+; PTR64-NEXT:    [[VEC_GEP2:%.*]] = getelementptr double, ptr [[IN]], i64 [[VEC_START1]]
+; PTR64-NEXT:    [[COL_LOAD3:%.*]] = load <4 x double>, ptr [[VEC_GEP2]], align 8
+; PTR64-NEXT:    [[TMP0:%.*]] = shufflevector <4 x double> [[COL_LOAD]], <4 x double> [[COL_LOAD3]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+; PTR64-NEXT:    ret <8 x double> [[TMP0]]
+;
+; PTR32-LABEL: @strided_load_4x2(
+; PTR32-NEXT:  entry:
+; PTR32-NEXT:    [[STRIDE_TRUNC:%.*]] = trunc i64 [[STRIDE:%.*]] to i32
+; PTR32-NEXT:    [[VEC_START:%.*]] = mul i32 0, [[STRIDE_TRUNC]]
+; PTR32-NEXT:    [[VEC_GEP:%.*]] = getelementptr double, ptr [[IN:%.*]], i32 [[VEC_START]]
+; PTR32-NEXT:    [[COL_LOAD:%.*]] = load <4 x double>, ptr [[VEC_GEP]], align...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/162646


More information about the llvm-commits mailing list