[polly] r271128 - Determination of statements that contain matrix multiplication
Roman Gareev via llvm-commits
llvm-commits at lists.llvm.org
Tue May 31 03:10:26 PDT 2016
Hi Johannes,
thank you for the comments and the advices! I’ve tried to address them
in http://reviews.llvm.org/D20806
2016-05-30 14:04 GMT+05:00 Johannes Doerfert <doerfert at cs.uni-saarland.de>:
> Hey Roman,
>
> First, thanks for the patch!
>
> I inlined three small change requests you can think about.
>
> Cheers,
> Johannes
>
> On 05/28, Roman Gareev via llvm-commits wrote:
>> Author: romangareev
>> Date: Sat May 28 11:17:58 2016
>> New Revision: 271128
>>
>> URL: http://llvm.org/viewvc/llvm-project?rev=271128&view=rev
>> Log:
>> Determination of statements that contain matrix multiplication
>>
>> Add determination of statements that contain, in particular,
>> matrix multiplications and can be optimized with [1] to try to
>> get close-to-peak performance. It can be enabled
>> via polly-pm-based-opts, which is false by default.
>>
>> Refs:
>> [1] - http://www.cs.utexas.edu/users/flame/pubs/TOMS-BLIS-Analytical.pdf
>>
>> Contributed-by: Roman Gareev <gareevroman at gmail.com>
>> Reviewed-by: Tobias Grosser <tobias at grosser.es>
>>
>> Differential Revision: http://reviews.llvm.org/D20575
>>
>> Added:
>> polly/trunk/test/ScheduleOptimizer/pattern-matching-based-opts.ll
>> polly/trunk/test/ScheduleOptimizer/pattern-matching-based-opts_2.ll
>> Modified:
>> polly/trunk/include/polly/ScheduleOptimizer.h
>> polly/trunk/lib/Transform/ScheduleOptimizer.cpp
>>
>> Modified: polly/trunk/include/polly/ScheduleOptimizer.h
>> URL: http://llvm.org/viewvc/llvm-project/polly/trunk/include/polly/ScheduleOptimizer.h?rev=271128&r1=271127&r2=271128&view=diff
>> ==============================================================================
>> --- polly/trunk/include/polly/ScheduleOptimizer.h (original)
>> +++ polly/trunk/include/polly/ScheduleOptimizer.h Sat May 28 11:17:58 2016
>> @@ -147,8 +147,45 @@ private:
>> /// - if vectorization is enabled
>> ///
>> /// @param Node The schedule node to (possibly) optimize.
>> - /// @param User A pointer to forward some use information (currently unused).
>> + /// @param User A pointer to forward some use information
>> + /// (currently unused).
>> static isl_schedule_node *optimizeBand(isl_schedule_node *Node, void *User);
>> +
>> + /// @brief Apply additional optimizations on the bands in the schedule tree.
>> + ///
>> + /// We apply the following
>> + /// transformations:
>> + ///
>> + /// - Tile the band
>> + /// - Prevectorize the schedule of the band (or the point loop in case of
>> + /// tiling).
>> + /// - if vectorization is enabled
>> + ///
>> + /// @param Node The schedule node to (possibly) optimize.
>> + /// @param User A pointer to forward some use information
>> + /// (currently unused).
>> + static isl_schedule_node *standardBandOpts(__isl_take isl_schedule_node *Node,
>> + void *User);
>> +
>> + /// @brief Check if this node contains a partial schedule that could
>> + /// probably be optimized with analytical modeling.
>> + ///
>> + /// isMatrMultPattern tries to determine whether the following conditions
>> + /// are true:
>> + /// 1. the partial schedule contains only one statement.
>> + /// 2. there are exactly three input dimensions.
>> + /// 3. all memory accesses of the statement will have stride 0 or 1, if we
>> + /// interchange loops (switch the variable used in the inner loop to
>> + /// the outer loop).
>> + /// 4. all memory accesses of the statement except from the last one, are
>> + /// read memory access and the last one is write memory access.
>> + /// 5. all subscripts of the last memory access of the statement don’t
> Can you use the plain version of "'" in the word "don't"? While my editor
> is fine with this fancy (maybe utf8) character, I do not want to find out
> which one is not.
>
>> + /// contain the variable used in the inner loop.
>> + /// If this is the case, we could try to use an approach that is similar to
>> + /// the one used to get close-to-peak performance of matrix multiplications.
>> + ///
>> + /// @param Node The node to check.
>> + static bool isMatrMultPattern(__isl_keep isl_schedule_node *Node);
>> };
>>
>> #endif
>>
>> Modified: polly/trunk/lib/Transform/ScheduleOptimizer.cpp
>> URL: http://llvm.org/viewvc/llvm-project/polly/trunk/lib/Transform/ScheduleOptimizer.cpp?rev=271128&r1=271127&r2=271128&view=diff
>> ==============================================================================
>> --- polly/trunk/lib/Transform/ScheduleOptimizer.cpp (original)
>> +++ polly/trunk/lib/Transform/ScheduleOptimizer.cpp Sat May 28 11:17:58 2016
>> @@ -166,6 +166,11 @@ static cl::list<int>
>> cl::Hidden, cl::ZeroOrMore, cl::CommaSeparated,
>> cl::cat(PollyCategory));
>>
>> +static cl::opt<bool>
>> + PMBasedOpts("polly-pattern-matching-based-opts",
>> + cl::desc("Perform optimizations based on pattern matching"),
>> + cl::init(false), cl::ZeroOrMore, cl::cat(PollyCategory));
>> +
>> /// @brief Create an isl_union_set, which describes the isolate option based
>> /// on IsoalteDomain.
>> ///
>> @@ -359,11 +364,8 @@ bool ScheduleTreeOptimizer::isTileableBa
>> }
>>
>> __isl_give isl_schedule_node *
>> -ScheduleTreeOptimizer::optimizeBand(__isl_take isl_schedule_node *Node,
>> - void *User) {
>> - if (!isTileableBandNode(Node))
>> - return Node;
>> -
>> +ScheduleTreeOptimizer::standardBandOpts(__isl_take isl_schedule_node *Node,
>> + void *User) {
>> if (FirstLevelTiling)
>> Node = tileNode(Node, "1st level tiling", FirstLevelTileSizes,
>> FirstLevelDefaultTileSize);
>> @@ -396,6 +398,110 @@ ScheduleTreeOptimizer::optimizeBand(__is
>> return Node;
>> }
>>
>> +/// @brief Check whether output dimensions of the map rely on the specified
>> +/// input dimension.
>> +///
>> +/// @param IslMap The isl map to be considered.
>> +/// @param DimNum The number of an input dimension to be checked.
>> +static bool isInputDimUsed(__isl_take isl_map *IslMap, unsigned DimNum) {
>> + auto *CheckedAccessRelation =
>> + isl_map_project_out(isl_map_copy(IslMap), isl_dim_in, DimNum, 1);
>> + CheckedAccessRelation =
>> + isl_map_insert_dims(CheckedAccessRelation, isl_dim_in, DimNum, 1);
>> + auto *InputDimsId = isl_map_get_tuple_id(IslMap, isl_dim_in);
>> + CheckedAccessRelation =
>> + isl_map_set_tuple_id(CheckedAccessRelation, isl_dim_in, InputDimsId);
>> + InputDimsId = isl_map_get_tuple_id(IslMap, isl_dim_out);
>> + CheckedAccessRelation =
>> + isl_map_set_tuple_id(CheckedAccessRelation, isl_dim_out, InputDimsId);
>> + auto res = !isl_map_is_equal(CheckedAccessRelation, IslMap);
>> + isl_map_free(CheckedAccessRelation);
>> + isl_map_free(IslMap);
>> + return res;
>> +}
>> +
>> +/// @brief Check if the SCoP statement could probably be optimized with
>> +/// analytical modeling.
>> +///
>> +/// containsMatrMult tries to determine whether the following conditions
>> +/// are true:
>> +/// 1. all memory accesses of the statement will have stride 0 or 1,
>> +/// if we interchange loops (switch the variable used in the inner
>> +/// loop to the outer loop).
>> +/// 2. all memory accesses of the statement except from the last one, are
>> +/// read memory access and the last one is write memory access.
>> +/// 3. all subscripts of the last memory access of the statement don’t contain
>> +/// the variable used in the inner loop.
>> +///
>> +/// @param PartialSchedule The PartialSchedule that contains a SCoP statement
>> +/// to check.
>> +static bool containsMatrMult(__isl_keep isl_map *PartialSchedule) {
>> + auto InputDimsId = isl_map_get_tuple_id(PartialSchedule, isl_dim_in);
>> + auto *ScpStmt = static_cast<ScopStmt *>(isl_id_get_user(InputDimsId));
>> + isl_id_free(InputDimsId);
>> + if (ScpStmt->size() <= 1)
>> + return false;
>> + auto MemA = ScpStmt->begin();
>> + for (unsigned i = 0; i < ScpStmt->size() - 2 && MemA != ScpStmt->end();
>> + i++, MemA++)
>> + if (!(*MemA)->isRead() or
>> + ((*MemA)->isArrayKind() and
>> + !((*MemA)->isStrideOne(isl_map_copy(PartialSchedule)) or
>> + (*MemA)->isStrideZero(isl_map_copy(PartialSchedule)))))
>> + return false;
>> + MemA++;
>> + if (!(*MemA)->isWrite() or !(*MemA)->isArrayKind() or
>> + !((*MemA)->isStrideOne(isl_map_copy(PartialSchedule)) or
>> + (*MemA)->isStrideZero(isl_map_copy(PartialSchedule))))
>> + return false;
> I just realized (again) that "and" and "or" are "word operators".
> Consequently, I would not use them as they are not used in the
> remaining code base.
>
>> + auto DimNum = isl_map_dim(PartialSchedule, isl_dim_in);
>> + return !isInputDimUsed((*MemA)->getAccessRelation(), DimNum - 1);
>> +}
>> +
>> +/// @brief Circular shift of output dimensions of the integer map.
>> +///
>> +/// @param IslMap The isl map to be modified.
>> +static __isl_give isl_map *circularShiftOutputDims(__isl_take isl_map *IslMap) {
>> + auto InputDimsId = isl_map_get_tuple_id(IslMap, isl_dim_in);
>> + auto DimNum = isl_map_dim(IslMap, isl_dim_out);
>> + IslMap = isl_map_move_dims(IslMap, isl_dim_in, 0, isl_dim_out, DimNum - 1, 1);
>> + IslMap = isl_map_move_dims(IslMap, isl_dim_out, 0, isl_dim_in, 0, 1);
>> + return isl_map_set_tuple_id(IslMap, isl_dim_in, InputDimsId);
>> +}
>> +
>> +bool ScheduleTreeOptimizer::isMatrMultPattern(
>> + __isl_keep isl_schedule_node *Node) {
>> + auto *PartialSchedule =
>> + isl_schedule_node_band_get_partial_schedule_union_map(Node);
>> + if (isl_union_map_n_map(PartialSchedule) != 1)
>> + return false;
>> + auto *NewPartialSchedule = isl_map_from_union_map(PartialSchedule);
>> + auto DimNum = isl_map_dim(NewPartialSchedule, isl_dim_in);
>> + if (DimNum != 3) {
>> + isl_map_free(NewPartialSchedule);
>> + return false;
>> + }
>> + NewPartialSchedule = circularShiftOutputDims(NewPartialSchedule);
>> + if (containsMatrMult(NewPartialSchedule)) {
>> + isl_map_free(NewPartialSchedule);
>> + return true;
>> + }
>> + isl_map_free(NewPartialSchedule);
>> + return false;
>> +}
>> +
>> +__isl_give isl_schedule_node *
>> +ScheduleTreeOptimizer::optimizeBand(__isl_take isl_schedule_node *Node,
>> + void *User) {
>> + if (!isTileableBandNode(Node))
>> + return Node;
>> +
>> + if (PMBasedOpts && isMatrMultPattern(Node))
>> + DEBUG(dbgs() << "The matrix multiplication pattern was detected\n");
>> +
>> + return standardBandOpts(Node, User);
>> +}
>> +
>> __isl_give isl_schedule *
>> ScheduleTreeOptimizer::optimizeSchedule(__isl_take isl_schedule *Schedule) {
>> isl_schedule_node *Root = isl_schedule_get_root(Schedule);
>>
>> Added: polly/trunk/test/ScheduleOptimizer/pattern-matching-based-opts.ll
>> URL: http://llvm.org/viewvc/llvm-project/polly/trunk/test/ScheduleOptimizer/pattern-matching-based-opts.ll?rev=271128&view=auto
>> ==============================================================================
>> --- polly/trunk/test/ScheduleOptimizer/pattern-matching-based-opts.ll (added)
>> +++ polly/trunk/test/ScheduleOptimizer/pattern-matching-based-opts.ll Sat May 28 11:17:58 2016
>> @@ -0,0 +1,65 @@
>> +; RUN: opt %loadPolly -polly-opt-isl -debug < %s 2>&1| FileCheck %s
>> +; RUN: opt %loadPolly -polly-opt-isl -polly-pattern-matching-based-opts=true -debug < %s 2>&1| FileCheck %s --check-prefix=PATTERN-MATCHING-OPTS
>> +; REQUIRES: asserts
>> +; CHECK-NOT: The matrix multiplication pattern was detected
>> +; PATTERN-MATCHING-OPTS: The matrix multiplication pattern was detected
>> +
> Without C code and/or description it is hard to understand the test
> cases. As an example, it took me a while to diff the two test cases here
> "by hand". I usually use the ./test/create_ll.sh script to get IR from
> C code and that already inlines the C code for you. Additionally, it
> might be good to write one sentence about the test case if its purpose
> is not obvious, e.g., the positive test is pretty straight forward but
> the negative test is not.
>
>> +define internal void @kernel_gemm(i32 %arg, i32 %arg1, i32 %arg2, double %arg3, double %arg4, [1056 x double]* %arg5, [1024 x double]* %arg6, [1056 x double]* %arg7) {
>> +bb:
>> + br label %bb8
>> +
>> +bb8: ; preds = %bb39, %bb
>> + %tmp = phi i32 [ 0, %bb ], [ %tmp40, %bb39 ]
>> + %tmp9 = icmp slt i32 %tmp, 1056
>> + br i1 %tmp9, label %bb10, label %bb41
>> +
>> +bb10: ; preds = %bb8
>> + br label %bb11
>> +
>> +bb11: ; preds = %bb37, %bb10
>> + %tmp12 = phi i32 [ 0, %bb10 ], [ %tmp38, %bb37 ]
>> + %tmp13 = icmp slt i32 %tmp12, 1056
>> + br i1 %tmp13, label %bb14, label %bb39
>> +
>> +bb14: ; preds = %bb11
>> + %tmp15 = sext i32 %tmp12 to i64
>> + %tmp16 = sext i32 %tmp to i64
>> + %tmp17 = getelementptr inbounds [1056 x double], [1056 x double]* %arg5, i64 %tmp16
>> + %tmp18 = getelementptr inbounds [1056 x double], [1056 x double]* %tmp17, i64 0, i64 %tmp15
>> + %tmp19 = load double, double* %tmp18, align 8
>> + %tmp20 = fmul double %tmp19, %arg4
>> + store double %tmp20, double* %tmp18, align 8
>> + br label %bb21
>> +
>> +bb21: ; preds = %bb24, %bb14
>> + %tmp22 = phi i32 [ 0, %bb14 ], [ %tmp36, %bb24 ]
>> + %tmp23 = icmp slt i32 %tmp22, 1024
>> + br i1 %tmp23, label %bb24, label %bb37
>> +
>> +bb24: ; preds = %bb21
>> + %tmp25 = sext i32 %tmp22 to i64
>> + %tmp26 = getelementptr inbounds [1024 x double], [1024 x double]* %arg6, i64 %tmp16
>> + %tmp27 = getelementptr inbounds [1024 x double], [1024 x double]* %tmp26, i64 0, i64 %tmp25
>> + %tmp28 = load double, double* %tmp27, align 8
>> + %tmp29 = fmul double %arg3, %tmp28
>> + %tmp30 = getelementptr inbounds [1056 x double], [1056 x double]* %arg7, i64 %tmp25
>> + %tmp31 = getelementptr inbounds [1056 x double], [1056 x double]* %tmp30, i64 0, i64 %tmp15
>> + %tmp32 = load double, double* %tmp31, align 8
>> + %tmp33 = fmul double %tmp29, %tmp32
>> + %tmp34 = load double, double* %tmp18, align 8
>> + %tmp35 = fadd double %tmp34, %tmp33
>> + store double %tmp35, double* %tmp18, align 8
>> + %tmp36 = add nsw i32 %tmp22, 1
>> + br label %bb21
>> +
>> +bb37: ; preds = %bb21
>> + %tmp38 = add nsw i32 %tmp12, 1
>> + br label %bb11
>> +
>> +bb39: ; preds = %bb11
>> + %tmp40 = add nsw i32 %tmp, 1
>> + br label %bb8
>> +
>> +bb41: ; preds = %bb8
>> + ret void
>> +}
>>
>> Added: polly/trunk/test/ScheduleOptimizer/pattern-matching-based-opts_2.ll
>> URL: http://llvm.org/viewvc/llvm-project/polly/trunk/test/ScheduleOptimizer/pattern-matching-based-opts_2.ll?rev=271128&view=auto
>> ==============================================================================
>> --- polly/trunk/test/ScheduleOptimizer/pattern-matching-based-opts_2.ll (added)
>> +++ polly/trunk/test/ScheduleOptimizer/pattern-matching-based-opts_2.ll Sat May 28 11:17:58 2016
>> @@ -0,0 +1,63 @@
>> +; RUN: opt %loadPolly -polly-opt-isl -polly-pattern-matching-based-opts=true -debug < %s 2>&1 | FileCheck %s
>> +; REQUIRES: asserts
>> +; CHECK-NOT: The matrix multiplication pattern was detected
>> +
>> +define internal void @kernel_gemm(i32 %arg, i32 %arg1, i32 %arg2, double %arg3, double %arg4, [1056 x double]* %arg5, [1024 x double]* %arg6, [1056 x double]* %arg7) {
>> +bb:
>> + br label %bb8
>> +
>> +bb8: ; preds = %bb39, %bb
>> + %tmp = phi i32 [ 0, %bb ], [ %tmp40, %bb39 ]
>> + %tmp9 = icmp slt i32 %tmp, 1056
>> + br i1 %tmp9, label %bb10, label %bb41
>> +
>> +bb10: ; preds = %bb8
>> + br label %bb11
>> +
>> +bb11: ; preds = %bb37, %bb10
>> + %tmp12 = phi i32 [ 0, %bb10 ], [ %tmp38, %bb37 ]
>> + %tmp13 = icmp slt i32 %tmp12, 1056
>> + br i1 %tmp13, label %bb14, label %bb39
>> +
>> +bb14: ; preds = %bb11
>> + %tmp15 = sext i32 %tmp12 to i64
>> + %tmp16 = sext i32 %tmp to i64
>> + %tmp17 = getelementptr inbounds [1056 x double], [1056 x double]* %arg5, i64 %tmp16
>> + %tmp18 = getelementptr inbounds [1056 x double], [1056 x double]* %tmp17, i64 0, i64 %tmp15
>> + %tmp19 = load double, double* %tmp18, align 8
>> + %tmp20 = fmul double %tmp19, %arg4
>> + store double %tmp20, double* %tmp18, align 8
>> + br label %bb21
>> +
>> +bb21: ; preds = %bb24, %bb14
>> + %tmp22 = phi i32 [ 0, %bb14 ], [ %tmp36, %bb24 ]
>> + %tmp23 = icmp slt i32 %tmp22, 1024
>> + br i1 %tmp23, label %bb24, label %bb37
>> +
>> +bb24: ; preds = %bb21
>> + %tmp25 = sext i32 %tmp22 to i64
>> + %tmp26 = getelementptr inbounds [1024 x double], [1024 x double]* %arg6, i64 %tmp16
>> + %tmp27 = getelementptr inbounds [1024 x double], [1024 x double]* %tmp26, i64 0, i64 %tmp25
>> + %tmp28 = load double, double* %tmp27, align 8
>> + %tmp29 = fmul double %arg3, %tmp28
>> + %tmp30 = getelementptr inbounds [1056 x double], [1056 x double]* %arg7, i64 %tmp25
>> + %tmp31 = getelementptr inbounds [1056 x double], [1056 x double]* %tmp30, i64 0, i64 %tmp15
>> + %tmp32 = load double, double* %tmp31, align 8
>> + %tmp33 = fmul double %tmp29, %tmp32
>> + %tmp34 = load double, double* %tmp18, align 8
>> + %tmp35 = fadd double %tmp34, %tmp33
>> + store double %tmp35, double* %tmp18, align 8
>> + %tmp36 = add nsw i32 %tmp22, 1
>> + br label %bb21
>> +
>> +bb37: ; preds = %bb21
>> + %tmp38 = add nsw i32 %tmp12, 2
>> + br label %bb11
>> +
>> +bb39: ; preds = %bb11
>> + %tmp40 = add nsw i32 %tmp, 1
>> + br label %bb8
>> +
>> +bb41: ; preds = %bb8
>> + ret void
>> +}
>>
>>
>> _______________________________________________
>> llvm-commits mailing list
>> llvm-commits at lists.llvm.org
>> http://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-commits
>
> --
>
> Johannes Doerfert
> Researcher / PhD Student
>
> Compiler Design Lab (Prof. Hack)
> Saarland University, Computer Science
> Building E1.3, Room 4.31
>
> Tel. +49 (0)681 302-57521 : doerfert at cs.uni-saarland.de
> Fax. +49 (0)681 302-3065 : http://www.cdl.uni-saarland.de/people/doerfert
--
Cheers, Roman Gareev.
More information about the llvm-commits
mailing list