[llvm] [Matrix] Optimize static extracts with ShapeInfo (PR #141815)

Jon Roelofs via llvm-commits llvm-commits at lists.llvm.org
Wed May 28 12:00:15 PDT 2025


https://github.com/jroelofs updated https://github.com/llvm/llvm-project/pull/141815

>From c647c580cac763f4105be1e6ed10924266518fd7 Mon Sep 17 00:00:00 2001
From: Jon Roelofs <jonathan_roelofs at apple.com>
Date: Tue, 27 May 2025 18:39:56 -0700
Subject: [PATCH 1/6] [Matrix] Propagate shape information through fdiv insts

---
 .../Scalar/LowerMatrixIntrinsics.cpp          |  3 +++
 .../Transforms/LowerMatrixIntrinsics/binop.ll | 26 +++++++++++++++++++
 2 files changed, 29 insertions(+)
 create mode 100644 llvm/test/Transforms/LowerMatrixIntrinsics/binop.ll

diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 56d4be513ea6f..259148124c701 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -233,6 +233,7 @@ static bool isUniformShape(Value *V) {
   case Instruction::FAdd:
   case Instruction::FSub:
   case Instruction::FMul: // Scalar multiply.
+  case Instruction::FDiv:
   case Instruction::FNeg:
   case Instruction::Add:
   case Instruction::Mul:
@@ -2167,6 +2168,8 @@ class LowerMatrixIntrinsics {
         return Builder.CreateFAdd(LHS, RHS);
       case Instruction::FMul:
         return Builder.CreateFMul(LHS, RHS);
+      case Instruction::FDiv:
+        return Builder.CreateFDiv(LHS, RHS);
       case Instruction::FSub:
         return Builder.CreateFSub(LHS, RHS);
       default:
diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/binop.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/binop.ll
new file mode 100644
index 0000000000000..fd3e440d779ea
--- /dev/null
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/binop.ll
@@ -0,0 +1,26 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -passes='lower-matrix-intrinsics' -S < %s | FileCheck %s
+
+define void @fdiv_2x2(ptr %num, ptr %denom, ptr %out) {
+; CHECK-LABEL: @fdiv_2x2(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x double>, ptr [[NUM:%.*]], align 32
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr double, ptr [[NUM]], i64 2
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x double>, ptr [[VEC_GEP]], align 16
+; CHECK-NEXT:    [[COL_LOAD2:%.*]] = load <2 x double>, ptr [[DENOM:%.*]], align 32
+; CHECK-NEXT:    [[VEC_GEP3:%.*]] = getelementptr double, ptr [[DENOM]], i64 2
+; CHECK-NEXT:    [[COL_LOAD4:%.*]] = load <2 x double>, ptr [[VEC_GEP3]], align 16
+; CHECK-NEXT:    [[TMP1:%.*]] = fdiv <2 x double> [[COL_LOAD]], [[COL_LOAD2]]
+; CHECK-NEXT:    [[TMP2:%.*]] = fdiv <2 x double> [[COL_LOAD1]], [[COL_LOAD4]]
+; CHECK-NEXT:    store <2 x double> [[TMP1]], ptr [[OUT:%.*]], align 32
+; CHECK-NEXT:    [[VEC_GEP5:%.*]] = getelementptr double, ptr [[OUT]], i64 2
+; CHECK-NEXT:    store <2 x double> [[TMP2]], ptr [[VEC_GEP5]], align 16
+; CHECK-NEXT:    ret void
+;
+  %numv = load <4 x double>, ptr %num
+  %denomv = load <4 x double>, ptr %denom
+  %div = fdiv <4 x double> %numv, %denomv
+  %divt  = call <4 x double> @llvm.matrix.transpose(<4 x double> %div, i32 2, i32 2)
+  %divtt = call <4 x double> @llvm.matrix.transpose(<4 x double> %divt, i32 2, i32 2)
+  store <4 x double> %divtt, ptr %out
+  ret void
+}

>From 41eeb1ea9a9afeca5905e1b2fec2ea45fd07ecb2 Mon Sep 17 00:00:00 2001
From: Jon Roelofs <jonathan_roelofs at apple.com>
Date: Tue, 27 May 2025 19:15:00 -0700
Subject: [PATCH 2/6] [Matrix] Propagate shape information through (f)abs insts

---
 .../Scalar/LowerMatrixIntrinsics.cpp          | 57 ++++++++++++++++++-
 .../Transforms/LowerMatrixIntrinsics/binop.ll | 43 ++++++++++++++
 2 files changed, 99 insertions(+), 1 deletion(-)

diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 259148124c701..825fa6a36e2eb 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -229,6 +229,15 @@ static bool isUniformShape(Value *V) {
   if (!I)
     return true;
 
+  if (auto *II = dyn_cast<IntrinsicInst>(V))
+    switch (II->getIntrinsicID()) {
+    case Intrinsic::abs:
+    case Intrinsic::fabs:
+      return true;
+    default:
+      return false;
+    }
+
   switch (I->getOpcode()) {
   case Instruction::FAdd:
   case Instruction::FSub:
@@ -625,7 +634,7 @@ class LowerMatrixIntrinsics {
       case Intrinsic::matrix_column_major_store:
         return true;
       default:
-        return false;
+        return isUniformShape(II);
       }
     return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V);
   }
@@ -1131,6 +1140,9 @@ class LowerMatrixIntrinsics {
     case Intrinsic::matrix_column_major_store:
       LowerColumnMajorStore(Inst);
       break;
+    case Intrinsic::abs:
+    case Intrinsic::fabs:
+      return VisitUniformIntrinsic(cast<IntrinsicInst>(Inst));
     default:
       return false;
     }
@@ -2223,6 +2235,49 @@ class LowerMatrixIntrinsics {
     return true;
   }
 
+  /// Lower uniform shape intrinsics, if shape information is available.
+  bool VisitUniformIntrinsic(IntrinsicInst *Inst) {
+    auto I = ShapeMap.find(Inst);
+    if (I == ShapeMap.end())
+      return false;
+
+    IRBuilder<> Builder(Inst);
+    ShapeInfo &Shape = I->second;
+
+    MatrixTy Result;
+
+    switch (Inst->getIntrinsicID()) {
+    case Intrinsic::abs:
+    case Intrinsic::fabs: {
+      Value *Op = Inst->getOperand(0);
+
+      MatrixTy M = getMatrix(Op, Shape, Builder);
+
+      Builder.setFastMathFlags(getFastMathFlags(Inst));
+
+      for (unsigned I = 0; I < Shape.getNumVectors(); ++I)
+        switch (Inst->getIntrinsicID()) {
+        case Intrinsic::abs:
+          Result.addVector(Builder.CreateBinaryIntrinsic(
+              Intrinsic::abs, M.getVector(I), Inst->getOperand(1)));
+          break;
+        case Intrinsic::fabs:
+          Result.addVector(Builder.CreateUnaryIntrinsic(Inst->getIntrinsicID(),
+                                                        M.getVector(I)));
+          break;
+        }
+
+      finalizeLowering(Inst,
+                       Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
+                                               Result.getNumVectors()),
+                       Builder);
+      return true;
+    }
+    default:
+      llvm_unreachable("unexpected intrinsic");
+    }
+  }
+
   /// Helper to linearize a matrix expression tree into a string. Currently
   /// matrix expressions are linarized by starting at an expression leaf and
   /// linearizing bottom up.
diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/binop.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/binop.ll
index fd3e440d779ea..1eacb2a32e07d 100644
--- a/llvm/test/Transforms/LowerMatrixIntrinsics/binop.ll
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/binop.ll
@@ -24,3 +24,46 @@ define void @fdiv_2x2(ptr %num, ptr %denom, ptr %out) {
   store <4 x double> %divtt, ptr %out
   ret void
 }
+
+define void @fabs_2x2f64(ptr %in, ptr %out) {
+; CHECK-LABEL: @fabs_2x2f64(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x double>, ptr [[IN:%.*]], align 32
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr double, ptr [[IN]], i64 2
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x double>, ptr [[VEC_GEP]], align 16
+; CHECK-NEXT:    [[TMP1:%.*]] = call <2 x double> @llvm.fabs.v2f64(<2 x double> [[COL_LOAD]])
+; CHECK-NEXT:    [[TMP2:%.*]] = call <2 x double> @llvm.fabs.v2f64(<2 x double> [[COL_LOAD1]])
+; CHECK-NEXT:    store <2 x double> [[TMP1]], ptr [[OUT:%.*]], align 32
+; CHECK-NEXT:    [[VEC_GEP2:%.*]] = getelementptr double, ptr [[OUT]], i64 2
+; CHECK-NEXT:    store <2 x double> [[TMP2]], ptr [[VEC_GEP2]], align 16
+; CHECK-NEXT:    ret void
+;
+  %load = load <4 x double>, ptr %in
+  %fabs = call <4 x double> @llvm.fabs.v4f64(<4 x double> %load)
+  %fabst  = call <4 x double> @llvm.matrix.transpose(<4 x double> %fabs, i32 2, i32 2)
+  %fabstt = call <4 x double> @llvm.matrix.transpose(<4 x double> %fabst, i32 2, i32 2)
+  store <4 x double> %fabstt, ptr %out
+  ret void
+}
+
+define void @fabs_2x2i32(ptr %in, ptr %out) {
+; CHECK-LABEL: @fabs_2x2i32(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x i32>, ptr [[IN:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr i32, ptr [[IN]], i64 2
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x i32>, ptr [[VEC_GEP]], align 8
+; CHECK-NEXT:    [[TMP1:%.*]] = call <2 x i32> @llvm.abs.v2i32(<2 x i32> [[COL_LOAD]], i1 false)
+; CHECK-NEXT:    [[TMP2:%.*]] = call <2 x i32> @llvm.abs.v2i32(<2 x i32> [[COL_LOAD1]], i1 false)
+; CHECK-NEXT:    [[TMP3:%.*]] = call <2 x i32> @llvm.abs.v2i32(<2 x i32> [[TMP1]], i1 true)
+; CHECK-NEXT:    [[TMP4:%.*]] = call <2 x i32> @llvm.abs.v2i32(<2 x i32> [[TMP2]], i1 true)
+; CHECK-NEXT:    store <2 x i32> [[TMP3]], ptr [[OUT:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP2:%.*]] = getelementptr i32, ptr [[OUT]], i64 2
+; CHECK-NEXT:    store <2 x i32> [[TMP4]], ptr [[VEC_GEP2]], align 8
+; CHECK-NEXT:    ret void
+;
+  %load = load <4 x i32>, ptr %in
+  %abs = call <4 x i32> @llvm.abs.v4i32(<4 x i32> %load, i1 false)
+  %abst  = call <4 x i32> @llvm.matrix.transpose(<4 x i32> %abs, i32 2, i32 2)
+  %abstt = call <4 x i32> @llvm.matrix.transpose(<4 x i32> %abst, i32 2, i32 2)
+  %absabstt = call <4 x i32> @llvm.abs.v4i32(<4 x i32> %abstt, i1 true)
+  store <4 x i32> %absabstt, ptr %out
+  ret void
+}

>From 39c585c34d6775f0d2b141d945a4f7668a9aa9b3 Mon Sep 17 00:00:00 2001
From: Jon Roelofs <jonathan_roelofs at apple.com>
Date: Wed, 28 May 2025 07:43:21 -0700
Subject: [PATCH 3/6] Extend split behavior to any binop

---
 .../Scalar/LowerMatrixIntrinsics.cpp          |  35 +-
 .../Transforms/LowerMatrixIntrinsics/binop.ll | 414 +++++++++++++++++-
 2 files changed, 416 insertions(+), 33 deletions(-)

diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 259148124c701..756a72e6d97bc 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -229,15 +229,11 @@ static bool isUniformShape(Value *V) {
   if (!I)
     return true;
 
+  if (I->isBinaryOp())
+    return true;
+
   switch (I->getOpcode()) {
-  case Instruction::FAdd:
-  case Instruction::FSub:
-  case Instruction::FMul: // Scalar multiply.
-  case Instruction::FDiv:
   case Instruction::FNeg:
-  case Instruction::Add:
-  case Instruction::Mul:
-  case Instruction::Sub:
     return true;
   default:
     return false;
@@ -2155,30 +2151,9 @@ class LowerMatrixIntrinsics {
 
     Builder.setFastMathFlags(getFastMathFlags(Inst));
 
-    // Helper to perform binary op on vectors.
-    auto BuildVectorOp = [&Builder, Inst](Value *LHS, Value *RHS) {
-      switch (Inst->getOpcode()) {
-      case Instruction::Add:
-        return Builder.CreateAdd(LHS, RHS);
-      case Instruction::Mul:
-        return Builder.CreateMul(LHS, RHS);
-      case Instruction::Sub:
-        return Builder.CreateSub(LHS, RHS);
-      case Instruction::FAdd:
-        return Builder.CreateFAdd(LHS, RHS);
-      case Instruction::FMul:
-        return Builder.CreateFMul(LHS, RHS);
-      case Instruction::FDiv:
-        return Builder.CreateFDiv(LHS, RHS);
-      case Instruction::FSub:
-        return Builder.CreateFSub(LHS, RHS);
-      default:
-        llvm_unreachable("Unsupported binary operator for matrix");
-      }
-    };
-
     for (unsigned I = 0; I < Shape.getNumVectors(); ++I)
-      Result.addVector(BuildVectorOp(A.getVector(I), B.getVector(I)));
+      Result.addVector(Builder.CreateBinOp(Inst->getOpcode(), A.getVector(I),
+                                           B.getVector(I)));
 
     finalizeLowering(Inst,
                      Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/binop.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/binop.ll
index fd3e440d779ea..9160ced2715aa 100644
--- a/llvm/test/Transforms/LowerMatrixIntrinsics/binop.ll
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/binop.ll
@@ -1,6 +1,198 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
 ; RUN: opt -passes='lower-matrix-intrinsics' -S < %s | FileCheck %s
 
+define void @add_2x2(ptr %lhs, ptr %rhs, ptr %out) {
+; CHECK-LABEL: @add_2x2(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x i32>, ptr [[LHS:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr i32, ptr [[LHS]], i64 2
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x i32>, ptr [[VEC_GEP]], align 8
+; CHECK-NEXT:    [[COL_LOAD2:%.*]] = load <2 x i32>, ptr [[RHS:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP3:%.*]] = getelementptr i32, ptr [[RHS]], i64 2
+; CHECK-NEXT:    [[COL_LOAD4:%.*]] = load <2 x i32>, ptr [[VEC_GEP3]], align 8
+; CHECK-NEXT:    [[TMP1:%.*]] = add <2 x i32> [[COL_LOAD]], [[COL_LOAD2]]
+; CHECK-NEXT:    [[TMP2:%.*]] = add <2 x i32> [[COL_LOAD1]], [[COL_LOAD4]]
+; CHECK-NEXT:    store <2 x i32> [[TMP1]], ptr [[OUT:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP5:%.*]] = getelementptr i32, ptr [[OUT]], i64 2
+; CHECK-NEXT:    store <2 x i32> [[TMP2]], ptr [[VEC_GEP5]], align 8
+; CHECK-NEXT:    ret void
+;
+  %lhsv = load <4 x i32>, ptr %lhs
+  %rhsv = load <4 x i32>, ptr %rhs
+  %op = add <4 x i32> %lhsv, %rhsv
+  %opt  = call <4 x i32> @llvm.matrix.transpose(<4 x i32> %op, i32 2, i32 2)
+  %optt = call <4 x i32> @llvm.matrix.transpose(<4 x i32> %opt, i32 2, i32 2)
+  store <4 x i32> %optt, ptr %out
+  ret void
+}
+
+define void @fadd_2x2(ptr %lhs, ptr %rhs, ptr %out) {
+; CHECK-LABEL: @fadd_2x2(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x float>, ptr [[LHS:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[LHS]], i64 2
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
+; CHECK-NEXT:    [[COL_LOAD2:%.*]] = load <2 x float>, ptr [[RHS:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP3:%.*]] = getelementptr float, ptr [[RHS]], i64 2
+; CHECK-NEXT:    [[COL_LOAD4:%.*]] = load <2 x float>, ptr [[VEC_GEP3]], align 8
+; CHECK-NEXT:    [[TMP1:%.*]] = fadd <2 x float> [[COL_LOAD]], [[COL_LOAD2]]
+; CHECK-NEXT:    [[TMP2:%.*]] = fadd <2 x float> [[COL_LOAD1]], [[COL_LOAD4]]
+; CHECK-NEXT:    store <2 x float> [[TMP1]], ptr [[OUT:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP5:%.*]] = getelementptr float, ptr [[OUT]], i64 2
+; CHECK-NEXT:    store <2 x float> [[TMP2]], ptr [[VEC_GEP5]], align 8
+; CHECK-NEXT:    ret void
+;
+  %lhsv = load <4 x float>, ptr %lhs
+  %rhsv = load <4 x float>, ptr %rhs
+  %op = fadd <4 x float> %lhsv, %rhsv
+  %opt  = call <4 x float> @llvm.matrix.transpose(<4 x float> %op, i32 2, i32 2)
+  %optt = call <4 x float> @llvm.matrix.transpose(<4 x float> %opt, i32 2, i32 2)
+  store <4 x float> %optt, ptr %out
+  ret void
+}
+
+define void @sub_2x2(ptr %lhs, ptr %rhs, ptr %out) {
+; CHECK-LABEL: @sub_2x2(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x i32>, ptr [[LHS:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr i32, ptr [[LHS]], i64 2
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x i32>, ptr [[VEC_GEP]], align 8
+; CHECK-NEXT:    [[COL_LOAD2:%.*]] = load <2 x i32>, ptr [[RHS:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP3:%.*]] = getelementptr i32, ptr [[RHS]], i64 2
+; CHECK-NEXT:    [[COL_LOAD4:%.*]] = load <2 x i32>, ptr [[VEC_GEP3]], align 8
+; CHECK-NEXT:    [[TMP1:%.*]] = sub <2 x i32> [[COL_LOAD]], [[COL_LOAD2]]
+; CHECK-NEXT:    [[TMP2:%.*]] = sub <2 x i32> [[COL_LOAD1]], [[COL_LOAD4]]
+; CHECK-NEXT:    store <2 x i32> [[TMP1]], ptr [[OUT:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP5:%.*]] = getelementptr i32, ptr [[OUT]], i64 2
+; CHECK-NEXT:    store <2 x i32> [[TMP2]], ptr [[VEC_GEP5]], align 8
+; CHECK-NEXT:    ret void
+;
+  %lhsv = load <4 x i32>, ptr %lhs
+  %rhsv = load <4 x i32>, ptr %rhs
+  %op = sub <4 x i32> %lhsv, %rhsv
+  %opt  = call <4 x i32> @llvm.matrix.transpose(<4 x i32> %op, i32 2, i32 2)
+  %optt = call <4 x i32> @llvm.matrix.transpose(<4 x i32> %opt, i32 2, i32 2)
+  store <4 x i32> %optt, ptr %out
+  ret void
+}
+
+define void @fsub_2x2(ptr %lhs, ptr %rhs, ptr %out) {
+; CHECK-LABEL: @fsub_2x2(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x float>, ptr [[LHS:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[LHS]], i64 2
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
+; CHECK-NEXT:    [[COL_LOAD2:%.*]] = load <2 x float>, ptr [[RHS:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP3:%.*]] = getelementptr float, ptr [[RHS]], i64 2
+; CHECK-NEXT:    [[COL_LOAD4:%.*]] = load <2 x float>, ptr [[VEC_GEP3]], align 8
+; CHECK-NEXT:    [[TMP1:%.*]] = fsub nnan <2 x float> [[COL_LOAD]], [[COL_LOAD2]]
+; CHECK-NEXT:    [[TMP2:%.*]] = fsub nnan <2 x float> [[COL_LOAD1]], [[COL_LOAD4]]
+; CHECK-NEXT:    store <2 x float> [[TMP1]], ptr [[OUT:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP5:%.*]] = getelementptr float, ptr [[OUT]], i64 2
+; CHECK-NEXT:    store <2 x float> [[TMP2]], ptr [[VEC_GEP5]], align 8
+; CHECK-NEXT:    ret void
+;
+  %lhsv = load <4 x float>, ptr %lhs
+  %rhsv = load <4 x float>, ptr %rhs
+  %op = fsub nnan <4 x float> %lhsv, %rhsv
+  %opt  = call <4 x float> @llvm.matrix.transpose(<4 x float> %op, i32 2, i32 2)
+  %optt = call <4 x float> @llvm.matrix.transpose(<4 x float> %opt, i32 2, i32 2)
+  store <4 x float> %optt, ptr %out
+  ret void
+}
+
+define void @mul_2x2(ptr %lhs, ptr %rhs, ptr %out) {
+; CHECK-LABEL: @mul_2x2(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x i32>, ptr [[LHS:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr i32, ptr [[LHS]], i64 2
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x i32>, ptr [[VEC_GEP]], align 8
+; CHECK-NEXT:    [[COL_LOAD2:%.*]] = load <2 x i32>, ptr [[RHS:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP3:%.*]] = getelementptr i32, ptr [[RHS]], i64 2
+; CHECK-NEXT:    [[COL_LOAD4:%.*]] = load <2 x i32>, ptr [[VEC_GEP3]], align 8
+; CHECK-NEXT:    [[TMP1:%.*]] = mul <2 x i32> [[COL_LOAD]], [[COL_LOAD2]]
+; CHECK-NEXT:    [[TMP2:%.*]] = mul <2 x i32> [[COL_LOAD1]], [[COL_LOAD4]]
+; CHECK-NEXT:    store <2 x i32> [[TMP1]], ptr [[OUT:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP5:%.*]] = getelementptr i32, ptr [[OUT]], i64 2
+; CHECK-NEXT:    store <2 x i32> [[TMP2]], ptr [[VEC_GEP5]], align 8
+; CHECK-NEXT:    ret void
+;
+  %lhsv = load <4 x i32>, ptr %lhs
+  %rhsv = load <4 x i32>, ptr %rhs
+  %op = mul <4 x i32> %lhsv, %rhsv
+  %opt  = call <4 x i32> @llvm.matrix.transpose(<4 x i32> %op, i32 2, i32 2)
+  %optt = call <4 x i32> @llvm.matrix.transpose(<4 x i32> %opt, i32 2, i32 2)
+  store <4 x i32> %optt, ptr %out
+  ret void
+}
+
+define void @fmul_2x2(ptr %lhs, ptr %rhs, ptr %out) {
+; CHECK-LABEL: @fmul_2x2(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x float>, ptr [[LHS:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[LHS]], i64 2
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
+; CHECK-NEXT:    [[COL_LOAD2:%.*]] = load <2 x float>, ptr [[RHS:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP3:%.*]] = getelementptr float, ptr [[RHS]], i64 2
+; CHECK-NEXT:    [[COL_LOAD4:%.*]] = load <2 x float>, ptr [[VEC_GEP3]], align 8
+; CHECK-NEXT:    [[TMP1:%.*]] = fmul contract <2 x float> [[COL_LOAD]], [[COL_LOAD2]]
+; CHECK-NEXT:    [[TMP2:%.*]] = fmul contract <2 x float> [[COL_LOAD1]], [[COL_LOAD4]]
+; CHECK-NEXT:    store <2 x float> [[TMP1]], ptr [[OUT:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP5:%.*]] = getelementptr float, ptr [[OUT]], i64 2
+; CHECK-NEXT:    store <2 x float> [[TMP2]], ptr [[VEC_GEP5]], align 8
+; CHECK-NEXT:    ret void
+;
+  %lhsv = load <4 x float>, ptr %lhs
+  %rhsv = load <4 x float>, ptr %rhs
+  %op = fmul contract <4 x float> %lhsv, %rhsv
+  %opt  = call <4 x float> @llvm.matrix.transpose(<4 x float> %op, i32 2, i32 2)
+  %optt = call <4 x float> @llvm.matrix.transpose(<4 x float> %opt, i32 2, i32 2)
+  store <4 x float> %optt, ptr %out
+  ret void
+}
+
+define void @udiv_2x2(ptr %lhs, ptr %rhs, ptr %out) {
+; CHECK-LABEL: @udiv_2x2(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x i32>, ptr [[LHS:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr i32, ptr [[LHS]], i64 2
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x i32>, ptr [[VEC_GEP]], align 8
+; CHECK-NEXT:    [[COL_LOAD2:%.*]] = load <2 x i32>, ptr [[RHS:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP3:%.*]] = getelementptr i32, ptr [[RHS]], i64 2
+; CHECK-NEXT:    [[COL_LOAD4:%.*]] = load <2 x i32>, ptr [[VEC_GEP3]], align 8
+; CHECK-NEXT:    [[TMP1:%.*]] = udiv <2 x i32> [[COL_LOAD]], [[COL_LOAD2]]
+; CHECK-NEXT:    [[TMP2:%.*]] = udiv <2 x i32> [[COL_LOAD1]], [[COL_LOAD4]]
+; CHECK-NEXT:    store <2 x i32> [[TMP1]], ptr [[OUT:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP5:%.*]] = getelementptr i32, ptr [[OUT]], i64 2
+; CHECK-NEXT:    store <2 x i32> [[TMP2]], ptr [[VEC_GEP5]], align 8
+; CHECK-NEXT:    ret void
+;
+  %lhsv = load <4 x i32>, ptr %lhs
+  %rhsv = load <4 x i32>, ptr %rhs
+  %op = udiv <4 x i32> %lhsv, %rhsv
+  %opt  = call <4 x i32> @llvm.matrix.transpose(<4 x i32> %op, i32 2, i32 2)
+  %optt = call <4 x i32> @llvm.matrix.transpose(<4 x i32> %opt, i32 2, i32 2)
+  store <4 x i32> %optt, ptr %out
+  ret void
+}
+
+define void @sdiv_2x2(ptr %lhs, ptr %rhs, ptr %out) {
+; CHECK-LABEL: @sdiv_2x2(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x i32>, ptr [[LHS:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr i32, ptr [[LHS]], i64 2
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x i32>, ptr [[VEC_GEP]], align 8
+; CHECK-NEXT:    [[COL_LOAD2:%.*]] = load <2 x i32>, ptr [[RHS:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP3:%.*]] = getelementptr i32, ptr [[RHS]], i64 2
+; CHECK-NEXT:    [[COL_LOAD4:%.*]] = load <2 x i32>, ptr [[VEC_GEP3]], align 8
+; CHECK-NEXT:    [[TMP1:%.*]] = sdiv <2 x i32> [[COL_LOAD]], [[COL_LOAD2]]
+; CHECK-NEXT:    [[TMP2:%.*]] = sdiv <2 x i32> [[COL_LOAD1]], [[COL_LOAD4]]
+; CHECK-NEXT:    store <2 x i32> [[TMP1]], ptr [[OUT:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP5:%.*]] = getelementptr i32, ptr [[OUT]], i64 2
+; CHECK-NEXT:    store <2 x i32> [[TMP2]], ptr [[VEC_GEP5]], align 8
+; CHECK-NEXT:    ret void
+;
+  %lhsv = load <4 x i32>, ptr %lhs
+  %rhsv = load <4 x i32>, ptr %rhs
+  %op = sdiv <4 x i32> %lhsv, %rhsv
+  %opt  = call <4 x i32> @llvm.matrix.transpose(<4 x i32> %op, i32 2, i32 2)
+  %optt = call <4 x i32> @llvm.matrix.transpose(<4 x i32> %opt, i32 2, i32 2)
+  store <4 x i32> %optt, ptr %out
+  ret void
+}
+
 define void @fdiv_2x2(ptr %num, ptr %denom, ptr %out) {
 ; CHECK-LABEL: @fdiv_2x2(
 ; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x double>, ptr [[NUM:%.*]], align 32
@@ -9,8 +201,8 @@ define void @fdiv_2x2(ptr %num, ptr %denom, ptr %out) {
 ; CHECK-NEXT:    [[COL_LOAD2:%.*]] = load <2 x double>, ptr [[DENOM:%.*]], align 32
 ; CHECK-NEXT:    [[VEC_GEP3:%.*]] = getelementptr double, ptr [[DENOM]], i64 2
 ; CHECK-NEXT:    [[COL_LOAD4:%.*]] = load <2 x double>, ptr [[VEC_GEP3]], align 16
-; CHECK-NEXT:    [[TMP1:%.*]] = fdiv <2 x double> [[COL_LOAD]], [[COL_LOAD2]]
-; CHECK-NEXT:    [[TMP2:%.*]] = fdiv <2 x double> [[COL_LOAD1]], [[COL_LOAD4]]
+; CHECK-NEXT:    [[TMP1:%.*]] = fdiv nnan <2 x double> [[COL_LOAD]], [[COL_LOAD2]]
+; CHECK-NEXT:    [[TMP2:%.*]] = fdiv nnan <2 x double> [[COL_LOAD1]], [[COL_LOAD4]]
 ; CHECK-NEXT:    store <2 x double> [[TMP1]], ptr [[OUT:%.*]], align 32
 ; CHECK-NEXT:    [[VEC_GEP5:%.*]] = getelementptr double, ptr [[OUT]], i64 2
 ; CHECK-NEXT:    store <2 x double> [[TMP2]], ptr [[VEC_GEP5]], align 16
@@ -18,9 +210,225 @@ define void @fdiv_2x2(ptr %num, ptr %denom, ptr %out) {
 ;
   %numv = load <4 x double>, ptr %num
   %denomv = load <4 x double>, ptr %denom
-  %div = fdiv <4 x double> %numv, %denomv
+  %div = fdiv nnan <4 x double> %numv, %denomv
   %divt  = call <4 x double> @llvm.matrix.transpose(<4 x double> %div, i32 2, i32 2)
   %divtt = call <4 x double> @llvm.matrix.transpose(<4 x double> %divt, i32 2, i32 2)
   store <4 x double> %divtt, ptr %out
   ret void
 }
+
+define void @urem_2x2(ptr %lhs, ptr %rhs, ptr %out) {
+; CHECK-LABEL: @urem_2x2(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x i32>, ptr [[LHS:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr i32, ptr [[LHS]], i64 2
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x i32>, ptr [[VEC_GEP]], align 8
+; CHECK-NEXT:    [[COL_LOAD2:%.*]] = load <2 x i32>, ptr [[RHS:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP3:%.*]] = getelementptr i32, ptr [[RHS]], i64 2
+; CHECK-NEXT:    [[COL_LOAD4:%.*]] = load <2 x i32>, ptr [[VEC_GEP3]], align 8
+; CHECK-NEXT:    [[TMP1:%.*]] = urem <2 x i32> [[COL_LOAD]], [[COL_LOAD2]]
+; CHECK-NEXT:    [[TMP2:%.*]] = urem <2 x i32> [[COL_LOAD1]], [[COL_LOAD4]]
+; CHECK-NEXT:    store <2 x i32> [[TMP1]], ptr [[OUT:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP5:%.*]] = getelementptr i32, ptr [[OUT]], i64 2
+; CHECK-NEXT:    store <2 x i32> [[TMP2]], ptr [[VEC_GEP5]], align 8
+; CHECK-NEXT:    ret void
+;
+  %lhsv = load <4 x i32>, ptr %lhs
+  %rhsv = load <4 x i32>, ptr %rhs
+  %op = urem <4 x i32> %lhsv, %rhsv
+  %opt  = call <4 x i32> @llvm.matrix.transpose(<4 x i32> %op, i32 2, i32 2)
+  %optt = call <4 x i32> @llvm.matrix.transpose(<4 x i32> %opt, i32 2, i32 2)
+  store <4 x i32> %optt, ptr %out
+  ret void
+}
+
+define void @srem_2x2(ptr %lhs, ptr %rhs, ptr %out) {
+; CHECK-LABEL: @srem_2x2(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x i32>, ptr [[LHS:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr i32, ptr [[LHS]], i64 2
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x i32>, ptr [[VEC_GEP]], align 8
+; CHECK-NEXT:    [[COL_LOAD2:%.*]] = load <2 x i32>, ptr [[RHS:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP3:%.*]] = getelementptr i32, ptr [[RHS]], i64 2
+; CHECK-NEXT:    [[COL_LOAD4:%.*]] = load <2 x i32>, ptr [[VEC_GEP3]], align 8
+; CHECK-NEXT:    [[TMP1:%.*]] = srem <2 x i32> [[COL_LOAD]], [[COL_LOAD2]]
+; CHECK-NEXT:    [[TMP2:%.*]] = srem <2 x i32> [[COL_LOAD1]], [[COL_LOAD4]]
+; CHECK-NEXT:    store <2 x i32> [[TMP1]], ptr [[OUT:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP5:%.*]] = getelementptr i32, ptr [[OUT]], i64 2
+; CHECK-NEXT:    store <2 x i32> [[TMP2]], ptr [[VEC_GEP5]], align 8
+; CHECK-NEXT:    ret void
+;
+  %lhsv = load <4 x i32>, ptr %lhs
+  %rhsv = load <4 x i32>, ptr %rhs
+  %op = srem <4 x i32> %lhsv, %rhsv
+  %opt  = call <4 x i32> @llvm.matrix.transpose(<4 x i32> %op, i32 2, i32 2)
+  %optt = call <4 x i32> @llvm.matrix.transpose(<4 x i32> %opt, i32 2, i32 2)
+  store <4 x i32> %optt, ptr %out
+  ret void
+}
+
+define void @frem_2x2(ptr %lhs, ptr %rhs, ptr %out) {
+; CHECK-LABEL: @frem_2x2(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x float>, ptr [[LHS:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[LHS]], i64 2
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
+; CHECK-NEXT:    [[COL_LOAD2:%.*]] = load <2 x float>, ptr [[RHS:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP3:%.*]] = getelementptr float, ptr [[RHS]], i64 2
+; CHECK-NEXT:    [[COL_LOAD4:%.*]] = load <2 x float>, ptr [[VEC_GEP3]], align 8
+; CHECK-NEXT:    [[TMP1:%.*]] = frem fast <2 x float> [[COL_LOAD]], [[COL_LOAD2]]
+; CHECK-NEXT:    [[TMP2:%.*]] = frem fast <2 x float> [[COL_LOAD1]], [[COL_LOAD4]]
+; CHECK-NEXT:    store <2 x float> [[TMP1]], ptr [[OUT:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP5:%.*]] = getelementptr float, ptr [[OUT]], i64 2
+; CHECK-NEXT:    store <2 x float> [[TMP2]], ptr [[VEC_GEP5]], align 8
+; CHECK-NEXT:    ret void
+;
+  %lhsv = load <4 x float>, ptr %lhs
+  %rhsv = load <4 x float>, ptr %rhs
+  %op = frem fast <4 x float> %lhsv, %rhsv
+  %opt  = call <4 x float> @llvm.matrix.transpose(<4 x float> %op, i32 2, i32 2)
+  %optt = call <4 x float> @llvm.matrix.transpose(<4 x float> %opt, i32 2, i32 2)
+  store <4 x float> %optt, ptr %out
+  ret void
+}
+
+define void @shl_2x2(ptr %lhs, ptr %rhs, ptr %out) {
+; CHECK-LABEL: @shl_2x2(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x i32>, ptr [[LHS:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr i32, ptr [[LHS]], i64 2
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x i32>, ptr [[VEC_GEP]], align 8
+; CHECK-NEXT:    [[COL_LOAD2:%.*]] = load <2 x i32>, ptr [[RHS:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP3:%.*]] = getelementptr i32, ptr [[RHS]], i64 2
+; CHECK-NEXT:    [[COL_LOAD4:%.*]] = load <2 x i32>, ptr [[VEC_GEP3]], align 8
+; CHECK-NEXT:    [[TMP1:%.*]] = shl <2 x i32> [[COL_LOAD]], [[COL_LOAD2]]
+; CHECK-NEXT:    [[TMP2:%.*]] = shl <2 x i32> [[COL_LOAD1]], [[COL_LOAD4]]
+; CHECK-NEXT:    store <2 x i32> [[TMP1]], ptr [[OUT:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP5:%.*]] = getelementptr i32, ptr [[OUT]], i64 2
+; CHECK-NEXT:    store <2 x i32> [[TMP2]], ptr [[VEC_GEP5]], align 8
+; CHECK-NEXT:    ret void
+;
+  %lhsv = load <4 x i32>, ptr %lhs
+  %rhsv = load <4 x i32>, ptr %rhs
+  %op = shl <4 x i32> %lhsv, %rhsv
+  %opt  = call <4 x i32> @llvm.matrix.transpose(<4 x i32> %op, i32 2, i32 2)
+  %optt = call <4 x i32> @llvm.matrix.transpose(<4 x i32> %opt, i32 2, i32 2)
+  store <4 x i32> %optt, ptr %out
+  ret void
+}
+
+define void @lshr_2x2(ptr %lhs, ptr %rhs, ptr %out) {
+; CHECK-LABEL: @lshr_2x2(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x i32>, ptr [[LHS:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr i32, ptr [[LHS]], i64 2
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x i32>, ptr [[VEC_GEP]], align 8
+; CHECK-NEXT:    [[COL_LOAD2:%.*]] = load <2 x i32>, ptr [[RHS:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP3:%.*]] = getelementptr i32, ptr [[RHS]], i64 2
+; CHECK-NEXT:    [[COL_LOAD4:%.*]] = load <2 x i32>, ptr [[VEC_GEP3]], align 8
+; CHECK-NEXT:    [[TMP1:%.*]] = lshr <2 x i32> [[COL_LOAD]], [[COL_LOAD2]]
+; CHECK-NEXT:    [[TMP2:%.*]] = lshr <2 x i32> [[COL_LOAD1]], [[COL_LOAD4]]
+; CHECK-NEXT:    store <2 x i32> [[TMP1]], ptr [[OUT:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP5:%.*]] = getelementptr i32, ptr [[OUT]], i64 2
+; CHECK-NEXT:    store <2 x i32> [[TMP2]], ptr [[VEC_GEP5]], align 8
+; CHECK-NEXT:    ret void
+;
+  %lhsv = load <4 x i32>, ptr %lhs
+  %rhsv = load <4 x i32>, ptr %rhs
+  %op = lshr <4 x i32> %lhsv, %rhsv
+  %opt  = call <4 x i32> @llvm.matrix.transpose(<4 x i32> %op, i32 2, i32 2)
+  %optt = call <4 x i32> @llvm.matrix.transpose(<4 x i32> %opt, i32 2, i32 2)
+  store <4 x i32> %optt, ptr %out
+  ret void
+}
+
+define void @ashr_2x2(ptr %lhs, ptr %rhs, ptr %out) {
+; CHECK-LABEL: @ashr_2x2(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x i32>, ptr [[LHS:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr i32, ptr [[LHS]], i64 2
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x i32>, ptr [[VEC_GEP]], align 8
+; CHECK-NEXT:    [[COL_LOAD2:%.*]] = load <2 x i32>, ptr [[RHS:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP3:%.*]] = getelementptr i32, ptr [[RHS]], i64 2
+; CHECK-NEXT:    [[COL_LOAD4:%.*]] = load <2 x i32>, ptr [[VEC_GEP3]], align 8
+; CHECK-NEXT:    [[TMP1:%.*]] = ashr <2 x i32> [[COL_LOAD]], [[COL_LOAD2]]
+; CHECK-NEXT:    [[TMP2:%.*]] = ashr <2 x i32> [[COL_LOAD1]], [[COL_LOAD4]]
+; CHECK-NEXT:    store <2 x i32> [[TMP1]], ptr [[OUT:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP5:%.*]] = getelementptr i32, ptr [[OUT]], i64 2
+; CHECK-NEXT:    store <2 x i32> [[TMP2]], ptr [[VEC_GEP5]], align 8
+; CHECK-NEXT:    ret void
+;
+  %lhsv = load <4 x i32>, ptr %lhs
+  %rhsv = load <4 x i32>, ptr %rhs
+  %op = ashr <4 x i32> %lhsv, %rhsv
+  %opt  = call <4 x i32> @llvm.matrix.transpose(<4 x i32> %op, i32 2, i32 2)
+  %optt = call <4 x i32> @llvm.matrix.transpose(<4 x i32> %opt, i32 2, i32 2)
+  store <4 x i32> %optt, ptr %out
+  ret void
+}
+
+define void @and_2x2(ptr %lhs, ptr %rhs, ptr %out) {
+; CHECK-LABEL: @and_2x2(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x i32>, ptr [[LHS:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr i32, ptr [[LHS]], i64 2
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x i32>, ptr [[VEC_GEP]], align 8
+; CHECK-NEXT:    [[COL_LOAD2:%.*]] = load <2 x i32>, ptr [[RHS:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP3:%.*]] = getelementptr i32, ptr [[RHS]], i64 2
+; CHECK-NEXT:    [[COL_LOAD4:%.*]] = load <2 x i32>, ptr [[VEC_GEP3]], align 8
+; CHECK-NEXT:    [[TMP1:%.*]] = and <2 x i32> [[COL_LOAD]], [[COL_LOAD2]]
+; CHECK-NEXT:    [[TMP2:%.*]] = and <2 x i32> [[COL_LOAD1]], [[COL_LOAD4]]
+; CHECK-NEXT:    store <2 x i32> [[TMP1]], ptr [[OUT:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP5:%.*]] = getelementptr i32, ptr [[OUT]], i64 2
+; CHECK-NEXT:    store <2 x i32> [[TMP2]], ptr [[VEC_GEP5]], align 8
+; CHECK-NEXT:    ret void
+;
+  %lhsv = load <4 x i32>, ptr %lhs
+  %rhsv = load <4 x i32>, ptr %rhs
+  %op = and <4 x i32> %lhsv, %rhsv
+  %opt  = call <4 x i32> @llvm.matrix.transpose(<4 x i32> %op, i32 2, i32 2)
+  %optt = call <4 x i32> @llvm.matrix.transpose(<4 x i32> %opt, i32 2, i32 2)
+  store <4 x i32> %optt, ptr %out
+  ret void
+}
+
+define void @or_2x2(ptr %lhs, ptr %rhs, ptr %out) {
+; CHECK-LABEL: @or_2x2(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x i32>, ptr [[LHS:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr i32, ptr [[LHS]], i64 2
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x i32>, ptr [[VEC_GEP]], align 8
+; CHECK-NEXT:    [[COL_LOAD2:%.*]] = load <2 x i32>, ptr [[RHS:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP3:%.*]] = getelementptr i32, ptr [[RHS]], i64 2
+; CHECK-NEXT:    [[COL_LOAD4:%.*]] = load <2 x i32>, ptr [[VEC_GEP3]], align 8
+; CHECK-NEXT:    [[TMP1:%.*]] = or <2 x i32> [[COL_LOAD]], [[COL_LOAD2]]
+; CHECK-NEXT:    [[TMP2:%.*]] = or <2 x i32> [[COL_LOAD1]], [[COL_LOAD4]]
+; CHECK-NEXT:    store <2 x i32> [[TMP1]], ptr [[OUT:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP5:%.*]] = getelementptr i32, ptr [[OUT]], i64 2
+; CHECK-NEXT:    store <2 x i32> [[TMP2]], ptr [[VEC_GEP5]], align 8
+; CHECK-NEXT:    ret void
+;
+  %lhsv = load <4 x i32>, ptr %lhs
+  %rhsv = load <4 x i32>, ptr %rhs
+  %op = or <4 x i32> %lhsv, %rhsv
+  %opt  = call <4 x i32> @llvm.matrix.transpose(<4 x i32> %op, i32 2, i32 2)
+  %optt = call <4 x i32> @llvm.matrix.transpose(<4 x i32> %opt, i32 2, i32 2)
+  store <4 x i32> %optt, ptr %out
+  ret void
+}
+
+define void @xor_2x2(ptr %lhs, ptr %rhs, ptr %out) {
+; CHECK-LABEL: @xor_2x2(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x i32>, ptr [[LHS:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr i32, ptr [[LHS]], i64 2
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x i32>, ptr [[VEC_GEP]], align 8
+; CHECK-NEXT:    [[COL_LOAD2:%.*]] = load <2 x i32>, ptr [[RHS:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP3:%.*]] = getelementptr i32, ptr [[RHS]], i64 2
+; CHECK-NEXT:    [[COL_LOAD4:%.*]] = load <2 x i32>, ptr [[VEC_GEP3]], align 8
+; CHECK-NEXT:    [[TMP1:%.*]] = xor <2 x i32> [[COL_LOAD]], [[COL_LOAD2]]
+; CHECK-NEXT:    [[TMP2:%.*]] = xor <2 x i32> [[COL_LOAD1]], [[COL_LOAD4]]
+; CHECK-NEXT:    store <2 x i32> [[TMP1]], ptr [[OUT:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP5:%.*]] = getelementptr i32, ptr [[OUT]], i64 2
+; CHECK-NEXT:    store <2 x i32> [[TMP2]], ptr [[VEC_GEP5]], align 8
+; CHECK-NEXT:    ret void
+;
+  %lhsv = load <4 x i32>, ptr %lhs
+  %rhsv = load <4 x i32>, ptr %rhs
+  %op = xor <4 x i32> %lhsv, %rhsv
+  %opt  = call <4 x i32> @llvm.matrix.transpose(<4 x i32> %op, i32 2, i32 2)
+  %optt = call <4 x i32> @llvm.matrix.transpose(<4 x i32> %opt, i32 2, i32 2)
+  store <4 x i32> %optt, ptr %out
+  ret void
+}

>From 7984bb90535e3efa554e644a8dfb125116952b97 Mon Sep 17 00:00:00 2001
From: Jon Roelofs <jonathan_roelofs at apple.com>
Date: Wed, 28 May 2025 09:30:50 -0700
Subject: [PATCH 4/6] pre-land test for extract

---
 .../LowerMatrixIntrinsics/extract.ll          | 34 +++++++++++++++++++
 1 file changed, 34 insertions(+)
 create mode 100644 llvm/test/Transforms/LowerMatrixIntrinsics/extract.ll

diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/extract.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/extract.ll
new file mode 100644
index 0000000000000..03f4864e4f8c1
--- /dev/null
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/extract.ll
@@ -0,0 +1,34 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -passes='lower-matrix-intrinsics' -S < %s | FileCheck %s
+
+define float @extract_static(ptr %in, ptr %out) {
+; CHECK-LABEL: @extract_static(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x float>, ptr [[IN:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 2
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <2 x float> [[COL_LOAD]], <2 x float> [[COL_LOAD1]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+; CHECK-NEXT:    [[EXTRACT:%.*]] = extractelement <4 x float> [[TMP1]], i32 0
+; CHECK-NEXT:    ret float [[EXTRACT]]
+;
+  %inv = load <4 x float>, ptr %in
+  %invt  = call <4 x float> @llvm.matrix.transpose(<4 x float> %inv, i32 2, i32 2)
+  %invtt = call <4 x float> @llvm.matrix.transpose(<4 x float> %invt, i32 2, i32 2)
+  %extract = extractelement <4 x float> %invtt, i32 0
+  ret float %extract
+}
+
+define float @extract_dynamic(ptr %in, i32 %idx, ptr %out) {
+; CHECK-LABEL: @extract_dynamic(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x float>, ptr [[IN:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 2
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <2 x float> [[COL_LOAD]], <2 x float> [[COL_LOAD1]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+; CHECK-NEXT:    [[EXTRACT:%.*]] = extractelement <4 x float> [[TMP1]], i32 [[IDX:%.*]]
+; CHECK-NEXT:    ret float [[EXTRACT]]
+;
+  %inv = load <4 x float>, ptr %in
+  %invt  = call <4 x float> @llvm.matrix.transpose(<4 x float> %inv, i32 2, i32 2)
+  %invtt = call <4 x float> @llvm.matrix.transpose(<4 x float> %invt, i32 2, i32 2)
+  %extract = extractelement <4 x float> %invtt, i32 %idx
+  ret float %extract
+}

>From 53098ce18f74005e35bf6d66035d21b62945f277 Mon Sep 17 00:00:00 2001
From: Jon Roelofs <jonathan_roelofs at apple.com>
Date: Wed, 28 May 2025 11:10:21 -0700
Subject: [PATCH 5/6] [Matrix] Optimize static extracts with ShapeInfo

For ExtractElementInsts with static indices that extract from a Matrix, use the
known layout of the Rows/Columns to look through the shuffles that
embedInVector creates, which in some cases allows us to delete them.
---
 .../Scalar/LowerMatrixIntrinsics.cpp          | 44 ++++++++++++++++++-
 .../LowerMatrixIntrinsics/extract.ll          | 20 +++++++--
 .../LowerMatrixIntrinsics/transpose-opts.ll   | 18 ++++----
 3 files changed, 68 insertions(+), 14 deletions(-)

diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index a58180c46be4b..f1624a008802b 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -34,6 +34,7 @@
 #include "llvm/IR/DebugInfoMetadata.h"
 #include "llvm/IR/Function.h"
 #include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Instruction.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/MatrixBuilder.h"
@@ -568,6 +569,7 @@ class LowerMatrixIntrinsics {
         return M;
 
       MatrixVal = M.embedInVector(Builder);
+      Inst2ColumnMatrix[MatrixVal] = M;
     }
 
     // Otherwise split MatrixVal.
@@ -632,7 +634,7 @@ class LowerMatrixIntrinsics {
       default:
         return isUniformShape(II);
       }
-    return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V);
+    return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V) || isa<ExtractElementInst>(V);
   }
 
   /// Propagate the shape information of instructions to their users.
@@ -1083,6 +1085,18 @@ class LowerMatrixIntrinsics {
         Changed |= VisitStore(cast<StoreInst>(Inst), Op1, Op2, Builder);
     }
 
+    // Fifth, lower instructions which can make use of shape information, but do
+    // not have shapes themselves.
+    for (auto *BB : RPOT)
+      for (Instruction &Inst : *BB) {
+        IRBuilder<> Builder(&Inst);
+
+        Value *Op1;
+        uint64_t Index;
+        if (match(&Inst, m_ExtractElt(m_Value(Op1), m_ConstantInt(Index))))
+          Changed |= VisitExtractElt(cast<ExtractElementInst>(&Inst), Index);
+      }
+
     if (ORE) {
       RemarkGenerator RemarkGen(Inst2ColumnMatrix, *ORE, Func);
       RemarkGen.emitRemarks();
@@ -1364,8 +1378,10 @@ class LowerMatrixIntrinsics {
     Value *Flattened = nullptr;
     for (Use &U : llvm::make_early_inc_range(Inst->uses())) {
       if (!ShapeMap.contains(U.getUser())) {
-        if (!Flattened)
+        if (!Flattened) {
           Flattened = Matrix.embedInVector(Builder);
+          Inst2ColumnMatrix[Flattened] = Matrix;
+        }
         U.set(Flattened);
       }
     }
@@ -2142,6 +2158,30 @@ class LowerMatrixIntrinsics {
     return true;
   }
 
+  bool VisitExtractElt(ExtractElementInst *Inst, uint64_t Index) {
+    Value *Op0 = Inst->getOperand(0);
+    auto *VTy = cast<VectorType>(Op0->getType());
+
+    if (VTy->getElementCount().getKnownMinValue() < Index) {
+      Inst->replaceAllUsesWith(PoisonValue::get(VTy->getElementType()));
+      ToRemove.push_back(Inst);
+      return true;
+    }
+
+    auto *I = Inst2ColumnMatrix.find(Op0);
+    if (I == Inst2ColumnMatrix.end())
+      return false;
+
+    const MatrixTy &M = I->second;
+
+    IRBuilder<> Builder(Inst);
+    Inst->setOperand(0, M.getVector(Index / M.getStride()));
+    Inst->setOperand(1, Builder.getInt32(Index % M.getStride()));
+    if (Op0->use_empty() && isa<Instruction>(Op0))
+      ToRemove.push_back(cast<Instruction>(Op0));
+    return true;
+  }
+
   /// Lower binary operators, if shape information is available.
   bool VisitBinaryOperator(BinaryOperator *Inst) {
     auto I = ShapeMap.find(Inst);
diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/extract.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/extract.ll
index 03f4864e4f8c1..67a1a780f3713 100644
--- a/llvm/test/Transforms/LowerMatrixIntrinsics/extract.ll
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/extract.ll
@@ -6,14 +6,28 @@ define float @extract_static(ptr %in, ptr %out) {
 ; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x float>, ptr [[IN:%.*]], align 16
 ; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 2
 ; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
-; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <2 x float> [[COL_LOAD]], <2 x float> [[COL_LOAD1]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
-; CHECK-NEXT:    [[EXTRACT:%.*]] = extractelement <4 x float> [[TMP1]], i32 0
+; CHECK-NEXT:    [[EXTRACT:%.*]] = extractelement <2 x float> [[COL_LOAD1]], i32 1
 ; CHECK-NEXT:    ret float [[EXTRACT]]
 ;
   %inv = load <4 x float>, ptr %in
   %invt  = call <4 x float> @llvm.matrix.transpose(<4 x float> %inv, i32 2, i32 2)
   %invtt = call <4 x float> @llvm.matrix.transpose(<4 x float> %invt, i32 2, i32 2)
-  %extract = extractelement <4 x float> %invtt, i32 0
+  %extract = extractelement <4 x float> %invtt, i32 3
+  ret float %extract
+}
+
+define float @extract_static_outofbounds(ptr %in, ptr %out) {
+; CHECK-LABEL: @extract_static_outofbounds(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x float>, ptr [[IN:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 2
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <2 x float> [[COL_LOAD]], <2 x float> [[COL_LOAD1]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+; CHECK-NEXT:    ret float poison
+;
+  %inv = load <4 x float>, ptr %in
+  %invt  = call <4 x float> @llvm.matrix.transpose(<4 x float> %inv, i32 2, i32 2)
+  %invtt = call <4 x float> @llvm.matrix.transpose(<4 x float> %invt, i32 2, i32 2)
+  %extract = extractelement <4 x float> %invtt, i32 5
   ret float %extract
 }
 
diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/transpose-opts.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/transpose-opts.ll
index 57d038e5cf947..ea71e980d922c 100644
--- a/llvm/test/Transforms/LowerMatrixIntrinsics/transpose-opts.ll
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/transpose-opts.ll
@@ -55,11 +55,11 @@ define void @multiply_ntt(ptr %A, ptr %B, ptr %C, ptr %R) {
 ; REMARK-NEXT: Function:        multiply_ntt
 ; REMARK-NEXT: Args:
 ; REMARK-NEXT:   - String:          'Lowered with '
-; REMARK-NEXT:   - NumStores:       '4'
+; REMARK-NEXT:   - NumStores:       '0'
 ; REMARK-NEXT:   - String:          ' stores, '
-; REMARK-NEXT:   - NumLoads:        '10'
+; REMARK-NEXT:   - NumLoads:        '3'
 ; REMARK-NEXT:   - String:          ' loads, '
-; REMARK-NEXT:   - NumComputeOps:   '38'
+; REMARK-NEXT:   - NumComputeOps:   '0'
 ; REMARK-NEXT:   - String:          ' compute ops, '
 ; REMARK-NEXT:   - NumExposedTransposes: '0'
 ; REMARK-NEXT:   - String:          ' exposed transposes'
@@ -443,11 +443,11 @@ define void @multiply_nt_t(ptr %A, ptr %B, ptr %C) {
 ; REMARK-NEXT: Function:        multiply_nt_t
 ; REMARK-NEXT: Args:
 ; REMARK-NEXT:   - String:          'Lowered with '
-; REMARK-NEXT:   - NumStores:       '4'
+; REMARK-NEXT:   - NumStores:       '0'
 ; REMARK-NEXT:   - String:          ' stores, '
-; REMARK-NEXT:   - NumLoads:        '9'
+; REMARK-NEXT:   - NumLoads:        '3'
 ; REMARK-NEXT:   - String:          ' loads, '
-; REMARK-NEXT:   - NumComputeOps:   '20'
+; REMARK-NEXT:   - NumComputeOps:   '0'
 ; REMARK-NEXT:   - String:          ' compute ops, '
 ; REMARK-NEXT:   - NumExposedTransposes: '0'
 ; REMARK-NEXT:   - String:          ' exposed transposes'
@@ -578,11 +578,11 @@ define void @multiply_ntt_t(ptr %A, ptr %B, ptr %C, ptr %R) {
 ; REMARK-NEXT: Function:        multiply_ntt_t
 ; REMARK-NEXT: Args:
 ; REMARK-NEXT:   - String:          'Lowered with '
-; REMARK-NEXT:   - NumStores:       '6'
+; REMARK-NEXT:   - NumStores:       '0'
 ; REMARK-NEXT:   - String:          ' stores, '
-; REMARK-NEXT:   - NumLoads:        '18'
+; REMARK-NEXT:   - NumLoads:        '6'
 ; REMARK-NEXT:   - String:          ' loads, '
-; REMARK-NEXT:   - NumComputeOps:   '60'
+; REMARK-NEXT:   - NumComputeOps:   '0'
 ; REMARK-NEXT:   - String:          ' compute ops, '
 ; REMARK-NEXT:   - NumExposedTransposes: '0'
 ; REMARK-NEXT:   - String:          ' exposed transposes'

>From 10a854e4dd9f59d4f62fedd25405641cf6897348 Mon Sep 17 00:00:00 2001
From: Jon Roelofs <jonathan_roelofs at apple.com>
Date: Wed, 28 May 2025 11:59:59 -0700
Subject: [PATCH 6/6] simplify the lowering a lot

---
 .../Scalar/LowerMatrixIntrinsics.cpp          | 71 +++++++++----------
 .../LowerMatrixIntrinsics/extract.ll          |  1 -
 .../LowerMatrixIntrinsics/transpose-opts.ll   | 18 ++---
 3 files changed, 44 insertions(+), 46 deletions(-)

diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index a4ac0fc57d5a0..5b5e903066224 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -560,7 +560,6 @@ class LowerMatrixIntrinsics {
         return M;
 
       MatrixVal = M.embedInVector(Builder);
-      Inst2ColumnMatrix[MatrixVal] = M;
     }
 
     // Otherwise split MatrixVal.
@@ -1082,10 +1081,7 @@ class LowerMatrixIntrinsics {
       for (Instruction &Inst : *BB) {
         IRBuilder<> Builder(&Inst);
 
-        Value *Op1;
-        uint64_t Index;
-        if (match(&Inst, m_ExtractElt(m_Value(Op1), m_ConstantInt(Index))))
-          Changed |= VisitExtractElt(cast<ExtractElementInst>(&Inst), Index);
+
       }
 
     if (ORE) {
@@ -1354,6 +1350,28 @@ class LowerMatrixIntrinsics {
     return Builder.CreateAdd(Sum, Mul);
   }
 
+  bool VisitExtractElt(ExtractElementInst *Inst, uint64_t Index) {
+    Value *Op0 = Inst->getOperand(0);
+    auto *VTy = cast<VectorType>(Op0->getType());
+
+    if (VTy->getElementCount().getKnownMinValue() < Index) {
+      Inst->replaceAllUsesWith(PoisonValue::get(VTy->getElementType()));
+      Inst->eraseFromParent();
+      return true;
+    }
+
+    auto *I = Inst2ColumnMatrix.find(Op0);
+    if (I == Inst2ColumnMatrix.end())
+      return false;
+
+    const MatrixTy &M = I->second;
+
+    IRBuilder<> Builder(Inst);
+    Inst->setOperand(0, M.getVector(Index / M.getStride()));
+    Inst->setOperand(1, Builder.getInt32(Index % M.getStride()));
+    return true;
+  }
+
   /// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For
   /// users with shape information, there's nothing to do: they will use the
   /// cached value when they are lowered. For other users, \p Matrix is
@@ -1368,13 +1386,18 @@ class LowerMatrixIntrinsics {
     ToRemove.push_back(Inst);
     Value *Flattened = nullptr;
     for (Use &U : llvm::make_early_inc_range(Inst->uses())) {
-      if (!ShapeMap.contains(U.getUser())) {
-        if (!Flattened) {
-          Flattened = Matrix.embedInVector(Builder);
-          Inst2ColumnMatrix[Flattened] = Matrix;
-        }
-        U.set(Flattened);
-      }
+      if (ShapeMap.contains(U.getUser()))
+        continue;
+
+      Value *Op1;
+      uint64_t Index;
+      if (match(U.getUser(), m_ExtractElt(m_Value(Op1), m_ConstantInt(Index))))
+        if (VisitExtractElt(cast<ExtractElementInst>(U.getUser()), Index))
+          continue;
+
+      if (!Flattened)
+        Flattened = Matrix.embedInVector(Builder);
+      U.set(Flattened);
     }
   }
 
@@ -2149,30 +2172,6 @@ class LowerMatrixIntrinsics {
     return true;
   }
 
-  bool VisitExtractElt(ExtractElementInst *Inst, uint64_t Index) {
-    Value *Op0 = Inst->getOperand(0);
-    auto *VTy = cast<VectorType>(Op0->getType());
-
-    if (VTy->getElementCount().getKnownMinValue() < Index) {
-      Inst->replaceAllUsesWith(PoisonValue::get(VTy->getElementType()));
-      ToRemove.push_back(Inst);
-      return true;
-    }
-
-    auto *I = Inst2ColumnMatrix.find(Op0);
-    if (I == Inst2ColumnMatrix.end())
-      return false;
-
-    const MatrixTy &M = I->second;
-
-    IRBuilder<> Builder(Inst);
-    Inst->setOperand(0, M.getVector(Index / M.getStride()));
-    Inst->setOperand(1, Builder.getInt32(Index % M.getStride()));
-    if (Op0->use_empty() && isa<Instruction>(Op0))
-      ToRemove.push_back(cast<Instruction>(Op0));
-    return true;
-  }
-
   /// Lower binary operators, if shape information is available.
   bool VisitBinaryOperator(BinaryOperator *Inst) {
     auto I = ShapeMap.find(Inst);
diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/extract.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/extract.ll
index 67a1a780f3713..0bac9492d654a 100644
--- a/llvm/test/Transforms/LowerMatrixIntrinsics/extract.ll
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/extract.ll
@@ -21,7 +21,6 @@ define float @extract_static_outofbounds(ptr %in, ptr %out) {
 ; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x float>, ptr [[IN:%.*]], align 16
 ; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 2
 ; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
-; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <2 x float> [[COL_LOAD]], <2 x float> [[COL_LOAD1]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
 ; CHECK-NEXT:    ret float poison
 ;
   %inv = load <4 x float>, ptr %in
diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/transpose-opts.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/transpose-opts.ll
index ea71e980d922c..57d038e5cf947 100644
--- a/llvm/test/Transforms/LowerMatrixIntrinsics/transpose-opts.ll
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/transpose-opts.ll
@@ -55,11 +55,11 @@ define void @multiply_ntt(ptr %A, ptr %B, ptr %C, ptr %R) {
 ; REMARK-NEXT: Function:        multiply_ntt
 ; REMARK-NEXT: Args:
 ; REMARK-NEXT:   - String:          'Lowered with '
-; REMARK-NEXT:   - NumStores:       '0'
+; REMARK-NEXT:   - NumStores:       '4'
 ; REMARK-NEXT:   - String:          ' stores, '
-; REMARK-NEXT:   - NumLoads:        '3'
+; REMARK-NEXT:   - NumLoads:        '10'
 ; REMARK-NEXT:   - String:          ' loads, '
-; REMARK-NEXT:   - NumComputeOps:   '0'
+; REMARK-NEXT:   - NumComputeOps:   '38'
 ; REMARK-NEXT:   - String:          ' compute ops, '
 ; REMARK-NEXT:   - NumExposedTransposes: '0'
 ; REMARK-NEXT:   - String:          ' exposed transposes'
@@ -443,11 +443,11 @@ define void @multiply_nt_t(ptr %A, ptr %B, ptr %C) {
 ; REMARK-NEXT: Function:        multiply_nt_t
 ; REMARK-NEXT: Args:
 ; REMARK-NEXT:   - String:          'Lowered with '
-; REMARK-NEXT:   - NumStores:       '0'
+; REMARK-NEXT:   - NumStores:       '4'
 ; REMARK-NEXT:   - String:          ' stores, '
-; REMARK-NEXT:   - NumLoads:        '3'
+; REMARK-NEXT:   - NumLoads:        '9'
 ; REMARK-NEXT:   - String:          ' loads, '
-; REMARK-NEXT:   - NumComputeOps:   '0'
+; REMARK-NEXT:   - NumComputeOps:   '20'
 ; REMARK-NEXT:   - String:          ' compute ops, '
 ; REMARK-NEXT:   - NumExposedTransposes: '0'
 ; REMARK-NEXT:   - String:          ' exposed transposes'
@@ -578,11 +578,11 @@ define void @multiply_ntt_t(ptr %A, ptr %B, ptr %C, ptr %R) {
 ; REMARK-NEXT: Function:        multiply_ntt_t
 ; REMARK-NEXT: Args:
 ; REMARK-NEXT:   - String:          'Lowered with '
-; REMARK-NEXT:   - NumStores:       '0'
+; REMARK-NEXT:   - NumStores:       '6'
 ; REMARK-NEXT:   - String:          ' stores, '
-; REMARK-NEXT:   - NumLoads:        '6'
+; REMARK-NEXT:   - NumLoads:        '18'
 ; REMARK-NEXT:   - String:          ' loads, '
-; REMARK-NEXT:   - NumComputeOps:   '0'
+; REMARK-NEXT:   - NumComputeOps:   '60'
 ; REMARK-NEXT:   - String:          ' compute ops, '
 ; REMARK-NEXT:   - NumExposedTransposes: '0'
 ; REMARK-NEXT:   - String:          ' exposed transposes'



More information about the llvm-commits mailing list