[llvm] ca5b71a - [Matrix] Propagate shape information through Select insts (#141876)

via llvm-commits llvm-commits at lists.llvm.org
Thu Jun 12 07:52:29 PDT 2025


Author: Jon Roelofs
Date: 2025-06-12T07:52:25-07:00
New Revision: ca5b71a4559890a9768558ddea724782fb638bfa

URL: https://github.com/llvm/llvm-project/commit/ca5b71a4559890a9768558ddea724782fb638bfa
DIFF: https://github.com/llvm/llvm-project/commit/ca5b71a4559890a9768558ddea724782fb638bfa.diff

LOG: [Matrix] Propagate shape information through Select insts (#141876)

Added: 
    llvm/test/Transforms/LowerMatrixIntrinsics/select.ll

Modified: 
    llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index a7072ea719292..ce6eaa292d8fb 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -323,9 +323,11 @@ computeShapeInfoForInst(Instruction *I,
       return OpShape->second;
   }
 
-  if (isUniformShape(I)) {
+  if (isUniformShape(I) || isa<SelectInst>(I)) {
+    auto Ops = I->operands();
+    auto ShapedOps = isa<SelectInst>(I) ? drop_begin(Ops) : Ops;
     // Find the first operand that has a known shape and use that.
-    for (auto &Op : I->operands()) {
+    for (auto &Op : ShapedOps) {
       auto OpShape = ShapeMap.find(Op.get());
       if (OpShape != ShapeMap.end())
         return OpShape->second;
@@ -701,7 +703,8 @@ class LowerMatrixIntrinsics {
       default:
         return isUniformShape(II);
       }
-    return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V);
+    return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V) ||
+           isa<SelectInst>(V);
   }
 
   /// Propagate the shape information of instructions to their users.
@@ -788,10 +791,12 @@ class LowerMatrixIntrinsics {
       } else if (isa<StoreInst>(V)) {
         // Nothing to do.  We forward-propagated to this so we would just
         // backward propagate to an instruction with an already known shape.
-      } else if (isUniformShape(V)) {
+      } else if (isUniformShape(V) || isa<SelectInst>(V)) {
+        auto Ops = cast<Instruction>(V)->operands();
+        auto ShapedOps = isa<SelectInst>(V) ? drop_begin(Ops) : Ops;
         // Propagate to all operands.
         ShapeInfo Shape = ShapeMap[V];
-        for (Use &U : cast<Instruction>(V)->operands()) {
+        for (Use &U : ShapedOps) {
           if (setShapeInfo(U.get(), Shape))
             pushInstruction(U.get(), WorkList);
         }
@@ -1148,6 +1153,8 @@ class LowerMatrixIntrinsics {
         Result = VisitUnaryOperator(UnOp, SI);
       else if (auto *Intr = dyn_cast<IntrinsicInst>(Inst))
         Result = VisitIntrinsicInst(Intr, SI);
+      else if (auto *Select = dyn_cast<SelectInst>(Inst))
+        Result = VisitSelectInst(Select, SI);
       else if (match(Inst, m_Load(m_Value(Op1))))
         Result = VisitLoad(cast<LoadInst>(Inst), SI, Op1);
       else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2))))
@@ -2307,6 +2314,36 @@ class LowerMatrixIntrinsics {
                                    Result.getNumVectors());
   }
 
+  /// Lower selects.
+  MatrixTy VisitSelectInst(SelectInst *Inst, const ShapeInfo &Shape) {
+    Value *Cond = Inst->getOperand(0);
+    Value *OpA = Inst->getOperand(1);
+    Value *OpB = Inst->getOperand(2);
+
+    IRBuilder<> Builder(Inst);
+
+    MatrixTy Result;
+    MatrixTy A = getMatrix(OpA, Shape, Builder);
+    MatrixTy B = getMatrix(OpB, Shape, Builder);
+
+    Value *CondV[2];
+    if (isa<FixedVectorType>(Cond->getType())) {
+      MatrixTy C = getMatrix(Cond, Shape, Builder);
+      CondV[0] = C.getVector(0);
+      CondV[1] = C.getVector(1);
+    } else {
+      CondV[0] = Cond;
+      CondV[1] = Cond;
+    }
+
+    for (unsigned I = 0, E = Shape.getNumVectors(); I != E; ++I)
+      Result.addVector(
+          Builder.CreateSelect(CondV[I], A.getVector(I), B.getVector(I)));
+
+    return Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
+                                   Result.getNumVectors());
+  }
+
   /// Helper to linearize a matrix expression tree into a string. Currently
   /// matrix expressions are linarized by starting at an expression leaf and
   /// linearizing bottom up.

diff  --git a/llvm/test/Transforms/LowerMatrixIntrinsics/select.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/select.ll
new file mode 100644
index 0000000000000..70b0dfdb3e7e8
--- /dev/null
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/select.ll
@@ -0,0 +1,146 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -passes='lower-matrix-intrinsics' -S < %s | FileCheck %s
+
+define void @select_2x2_bot(i1 %cond, ptr %lhs, ptr %rhs, ptr %out) {
+; CHECK-LABEL: @select_2x2_bot(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x float>, ptr [[LHS:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[LHS]], i64 2
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
+; CHECK-NEXT:    [[COL_LOAD2:%.*]] = load <2 x float>, ptr [[RHS:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP3:%.*]] = getelementptr float, ptr [[RHS]], i64 2
+; CHECK-NEXT:    [[COL_LOAD4:%.*]] = load <2 x float>, ptr [[VEC_GEP3]], align 8
+; CHECK-NEXT:    [[TMP1:%.*]] = select i1 [[COND:%.*]], <2 x float> [[COL_LOAD]], <2 x float> [[COL_LOAD2]]
+; CHECK-NEXT:    [[TMP2:%.*]] = select i1 [[COND]], <2 x float> [[COL_LOAD1]], <2 x float> [[COL_LOAD4]]
+; CHECK-NEXT:    store <2 x float> [[TMP1]], ptr [[OUT:%.*]], align 4
+; CHECK-NEXT:    [[VEC_GEP5:%.*]] = getelementptr float, ptr [[OUT]], i64 2
+; CHECK-NEXT:    store <2 x float> [[TMP2]], ptr [[VEC_GEP5]], align 4
+; CHECK-NEXT:    ret void
+;
+  %lhsv = load <4 x float>, ptr %lhs
+  %rhsv = load <4 x float>, ptr %rhs
+  %op = select i1 %cond, <4 x float> %lhsv, <4 x float> %rhsv
+  call void @llvm.matrix.column.major.store(<4 x float> %op, ptr %out, i64 2, i1 false, i32 2, i32 2)
+  ret void
+}
+
+define void @select_2x2_lhs(i1 %cond, ptr %lhs, ptr %rhs, ptr %out) {
+; CHECK-LABEL: @select_2x2_lhs(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x float>, ptr [[LHS:%.*]], align 4
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[LHS]], i64 2
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT:    [[COL_LOAD2:%.*]] = load <2 x float>, ptr [[RHS:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP3:%.*]] = getelementptr float, ptr [[RHS]], i64 2
+; CHECK-NEXT:    [[COL_LOAD4:%.*]] = load <2 x float>, ptr [[VEC_GEP3]], align 8
+; CHECK-NEXT:    [[TMP1:%.*]] = select i1 [[COND:%.*]], <2 x float> [[COL_LOAD]], <2 x float> [[COL_LOAD2]]
+; CHECK-NEXT:    [[TMP2:%.*]] = select i1 [[COND]], <2 x float> [[COL_LOAD1]], <2 x float> [[COL_LOAD4]]
+; CHECK-NEXT:    store <2 x float> [[TMP1]], ptr [[OUT:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP5:%.*]] = getelementptr float, ptr [[OUT]], i64 2
+; CHECK-NEXT:    store <2 x float> [[TMP2]], ptr [[VEC_GEP5]], align 8
+; CHECK-NEXT:    ret void
+;
+  %lhsv = call <4 x float> @llvm.matrix.column.major.load(ptr %lhs, i64 2, i1 false, i32 2, i32 2)
+  %rhsv = load <4 x float>, ptr %rhs
+  %op = select i1 %cond, <4 x float> %lhsv, <4 x float> %rhsv
+  store <4 x float> %op, ptr %out
+  ret void
+}
+
+define void @select_2x2_rhs(i1 %cond, ptr %lhs, ptr %rhs, ptr %out) {
+; CHECK-LABEL: @select_2x2_rhs(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x float>, ptr [[RHS:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[RHS]], i64 2
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
+; CHECK-NEXT:    [[COL_LOAD2:%.*]] = load <2 x float>, ptr [[RHS1:%.*]], align 4
+; CHECK-NEXT:    [[VEC_GEP3:%.*]] = getelementptr float, ptr [[RHS1]], i64 2
+; CHECK-NEXT:    [[COL_LOAD4:%.*]] = load <2 x float>, ptr [[VEC_GEP3]], align 4
+; CHECK-NEXT:    [[TMP1:%.*]] = select i1 [[COND:%.*]], <2 x float> [[COL_LOAD]], <2 x float> [[COL_LOAD2]]
+; CHECK-NEXT:    [[TMP2:%.*]] = select i1 [[COND]], <2 x float> [[COL_LOAD1]], <2 x float> [[COL_LOAD4]]
+; CHECK-NEXT:    store <2 x float> [[TMP1]], ptr [[OUT:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP2:%.*]] = getelementptr float, ptr [[OUT]], i64 2
+; CHECK-NEXT:    store <2 x float> [[TMP2]], ptr [[VEC_GEP2]], align 8
+; CHECK-NEXT:    ret void
+;
+  %lhsv = load <4 x float>, ptr %lhs
+  %rhsv = call <4 x float> @llvm.matrix.column.major.load(ptr %rhs, i64 2, i1 false, i32 2, i32 2)
+  %op = select i1 %cond, <4 x float> %lhsv, <4 x float> %rhsv
+  store <4 x float> %op, ptr %out
+  ret void
+}
+
+define void @select_2x2_vcond_shape1(ptr %cond, ptr %lhs, ptr %rhs, ptr %out) {
+; CHECK-LABEL: @select_2x2_vcond_shape1(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x float>, ptr [[LHS:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[LHS]], i64 2
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
+; CHECK-NEXT:    [[CONDV:%.*]] = load <4 x i1>, ptr [[COND:%.*]], align 1
+; CHECK-NEXT:    [[COL_LOAD5:%.*]] = load <2 x float>, ptr [[RHS1:%.*]], align 4
+; CHECK-NEXT:    [[VEC_GEP6:%.*]] = getelementptr float, ptr [[RHS1]], i64 2
+; CHECK-NEXT:    [[COL_LOAD7:%.*]] = load <2 x float>, ptr [[VEC_GEP6]], align 4
+; CHECK-NEXT:    [[COL_LOAD2:%.*]] = shufflevector <4 x i1> [[CONDV]], <4 x i1> poison, <2 x i32> <i32 0, i32 1>
+; CHECK-NEXT:    [[COL_LOAD4:%.*]] = shufflevector <4 x i1> [[CONDV]], <4 x i1> poison, <2 x i32> <i32 2, i32 3>
+; CHECK-NEXT:    [[TMP1:%.*]] = select <2 x i1> [[COL_LOAD2]], <2 x float> [[COL_LOAD]], <2 x float> [[COL_LOAD5]]
+; CHECK-NEXT:    [[TMP2:%.*]] = select <2 x i1> [[COL_LOAD4]], <2 x float> [[COL_LOAD1]], <2 x float> [[COL_LOAD7]]
+; CHECK-NEXT:    store <2 x float> [[TMP1]], ptr [[OUT:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP8:%.*]] = getelementptr float, ptr [[OUT]], i64 2
+; CHECK-NEXT:    store <2 x float> [[TMP2]], ptr [[VEC_GEP8]], align 8
+; CHECK-NEXT:    ret void
+;
+  %lhsv = load <4 x float>, ptr %lhs
+  %condv = load <4 x i1>, ptr %cond
+  %rhsv = call <4 x float> @llvm.matrix.column.major.load(ptr %rhs, i64 2, i1 false, i32 2, i32 2)
+  %op = select <4 x i1> %condv, <4 x float> %lhsv, <4 x float> %rhsv
+  store <4 x float> %op, ptr %out
+  ret void
+}
+
+define void @select_2x2_vcond_shape2(ptr %cond, ptr %lhs, ptr %rhs, ptr %out) {
+; CHECK-LABEL: @select_2x2_vcond_shape2(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x float>, ptr [[LHS:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[LHS]], i64 2
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
+; CHECK-NEXT:    [[COL_LOAD2:%.*]] = load <2 x i1>, ptr [[COND:%.*]], align 1
+; CHECK-NEXT:    [[VEC_GEP3:%.*]] = getelementptr i1, ptr [[COND]], i64 2
+; CHECK-NEXT:    [[COL_LOAD4:%.*]] = load <2 x i1>, ptr [[VEC_GEP3]], align 1
+; CHECK-NEXT:    [[COL_LOAD5:%.*]] = load <2 x float>, ptr [[RHS:%.*]], align 4
+; CHECK-NEXT:    [[VEC_GEP6:%.*]] = getelementptr float, ptr [[RHS]], i64 2
+; CHECK-NEXT:    [[COL_LOAD7:%.*]] = load <2 x float>, ptr [[VEC_GEP6]], align 4
+; CHECK-NEXT:    [[TMP1:%.*]] = select <2 x i1> [[COL_LOAD2]], <2 x float> [[COL_LOAD]], <2 x float> [[COL_LOAD5]]
+; CHECK-NEXT:    [[TMP2:%.*]] = select <2 x i1> [[COL_LOAD4]], <2 x float> [[COL_LOAD1]], <2 x float> [[COL_LOAD7]]
+; CHECK-NEXT:    store <2 x float> [[TMP1]], ptr [[OUT:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP8:%.*]] = getelementptr float, ptr [[OUT]], i64 2
+; CHECK-NEXT:    store <2 x float> [[TMP2]], ptr [[VEC_GEP8]], align 8
+; CHECK-NEXT:    ret void
+;
+  %lhsv = load <4 x float>, ptr %lhs
+  %condv = call <4 x i1> @llvm.matrix.column.major.load(ptr %cond, i64 2, i1 false, i32 2, i32 2)
+  %rhsv = call <4 x float> @llvm.matrix.column.major.load(ptr %rhs, i64 2, i1 false, i32 2, i32 2)
+  %op = select <4 x i1> %condv, <4 x float> %lhsv, <4 x float> %rhsv
+  store <4 x float> %op, ptr %out
+  ret void
+}
+
+define void @select_2x2_vcond_shape3(ptr %cond, ptr %lhs, ptr %rhs, ptr %out) {
+; CHECK-LABEL: @select_2x2_vcond_shape3(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x float>, ptr [[LHS:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[LHS]], i64 2
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
+; CHECK-NEXT:    [[COL_LOAD2:%.*]] = load <4 x i1>, ptr [[COND:%.*]], align 1
+; CHECK-NEXT:    [[COL_LOAD3:%.*]] = load <2 x float>, ptr [[RHS:%.*]], align 4
+; CHECK-NEXT:    [[VEC_GEP4:%.*]] = getelementptr float, ptr [[RHS]], i64 2
+; CHECK-NEXT:    [[COL_LOAD5:%.*]] = load <2 x float>, ptr [[VEC_GEP4]], align 4
+; CHECK-NEXT:    [[SPLIT:%.*]] = shufflevector <4 x i1> [[COL_LOAD2]], <4 x i1> poison, <2 x i32> <i32 0, i32 1>
+; CHECK-NEXT:    [[SPLIT6:%.*]] = shufflevector <4 x i1> [[COL_LOAD2]], <4 x i1> poison, <2 x i32> <i32 2, i32 3>
+; CHECK-NEXT:    [[TMP1:%.*]] = select <2 x i1> [[SPLIT]], <2 x float> [[COL_LOAD]], <2 x float> [[COL_LOAD3]]
+; CHECK-NEXT:    [[TMP2:%.*]] = select <2 x i1> [[SPLIT6]], <2 x float> [[COL_LOAD1]], <2 x float> [[COL_LOAD5]]
+; CHECK-NEXT:    store <2 x float> [[TMP1]], ptr [[OUT:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP7:%.*]] = getelementptr float, ptr [[OUT]], i64 2
+; CHECK-NEXT:    store <2 x float> [[TMP2]], ptr [[VEC_GEP7]], align 8
+; CHECK-NEXT:    ret void
+;
+  %lhsv = load <4 x float>, ptr %lhs
+  %condv = call <4 x i1> @llvm.matrix.column.major.load(ptr %cond, i64 4, i1 false, i32 4, i32 1)
+  %rhsv = call <4 x float> @llvm.matrix.column.major.load(ptr %rhs, i64 2, i1 false, i32 2, i32 2)
+  %op = select <4 x i1> %condv, <4 x float> %lhsv, <4 x float> %rhsv
+  store <4 x float> %op, ptr %out
+  ret void
+}


        


More information about the llvm-commits mailing list