[llvm] [Matrix] Optimize static extracts with ShapeInfo (PR #141815)
Jon Roelofs via llvm-commits
llvm-commits at lists.llvm.org
Thu May 29 11:07:27 PDT 2025
https://github.com/jroelofs updated https://github.com/llvm/llvm-project/pull/141815
>From 323c9a8d2de459a0f81f32f7537e8b2e087ffc00 Mon Sep 17 00:00:00 2001
From: Jon Roelofs <jonathan_roelofs at apple.com>
Date: Wed, 28 May 2025 09:30:50 -0700
Subject: [PATCH 1/2] [Matrix] Optimize static extracts with ShapeInfo
For ExtractElementInsts with static indices that extract from a Matrix, use the
known layout of the Rows/Columns, avoiding some of the shuffles that
embedInVector creates.
---
.../Scalar/LowerMatrixIntrinsics.cpp | 43 ++++++++++++++---
.../LowerMatrixIntrinsics/extract.ll | 47 +++++++++++++++++++
2 files changed, 84 insertions(+), 6 deletions(-)
create mode 100644 llvm/test/Transforms/LowerMatrixIntrinsics/extract.ll
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 756a72e6d97bc..8b322afd9b6e4 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -34,6 +34,7 @@
#include "llvm/IR/DebugInfoMetadata.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Instruction.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/MatrixBuilder.h"
@@ -623,7 +624,8 @@ class LowerMatrixIntrinsics {
default:
return false;
}
- return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V);
+ return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V) ||
+ isa<ExtractElementInst>(V);
}
/// Propagate the shape information of instructions to their users.
@@ -1337,6 +1339,28 @@ class LowerMatrixIntrinsics {
return Builder.CreateAdd(Sum, Mul);
}
+ bool VisitExtractElt(ExtractElementInst *Inst, uint64_t Index) {
+ Value *Op0 = Inst->getOperand(0);
+ auto *VTy = cast<VectorType>(Op0->getType());
+
+ if (VTy->getElementCount().getKnownMinValue() < Index) {
+ Inst->replaceAllUsesWith(PoisonValue::get(VTy->getElementType()));
+ Inst->eraseFromParent();
+ return true;
+ }
+
+ auto *I = Inst2ColumnMatrix.find(Op0);
+ if (I == Inst2ColumnMatrix.end())
+ return false;
+
+ const MatrixTy &M = I->second;
+
+ IRBuilder<> Builder(Inst);
+ Inst->setOperand(0, M.getVector(Index / M.getStride()));
+ Inst->setOperand(1, Builder.getInt32(Index % M.getStride()));
+ return true;
+ }
+
/// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For
/// users with shape information, there's nothing to do: they will use the
/// cached value when they are lowered. For other users, \p Matrix is
@@ -1351,11 +1375,18 @@ class LowerMatrixIntrinsics {
ToRemove.push_back(Inst);
Value *Flattened = nullptr;
for (Use &U : llvm::make_early_inc_range(Inst->uses())) {
- if (!ShapeMap.contains(U.getUser())) {
- if (!Flattened)
- Flattened = Matrix.embedInVector(Builder);
- U.set(Flattened);
- }
+ if (ShapeMap.contains(U.getUser()))
+ continue;
+
+ Value *Op1;
+ uint64_t Index;
+ if (match(U.getUser(), m_ExtractElt(m_Value(Op1), m_ConstantInt(Index))))
+ if (VisitExtractElt(cast<ExtractElementInst>(U.getUser()), Index))
+ continue;
+
+ if (!Flattened)
+ Flattened = Matrix.embedInVector(Builder);
+ U.set(Flattened);
}
}
diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/extract.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/extract.ll
new file mode 100644
index 0000000000000..0bac9492d654a
--- /dev/null
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/extract.ll
@@ -0,0 +1,47 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -passes='lower-matrix-intrinsics' -S < %s | FileCheck %s
+
+define float @extract_static(ptr %in, ptr %out) {
+; CHECK-LABEL: @extract_static(
+; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x float>, ptr [[IN:%.*]], align 16
+; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 2
+; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
+; CHECK-NEXT: [[EXTRACT:%.*]] = extractelement <2 x float> [[COL_LOAD1]], i32 1
+; CHECK-NEXT: ret float [[EXTRACT]]
+;
+ %inv = load <4 x float>, ptr %in
+ %invt = call <4 x float> @llvm.matrix.transpose(<4 x float> %inv, i32 2, i32 2)
+ %invtt = call <4 x float> @llvm.matrix.transpose(<4 x float> %invt, i32 2, i32 2)
+ %extract = extractelement <4 x float> %invtt, i32 3
+ ret float %extract
+}
+
+define float @extract_static_outofbounds(ptr %in, ptr %out) {
+; CHECK-LABEL: @extract_static_outofbounds(
+; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x float>, ptr [[IN:%.*]], align 16
+; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 2
+; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
+; CHECK-NEXT: ret float poison
+;
+ %inv = load <4 x float>, ptr %in
+ %invt = call <4 x float> @llvm.matrix.transpose(<4 x float> %inv, i32 2, i32 2)
+ %invtt = call <4 x float> @llvm.matrix.transpose(<4 x float> %invt, i32 2, i32 2)
+ %extract = extractelement <4 x float> %invtt, i32 5
+ ret float %extract
+}
+
+define float @extract_dynamic(ptr %in, i32 %idx, ptr %out) {
+; CHECK-LABEL: @extract_dynamic(
+; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x float>, ptr [[IN:%.*]], align 16
+; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 2
+; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
+; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <2 x float> [[COL_LOAD]], <2 x float> [[COL_LOAD1]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+; CHECK-NEXT: [[EXTRACT:%.*]] = extractelement <4 x float> [[TMP1]], i32 [[IDX:%.*]]
+; CHECK-NEXT: ret float [[EXTRACT]]
+;
+ %inv = load <4 x float>, ptr %in
+ %invt = call <4 x float> @llvm.matrix.transpose(<4 x float> %inv, i32 2, i32 2)
+ %invtt = call <4 x float> @llvm.matrix.transpose(<4 x float> %invt, i32 2, i32 2)
+ %extract = extractelement <4 x float> %invtt, i32 %idx
+ ret float %extract
+}
>From c962ce9a79028d9514951d8ec07a818114aeea92 Mon Sep 17 00:00:00 2001
From: Jon Roelofs <jonathan_roelofs at apple.com>
Date: Thu, 29 May 2025 11:07:13 -0700
Subject: [PATCH 2/2] use colum major load intrinsic for shape info
---
.../LowerMatrixIntrinsics/extract.ll | 30 ++++++++-----------
1 file changed, 12 insertions(+), 18 deletions(-)
diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/extract.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/extract.ll
index 0bac9492d654a..db5444ca036ae 100644
--- a/llvm/test/Transforms/LowerMatrixIntrinsics/extract.ll
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/extract.ll
@@ -3,45 +3,39 @@
define float @extract_static(ptr %in, ptr %out) {
; CHECK-LABEL: @extract_static(
-; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x float>, ptr [[IN:%.*]], align 16
+; CHECK-NEXT: [[COL_LOAD:%.*]] = load volatile <2 x float>, ptr [[IN:%.*]], align 4
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 2
-; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
+; CHECK-NEXT: [[COL_LOAD1:%.*]] = load volatile <2 x float>, ptr [[VEC_GEP]], align 4
; CHECK-NEXT: [[EXTRACT:%.*]] = extractelement <2 x float> [[COL_LOAD1]], i32 1
; CHECK-NEXT: ret float [[EXTRACT]]
;
- %inv = load <4 x float>, ptr %in
- %invt = call <4 x float> @llvm.matrix.transpose(<4 x float> %inv, i32 2, i32 2)
- %invtt = call <4 x float> @llvm.matrix.transpose(<4 x float> %invt, i32 2, i32 2)
- %extract = extractelement <4 x float> %invtt, i32 3
+ %inv = call <4 x float> @llvm.matrix.column.major.load(ptr %in, i64 2, i1 1, i32 2, i32 2)
+ %extract = extractelement <4 x float> %inv, i32 3
ret float %extract
}
define float @extract_static_outofbounds(ptr %in, ptr %out) {
; CHECK-LABEL: @extract_static_outofbounds(
-; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x float>, ptr [[IN:%.*]], align 16
+; CHECK-NEXT: [[COL_LOAD:%.*]] = load volatile <2 x float>, ptr [[IN:%.*]], align 4
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 2
-; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
+; CHECK-NEXT: [[COL_LOAD1:%.*]] = load volatile <2 x float>, ptr [[VEC_GEP]], align 4
; CHECK-NEXT: ret float poison
;
- %inv = load <4 x float>, ptr %in
- %invt = call <4 x float> @llvm.matrix.transpose(<4 x float> %inv, i32 2, i32 2)
- %invtt = call <4 x float> @llvm.matrix.transpose(<4 x float> %invt, i32 2, i32 2)
- %extract = extractelement <4 x float> %invtt, i32 5
+ %inv = call <4 x float> @llvm.matrix.column.major.load(ptr %in, i64 2, i1 1, i32 2, i32 2)
+ %extract = extractelement <4 x float> %inv, i32 5
ret float %extract
}
define float @extract_dynamic(ptr %in, i32 %idx, ptr %out) {
; CHECK-LABEL: @extract_dynamic(
-; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x float>, ptr [[IN:%.*]], align 16
+; CHECK-NEXT: [[COL_LOAD:%.*]] = load volatile <2 x float>, ptr [[IN:%.*]], align 4
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 2
-; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
+; CHECK-NEXT: [[COL_LOAD1:%.*]] = load volatile <2 x float>, ptr [[VEC_GEP]], align 4
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <2 x float> [[COL_LOAD]], <2 x float> [[COL_LOAD1]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
; CHECK-NEXT: [[EXTRACT:%.*]] = extractelement <4 x float> [[TMP1]], i32 [[IDX:%.*]]
; CHECK-NEXT: ret float [[EXTRACT]]
;
- %inv = load <4 x float>, ptr %in
- %invt = call <4 x float> @llvm.matrix.transpose(<4 x float> %inv, i32 2, i32 2)
- %invtt = call <4 x float> @llvm.matrix.transpose(<4 x float> %invt, i32 2, i32 2)
- %extract = extractelement <4 x float> %invtt, i32 %idx
+ %inv = call <4 x float> @llvm.matrix.column.major.load(ptr %in, i64 2, i1 1, i32 2, i32 2)
+ %extract = extractelement <4 x float> %inv, i32 %idx
ret float %extract
}
More information about the llvm-commits
mailing list