[polly] r271128 - Determination of statements that contain matrix multiplication

Mehdi Amini via llvm-commits llvm-commits at lists.llvm.org
Tue May 31 14:40:42 PDT 2016


> On May 28, 2016, at 9:17 AM, Roman Gareev via llvm-commits <llvm-commits at lists.llvm.org> wrote:
> 
> Author: romangareev
> Date: Sat May 28 11:17:58 2016
> New Revision: 271128
> 
> URL: http://llvm.org/viewvc/llvm-project?rev=271128&view=rev
> Log:
> Determination of statements that contain matrix multiplication
> 
> Add determination of statements that contain, in particular,
> matrix multiplications and can be optimized with [1] to try to
> get close-to-peak performance. It can be enabled
> via polly-pm-based-opts, which is false by default.
> 
> Refs:
> [1] - http://www.cs.utexas.edu/users/flame/pubs/TOMS-BLIS-Analytical.pdf
> 
> Contributed-by: Roman Gareev <gareevroman at gmail.com>
> Reviewed-by: Tobias Grosser <tobias at grosser.es>
> 
> Differential Revision: http://reviews.llvm.org/D20575
> 
> Added:
>    polly/trunk/test/ScheduleOptimizer/pattern-matching-based-opts.ll
>    polly/trunk/test/ScheduleOptimizer/pattern-matching-based-opts_2.ll
> Modified:
>    polly/trunk/include/polly/ScheduleOptimizer.h
>    polly/trunk/lib/Transform/ScheduleOptimizer.cpp
> 
> Modified: polly/trunk/include/polly/ScheduleOptimizer.h
> URL: http://llvm.org/viewvc/llvm-project/polly/trunk/include/polly/ScheduleOptimizer.h?rev=271128&r1=271127&r2=271128&view=diff
> ==============================================================================
> --- polly/trunk/include/polly/ScheduleOptimizer.h (original)
> +++ polly/trunk/include/polly/ScheduleOptimizer.h Sat May 28 11:17:58 2016
> @@ -147,8 +147,45 @@ private:
>   ///      - if vectorization is enabled
>   ///
>   /// @param Node The schedule node to (possibly) optimize.
> -  /// @param User A pointer to forward some use information (currently unused).
> +  /// @param User A pointer to forward some use information
> +  ///        (currently unused).
>   static isl_schedule_node *optimizeBand(isl_schedule_node *Node, void *User);
> +
> +  /// @brief Apply additional optimizations on the bands in the schedule tree.
> +  ///
> +  /// We apply the following
> +  /// transformations:
> +  ///
> +  ///  - Tile the band
> +  ///  - Prevectorize the schedule of the band (or the point loop in case of
> +  ///    tiling).
> +  ///      - if vectorization is enabled
> +  ///
> +  /// @param Node The schedule node to (possibly) optimize.
> +  /// @param User A pointer to forward some use information
> +  ///        (currently unused).
> +  static isl_schedule_node *standardBandOpts(__isl_take isl_schedule_node *Node,
> +                                             void *User);
> +
> +  /// @brief Check if this node contains a partial schedule that could
> +  ///        probably be optimized with analytical modeling.
> +  ///
> +  /// isMatrMultPattern tries to determine whether the following conditions
> +  /// are true:
> +  /// 1. the partial schedule contains only one statement.
> +  /// 2. there are exactly three input dimensions.
> +  /// 3. all memory accesses of the statement will have stride 0 or 1, if we
> +  ///    interchange loops (switch the variable used in the inner loop to
> +  ///    the outer loop).
> +  /// 4. all memory accesses of the statement except from the last one, are
> +  ///    read memory access and the last one is write memory access.
> +  /// 5. all subscripts of the last memory access of the statement don’t
> +  ///    contain the variable used in the inner loop.
> +  /// If this is the case, we could try to use an approach that is similar to
> +  /// the one used to get close-to-peak performance of matrix multiplications.
> +  ///
> +  /// @param Node The node to check.
> +  static bool isMatrMultPattern(__isl_keep isl_schedule_node *Node);
> };
> 
> #endif
> 
> Modified: polly/trunk/lib/Transform/ScheduleOptimizer.cpp
> URL: http://llvm.org/viewvc/llvm-project/polly/trunk/lib/Transform/ScheduleOptimizer.cpp?rev=271128&r1=271127&r2=271128&view=diff
> ==============================================================================
> --- polly/trunk/lib/Transform/ScheduleOptimizer.cpp (original)
> +++ polly/trunk/lib/Transform/ScheduleOptimizer.cpp Sat May 28 11:17:58 2016
> @@ -166,6 +166,11 @@ static cl::list<int>
>                       cl::Hidden, cl::ZeroOrMore, cl::CommaSeparated,
>                       cl::cat(PollyCategory));
> 
> +static cl::opt<bool>
> +    PMBasedOpts("polly-pattern-matching-based-opts",
> +                cl::desc("Perform optimizations based on pattern matching"),
> +                cl::init(false), cl::ZeroOrMore, cl::cat(PollyCategory));
> +
> /// @brief Create an isl_union_set, which describes the isolate option based
> ///        on IsoalteDomain.
> ///
> @@ -359,11 +364,8 @@ bool ScheduleTreeOptimizer::isTileableBa
> }
> 
> __isl_give isl_schedule_node *
> -ScheduleTreeOptimizer::optimizeBand(__isl_take isl_schedule_node *Node,
> -                                    void *User) {
> -  if (!isTileableBandNode(Node))
> -    return Node;
> -
> +ScheduleTreeOptimizer::standardBandOpts(__isl_take isl_schedule_node *Node,
> +                                        void *User) {
>   if (FirstLevelTiling)
>     Node = tileNode(Node, "1st level tiling", FirstLevelTileSizes,
>                     FirstLevelDefaultTileSize);
> @@ -396,6 +398,110 @@ ScheduleTreeOptimizer::optimizeBand(__is
>   return Node;
> }
> 
> +/// @brief Check whether output dimensions of the map rely on the specified
> +///        input dimension.
> +///
> +/// @param IslMap The isl map to be considered.
> +/// @param DimNum The number of an input dimension to be checked.
> +static bool isInputDimUsed(__isl_take isl_map *IslMap, unsigned DimNum) {
> +  auto *CheckedAccessRelation =
> +      isl_map_project_out(isl_map_copy(IslMap), isl_dim_in, DimNum, 1);
> +  CheckedAccessRelation =
> +      isl_map_insert_dims(CheckedAccessRelation, isl_dim_in, DimNum, 1);
> +  auto *InputDimsId = isl_map_get_tuple_id(IslMap, isl_dim_in);
> +  CheckedAccessRelation =
> +      isl_map_set_tuple_id(CheckedAccessRelation, isl_dim_in, InputDimsId);
> +  InputDimsId = isl_map_get_tuple_id(IslMap, isl_dim_out);
> +  CheckedAccessRelation =
> +      isl_map_set_tuple_id(CheckedAccessRelation, isl_dim_out, InputDimsId);
> +  auto res = !isl_map_is_equal(CheckedAccessRelation, IslMap);
> +  isl_map_free(CheckedAccessRelation);
> +  isl_map_free(IslMap);
> +  return res;
> +}
> +
> +/// @brief Check if the SCoP statement could probably be optimized with
> +///        analytical modeling.
> +///
> +/// containsMatrMult tries to determine whether the following conditions
> +/// are true:
> +/// 1. all memory accesses of the statement will have stride 0 or 1,
> +///    if we interchange loops (switch the variable used in the inner
> +///    loop to the outer loop).
> +/// 2. all memory accesses of the statement except from the last one, are
> +///    read memory access and the last one is write memory access.
> +/// 3. all subscripts of the last memory access of the statement don’t contain
> +///    the variable used in the inner loop.
> +///
> +/// @param PartialSchedule The PartialSchedule that contains a SCoP statement
> +///        to check.
> +static bool containsMatrMult(__isl_keep isl_map *PartialSchedule) {
> +  auto InputDimsId = isl_map_get_tuple_id(PartialSchedule, isl_dim_in);
> +  auto *ScpStmt = static_cast<ScopStmt *>(isl_id_get_user(InputDimsId));
> +  isl_id_free(InputDimsId);
> +  if (ScpStmt->size() <= 1)
> +    return false;
> +  auto MemA = ScpStmt->begin();
> +  for (unsigned i = 0; i < ScpStmt->size() - 2 && MemA != ScpStmt->end();
> +       i++, MemA++)
> +    if (!(*MemA)->isRead() or
> +        ((*MemA)->isArrayKind() and
> +         !((*MemA)->isStrideOne(isl_map_copy(PartialSchedule)) or
> +           (*MemA)->isStrideZero(isl_map_copy(PartialSchedule)))))
> +      return false;
> +  MemA++;
> +  if (!(*MemA)->isWrite() or !(*MemA)->isArrayKind() or
> +      !((*MemA)->isStrideOne(isl_map_copy(PartialSchedule)) or
> +        (*MemA)->isStrideZero(isl_map_copy(PartialSchedule))))
> +    return false;
> +  auto DimNum = isl_map_dim(PartialSchedule, isl_dim_in);
> +  return !isInputDimUsed((*MemA)->getAccessRelation(), DimNum - 1);
> +}
> +
> +/// @brief Circular shift of output dimensions of the integer map.
> +///
> +/// @param IslMap The isl map to be modified.
> +static __isl_give isl_map *circularShiftOutputDims(__isl_take isl_map *IslMap) {
> +  auto InputDimsId = isl_map_get_tuple_id(IslMap, isl_dim_in);
> +  auto DimNum = isl_map_dim(IslMap, isl_dim_out);
> +  IslMap = isl_map_move_dims(IslMap, isl_dim_in, 0, isl_dim_out, DimNum - 1, 1);
> +  IslMap = isl_map_move_dims(IslMap, isl_dim_out, 0, isl_dim_in, 0, 1);
> +  return isl_map_set_tuple_id(IslMap, isl_dim_in, InputDimsId);
> +}

Coverity is reporting a potential issue here:

*** CID 1356130:  Insecure data handling  (INTEGER_OVERFLOW)
/tools/polly/lib/Transform/ScheduleOptimizer.cpp: 469 in circularShiftOutputDims(isl_map *)()
463     /// @param IslMap The isl map to be modified.
464     static __isl_give isl_map *circularShiftOutputDims(__isl_take isl_map *IslMap) {
465       auto InputDimsId = isl_map_get_tuple_id(IslMap, isl_dim_in);
466       auto DimNum = isl_map_dim(IslMap, isl_dim_out);
467       IslMap = isl_map_move_dims(IslMap, isl_dim_in, 0, isl_dim_out, DimNum - 1, 1);
468       IslMap = isl_map_move_dims(IslMap, isl_dim_out, 0, isl_dim_in, 0, 1);
>>>     CID 1356130:  Insecure data handling  (INTEGER_OVERFLOW)
>>>     Overflowed or truncated value (or a value computed from an overflowed or truncated value) "isl_map_set_tuple_id(IslMap, isl_dim_in, InputDimsId)" used as return value.
469       return isl_map_set_tuple_id(IslMap, isl_dim_in, InputDimsId);
470     }


-- 
Mehdi




> +
> +bool ScheduleTreeOptimizer::isMatrMultPattern(
> +    __isl_keep isl_schedule_node *Node) {
> +  auto *PartialSchedule =
> +      isl_schedule_node_band_get_partial_schedule_union_map(Node);
> +  if (isl_union_map_n_map(PartialSchedule) != 1)
> +    return false;
> +  auto *NewPartialSchedule = isl_map_from_union_map(PartialSchedule);
> +  auto DimNum = isl_map_dim(NewPartialSchedule, isl_dim_in);
> +  if (DimNum != 3) {
> +    isl_map_free(NewPartialSchedule);
> +    return false;
> +  }
> +  NewPartialSchedule = circularShiftOutputDims(NewPartialSchedule);
> +  if (containsMatrMult(NewPartialSchedule)) {
> +    isl_map_free(NewPartialSchedule);
> +    return true;
> +  }
> +  isl_map_free(NewPartialSchedule);
> +  return false;
> +}
> +
> +__isl_give isl_schedule_node *
> +ScheduleTreeOptimizer::optimizeBand(__isl_take isl_schedule_node *Node,
> +                                    void *User) {
> +  if (!isTileableBandNode(Node))
> +    return Node;
> +
> +  if (PMBasedOpts && isMatrMultPattern(Node))
> +    DEBUG(dbgs() << "The matrix multiplication pattern was detected\n");
> +
> +  return standardBandOpts(Node, User);
> +}
> +
> __isl_give isl_schedule *
> ScheduleTreeOptimizer::optimizeSchedule(__isl_take isl_schedule *Schedule) {
>   isl_schedule_node *Root = isl_schedule_get_root(Schedule);
> 
> Added: polly/trunk/test/ScheduleOptimizer/pattern-matching-based-opts.ll
> URL: http://llvm.org/viewvc/llvm-project/polly/trunk/test/ScheduleOptimizer/pattern-matching-based-opts.ll?rev=271128&view=auto
> ==============================================================================
> --- polly/trunk/test/ScheduleOptimizer/pattern-matching-based-opts.ll (added)
> +++ polly/trunk/test/ScheduleOptimizer/pattern-matching-based-opts.ll Sat May 28 11:17:58 2016
> @@ -0,0 +1,65 @@
> +; RUN: opt %loadPolly -polly-opt-isl -debug < %s 2>&1| FileCheck %s
> +; RUN: opt %loadPolly -polly-opt-isl -polly-pattern-matching-based-opts=true -debug < %s 2>&1| FileCheck %s --check-prefix=PATTERN-MATCHING-OPTS
> +; REQUIRES: asserts
> +; CHECK-NOT: The matrix multiplication pattern was detected
> +; PATTERN-MATCHING-OPTS: The matrix multiplication pattern was detected
> +
> +define internal void @kernel_gemm(i32 %arg, i32 %arg1, i32 %arg2, double %arg3, double %arg4, [1056 x double]* %arg5, [1024 x double]* %arg6, [1056 x double]* %arg7) {
> +bb:
> +  br label %bb8
> +
> +bb8:                                              ; preds = %bb39, %bb
> +  %tmp = phi i32 [ 0, %bb ], [ %tmp40, %bb39 ]
> +  %tmp9 = icmp slt i32 %tmp, 1056
> +  br i1 %tmp9, label %bb10, label %bb41
> +
> +bb10:                                             ; preds = %bb8
> +  br label %bb11
> +
> +bb11:                                             ; preds = %bb37, %bb10
> +  %tmp12 = phi i32 [ 0, %bb10 ], [ %tmp38, %bb37 ]
> +  %tmp13 = icmp slt i32 %tmp12, 1056
> +  br i1 %tmp13, label %bb14, label %bb39
> +
> +bb14:                                             ; preds = %bb11
> +  %tmp15 = sext i32 %tmp12 to i64
> +  %tmp16 = sext i32 %tmp to i64
> +  %tmp17 = getelementptr inbounds [1056 x double], [1056 x double]* %arg5, i64 %tmp16
> +  %tmp18 = getelementptr inbounds [1056 x double], [1056 x double]* %tmp17, i64 0, i64 %tmp15
> +  %tmp19 = load double, double* %tmp18, align 8
> +  %tmp20 = fmul double %tmp19, %arg4
> +  store double %tmp20, double* %tmp18, align 8
> +  br label %bb21
> +
> +bb21:                                             ; preds = %bb24, %bb14
> +  %tmp22 = phi i32 [ 0, %bb14 ], [ %tmp36, %bb24 ]
> +  %tmp23 = icmp slt i32 %tmp22, 1024
> +  br i1 %tmp23, label %bb24, label %bb37
> +
> +bb24:                                             ; preds = %bb21
> +  %tmp25 = sext i32 %tmp22 to i64
> +  %tmp26 = getelementptr inbounds [1024 x double], [1024 x double]* %arg6, i64 %tmp16
> +  %tmp27 = getelementptr inbounds [1024 x double], [1024 x double]* %tmp26, i64 0, i64 %tmp25
> +  %tmp28 = load double, double* %tmp27, align 8
> +  %tmp29 = fmul double %arg3, %tmp28
> +  %tmp30 = getelementptr inbounds [1056 x double], [1056 x double]* %arg7, i64 %tmp25
> +  %tmp31 = getelementptr inbounds [1056 x double], [1056 x double]* %tmp30, i64 0, i64 %tmp15
> +  %tmp32 = load double, double* %tmp31, align 8
> +  %tmp33 = fmul double %tmp29, %tmp32
> +  %tmp34 = load double, double* %tmp18, align 8
> +  %tmp35 = fadd double %tmp34, %tmp33
> +  store double %tmp35, double* %tmp18, align 8
> +  %tmp36 = add nsw i32 %tmp22, 1
> +  br label %bb21
> +
> +bb37:                                             ; preds = %bb21
> +  %tmp38 = add nsw i32 %tmp12, 1
> +  br label %bb11
> +
> +bb39:                                             ; preds = %bb11
> +  %tmp40 = add nsw i32 %tmp, 1
> +  br label %bb8
> +
> +bb41:                                             ; preds = %bb8
> +  ret void
> +}
> 
> Added: polly/trunk/test/ScheduleOptimizer/pattern-matching-based-opts_2.ll
> URL: http://llvm.org/viewvc/llvm-project/polly/trunk/test/ScheduleOptimizer/pattern-matching-based-opts_2.ll?rev=271128&view=auto
> ==============================================================================
> --- polly/trunk/test/ScheduleOptimizer/pattern-matching-based-opts_2.ll (added)
> +++ polly/trunk/test/ScheduleOptimizer/pattern-matching-based-opts_2.ll Sat May 28 11:17:58 2016
> @@ -0,0 +1,63 @@
> +; RUN: opt %loadPolly -polly-opt-isl -polly-pattern-matching-based-opts=true -debug < %s 2>&1 | FileCheck %s
> +; REQUIRES: asserts
> +; CHECK-NOT: The matrix multiplication pattern was detected
> +
> +define internal void @kernel_gemm(i32 %arg, i32 %arg1, i32 %arg2, double %arg3, double %arg4, [1056 x double]* %arg5, [1024 x double]* %arg6, [1056 x double]* %arg7) {
> +bb:
> +  br label %bb8
> +
> +bb8:                                              ; preds = %bb39, %bb
> +  %tmp = phi i32 [ 0, %bb ], [ %tmp40, %bb39 ]
> +  %tmp9 = icmp slt i32 %tmp, 1056
> +  br i1 %tmp9, label %bb10, label %bb41
> +
> +bb10:                                             ; preds = %bb8
> +  br label %bb11
> +
> +bb11:                                             ; preds = %bb37, %bb10
> +  %tmp12 = phi i32 [ 0, %bb10 ], [ %tmp38, %bb37 ]
> +  %tmp13 = icmp slt i32 %tmp12, 1056
> +  br i1 %tmp13, label %bb14, label %bb39
> +
> +bb14:                                             ; preds = %bb11
> +  %tmp15 = sext i32 %tmp12 to i64
> +  %tmp16 = sext i32 %tmp to i64
> +  %tmp17 = getelementptr inbounds [1056 x double], [1056 x double]* %arg5, i64 %tmp16
> +  %tmp18 = getelementptr inbounds [1056 x double], [1056 x double]* %tmp17, i64 0, i64 %tmp15
> +  %tmp19 = load double, double* %tmp18, align 8
> +  %tmp20 = fmul double %tmp19, %arg4
> +  store double %tmp20, double* %tmp18, align 8
> +  br label %bb21
> +
> +bb21:                                             ; preds = %bb24, %bb14
> +  %tmp22 = phi i32 [ 0, %bb14 ], [ %tmp36, %bb24 ]
> +  %tmp23 = icmp slt i32 %tmp22, 1024
> +  br i1 %tmp23, label %bb24, label %bb37
> +
> +bb24:                                             ; preds = %bb21
> +  %tmp25 = sext i32 %tmp22 to i64
> +  %tmp26 = getelementptr inbounds [1024 x double], [1024 x double]* %arg6, i64 %tmp16
> +  %tmp27 = getelementptr inbounds [1024 x double], [1024 x double]* %tmp26, i64 0, i64 %tmp25
> +  %tmp28 = load double, double* %tmp27, align 8
> +  %tmp29 = fmul double %arg3, %tmp28
> +  %tmp30 = getelementptr inbounds [1056 x double], [1056 x double]* %arg7, i64 %tmp25
> +  %tmp31 = getelementptr inbounds [1056 x double], [1056 x double]* %tmp30, i64 0, i64 %tmp15
> +  %tmp32 = load double, double* %tmp31, align 8
> +  %tmp33 = fmul double %tmp29, %tmp32
> +  %tmp34 = load double, double* %tmp18, align 8
> +  %tmp35 = fadd double %tmp34, %tmp33
> +  store double %tmp35, double* %tmp18, align 8
> +  %tmp36 = add nsw i32 %tmp22, 1
> +  br label %bb21
> +
> +bb37:                                             ; preds = %bb21
> +  %tmp38 = add nsw i32 %tmp12, 2
> +  br label %bb11
> +
> +bb39:                                             ; preds = %bb11
> +  %tmp40 = add nsw i32 %tmp, 1
> +  br label %bb8
> +
> +bb41:                                             ; preds = %bb8
> +  ret void
> +}
> 
> 
> _______________________________________________
> llvm-commits mailing list
> llvm-commits at lists.llvm.org
> http://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-commits



More information about the llvm-commits mailing list