[polly] r271128 - Determination of statements that contain matrix multiplication

Johannes Doerfert via llvm-commits llvm-commits at lists.llvm.org
Mon May 30 02:04:39 PDT 2016


Hey Roman,

First, thanks for the patch!

I inlined three small change requests you can think about.

Cheers,
  Johannes

On 05/28, Roman Gareev via llvm-commits wrote:
> Author: romangareev
> Date: Sat May 28 11:17:58 2016
> New Revision: 271128
> 
> URL: http://llvm.org/viewvc/llvm-project?rev=271128&view=rev
> Log:
> Determination of statements that contain matrix multiplication
> 
> Add determination of statements that contain, in particular,
> matrix multiplications and can be optimized with [1] to try to
> get close-to-peak performance. It can be enabled
> via polly-pm-based-opts, which is false by default.
> 
> Refs:
> [1] - http://www.cs.utexas.edu/users/flame/pubs/TOMS-BLIS-Analytical.pdf
> 
> Contributed-by: Roman Gareev <gareevroman at gmail.com>
> Reviewed-by: Tobias Grosser <tobias at grosser.es>
> 
> Differential Revision: http://reviews.llvm.org/D20575
> 
> Added:
>     polly/trunk/test/ScheduleOptimizer/pattern-matching-based-opts.ll
>     polly/trunk/test/ScheduleOptimizer/pattern-matching-based-opts_2.ll
> Modified:
>     polly/trunk/include/polly/ScheduleOptimizer.h
>     polly/trunk/lib/Transform/ScheduleOptimizer.cpp
> 
> Modified: polly/trunk/include/polly/ScheduleOptimizer.h
> URL: http://llvm.org/viewvc/llvm-project/polly/trunk/include/polly/ScheduleOptimizer.h?rev=271128&r1=271127&r2=271128&view=diff
> ==============================================================================
> --- polly/trunk/include/polly/ScheduleOptimizer.h (original)
> +++ polly/trunk/include/polly/ScheduleOptimizer.h Sat May 28 11:17:58 2016
> @@ -147,8 +147,45 @@ private:
>    ///      - if vectorization is enabled
>    ///
>    /// @param Node The schedule node to (possibly) optimize.
> -  /// @param User A pointer to forward some use information (currently unused).
> +  /// @param User A pointer to forward some use information
> +  ///        (currently unused).
>    static isl_schedule_node *optimizeBand(isl_schedule_node *Node, void *User);
> +
> +  /// @brief Apply additional optimizations on the bands in the schedule tree.
> +  ///
> +  /// We apply the following
> +  /// transformations:
> +  ///
> +  ///  - Tile the band
> +  ///  - Prevectorize the schedule of the band (or the point loop in case of
> +  ///    tiling).
> +  ///      - if vectorization is enabled
> +  ///
> +  /// @param Node The schedule node to (possibly) optimize.
> +  /// @param User A pointer to forward some use information
> +  ///        (currently unused).
> +  static isl_schedule_node *standardBandOpts(__isl_take isl_schedule_node *Node,
> +                                             void *User);
> +
> +  /// @brief Check if this node contains a partial schedule that could
> +  ///        probably be optimized with analytical modeling.
> +  ///
> +  /// isMatrMultPattern tries to determine whether the following conditions
> +  /// are true:
> +  /// 1. the partial schedule contains only one statement.
> +  /// 2. there are exactly three input dimensions.
> +  /// 3. all memory accesses of the statement will have stride 0 or 1, if we
> +  ///    interchange loops (switch the variable used in the inner loop to
> +  ///    the outer loop).
> +  /// 4. all memory accesses of the statement except from the last one, are
> +  ///    read memory access and the last one is write memory access.
> +  /// 5. all subscripts of the last memory access of the statement don’t
Can you use the plain version of "'" in the word "don't"? While my editor
is fine with this fancy (maybe utf8) character, I do not want to find out
which one is not.

> +  ///    contain the variable used in the inner loop.
> +  /// If this is the case, we could try to use an approach that is similar to
> +  /// the one used to get close-to-peak performance of matrix multiplications.
> +  ///
> +  /// @param Node The node to check.
> +  static bool isMatrMultPattern(__isl_keep isl_schedule_node *Node);
>  };
>  
>  #endif
> 
> Modified: polly/trunk/lib/Transform/ScheduleOptimizer.cpp
> URL: http://llvm.org/viewvc/llvm-project/polly/trunk/lib/Transform/ScheduleOptimizer.cpp?rev=271128&r1=271127&r2=271128&view=diff
> ==============================================================================
> --- polly/trunk/lib/Transform/ScheduleOptimizer.cpp (original)
> +++ polly/trunk/lib/Transform/ScheduleOptimizer.cpp Sat May 28 11:17:58 2016
> @@ -166,6 +166,11 @@ static cl::list<int>
>                        cl::Hidden, cl::ZeroOrMore, cl::CommaSeparated,
>                        cl::cat(PollyCategory));
>  
> +static cl::opt<bool>
> +    PMBasedOpts("polly-pattern-matching-based-opts",
> +                cl::desc("Perform optimizations based on pattern matching"),
> +                cl::init(false), cl::ZeroOrMore, cl::cat(PollyCategory));
> +
>  /// @brief Create an isl_union_set, which describes the isolate option based
>  ///        on IsoalteDomain.
>  ///
> @@ -359,11 +364,8 @@ bool ScheduleTreeOptimizer::isTileableBa
>  }
>  
>  __isl_give isl_schedule_node *
> -ScheduleTreeOptimizer::optimizeBand(__isl_take isl_schedule_node *Node,
> -                                    void *User) {
> -  if (!isTileableBandNode(Node))
> -    return Node;
> -
> +ScheduleTreeOptimizer::standardBandOpts(__isl_take isl_schedule_node *Node,
> +                                        void *User) {
>    if (FirstLevelTiling)
>      Node = tileNode(Node, "1st level tiling", FirstLevelTileSizes,
>                      FirstLevelDefaultTileSize);
> @@ -396,6 +398,110 @@ ScheduleTreeOptimizer::optimizeBand(__is
>    return Node;
>  }
>  
> +/// @brief Check whether output dimensions of the map rely on the specified
> +///        input dimension.
> +///
> +/// @param IslMap The isl map to be considered.
> +/// @param DimNum The number of an input dimension to be checked.
> +static bool isInputDimUsed(__isl_take isl_map *IslMap, unsigned DimNum) {
> +  auto *CheckedAccessRelation =
> +      isl_map_project_out(isl_map_copy(IslMap), isl_dim_in, DimNum, 1);
> +  CheckedAccessRelation =
> +      isl_map_insert_dims(CheckedAccessRelation, isl_dim_in, DimNum, 1);
> +  auto *InputDimsId = isl_map_get_tuple_id(IslMap, isl_dim_in);
> +  CheckedAccessRelation =
> +      isl_map_set_tuple_id(CheckedAccessRelation, isl_dim_in, InputDimsId);
> +  InputDimsId = isl_map_get_tuple_id(IslMap, isl_dim_out);
> +  CheckedAccessRelation =
> +      isl_map_set_tuple_id(CheckedAccessRelation, isl_dim_out, InputDimsId);
> +  auto res = !isl_map_is_equal(CheckedAccessRelation, IslMap);
> +  isl_map_free(CheckedAccessRelation);
> +  isl_map_free(IslMap);
> +  return res;
> +}
> +
> +/// @brief Check if the SCoP statement could probably be optimized with
> +///        analytical modeling.
> +///
> +/// containsMatrMult tries to determine whether the following conditions
> +/// are true:
> +/// 1. all memory accesses of the statement will have stride 0 or 1,
> +///    if we interchange loops (switch the variable used in the inner
> +///    loop to the outer loop).
> +/// 2. all memory accesses of the statement except from the last one, are
> +///    read memory access and the last one is write memory access.
> +/// 3. all subscripts of the last memory access of the statement don’t contain
> +///    the variable used in the inner loop.
> +///
> +/// @param PartialSchedule The PartialSchedule that contains a SCoP statement
> +///        to check.
> +static bool containsMatrMult(__isl_keep isl_map *PartialSchedule) {
> +  auto InputDimsId = isl_map_get_tuple_id(PartialSchedule, isl_dim_in);
> +  auto *ScpStmt = static_cast<ScopStmt *>(isl_id_get_user(InputDimsId));
> +  isl_id_free(InputDimsId);
> +  if (ScpStmt->size() <= 1)
> +    return false;
> +  auto MemA = ScpStmt->begin();
> +  for (unsigned i = 0; i < ScpStmt->size() - 2 && MemA != ScpStmt->end();
> +       i++, MemA++)
> +    if (!(*MemA)->isRead() or
> +        ((*MemA)->isArrayKind() and
> +         !((*MemA)->isStrideOne(isl_map_copy(PartialSchedule)) or
> +           (*MemA)->isStrideZero(isl_map_copy(PartialSchedule)))))
> +      return false;
> +  MemA++;
> +  if (!(*MemA)->isWrite() or !(*MemA)->isArrayKind() or
> +      !((*MemA)->isStrideOne(isl_map_copy(PartialSchedule)) or
> +        (*MemA)->isStrideZero(isl_map_copy(PartialSchedule))))
> +    return false;
I just realized (again) that "and" and "or" are "word operators".
Consequently, I would not use them as they are not used in the
remaining code base.

> +  auto DimNum = isl_map_dim(PartialSchedule, isl_dim_in);
> +  return !isInputDimUsed((*MemA)->getAccessRelation(), DimNum - 1);
> +}
> +
> +/// @brief Circular shift of output dimensions of the integer map.
> +///
> +/// @param IslMap The isl map to be modified.
> +static __isl_give isl_map *circularShiftOutputDims(__isl_take isl_map *IslMap) {
> +  auto InputDimsId = isl_map_get_tuple_id(IslMap, isl_dim_in);
> +  auto DimNum = isl_map_dim(IslMap, isl_dim_out);
> +  IslMap = isl_map_move_dims(IslMap, isl_dim_in, 0, isl_dim_out, DimNum - 1, 1);
> +  IslMap = isl_map_move_dims(IslMap, isl_dim_out, 0, isl_dim_in, 0, 1);
> +  return isl_map_set_tuple_id(IslMap, isl_dim_in, InputDimsId);
> +}
> +
> +bool ScheduleTreeOptimizer::isMatrMultPattern(
> +    __isl_keep isl_schedule_node *Node) {
> +  auto *PartialSchedule =
> +      isl_schedule_node_band_get_partial_schedule_union_map(Node);
> +  if (isl_union_map_n_map(PartialSchedule) != 1)
> +    return false;
> +  auto *NewPartialSchedule = isl_map_from_union_map(PartialSchedule);
> +  auto DimNum = isl_map_dim(NewPartialSchedule, isl_dim_in);
> +  if (DimNum != 3) {
> +    isl_map_free(NewPartialSchedule);
> +    return false;
> +  }
> +  NewPartialSchedule = circularShiftOutputDims(NewPartialSchedule);
> +  if (containsMatrMult(NewPartialSchedule)) {
> +    isl_map_free(NewPartialSchedule);
> +    return true;
> +  }
> +  isl_map_free(NewPartialSchedule);
> +  return false;
> +}
> +
> +__isl_give isl_schedule_node *
> +ScheduleTreeOptimizer::optimizeBand(__isl_take isl_schedule_node *Node,
> +                                    void *User) {
> +  if (!isTileableBandNode(Node))
> +    return Node;
> +
> +  if (PMBasedOpts && isMatrMultPattern(Node))
> +    DEBUG(dbgs() << "The matrix multiplication pattern was detected\n");
> +
> +  return standardBandOpts(Node, User);
> +}
> +
>  __isl_give isl_schedule *
>  ScheduleTreeOptimizer::optimizeSchedule(__isl_take isl_schedule *Schedule) {
>    isl_schedule_node *Root = isl_schedule_get_root(Schedule);
> 
> Added: polly/trunk/test/ScheduleOptimizer/pattern-matching-based-opts.ll
> URL: http://llvm.org/viewvc/llvm-project/polly/trunk/test/ScheduleOptimizer/pattern-matching-based-opts.ll?rev=271128&view=auto
> ==============================================================================
> --- polly/trunk/test/ScheduleOptimizer/pattern-matching-based-opts.ll (added)
> +++ polly/trunk/test/ScheduleOptimizer/pattern-matching-based-opts.ll Sat May 28 11:17:58 2016
> @@ -0,0 +1,65 @@
> +; RUN: opt %loadPolly -polly-opt-isl -debug < %s 2>&1| FileCheck %s
> +; RUN: opt %loadPolly -polly-opt-isl -polly-pattern-matching-based-opts=true -debug < %s 2>&1| FileCheck %s --check-prefix=PATTERN-MATCHING-OPTS
> +; REQUIRES: asserts
> +; CHECK-NOT: The matrix multiplication pattern was detected
> +; PATTERN-MATCHING-OPTS: The matrix multiplication pattern was detected
> +
Without C code and/or description it is hard to understand the test
cases. As an example, it took me a while to diff the two test cases here
"by hand". I usually use the ./test/create_ll.sh script to get IR from
C code and that already inlines the C code for you. Additionally, it
might be good to write one sentence about the test case if its purpose
is not obvious, e.g., the positive test is pretty straight forward but
the negative test is not.

> +define internal void @kernel_gemm(i32 %arg, i32 %arg1, i32 %arg2, double %arg3, double %arg4, [1056 x double]* %arg5, [1024 x double]* %arg6, [1056 x double]* %arg7) {
> +bb:
> +  br label %bb8
> +
> +bb8:                                              ; preds = %bb39, %bb
> +  %tmp = phi i32 [ 0, %bb ], [ %tmp40, %bb39 ]
> +  %tmp9 = icmp slt i32 %tmp, 1056
> +  br i1 %tmp9, label %bb10, label %bb41
> +
> +bb10:                                             ; preds = %bb8
> +  br label %bb11
> +
> +bb11:                                             ; preds = %bb37, %bb10
> +  %tmp12 = phi i32 [ 0, %bb10 ], [ %tmp38, %bb37 ]
> +  %tmp13 = icmp slt i32 %tmp12, 1056
> +  br i1 %tmp13, label %bb14, label %bb39
> +
> +bb14:                                             ; preds = %bb11
> +  %tmp15 = sext i32 %tmp12 to i64
> +  %tmp16 = sext i32 %tmp to i64
> +  %tmp17 = getelementptr inbounds [1056 x double], [1056 x double]* %arg5, i64 %tmp16
> +  %tmp18 = getelementptr inbounds [1056 x double], [1056 x double]* %tmp17, i64 0, i64 %tmp15
> +  %tmp19 = load double, double* %tmp18, align 8
> +  %tmp20 = fmul double %tmp19, %arg4
> +  store double %tmp20, double* %tmp18, align 8
> +  br label %bb21
> +
> +bb21:                                             ; preds = %bb24, %bb14
> +  %tmp22 = phi i32 [ 0, %bb14 ], [ %tmp36, %bb24 ]
> +  %tmp23 = icmp slt i32 %tmp22, 1024
> +  br i1 %tmp23, label %bb24, label %bb37
> +
> +bb24:                                             ; preds = %bb21
> +  %tmp25 = sext i32 %tmp22 to i64
> +  %tmp26 = getelementptr inbounds [1024 x double], [1024 x double]* %arg6, i64 %tmp16
> +  %tmp27 = getelementptr inbounds [1024 x double], [1024 x double]* %tmp26, i64 0, i64 %tmp25
> +  %tmp28 = load double, double* %tmp27, align 8
> +  %tmp29 = fmul double %arg3, %tmp28
> +  %tmp30 = getelementptr inbounds [1056 x double], [1056 x double]* %arg7, i64 %tmp25
> +  %tmp31 = getelementptr inbounds [1056 x double], [1056 x double]* %tmp30, i64 0, i64 %tmp15
> +  %tmp32 = load double, double* %tmp31, align 8
> +  %tmp33 = fmul double %tmp29, %tmp32
> +  %tmp34 = load double, double* %tmp18, align 8
> +  %tmp35 = fadd double %tmp34, %tmp33
> +  store double %tmp35, double* %tmp18, align 8
> +  %tmp36 = add nsw i32 %tmp22, 1
> +  br label %bb21
> +
> +bb37:                                             ; preds = %bb21
> +  %tmp38 = add nsw i32 %tmp12, 1
> +  br label %bb11
> +
> +bb39:                                             ; preds = %bb11
> +  %tmp40 = add nsw i32 %tmp, 1
> +  br label %bb8
> +
> +bb41:                                             ; preds = %bb8
> +  ret void
> +}
> 
> Added: polly/trunk/test/ScheduleOptimizer/pattern-matching-based-opts_2.ll
> URL: http://llvm.org/viewvc/llvm-project/polly/trunk/test/ScheduleOptimizer/pattern-matching-based-opts_2.ll?rev=271128&view=auto
> ==============================================================================
> --- polly/trunk/test/ScheduleOptimizer/pattern-matching-based-opts_2.ll (added)
> +++ polly/trunk/test/ScheduleOptimizer/pattern-matching-based-opts_2.ll Sat May 28 11:17:58 2016
> @@ -0,0 +1,63 @@
> +; RUN: opt %loadPolly -polly-opt-isl -polly-pattern-matching-based-opts=true -debug < %s 2>&1 | FileCheck %s
> +; REQUIRES: asserts
> +; CHECK-NOT: The matrix multiplication pattern was detected
> +
> +define internal void @kernel_gemm(i32 %arg, i32 %arg1, i32 %arg2, double %arg3, double %arg4, [1056 x double]* %arg5, [1024 x double]* %arg6, [1056 x double]* %arg7) {
> +bb:
> +  br label %bb8
> +
> +bb8:                                              ; preds = %bb39, %bb
> +  %tmp = phi i32 [ 0, %bb ], [ %tmp40, %bb39 ]
> +  %tmp9 = icmp slt i32 %tmp, 1056
> +  br i1 %tmp9, label %bb10, label %bb41
> +
> +bb10:                                             ; preds = %bb8
> +  br label %bb11
> +
> +bb11:                                             ; preds = %bb37, %bb10
> +  %tmp12 = phi i32 [ 0, %bb10 ], [ %tmp38, %bb37 ]
> +  %tmp13 = icmp slt i32 %tmp12, 1056
> +  br i1 %tmp13, label %bb14, label %bb39
> +
> +bb14:                                             ; preds = %bb11
> +  %tmp15 = sext i32 %tmp12 to i64
> +  %tmp16 = sext i32 %tmp to i64
> +  %tmp17 = getelementptr inbounds [1056 x double], [1056 x double]* %arg5, i64 %tmp16
> +  %tmp18 = getelementptr inbounds [1056 x double], [1056 x double]* %tmp17, i64 0, i64 %tmp15
> +  %tmp19 = load double, double* %tmp18, align 8
> +  %tmp20 = fmul double %tmp19, %arg4
> +  store double %tmp20, double* %tmp18, align 8
> +  br label %bb21
> +
> +bb21:                                             ; preds = %bb24, %bb14
> +  %tmp22 = phi i32 [ 0, %bb14 ], [ %tmp36, %bb24 ]
> +  %tmp23 = icmp slt i32 %tmp22, 1024
> +  br i1 %tmp23, label %bb24, label %bb37
> +
> +bb24:                                             ; preds = %bb21
> +  %tmp25 = sext i32 %tmp22 to i64
> +  %tmp26 = getelementptr inbounds [1024 x double], [1024 x double]* %arg6, i64 %tmp16
> +  %tmp27 = getelementptr inbounds [1024 x double], [1024 x double]* %tmp26, i64 0, i64 %tmp25
> +  %tmp28 = load double, double* %tmp27, align 8
> +  %tmp29 = fmul double %arg3, %tmp28
> +  %tmp30 = getelementptr inbounds [1056 x double], [1056 x double]* %arg7, i64 %tmp25
> +  %tmp31 = getelementptr inbounds [1056 x double], [1056 x double]* %tmp30, i64 0, i64 %tmp15
> +  %tmp32 = load double, double* %tmp31, align 8
> +  %tmp33 = fmul double %tmp29, %tmp32
> +  %tmp34 = load double, double* %tmp18, align 8
> +  %tmp35 = fadd double %tmp34, %tmp33
> +  store double %tmp35, double* %tmp18, align 8
> +  %tmp36 = add nsw i32 %tmp22, 1
> +  br label %bb21
> +
> +bb37:                                             ; preds = %bb21
> +  %tmp38 = add nsw i32 %tmp12, 2
> +  br label %bb11
> +
> +bb39:                                             ; preds = %bb11
> +  %tmp40 = add nsw i32 %tmp, 1
> +  br label %bb8
> +
> +bb41:                                             ; preds = %bb8
> +  ret void
> +}
> 
> 
> _______________________________________________
> llvm-commits mailing list
> llvm-commits at lists.llvm.org
> http://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-commits

-- 

Johannes Doerfert
Researcher / PhD Student

Compiler Design Lab (Prof. Hack)
Saarland University, Computer Science
Building E1.3, Room 4.31

Tel. +49 (0)681 302-57521 : doerfert at cs.uni-saarland.de
Fax. +49 (0)681 302-3065  : http://www.cdl.uni-saarland.de/people/doerfert
-------------- next part --------------
A non-text attachment was scrubbed...
Name: signature.asc
Type: application/pgp-signature
Size: 213 bytes
Desc: Digital signature
URL: <http://lists.llvm.org/pipermail/llvm-commits/attachments/20160530/ba068b5f/attachment.sig>


More information about the llvm-commits mailing list