[PATCH] D114336: [Polly] Generalize the pattern matching to the case of tensor contractions.

Roman via Phabricator via llvm-commits llvm-commits at lists.llvm.org
Mon May 9 08:06:11 PDT 2022


gareevroman marked 16 inline comments as done.
gareevroman added a comment.

1.

> The following is successfully detected as tensor contraction. Is this intended?



> void foo(double C[1024][1024], double A[1024][64][64], double B[64][1024][64]) {
>
>   for (int i = 0; i < 1024; i++)
>       for (int j = 0; j < 1024; j++)
>         for (int l = 0; l < 64; ++l)
>            if (l != 0)
>              for (int w = 0; w < 64; ++w)
>                C[i][j] += A[i][l][w] * B[w][j][l];
>
> }

Yes, it was intended. The transformation helps to optimize a class of programs, which is broader then a tensor contraction. However, it heavily depends on the codegen part. I think that the improvement of the detection can be the goal of the future work.

> It might be if the codegen part is able exclude the element 0. In contrast, this one is rejected:

In this case, the codegen excludes the element 0 for i2. I added a test case for this.

domain: "{ Stmt4[i0, i1, i2, i3] : 0 <= i0 <= 1023 and 0 <= i1 <= 1023 and 0 < i2 <= 63 and 0 <= i3 <= 63 }"
…

> In contrast, this one is rejected:



> void foo(int n, double C[1024][1024], double A[1024][64][64], double B[64][1024][64]) {
>
>   for (int i = 0; i < 1024; i++)
>       for (int j = 0; j < 1024; j++)
>         for (int l = 0; l < 64; l++)
>           for (int w = 0; w < 64; ++w)
>              if (w != 0)
>                C[i][j] += A[i][l][w] * B[w][j][l];
>
> }



> or this:



> void foo(int n, double C[1024][1024], double A[1024][64][64], double B[64][1024][64]) {
>
>   for (int i = 0; i < 1024; i++)
>       for (int j = 0; j < 1024; j++)
>         for (int l = 0; l < 64; l+=2)
>           for (int w = 0; w < 64; ++w)
>                C[i][j] += A[i][l][w] * B[w][j][l];
>
> }

As far as I know, in these cases, the codegen modifies some memory accesses. Consequently, they are not correspond to the current pattern.

  ReadAccess :=	[Reduction Type: NONE] [Scalar: 0]
      { Stmt3[i0, i1, i2, i3] -> MemRef0[i0, i2, 1 + i3] };
  ReadAccess :=	[Reduction Type: NONE] [Scalar: 0]
      { Stmt3[i0, i1, i2, i3] -> MemRef1[1 + i3, i1, i2] };
  ReadAccess :=	[Reduction Type: +] [Scalar: 0]
      { Stmt3[i0, i1, i2, i3] -> MemRef2[i0, i1] };
  MustWriteAccess :=	[Reduction Type: +] [Scalar: 0]
      { Stmt3[i0, i1, i2, i3] -> MemRef2[i0, i1] };

1.

> I do get an assertion failure with this one:



> void foo(double C[64][64], double A[64][64][64], double B[64][64][64]) {
>
>   for (int i = 0; i < 32; i++)
>       for (int j = 0; j < 32; j++)
>         for (int l = 0; l < 32; l++)
>           for (int w = 0; w < 32; ++w)
>                C[i][j] += A[i][l][w] * B[w][j][i+3];
>
> }

I fixed the isTCOperandAcc function and checked that all other asserts are used properly.

1.

> I do get an assertion failure with this one:



> void foo(double C[64][64], double A[64][64][64], double B[64][64][64]) {
>
>   for (int i = 0; i < 32; i++)
>       for (int j = 0; j < 32; j++)
>         for (int l = 0; l < 32; l++)
>           for (int w = 0; w < 32; ++w)
>                C[i][j] += A[i][l][w] * B[w][j][i+3];
>
> }

I fixed the isTCOperandAcc function and checked that all other asserts are used properly.

1.

> Here, i occurs as indices for A, B, and C and detected as TC. Is this supported?



> void foo(double C[64][64], double A[64][64][64], double B[64][64][64]) {
>
>   for (int i = 0; i < 32; i++)
>   for (int j = 0; j < 32; j++)
>   for (int l = 0; l < 32; l++)
>   for (int w = 0; w < 32; w++)
>       C[i][j] += A[i][l][w] * B[w][j][i];
>
> }

For some reason, I cannot reproduce that. I have added a corresponding test case. As far as I understand, that should be detected because of the line 1365.

1347  static bool containsOnlyTCAcc(isl::set Domain, isl::map PartialSchedule,

                              TCInfoTy &TCI) {
  ...

1365  if (intersect(IandJIndexSet, TCI.P).size() != 0)
1366    return false;

1.

I think that that it is redundant to require that bands are marked as permutable, since we check the form of dependencies and memory accesses. I propose to remove such checks for pattern matching optimizations.



================
Comment at: polly/lib/Transform/MatmulOptimizer.cpp:1093
+  isl::map PossibleTensor = isl::manage(Universe.copy());
+  for (int i = 0; i < static_cast<int>(Dimensions.size()); i++) {
+    const int InPos = Dimensions[i];
----------------
Meinersbur wrote:
> gareevroman wrote:
> > Meinersbur wrote:
> > > or introduce `intFromIslSize`.
> > As far I understand, Dimensions.size() returns a value of type size_t instead of a value of the type isl_size. So, in the new version I used the unsigned type to avoid the cast.
> `rangeIslSize` should make it easier.
I think rangeIslSize can’t be used in this case. However, I’ve tried to use rangeIslSize to improve the patch.


================
Comment at: polly/lib/Transform/MatmulOptimizer.cpp:1261
+  TCI.ReadFromC = nullptr;
+  SmallVector<MemoryAccess *, 32> Accesses = getAccessesInOrder(*Stmt);
+  for (auto *MemA = Accesses.begin(); *MemA != TCI.WriteToC; MemA++) {
----------------
Meinersbur wrote:
> gareevroman wrote:
> > Meinersbur wrote:
> > > `getAccessesInOrder` requires `Stmt` to not be a RegionStmt. Please add a test for it.
> > I’ve added a check to containsOnlyTCAcc. Could you clarify how the test case should look like? Should it be a region statement that contains a matrix multiplication with right order of memory accesses?
> Test in `containsOnlyTCAcc` is exactly what I was looking for. A region statement could look like this:
> 
> ```
> c = C[i][j];
> if (/*non-affine condition*/) {
>   (void)A[i][k] + B[k][j];
> } else {
>   C[i][j] = c;
> }
> ```
> which has the correct order of accesses but is obviously not what we are looking for.
> 
Thanks for the example! I have added a corresponding test case. If I am not mistaken, it requires DeLICM.


================
Comment at: polly/lib/Transform/MatmulOptimizer.cpp:1136
+  // words, it is a Partial write and Partial writes must be rejected.
+  return AccMap.is_equal(PossibleTensor);
+}
----------------
Meinersbur wrote:
> I like the idea of verifying the correctness by reconstructing and comparing to the original.
> 
> Maybe do it at the end to verify that the entire `TCInfoTy` is correct? On the other size, earlier fail would be better. What do you think?
Other parts of TCInfoTy are verified in isTCOperandAcc too. I think that it would be better to verify the related information in one place as much and as early as possible.

Probably, the earlier fail would simplify the debugging, since we exactly know the form of memory accesses and can rely on it. Additionally, the performance can be improved, since the earlier fail helps to avoid additional operations with sets.


================
Comment at: polly/lib/Transform/MatmulOptimizer.cpp:1219
+  for (MemoryAccess *MemA : reverse(Accesses)) {
+    if (!MemA->isLatestArrayKind())
+      continue;
----------------
Meinersbur wrote:
> I can use JScopImport to set a scalar memory access to a partial write without adding additional dependencies; that is, I don't think this can be just ignored.
> 
> I suggest to have a single function that calls `getAccessesInOrder` and sort out which MemAccess is read/write in there, then analyze them.
I am not sure whether modifications of implementations of tensor contractions, which contain read and write scalar memory accesses, are useful in practice. 

Moreover, since bundles of induction variables I, J, P can contain an unlimited number of dimensions, we possibly cannot follow the algorithm from the containsOnlyMatrMultAcc function, which permutes dimensions and checks that additional memory accesses have stride 0 in terms of dimensions MMM.i, MMM.j, and MMM.k. Consequently, such memory accesses can be treated as scalar memory accesses. I have not come up with an effective alternative yet.

That is why I do not consider scalar memory accesses in getWriteAccess and setReadAccesses functions. Could we mark it as TODO and do it future?


================
Comment at: polly/lib/Transform/MatmulOptimizer.cpp:1395
+///         and false, otherwise.
+static bool isTcDep(isl::set DepDelta, unsigned Pos, isl::set BoundDeltas,
+                    SmallDenseSet<int> *IndexSet) {
----------------
Meinersbur wrote:
> This seems to check setermine whether there is a reduction (contraction) carried by loop number `Pos`. The function name could be more meaningful. Suggestion: `isDepCarryingReductionOverDim` (not nice, but "TcDep" can mean anything)
Could we name it isReductionCarriedOverDim? I think, in this case, we should rename the parameter Pos to Dim to make it more readable.


================
Comment at: polly/lib/Transform/MatmulOptimizer.cpp:1410-1412
+  isl::map DepDeltaNegToBoundDeltas = isl::map::from_domain_and_range(
+      isl::manage(isl_set_neg(DepDelta.copy())), BoundDeltas);
+  isl::set Complement = DepDeltaNegToBoundDeltas.deltas();
----------------
Meinersbur wrote:
> Why not `BoundDeltas.subtract()` instead of deltas?
As far as I understand, these operations are not equal.

deltas computes a set containing the differences between image elements and corresponding domain elements in the input. subtract computes a subtraction of sets.

For example, in the case of the following sets they compute the following:

BoundDeltas : {Stmt_for_body15[31, 31, 31, 31, 31, 31] }
isl::manage(isl_set_neg(DepDelta.copy())): {Stmt_for_body15[0, 0, 0, 0, 0, -1]}

BoundDeltas.subtract(isl::manage(isl_set_neg(DepDelta.copy()))) : {Stmt_for_body15[31, 31, 31, 31, 31, 31]}
deltas: {Stmt_for_body15[31, 31, 31, 31, 31, 32]}

BoundDeltas : {Stmt_for_body15[31, 31, 31, 31, 31, 31]}
isl::manage(isl_set_neg(DepDelta.copy())): {Stmt_for_body15[0, 0, 0, -1, 0, 31]}

BoundDeltas.subtract(isl::manage(isl_set_neg(DepDelta.copy()))) : {Stmt_for_body15[31, 31, 31, 31, 31, 31]}
deltas: {Stmt_for_body15[31, 31, 31, 32, 31, 0]}

These comment interferes with the comment about pw_multi_aff. Consequently, I replaced the usage of isl_map_deltas with operations on pw_multi_aff.


================
Comment at: polly/lib/Transform/MatmulOptimizer.cpp:1451-1454
+  isl::union_map Dep = D->getDependences(Dependences::TYPE_RAW);
+  isl::union_map Red = D->getDependences(Dependences::TYPE_RED);
+  if (!Red.is_null())
+    Dep = Dep.unite(Red);
----------------
Meinersbur wrote:
> Should we also check whether WAW, RAW dependences are incompatible?
As far as I understand, that is not necessary, because subsequently we check that the statement has the form C(shuffle(I, J)) = E(A(shuffle(I, P)),B(shuffle(P, J))C(shuffle(I, J))), where E is an expression that contains reads from the tensors A, B, C, and an arbitrary number of reads from constants with respect to bundles I, J , and P.

I have added a comment that describes this.

"The form of anti and output dependencies is determined specified by the form the SCoP statement, which is checked by subsequent analysis."


================
Comment at: polly/lib/Transform/MatmulOptimizer.cpp:1460-1461
+  unsigned DeltasDimNum = unsignedFromIslSize(DepDeltas.dim(isl::dim::set));
+  isl::pw_multi_aff LowerBound = Domain.lexmin_pw_multi_aff();
+  isl::pw_multi_aff BoundSub = Domain.lexmax_pw_multi_aff().sub(LowerBound);
+  auto BoundDeltas = isl::manage(isl_set_from_pw_multi_aff(BoundSub.release()));
----------------
Meinersbur wrote:
> lexmin/lexmax can be expensive. Wrap into a `IslMaxOperationsGuard`?
What is the maximal amount of computational steps we should use by default? I set it to 500000 according to DependenceInfo.cpp.


================
Comment at: polly/lib/Transform/MatmulOptimizer.cpp:1465-1470
+    // In the case of the tensor contraction, the difference between image
+    // elements and domain elements lies on a hyperplane where a dimension
+    // has the fixed value one.
+    isl::set Intersection = DepDeltas.fix_si(isl::dim::set, i, 1);
+    if (Intersection.is_empty())
+      continue;
----------------
Meinersbur wrote:
> This is going to check whether each element out of `Intersection` is a contraction over dimension `i`. Don't we also need to check that every iteration out of the band `i` is contributing to that contraction?
Could you clarify what do you mean by the band i? Are these indexes ki, which describe the dependencies?

isTcDep checks that the dependency has the form

/// S(..., ki, max(k(i + 1)), ..., max(kn), ...) -> S(..., ki + 1, min(k(i + 1)), ..., min(kn), …)


================
Comment at: polly/lib/Transform/MatmulOptimizer.cpp:1602
+  Node = Node.child(0);
+  auto LeafType = isl_schedule_node_get_type(Node.get());
+  isl::union_map PartialSchedule = Node.get_prefix_schedule_union_map();
----------------
Meinersbur wrote:
> Prefer `Node.isa<isl::schedule_node_leaf>()` (and then typed subclass: `Node.as<isl_schedule_node_leaf>()`)
Could we factor out this condition into ScheduleTreeOptimizer::isPMOptimizableBandNode, since it is common for isTCPattern and isMatrMultPattern functions? A new version of the patch shows how it could look like.


================
Comment at: polly/lib/Transform/MatmulOptimizer.cpp:1613
+  // The partial schedule should contain only one statement.
+  if (isl_union_set_n_set(Domain.get()) != 1)
+    return false;
----------------
Meinersbur wrote:
> This constraint should not be intrinsic to the algorithm, but I agree it to be easier to handle for now.
Could we add a TODO comment for this?


================
Comment at: polly/lib/Transform/MatmulOptimizer.cpp:1622
+  // Subsequently, such band nodes will be replaced by a single band node.
+  while (NodeType != isl_schedule_node_domain) {
+    if (HasFilterNode && (NodeType == isl_schedule_node_band))
----------------
Meinersbur wrote:
> This looks for the outermost node that is not a filter or band. Is it possible that while that outermost node is not a TC contraction, one of the inner ones might? What if the outermost node is a filter, looks like it would just `return false` in this case.
If I am not mistaken, this only checks that all band nodes, which represent the statement, are not split by filter nodes. These accepts a straightforward implementation of TC with/without delicm. For example,

 domain: "{ Stmt_for_body8[i0, i1, i2]  : 0 <= i0 <= 1599 and
                                          0 <= i1 <= 1799 and
                                          0 <= i2 <= 2199;
            Stmt_for_body3[i0, i1] :      0 <= i0 <= 1599 and
                                          0 <= i1 <= 1799;
            Stmt_for_body3_last[i0, i1] : 0 <= i0 <= 1599 and
                                          0 <= i1 <= 1799 }"
 child:
  sequence:
  - filter: "{ Stmt_for_body3[i0, i1] }"
    child:
      schedule: "[{ Stmt_for_body3[i0, i1] -> [(i0)] }, { Stmt_for_body3[i0, i1] -> [(i1)] }]"
      permutable: 1
      coincident: [ 1, 1 ]
  - filter: "{ Stmt_for_body3_last[i0, i1] }"
    child:
      schedule: "[{ Stmt_for_body3_last[i0, i1] -> [(i0)] }, { Stmt_for_body3_last[i0, i1] -> [(i1)] }]"
      permutable: 1
      coincident: [ 1, 1 ]
  - filter: "{ Stmt_for_body8[i0, i1, i2] }"
    child:
      schedule: "[{ Stmt_for_body8[i0, i1, i2] -> [(i0)] },
                  { Stmt_for_body8[i0, i1, i2] -> [(i1)] },
                  { Stmt_for_body8[i0, i1, i2] -> [(i2)] }]"
      permutable: 1
      coincident: [ 1, 1, 0 ]

domain: "{ Stmt2[i0, i1, i2] : 0 <= i0 <= 31 and 0 <= i1 <= 31 and 0 <= i2 <= 31 }"
 child:
  schedule: "[{ Stmt2[i0, i1, i2] -> [(i0)] }, { Stmt2[i0, i1, i2] -> [(i1)] }, { Stmt2[i0, i1, i2] -> [(i2)] }]"
  permutable: 1
  coincident: [ 1, 1, 0 ]

Sorry, I have not committed an updated version of the optimization of TC to my github repo. However, I believe that, if this is that case, we can safely replace all such nodes.

+  auto NodeType = isl_schedule_node_get_type(Node.get());
+  while ((NodeType != isl_schedule_node_domain) &&
+         (NodeType != isl_schedule_node_filter)) {
+    assert((NodeType != isl_schedule_node_sequence) &&
+           L"Prevent the undefined behavior");
+    Node = Node.parent();
+    NodeType = isl_schedule_node_get_type(Node.get());
+  }
+  Node = Node.child(0);
+  Node = isl::manage(isl_schedule_node_cut(Node.release()));
+  return Node.insert_partial_schedule(Dimensions);

I think taht the detection of a more sophisticated implementations of TC is a possible goal of a future research.

I have described this in the comment.


================
Comment at: polly/test/ScheduleOptimizer/pattern-matching-based-opts-after-delicm_2.ll:1-3
+; RUN: opt %loadPolly -polly-delicm -polly-simplify -polly-opt-isl \
+; RUN: -polly-pattern-matching-based-opts=true \
+; RUN: -polly-tc-opt=true -debug < %s 2>&1 | FileCheck %s
----------------
Meinersbur wrote:
> Since this is not FileCheck-ing the LLVM-IR output, suppress it with `-disable-output`
Could we fix the existing test cases in a separate patch?

polly/test/ScheduleOptimizer/pattern-matching-based-opts-after-delicm_2.ll

; RUN: -polly-tc-opt=true -debug -disable-output < %s 2>&1 | FileCheck %s
; REQUIRES: asserts

polly/test/ScheduleOptimizer/pattern-matching-based-opts_16.ll

; RUN: opt %loadPolly -polly-opt-isl -polly-pattern-matching-based-opts=true \
; RUN: -polly-tc-opt=true -debug -disable-output < %s 2>&1 | FileCheck %s

polly/test/ScheduleOptimizer/pattern-matching-based-opts_17.ll

; RUN: opt %loadPolly -polly-opt-isl -polly-pattern-matching-based-opts=true \
; RUN: -polly-tc-opt=true -debug -disable-output < %s 2>&1 | FileCheck %s

polly/test/ScheduleOptimizer/pattern-matching-based-opts_18.ll

; RUN: opt %loadPolly -polly-opt-isl -polly-pattern-matching-based-opts=true \
; RUN: -polly-tc-opt=true -debug -disable-output < %s 2>&1 | FileCheck %s

polly/test/ScheduleOptimizer/pattern-matching-based-opts_19.ll

; RUN: opt %loadPolly -polly-opt-isl -polly-pattern-matching-based-opts=true \
; RUN: -polly-tc-opt=true -debug -disable-output < %s 2>&1 | FileCheck %s

polly/test/ScheduleOptimizer/pattern-matching-based-opts_20.ll

; RUN: opt %loadPolly -polly-opt-isl -polly-pattern-matching-based-opts=true \
; RUN: -polly-tc-opt=true -debug -disable-output < %s 2>&1 | FileCheck %s



CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D114336/new/

https://reviews.llvm.org/D114336



More information about the llvm-commits mailing list