[llvm] [Matrix] Propagate shape information through cast insts (PR #141869)

Jon Roelofs via llvm-commits llvm-commits at lists.llvm.org
Thu May 29 09:50:43 PDT 2025


https://github.com/jroelofs updated https://github.com/llvm/llvm-project/pull/141869

>From 3779c5b6461c737598de9d73f69400a84752247f Mon Sep 17 00:00:00 2001
From: Jon Roelofs <jonathan_roelofs at apple.com>
Date: Wed, 28 May 2025 15:32:53 -0700
Subject: [PATCH 1/5] [Matrix] Propagate shape information through cast
 instructions

---
 .../Scalar/LowerMatrixIntrinsics.cpp          |  65 ++++-
 .../Transforms/LowerMatrixIntrinsics/unary.ll | 237 ++++++++++++++++++
 2 files changed, 300 insertions(+), 2 deletions(-)
 create mode 100644 llvm/test/Transforms/LowerMatrixIntrinsics/unary.ll

diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 756a72e6d97bc..4702355ca4577 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -32,8 +32,10 @@
 #include "llvm/IR/CFG.h"
 #include "llvm/IR/DataLayout.h"
 #include "llvm/IR/DebugInfoMetadata.h"
+#include "llvm/IR/DerivedTypes.h"
 #include "llvm/IR/Function.h"
 #include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/InstrTypes.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/MatrixBuilder.h"
@@ -232,6 +234,32 @@ static bool isUniformShape(Value *V) {
   if (I->isBinaryOp())
     return true;
 
+  if (auto *Cast = dyn_cast<CastInst>(V))
+    switch (Cast->getOpcode()) {
+    case llvm::Instruction::Trunc:
+    case llvm::Instruction::ZExt:
+    case llvm::Instruction::SExt:
+    case llvm::Instruction::FPToUI:
+    case llvm::Instruction::FPToSI:
+    case llvm::Instruction::UIToFP:
+    case llvm::Instruction::SIToFP:
+    case llvm::Instruction::FPTrunc:
+    case llvm::Instruction::FPExt:
+      return true;
+    case llvm::Instruction::AddrSpaceCast:
+    case CastInst::PtrToInt:
+    case CastInst::IntToPtr:
+      return false;
+    case CastInst::BitCast: {
+      if (auto *SrcVTy = dyn_cast<FixedVectorType>(Cast->getSrcTy()))
+        if (auto *DestVTy = dyn_cast<FixedVectorType>(Cast->getDestTy()))
+          return SrcVTy->getNumElements() == DestVTy->getNumElements();
+      return false;
+    }
+    case llvm::Instruction::CastOpsEnd:
+      llvm_unreachable("not an actual cast op");
+    }
+
   switch (I->getOpcode()) {
   case Instruction::FNeg:
     return true;
@@ -1066,9 +1094,11 @@ class LowerMatrixIntrinsics {
       Value *Op2;
       if (auto *BinOp = dyn_cast<BinaryOperator>(Inst))
         Changed |= VisitBinaryOperator(BinOp);
-      if (auto *UnOp = dyn_cast<UnaryOperator>(Inst))
+      else if (auto *UnOp = dyn_cast<UnaryOperator>(Inst))
         Changed |= VisitUnaryOperator(UnOp);
-      if (match(Inst, m_Load(m_Value(Op1))))
+      else if (auto *Cast = dyn_cast<CastInst>(Inst))
+        Changed |= VisitCastInstruction(Cast);
+      else if (match(Inst, m_Load(m_Value(Op1))))
         Changed |= VisitLoad(cast<LoadInst>(Inst), Op1, Builder);
       else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2))))
         Changed |= VisitStore(cast<StoreInst>(Inst), Op1, Op2, Builder);
@@ -2198,6 +2228,37 @@ class LowerMatrixIntrinsics {
     return true;
   }
 
+  /// Lower cast instructions, if shape information is available.
+  bool VisitCastInstruction(CastInst *Inst) {
+    auto I = ShapeMap.find(Inst);
+    if (I == ShapeMap.end())
+      return false;
+
+    Value *Op = Inst->getOperand(0);
+
+    IRBuilder<> Builder(Inst);
+    ShapeInfo &Shape = I->second;
+
+    MatrixTy Result;
+    MatrixTy M = getMatrix(Op, Shape, Builder);
+
+    Builder.setFastMathFlags(getFastMathFlags(Inst));
+
+    auto *OrigVTy = cast<VectorType>(Inst->getType());
+    auto *NewVTy = VectorType::get(OrigVTy->getElementType(),
+                                   ElementCount::getFixed(M.getStride()));
+
+    for (unsigned I = 0; I < Shape.getNumVectors(); ++I)
+      Result.addVector(
+          Builder.CreateCast(Inst->getOpcode(), M.getVector(I), NewVTy));
+
+    finalizeLowering(Inst,
+                     Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
+                                             Result.getNumVectors()),
+                     Builder);
+    return true;
+  }
+
   /// Helper to linearize a matrix expression tree into a string. Currently
   /// matrix expressions are linarized by starting at an expression leaf and
   /// linearizing bottom up.
diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/unary.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/unary.ll
new file mode 100644
index 0000000000000..80527cf717c7b
--- /dev/null
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/unary.ll
@@ -0,0 +1,237 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -passes='lower-matrix-intrinsics' -S < %s | FileCheck %s
+
+define void @fneg_2x2(ptr %in, ptr %out) {
+; CHECK-LABEL: @fneg_2x2(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x float>, ptr [[IN:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 2
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
+; CHECK-NEXT:    [[TMP1:%.*]] = fneg <2 x float> [[COL_LOAD]]
+; CHECK-NEXT:    [[TMP2:%.*]] = fneg <2 x float> [[COL_LOAD1]]
+; CHECK-NEXT:    store <2 x float> [[TMP1]], ptr [[OUT:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP2:%.*]] = getelementptr float, ptr [[OUT]], i64 2
+; CHECK-NEXT:    store <2 x float> [[TMP2]], ptr [[VEC_GEP2]], align 8
+; CHECK-NEXT:    ret void
+;
+  %inv = load <4 x float>, ptr %in
+  %op = fneg <4 x float> %inv
+  %opt  = call <4 x float> @llvm.matrix.transpose(<4 x float> %op, i32 2, i32 2)
+  %optt = call <4 x float> @llvm.matrix.transpose(<4 x float> %opt, i32 2, i32 2)
+  store <4 x float> %optt, ptr %out
+  ret void
+}
+
+define void @trunc_2x2(ptr %in, ptr %out) {
+; CHECK-LABEL: @trunc_2x2(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x i64>, ptr [[IN:%.*]], align 32
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr i64, ptr [[IN]], i64 2
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x i64>, ptr [[VEC_GEP]], align 16
+; CHECK-NEXT:    [[TMP1:%.*]] = trunc <2 x i64> [[COL_LOAD]] to <2 x i32>
+; CHECK-NEXT:    [[TMP2:%.*]] = trunc <2 x i64> [[COL_LOAD1]] to <2 x i32>
+; CHECK-NEXT:    store <2 x i32> [[TMP1]], ptr [[OUT:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP2:%.*]] = getelementptr i32, ptr [[OUT]], i64 2
+; CHECK-NEXT:    store <2 x i32> [[TMP2]], ptr [[VEC_GEP2]], align 8
+; CHECK-NEXT:    ret void
+;
+  %inv = load <4 x i64>, ptr %in
+  %op = trunc <4 x i64> %inv to <4 x i32>
+  %opt  = call <4 x i32> @llvm.matrix.transpose(<4 x i32> %op, i32 2, i32 2)
+  %optt = call <4 x i32> @llvm.matrix.transpose(<4 x i32> %opt, i32 2, i32 2)
+  store <4 x i32> %optt, ptr %out
+  ret void
+}
+
+define void @zext_2x2(ptr %in, ptr %out) {
+; CHECK-LABEL: @zext_2x2(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x i16>, ptr [[IN:%.*]], align 8
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr i16, ptr [[IN]], i64 2
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x i16>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT:    [[TMP1:%.*]] = zext <2 x i16> [[COL_LOAD]] to <2 x i32>
+; CHECK-NEXT:    [[TMP2:%.*]] = zext <2 x i16> [[COL_LOAD1]] to <2 x i32>
+; CHECK-NEXT:    store <2 x i32> [[TMP1]], ptr [[OUT:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP2:%.*]] = getelementptr i32, ptr [[OUT]], i64 2
+; CHECK-NEXT:    store <2 x i32> [[TMP2]], ptr [[VEC_GEP2]], align 8
+; CHECK-NEXT:    ret void
+;
+  %inv = load <4 x i16>, ptr %in
+  %op = zext <4 x i16> %inv to <4 x i32>
+  %opt  = call <4 x i32> @llvm.matrix.transpose(<4 x i32> %op, i32 2, i32 2)
+  %optt = call <4 x i32> @llvm.matrix.transpose(<4 x i32> %opt, i32 2, i32 2)
+  store <4 x i32> %optt, ptr %out
+  ret void
+}
+
+define void @sext_2x2(ptr %in, ptr %out) {
+; CHECK-LABEL: @sext_2x2(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x i8>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr i8, ptr [[IN]], i64 2
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x i8>, ptr [[VEC_GEP]], align 2
+; CHECK-NEXT:    [[TMP1:%.*]] = sext <2 x i8> [[COL_LOAD]] to <2 x i16>
+; CHECK-NEXT:    [[TMP2:%.*]] = sext <2 x i8> [[COL_LOAD1]] to <2 x i16>
+; CHECK-NEXT:    store <2 x i16> [[TMP1]], ptr [[OUT:%.*]], align 8
+; CHECK-NEXT:    [[VEC_GEP2:%.*]] = getelementptr i16, ptr [[OUT]], i64 2
+; CHECK-NEXT:    store <2 x i16> [[TMP2]], ptr [[VEC_GEP2]], align 4
+; CHECK-NEXT:    ret void
+;
+  %inv = load <4 x i8>, ptr %in
+  %op = sext <4 x i8> %inv to <4 x i16>
+  %opt  = call <4 x i16> @llvm.matrix.transpose(<4 x i16> %op, i32 2, i32 2)
+  %optt = call <4 x i16> @llvm.matrix.transpose(<4 x i16> %opt, i32 2, i32 2)
+  store <4 x i16> %optt, ptr %out
+  ret void
+}
+
+define void @fptoui_2x2(ptr %in, ptr %out) {
+; CHECK-LABEL: @fptoui_2x2(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x float>, ptr [[IN:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 2
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
+; CHECK-NEXT:    [[TMP1:%.*]] = fptoui <2 x float> [[COL_LOAD]] to <2 x i32>
+; CHECK-NEXT:    [[TMP2:%.*]] = fptoui <2 x float> [[COL_LOAD1]] to <2 x i32>
+; CHECK-NEXT:    store <2 x i32> [[TMP1]], ptr [[OUT:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP2:%.*]] = getelementptr i32, ptr [[OUT]], i64 2
+; CHECK-NEXT:    store <2 x i32> [[TMP2]], ptr [[VEC_GEP2]], align 8
+; CHECK-NEXT:    ret void
+;
+  %inv = load <4 x float>, ptr %in
+  %op = fptoui <4 x float> %inv to <4 x i32>
+  %opt  = call <4 x i32> @llvm.matrix.transpose(<4 x i32> %op, i32 2, i32 2)
+  %optt = call <4 x i32> @llvm.matrix.transpose(<4 x i32> %opt, i32 2, i32 2)
+  store <4 x i32> %optt, ptr %out
+  ret void
+}
+
+define void @fptosi_2x2(ptr %in, ptr %out) {
+; CHECK-LABEL: @fptosi_2x2(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x float>, ptr [[IN:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 2
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
+; CHECK-NEXT:    [[TMP1:%.*]] = fptosi <2 x float> [[COL_LOAD]] to <2 x i32>
+; CHECK-NEXT:    [[TMP2:%.*]] = fptosi <2 x float> [[COL_LOAD1]] to <2 x i32>
+; CHECK-NEXT:    store <2 x i32> [[TMP1]], ptr [[OUT:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP2:%.*]] = getelementptr i32, ptr [[OUT]], i64 2
+; CHECK-NEXT:    store <2 x i32> [[TMP2]], ptr [[VEC_GEP2]], align 8
+; CHECK-NEXT:    ret void
+;
+  %inv = load <4 x float>, ptr %in
+  %op = fptosi <4 x float> %inv to <4 x i32>
+  %opt  = call <4 x i32> @llvm.matrix.transpose(<4 x i32> %op, i32 2, i32 2)
+  %optt = call <4 x i32> @llvm.matrix.transpose(<4 x i32> %opt, i32 2, i32 2)
+  store <4 x i32> %optt, ptr %out
+  ret void
+}
+
+define void @uitofp_2x2(ptr %in, ptr %out) {
+; CHECK-LABEL: @uitofp_2x2(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x i64>, ptr [[IN:%.*]], align 32
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr i64, ptr [[IN]], i64 2
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x i64>, ptr [[VEC_GEP]], align 16
+; CHECK-NEXT:    [[TMP1:%.*]] = uitofp <2 x i64> [[COL_LOAD]] to <2 x double>
+; CHECK-NEXT:    [[TMP2:%.*]] = uitofp <2 x i64> [[COL_LOAD1]] to <2 x double>
+; CHECK-NEXT:    store <2 x double> [[TMP1]], ptr [[OUT:%.*]], align 32
+; CHECK-NEXT:    [[VEC_GEP2:%.*]] = getelementptr double, ptr [[OUT]], i64 2
+; CHECK-NEXT:    store <2 x double> [[TMP2]], ptr [[VEC_GEP2]], align 16
+; CHECK-NEXT:    ret void
+;
+  %inv = load <4 x i64>, ptr %in
+  %op = uitofp <4 x i64> %inv to <4 x double>
+  %opt  = call <4  x double> @llvm.matrix.transpose(<4  x double> %op, i32 2, i32 2)
+  %optt = call <4  x double> @llvm.matrix.transpose(<4  x double> %opt, i32 2, i32 2)
+  store <4  x double> %optt, ptr %out
+  ret void
+}
+
+define void @sitofp_2x2(ptr %in, ptr %out) {
+; CHECK-LABEL: @sitofp_2x2(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x i64>, ptr [[IN:%.*]], align 32
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr i64, ptr [[IN]], i64 2
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x i64>, ptr [[VEC_GEP]], align 16
+; CHECK-NEXT:    [[TMP1:%.*]] = sitofp <2 x i64> [[COL_LOAD]] to <2 x double>
+; CHECK-NEXT:    [[TMP2:%.*]] = sitofp <2 x i64> [[COL_LOAD1]] to <2 x double>
+; CHECK-NEXT:    store <2 x double> [[TMP1]], ptr [[OUT:%.*]], align 32
+; CHECK-NEXT:    [[VEC_GEP2:%.*]] = getelementptr double, ptr [[OUT]], i64 2
+; CHECK-NEXT:    store <2 x double> [[TMP2]], ptr [[VEC_GEP2]], align 16
+; CHECK-NEXT:    ret void
+;
+  %inv = load <4 x i64>, ptr %in
+  %op = sitofp <4 x i64> %inv to <4 x double>
+  %opt  = call <4  x double> @llvm.matrix.transpose(<4  x double> %op, i32 2, i32 2)
+  %optt = call <4  x double> @llvm.matrix.transpose(<4  x double> %opt, i32 2, i32 2)
+  store <4  x double> %optt, ptr %out
+  ret void
+}
+
+define void @fptrunc_2x2(ptr %in, ptr %out) {
+; CHECK-LABEL: @fptrunc_2x2(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x double>, ptr [[IN:%.*]], align 32
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr double, ptr [[IN]], i64 2
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x double>, ptr [[VEC_GEP]], align 16
+; CHECK-NEXT:    [[TMP1:%.*]] = fptrunc nnan <2 x double> [[COL_LOAD]] to <2 x float>
+; CHECK-NEXT:    [[TMP2:%.*]] = fptrunc nnan <2 x double> [[COL_LOAD1]] to <2 x float>
+; CHECK-NEXT:    store <2 x float> [[TMP1]], ptr [[OUT:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP2:%.*]] = getelementptr float, ptr [[OUT]], i64 2
+; CHECK-NEXT:    store <2 x float> [[TMP2]], ptr [[VEC_GEP2]], align 8
+; CHECK-NEXT:    ret void
+;
+  %inv = load <4 x double>, ptr %in
+  %op = fptrunc nnan <4 x double> %inv to <4 x float>
+  %opt  = call <4 x float> @llvm.matrix.transpose(<4 x float> %op, i32 2, i32 2)
+  %optt = call <4 x float> @llvm.matrix.transpose(<4 x float> %opt, i32 2, i32 2)
+  store <4 x float> %optt, ptr %out
+  ret void
+}
+
+define void @fpext_2x2(ptr %in, ptr %out) {
+; CHECK-LABEL: @fpext_2x2(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x float>, ptr [[IN:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 2
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
+; CHECK-NEXT:    [[TMP1:%.*]] = fpext <2 x float> [[COL_LOAD]] to <2 x double>
+; CHECK-NEXT:    [[TMP2:%.*]] = fpext <2 x float> [[COL_LOAD1]] to <2 x double>
+; CHECK-NEXT:    store <2 x double> [[TMP1]], ptr [[OUT:%.*]], align 32
+; CHECK-NEXT:    [[VEC_GEP2:%.*]] = getelementptr double, ptr [[OUT]], i64 2
+; CHECK-NEXT:    store <2 x double> [[TMP2]], ptr [[VEC_GEP2]], align 16
+; CHECK-NEXT:    ret void
+;
+  %inv = load <4 x float>, ptr %in
+  %op = fpext <4 x float> %inv to <4 x double>
+  %opt  = call <4 x double> @llvm.matrix.transpose(<4 x double> %op, i32 2, i32 2)
+  %optt = call <4 x double> @llvm.matrix.transpose(<4 x double> %opt, i32 2, i32 2)
+  store <4 x double> %optt, ptr %out
+  ret void
+}
+
+define void @bitcast_2x2_v4f64_to_v4i64(ptr %in, ptr %out) {
+; CHECK-LABEL: @bitcast_2x2_v4f64_to_v4i64(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x double>, ptr [[IN:%.*]], align 32
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr double, ptr [[IN]], i64 2
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x double>, ptr [[VEC_GEP]], align 16
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <2 x double> [[COL_LOAD]] to <2 x i64>
+; CHECK-NEXT:    [[TMP2:%.*]] = bitcast <2 x double> [[COL_LOAD1]] to <2 x i64>
+; CHECK-NEXT:    store <2 x i64> [[TMP1]], ptr [[OUT:%.*]], align 32
+; CHECK-NEXT:    [[VEC_GEP2:%.*]] = getelementptr i64, ptr [[OUT]], i64 2
+; CHECK-NEXT:    store <2 x i64> [[TMP2]], ptr [[VEC_GEP2]], align 16
+; CHECK-NEXT:    ret void
+;
+  %inv = load <4 x double>, ptr %in
+  %op = bitcast <4 x double> %inv to <4 x i64>
+  %opt  = call <4 x i64> @llvm.matrix.transpose(<4 x i64> %op, i32 2, i32 2)
+  %optt = call <4 x i64> @llvm.matrix.transpose(<4 x i64> %opt, i32 2, i32 2)
+  store <4 x i64> %optt, ptr %out
+  ret void
+}
+
+define void @bitcast_2x2_i256_to_v4i64(ptr %in, ptr %out) {
+; CHECK-LABEL: @bitcast_2x2_i256_to_v4i64(
+; CHECK-NEXT:    [[INV:%.*]] = load i256, ptr [[IN:%.*]], align 4
+; CHECK-NEXT:    [[OP:%.*]] = bitcast i256 [[INV]] to <4 x double>
+; CHECK-NEXT:    store <4 x double> [[OP]], ptr [[OUT:%.*]], align 32
+; CHECK-NEXT:    ret void
+;
+  %inv = load i256, ptr %in
+  %op = bitcast i256 %inv to <4 x double>
+  %opt  = call <4 x double> @llvm.matrix.transpose(<4 x double> %op, i32 2, i32 2)
+  %optt = call <4 x double> @llvm.matrix.transpose(<4 x double> %opt, i32 2, i32 2)
+  store <4 x double> %optt, ptr %out
+  ret void
+}

>From f2f146ff46a1e61d2ba725c9557517c92af63c80 Mon Sep 17 00:00:00 2001
From: Jon Roelofs <jonathan_roelofs at apple.com>
Date: Thu, 29 May 2025 08:14:54 -0700
Subject: [PATCH 2/5] use llvm.column.major.store for shape info

---
 .../Transforms/LowerMatrixIntrinsics/unary.ll | 98 ++++++++-----------
 1 file changed, 39 insertions(+), 59 deletions(-)

diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/unary.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/unary.ll
index 80527cf717c7b..6c26ffac64462 100644
--- a/llvm/test/Transforms/LowerMatrixIntrinsics/unary.ll
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/unary.ll
@@ -8,16 +8,14 @@ define void @fneg_2x2(ptr %in, ptr %out) {
 ; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
 ; CHECK-NEXT:    [[TMP1:%.*]] = fneg <2 x float> [[COL_LOAD]]
 ; CHECK-NEXT:    [[TMP2:%.*]] = fneg <2 x float> [[COL_LOAD1]]
-; CHECK-NEXT:    store <2 x float> [[TMP1]], ptr [[OUT:%.*]], align 16
+; CHECK-NEXT:    store <2 x float> [[TMP1]], ptr [[OUT:%.*]], align 4
 ; CHECK-NEXT:    [[VEC_GEP2:%.*]] = getelementptr float, ptr [[OUT]], i64 2
-; CHECK-NEXT:    store <2 x float> [[TMP2]], ptr [[VEC_GEP2]], align 8
+; CHECK-NEXT:    store <2 x float> [[TMP2]], ptr [[VEC_GEP2]], align 4
 ; CHECK-NEXT:    ret void
 ;
   %inv = load <4 x float>, ptr %in
   %op = fneg <4 x float> %inv
-  %opt  = call <4 x float> @llvm.matrix.transpose(<4 x float> %op, i32 2, i32 2)
-  %optt = call <4 x float> @llvm.matrix.transpose(<4 x float> %opt, i32 2, i32 2)
-  store <4 x float> %optt, ptr %out
+  call void @llvm.matrix.column.major.store(<4 x float> %op, ptr %out, i64 2, i1 false, i32 2, i32 2)
   ret void
 }
 
@@ -28,16 +26,14 @@ define void @trunc_2x2(ptr %in, ptr %out) {
 ; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x i64>, ptr [[VEC_GEP]], align 16
 ; CHECK-NEXT:    [[TMP1:%.*]] = trunc <2 x i64> [[COL_LOAD]] to <2 x i32>
 ; CHECK-NEXT:    [[TMP2:%.*]] = trunc <2 x i64> [[COL_LOAD1]] to <2 x i32>
-; CHECK-NEXT:    store <2 x i32> [[TMP1]], ptr [[OUT:%.*]], align 16
+; CHECK-NEXT:    store <2 x i32> [[TMP1]], ptr [[OUT:%.*]], align 4
 ; CHECK-NEXT:    [[VEC_GEP2:%.*]] = getelementptr i32, ptr [[OUT]], i64 2
-; CHECK-NEXT:    store <2 x i32> [[TMP2]], ptr [[VEC_GEP2]], align 8
+; CHECK-NEXT:    store <2 x i32> [[TMP2]], ptr [[VEC_GEP2]], align 4
 ; CHECK-NEXT:    ret void
 ;
   %inv = load <4 x i64>, ptr %in
   %op = trunc <4 x i64> %inv to <4 x i32>
-  %opt  = call <4 x i32> @llvm.matrix.transpose(<4 x i32> %op, i32 2, i32 2)
-  %optt = call <4 x i32> @llvm.matrix.transpose(<4 x i32> %opt, i32 2, i32 2)
-  store <4 x i32> %optt, ptr %out
+  call void @llvm.matrix.column.major.store(<4 x i32> %op, ptr %out, i64 2, i1 false, i32 2, i32 2)
   ret void
 }
 
@@ -48,16 +44,14 @@ define void @zext_2x2(ptr %in, ptr %out) {
 ; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x i16>, ptr [[VEC_GEP]], align 4
 ; CHECK-NEXT:    [[TMP1:%.*]] = zext <2 x i16> [[COL_LOAD]] to <2 x i32>
 ; CHECK-NEXT:    [[TMP2:%.*]] = zext <2 x i16> [[COL_LOAD1]] to <2 x i32>
-; CHECK-NEXT:    store <2 x i32> [[TMP1]], ptr [[OUT:%.*]], align 16
+; CHECK-NEXT:    store <2 x i32> [[TMP1]], ptr [[OUT:%.*]], align 4
 ; CHECK-NEXT:    [[VEC_GEP2:%.*]] = getelementptr i32, ptr [[OUT]], i64 2
-; CHECK-NEXT:    store <2 x i32> [[TMP2]], ptr [[VEC_GEP2]], align 8
+; CHECK-NEXT:    store <2 x i32> [[TMP2]], ptr [[VEC_GEP2]], align 4
 ; CHECK-NEXT:    ret void
 ;
   %inv = load <4 x i16>, ptr %in
   %op = zext <4 x i16> %inv to <4 x i32>
-  %opt  = call <4 x i32> @llvm.matrix.transpose(<4 x i32> %op, i32 2, i32 2)
-  %optt = call <4 x i32> @llvm.matrix.transpose(<4 x i32> %opt, i32 2, i32 2)
-  store <4 x i32> %optt, ptr %out
+  call void @llvm.matrix.column.major.store(<4 x i32> %op, ptr %out, i64 2, i1 false, i32 2, i32 2)
   ret void
 }
 
@@ -68,16 +62,14 @@ define void @sext_2x2(ptr %in, ptr %out) {
 ; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x i8>, ptr [[VEC_GEP]], align 2
 ; CHECK-NEXT:    [[TMP1:%.*]] = sext <2 x i8> [[COL_LOAD]] to <2 x i16>
 ; CHECK-NEXT:    [[TMP2:%.*]] = sext <2 x i8> [[COL_LOAD1]] to <2 x i16>
-; CHECK-NEXT:    store <2 x i16> [[TMP1]], ptr [[OUT:%.*]], align 8
+; CHECK-NEXT:    store <2 x i16> [[TMP1]], ptr [[OUT:%.*]], align 2
 ; CHECK-NEXT:    [[VEC_GEP2:%.*]] = getelementptr i16, ptr [[OUT]], i64 2
-; CHECK-NEXT:    store <2 x i16> [[TMP2]], ptr [[VEC_GEP2]], align 4
+; CHECK-NEXT:    store <2 x i16> [[TMP2]], ptr [[VEC_GEP2]], align 2
 ; CHECK-NEXT:    ret void
 ;
   %inv = load <4 x i8>, ptr %in
   %op = sext <4 x i8> %inv to <4 x i16>
-  %opt  = call <4 x i16> @llvm.matrix.transpose(<4 x i16> %op, i32 2, i32 2)
-  %optt = call <4 x i16> @llvm.matrix.transpose(<4 x i16> %opt, i32 2, i32 2)
-  store <4 x i16> %optt, ptr %out
+  call void @llvm.matrix.column.major.store(<4 x i16> %op, ptr %out, i64 2, i1 false, i32 2, i32 2)
   ret void
 }
 
@@ -88,16 +80,14 @@ define void @fptoui_2x2(ptr %in, ptr %out) {
 ; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
 ; CHECK-NEXT:    [[TMP1:%.*]] = fptoui <2 x float> [[COL_LOAD]] to <2 x i32>
 ; CHECK-NEXT:    [[TMP2:%.*]] = fptoui <2 x float> [[COL_LOAD1]] to <2 x i32>
-; CHECK-NEXT:    store <2 x i32> [[TMP1]], ptr [[OUT:%.*]], align 16
+; CHECK-NEXT:    store <2 x i32> [[TMP1]], ptr [[OUT:%.*]], align 4
 ; CHECK-NEXT:    [[VEC_GEP2:%.*]] = getelementptr i32, ptr [[OUT]], i64 2
-; CHECK-NEXT:    store <2 x i32> [[TMP2]], ptr [[VEC_GEP2]], align 8
+; CHECK-NEXT:    store <2 x i32> [[TMP2]], ptr [[VEC_GEP2]], align 4
 ; CHECK-NEXT:    ret void
 ;
   %inv = load <4 x float>, ptr %in
   %op = fptoui <4 x float> %inv to <4 x i32>
-  %opt  = call <4 x i32> @llvm.matrix.transpose(<4 x i32> %op, i32 2, i32 2)
-  %optt = call <4 x i32> @llvm.matrix.transpose(<4 x i32> %opt, i32 2, i32 2)
-  store <4 x i32> %optt, ptr %out
+  call void @llvm.matrix.column.major.store(<4 x i32> %op, ptr %out, i64 2, i1 false, i32 2, i32 2)
   ret void
 }
 
@@ -108,16 +98,14 @@ define void @fptosi_2x2(ptr %in, ptr %out) {
 ; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
 ; CHECK-NEXT:    [[TMP1:%.*]] = fptosi <2 x float> [[COL_LOAD]] to <2 x i32>
 ; CHECK-NEXT:    [[TMP2:%.*]] = fptosi <2 x float> [[COL_LOAD1]] to <2 x i32>
-; CHECK-NEXT:    store <2 x i32> [[TMP1]], ptr [[OUT:%.*]], align 16
+; CHECK-NEXT:    store <2 x i32> [[TMP1]], ptr [[OUT:%.*]], align 4
 ; CHECK-NEXT:    [[VEC_GEP2:%.*]] = getelementptr i32, ptr [[OUT]], i64 2
-; CHECK-NEXT:    store <2 x i32> [[TMP2]], ptr [[VEC_GEP2]], align 8
+; CHECK-NEXT:    store <2 x i32> [[TMP2]], ptr [[VEC_GEP2]], align 4
 ; CHECK-NEXT:    ret void
 ;
   %inv = load <4 x float>, ptr %in
   %op = fptosi <4 x float> %inv to <4 x i32>
-  %opt  = call <4 x i32> @llvm.matrix.transpose(<4 x i32> %op, i32 2, i32 2)
-  %optt = call <4 x i32> @llvm.matrix.transpose(<4 x i32> %opt, i32 2, i32 2)
-  store <4 x i32> %optt, ptr %out
+  call void @llvm.matrix.column.major.store(<4 x i32> %op, ptr %out, i64 2, i1 false, i32 2, i32 2)
   ret void
 }
 
@@ -128,16 +116,14 @@ define void @uitofp_2x2(ptr %in, ptr %out) {
 ; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x i64>, ptr [[VEC_GEP]], align 16
 ; CHECK-NEXT:    [[TMP1:%.*]] = uitofp <2 x i64> [[COL_LOAD]] to <2 x double>
 ; CHECK-NEXT:    [[TMP2:%.*]] = uitofp <2 x i64> [[COL_LOAD1]] to <2 x double>
-; CHECK-NEXT:    store <2 x double> [[TMP1]], ptr [[OUT:%.*]], align 32
+; CHECK-NEXT:    store <2 x double> [[TMP1]], ptr [[OUT:%.*]], align 8
 ; CHECK-NEXT:    [[VEC_GEP2:%.*]] = getelementptr double, ptr [[OUT]], i64 2
-; CHECK-NEXT:    store <2 x double> [[TMP2]], ptr [[VEC_GEP2]], align 16
+; CHECK-NEXT:    store <2 x double> [[TMP2]], ptr [[VEC_GEP2]], align 8
 ; CHECK-NEXT:    ret void
 ;
   %inv = load <4 x i64>, ptr %in
   %op = uitofp <4 x i64> %inv to <4 x double>
-  %opt  = call <4  x double> @llvm.matrix.transpose(<4  x double> %op, i32 2, i32 2)
-  %optt = call <4  x double> @llvm.matrix.transpose(<4  x double> %opt, i32 2, i32 2)
-  store <4  x double> %optt, ptr %out
+  call void @llvm.matrix.column.major.store(<4 x double> %op, ptr %out, i64 2, i1 false, i32 2, i32 2)
   ret void
 }
 
@@ -148,16 +134,14 @@ define void @sitofp_2x2(ptr %in, ptr %out) {
 ; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x i64>, ptr [[VEC_GEP]], align 16
 ; CHECK-NEXT:    [[TMP1:%.*]] = sitofp <2 x i64> [[COL_LOAD]] to <2 x double>
 ; CHECK-NEXT:    [[TMP2:%.*]] = sitofp <2 x i64> [[COL_LOAD1]] to <2 x double>
-; CHECK-NEXT:    store <2 x double> [[TMP1]], ptr [[OUT:%.*]], align 32
+; CHECK-NEXT:    store <2 x double> [[TMP1]], ptr [[OUT:%.*]], align 8
 ; CHECK-NEXT:    [[VEC_GEP2:%.*]] = getelementptr double, ptr [[OUT]], i64 2
-; CHECK-NEXT:    store <2 x double> [[TMP2]], ptr [[VEC_GEP2]], align 16
+; CHECK-NEXT:    store <2 x double> [[TMP2]], ptr [[VEC_GEP2]], align 8
 ; CHECK-NEXT:    ret void
 ;
   %inv = load <4 x i64>, ptr %in
   %op = sitofp <4 x i64> %inv to <4 x double>
-  %opt  = call <4  x double> @llvm.matrix.transpose(<4  x double> %op, i32 2, i32 2)
-  %optt = call <4  x double> @llvm.matrix.transpose(<4  x double> %opt, i32 2, i32 2)
-  store <4  x double> %optt, ptr %out
+  call void @llvm.matrix.column.major.store(<4 x double> %op, ptr %out, i64 2, i1 false, i32 2, i32 2)
   ret void
 }
 
@@ -168,16 +152,14 @@ define void @fptrunc_2x2(ptr %in, ptr %out) {
 ; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x double>, ptr [[VEC_GEP]], align 16
 ; CHECK-NEXT:    [[TMP1:%.*]] = fptrunc nnan <2 x double> [[COL_LOAD]] to <2 x float>
 ; CHECK-NEXT:    [[TMP2:%.*]] = fptrunc nnan <2 x double> [[COL_LOAD1]] to <2 x float>
-; CHECK-NEXT:    store <2 x float> [[TMP1]], ptr [[OUT:%.*]], align 16
+; CHECK-NEXT:    store <2 x float> [[TMP1]], ptr [[OUT:%.*]], align 4
 ; CHECK-NEXT:    [[VEC_GEP2:%.*]] = getelementptr float, ptr [[OUT]], i64 2
-; CHECK-NEXT:    store <2 x float> [[TMP2]], ptr [[VEC_GEP2]], align 8
+; CHECK-NEXT:    store <2 x float> [[TMP2]], ptr [[VEC_GEP2]], align 4
 ; CHECK-NEXT:    ret void
 ;
   %inv = load <4 x double>, ptr %in
   %op = fptrunc nnan <4 x double> %inv to <4 x float>
-  %opt  = call <4 x float> @llvm.matrix.transpose(<4 x float> %op, i32 2, i32 2)
-  %optt = call <4 x float> @llvm.matrix.transpose(<4 x float> %opt, i32 2, i32 2)
-  store <4 x float> %optt, ptr %out
+  call void @llvm.matrix.column.major.store(<4 x float> %op, ptr %out, i64 2, i1 false, i32 2, i32 2)
   ret void
 }
 
@@ -188,16 +170,14 @@ define void @fpext_2x2(ptr %in, ptr %out) {
 ; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
 ; CHECK-NEXT:    [[TMP1:%.*]] = fpext <2 x float> [[COL_LOAD]] to <2 x double>
 ; CHECK-NEXT:    [[TMP2:%.*]] = fpext <2 x float> [[COL_LOAD1]] to <2 x double>
-; CHECK-NEXT:    store <2 x double> [[TMP1]], ptr [[OUT:%.*]], align 32
+; CHECK-NEXT:    store <2 x double> [[TMP1]], ptr [[OUT:%.*]], align 8
 ; CHECK-NEXT:    [[VEC_GEP2:%.*]] = getelementptr double, ptr [[OUT]], i64 2
-; CHECK-NEXT:    store <2 x double> [[TMP2]], ptr [[VEC_GEP2]], align 16
+; CHECK-NEXT:    store <2 x double> [[TMP2]], ptr [[VEC_GEP2]], align 8
 ; CHECK-NEXT:    ret void
 ;
   %inv = load <4 x float>, ptr %in
   %op = fpext <4 x float> %inv to <4 x double>
-  %opt  = call <4 x double> @llvm.matrix.transpose(<4 x double> %op, i32 2, i32 2)
-  %optt = call <4 x double> @llvm.matrix.transpose(<4 x double> %opt, i32 2, i32 2)
-  store <4 x double> %optt, ptr %out
+  call void @llvm.matrix.column.major.store(<4 x double> %op, ptr %out, i64 2, i1 false, i32 2, i32 2)
   ret void
 }
 
@@ -208,16 +188,14 @@ define void @bitcast_2x2_v4f64_to_v4i64(ptr %in, ptr %out) {
 ; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x double>, ptr [[VEC_GEP]], align 16
 ; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <2 x double> [[COL_LOAD]] to <2 x i64>
 ; CHECK-NEXT:    [[TMP2:%.*]] = bitcast <2 x double> [[COL_LOAD1]] to <2 x i64>
-; CHECK-NEXT:    store <2 x i64> [[TMP1]], ptr [[OUT:%.*]], align 32
+; CHECK-NEXT:    store <2 x i64> [[TMP1]], ptr [[OUT:%.*]], align 4
 ; CHECK-NEXT:    [[VEC_GEP2:%.*]] = getelementptr i64, ptr [[OUT]], i64 2
-; CHECK-NEXT:    store <2 x i64> [[TMP2]], ptr [[VEC_GEP2]], align 16
+; CHECK-NEXT:    store <2 x i64> [[TMP2]], ptr [[VEC_GEP2]], align 4
 ; CHECK-NEXT:    ret void
 ;
   %inv = load <4 x double>, ptr %in
   %op = bitcast <4 x double> %inv to <4 x i64>
-  %opt  = call <4 x i64> @llvm.matrix.transpose(<4 x i64> %op, i32 2, i32 2)
-  %optt = call <4 x i64> @llvm.matrix.transpose(<4 x i64> %opt, i32 2, i32 2)
-  store <4 x i64> %optt, ptr %out
+  call void @llvm.matrix.column.major.store(<4 x i64> %op, ptr %out, i64 2, i1 false, i32 2, i32 2)
   ret void
 }
 
@@ -225,13 +203,15 @@ define void @bitcast_2x2_i256_to_v4i64(ptr %in, ptr %out) {
 ; CHECK-LABEL: @bitcast_2x2_i256_to_v4i64(
 ; CHECK-NEXT:    [[INV:%.*]] = load i256, ptr [[IN:%.*]], align 4
 ; CHECK-NEXT:    [[OP:%.*]] = bitcast i256 [[INV]] to <4 x double>
-; CHECK-NEXT:    store <4 x double> [[OP]], ptr [[OUT:%.*]], align 32
+; CHECK-NEXT:    [[SPLIT:%.*]] = shufflevector <4 x double> [[OP]], <4 x double> poison, <2 x i32> <i32 0, i32 1>
+; CHECK-NEXT:    [[SPLIT1:%.*]] = shufflevector <4 x double> [[OP]], <4 x double> poison, <2 x i32> <i32 2, i32 3>
+; CHECK-NEXT:    store <2 x double> [[SPLIT]], ptr [[OUT:%.*]], align 8
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr double, ptr [[OUT]], i64 2
+; CHECK-NEXT:    store <2 x double> [[SPLIT1]], ptr [[VEC_GEP]], align 8
 ; CHECK-NEXT:    ret void
 ;
   %inv = load i256, ptr %in
   %op = bitcast i256 %inv to <4 x double>
-  %opt  = call <4 x double> @llvm.matrix.transpose(<4 x double> %op, i32 2, i32 2)
-  %optt = call <4 x double> @llvm.matrix.transpose(<4 x double> %opt, i32 2, i32 2)
-  store <4 x double> %optt, ptr %out
+  call void @llvm.matrix.column.major.store(<4 x double> %op, ptr %out, i64 2, i1 false, i32 2, i32 2)
   ret void
 }

>From 514340382361dc3648f2a85c9d68b7b07bb4dacc Mon Sep 17 00:00:00 2001
From: Jon Roelofs <jonathan_roelofs at apple.com>
Date: Thu, 29 May 2025 08:17:38 -0700
Subject: [PATCH 3/5] ensure switch is fully covered

---
 llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 4702355ca4577..f964ce2d5e499 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -234,7 +234,7 @@ static bool isUniformShape(Value *V) {
   if (I->isBinaryOp())
     return true;
 
-  if (auto *Cast = dyn_cast<CastInst>(V))
+  if (auto *Cast = dyn_cast<CastInst>(V)) {
     switch (Cast->getOpcode()) {
     case llvm::Instruction::Trunc:
     case llvm::Instruction::ZExt:
@@ -259,6 +259,8 @@ static bool isUniformShape(Value *V) {
     case llvm::Instruction::CastOpsEnd:
       llvm_unreachable("not an actual cast op");
     }
+    llvm_unreachable("unhandled cast opcode");
+  }
 
   switch (I->getOpcode()) {
   case Instruction::FNeg:

>From 278dbd115e882122b115291a242b4ed2e700bd17 Mon Sep 17 00:00:00 2001
From: Jon Roelofs <jonathan_roelofs at apple.com>
Date: Thu, 29 May 2025 09:48:18 -0700
Subject: [PATCH 4/5] add a test for bitcast <4 x double> %inv to <8 x i32>

---
 .../Transforms/LowerMatrixIntrinsics/unary.ll   | 17 +++++++++++++++++
 1 file changed, 17 insertions(+)

diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/unary.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/unary.ll
index 6c26ffac64462..e1d754bffcd69 100644
--- a/llvm/test/Transforms/LowerMatrixIntrinsics/unary.ll
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/unary.ll
@@ -199,6 +199,23 @@ define void @bitcast_2x2_v4f64_to_v4i64(ptr %in, ptr %out) {
   ret void
 }
 
+define void @bitcast_2x2_v4f64_to_v8i32(ptr %in, ptr %out) {
+; CHECK-LABEL: @bitcast_2x2_v4f64_to_v8i32(
+; CHECK-NEXT:    [[INV:%.*]] = load <4 x double>, ptr [[IN:%.*]], align 32
+; CHECK-NEXT:    [[OP:%.*]] = bitcast <4 x double> [[INV]] to <8 x i32>
+; CHECK-NEXT:    [[SPLIT:%.*]] = shufflevector <8 x i32> [[OP]], <8 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+; CHECK-NEXT:    [[SPLIT1:%.*]] = shufflevector <8 x i32> [[OP]], <8 x i32> poison, <4 x i32> <i32 4, i32 5, i32 6, i32 7>
+; CHECK-NEXT:    store <4 x i32> [[SPLIT]], ptr [[OUT:%.*]], align 4
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr i32, ptr [[OUT]], i64 4
+; CHECK-NEXT:    store <4 x i32> [[SPLIT1]], ptr [[VEC_GEP]], align 4
+; CHECK-NEXT:    ret void
+;
+  %inv = load <4 x double>, ptr %in
+  %op = bitcast <4 x double> %inv to <8 x i32>
+  call void @llvm.matrix.column.major.store(<8 x i32> %op, ptr %out, i64 4, i1 false, i32 4, i32 2)
+  ret void
+}
+
 define void @bitcast_2x2_i256_to_v4i64(ptr %in, ptr %out) {
 ; CHECK-LABEL: @bitcast_2x2_i256_to_v4i64(
 ; CHECK-NEXT:    [[INV:%.*]] = load i256, ptr [[IN:%.*]], align 4

>From 2d10a4f360c40a395c5449b10ea57ff8001c350e Mon Sep 17 00:00:00 2001
From: Jon Roelofs <jonathan_roelofs at apple.com>
Date: Thu, 29 May 2025 09:50:26 -0700
Subject: [PATCH 5/5] and a test for bitcast <4 x double> %inv to i256

---
 .../Transforms/LowerMatrixIntrinsics/unary.ll    | 16 ++++++++++++++++
 1 file changed, 16 insertions(+)

diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/unary.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/unary.ll
index e1d754bffcd69..a4bd516868bcd 100644
--- a/llvm/test/Transforms/LowerMatrixIntrinsics/unary.ll
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/unary.ll
@@ -232,3 +232,19 @@ define void @bitcast_2x2_i256_to_v4i64(ptr %in, ptr %out) {
   call void @llvm.matrix.column.major.store(<4 x double> %op, ptr %out, i64 2, i1 false, i32 2, i32 2)
   ret void
 }
+
+define void @bitcast_2x2_4i64_to_i256(ptr %in, ptr %out) {
+; CHECK-LABEL: @bitcast_2x2_4i64_to_i256(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x double>, ptr [[IN:%.*]], align 8
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr double, ptr [[IN]], i64 2
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x double>, ptr [[VEC_GEP]], align 8
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <2 x double> [[COL_LOAD]], <2 x double> [[COL_LOAD1]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+; CHECK-NEXT:    [[OP:%.*]] = bitcast <4 x double> [[TMP1]] to i256
+; CHECK-NEXT:    store i256 [[OP]], ptr [[OUT:%.*]], align 4
+; CHECK-NEXT:    ret void
+;
+  %inv = call <4 x double> @llvm.matrix.column.major.load(ptr %in, i64 2, i1 false, i32 2, i32 2)
+  %op = bitcast <4 x double> %inv to i256
+  store i256 %op, ptr %out
+  ret void
+}



More information about the llvm-commits mailing list