[polly] r311302 - [MatMul] Make MatMul detection independent of internal isl representations.
Michael Kruse via llvm-commits
llvm-commits at lists.llvm.org
Sun Aug 20 14:31:11 PDT 2017
Author: meinersbur
Date: Sun Aug 20 14:31:11 2017
New Revision: 311302
URL: http://llvm.org/viewvc/llvm-project?rev=311302&view=rev
Log:
[MatMul] Make MatMul detection independent of internal isl representations.
The pattern recognition for MatMul is restrictive.
The number of "disjuncts" in the isl_map containing constraint
information was previously required to be 1
(as per isl_*_coalesce - which should ideally produce a domain map with
a single disjunct, but does not under some circumstances).
This was changed and made more flexible.
Contributed-by: Annanay Agarwal <cs14btech11001 at iith.ac.in>
Differential Revision: https://reviews.llvm.org/D36460
Added:
polly/trunk/test/ScheduleOptimizer/pattern_matching_based_opts_splitmap.ll
polly/trunk/test/ScheduleOptimizer/pattern_matching_based_opts_splitmap___%for.body---%for.end23.jscop
polly/trunk/test/ScheduleOptimizer/pattern_matching_based_opts_splitmap___%for.body---%for.end23.jscop.transformed
Modified:
polly/trunk/lib/Transform/ScheduleOptimizer.cpp
Modified: polly/trunk/lib/Transform/ScheduleOptimizer.cpp
URL: http://llvm.org/viewvc/llvm-project/polly/trunk/lib/Transform/ScheduleOptimizer.cpp?rev=311302&r1=311301&r2=311302&view=diff
==============================================================================
--- polly/trunk/lib/Transform/ScheduleOptimizer.cpp (original)
+++ polly/trunk/lib/Transform/ScheduleOptimizer.cpp Sun Aug 20 14:31:11 2017
@@ -483,61 +483,6 @@ ScheduleTreeOptimizer::standardBandOpts(
return Node;
}
-/// Get the position of a dimension with a non-zero coefficient.
-///
-/// Check that isl constraint @p Constraint has only one non-zero
-/// coefficient for dimensions that have type @p DimType. If this is true,
-/// return the position of the dimension corresponding to the non-zero
-/// coefficient and negative value, otherwise.
-///
-/// @param Constraint The isl constraint to be checked.
-/// @param DimType The type of the dimensions.
-/// @return The position of the dimension in case the isl
-/// constraint satisfies the requirements, a negative
-/// value, otherwise.
-static int getMatMulConstraintDim(isl::constraint Constraint,
- isl::dim DimType) {
- int DimPos = -1;
- auto LocalSpace = Constraint.get_local_space();
- int LocalSpaceDimNum = LocalSpace.dim(DimType);
- for (int i = 0; i < LocalSpaceDimNum; i++) {
- auto Val = Constraint.get_coefficient_val(DimType, i);
- if (Val.is_zero())
- continue;
- if (DimPos >= 0 || (DimType == isl::dim::out && !Val.is_one()) ||
- (DimType == isl::dim::in && !Val.is_negone()))
- return -1;
- DimPos = i;
- }
- return DimPos;
-}
-
-/// Check the form of the isl constraint.
-///
-/// Check that the @p DimInPos input dimension of the isl constraint
-/// @p Constraint has a coefficient that is equal to negative one, the @p
-/// DimOutPos has a coefficient that is equal to one and others
-/// have coefficients equal to zero.
-///
-/// @param Constraint The isl constraint to be checked.
-/// @param DimInPos The input dimension of the isl constraint.
-/// @param DimOutPos The output dimension of the isl constraint.
-/// @return isl_stat_ok in case the isl constraint satisfies
-/// the requirements, isl_stat_error otherwise.
-static isl_stat isMatMulOperandConstraint(isl::constraint Constraint,
- int &DimInPos, int &DimOutPos) {
- auto Val = Constraint.get_constant_val();
- if (!isl_constraint_is_equality(Constraint.get()) || !Val.is_zero())
- return isl_stat_error;
- DimInPos = getMatMulConstraintDim(Constraint, isl::dim::in);
- if (DimInPos < 0)
- return isl_stat_error;
- DimOutPos = getMatMulConstraintDim(Constraint, isl::dim::out);
- if (DimOutPos < 0)
- return isl_stat_error;
- return isl_stat_ok;
-}
-
/// Permute the two dimensions of the isl map.
///
/// Permute @p DstPos and @p SrcPos dimensions of the isl map @p Map that
@@ -585,30 +530,49 @@ isl::map permuteDimensions(isl::map Map,
/// second output dimension.
/// @return True in case @p AccMap has the expected form and false,
/// otherwise.
-static bool isMatMulOperandAcc(isl::map AccMap, int &FirstPos, int &SecondPos) {
- int DimInPos[] = {FirstPos, SecondPos};
- auto Lambda = [=, &DimInPos](isl::basic_map BasicMap) -> isl::stat {
- auto Constraints = BasicMap.get_constraint_list();
- if (isl_constraint_list_n_constraint(Constraints.get()) != 2)
- return isl::stat::error;
- for (int i = 0; i < 2; i++) {
- auto Constraint =
- isl::manage(isl_constraint_list_get_constraint(Constraints.get(), i));
- int InPos, OutPos;
- if (isMatMulOperandConstraint(Constraint, InPos, OutPos) ==
- isl_stat_error ||
- OutPos > 1 || (DimInPos[OutPos] >= 0 && DimInPos[OutPos] != InPos))
- return isl::stat::error;
- DimInPos[OutPos] = InPos;
- }
- return isl::stat::ok;
- };
- if (AccMap.foreach_basic_map(Lambda) != isl::stat::ok || DimInPos[0] < 0 ||
- DimInPos[1] < 0)
+static bool isMatMulOperandAcc(isl::set Domain, isl::map AccMap, int &FirstPos,
+ int &SecondPos) {
+
+ isl::space Space = AccMap.get_space();
+ isl::map Universe = isl::map::universe(Space);
+
+ if (Space.dim(isl::dim::out) != 2)
return false;
- FirstPos = DimInPos[0];
- SecondPos = DimInPos[1];
- return true;
+
+ // MatMul has the form:
+ // for (i = 0; i < N; i++)
+ // for (j = 0; j < M; j++)
+ // for (k = 0; k < P; k++)
+ // C[i, j] += A[i, k] * B[k, j]
+ //
+ // Permutation of three outer loops: 3! = 6 possibilities.
+ int FirstDims[] = {0, 0, 1, 1, 2, 2};
+ int SecondDims[] = {1, 2, 2, 0, 0, 1};
+ for (int i = 0; i < 6; i += 1) {
+ auto PossibleMatMul =
+ Universe.equate(isl::dim::in, FirstDims[i], isl::dim::out, 0)
+ .equate(isl::dim::in, SecondDims[i], isl::dim::out, 1);
+
+ AccMap = AccMap.intersect_domain(Domain);
+ PossibleMatMul = PossibleMatMul.intersect_domain(Domain);
+
+ // If AccMap spans entire domain (Non-partial write),
+ // compute FirstPos and SecondPos.
+ // If AccMap != PossibleMatMul here (the two maps have been gisted at
+ // this point), it means that the writes are not complete, or in other
+ // words, it is a Partial write and Partial writes must be rejected.
+ if (AccMap.is_equal(PossibleMatMul)) {
+ if (FirstPos != -1 && FirstPos != FirstDims[i])
+ continue;
+ FirstPos = FirstDims[i];
+ if (SecondPos != -1 && SecondPos != SecondDims[i])
+ continue;
+ SecondPos = SecondDims[i];
+ return true;
+ }
+ }
+
+ return false;
}
/// Does the memory access represent a non-scalar operand of the matrix
@@ -627,18 +591,16 @@ static bool isMatMulNonScalarReadAccess(
if (!MemAccess->isLatestArrayKind() || !MemAccess->isRead())
return false;
auto AccMap = MemAccess->getLatestAccessRelation();
- if (isMatMulOperandAcc(AccMap, MMI.i, MMI.j) && !MMI.ReadFromC &&
- isl_map_n_basic_map(AccMap.get()) == 1) {
+ isl::set StmtDomain = MemAccess->getStatement()->getDomain();
+ if (isMatMulOperandAcc(StmtDomain, AccMap, MMI.i, MMI.j) && !MMI.ReadFromC) {
MMI.ReadFromC = MemAccess;
return true;
}
- if (isMatMulOperandAcc(AccMap, MMI.i, MMI.k) && !MMI.A &&
- isl_map_n_basic_map(AccMap.get()) == 1) {
+ if (isMatMulOperandAcc(StmtDomain, AccMap, MMI.i, MMI.k) && !MMI.A) {
MMI.A = MemAccess;
return true;
}
- if (isMatMulOperandAcc(AccMap, MMI.k, MMI.j) && !MMI.B &&
- isl_map_n_basic_map(AccMap.get()) == 1) {
+ if (isMatMulOperandAcc(StmtDomain, AccMap, MMI.k, MMI.j) && !MMI.B) {
MMI.B = MemAccess;
return true;
}
@@ -758,8 +720,7 @@ static bool containsMatrMult(isl::map Pa
if (!MemAccessPtr->isWrite())
return false;
auto AccMap = MemAccessPtr->getLatestAccessRelation();
- if (isl_map_n_basic_map(AccMap.get()) != 1 ||
- !isMatMulOperandAcc(AccMap, MMI.i, MMI.j))
+ if (!isMatMulOperandAcc(Stmt->getDomain(), AccMap, MMI.i, MMI.j))
return false;
MMI.WriteToC = MemAccessPtr;
break;
Added: polly/trunk/test/ScheduleOptimizer/pattern_matching_based_opts_splitmap.ll
URL: http://llvm.org/viewvc/llvm-project/polly/trunk/test/ScheduleOptimizer/pattern_matching_based_opts_splitmap.ll?rev=311302&view=auto
==============================================================================
--- polly/trunk/test/ScheduleOptimizer/pattern_matching_based_opts_splitmap.ll (added)
+++ polly/trunk/test/ScheduleOptimizer/pattern_matching_based_opts_splitmap.ll Sun Aug 20 14:31:11 2017
@@ -0,0 +1,59 @@
+; RUN: opt %loadPolly -polly-import-jscop -polly-import-jscop-postfix=transformed -polly-opt-isl -debug-only=polly-opt-isl -disable-output < %s 2>&1 | FileCheck %s
+; REQUIRES: asserts
+;
+; void pattern_matching_based_opts_splitmap(double C[static const restrict 2][2], double A[static const restrict 2][784], double B[static const restrict 784][2]) {
+; for (int i = 0; i < 2; i+=1)
+; for (int j = 0; j < 2; j+=1)
+; for (int k = 0; k < 784; k+=1)
+; C[i][j] += A[i][k] * B[k][j];
+;}
+;
+; Check that the pattern matching detects the matrix multiplication pattern
+; when the AccMap cannot be reduced to a single disjunct.
+;
+; CHECK: The matrix multiplication pattern was detected
+;
+; ModuleID = 'pattern_matching_based_opts_splitmap.ll'
+;
+; Function Attrs: noinline nounwind uwtable
+define void @pattern_matching_based_opts_splitmap([2 x double]* noalias dereferenceable(32) %C, [784 x double]* noalias dereferenceable(12544) %A, [2 x double]* noalias dereferenceable(12544) %B) {
+entry:
+ br label %for.body
+
+for.body: ; preds = %entry, %for.inc21
+ %i = phi i64 [ 0, %entry ], [ %add22, %for.inc21 ]
+ br label %for.body3
+
+for.body3: ; preds = %for.body, %for.inc18
+ %j = phi i64 [ 0, %for.body ], [ %add19, %for.inc18 ]
+ br label %for.body6
+
+for.body6: ; preds = %for.body3, %for.body6
+ %k = phi i64 [ 0, %for.body3 ], [ %add17, %for.body6 ]
+ %arrayidx8 = getelementptr inbounds [784 x double], [784 x double]* %A, i64 %i, i64 %k
+ %tmp6 = load double, double* %arrayidx8, align 8
+ %arrayidx12 = getelementptr inbounds [2 x double], [2 x double]* %B, i64 %k, i64 %j
+ %tmp10 = load double, double* %arrayidx12, align 8
+ %mul = fmul double %tmp6, %tmp10
+ %arrayidx16 = getelementptr inbounds [2 x double], [2 x double]* %C, i64 %i, i64 %j
+ %tmp14 = load double, double* %arrayidx16, align 8
+ %add = fadd double %tmp14, %mul
+ store double %add, double* %arrayidx16, align 8
+ %add17 = add nsw i64 %k, 1
+ %cmp5 = icmp slt i64 %add17, 784
+ br i1 %cmp5, label %for.body6, label %for.inc18
+
+for.inc18: ; preds = %for.body6
+ %add19 = add nsw i64 %j, 1
+ %cmp2 = icmp slt i64 %add19, 2
+ br i1 %cmp2, label %for.body3, label %for.inc21
+
+for.inc21: ; preds = %for.inc18
+ %add22 = add nsw i64 %i, 1
+ %cmp = icmp slt i64 %add22, 2
+ br i1 %cmp, label %for.body, label %for.end23
+
+for.end23: ; preds = %for.inc21
+ ret void
+}
+
Added: polly/trunk/test/ScheduleOptimizer/pattern_matching_based_opts_splitmap___%for.body---%for.end23.jscop
URL: http://llvm.org/viewvc/llvm-project/polly/trunk/test/ScheduleOptimizer/pattern_matching_based_opts_splitmap___%25for.body---%25for.end23.jscop?rev=311302&view=auto
==============================================================================
--- polly/trunk/test/ScheduleOptimizer/pattern_matching_based_opts_splitmap___%for.body---%for.end23.jscop (added)
+++ polly/trunk/test/ScheduleOptimizer/pattern_matching_based_opts_splitmap___%for.body---%for.end23.jscop Sun Aug 20 14:31:11 2017
@@ -0,0 +1,46 @@
+{
+ "arrays" : [
+ {
+ "name" : "MemRef_A",
+ "sizes" : [ "*", "784" ],
+ "type" : "double"
+ },
+ {
+ "name" : "MemRef_B",
+ "sizes" : [ "*", "2" ],
+ "type" : "double"
+ },
+ {
+ "name" : "MemRef_C",
+ "sizes" : [ "*", "2" ],
+ "type" : "double"
+ }
+ ],
+ "context" : "{ : }",
+ "name" : "%for.body---%for.end23",
+ "statements" : [
+ {
+ "accesses" : [
+ {
+ "kind" : "read",
+ "relation" : "{ Stmt_for_body6[i0, i1, i2] -> MemRef_A[i0, i2] }"
+ },
+ {
+ "kind" : "read",
+ "relation" : "{ Stmt_for_body6[i0, i1, i2] -> MemRef_B[i2, i1] }"
+ },
+ {
+ "kind" : "read",
+ "relation" : "{ Stmt_for_body6[i0, i1, i2] -> MemRef_C[i0, i1] }"
+ },
+ {
+ "kind" : "write",
+ "relation" : "{ Stmt_for_body6[i0, i1, i2] -> MemRef_C[i0, i1] }"
+ }
+ ],
+ "domain" : "{ Stmt_for_body6[i0, i1, i2] : 0 <= i0 <= 1 and 0 <= i1 <= 1 and 0 <= i2 <= 783 }",
+ "name" : "Stmt_for_body6",
+ "schedule" : "{ Stmt_for_body6[i0, i1, i2] -> [i0, i1, i2] }"
+ }
+ ]
+}
Added: polly/trunk/test/ScheduleOptimizer/pattern_matching_based_opts_splitmap___%for.body---%for.end23.jscop.transformed
URL: http://llvm.org/viewvc/llvm-project/polly/trunk/test/ScheduleOptimizer/pattern_matching_based_opts_splitmap___%25for.body---%25for.end23.jscop.transformed?rev=311302&view=auto
==============================================================================
--- polly/trunk/test/ScheduleOptimizer/pattern_matching_based_opts_splitmap___%for.body---%for.end23.jscop.transformed (added)
+++ polly/trunk/test/ScheduleOptimizer/pattern_matching_based_opts_splitmap___%for.body---%for.end23.jscop.transformed Sun Aug 20 14:31:11 2017
@@ -0,0 +1,46 @@
+{
+ "arrays" : [
+ {
+ "name" : "MemRef_A",
+ "sizes" : [ "*", "784" ],
+ "type" : "double"
+ },
+ {
+ "name" : "MemRef_B",
+ "sizes" : [ "*", "2" ],
+ "type" : "double"
+ },
+ {
+ "name" : "MemRef_C",
+ "sizes" : [ "*", "2" ],
+ "type" : "double"
+ }
+ ],
+ "context" : "{ : }",
+ "name" : "%for.body---%for.end23",
+ "statements" : [
+ {
+ "accesses" : [
+ {
+ "kind" : "read",
+ "relation" : "{ Stmt_for_body6[i0, i1, i2] -> MemRef_A[i0, i2] }"
+ },
+ {
+ "kind" : "read",
+ "relation" : "{ Stmt_for_body6[i0, i1, i2] -> MemRef_B[i2, i1] }"
+ },
+ {
+ "kind" : "read",
+ "relation" : "{ Stmt_for_body6[i0, i1, i2] -> MemRef_C[i0, i1] }"
+ },
+ {
+ "kind" : "write",
+ "relation" : "{ Stmt_for_body6[i0, i1, i2] -> MemRef_C[i0, i1] : i2 <= 784 - i0 - i1; Stmt_for_body6[1, 1, 783] -> MemRef_C[1, 1] }"
+ }
+ ],
+ "domain" : "{ Stmt_for_body6[i0, i1, i2] : 0 <= i0 <= 1 and 0 <= i1 <= 1 and 0 <= i2 <= 783 }",
+ "name" : "Stmt_for_body6",
+ "schedule" : "{ Stmt_for_body6[i0, i1, i2] -> [i0, i1, i2] }"
+ }
+ ]
+}
More information about the llvm-commits
mailing list