[PATCH] D114336: [Polly] Generalize the pattern matching to the case of tensor contractions.
Roman via Phabricator via llvm-commits
llvm-commits at lists.llvm.org
Sun Aug 7 04:27:18 PDT 2022
gareevroman marked 6 inline comments as done.
gareevroman added a comment.
In D114336#3694323 <https://reviews.llvm.org/D114336#3694323>, @Meinersbur wrote:
> Thank you Gareev. I think the description can still be improved, I but we should also move forward and can improve iteratively.
>
> Looking forward for the actual TC optimization.
Thanks! I've tried to address new comments in the committed patch.
================
Comment at: polly/lib/Transform/MatmulOptimizer.cpp:199
+/// Tensor contraction (TC) of tensors A, B into tensor C can be represented as
+/// C(shuffle(I,J))=∑α·A(shuffle(I,P))·B(shuffle(P,J))+β·C(shuffle(I,J)),
+/// where ∑ is a summation over all contracted indices of P,
----------------
Meinersbur wrote:
> AFAIU multiplication by β is not part of this detection, but required to be loop-distributed by the isl scheduler.
Yes, it's not. I've added a comment about this.
================
Comment at: polly/lib/Transform/MatmulOptimizer.cpp:1261
+ TCI.ReadFromC = nullptr;
+ SmallVector<MemoryAccess *, 32> Accesses = getAccessesInOrder(*Stmt);
+ for (auto *MemA = Accesses.begin(); *MemA != TCI.WriteToC; MemA++) {
----------------
Meinersbur wrote:
> gareevroman wrote:
> > Meinersbur wrote:
> > > gareevroman wrote:
> > > > Meinersbur wrote:
> > > > > `getAccessesInOrder` requires `Stmt` to not be a RegionStmt. Please add a test for it.
> > > > I’ve added a check to containsOnlyTCAcc. Could you clarify how the test case should look like? Should it be a region statement that contains a matrix multiplication with right order of memory accesses?
> > > Test in `containsOnlyTCAcc` is exactly what I was looking for. A region statement could look like this:
> > >
> > > ```
> > > c = C[i][j];
> > > if (/*non-affine condition*/) {
> > > (void)A[i][k] + B[k][j];
> > > } else {
> > > C[i][j] = c;
> > > }
> > > ```
> > > which has the correct order of accesses but is obviously not what we are looking for.
> > >
> > Thanks for the example! I have added a corresponding test case. If I am not mistaken, it requires DeLICM.
> It does not require DeLICM, but `-polly-allow-nonaffine-branches` (which is enabled by default)
If I'm not mistaken, in your example the form of the dependencies doesn't correspond to the pattern.
```
c = C[i][j];
if (/*non-affine condition*/) {
A[i][k] + B[k][j];
} else {
C[i][j] = c;
}
```
MayWrite: { Stmt_for_body8__TO__for_inc[i0, i1, i2] -> MemRef_C[i0, i1] : 0 <= i0 <= 31 and 0 <= i1 <= 31 and 0 <= i2 <= 31 }
I've added a slightly modified version of it to polly/test/ScheduleOptimizer/pattern-matching-based-opts_23.ll. It produces a region statement too.
```
for (int i = 0; i < 32; i++)
for (int j = 0; j < 32; j++)
for (int k = 0; k < 32; k++) {
int c = C[i][j];
if (i*j*k < 10) {
C[i][j] = A[i][k] + B[k][j];
} else {
C[i][j] = c;
}
}
```
However, it introduces store merge phi nodes. It makes DeLICM necessary.
```
Statements {
Stmt_for_body8__TO__if_end
Domain :=
{ Stmt_for_body8__TO__if_end[i0, i1, i2] : 0 <= i0 <= 31 and 0 <= i1 <= 31 and 0 <= i2 <= 31 };
Schedule :=
{ Stmt_for_body8__TO__if_end[i0, i1, i2] -> [i0, i1, i2, 0] };
ReadAccess := [Reduction Type: NONE] [Scalar: 0]
{ Stmt_for_body8__TO__if_end[i0, i1, i2] -> MemRef_A[i0, i2] };
ReadAccess := [Reduction Type: NONE] [Scalar: 0]
{ Stmt_for_body8__TO__if_end[i0, i1, i2] -> MemRef_B[i2, i1] };
ReadAccess := [Reduction Type: NONE] [Scalar: 0]
{ Stmt_for_body8__TO__if_end[i0, i1, i2] -> MemRef_C[i0, i1] };
MustWriteAccess := [Reduction Type: NONE] [Scalar: 1]
{ Stmt_for_body8__TO__if_end[i0, i1, i2] -> MemRef_storemerge__phi[] };
new: { Stmt_for_body8__TO__if_end[i0, i1, i2] -> MemRef_C[i0, i1] };
Stmt_if_end
Domain :=
{ Stmt_if_end[i0, i1, i2] : 0 <= i0 <= 31 and 0 <= i1 <= 31 and 0 <= i2 <= 31 };
Schedule :=
{ Stmt_if_end[i0, i1, i2] -> [i0, i1, i2, 1] };
ReadAccess := [Reduction Type: NONE] [Scalar: 1]
{ Stmt_if_end[i0, i1, i2] -> MemRef_storemerge__phi[] };
new: { Stmt_if_end[i0, i1, i2] -> MemRef_C[i0, i1] };
MustWriteAccess := [Reduction Type: NONE] [Scalar: 0]
{ Stmt_if_end[i0, i1, i2] -> MemRef_C[i0, i1] };
```
Repository:
rG LLVM Github Monorepo
CHANGES SINCE LAST ACTION
https://reviews.llvm.org/D114336/new/
https://reviews.llvm.org/D114336
More information about the llvm-commits
mailing list