[polly] r281441 - Perform copying to created arrays according to the packing transformation

Roman Gareev via llvm-commits llvm-commits at lists.llvm.org
Tue Sep 13 23:26:09 PDT 2016


Author: romangareev
Date: Wed Sep 14 01:26:09 2016
New Revision: 281441

URL: http://llvm.org/viewvc/llvm-project?rev=281441&view=rev
Log:
Perform copying to created arrays according to the packing transformation

This is the fourth patch to apply the BLIS matmul optimization pattern on matmul
kernels (http://www.cs.utexas.edu/users/flame/pubs/TOMS-BLIS-Analytical.pdf).
BLIS implements gemm as three nested loops around a macro-kernel, plus two
packing routines. The macro-kernel is implemented in terms of two additional
loops around a micro-kernel. The micro-kernel is a loop around a rank-1
(i.e., outer product) update. In this change we perform copying to created
arrays, which is the last step to implement the packing transformation.

Reviewed-by: Tobias Grosser <tobias at grosser.es>

Differential Revision: https://reviews.llvm.org/D23260

Modified:
    polly/trunk/include/polly/CodeGen/IslExprBuilder.h
    polly/trunk/include/polly/CodeGen/IslNodeBuilder.h
    polly/trunk/include/polly/ScheduleOptimizer.h
    polly/trunk/include/polly/ScopInfo.h
    polly/trunk/lib/Analysis/DependenceInfo.cpp
    polly/trunk/lib/Analysis/PolyhedralInfo.cpp
    polly/trunk/lib/Analysis/ScopInfo.cpp
    polly/trunk/lib/CodeGen/BlockGenerators.cpp
    polly/trunk/lib/CodeGen/IRBuilder.cpp
    polly/trunk/lib/CodeGen/IslAst.cpp
    polly/trunk/lib/CodeGen/IslNodeBuilder.cpp
    polly/trunk/lib/Exchange/JSONExporter.cpp
    polly/trunk/lib/Transform/DeadCodeElimination.cpp
    polly/trunk/lib/Transform/ScheduleOptimizer.cpp
    polly/trunk/test/ScheduleOptimizer/mat_mul_pattern_data_layout.ll

Modified: polly/trunk/include/polly/CodeGen/IslExprBuilder.h
URL: http://llvm.org/viewvc/llvm-project/polly/trunk/include/polly/CodeGen/IslExprBuilder.h?rev=281441&r1=281440&r2=281441&view=diff
==============================================================================
--- polly/trunk/include/polly/CodeGen/IslExprBuilder.h (original)
+++ polly/trunk/include/polly/CodeGen/IslExprBuilder.h Wed Sep 14 01:26:09 2016
@@ -166,6 +166,17 @@ public:
   ///         was enabled.
   llvm::Value *getOverflowState() const;
 
+  /// Create LLVM-IR that computes the memory location of an access expression.
+  ///
+  /// For a given isl_ast_expr[ession] of type isl_ast_op_access this function
+  /// creates IR that computes the address the access expression refers to.
+  ///
+  /// @param Expr The ast expression of type isl_ast_op_access
+  ///             for which we generate LLVM-IR.
+  ///
+  /// @return The llvm::Value* containing the result of the computation.
+  llvm::Value *createAccessAddress(__isl_take isl_ast_expr *Expr);
+
 private:
   Scop &S;
 
@@ -203,7 +214,6 @@ private:
   llvm::Value *createId(__isl_take isl_ast_expr *Expr);
   llvm::Value *createInt(__isl_take isl_ast_expr *Expr);
   llvm::Value *createOpAddressOf(__isl_take isl_ast_expr *Expr);
-  llvm::Value *createAccessAddress(__isl_take isl_ast_expr *Expr);
 
   /// Create a binary operation @p Opc and track overflows if requested.
   ///

Modified: polly/trunk/include/polly/CodeGen/IslNodeBuilder.h
URL: http://llvm.org/viewvc/llvm-project/polly/trunk/include/polly/CodeGen/IslNodeBuilder.h?rev=281441&r1=281440&r2=281441&view=diff
==============================================================================
--- polly/trunk/include/polly/CodeGen/IslNodeBuilder.h (original)
+++ polly/trunk/include/polly/CodeGen/IslNodeBuilder.h Wed Sep 14 01:26:09 2016
@@ -375,6 +375,21 @@ protected:
   ///
   virtual __isl_give isl_union_map *
   getScheduleForAstNode(__isl_take isl_ast_node *Node);
+
+private:
+  /// Create code for a copy statement.
+  ///
+  /// A copy statement is expected to have one read memory access and one write
+  /// memory access (in this very order). Data is loaded from the location
+  /// described by the read memory access and written to the location described
+  /// by the write memory access. @p NewAccesses contains for each access
+  /// the isl ast expression that describes the location accessed.
+  ///
+  /// @param Stmt The copy statement that contains the accesses.
+  /// @param NewAccesses The hash table that contains remappings from memory
+  ///                    ids to new access expressions.
+  void generateCopyStmt(ScopStmt *Stmt,
+                        __isl_keep isl_id_to_ast_expr *NewAccesses);
 };
 
 #endif

Modified: polly/trunk/include/polly/ScheduleOptimizer.h
URL: http://llvm.org/viewvc/llvm-project/polly/trunk/include/polly/ScheduleOptimizer.h?rev=281441&r1=281440&r2=281441&view=diff
==============================================================================
--- polly/trunk/include/polly/ScheduleOptimizer.h (original)
+++ polly/trunk/include/polly/ScheduleOptimizer.h Wed Sep 14 01:26:09 2016
@@ -88,7 +88,7 @@ public:
   ///
   /// @return True, if we believe @p NewSchedule is an improvement for @p S.
   static bool isProfitableSchedule(polly::Scop &S,
-                                   __isl_keep isl_union_map *NewSchedule);
+                                   __isl_keep isl_schedule *NewSchedule);
 
   /// Isolate a set of partial tile prefixes.
   ///

Modified: polly/trunk/include/polly/ScopInfo.h
URL: http://llvm.org/viewvc/llvm-project/polly/trunk/include/polly/ScopInfo.h?rev=281441&r1=281440&r2=281441&view=diff
==============================================================================
--- polly/trunk/include/polly/ScopInfo.h (original)
+++ polly/trunk/include/polly/ScopInfo.h Wed Sep 14 01:26:09 2016
@@ -689,6 +689,19 @@ public:
                ArrayRef<const SCEV *> Subscripts, ArrayRef<const SCEV *> Sizes,
                Value *AccessValue, ScopArrayInfo::MemoryKind Kind,
                StringRef BaseName);
+
+  /// Create a new MemoryAccess that corresponds to @p AccRel.
+  ///
+  /// Along with @p Stmt and @p AccType it uses information about dimension
+  /// lengths of the accessed array, the type of the accessed array elements,
+  /// the name of the accessed array that is derived from the object accessible
+  /// via @p AccRel.
+  ///
+  /// @param Stmt       The parent statement.
+  /// @param AccType    Whether read or write access.
+  /// @param AccRel     The access relation that describes the memory access.
+  MemoryAccess(ScopStmt *Stmt, AccessType AccType, __isl_take isl_map *AccRel);
+
   ~MemoryAccess();
 
   /// Add a new incoming block/value pairs for this PHI/ExitPHI access.
@@ -1083,6 +1096,16 @@ public:
   /// Create an overapproximating ScopStmt for the region @p R.
   ScopStmt(Scop &parent, Region &R);
 
+  /// Create a copy statement.
+  ///
+  /// @param Stmt       The parent statement.
+  /// @param SourceRel  The source location.
+  /// @param TargetRel  The target location.
+  /// @param Domain     The original domain under which copy statement whould
+  ///                   be executed.
+  ScopStmt(Scop &parent, __isl_take isl_map *SourceRel,
+           __isl_take isl_map *TargetRel, __isl_take isl_set *Domain);
+
   /// Initialize members after all MemoryAccesses have been added.
   void init(LoopInfo &LI);
 
@@ -1217,10 +1240,14 @@ public:
 
   /// Get the schedule function of this ScopStmt.
   ///
-  /// @return The schedule function of this ScopStmt.
+  /// @return The schedule function of this ScopStmt, if it does not contain
+  /// extension nodes, and nullptr, otherwise.
   __isl_give isl_map *getSchedule() const;
 
   /// Get an isl string representing this schedule.
+  ///
+  /// @return An isl string representing this schedule, if it does not contain
+  /// extension nodes, and an empty string, otherwise.
   std::string getScheduleStr() const;
 
   /// Get the invalid domain for this statement.
@@ -1245,6 +1272,9 @@ public:
   /// Return true if this statement represents a single basic block.
   bool isBlockStmt() const { return BB != nullptr; }
 
+  /// Return true if this is a copy statement.
+  bool isCopyStmt() const { return BB == nullptr && R == nullptr; }
+
   /// Get the region represented by this ScopStmt (if any).
   ///
   /// @return The region represented by this ScopStmt, or null if the statement
@@ -1448,6 +1478,9 @@ private:
   /// Max loop depth.
   unsigned MaxLoopDepth;
 
+  /// Number of copy statements.
+  unsigned CopyStmtsNum;
+
   typedef std::list<ScopStmt> StmtSet;
   /// The statements in this Scop.
   StmtSet Stmts;
@@ -1615,11 +1648,6 @@ private:
   Scop(Region &R, ScalarEvolution &SE, LoopInfo &LI,
        ScopDetection::DetectionContext &DC);
 
-  /// Add the access function to all MemoryAccess objects of the Scop
-  ///        created in this pass.
-  void addAccessFunction(MemoryAccess *Access) {
-    AccessFunctions.emplace_back(Access);
-  }
   //@}
 
   /// Initialize this ScopBuilder.
@@ -1927,6 +1955,30 @@ private:
 public:
   ~Scop();
 
+  /// Get the count of copy statements added to this Scop.
+  ///
+  /// @return The count of copy statements added to this Scop.
+  unsigned getCopyStmtsNum() { return CopyStmtsNum; }
+
+  /// Create a new copy statement.
+  ///
+  /// A new statement will be created and added to the statement vector.
+  ///
+  /// @param Stmt       The parent statement.
+  /// @param SourceRel  The source location.
+  /// @param TargetRel  The target location.
+  /// @param Domain     The original domain under which copy statement whould
+  ///                   be executed.
+  ScopStmt *addScopStmt(__isl_take isl_map *SourceRel,
+                        __isl_take isl_map *TargetRel,
+                        __isl_take isl_set *Domain);
+
+  /// Add the access function to all MemoryAccess objects of the Scop
+  ///        created in this pass.
+  void addAccessFunction(MemoryAccess *Access) {
+    AccessFunctions.emplace_back(Access);
+  }
+
   ScalarEvolution *getSE() const;
 
   /// Get the count of parameters used in this Scop.
@@ -2349,6 +2401,9 @@ public:
   __isl_give isl_union_map *getAccesses();
 
   /// Get the schedule of all the statements in the SCoP.
+  ///
+  /// @return The schedule of all the statements in the SCoP, if the schedule of
+  /// the Scop does not contain extension nodes, and nullptr, otherwise.
   __isl_give isl_union_map *getSchedule() const;
 
   /// Get a schedule tree describing the schedule of all statements.
@@ -2380,6 +2435,11 @@ public:
   /// Find the ScopArrayInfo associated with an isl Id
   ///        that has name @p Name.
   ScopArrayInfo *getArrayInfoByName(const std::string BaseName);
+
+  /// Check whether @p Schedule contains extension nodes.
+  ///
+  /// @return true if @p Schedule contains extension nodes.
+  static bool containsExtensionNode(__isl_keep isl_schedule *Schedule);
 };
 
 /// Print Scop scop to raw_ostream O.

Modified: polly/trunk/lib/Analysis/DependenceInfo.cpp
URL: http://llvm.org/viewvc/llvm-project/polly/trunk/lib/Analysis/DependenceInfo.cpp?rev=281441&r1=281440&r2=281441&view=diff
==============================================================================
--- polly/trunk/lib/Analysis/DependenceInfo.cpp (original)
+++ polly/trunk/lib/Analysis/DependenceInfo.cpp Wed Sep 14 01:26:09 2016
@@ -153,6 +153,8 @@ static void collectInfo(Scop &S, isl_uni
         // to match the new access domains, thus we need
         //   [Stmt[i0, i1] -> MemAcc_A[i0 + i1]] -> [0, i0, 2, i1, 0]
         isl_map *Schedule = Stmt.getSchedule();
+        assert(Schedule && "Schedules that contain extension nodes require "
+                           "special handling.");
         Schedule = isl_map_apply_domain(
             Schedule,
             isl_map_reverse(isl_map_domain_map(isl_map_copy(accdom))));
@@ -162,7 +164,10 @@ static void collectInfo(Scop &S, isl_uni
       } else {
         accdom = tag(accdom, MA, Level);
         if (Level > Dependences::AL_Statement) {
-          isl_map *Schedule = tag(Stmt.getSchedule(), MA, Level);
+          auto *StmtScheduleMap = Stmt.getSchedule();
+          assert(StmtScheduleMap && "Schedules that contain extension nodes "
+                                    "require special handling.");
+          isl_map *Schedule = tag(StmtScheduleMap, MA, Level);
           *StmtSchedule = isl_union_map_add_map(*StmtSchedule, Schedule);
         }
       }
@@ -610,6 +615,8 @@ bool Dependences::isValidSchedule(Scop &
       StmtScat = Stmt.getSchedule();
     else
       StmtScat = isl_map_copy((*NewSchedule)[&Stmt]);
+    assert(StmtScat &&
+           "Schedules that contain extension nodes require special handling.");
 
     if (!ScheduleSpace)
       ScheduleSpace = isl_space_range(isl_map_get_space(StmtScat));

Modified: polly/trunk/lib/Analysis/PolyhedralInfo.cpp
URL: http://llvm.org/viewvc/llvm-project/polly/trunk/lib/Analysis/PolyhedralInfo.cpp?rev=281441&r1=281440&r2=281441&view=diff
==============================================================================
--- polly/trunk/lib/Analysis/PolyhedralInfo.cpp (original)
+++ polly/trunk/lib/Analysis/PolyhedralInfo.cpp Wed Sep 14 01:26:09 2016
@@ -134,6 +134,8 @@ __isl_give isl_union_map *PolyhedralInfo
     unsigned int MaxDim = SS->getNumIterators();
     DEBUG(dbgs() << "Maximum depth of Stmt:\t" << MaxDim << "\n");
     auto *ScheduleMap = SS->getSchedule();
+    assert(ScheduleMap &&
+           "Schedules that contain extension nodes require special handling.");
 
     ScheduleMap = isl_map_project_out(ScheduleMap, isl_dim_out, CurrDim + 1,
                                       MaxDim - CurrDim - 1);

Modified: polly/trunk/lib/Analysis/ScopInfo.cpp
URL: http://llvm.org/viewvc/llvm-project/polly/trunk/lib/Analysis/ScopInfo.cpp?rev=281441&r1=281440&r2=281441&view=diff
==============================================================================
--- polly/trunk/lib/Analysis/ScopInfo.cpp (original)
+++ polly/trunk/lib/Analysis/ScopInfo.cpp Wed Sep 14 01:26:09 2016
@@ -857,6 +857,28 @@ MemoryAccess::MemoryAccess(ScopStmt *Stm
   Id = isl_id_alloc(Stmt->getParent()->getIslCtx(), IdName.c_str(), this);
 }
 
+MemoryAccess::MemoryAccess(ScopStmt *Stmt, AccessType AccType,
+                           __isl_take isl_map *AccRel)
+    : Kind(ScopArrayInfo::MemoryKind::MK_Array), AccType(AccType),
+      RedType(RT_NONE), Statement(Stmt), InvalidDomain(nullptr),
+      AccessInstruction(nullptr), IsAffine(true), AccessRelation(nullptr),
+      NewAccessRelation(AccRel) {
+  auto *ArrayInfoId = isl_map_get_tuple_id(NewAccessRelation, isl_dim_out);
+  auto *SAI = ScopArrayInfo::getFromId(ArrayInfoId);
+  Sizes.push_back(nullptr);
+  for (unsigned i = 1; i < SAI->getNumberOfDimensions(); i++)
+    Sizes.push_back(SAI->getDimensionSize(i));
+  ElementType = SAI->getElementType();
+  BaseAddr = SAI->getBasePtr();
+  BaseName = SAI->getName();
+  static const std::string TypeStrings[] = {"", "_Read", "_Write", "_MayWrite"};
+  const std::string Access = TypeStrings[AccType] + utostr(Stmt->size()) + "_";
+
+  std::string IdName =
+      getIslCompatibleName(Stmt->getBaseName(), Access, BaseName);
+  Id = isl_id_alloc(Stmt->getParent()->getIslCtx(), IdName.c_str(), this);
+}
+
 void MemoryAccess::realignParams() {
   auto *Ctx = Statement->getParent()->getContext();
   InvalidDomain = isl_set_gist_params(InvalidDomain, isl_set_copy(Ctx));
@@ -1040,6 +1062,10 @@ __isl_give isl_map *ScopStmt::getSchedul
         isl_aff_zero_on_domain(isl_local_space_from_space(getDomainSpace())));
   }
   auto *Schedule = getParent()->getSchedule();
+  if (!Schedule) {
+    isl_set_free(Domain);
+    return nullptr;
+  }
   Schedule = isl_union_map_intersect_domain(
       Schedule, isl_union_set_from_set(isl_set_copy(Domain)));
   if (isl_union_map_is_empty(Schedule)) {
@@ -1430,6 +1456,25 @@ ScopStmt::ScopStmt(Scop &parent, BasicBl
   BaseName = getIslCompatibleName("Stmt_", &bb, "");
 }
 
+ScopStmt::ScopStmt(Scop &parent, __isl_take isl_map *SourceRel,
+                   __isl_take isl_map *TargetRel, __isl_take isl_set *NewDomain)
+    : Parent(parent), InvalidDomain(nullptr), Domain(NewDomain), BB(nullptr),
+      R(nullptr), Build(nullptr) {
+  BaseName = getIslCompatibleName("CopyStmt_", "",
+                                  std::to_string(parent.getCopyStmtsNum()));
+  auto *Id = isl_id_alloc(getIslCtx(), getBaseName(), this);
+  Domain = isl_set_set_tuple_id(Domain, isl_id_copy(Id));
+  TargetRel = isl_map_set_tuple_id(TargetRel, isl_dim_in, Id);
+  auto *Access =
+      new MemoryAccess(this, MemoryAccess::AccessType::MUST_WRITE, TargetRel);
+  parent.addAccessFunction(Access);
+  addAccess(Access);
+  SourceRel = isl_map_set_tuple_id(SourceRel, isl_dim_in, isl_id_copy(Id));
+  Access = new MemoryAccess(this, MemoryAccess::AccessType::READ, SourceRel);
+  parent.addAccessFunction(Access);
+  addAccess(Access);
+}
+
 void ScopStmt::init(LoopInfo &LI) {
   assert(!Domain && "init must be called only once");
 
@@ -1576,6 +1621,8 @@ std::string ScopStmt::getDomainStr() con
 
 std::string ScopStmt::getScheduleStr() const {
   auto *S = getSchedule();
+  if (!S)
+    return "";
   auto Str = stringFromIslObj(S);
   isl_map_free(S);
   return Str;
@@ -3041,9 +3088,10 @@ Scop::Scop(Region &R, ScalarEvolution &S
            ScopDetection::DetectionContext &DC)
     : SE(&ScalarEvolution), R(R), IsOptimized(false),
       HasSingleExitEdge(R.getExitingBlock()), HasErrorBlock(false),
-      MaxLoopDepth(0), DC(DC), IslCtx(isl_ctx_alloc(), isl_ctx_free),
-      Context(nullptr), Affinator(this, LI), AssumedContext(nullptr),
-      InvalidContext(nullptr), Schedule(nullptr) {
+      MaxLoopDepth(0), CopyStmtsNum(0), DC(DC),
+      IslCtx(isl_ctx_alloc(), isl_ctx_free), Context(nullptr),
+      Affinator(this, LI), AssumedContext(nullptr), InvalidContext(nullptr),
+      Schedule(nullptr) {
   if (IslOnErrorAbort)
     isl_options_set_on_error(getIslCtx(), ISL_ON_ERROR_ABORT);
   buildContext();
@@ -3922,8 +3970,27 @@ __isl_give isl_union_map *Scop::getAcces
   return getAccessesOfType([](MemoryAccess &MA) { return true; });
 }
 
+// Check whether @p Node is an extension node.
+//
+// @return true if @p Node is an extension node.
+isl_bool isNotExtNode(__isl_keep isl_schedule_node *Node, void *User) {
+  if (isl_schedule_node_get_type(Node) == isl_schedule_node_extension)
+    return isl_bool_error;
+  else
+    return isl_bool_true;
+}
+
+bool Scop::containsExtensionNode(__isl_keep isl_schedule *Schedule) {
+  return isl_schedule_foreach_schedule_node_top_down(Schedule, isNotExtNode,
+                                                     nullptr) == isl_stat_error;
+}
+
 __isl_give isl_union_map *Scop::getSchedule() const {
   auto *Tree = getScheduleTree();
+  if (containsExtensionNode(Tree)) {
+    isl_schedule_free(Tree);
+    return nullptr;
+  }
   auto *S = isl_schedule_get_map(Tree);
   isl_schedule_free(Tree);
   return S;
@@ -4059,6 +4126,14 @@ void Scop::addScopStmt(BasicBlock *BB, R
   }
 }
 
+ScopStmt *Scop::addScopStmt(__isl_take isl_map *SourceRel,
+                            __isl_take isl_map *TargetRel,
+                            __isl_take isl_set *Domain) {
+  Stmts.emplace_back(*this, SourceRel, TargetRel, Domain);
+  CopyStmtsNum++;
+  return &(Stmts.back());
+}
+
 void Scop::buildSchedule(LoopInfo &LI) {
   Loop *L = getLoopSurroundingScop(*this, LI);
   LoopStackTy LoopStack({LoopStackElementTy(L, nullptr, 0)});

Modified: polly/trunk/lib/CodeGen/BlockGenerators.cpp
URL: http://llvm.org/viewvc/llvm-project/polly/trunk/lib/CodeGen/BlockGenerators.cpp?rev=281441&r1=281440&r2=281441&view=diff
==============================================================================
--- polly/trunk/lib/CodeGen/BlockGenerators.cpp (original)
+++ polly/trunk/lib/CodeGen/BlockGenerators.cpp Wed Sep 14 01:26:09 2016
@@ -681,7 +681,9 @@ void BlockGenerator::createExitPHINodeMe
 
 void BlockGenerator::invalidateScalarEvolution(Scop &S) {
   for (auto &Stmt : S)
-    if (Stmt.isBlockStmt())
+    if (Stmt.isCopyStmt())
+      continue;
+    else if (Stmt.isBlockStmt())
       for (auto &Inst : *Stmt.getBasicBlock())
         SE.forgetValue(&Inst);
     else if (Stmt.isRegionStmt())

Modified: polly/trunk/lib/CodeGen/IRBuilder.cpp
URL: http://llvm.org/viewvc/llvm-project/polly/trunk/lib/CodeGen/IRBuilder.cpp?rev=281441&r1=281440&r2=281441&view=diff
==============================================================================
--- polly/trunk/lib/CodeGen/IRBuilder.cpp (original)
+++ polly/trunk/lib/CodeGen/IRBuilder.cpp Wed Sep 14 01:26:09 2016
@@ -61,7 +61,8 @@ void ScopAnnotator::buildAliasScopes(Sco
   SetVector<Value *> BasePtrs;
   for (ScopStmt &Stmt : S)
     for (MemoryAccess *MA : Stmt)
-      BasePtrs.insert(MA->getBaseAddr());
+      if (!Stmt.isCopyStmt())
+        BasePtrs.insert(MA->getBaseAddr());
 
   std::string AliasScopeStr = "polly.alias.scope.";
   for (Value *BasePtr : BasePtrs)

Modified: polly/trunk/lib/CodeGen/IslAst.cpp
URL: http://llvm.org/viewvc/llvm-project/polly/trunk/lib/CodeGen/IslAst.cpp?rev=281441&r1=281440&r2=281441&view=diff
==============================================================================
--- polly/trunk/lib/CodeGen/IslAst.cpp (original)
+++ polly/trunk/lib/CodeGen/IslAst.cpp Wed Sep 14 01:26:09 2016
@@ -593,8 +593,7 @@ void IslAstInfo::printScop(raw_ostream &
   P = isl_ast_node_print(RootNode, P, Options);
   AstStr = isl_printer_get_str(P);
 
-  isl_union_map *Schedule =
-      isl_union_map_intersect_domain(S.getSchedule(), S.getDomains());
+  auto *Schedule = S.getScheduleTree();
 
   DEBUG({
     dbgs() << S.getContextStr() << "\n";
@@ -609,7 +608,7 @@ void IslAstInfo::printScop(raw_ostream &
   free(AstStr);
 
   isl_ast_expr_free(RunCondition);
-  isl_union_map_free(Schedule);
+  isl_schedule_free(Schedule);
   isl_ast_node_free(RootNode);
   isl_printer_free(P);
 }

Modified: polly/trunk/lib/CodeGen/IslNodeBuilder.cpp
URL: http://llvm.org/viewvc/llvm-project/polly/trunk/lib/CodeGen/IslNodeBuilder.cpp?rev=281441&r1=281440&r2=281441&view=diff
==============================================================================
--- polly/trunk/lib/CodeGen/IslNodeBuilder.cpp (original)
+++ polly/trunk/lib/CodeGen/IslNodeBuilder.cpp Wed Sep 14 01:26:09 2016
@@ -767,6 +767,23 @@ void IslNodeBuilder::createSubstitutions
   isl_ast_expr_free(Expr);
 }
 
+void IslNodeBuilder::generateCopyStmt(
+    ScopStmt *Stmt, __isl_keep isl_id_to_ast_expr *NewAccesses) {
+  assert(Stmt->size() == 2);
+  auto ReadAccess = Stmt->begin();
+  auto WriteAccess = ReadAccess++;
+  assert((*ReadAccess)->isRead() && (*WriteAccess)->isMustWrite());
+  assert((*ReadAccess)->getElementType() == (*WriteAccess)->getElementType() &&
+         "Accesses use the same data type");
+  assert((*ReadAccess)->isArrayKind() && (*WriteAccess)->isArrayKind());
+  auto *AccessExpr =
+      isl_id_to_ast_expr_get(NewAccesses, (*ReadAccess)->getId());
+  auto *LoadValue = ExprBuilder.create(AccessExpr);
+  AccessExpr = isl_id_to_ast_expr_get(NewAccesses, (*WriteAccess)->getId());
+  auto *StoreAddr = ExprBuilder.createAccessAddress(AccessExpr);
+  Builder.CreateStore(LoadValue, StoreAddr);
+}
+
 void IslNodeBuilder::createUser(__isl_take isl_ast_node *User) {
   LoopToScevMapT LTS;
   isl_id *Id;
@@ -781,12 +798,17 @@ void IslNodeBuilder::createUser(__isl_ta
 
   Stmt = (ScopStmt *)isl_id_get_user(Id);
   auto *NewAccesses = createNewAccesses(Stmt, User);
-  createSubstitutions(Expr, Stmt, LTS);
-
-  if (Stmt->isBlockStmt())
-    BlockGen.copyStmt(*Stmt, LTS, NewAccesses);
-  else
-    RegionGen.copyStmt(*Stmt, LTS, NewAccesses);
+  if (Stmt->isCopyStmt()) {
+    generateCopyStmt(Stmt, NewAccesses);
+    isl_ast_expr_free(Expr);
+  } else {
+    createSubstitutions(Expr, Stmt, LTS);
+
+    if (Stmt->isBlockStmt())
+      BlockGen.copyStmt(*Stmt, LTS, NewAccesses);
+    else
+      RegionGen.copyStmt(*Stmt, LTS, NewAccesses);
+  }
 
   isl_id_to_ast_expr_free(NewAccesses);
   isl_ast_node_free(User);

Modified: polly/trunk/lib/Exchange/JSONExporter.cpp
URL: http://llvm.org/viewvc/llvm-project/polly/trunk/lib/Exchange/JSONExporter.cpp?rev=281441&r1=281440&r2=281441&view=diff
==============================================================================
--- polly/trunk/lib/Exchange/JSONExporter.cpp (original)
+++ polly/trunk/lib/Exchange/JSONExporter.cpp Wed Sep 14 01:26:09 2016
@@ -294,6 +294,8 @@ bool JSONImporter::importSchedule(Scop &
   int Index = 0;
   for (ScopStmt &Stmt : S) {
     Json::Value Schedule = JScop["statements"][Index]["schedule"];
+    assert(!Schedule.asString().empty() &&
+           "Schedules that contain extension nodes require special handling.");
     isl_map *Map = isl_map_read_from_str(S.getIslCtx(), Schedule.asCString());
     isl_space *Space = Stmt.getDomainSpace();
 

Modified: polly/trunk/lib/Transform/DeadCodeElimination.cpp
URL: http://llvm.org/viewvc/llvm-project/polly/trunk/lib/Transform/DeadCodeElimination.cpp?rev=281441&r1=281440&r2=281441&view=diff
==============================================================================
--- polly/trunk/lib/Transform/DeadCodeElimination.cpp (original)
+++ polly/trunk/lib/Transform/DeadCodeElimination.cpp Wed Sep 14 01:26:09 2016
@@ -92,6 +92,8 @@ char DeadCodeElim::ID = 0;
 // no point in trying to remove them from the live-out set.
 __isl_give isl_union_set *DeadCodeElim::getLiveOut(Scop &S) {
   isl_union_map *Schedule = S.getSchedule();
+  assert(Schedule &&
+         "Schedules that contain extension nodes require special handling.");
   isl_union_map *WriteIterations = isl_union_map_reverse(S.getMustWrites());
   isl_union_map *WriteTimes =
       isl_union_map_apply_range(WriteIterations, isl_union_map_copy(Schedule));

Modified: polly/trunk/lib/Transform/ScheduleOptimizer.cpp
URL: http://llvm.org/viewvc/llvm-project/polly/trunk/lib/Transform/ScheduleOptimizer.cpp?rev=281441&r1=281440&r2=281441&view=diff
==============================================================================
--- polly/trunk/lib/Transform/ScheduleOptimizer.cpp (original)
+++ polly/trunk/lib/Transform/ScheduleOptimizer.cpp Wed Sep 14 01:26:09 2016
@@ -660,6 +660,76 @@ identifyAccessByAccessRelation(ScopStmt
   return IdentifiedAccess;
 }
 
+/// Add constrains to @Dim dimension of @p ExtMap.
+///
+/// If @ExtMap has the following form [O0, O1, O2]->[I1, I2, I3],
+/// the following constraint will be added
+/// Bound * OM <= IM <= Bound * (OM + 1) - 1,
+/// where M is @p Dim and Bound is @p Bound.
+///
+/// @param ExtMap The isl map to be modified.
+/// @param Dim The output dimension to be modfied.
+/// @param Bound The value that is used to specify the constraint.
+/// @return The modified isl map
+__isl_give isl_map *
+addExtensionMapMatMulDimConstraint(__isl_take isl_map *ExtMap, unsigned Dim,
+                                   unsigned Bound) {
+  assert(Bound != 0);
+  auto *ExtMapSpace = isl_map_get_space(ExtMap);
+  auto *ConstrSpace = isl_local_space_from_space(ExtMapSpace);
+  auto *Constr =
+      isl_constraint_alloc_inequality(isl_local_space_copy(ConstrSpace));
+  Constr = isl_constraint_set_coefficient_si(Constr, isl_dim_out, Dim, 1);
+  Constr =
+      isl_constraint_set_coefficient_si(Constr, isl_dim_in, Dim, Bound * (-1));
+  ExtMap = isl_map_add_constraint(ExtMap, Constr);
+  Constr = isl_constraint_alloc_inequality(ConstrSpace);
+  Constr = isl_constraint_set_coefficient_si(Constr, isl_dim_out, Dim, -1);
+  Constr = isl_constraint_set_coefficient_si(Constr, isl_dim_in, Dim, Bound);
+  Constr = isl_constraint_set_constant_si(Constr, Bound - 1);
+  return isl_map_add_constraint(ExtMap, Constr);
+}
+
+/// Create an access relation that is specific for matrix multiplication
+/// pattern.
+///
+/// Create an access relation of the following form:
+/// { [O0, O1, O2]->[I1, I2, I3] :
+///   FirstOutputDimBound * O0 <= I1 <= FirstOutputDimBound * (O0 + 1) - 1
+///   and SecondOutputDimBound * O1 <= I2 <= SecondOutputDimBound * (O1 + 1) - 1
+///   and ThirdOutputDimBound * O2 <= I3 <= ThirdOutputDimBound * (O2 + 1) - 1}
+///   where FirstOutputDimBound is @p FirstOutputDimBound,
+///   SecondOutputDimBound is @p SecondOutputDimBound,
+///   ThirdOutputDimBound is @p ThirdOutputDimBound
+///
+/// @param Ctx The isl context.
+/// @param FirstOutputDimBound,
+///        SecondOutputDimBound,
+///        ThirdOutputDimBound The parameters of the access relation.
+/// @return The specified access relation.
+__isl_give isl_map *getMatMulExt(isl_ctx *Ctx, unsigned FirstOutputDimBound,
+                                 unsigned SecondOutputDimBound,
+                                 unsigned ThirdOutputDimBound) {
+  auto *NewRelSpace = isl_space_alloc(Ctx, 0, 3, 3);
+  auto *extensionMap = isl_map_universe(NewRelSpace);
+  if (!FirstOutputDimBound)
+    extensionMap = isl_map_fix_si(extensionMap, isl_dim_out, 0, 0);
+  else
+    extensionMap = addExtensionMapMatMulDimConstraint(extensionMap, 0,
+                                                      FirstOutputDimBound);
+  if (!SecondOutputDimBound)
+    extensionMap = isl_map_fix_si(extensionMap, isl_dim_out, 1, 0);
+  else
+    extensionMap = addExtensionMapMatMulDimConstraint(extensionMap, 1,
+                                                      SecondOutputDimBound);
+  if (!ThirdOutputDimBound)
+    extensionMap = isl_map_fix_si(extensionMap, isl_dim_out, 2, 0);
+  else
+    extensionMap = addExtensionMapMatMulDimConstraint(extensionMap, 2,
+                                                      ThirdOutputDimBound);
+  return extensionMap;
+}
+
 /// Create an access relation that is specific to the matrix
 ///        multiplication pattern.
 ///
@@ -758,6 +828,14 @@ __isl_give isl_map *getMatMulAccRel(__is
   return isl_map_apply_range(MapOldIndVar, AccessRel);
 }
 
+__isl_give isl_schedule_node *
+createExtensionNode(__isl_take isl_schedule_node *Node,
+                    __isl_take isl_map *ExtensionMap) {
+  auto *Extension = isl_union_map_from_map(ExtensionMap);
+  auto *NewNode = isl_schedule_node_from_extension(Extension);
+  return isl_schedule_node_graft_before(Node, NewNode);
+}
+
 /// Apply the packing transformation.
 ///
 /// The packing transformation can be described as a data-layout
@@ -772,9 +850,9 @@ __isl_give isl_map *getMatMulAccRel(__is
 /// @param MicroParams, MacroParams Parameters of the BLIS kernel
 ///                                 to be taken into account.
 /// @return The optimized schedule node.
-static void optimizeDataLayoutMatrMulPattern(__isl_take isl_map *MapOldIndVar,
-                                             MicroKernelParamsTy MicroParams,
-                                             MacroKernelParamsTy MacroParams) {
+static __isl_give isl_schedule_node *optimizeDataLayoutMatrMulPattern(
+    __isl_take isl_schedule_node *Node, __isl_take isl_map *MapOldIndVar,
+    MicroKernelParamsTy MicroParams, MacroKernelParamsTy MacroParams) {
   auto InputDimsId = isl_map_get_tuple_id(MapOldIndVar, isl_dim_in);
   auto *Stmt = static_cast<ScopStmt *>(isl_id_get_user(InputDimsId));
   isl_id_free(InputDimsId);
@@ -782,8 +860,12 @@ static void optimizeDataLayoutMatrMulPat
   MemoryAccess *MemAccessB = identifyAccessB(Stmt);
   if (!MemAccessA || !MemAccessB) {
     isl_map_free(MapOldIndVar);
-    return;
+    return Node;
   }
+  Node = isl_schedule_node_parent(isl_schedule_node_parent(Node));
+  Node = isl_schedule_node_parent(isl_schedule_node_parent(Node));
+  Node = isl_schedule_node_parent(Node);
+  Node = isl_schedule_node_child(isl_schedule_node_band_split(Node, 2), 0);
   auto *AccRel =
       getMatMulAccRel(isl_map_copy(MapOldIndVar), MacroParams.Kc, 3, 6);
   unsigned FirstDimSize = MacroParams.Mc * MacroParams.Kc / MicroParams.Mr;
@@ -791,14 +873,34 @@ static void optimizeDataLayoutMatrMulPat
   auto *SAI = Stmt->getParent()->createScopArrayInfo(
       MemAccessA->getElementType(), "Packed_A", {FirstDimSize, SecondDimSize});
   AccRel = isl_map_set_tuple_id(AccRel, isl_dim_out, SAI->getBasePtrId());
+  auto *OldAcc = MemAccessA->getAccessRelation();
   MemAccessA->setNewAccessRelation(AccRel);
+  auto *ExtMap =
+      getMatMulExt(Stmt->getIslCtx(), MacroParams.Mc, 0, MacroParams.Kc);
+  ExtMap = isl_map_project_out(ExtMap, isl_dim_in, 1, 1);
+  auto *Domain = Stmt->getDomain();
+  auto *NewStmt = Stmt->getParent()->addScopStmt(
+      OldAcc, MemAccessA->getAccessRelation(), isl_set_copy(Domain));
+  ExtMap = isl_map_set_tuple_id(ExtMap, isl_dim_out, NewStmt->getDomainId());
+  Node = createExtensionNode(Node, ExtMap);
+  Node = isl_schedule_node_child(Node, 0);
   AccRel = getMatMulAccRel(MapOldIndVar, MacroParams.Kc, 4, 7);
   FirstDimSize = MacroParams.Nc * MacroParams.Kc / MicroParams.Nr;
   SecondDimSize = MicroParams.Nr;
   SAI = Stmt->getParent()->createScopArrayInfo(
       MemAccessB->getElementType(), "Packed_B", {FirstDimSize, SecondDimSize});
   AccRel = isl_map_set_tuple_id(AccRel, isl_dim_out, SAI->getBasePtrId());
+  OldAcc = MemAccessB->getAccessRelation();
   MemAccessB->setNewAccessRelation(AccRel);
+  ExtMap = getMatMulExt(Stmt->getIslCtx(), 0, MacroParams.Nc, MacroParams.Kc);
+  isl_map_move_dims(ExtMap, isl_dim_out, 0, isl_dim_in, 1, 1);
+  isl_map_move_dims(ExtMap, isl_dim_in, 2, isl_dim_out, 0, 1);
+  NewStmt = Stmt->getParent()->addScopStmt(
+      OldAcc, MemAccessB->getAccessRelation(), Domain);
+  ExtMap = isl_map_set_tuple_id(ExtMap, isl_dim_out, NewStmt->getDomainId());
+  Node = createExtensionNode(Node, ExtMap);
+  Node = isl_schedule_node_child(isl_schedule_node_child(Node, 0), 0);
+  return isl_schedule_node_child(isl_schedule_node_child(Node, 0), 0);
 }
 
 /// Get a relation mapping induction variables produced by schedule
@@ -842,9 +944,8 @@ __isl_give isl_schedule_node *ScheduleTr
       Node, MicroKernelParams, MacroKernelParams);
   if (!MapOldIndVar)
     return Node;
-  optimizeDataLayoutMatrMulPattern(MapOldIndVar, MicroKernelParams,
-                                   MacroKernelParams);
-  return Node;
+  return optimizeDataLayoutMatrMulPattern(Node, MapOldIndVar, MicroKernelParams,
+                                          MacroKernelParams);
 }
 
 bool ScheduleTreeOptimizer::isMatrMultPattern(
@@ -901,7 +1002,7 @@ __isl_give isl_schedule_node *ScheduleTr
 }
 
 bool ScheduleTreeOptimizer::isProfitableSchedule(
-    Scop &S, __isl_keep isl_union_map *NewSchedule) {
+    Scop &S, __isl_keep isl_schedule *NewSchedule) {
   // To understand if the schedule has been optimized we check if the schedule
   // has changed at all.
   // TODO: We can improve this by tracking if any necessarily beneficial
@@ -911,9 +1012,15 @@ bool ScheduleTreeOptimizer::isProfitable
   // optimizations, by comparing (yet to be defined) performance metrics
   // before/after the scheduling optimizer
   // (e.g., #stride-one accesses)
+  if (S.containsExtensionNode(NewSchedule))
+    return true;
+  auto *NewScheduleMap = isl_schedule_get_map(NewSchedule);
   isl_union_map *OldSchedule = S.getSchedule();
-  bool changed = !isl_union_map_is_equal(OldSchedule, NewSchedule);
+  assert(OldSchedule && "Only IslScheduleOptimizer can insert extension nodes "
+                        "that make Scop::getSchedule() return nullptr.");
+  bool changed = !isl_union_map_is_equal(OldSchedule, NewScheduleMap);
   isl_union_map_free(OldSchedule);
+  isl_union_map_free(NewScheduleMap);
   return changed;
 }
 
@@ -1090,10 +1197,8 @@ bool IslScheduleOptimizer::runOnScop(Sco
   auto *TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
   isl_schedule *NewSchedule =
       ScheduleTreeOptimizer::optimizeSchedule(Schedule, TTI);
-  isl_union_map *NewScheduleMap = isl_schedule_get_map(NewSchedule);
 
-  if (!ScheduleTreeOptimizer::isProfitableSchedule(S, NewScheduleMap)) {
-    isl_union_map_free(NewScheduleMap);
+  if (!ScheduleTreeOptimizer::isProfitableSchedule(S, NewSchedule)) {
     isl_schedule_free(NewSchedule);
     return false;
   }
@@ -1104,7 +1209,6 @@ bool IslScheduleOptimizer::runOnScop(Sco
   if (OptimizedScops)
     S.dump();
 
-  isl_union_map_free(NewScheduleMap);
   return false;
 }
 

Modified: polly/trunk/test/ScheduleOptimizer/mat_mul_pattern_data_layout.ll
URL: http://llvm.org/viewvc/llvm-project/polly/trunk/test/ScheduleOptimizer/mat_mul_pattern_data_layout.ll?rev=281441&r1=281440&r2=281441&view=diff
==============================================================================
--- polly/trunk/test/ScheduleOptimizer/mat_mul_pattern_data_layout.ll (original)
+++ polly/trunk/test/ScheduleOptimizer/mat_mul_pattern_data_layout.ll Wed Sep 14 01:26:09 2016
@@ -12,11 +12,34 @@
 ; CHECK:        double Packed_A[ { [] -> [(1024)] } ][ { [] -> [(4)] } ]; // Element size 8
 ; CHECK:        double Packed_B[ { [] -> [(3072)] } ][ { [] -> [(8)] } ]; // Element size 8
 ;
-; CHECK:                { Stmt_bb14[i0, i1, i2] -> MemRef_arg6[i0, i2] };
-; CHECK:           new: { Stmt_bb14[i0, i1, i2] -> Packed_A[o0, o1] : 256*floor((-i2 + o0)/256) = -i2 + o0 and 4*floor((-i0 + o1)/4) = -i0 + o1 and 0 <= o1 <= 3 and -3 + i0 - 16*floor((i0)/16) <= 4*floor((o0)/256) <= i0 - 16*floor((i0)/16) };
+; CHECK:                { Stmt_Copy_0[i0, i1, i2] -> MemRef_arg6[i0, i2] };
+; CHECK:           new: { Stmt_Copy_0[i0, i1, i2] -> Packed_A[o0, o1] : 256*floor((-i2 + o0)/256) = -i2 + o0 and 4*floor((-i0 + o1)/4) = -i0 + o1 and 0 <= o1 <= 3 and -3 + i0 - 16*floor((i0)/16) <= 4*floor((o0)/256) <= i0 - 16*floor((i0)/16) };
 ;
-; CHECK:                { Stmt_bb14[i0, i1, i2] -> MemRef_arg7[i2, i1] };
-; CHECK:           new: { Stmt_bb14[i0, i1, i2] -> Packed_B[o0, o1] : 256*floor((-i2 + o0)/256) = -i2 + o0 and 8*floor((-i1 + o1)/8) = -i1 + o1 and 0 <= o1 <= 7 and -7 + i1 - 96*floor((i1)/96) <= 8*floor((o0)/256) <= i1 - 96*floor((i1)/96) };
+; CHECK:                { Stmt_Copy_0[i0, i1, i2] -> MemRef_arg7[i2, i1] };
+; CHECK:           new: { Stmt_Copy_0[i0, i1, i2] -> Packed_B[o0, o1] : 256*floor((-i2 + o0)/256) = -i2 + o0 and 8*floor((-i1 + o1)/8) = -i1 + o1 and 0 <= o1 <= 7 and -7 + i1 - 96*floor((i1)/96) <= 8*floor((o0)/256) <= i1 - 96*floor((i1)/96) };
+;
+; CHECK:    	CopyStmt_0
+; CHECK:            Domain :=
+; CHECK:                { CopyStmt_0[i0, i1, i2] : 0 <= i0 <= 1055 and 0 <= i1 <= 1055 and 0 <= i2 <= 1023 };
+; CHECK:            Schedule :=
+; CHECK:                ;
+; CHECK:            MustWriteAccess :=	[Reduction Type: NONE] [Scalar: 0]
+; CHECK:                null;
+; CHECK:           new: { CopyStmt_0[i0, i1, i2] -> Packed_A[o0, o1] : 256*floor((-i2 + o0)/256) = -i2 + o0 and 4*floor((-i0 + o1)/4) = -i0 + o1 and 0 <= o1 <= 3 and -3 + i0 - 16*floor((i0)/16) <= 4*floor((o0)/256) <= i0 - 16*floor((i0)/16) };
+; CHECK:            ReadAccess :=	[Reduction Type: NONE] [Scalar: 0]
+; CHECK:                null;
+; CHECK:           new: { CopyStmt_0[i0, i1, i2] -> MemRef_arg6[i0, i2] };
+; CHECK:    	CopyStmt_1
+; CHECK:            Domain :=
+; CHECK:                { CopyStmt_1[i0, i1, i2] : 0 <= i0 <= 1055 and 0 <= i1 <= 1055 and 0 <= i2 <= 1023 };
+; CHECK:            Schedule :=
+; CHECK:                ;
+; CHECK:            MustWriteAccess :=	[Reduction Type: NONE] [Scalar: 0]
+; CHECK:                null;
+; CHECK:           new: { CopyStmt_1[i0, i1, i2] -> Packed_B[o0, o1] : 256*floor((-i2 + o0)/256) = -i2 + o0 and 8*floor((-i1 + o1)/8) = -i1 + o1 and 0 <= o1 <= 7 and -7 + i1 - 96*floor((i1)/96) <= 8*floor((o0)/256) <= i1 - 96*floor((i1)/96) };
+; CHECK:            ReadAccess :=	[Reduction Type: NONE] [Scalar: 0]
+; CHECK:                null;
+; CHECK:           new: { CopyStmt_1[i0, i1, i2] -> MemRef_arg7[i2, i1] };
 ;
 target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128"
 target triple = "x86_64-unknown-unknown"
@@ -35,10 +58,10 @@ bb9:
   %tmp12 = load double, double* %tmp11, align 8
   %tmp13 = fmul double %tmp12, %arg4
   store double %tmp13, double* %tmp11, align 8
-  br label %bb14
+  br label %Copy_0
 
-bb14:                                             ; preds = %bb14, %bb9
-  %tmp15 = phi i64 [ 0, %bb9 ], [ %tmp24, %bb14 ]
+Copy_0:                                             ; preds = %Copy_0, %bb9
+  %tmp15 = phi i64 [ 0, %bb9 ], [ %tmp24, %Copy_0 ]
   %tmp16 = getelementptr inbounds [1024 x double], [1024 x double]* %arg6, i64 %tmp, i64 %tmp15
   %tmp17 = load double, double* %tmp16, align 8
   %tmp18 = fmul double %tmp17, %arg3
@@ -50,9 +73,9 @@ bb14:
   store double %tmp23, double* %tmp11, align 8
   %tmp24 = add nuw nsw i64 %tmp15, 1
   %tmp25 = icmp ne i64 %tmp24, 1024
-  br i1 %tmp25, label %bb14, label %bb26
+  br i1 %tmp25, label %Copy_0, label %bb26
 
-bb26:                                             ; preds = %bb14
+bb26:                                             ; preds = %Copy_0
   %tmp27 = add nuw nsw i64 %tmp10, 1
   %tmp28 = icmp ne i64 %tmp27, 1056
   br i1 %tmp28, label %bb9, label %bb29




More information about the llvm-commits mailing list