[polly] r273397 - Apply all necessary tilings and unrollings to get a micro-kernel

Roman Gareev via llvm-commits llvm-commits at lists.llvm.org
Wed Jun 22 02:52:38 PDT 2016


Author: romangareev
Date: Wed Jun 22 04:52:37 2016
New Revision: 273397

URL: http://llvm.org/viewvc/llvm-project?rev=273397&view=rev
Log:
Apply all necessary tilings and unrollings to get a micro-kernel

This is the first patch to apply the BLIS matmul optimization pattern
on matmul kernels
(http://www.cs.utexas.edu/users/flame/pubs/TOMS-BLIS-Analytical.pdf).
BLIS implements gemm as three nested loops around a macro-kernel,
plus two packing routines. The macro-kernel is implemented in terms
of two additional loops around a micro-kernel. The micro-kernel
is a loop around a rank-1 (i.e., outer product) update.
In this change we create the BLIS micro-kernel by applying
a combination of tiling and unrolling. In subsequent changes
we will add the extraction of the BLIS macro-kernel
and implement the packing transformation.

Contributed-by: Roman Gareev <gareevroman at gmail.com>
Reviewed-by: Tobias Grosser <tobias at grosser.es>

Differential Revision: http://reviews.llvm.org/D21140

Added:
    polly/trunk/test/ScheduleOptimizer/pattern-matching-based-opts_3.ll
Modified:
    polly/trunk/include/polly/ScheduleOptimizer.h
    polly/trunk/lib/Transform/ScheduleOptimizer.cpp

Modified: polly/trunk/include/polly/ScheduleOptimizer.h
URL: http://llvm.org/viewvc/llvm-project/polly/trunk/include/polly/ScheduleOptimizer.h?rev=273397&r1=273396&r2=273397&view=diff
==============================================================================
--- polly/trunk/include/polly/ScheduleOptimizer.h (original)
+++ polly/trunk/include/polly/ScheduleOptimizer.h Wed Jun 22 04:52:37 2016
@@ -13,6 +13,7 @@
 #define POLLY_SCHEDULE_OPTIMIZER_H
 
 #include "llvm/ADT/ArrayRef.h"
+#include "llvm/Analysis/TargetTransformInfo.h"
 #include "isl/ctx.h"
 
 struct isl_schedule;
@@ -37,9 +38,11 @@ public:
   ///
   /// @param Schedule The schedule object the transformations will be applied
   ///                 to.
+  /// @param TTI      Target Transform Info.
   /// @returns        The transformed schedule.
   static __isl_give isl_schedule *
-  optimizeSchedule(__isl_take isl_schedule *Schedule);
+  optimizeSchedule(__isl_take isl_schedule *Schedule,
+                   const llvm::TargetTransformInfo *TTI = nullptr);
 
   /// @brief Apply schedule tree transformations.
   ///
@@ -51,9 +54,11 @@ public:
   ///   - Prevectorization
   ///
   /// @param Node The schedule object post-transformations will be applied to.
+  /// @param TTI  Target Transform Info.
   /// @returns    The transformed schedule.
   static __isl_give isl_schedule_node *
-  optimizeScheduleNode(__isl_take isl_schedule_node *Node);
+  optimizeScheduleNode(__isl_take isl_schedule_node *Node,
+                       const llvm::TargetTransformInfo *TTI = nullptr);
 
   /// @brief Decide if the @p NewSchedule is profitable for @p S.
   ///
@@ -100,6 +105,32 @@ private:
   applyRegisterTiling(__isl_take isl_schedule_node *Node,
                       llvm::ArrayRef<int> TileSizes, int DefaultTileSize);
 
+  /// @brief Apply the BLIS matmul optimization pattern
+  ///
+  /// Apply the BLIS matmul optimization pattern
+  /// (http://www.cs.utexas.edu/users/flame/pubs/TOMS-BLIS-Analytical.pdf).
+  /// BLIS implements gemm as three nested loops around a macro-kernel,
+  /// plus two packing routines. The macro-kernel is implemented in terms
+  /// of two additional loops around a micro-kernel. The micro-kernel
+  /// is a loop around a rank-1 (i.e., outer product) update.
+  ///
+  /// We create the BLIS micro-kernel by applying a combination of tiling
+  /// and unrolling. In subsequent changes we will add the extraction
+  /// of the BLIS macro-kernel and implement the packing transformation.
+  ///
+  /// It is assumed that the Node is successfully checked
+  /// by ScheduleTreeOptimizer::isMatrMultPattern. Consequently
+  /// in case of matmul kernels the application of optimizeMatMulPattern
+  /// can lead to close-to-peak performance. Maybe it can be generalized
+  /// to effectively optimize the whole class of successfully checked
+  /// statements.
+  ///
+  /// @param Node the node that contains a band to be optimized.
+  /// @return Modified isl_schedule_node.
+  static __isl_give isl_schedule_node *
+  optimizeMatMulPattern(__isl_take isl_schedule_node *Node,
+                        const llvm::TargetTransformInfo *TTI);
+
   /// @brief Check if this node is a band node we want to tile.
   ///
   /// We look for innermost band nodes where individual dimensions are marked as

Modified: polly/trunk/lib/Transform/ScheduleOptimizer.cpp
URL: http://llvm.org/viewvc/llvm-project/polly/trunk/lib/Transform/ScheduleOptimizer.cpp?rev=273397&r1=273396&r2=273397&view=diff
==============================================================================
--- polly/trunk/lib/Transform/ScheduleOptimizer.cpp (original)
+++ polly/trunk/lib/Transform/ScheduleOptimizer.cpp Wed Jun 22 04:52:37 2016
@@ -53,6 +53,7 @@
 #include "polly/Options.h"
 #include "polly/ScopInfo.h"
 #include "polly/Support/GICHelper.h"
+#include "llvm/Analysis/TargetTransformInfo.h"
 #include "llvm/Support/Debug.h"
 #include "isl/aff.h"
 #include "isl/band.h"
@@ -119,6 +120,20 @@ static cl::opt<bool> FirstLevelTiling("p
                                       cl::init(true), cl::ZeroOrMore,
                                       cl::cat(PollyCategory));
 
+static cl::opt<int> LatencyVectorFma(
+    "polly-target-latency-vector-fma",
+    cl::desc("The minimal number of cycles between issuing two "
+             "dependent consecutive vector fused multiply-add "
+             "instructions."),
+    cl::Hidden, cl::init(8), cl::ZeroOrMore, cl::cat(PollyCategory));
+
+static cl::opt<int> ThrougputVectorFma(
+    "polly-target-througput-vector-fma",
+    cl::desc("A throughput of the processor floating-point arithmetic units "
+             "expressed in the number of vector fused multiply-add "
+             "instructions per clock cycle."),
+    cl::Hidden, cl::init(1), cl::ZeroOrMore, cl::cat(PollyCategory));
+
 static cl::opt<int> FirstLevelDefaultTileSize(
     "polly-default-tile-size",
     cl::desc("The default tile size (if not enough were provided by"
@@ -478,6 +493,23 @@ static __isl_give isl_map *circularShift
   return isl_map_set_tuple_id(IslMap, isl_dim_in, InputDimsId);
 }
 
+__isl_give isl_schedule_node *ScheduleTreeOptimizer::optimizeMatMulPattern(
+    __isl_take isl_schedule_node *Node, const llvm::TargetTransformInfo *TTI) {
+  assert(TTI && "The target transform info should be provided.");
+  // Get a micro-kernel.
+  // Nvec - Number of double-precision floating-point numbers that can be hold
+  // by a vector register. Use 2 by default.
+  auto Nvec = TTI->getRegisterBitWidth(true) / 64;
+  if (Nvec == 0)
+    Nvec = 2;
+  int Nr =
+      ceil(sqrt(Nvec * LatencyVectorFma * ThrougputVectorFma) / Nvec) * Nvec;
+  int Mr = ceil(Nvec * LatencyVectorFma * ThrougputVectorFma / Nr);
+  std::vector<int> MicroKernelParams{Mr, Nr};
+  Node = applyRegisterTiling(Node, MicroKernelParams, 1);
+  return Node;
+}
+
 bool ScheduleTreeOptimizer::isMatrMultPattern(
     __isl_keep isl_schedule_node *Node) {
   auto *PartialSchedule =
@@ -508,16 +540,21 @@ ScheduleTreeOptimizer::optimizeBand(__is
   if (!isTileableBandNode(Node))
     return Node;
 
-  if (PMBasedOpts && isMatrMultPattern(Node))
+  if (PMBasedOpts && User && isMatrMultPattern(Node)) {
     DEBUG(dbgs() << "The matrix multiplication pattern was detected\n");
+    const llvm::TargetTransformInfo *TTI;
+    TTI = static_cast<const llvm::TargetTransformInfo *>(User);
+    Node = optimizeMatMulPattern(Node, TTI);
+  }
 
   return standardBandOpts(Node, User);
 }
 
 __isl_give isl_schedule *
-ScheduleTreeOptimizer::optimizeSchedule(__isl_take isl_schedule *Schedule) {
+ScheduleTreeOptimizer::optimizeSchedule(__isl_take isl_schedule *Schedule,
+                                        const llvm::TargetTransformInfo *TTI) {
   isl_schedule_node *Root = isl_schedule_get_root(Schedule);
-  Root = optimizeScheduleNode(Root);
+  Root = optimizeScheduleNode(Root, TTI);
   isl_schedule_free(Schedule);
   auto S = isl_schedule_node_get_schedule(Root);
   isl_schedule_node_free(Root);
@@ -525,8 +562,9 @@ ScheduleTreeOptimizer::optimizeSchedule(
 }
 
 __isl_give isl_schedule_node *ScheduleTreeOptimizer::optimizeScheduleNode(
-    __isl_take isl_schedule_node *Node) {
-  Node = isl_schedule_node_map_descendant_bottom_up(Node, optimizeBand, NULL);
+    __isl_take isl_schedule_node *Node, const llvm::TargetTransformInfo *TTI) {
+  Node = isl_schedule_node_map_descendant_bottom_up(
+      Node, optimizeBand, const_cast<void *>(static_cast<const void *>(TTI)));
   return Node;
 }
 
@@ -714,7 +752,10 @@ bool IslScheduleOptimizer::runOnScop(Sco
     isl_printer_free(P);
   });
 
-  isl_schedule *NewSchedule = ScheduleTreeOptimizer::optimizeSchedule(Schedule);
+  Function &F = S.getFunction();
+  auto *TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
+  isl_schedule *NewSchedule =
+      ScheduleTreeOptimizer::optimizeSchedule(Schedule, TTI);
   isl_union_map *NewScheduleMap = isl_schedule_get_map(NewSchedule);
 
   if (!ScheduleTreeOptimizer::isProfitableSchedule(S, NewScheduleMap)) {
@@ -752,6 +793,7 @@ void IslScheduleOptimizer::printScop(raw
 void IslScheduleOptimizer::getAnalysisUsage(AnalysisUsage &AU) const {
   ScopPass::getAnalysisUsage(AU);
   AU.addRequired<DependenceInfo>();
+  AU.addRequired<TargetTransformInfoWrapperPass>();
 }
 
 Pass *polly::createIslScheduleOptimizerPass() {
@@ -762,5 +804,6 @@ INITIALIZE_PASS_BEGIN(IslScheduleOptimiz
                       "Polly - Optimize schedule of SCoP", false, false);
 INITIALIZE_PASS_DEPENDENCY(DependenceInfo);
 INITIALIZE_PASS_DEPENDENCY(ScopInfoRegionPass);
+INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass);
 INITIALIZE_PASS_END(IslScheduleOptimizer, "polly-opt-isl",
                     "Polly - Optimize schedule of SCoP", false, false)

Added: polly/trunk/test/ScheduleOptimizer/pattern-matching-based-opts_3.ll
URL: http://llvm.org/viewvc/llvm-project/polly/trunk/test/ScheduleOptimizer/pattern-matching-based-opts_3.ll?rev=273397&view=auto
==============================================================================
--- polly/trunk/test/ScheduleOptimizer/pattern-matching-based-opts_3.ll (added)
+++ polly/trunk/test/ScheduleOptimizer/pattern-matching-based-opts_3.ll Wed Jun 22 04:52:37 2016
@@ -0,0 +1,128 @@
+; RUN: opt %loadPolly -polly-opt-isl -polly-pattern-matching-based-opts=true -polly-target-througput-vector-fma=1 -polly-target-latency-vector-fma=8 -analyze -polly-ast < %s 2>&1 | FileCheck %s
+;
+;    /* C := alpha*A*B + beta*C */
+;    for (i = 0; i < _PB_NI; i++)
+;      for (j = 0; j < _PB_NJ; j++)
+;        {
+;	   C[i][j] *= beta;
+;	   for (k = 0; k < _PB_NK; ++k)
+;	     C[i][j] += alpha * A[i][k] * B[k][j];
+;        }
+;
+; CHECK:    {
+; CHECK:      // 1st level tiling - Tiles
+; CHECK:      for (int c0 = 0; c0 <= 32; c0 += 1)
+; CHECK:        for (int c1 = 0; c1 <= 32; c1 += 1) {
+; CHECK:          // 1st level tiling - Points
+; CHECK:          for (int c2 = 0; c2 <= 31; c2 += 1)
+; CHECK:            for (int c3 = 0; c3 <= 31; c3 += 1)
+; CHECK:              Stmt_bb14(32 * c0 + c2, 32 * c1 + c3);
+; CHECK:        }
+; CHECK:      // Register tiling - Tiles
+; CHECK:      for (int c0 = 0; c0 <= 263; c0 += 1)
+; CHECK:        for (int c1 = 0; c1 <= 131; c1 += 1)
+; CHECK:          for (int c2 = 0; c2 <= 1023; c2 += 1) {
+; CHECK:            // Register tiling - Points
+; CHECK:            // 1st level tiling - Tiles
+; CHECK:            // 1st level tiling - Points
+; CHECK:            {
+; CHECK:              Stmt_bb24(4 * c0, 8 * c1, c2);
+; CHECK:              Stmt_bb24(4 * c0, 8 * c1 + 1, c2);
+; CHECK:              Stmt_bb24(4 * c0, 8 * c1 + 2, c2);
+; CHECK:              Stmt_bb24(4 * c0, 8 * c1 + 3, c2);
+; CHECK:              Stmt_bb24(4 * c0, 8 * c1 + 4, c2);
+; CHECK:              Stmt_bb24(4 * c0, 8 * c1 + 5, c2);
+; CHECK:              Stmt_bb24(4 * c0, 8 * c1 + 6, c2);
+; CHECK:              Stmt_bb24(4 * c0, 8 * c1 + 7, c2);
+; CHECK:              Stmt_bb24(4 * c0 + 1, 8 * c1, c2);
+; CHECK:              Stmt_bb24(4 * c0 + 1, 8 * c1 + 1, c2);
+; CHECK:              Stmt_bb24(4 * c0 + 1, 8 * c1 + 2, c2);
+; CHECK:              Stmt_bb24(4 * c0 + 1, 8 * c1 + 3, c2);
+; CHECK:              Stmt_bb24(4 * c0 + 1, 8 * c1 + 4, c2);
+; CHECK:              Stmt_bb24(4 * c0 + 1, 8 * c1 + 5, c2);
+; CHECK:              Stmt_bb24(4 * c0 + 1, 8 * c1 + 6, c2);
+; CHECK:              Stmt_bb24(4 * c0 + 1, 8 * c1 + 7, c2);
+; CHECK:              Stmt_bb24(4 * c0 + 2, 8 * c1, c2);
+; CHECK:              Stmt_bb24(4 * c0 + 2, 8 * c1 + 1, c2);
+; CHECK:              Stmt_bb24(4 * c0 + 2, 8 * c1 + 2, c2);
+; CHECK:              Stmt_bb24(4 * c0 + 2, 8 * c1 + 3, c2);
+; CHECK:              Stmt_bb24(4 * c0 + 2, 8 * c1 + 4, c2);
+; CHECK:              Stmt_bb24(4 * c0 + 2, 8 * c1 + 5, c2);
+; CHECK:              Stmt_bb24(4 * c0 + 2, 8 * c1 + 6, c2);
+; CHECK:              Stmt_bb24(4 * c0 + 2, 8 * c1 + 7, c2);
+; CHECK:              Stmt_bb24(4 * c0 + 3, 8 * c1, c2);
+; CHECK:              Stmt_bb24(4 * c0 + 3, 8 * c1 + 1, c2);
+; CHECK:              Stmt_bb24(4 * c0 + 3, 8 * c1 + 2, c2);
+; CHECK:              Stmt_bb24(4 * c0 + 3, 8 * c1 + 3, c2);
+; CHECK:              Stmt_bb24(4 * c0 + 3, 8 * c1 + 4, c2);
+; CHECK:              Stmt_bb24(4 * c0 + 3, 8 * c1 + 5, c2);
+; CHECK:              Stmt_bb24(4 * c0 + 3, 8 * c1 + 6, c2);
+; CHECK:              Stmt_bb24(4 * c0 + 3, 8 * c1 + 7, c2);
+; CHECK:            }
+; CHECK:          }
+; CHECK:    }
+;
+target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128"
+target triple = "x86_64-unknown-unknown"
+
+define internal void @kernel_gemm(i32 %arg, i32 %arg1, i32 %arg2, double %arg3, double %arg4, [1056 x double]* %arg5, [1024 x double]* %arg6, [1056 x double]* %arg7) #0 {
+bb:
+  br label %bb8
+
+bb8:                                              ; preds = %bb39, %bb
+  %tmp = phi i32 [ 0, %bb ], [ %tmp40, %bb39 ]
+  %tmp9 = icmp slt i32 %tmp, 1056
+  br i1 %tmp9, label %bb10, label %bb41
+
+bb10:                                             ; preds = %bb8
+  br label %bb11
+
+bb11:                                             ; preds = %bb37, %bb10
+  %tmp12 = phi i32 [ 0, %bb10 ], [ %tmp38, %bb37 ]
+  %tmp13 = icmp slt i32 %tmp12, 1056
+  br i1 %tmp13, label %bb14, label %bb39
+
+bb14:                                             ; preds = %bb11
+  %tmp15 = sext i32 %tmp12 to i64
+  %tmp16 = sext i32 %tmp to i64
+  %tmp17 = getelementptr inbounds [1056 x double], [1056 x double]* %arg5, i64 %tmp16
+  %tmp18 = getelementptr inbounds [1056 x double], [1056 x double]* %tmp17, i64 0, i64 %tmp15
+  %tmp19 = load double, double* %tmp18, align 8
+  %tmp20 = fmul double %tmp19, %arg4
+  store double %tmp20, double* %tmp18, align 8
+  br label %bb21
+
+bb21:                                             ; preds = %bb24, %bb14
+  %tmp22 = phi i32 [ 0, %bb14 ], [ %tmp36, %bb24 ]
+  %tmp23 = icmp slt i32 %tmp22, 1024
+  br i1 %tmp23, label %bb24, label %bb37
+
+bb24:                                             ; preds = %bb21
+  %tmp25 = sext i32 %tmp22 to i64
+  %tmp26 = getelementptr inbounds [1024 x double], [1024 x double]* %arg6, i64 %tmp16
+  %tmp27 = getelementptr inbounds [1024 x double], [1024 x double]* %tmp26, i64 0, i64 %tmp25
+  %tmp28 = load double, double* %tmp27, align 8
+  %tmp29 = fmul double %arg3, %tmp28
+  %tmp30 = getelementptr inbounds [1056 x double], [1056 x double]* %arg7, i64 %tmp25
+  %tmp31 = getelementptr inbounds [1056 x double], [1056 x double]* %tmp30, i64 0, i64 %tmp15
+  %tmp32 = load double, double* %tmp31, align 8
+  %tmp33 = fmul double %tmp29, %tmp32
+  %tmp34 = load double, double* %tmp18, align 8
+  %tmp35 = fadd double %tmp34, %tmp33
+  store double %tmp35, double* %tmp18, align 8
+  %tmp36 = add nsw i32 %tmp22, 1
+  br label %bb21
+
+bb37:                                             ; preds = %bb21
+  %tmp38 = add nsw i32 %tmp12, 1
+  br label %bb11
+
+bb39:                                             ; preds = %bb11
+  %tmp40 = add nsw i32 %tmp, 1
+  br label %bb8
+
+bb41:                                             ; preds = %bb8
+  ret void
+}
+
+attributes #0 = { nounwind uwtable "target-cpu"="x86-64" "target-features"="+aes,+avx,+cmov,+cx16,+fxsr,+mmx,+pclmul,+popcnt,+sse,+sse2,+sse3,+sse4.1,+sse4.2,+ssse3,+x87,+xsave,+xsaveopt" }




More information about the llvm-commits mailing list