[llvm] [Matrix] Optimize shuffle extracts with ShapeInfo (PR #142276)
Jon Roelofs via llvm-commits
llvm-commits at lists.llvm.org
Sat May 31 08:19:05 PDT 2025
https://github.com/jroelofs created https://github.com/llvm/llvm-project/pull/142276
When a shuffle extracts a vector that we have as part of the ShapeInfo for a Matrix (i.e. one column of a column-major matrix, or one row of a row-major matrix), replace the shuffle with that vector during lowering.
>From adb2f15583e236fc63f555ddb37d205f935b6bc9 Mon Sep 17 00:00:00 2001
From: Jon Roelofs <jonathan_roelofs at apple.com>
Date: Fri, 30 May 2025 16:09:00 -0700
Subject: [PATCH] [Matrix] Optimize shuffle extracts with ShapeInfo
When a shuffle extracts a vector that we have as part of the ShapeInfo for a
Matrix (i.e. one column of a column-major matrix, or one row of a row-major
matrix), replace the shuffle with that vector during lowering.
---
.../Scalar/LowerMatrixIntrinsics.cpp | 68 +++++++++++++++++--
.../LowerMatrixIntrinsics/shuffle.ll | 34 ++++++++++
2 files changed, 97 insertions(+), 5 deletions(-)
create mode 100644 llvm/test/Transforms/LowerMatrixIntrinsics/shuffle.ll
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 756a72e6d97bc..93d56a9a7bd4a 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -19,6 +19,7 @@
#include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h"
#include "llvm/ADT/PostOrderIterator.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
@@ -32,6 +33,7 @@
#include "llvm/IR/CFG.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/DebugInfoMetadata.h"
+#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instructions.h"
@@ -1337,6 +1339,57 @@ class LowerMatrixIntrinsics {
return Builder.CreateAdd(Sum, Mul);
}
+ bool tryLowerShuffleVector(ShuffleVectorInst *Inst) {
+ Value *Op0 = Inst->getOperand(0), *Op1 = Inst->getOperand(1);
+ SmallVector<int> Mask;
+ Inst->getShuffleMask(Mask);
+
+ auto *Ty = cast<FixedVectorType>(Op0->getType());
+
+ if (Mask[0] == PoisonMaskElem)
+ return false;
+
+ // Check if the Mask implies a contiguous extraction, i.e. one column of a
+ // column-major matrix (or row of a row-major one).
+ for (int I = 1, E = Mask.size(); I != E; ++I) {
+ if (Mask[I] == PoisonMaskElem)
+ return false;
+ if (Mask[I-1] + 1 != Mask[I])
+ return false;
+ }
+
+ auto VectorForIndex = [&](int Idx) {
+ return Idx < int(Ty->getNumElements()) ? Op0 : Op1;
+ };
+
+ // Check if the Mask extracts from a single source operand.
+ Value *Op = VectorForIndex(Mask.front());
+ if (Op != VectorForIndex(Mask.back()))
+ return false;
+
+ auto *I = Inst2ColumnMatrix.find(Op);
+ if (I == Inst2ColumnMatrix.end())
+ return false;
+
+ const MatrixTy &M = I->second;
+
+ // Check if the Mask extracts one entire vector from the matrix.
+ if (Mask.size() != M.getStride())
+ return false;
+
+ // Check if the result would span two of the vectors in the matrix.
+ // TODO: we could handle this case by creating a new shuffle, if we see that
+ // happening in the wild.
+ if (0 != Mask[0] % M.getStride())
+ return false;
+
+ Value *Result = M.getVector(Mask[0] / M.getStride());
+ Inst->replaceAllUsesWith(Result);
+ Result->takeName(Inst);
+ Inst->eraseFromParent();
+ return true;
+ }
+
/// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For
/// users with shape information, there's nothing to do: they will use the
/// cached value when they are lowered. For other users, \p Matrix is
@@ -1351,11 +1404,16 @@ class LowerMatrixIntrinsics {
ToRemove.push_back(Inst);
Value *Flattened = nullptr;
for (Use &U : llvm::make_early_inc_range(Inst->uses())) {
- if (!ShapeMap.contains(U.getUser())) {
- if (!Flattened)
- Flattened = Matrix.embedInVector(Builder);
- U.set(Flattened);
- }
+ if (ShapeMap.contains(U.getUser()))
+ continue;
+
+ if (auto *Intr = dyn_cast<ShuffleVectorInst>(U.getUser()))
+ if (tryLowerShuffleVector(Intr))
+ continue;
+
+ if (!Flattened)
+ Flattened = Matrix.embedInVector(Builder);
+ U.set(Flattened);
}
}
diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/shuffle.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/shuffle.ll
new file mode 100644
index 0000000000000..21f49d2561d4e
--- /dev/null
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/shuffle.ll
@@ -0,0 +1,34 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -passes='lower-matrix-intrinsics' -S < %s | FileCheck %s
+
+define <3 x double> @extract_column(ptr %in, ptr %out) {
+; CHECK-LABEL: @extract_column(
+; CHECK-NEXT: [[COL_LOAD:%.*]] = load volatile <3 x double>, ptr [[IN:%.*]], align 8
+; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr double, ptr [[IN]], i64 3
+; CHECK-NEXT: [[COL_LOAD1:%.*]] = load volatile <3 x double>, ptr [[VEC_GEP]], align 8
+; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr double, ptr [[IN]], i64 6
+; CHECK-NEXT: [[COL_LOAD3:%.*]] = load volatile <3 x double>, ptr [[VEC_GEP2]], align 8
+; CHECK-NEXT: ret <3 x double> [[COL_LOAD3]]
+;
+ %inv = call <9 x double> @llvm.matrix.column.major.load(ptr %in, i64 3, i1 1, i32 3, i32 3)
+ %col = shufflevector <9 x double> %inv, <9 x double> poison, <3 x i32> <i32 6, i32 7, i32 8>
+ ret <3 x double> %col
+}
+
+define <3 x double> @extract_row(ptr %in, ptr %out) {
+; CHECK-LABEL: @extract_row(
+; CHECK-NEXT: [[COL_LOAD:%.*]] = load volatile <3 x double>, ptr [[IN:%.*]], align 8
+; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr double, ptr [[IN]], i64 3
+; CHECK-NEXT: [[COL_LOAD1:%.*]] = load volatile <3 x double>, ptr [[VEC_GEP]], align 8
+; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr double, ptr [[IN]], i64 6
+; CHECK-NEXT: [[COL_LOAD3:%.*]] = load volatile <3 x double>, ptr [[VEC_GEP2]], align 8
+; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <3 x double> [[COL_LOAD]], <3 x double> [[COL_LOAD1]], <6 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5>
+; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <3 x double> [[COL_LOAD3]], <3 x double> poison, <6 x i32> <i32 0, i32 1, i32 2, i32 poison, i32 poison, i32 poison>
+; CHECK-NEXT: [[TMP3:%.*]] = shufflevector <6 x double> [[TMP1]], <6 x double> [[TMP2]], <9 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8>
+; CHECK-NEXT: [[ROW:%.*]] = shufflevector <9 x double> [[TMP3]], <9 x double> poison, <3 x i32> <i32 0, i32 3, i32 6>
+; CHECK-NEXT: ret <3 x double> [[ROW]]
+;
+ %inv = call <9 x double> @llvm.matrix.column.major.load(ptr %in, i64 3, i1 1, i32 3, i32 3)
+ %row = shufflevector <9 x double> %inv, <9 x double> poison, <3 x i32> <i32 0, i32 3, i32 6>
+ ret <3 x double> %row
+}
More information about the llvm-commits
mailing list