[llvm] [Matrix] Optimize shuffle extracts with ShapeInfo (PR #142276)

Jon Roelofs via llvm-commits llvm-commits at lists.llvm.org
Sat May 31 08:19:05 PDT 2025


https://github.com/jroelofs created https://github.com/llvm/llvm-project/pull/142276

When a shuffle extracts a vector that we have as part of the ShapeInfo for a Matrix (i.e. one column of a column-major matrix, or one row of a row-major matrix), replace the shuffle with that vector during lowering.

>From adb2f15583e236fc63f555ddb37d205f935b6bc9 Mon Sep 17 00:00:00 2001
From: Jon Roelofs <jonathan_roelofs at apple.com>
Date: Fri, 30 May 2025 16:09:00 -0700
Subject: [PATCH] [Matrix] Optimize shuffle extracts with ShapeInfo

When a shuffle extracts a vector that we have as part of the ShapeInfo for a
Matrix (i.e. one column of a column-major matrix, or one row of a row-major
matrix), replace the shuffle with that vector during lowering.
---
 .../Scalar/LowerMatrixIntrinsics.cpp          | 68 +++++++++++++++++--
 .../LowerMatrixIntrinsics/shuffle.ll          | 34 ++++++++++
 2 files changed, 97 insertions(+), 5 deletions(-)
 create mode 100644 llvm/test/Transforms/LowerMatrixIntrinsics/shuffle.ll

diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 756a72e6d97bc..93d56a9a7bd4a 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -19,6 +19,7 @@
 
 #include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h"
 #include "llvm/ADT/PostOrderIterator.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/ScopeExit.h"
 #include "llvm/ADT/SmallSet.h"
 #include "llvm/ADT/SmallVector.h"
@@ -32,6 +33,7 @@
 #include "llvm/IR/CFG.h"
 #include "llvm/IR/DataLayout.h"
 #include "llvm/IR/DebugInfoMetadata.h"
+#include "llvm/IR/DerivedTypes.h"
 #include "llvm/IR/Function.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/Instructions.h"
@@ -1337,6 +1339,57 @@ class LowerMatrixIntrinsics {
     return Builder.CreateAdd(Sum, Mul);
   }
 
+  bool tryLowerShuffleVector(ShuffleVectorInst *Inst) {
+    Value *Op0 = Inst->getOperand(0), *Op1 = Inst->getOperand(1);
+    SmallVector<int> Mask;
+    Inst->getShuffleMask(Mask);
+
+    auto *Ty = cast<FixedVectorType>(Op0->getType());
+
+    if (Mask[0] == PoisonMaskElem)
+      return false;
+
+    // Check if the Mask implies a contiguous extraction, i.e. one column of a
+    // column-major matrix (or row of a row-major one).
+    for (int I = 1, E = Mask.size(); I != E; ++I) {
+      if (Mask[I] == PoisonMaskElem)
+        return false;
+      if (Mask[I-1] + 1 != Mask[I])
+        return false;
+    }
+
+    auto VectorForIndex = [&](int Idx) {
+      return Idx < int(Ty->getNumElements()) ? Op0 : Op1;
+    };
+
+    // Check if the Mask extracts from a single source operand.
+    Value *Op = VectorForIndex(Mask.front());
+    if (Op != VectorForIndex(Mask.back()))
+      return false;
+
+    auto *I = Inst2ColumnMatrix.find(Op);
+    if (I == Inst2ColumnMatrix.end())
+      return false;
+
+    const MatrixTy &M = I->second;
+
+    // Check if the Mask extracts one entire vector from the matrix.
+    if (Mask.size() != M.getStride())
+      return false;
+
+    // Check if the result would span two of the vectors in the matrix.
+    // TODO: we could handle this case by creating a new shuffle, if we see that
+    // happening in the wild.
+    if (0 != Mask[0] % M.getStride())
+      return false;
+
+    Value *Result = M.getVector(Mask[0] / M.getStride());
+    Inst->replaceAllUsesWith(Result);
+    Result->takeName(Inst);
+    Inst->eraseFromParent();
+    return true;
+  }
+
   /// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For
   /// users with shape information, there's nothing to do: they will use the
   /// cached value when they are lowered. For other users, \p Matrix is
@@ -1351,11 +1404,16 @@ class LowerMatrixIntrinsics {
     ToRemove.push_back(Inst);
     Value *Flattened = nullptr;
     for (Use &U : llvm::make_early_inc_range(Inst->uses())) {
-      if (!ShapeMap.contains(U.getUser())) {
-        if (!Flattened)
-          Flattened = Matrix.embedInVector(Builder);
-        U.set(Flattened);
-      }
+      if (ShapeMap.contains(U.getUser()))
+        continue;
+
+      if (auto *Intr = dyn_cast<ShuffleVectorInst>(U.getUser()))
+        if (tryLowerShuffleVector(Intr))
+          continue;
+
+      if (!Flattened)
+        Flattened = Matrix.embedInVector(Builder);
+      U.set(Flattened);
     }
   }
 
diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/shuffle.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/shuffle.ll
new file mode 100644
index 0000000000000..21f49d2561d4e
--- /dev/null
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/shuffle.ll
@@ -0,0 +1,34 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -passes='lower-matrix-intrinsics' -S < %s | FileCheck %s
+
+define <3 x double> @extract_column(ptr %in, ptr %out) {
+; CHECK-LABEL: @extract_column(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load volatile <3 x double>, ptr [[IN:%.*]], align 8
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr double, ptr [[IN]], i64 3
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load volatile <3 x double>, ptr [[VEC_GEP]], align 8
+; CHECK-NEXT:    [[VEC_GEP2:%.*]] = getelementptr double, ptr [[IN]], i64 6
+; CHECK-NEXT:    [[COL_LOAD3:%.*]] = load volatile <3 x double>, ptr [[VEC_GEP2]], align 8
+; CHECK-NEXT:    ret <3 x double> [[COL_LOAD3]]
+;
+  %inv = call <9 x double> @llvm.matrix.column.major.load(ptr %in, i64 3, i1 1, i32 3, i32 3)
+  %col = shufflevector <9 x double> %inv, <9 x double> poison, <3 x i32> <i32 6, i32 7, i32 8>
+  ret <3 x double> %col
+}
+
+define <3 x double> @extract_row(ptr %in, ptr %out) {
+; CHECK-LABEL: @extract_row(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load volatile <3 x double>, ptr [[IN:%.*]], align 8
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr double, ptr [[IN]], i64 3
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load volatile <3 x double>, ptr [[VEC_GEP]], align 8
+; CHECK-NEXT:    [[VEC_GEP2:%.*]] = getelementptr double, ptr [[IN]], i64 6
+; CHECK-NEXT:    [[COL_LOAD3:%.*]] = load volatile <3 x double>, ptr [[VEC_GEP2]], align 8
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <3 x double> [[COL_LOAD]], <3 x double> [[COL_LOAD1]], <6 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5>
+; CHECK-NEXT:    [[TMP2:%.*]] = shufflevector <3 x double> [[COL_LOAD3]], <3 x double> poison, <6 x i32> <i32 0, i32 1, i32 2, i32 poison, i32 poison, i32 poison>
+; CHECK-NEXT:    [[TMP3:%.*]] = shufflevector <6 x double> [[TMP1]], <6 x double> [[TMP2]], <9 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8>
+; CHECK-NEXT:    [[ROW:%.*]] = shufflevector <9 x double> [[TMP3]], <9 x double> poison, <3 x i32> <i32 0, i32 3, i32 6>
+; CHECK-NEXT:    ret <3 x double> [[ROW]]
+;
+  %inv = call <9 x double> @llvm.matrix.column.major.load(ptr %in, i64 3, i1 1, i32 3, i32 3)
+  %row = shufflevector <9 x double> %inv, <9 x double> poison, <3 x i32> <i32 0, i32 3, i32 6>
+  ret <3 x double> %row
+}



More information about the llvm-commits mailing list