[polly] r271128 - Determination of statements that contain matrix multiplication
Roman Gareev via llvm-commits
llvm-commits at lists.llvm.org
Sat May 28 09:17:59 PDT 2016
Author: romangareev
Date: Sat May 28 11:17:58 2016
New Revision: 271128
URL: http://llvm.org/viewvc/llvm-project?rev=271128&view=rev
Log:
Determination of statements that contain matrix multiplication
Add determination of statements that contain, in particular,
matrix multiplications and can be optimized with [1] to try to
get close-to-peak performance. It can be enabled
via polly-pm-based-opts, which is false by default.
Refs:
[1] - http://www.cs.utexas.edu/users/flame/pubs/TOMS-BLIS-Analytical.pdf
Contributed-by: Roman Gareev <gareevroman at gmail.com>
Reviewed-by: Tobias Grosser <tobias at grosser.es>
Differential Revision: http://reviews.llvm.org/D20575
Added:
polly/trunk/test/ScheduleOptimizer/pattern-matching-based-opts.ll
polly/trunk/test/ScheduleOptimizer/pattern-matching-based-opts_2.ll
Modified:
polly/trunk/include/polly/ScheduleOptimizer.h
polly/trunk/lib/Transform/ScheduleOptimizer.cpp
Modified: polly/trunk/include/polly/ScheduleOptimizer.h
URL: http://llvm.org/viewvc/llvm-project/polly/trunk/include/polly/ScheduleOptimizer.h?rev=271128&r1=271127&r2=271128&view=diff
==============================================================================
--- polly/trunk/include/polly/ScheduleOptimizer.h (original)
+++ polly/trunk/include/polly/ScheduleOptimizer.h Sat May 28 11:17:58 2016
@@ -147,8 +147,45 @@ private:
/// - if vectorization is enabled
///
/// @param Node The schedule node to (possibly) optimize.
- /// @param User A pointer to forward some use information (currently unused).
+ /// @param User A pointer to forward some use information
+ /// (currently unused).
static isl_schedule_node *optimizeBand(isl_schedule_node *Node, void *User);
+
+ /// @brief Apply additional optimizations on the bands in the schedule tree.
+ ///
+ /// We apply the following
+ /// transformations:
+ ///
+ /// - Tile the band
+ /// - Prevectorize the schedule of the band (or the point loop in case of
+ /// tiling).
+ /// - if vectorization is enabled
+ ///
+ /// @param Node The schedule node to (possibly) optimize.
+ /// @param User A pointer to forward some use information
+ /// (currently unused).
+ static isl_schedule_node *standardBandOpts(__isl_take isl_schedule_node *Node,
+ void *User);
+
+ /// @brief Check if this node contains a partial schedule that could
+ /// probably be optimized with analytical modeling.
+ ///
+ /// isMatrMultPattern tries to determine whether the following conditions
+ /// are true:
+ /// 1. the partial schedule contains only one statement.
+ /// 2. there are exactly three input dimensions.
+ /// 3. all memory accesses of the statement will have stride 0 or 1, if we
+ /// interchange loops (switch the variable used in the inner loop to
+ /// the outer loop).
+ /// 4. all memory accesses of the statement except from the last one, are
+ /// read memory access and the last one is write memory access.
+ /// 5. all subscripts of the last memory access of the statement donât
+ /// contain the variable used in the inner loop.
+ /// If this is the case, we could try to use an approach that is similar to
+ /// the one used to get close-to-peak performance of matrix multiplications.
+ ///
+ /// @param Node The node to check.
+ static bool isMatrMultPattern(__isl_keep isl_schedule_node *Node);
};
#endif
Modified: polly/trunk/lib/Transform/ScheduleOptimizer.cpp
URL: http://llvm.org/viewvc/llvm-project/polly/trunk/lib/Transform/ScheduleOptimizer.cpp?rev=271128&r1=271127&r2=271128&view=diff
==============================================================================
--- polly/trunk/lib/Transform/ScheduleOptimizer.cpp (original)
+++ polly/trunk/lib/Transform/ScheduleOptimizer.cpp Sat May 28 11:17:58 2016
@@ -166,6 +166,11 @@ static cl::list<int>
cl::Hidden, cl::ZeroOrMore, cl::CommaSeparated,
cl::cat(PollyCategory));
+static cl::opt<bool>
+ PMBasedOpts("polly-pattern-matching-based-opts",
+ cl::desc("Perform optimizations based on pattern matching"),
+ cl::init(false), cl::ZeroOrMore, cl::cat(PollyCategory));
+
/// @brief Create an isl_union_set, which describes the isolate option based
/// on IsoalteDomain.
///
@@ -359,11 +364,8 @@ bool ScheduleTreeOptimizer::isTileableBa
}
__isl_give isl_schedule_node *
-ScheduleTreeOptimizer::optimizeBand(__isl_take isl_schedule_node *Node,
- void *User) {
- if (!isTileableBandNode(Node))
- return Node;
-
+ScheduleTreeOptimizer::standardBandOpts(__isl_take isl_schedule_node *Node,
+ void *User) {
if (FirstLevelTiling)
Node = tileNode(Node, "1st level tiling", FirstLevelTileSizes,
FirstLevelDefaultTileSize);
@@ -396,6 +398,110 @@ ScheduleTreeOptimizer::optimizeBand(__is
return Node;
}
+/// @brief Check whether output dimensions of the map rely on the specified
+/// input dimension.
+///
+/// @param IslMap The isl map to be considered.
+/// @param DimNum The number of an input dimension to be checked.
+static bool isInputDimUsed(__isl_take isl_map *IslMap, unsigned DimNum) {
+ auto *CheckedAccessRelation =
+ isl_map_project_out(isl_map_copy(IslMap), isl_dim_in, DimNum, 1);
+ CheckedAccessRelation =
+ isl_map_insert_dims(CheckedAccessRelation, isl_dim_in, DimNum, 1);
+ auto *InputDimsId = isl_map_get_tuple_id(IslMap, isl_dim_in);
+ CheckedAccessRelation =
+ isl_map_set_tuple_id(CheckedAccessRelation, isl_dim_in, InputDimsId);
+ InputDimsId = isl_map_get_tuple_id(IslMap, isl_dim_out);
+ CheckedAccessRelation =
+ isl_map_set_tuple_id(CheckedAccessRelation, isl_dim_out, InputDimsId);
+ auto res = !isl_map_is_equal(CheckedAccessRelation, IslMap);
+ isl_map_free(CheckedAccessRelation);
+ isl_map_free(IslMap);
+ return res;
+}
+
+/// @brief Check if the SCoP statement could probably be optimized with
+/// analytical modeling.
+///
+/// containsMatrMult tries to determine whether the following conditions
+/// are true:
+/// 1. all memory accesses of the statement will have stride 0 or 1,
+/// if we interchange loops (switch the variable used in the inner
+/// loop to the outer loop).
+/// 2. all memory accesses of the statement except from the last one, are
+/// read memory access and the last one is write memory access.
+/// 3. all subscripts of the last memory access of the statement donât contain
+/// the variable used in the inner loop.
+///
+/// @param PartialSchedule The PartialSchedule that contains a SCoP statement
+/// to check.
+static bool containsMatrMult(__isl_keep isl_map *PartialSchedule) {
+ auto InputDimsId = isl_map_get_tuple_id(PartialSchedule, isl_dim_in);
+ auto *ScpStmt = static_cast<ScopStmt *>(isl_id_get_user(InputDimsId));
+ isl_id_free(InputDimsId);
+ if (ScpStmt->size() <= 1)
+ return false;
+ auto MemA = ScpStmt->begin();
+ for (unsigned i = 0; i < ScpStmt->size() - 2 && MemA != ScpStmt->end();
+ i++, MemA++)
+ if (!(*MemA)->isRead() or
+ ((*MemA)->isArrayKind() and
+ !((*MemA)->isStrideOne(isl_map_copy(PartialSchedule)) or
+ (*MemA)->isStrideZero(isl_map_copy(PartialSchedule)))))
+ return false;
+ MemA++;
+ if (!(*MemA)->isWrite() or !(*MemA)->isArrayKind() or
+ !((*MemA)->isStrideOne(isl_map_copy(PartialSchedule)) or
+ (*MemA)->isStrideZero(isl_map_copy(PartialSchedule))))
+ return false;
+ auto DimNum = isl_map_dim(PartialSchedule, isl_dim_in);
+ return !isInputDimUsed((*MemA)->getAccessRelation(), DimNum - 1);
+}
+
+/// @brief Circular shift of output dimensions of the integer map.
+///
+/// @param IslMap The isl map to be modified.
+static __isl_give isl_map *circularShiftOutputDims(__isl_take isl_map *IslMap) {
+ auto InputDimsId = isl_map_get_tuple_id(IslMap, isl_dim_in);
+ auto DimNum = isl_map_dim(IslMap, isl_dim_out);
+ IslMap = isl_map_move_dims(IslMap, isl_dim_in, 0, isl_dim_out, DimNum - 1, 1);
+ IslMap = isl_map_move_dims(IslMap, isl_dim_out, 0, isl_dim_in, 0, 1);
+ return isl_map_set_tuple_id(IslMap, isl_dim_in, InputDimsId);
+}
+
+bool ScheduleTreeOptimizer::isMatrMultPattern(
+ __isl_keep isl_schedule_node *Node) {
+ auto *PartialSchedule =
+ isl_schedule_node_band_get_partial_schedule_union_map(Node);
+ if (isl_union_map_n_map(PartialSchedule) != 1)
+ return false;
+ auto *NewPartialSchedule = isl_map_from_union_map(PartialSchedule);
+ auto DimNum = isl_map_dim(NewPartialSchedule, isl_dim_in);
+ if (DimNum != 3) {
+ isl_map_free(NewPartialSchedule);
+ return false;
+ }
+ NewPartialSchedule = circularShiftOutputDims(NewPartialSchedule);
+ if (containsMatrMult(NewPartialSchedule)) {
+ isl_map_free(NewPartialSchedule);
+ return true;
+ }
+ isl_map_free(NewPartialSchedule);
+ return false;
+}
+
+__isl_give isl_schedule_node *
+ScheduleTreeOptimizer::optimizeBand(__isl_take isl_schedule_node *Node,
+ void *User) {
+ if (!isTileableBandNode(Node))
+ return Node;
+
+ if (PMBasedOpts && isMatrMultPattern(Node))
+ DEBUG(dbgs() << "The matrix multiplication pattern was detected\n");
+
+ return standardBandOpts(Node, User);
+}
+
__isl_give isl_schedule *
ScheduleTreeOptimizer::optimizeSchedule(__isl_take isl_schedule *Schedule) {
isl_schedule_node *Root = isl_schedule_get_root(Schedule);
Added: polly/trunk/test/ScheduleOptimizer/pattern-matching-based-opts.ll
URL: http://llvm.org/viewvc/llvm-project/polly/trunk/test/ScheduleOptimizer/pattern-matching-based-opts.ll?rev=271128&view=auto
==============================================================================
--- polly/trunk/test/ScheduleOptimizer/pattern-matching-based-opts.ll (added)
+++ polly/trunk/test/ScheduleOptimizer/pattern-matching-based-opts.ll Sat May 28 11:17:58 2016
@@ -0,0 +1,65 @@
+; RUN: opt %loadPolly -polly-opt-isl -debug < %s 2>&1| FileCheck %s
+; RUN: opt %loadPolly -polly-opt-isl -polly-pattern-matching-based-opts=true -debug < %s 2>&1| FileCheck %s --check-prefix=PATTERN-MATCHING-OPTS
+; REQUIRES: asserts
+; CHECK-NOT: The matrix multiplication pattern was detected
+; PATTERN-MATCHING-OPTS: The matrix multiplication pattern was detected
+
+define internal void @kernel_gemm(i32 %arg, i32 %arg1, i32 %arg2, double %arg3, double %arg4, [1056 x double]* %arg5, [1024 x double]* %arg6, [1056 x double]* %arg7) {
+bb:
+ br label %bb8
+
+bb8: ; preds = %bb39, %bb
+ %tmp = phi i32 [ 0, %bb ], [ %tmp40, %bb39 ]
+ %tmp9 = icmp slt i32 %tmp, 1056
+ br i1 %tmp9, label %bb10, label %bb41
+
+bb10: ; preds = %bb8
+ br label %bb11
+
+bb11: ; preds = %bb37, %bb10
+ %tmp12 = phi i32 [ 0, %bb10 ], [ %tmp38, %bb37 ]
+ %tmp13 = icmp slt i32 %tmp12, 1056
+ br i1 %tmp13, label %bb14, label %bb39
+
+bb14: ; preds = %bb11
+ %tmp15 = sext i32 %tmp12 to i64
+ %tmp16 = sext i32 %tmp to i64
+ %tmp17 = getelementptr inbounds [1056 x double], [1056 x double]* %arg5, i64 %tmp16
+ %tmp18 = getelementptr inbounds [1056 x double], [1056 x double]* %tmp17, i64 0, i64 %tmp15
+ %tmp19 = load double, double* %tmp18, align 8
+ %tmp20 = fmul double %tmp19, %arg4
+ store double %tmp20, double* %tmp18, align 8
+ br label %bb21
+
+bb21: ; preds = %bb24, %bb14
+ %tmp22 = phi i32 [ 0, %bb14 ], [ %tmp36, %bb24 ]
+ %tmp23 = icmp slt i32 %tmp22, 1024
+ br i1 %tmp23, label %bb24, label %bb37
+
+bb24: ; preds = %bb21
+ %tmp25 = sext i32 %tmp22 to i64
+ %tmp26 = getelementptr inbounds [1024 x double], [1024 x double]* %arg6, i64 %tmp16
+ %tmp27 = getelementptr inbounds [1024 x double], [1024 x double]* %tmp26, i64 0, i64 %tmp25
+ %tmp28 = load double, double* %tmp27, align 8
+ %tmp29 = fmul double %arg3, %tmp28
+ %tmp30 = getelementptr inbounds [1056 x double], [1056 x double]* %arg7, i64 %tmp25
+ %tmp31 = getelementptr inbounds [1056 x double], [1056 x double]* %tmp30, i64 0, i64 %tmp15
+ %tmp32 = load double, double* %tmp31, align 8
+ %tmp33 = fmul double %tmp29, %tmp32
+ %tmp34 = load double, double* %tmp18, align 8
+ %tmp35 = fadd double %tmp34, %tmp33
+ store double %tmp35, double* %tmp18, align 8
+ %tmp36 = add nsw i32 %tmp22, 1
+ br label %bb21
+
+bb37: ; preds = %bb21
+ %tmp38 = add nsw i32 %tmp12, 1
+ br label %bb11
+
+bb39: ; preds = %bb11
+ %tmp40 = add nsw i32 %tmp, 1
+ br label %bb8
+
+bb41: ; preds = %bb8
+ ret void
+}
Added: polly/trunk/test/ScheduleOptimizer/pattern-matching-based-opts_2.ll
URL: http://llvm.org/viewvc/llvm-project/polly/trunk/test/ScheduleOptimizer/pattern-matching-based-opts_2.ll?rev=271128&view=auto
==============================================================================
--- polly/trunk/test/ScheduleOptimizer/pattern-matching-based-opts_2.ll (added)
+++ polly/trunk/test/ScheduleOptimizer/pattern-matching-based-opts_2.ll Sat May 28 11:17:58 2016
@@ -0,0 +1,63 @@
+; RUN: opt %loadPolly -polly-opt-isl -polly-pattern-matching-based-opts=true -debug < %s 2>&1 | FileCheck %s
+; REQUIRES: asserts
+; CHECK-NOT: The matrix multiplication pattern was detected
+
+define internal void @kernel_gemm(i32 %arg, i32 %arg1, i32 %arg2, double %arg3, double %arg4, [1056 x double]* %arg5, [1024 x double]* %arg6, [1056 x double]* %arg7) {
+bb:
+ br label %bb8
+
+bb8: ; preds = %bb39, %bb
+ %tmp = phi i32 [ 0, %bb ], [ %tmp40, %bb39 ]
+ %tmp9 = icmp slt i32 %tmp, 1056
+ br i1 %tmp9, label %bb10, label %bb41
+
+bb10: ; preds = %bb8
+ br label %bb11
+
+bb11: ; preds = %bb37, %bb10
+ %tmp12 = phi i32 [ 0, %bb10 ], [ %tmp38, %bb37 ]
+ %tmp13 = icmp slt i32 %tmp12, 1056
+ br i1 %tmp13, label %bb14, label %bb39
+
+bb14: ; preds = %bb11
+ %tmp15 = sext i32 %tmp12 to i64
+ %tmp16 = sext i32 %tmp to i64
+ %tmp17 = getelementptr inbounds [1056 x double], [1056 x double]* %arg5, i64 %tmp16
+ %tmp18 = getelementptr inbounds [1056 x double], [1056 x double]* %tmp17, i64 0, i64 %tmp15
+ %tmp19 = load double, double* %tmp18, align 8
+ %tmp20 = fmul double %tmp19, %arg4
+ store double %tmp20, double* %tmp18, align 8
+ br label %bb21
+
+bb21: ; preds = %bb24, %bb14
+ %tmp22 = phi i32 [ 0, %bb14 ], [ %tmp36, %bb24 ]
+ %tmp23 = icmp slt i32 %tmp22, 1024
+ br i1 %tmp23, label %bb24, label %bb37
+
+bb24: ; preds = %bb21
+ %tmp25 = sext i32 %tmp22 to i64
+ %tmp26 = getelementptr inbounds [1024 x double], [1024 x double]* %arg6, i64 %tmp16
+ %tmp27 = getelementptr inbounds [1024 x double], [1024 x double]* %tmp26, i64 0, i64 %tmp25
+ %tmp28 = load double, double* %tmp27, align 8
+ %tmp29 = fmul double %arg3, %tmp28
+ %tmp30 = getelementptr inbounds [1056 x double], [1056 x double]* %arg7, i64 %tmp25
+ %tmp31 = getelementptr inbounds [1056 x double], [1056 x double]* %tmp30, i64 0, i64 %tmp15
+ %tmp32 = load double, double* %tmp31, align 8
+ %tmp33 = fmul double %tmp29, %tmp32
+ %tmp34 = load double, double* %tmp18, align 8
+ %tmp35 = fadd double %tmp34, %tmp33
+ store double %tmp35, double* %tmp18, align 8
+ %tmp36 = add nsw i32 %tmp22, 1
+ br label %bb21
+
+bb37: ; preds = %bb21
+ %tmp38 = add nsw i32 %tmp12, 2
+ br label %bb11
+
+bb39: ; preds = %bb11
+ %tmp40 = add nsw i32 %tmp, 1
+ br label %bb8
+
+bb41: ; preds = %bb8
+ ret void
+}
More information about the llvm-commits
mailing list