[PATCH] D114336: [Polly] Generalize the pattern matching to the case of tensor contractions.

Roman via Phabricator via llvm-commits llvm-commits at lists.llvm.org
Sun Aug 7 04:27:18 PDT 2022


gareevroman marked 6 inline comments as done.
gareevroman added a comment.

In D114336#3694323 <https://reviews.llvm.org/D114336#3694323>, @Meinersbur wrote:

> Thank you Gareev. I think the description can still be improved, I but we should also move forward and can improve iteratively.
>
> Looking forward for the actual TC optimization.

Thanks! I've tried to address new comments in the committed patch.



================
Comment at: polly/lib/Transform/MatmulOptimizer.cpp:199
+/// Tensor contraction (TC) of tensors A, B into tensor C can be represented as
+/// C(shuffle(I,J))=∑α·A(shuffle(I,P))·B(shuffle(P,J))+β·C(shuffle(I,J)),
+/// where ∑ is a summation over all contracted indices of P,
----------------
Meinersbur wrote:
> AFAIU multiplication by β is not part of this detection, but required to be loop-distributed by the isl scheduler.
Yes, it's not. I've added a comment about this.


================
Comment at: polly/lib/Transform/MatmulOptimizer.cpp:1261
+  TCI.ReadFromC = nullptr;
+  SmallVector<MemoryAccess *, 32> Accesses = getAccessesInOrder(*Stmt);
+  for (auto *MemA = Accesses.begin(); *MemA != TCI.WriteToC; MemA++) {
----------------
Meinersbur wrote:
> gareevroman wrote:
> > Meinersbur wrote:
> > > gareevroman wrote:
> > > > Meinersbur wrote:
> > > > > `getAccessesInOrder` requires `Stmt` to not be a RegionStmt. Please add a test for it.
> > > > I’ve added a check to containsOnlyTCAcc. Could you clarify how the test case should look like? Should it be a region statement that contains a matrix multiplication with right order of memory accesses?
> > > Test in `containsOnlyTCAcc` is exactly what I was looking for. A region statement could look like this:
> > > 
> > > ```
> > > c = C[i][j];
> > > if (/*non-affine condition*/) {
> > >   (void)A[i][k] + B[k][j];
> > > } else {
> > >   C[i][j] = c;
> > > }
> > > ```
> > > which has the correct order of accesses but is obviously not what we are looking for.
> > > 
> > Thanks for the example! I have added a corresponding test case. If I am not mistaken, it requires DeLICM.
> It does not require DeLICM, but `-polly-allow-nonaffine-branches` (which is enabled by default)
If I'm not mistaken, in your example the form of the dependencies doesn't correspond to the pattern.

```
c = C[i][j];
if (/*non-affine condition*/) {
  A[i][k] + B[k][j];
} else {
  C[i][j] = c;
}
```

MayWrite: { Stmt_for_body8__TO__for_inc[i0, i1, i2] -> MemRef_C[i0, i1] : 0 <= i0 <= 31 and 0 <= i1 <= 31 and 0 <= i2 <= 31 }

I've added a slightly modified version of it to polly/test/ScheduleOptimizer/pattern-matching-based-opts_23.ll. It produces a region statement too.

```
for (int i = 0; i < 32; i++)
  for (int j = 0; j < 32; j++)
    for (int k = 0; k < 32; k++) {
      int c = C[i][j];
      if (i*j*k < 10) {
        C[i][j] = A[i][k] + B[k][j];
      } else {
        C[i][j] = c;
      } 
}
```

However, it introduces store merge phi nodes. It makes DeLICM necessary.

```
    Statements {
    	Stmt_for_body8__TO__if_end
            Domain :=
                { Stmt_for_body8__TO__if_end[i0, i1, i2] : 0 <= i0 <= 31 and 0 <= i1 <= 31 and 0 <= i2 <= 31 };
            Schedule :=
                { Stmt_for_body8__TO__if_end[i0, i1, i2] -> [i0, i1, i2, 0] };
            ReadAccess :=	[Reduction Type: NONE] [Scalar: 0]
                { Stmt_for_body8__TO__if_end[i0, i1, i2] -> MemRef_A[i0, i2] };
            ReadAccess :=	[Reduction Type: NONE] [Scalar: 0]
                { Stmt_for_body8__TO__if_end[i0, i1, i2] -> MemRef_B[i2, i1] };
            ReadAccess :=	[Reduction Type: NONE] [Scalar: 0]
                { Stmt_for_body8__TO__if_end[i0, i1, i2] -> MemRef_C[i0, i1] };
            MustWriteAccess :=	[Reduction Type: NONE] [Scalar: 1]
                { Stmt_for_body8__TO__if_end[i0, i1, i2] -> MemRef_storemerge__phi[] };
           new: { Stmt_for_body8__TO__if_end[i0, i1, i2] -> MemRef_C[i0, i1] };
    	Stmt_if_end
            Domain :=
                { Stmt_if_end[i0, i1, i2] : 0 <= i0 <= 31 and 0 <= i1 <= 31 and 0 <= i2 <= 31 };
            Schedule :=
                { Stmt_if_end[i0, i1, i2] -> [i0, i1, i2, 1] };
            ReadAccess :=	[Reduction Type: NONE] [Scalar: 1]
                { Stmt_if_end[i0, i1, i2] -> MemRef_storemerge__phi[] };
           new: { Stmt_if_end[i0, i1, i2] -> MemRef_C[i0, i1] };
            MustWriteAccess :=	[Reduction Type: NONE] [Scalar: 0]
                { Stmt_if_end[i0, i1, i2] -> MemRef_C[i0, i1] };
```


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D114336/new/

https://reviews.llvm.org/D114336



More information about the llvm-commits mailing list