[PATCH] D114336: [Polly] Generalize the pattern matching to the case of tensor contractions.

Michael Kruse via Phabricator via llvm-commits llvm-commits at lists.llvm.org
Wed Nov 24 18:36:39 PST 2021


Meinersbur added a comment.

Thanks for upstreaming your tensor optimization!

The naming of `polly-pattern-matching-based-opts` suggests that it includes all pattern-based optimizations, yet this introduces another flag `polly-pattern-matching-based-tc-opts`. I'd prefer `polly-pattern-matching-based-opts` controlling both optimizations, and then additional flags for enabling matrix-multiplication and tensor optimizations. Alternatively, rename `polly-pattern-matching-based-opts` to e.g. `polly-matmul-opt`. Also, you are adding this functionality into a file called `MatmulOptimizer` and a function called `tryOptimizeMatMulPattern`. Since matrix-multiplication is a tensor contraction, is ts-opt supposed to supperseed matrix multiplication?  In short, I would like to know what what the relation between those two optimizations should be.  I'd prefer to not have to maintain two optimizations if one is strictly more powerful than the other.

I'd enjoy some occasional comments and grouping of statements (empty lines) inside the functions in addition to the doxygen comments. For instance `isTCOperandAcc` is just a wall of code. For such property-checking functions, ideally each `return` should mention what property is violated here and why this property is required to be compatible with tensor optimization.



================
Comment at: polly/lib/Transform/MatmulOptimizer.cpp:172-173
+///
+/// Parameters, which describe access relations that represent operands of the
+/// tensor contraction.
+struct TCInfoTy {
----------------
Please add more details on what the members represent.


================
Comment at: polly/lib/Transform/MatmulOptimizer.cpp:179-181
+  std::set<int> I;
+  std::set<int> J;
+  std::set<int> P;
----------------
`std::set` is a high-overhead implementation. Consider using `DenseSet` or `SmallDenseSet`. See https://www.llvm.org/docs/ProgrammersManual.html#llvm-adt-denseset-h


================
Comment at: polly/lib/Transform/MatmulOptimizer.cpp:182-188
+  SmallVector<int, 30> DimensionSizes;
+  SmallVector<int, 30> ADimensions;
+  SmallVector<int, 30> BDimensions;
+  SmallVector<int, 30> CDimensions;
+  SmallVector<int, 30> OrderedI;
+  SmallVector<int, 30> OrderedJ;
+  SmallVector<int, 30> OrderedP;
----------------
Is there an argument to use 30 and small size? If not, consider using just `SmallVector<int>`.


================
Comment at: polly/lib/Transform/MatmulOptimizer.cpp:1068
+  const llvm::SCEV *SCEVDimSize = SAI->getDimensionSize(Pos);
+  assert(SCEVDimSize && L"Prevent the undefined behavior");
+  auto *ConstantDimSize = dyn_cast<const SCEVConstant>(SCEVDimSize);
----------------
[style] No reason to make this a wide string literal, especially if just used as an assertion failed message.

Apples to other occurrences as well.


================
Comment at: polly/lib/Transform/MatmulOptimizer.cpp:1087
+static bool isCorrectAccessMap(isl::set Domain, isl::map AccMap,
+                               const SmallVector<int, 30> &Dimensions) {
+  isl::space Space = AccMap.get_space();
----------------



================
Comment at: polly/lib/Transform/MatmulOptimizer.cpp:1093
+  isl::map PossibleTensor = isl::manage(Universe.copy());
+  for (int i = 0; i < static_cast<int>(Dimensions.size()); i++) {
+    const int InPos = Dimensions[i];
----------------
or introduce `intFromIslSize`.


================
Comment at: polly/lib/Transform/MatmulOptimizer.cpp:1126-1127
+                           std::set<int> &IndexSet,
+                           SmallVector<int, 30> &DimensionSizes,
+                           SmallVector<int, 30> &Dimensions) {
+  isl::id Id = AccMap.get_tuple_id(isl::dim::out);
----------------
`SmallVectorImpl` is not specific to what the vector's small size is.


================
Comment at: polly/lib/Transform/MatmulOptimizer.cpp:1137
+  unsigned InDimNum = unsignedFromIslSize(CheckMap.dim(isl::dim::in));
+  for (unsigned i = 0; i < InDimNum; i++) {
+    isl::val Val = isl::manage(
----------------
Consider using `polly::rangeIslSize` for iterating over dimensions.


================
Comment at: polly/lib/Transform/MatmulOptimizer.cpp:1164
+/// @return The set intersection.
+std::set<int> intersect(const std::set<int> &A, const std::set<int> &B) {
+  std::set<int> Intersection;
----------------
Although already in an anon namespace, the other methods add `static` as well. I found it helps the compiler to warn if a static function is unnused.


================
Comment at: polly/lib/Transform/MatmulOptimizer.cpp:1166-1167
+  std::set<int> Intersection;
+  set_intersection(A.begin(), A.end(), B.begin(), B.end(),
+                   std::inserter(Intersection, Intersection.begin()));
+  return Intersection;
----------------
Do you know of `#include <llvm/ADT/SetOperations.h>`? Unfortunately, these modify one set rather than returning a new set.


================
Comment at: polly/lib/Transform/MatmulOptimizer.cpp:1196
+  auto *MemA = Accesses.end() - 1;
+  for (; MemA != Accesses.begin(); MemA--) {
+    MemoryAccess *MemAccessPtr = *MemA;
----------------



================
Comment at: polly/lib/Transform/MatmulOptimizer.cpp:1206-1207
+      return;
+    if (intersect(IandJIndexSet, TCI.P).size())
+      return;
+    TCI.WriteToC = MemAccessPtr;
----------------
This computes whether two sets a disjoint, it should not be required to compute the intersection.


================
Comment at: polly/lib/Transform/MatmulOptimizer.cpp:1223
+                   const std::set<int> &IandJIndexSet,
+                   SmallVector<int, 30> Dimensions, TCInfoTy &TCI) {
+  if (!TCI.A) {
----------------



================
Comment at: polly/lib/Transform/MatmulOptimizer.cpp:1261
+  TCI.ReadFromC = nullptr;
+  SmallVector<MemoryAccess *, 32> Accesses = getAccessesInOrder(*Stmt);
+  for (auto *MemA = Accesses.begin(); *MemA != TCI.WriteToC; MemA++) {
----------------
`getAccessesInOrder` requires `Stmt` to not be a RegionStmt. Please add a test for it.


================
Comment at: polly/lib/Transform/MatmulOptimizer.cpp:1271
+      if (TCI.ReadFromC)
+        return;
+      TCI.ReadFromC = MemAccessPtr;
----------------
If any of the returns are executed, what causes the pattern to be rejected (it's not `return false`)?


================
Comment at: polly/lib/Transform/MatmulOptimizer.cpp:1340-1341
+///         and false, otherwise.
+static bool isTcDep(isl::set DepDelta, unsigned Pos, isl::set BoundDeltas,
+                    std::set<int> *IndexSet) {
+  if ((unsignedFromIslSize(DepDelta.n_basic_set()) != 1) ||
----------------



================
Comment at: polly/lib/Transform/MatmulOptimizer.cpp:1342-1344
+  if ((unsignedFromIslSize(DepDelta.n_basic_set()) != 1) ||
+      !DepDelta.plain_get_val_if_fixed(isl::dim::set, Pos).is_one())
+    return false;
----------------
The check should not depend on `n_basic_set`, which is fragile and depends on whether on eg. `coalesce` was successful. Consider using something like `polly::getConstant`.


================
Comment at: polly/lib/Transform/MatmulOptimizer.cpp:1388-1389
+  unsigned DeltasDimNum = unsignedFromIslSize(DepDeltas.dim(isl::dim::set));
+  isl::set UpperBound = Domain.lexmax();
+  isl::set LowerBound = Domain.lexmin();
+  isl::map BoundMap = isl::map::from_domain_and_range(LowerBound, UpperBound);
----------------
Consider `lexmin_pw_multi_aff`/`lexmax_pw_multi_aff`


================
Comment at: polly/lib/Transform/ScheduleOptimizer.cpp:459-463
-  auto Space = isl::manage(isl_schedule_node_band_get_space(Node.get()));
-
-  if (unsignedFromIslSize(Space.dim(isl::dim::set)) <= 1u)
-    return false;
-
----------------
Instead of modifying the idea of whether a node is tilable, consider introducing another constraint-checking function, as we should have done with prevectorization as well.


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D114336/new/

https://reviews.llvm.org/D114336



More information about the llvm-commits mailing list