[llvm] [LoopInterchange] Relax the legality check to accept more patterns (PR #118267)
Ryotaro Kasuga via llvm-commits
llvm-commits at lists.llvm.org
Wed Jan 29 03:52:21 PST 2025
https://github.com/kasuga-fj updated https://github.com/llvm/llvm-project/pull/118267
>From 5ff4eb276dc47834b08633af736fc6e7303d6250 Mon Sep 17 00:00:00 2001
From: Ryotaro Kasuga <kasuga.ryotaro at fujitsu.com>
Date: Wed, 27 Nov 2024 12:32:24 +0000
Subject: [PATCH] [LoopInterchange] Relax the legality check to accept more
patterns
We lose opportunities to interchange loops because the current legality
check is stricter than necessary. This patch relaxes the restriction and
increases the number of acceptable patterns. Here is a motivating
example.
```
// From TSVC s231
for (int nl=0;nl<100;nl++) {
for (int i=0;i<256;i++) {
for (int j=1;j<256;j++)
aa[j][i] = aa[j-1][i] + bb[j][i];
}
dummy(aa, bb);
}
```
This patch allows us to interchange the two innermost in the above
code. Note, however, that the current implementation interchanges these
loops twice so that they end up going back in the original order.
---
.../lib/Transforms/Scalar/LoopInterchange.cpp | 138 ++++++++++--
.../LoopInterchange/inner-only-reductions.ll | 2 +-
.../LoopInterchange/legality-checks.ll | 209 ++++++++++++++++++
3 files changed, 330 insertions(+), 19 deletions(-)
create mode 100644 llvm/test/Transforms/LoopInterchange/legality-checks.ll
diff --git a/llvm/lib/Transforms/Scalar/LoopInterchange.cpp b/llvm/lib/Transforms/Scalar/LoopInterchange.cpp
index ca125d2c0c490c..d3de0cdc859fc6 100644
--- a/llvm/lib/Transforms/Scalar/LoopInterchange.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopInterchange.cpp
@@ -72,6 +72,15 @@ using LoopVector = SmallVector<Loop *, 8>;
// TODO: Check if we can use a sparse matrix here.
using CharMatrix = std::vector<std::vector<char>>;
+// Classification of a direction vector by the leftmost element after removing
+// '=' and 'I' from it.
+enum class DirectionVectorPattern {
+ Zero, ///< The direction vector contains only '=' or 'I'.
+ Positive, ///< The leftmost element after removing '=' and 'I' is '<'.
+ Negative, ///< The leftmost element after removing '=' and 'I' is '>'.
+ All, ///< The leftmost element after removing '=' and 'I' is '*'.
+};
+
} // end anonymous namespace
// Minimum loop depth supported.
@@ -199,18 +208,115 @@ static void interChangeDependencies(CharMatrix &DepMatrix, unsigned FromIndx,
std::swap(DepMatrix[I][ToIndx], DepMatrix[I][FromIndx]);
}
-// After interchanging, check if the direction vector is valid.
-// [Theorem] A permutation of the loops in a perfect nest is legal if and only
-// if the direction matrix, after the same permutation is applied to its
-// columns, has no ">" direction as the leftmost non-"=" direction in any row.
-static bool isLexicographicallyPositive(std::vector<char> &DV) {
- for (unsigned char Direction : DV) {
- if (Direction == '<')
+// Clasify the direction vector into the four patterns. The target vector is
+// [DV[Left], DV[Left+1], ..., DV[Right-1]], not the whole of \p DV.
+static DirectionVectorPattern
+classifyDirectionVector(const std::vector<char> &DV, unsigned Left,
+ unsigned Right) {
+ assert(Left <= Right && "Left must be less or equal to Right");
+ for (unsigned I = Left; I < Right; I++) {
+ unsigned char Direction = DV[I];
+ switch (Direction) {
+ case '<':
+ return DirectionVectorPattern::Positive;
+ case '>':
+ return DirectionVectorPattern::Negative;
+ case '*':
+ return DirectionVectorPattern::All;
+ case '=':
+ case 'I':
+ break;
+ default:
+ llvm_unreachable("Unknown element in direction vector");
+ }
+ }
+ return DirectionVectorPattern::Zero;
+}
+
+// Check whether the requested interchange is legal or not. The interchange is
+// valid if the following condition holds:
+//
+// [Cond] For two instructions that can access the same location, the execution
+// order of the instructions before and after interchanged is the same.
+//
+// If the direction vector doesn't contain '*', the above Cond is equivalent to
+// one of the following:
+//
+// - The leftmost non-'=' element is '<' before and after interchanging.
+// - The leftmost non-'=' element is '>' before and after interchanging.
+// - All the elements in the direction vector is '='.
+//
+// As for '*', we must treat it as having dependency in all directions. It could
+// be '<', it could be '>', it could be '='. We can eliminate '*'s from the
+// direction vector by enumerating all possible patterns by replacing '*' with
+// '<' or '>' or '=', and then doing the above checks for all of them. The
+// enumeration can grow exponentially, so it is not practical to run it as it
+// is. Fortunately, we can perform the following pruning.
+//
+// - For '*' to the left of \p OuterLoopId, replacing it with '=' is allowed.
+//
+// This is because, for patterns where '<' (or '>') is assigned to some '*' to
+// the left of \p OuterLoopId, the first (or second) condition above holds
+// regardless of interchanging. After doing this pruning, the interchange is
+// legal if the leftmost non-'=' element is the same before and after swapping
+// the element of \p OuterLoopId and \p InnerLoopId.
+//
+//
+// Example: Consider the following loop.
+//
+// ```
+// for (i=0; i<=32; i++)
+// for (j=0; j<N-1; j++)
+// for (k=0; k<N-1; k++) {
+// Src: A[i][j][k] = ...;
+// Dst: use(A[32-i][j+1][k+1]);
+// }
+// ```
+//
+// In this case, the direction vector is [* < <] (if the analysis is powerful
+// enough). The enumeration of all possible patterns by replacing '*' is as
+// follows:
+//
+// - [< < <] : when i < 16
+// - [= < <] : when i = 16
+// - [> < <] : when i > 16
+//
+// We can prove that it is safe to interchange the innermost two loops here,
+// because the interchange doesn't change the leftmost non-'=' element for all
+// enumerated vectors.
+//
+// TODO: There are cases where the interchange is legal but rejected. At least
+// the following patterns are legal:
+// - If both Dep[OuterLoopId] and Dep[InnerLoopId] are '=', the interchange is
+// legal regardless of any other elements.
+// - If the loops are adjacent to each other and at least one of them is '=',
+// the interchange is legal even if the other is '*'.
+static bool isLegalToInterchangeLoopsForRow(std::vector<char> Dep,
+ unsigned InnerLoopId,
+ unsigned OuterLoopId) {
+ // Replace '*' to the left of OuterLoopId with '='. The presence of '<' means
+ // that the direction vector is something like [= = = < ...], where the
+ // interchange is safe.
+ for (unsigned I = 0; I < OuterLoopId; I++) {
+ if (Dep[I] == '<' || Dep[I] == '>')
return true;
- if (Direction == '>' || Direction == '*')
- return false;
+ Dep[I] = '=';
}
- return true;
+
+ // From this point on, all elements to the left of OuterLoopId are considered
+ // to be '='.
+
+ // Perform legality checks by comparing the leftmost non-'=' element between
+ // before and after the interchange. If either one is '*', then the
+ // interchange is unsafe. Otherwise it is safe if the element is equal.
+ auto BeforePattern =
+ classifyDirectionVector(Dep, OuterLoopId, InnerLoopId + 1);
+ if (BeforePattern == DirectionVectorPattern::All)
+ return false;
+ std::swap(Dep[InnerLoopId], Dep[OuterLoopId]);
+ auto AfterPattern =
+ classifyDirectionVector(Dep, OuterLoopId, InnerLoopId + 1);
+ return BeforePattern == AfterPattern;
}
// Checks if it is legal to interchange 2 loops.
@@ -218,16 +324,12 @@ static bool isLegalToInterChangeLoops(CharMatrix &DepMatrix,
unsigned InnerLoopId,
unsigned OuterLoopId) {
unsigned NumRows = DepMatrix.size();
- std::vector<char> Cur;
// For each row check if it is valid to interchange.
for (unsigned Row = 0; Row < NumRows; ++Row) {
- // Create temporary DepVector check its lexicographical order
- // before and after swapping OuterLoop vs InnerLoop
- Cur = DepMatrix[Row];
- if (!isLexicographicallyPositive(Cur))
- return false;
- std::swap(Cur[InnerLoopId], Cur[OuterLoopId]);
- if (!isLexicographicallyPositive(Cur))
+ // The vector is copied because the elements will be modified in the
+ // `isLegalToInterchangeLoopsForRow`
+ if (!isLegalToInterchangeLoopsForRow(DepMatrix[Row], InnerLoopId,
+ OuterLoopId))
return false;
}
return true;
diff --git a/llvm/test/Transforms/LoopInterchange/inner-only-reductions.ll b/llvm/test/Transforms/LoopInterchange/inner-only-reductions.ll
index ee1a7f16199288..c6c2e2a8a51874 100644
--- a/llvm/test/Transforms/LoopInterchange/inner-only-reductions.ll
+++ b/llvm/test/Transforms/LoopInterchange/inner-only-reductions.ll
@@ -74,7 +74,7 @@ for.end8: ; preds = %for.cond1.for.inc6_
; CHECK: --- !Missed
; CHECK-NEXT: Pass: loop-interchange
-; CHECK-NEXT: Name: Dependence
+; CHECK-NEXT: Name: UnsupportedPHIOuter
; CHECK-NEXT: Function: reduction_03
; IR-LABEL: @reduction_03(
diff --git a/llvm/test/Transforms/LoopInterchange/legality-checks.ll b/llvm/test/Transforms/LoopInterchange/legality-checks.ll
new file mode 100644
index 00000000000000..334003e922ec6e
--- /dev/null
+++ b/llvm/test/Transforms/LoopInterchange/legality-checks.ll
@@ -0,0 +1,209 @@
+; REQUIRES: asserts
+; RUN: opt < %s -passes=loop-interchange -verify-dom-info -verify-loop-info \
+; RUN: -disable-output -debug 2>&1 | FileCheck %s
+
+ at a = dso_local global [20 x [20 x [20 x i32]]] zeroinitializer, align 4
+ at aa = dso_local global [256 x [256 x float]] zeroinitializer, align 64
+ at bb = dso_local global [256 x [256 x float]] zeroinitializer, align 64
+
+;; for (int nl=0;nl<100;++nl)
+;; for (int i=0;i<256;++i)
+;; for (int j=1;j<256;++j)
+;; aa[j][i] = aa[j-1][i] + bb[j][i];
+;;
+;; The direction vector of `aa` is [* = >]. We can interchange the innermost
+;; two loops, The direction vector after interchanging will be [* > =].
+
+; CHECK: Dependency matrix before interchange:
+; CHECK-NEXT: * = >
+; CHECK-NEXT: * = =
+; CHECK-NEXT: Processing InnerLoopId = 2 and OuterLoopId = 1
+; CHECK-NEXT: Checking if loops are tightly nested
+; CHECK-NEXT: Checking instructions in Loop header and Loop latch
+; CHECK-NEXT: Loops are perfectly nested
+; CHECK-NEXT: Loops are legal to interchange
+; CHECK: Dependency matrix after interchange:
+; CHECK-NEXT: * > =
+; CHECK-NEXT: * = =
+
+define void @all_eq_gt() {
+entry:
+ br label %for.cond1.preheader
+
+for.cond1.preheader:
+ %nl.036 = phi i32 [ 0, %entry ], [ %inc23, %for.cond.cleanup3 ]
+ br label %for.cond5.preheader
+
+for.cond.cleanup3:
+ %inc23 = add nuw nsw i32 %nl.036, 1
+ %exitcond43 = icmp ne i32 %inc23, 100
+ br i1 %exitcond43, label %for.cond1.preheader, label %for.cond.cleanup
+
+for.cond.cleanup7:
+ %indvars.iv.next40 = add nuw nsw i64 %indvars.iv39, 1
+ %exitcond42 = icmp ne i64 %indvars.iv.next40, 256
+ br i1 %exitcond42, label %for.cond5.preheader, label %for.cond.cleanup3
+
+for.body8:
+ %indvars.iv = phi i64 [ 1, %for.cond5.preheader ], [ %indvars.iv.next, %for.body8 ]
+ %0 = add nsw i64 %indvars.iv, -1
+ %arrayidx10 = getelementptr inbounds [256 x [256 x float]], ptr @aa, i64 0, i64 %0, i64 %indvars.iv39
+ %1 = load float, ptr %arrayidx10, align 4
+ %arrayidx14 = getelementptr inbounds [256 x [256 x float]], ptr @bb, i64 0, i64 %indvars.iv, i64 %indvars.iv39
+ %2 = load float, ptr %arrayidx14, align 4
+ %add = fadd fast float %2, %1
+ %arrayidx18 = getelementptr inbounds [256 x [256 x float]], ptr @aa, i64 0, i64 %indvars.iv, i64 %indvars.iv39
+ store float %add, ptr %arrayidx18, align 4
+ %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1
+ %exitcond = icmp ne i64 %indvars.iv.next, 256
+ br i1 %exitcond, label %for.body8, label %for.cond.cleanup7
+
+for.cond5.preheader:
+ %indvars.iv39 = phi i64 [ 0, %for.cond1.preheader ], [ %indvars.iv.next40, %for.cond.cleanup7 ]
+ br label %for.body8
+
+for.cond.cleanup:
+ ret void
+}
+
+;; for (int i=0;i<256;++i)
+;; for (int j=1;j<256;++j)
+;; aa[j][i] = aa[j-1][255-i] + bb[j][i];
+;;
+;; The direction vector of `aa` is [* >]. We cannot interchange the loops
+;; because we must handle a `*` dependence conservatively.
+
+; CHECK: Dependency matrix before interchange:
+; CHECK-NEXT: * >
+; CHECK-NEXT: Processing InnerLoopId = 1 and OuterLoopId = 0
+; CHECK-NEXT: Failed interchange InnerLoopId = 1 and OuterLoopId = 0 due to dependence
+; CHECK-NEXT: Not interchanging loops. Cannot prove legality.
+
+define void @all_gt() {
+entry:
+ br label %for.cond1.preheader
+
+for.cond1.preheader:
+ %indvars.iv31 = phi i64 [ 0, %entry ], [ %indvars.iv.next32, %for.cond.cleanup3 ]
+ %0 = sub nuw nsw i64 255, %indvars.iv31
+ br label %for.body4
+
+for.cond.cleanup3:
+ %indvars.iv.next32 = add nuw nsw i64 %indvars.iv31, 1
+ %exitcond35 = icmp ne i64 %indvars.iv.next32, 256
+ br i1 %exitcond35, label %for.cond1.preheader, label %for.cond.cleanup
+
+for.body4:
+ %indvars.iv = phi i64 [ 1, %for.cond1.preheader ], [ %indvars.iv.next, %for.body4 ]
+ %1 = add nsw i64 %indvars.iv, -1
+ %arrayidx7 = getelementptr inbounds [256 x [256 x float]], ptr @aa, i64 0, i64 %1, i64 %0
+ %2 = load float, ptr %arrayidx7, align 4
+ %arrayidx11 = getelementptr inbounds [256 x [256 x float]], ptr @bb, i64 0, i64 %indvars.iv, i64 %indvars.iv31
+ %3 = load float, ptr %arrayidx11, align 4
+ %add = fadd fast float %3, %2
+ %arrayidx15 = getelementptr inbounds [256 x [256 x float]], ptr @aa, i64 0, i64 %indvars.iv, i64 %indvars.iv31
+ store float %add, ptr %arrayidx15, align 4
+ %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1
+ %exitcond = icmp ne i64 %indvars.iv.next, 256
+ br i1 %exitcond, label %for.body4, label %for.cond.cleanup3
+
+for.cond.cleanup:
+ ret void
+}
+
+;; for (int i=0;i<255;++i)
+;; for (int j=1;j<256;++j)
+;; aa[j][i] = aa[j-1][i+1] + bb[j][i];
+;;
+;; The direciton vector of `aa` is [< >]. We cannot interchange the loops
+;; because the read/write order for `aa` cannot be changed.
+
+; CHECK: Dependency matrix before interchange:
+; CHECK-NEXT: < >
+; CHECK-NEXT: Processing InnerLoopId = 1 and OuterLoopId = 0
+; CHECK-NEXT: Failed interchange InnerLoopId = 1 and OuterLoopId = 0 due to dependence
+; CHECK-NEXT: Not interchanging loops. Cannot prove legality.
+
+define void @lt_gt() {
+entry:
+ br label %for.cond1.preheader
+
+for.cond1.preheader:
+ %indvars.iv31 = phi i64 [ 0, %entry ], [ %indvars.iv.next32, %for.cond.cleanup3 ]
+ %indvars.iv.next32 = add nuw nsw i64 %indvars.iv31, 1
+ br label %for.body4
+
+for.cond.cleanup3:
+ %exitcond34 = icmp ne i64 %indvars.iv.next32, 255
+ br i1 %exitcond34, label %for.cond1.preheader, label %for.cond.cleanup
+
+for.body4:
+ %indvars.iv = phi i64 [ 1, %for.cond1.preheader ], [ %indvars.iv.next, %for.body4 ]
+ %0 = add nsw i64 %indvars.iv, -1
+ %arrayidx6 = getelementptr inbounds [256 x [256 x float]], ptr @aa, i64 0, i64 %0, i64 %indvars.iv.next32
+ %1 = load float, ptr %arrayidx6, align 4
+ %arrayidx10 = getelementptr inbounds [256 x [256 x float]], ptr @bb, i64 0, i64 %indvars.iv, i64 %indvars.iv31
+ %2 = load float, ptr %arrayidx10, align 4
+ %add11 = fadd fast float %2, %1
+ %arrayidx15 = getelementptr inbounds [256 x [256 x float]], ptr @aa, i64 0, i64 %indvars.iv, i64 %indvars.iv31
+ store float %add11, ptr %arrayidx15, align 4
+ %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1
+ %exitcond = icmp ne i64 %indvars.iv.next, 256
+ br i1 %exitcond, label %for.body4, label %for.cond.cleanup3
+
+for.cond.cleanup:
+ ret void
+}
+
+;; for (int i=0;i<20;i++)
+;; for (int j=0;j<20;j++)
+;; for (int k=1;k<20;k++)
+;; a[i][j][k] = a[i][5][k-1];
+;;
+;; The direction vector of `a` is [= * >]. We cannot interchange all the loops.
+
+; CHECK: Dependency matrix before interchange:
+; CHECK-NEXT: = * >
+; CHECK-NEXT: Processing InnerLoopId = 2 and OuterLoopId = 1
+; CHECK-NEXT: Failed interchange InnerLoopId = 2 and OuterLoopId = 1 due to dependence
+; CHECK-NEXT: Not interchanging loops. Cannot prove legality.
+; CHECK-NEXT: Processing InnerLoopId = 1 and OuterLoopId = 0
+; CHECK-NEXT: Failed interchange InnerLoopId = 1 and OuterLoopId = 0 due to dependence
+; CHECK-NEXT: Not interchanging loops. Cannot prove legality.
+
+define void @eq_all_gt() {
+entry:
+ br label %for.cond1.preheader
+
+for.cond1.preheader:
+ %indvars.iv44 = phi i64 [ 0, %entry ], [ %indvars.iv.next45, %for.cond.cleanup3 ]
+ br label %for.cond5.preheader
+
+for.cond.cleanup3:
+ %indvars.iv.next45 = add nuw nsw i64 %indvars.iv44, 1
+ %exitcond47 = icmp ne i64 %indvars.iv.next45, 20
+ br i1 %exitcond47, label %for.cond1.preheader, label %for.cond.cleanup
+
+for.cond.cleanup7:
+ %indvars.iv.next41 = add nuw nsw i64 %indvars.iv40, 1
+ %exitcond43 = icmp ne i64 %indvars.iv.next41, 20
+ br i1 %exitcond43, label %for.cond5.preheader, label %for.cond.cleanup3
+
+for.body8:
+ %indvars.iv = phi i64 [ 1, %for.cond5.preheader ], [ %indvars.iv.next, %for.body8 ]
+ %0 = add nsw i64 %indvars.iv, -1
+ %arrayidx11 = getelementptr inbounds [20 x [20 x [20 x i32]]], ptr @a, i64 0, i64 %indvars.iv44, i64 5, i64 %0
+ %1 = load i32, ptr %arrayidx11, align 4
+ %arrayidx17 = getelementptr inbounds nuw [20 x [20 x [20 x i32]]], ptr @a, i64 0, i64 %indvars.iv44, i64 %indvars.iv40, i64 %indvars.iv
+ store i32 %1, ptr %arrayidx17, align 4
+ %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1
+ %exitcond = icmp ne i64 %indvars.iv.next, 20
+ br i1 %exitcond, label %for.body8, label %for.cond.cleanup7
+
+for.cond5.preheader:
+ %indvars.iv40 = phi i64 [ 0, %for.cond1.preheader ], [ %indvars.iv.next41, %for.cond.cleanup7 ]
+ br label %for.body8
+
+for.cond.cleanup:
+ ret void
+}
More information about the llvm-commits
mailing list