[polly] r311302 - [MatMul] Make MatMul detection independent of internal isl representations.

Michael Kruse via llvm-commits llvm-commits at lists.llvm.org
Sun Aug 20 14:31:11 PDT 2017


Author: meinersbur
Date: Sun Aug 20 14:31:11 2017
New Revision: 311302

URL: http://llvm.org/viewvc/llvm-project?rev=311302&view=rev
Log:
[MatMul] Make MatMul detection independent of internal isl representations.

The pattern recognition for MatMul is restrictive.

The number of "disjuncts" in the isl_map containing constraint
information was previously required to be 1
(as per isl_*_coalesce - which should ideally produce a domain map with
a single disjunct, but does not under some circumstances).

This was changed and made more flexible.

Contributed-by: Annanay Agarwal <cs14btech11001 at iith.ac.in>

Differential Revision: https://reviews.llvm.org/D36460

Added:
    polly/trunk/test/ScheduleOptimizer/pattern_matching_based_opts_splitmap.ll
    polly/trunk/test/ScheduleOptimizer/pattern_matching_based_opts_splitmap___%for.body---%for.end23.jscop
    polly/trunk/test/ScheduleOptimizer/pattern_matching_based_opts_splitmap___%for.body---%for.end23.jscop.transformed
Modified:
    polly/trunk/lib/Transform/ScheduleOptimizer.cpp

Modified: polly/trunk/lib/Transform/ScheduleOptimizer.cpp
URL: http://llvm.org/viewvc/llvm-project/polly/trunk/lib/Transform/ScheduleOptimizer.cpp?rev=311302&r1=311301&r2=311302&view=diff
==============================================================================
--- polly/trunk/lib/Transform/ScheduleOptimizer.cpp (original)
+++ polly/trunk/lib/Transform/ScheduleOptimizer.cpp Sun Aug 20 14:31:11 2017
@@ -483,61 +483,6 @@ ScheduleTreeOptimizer::standardBandOpts(
   return Node;
 }
 
-/// Get the position of a dimension with a non-zero coefficient.
-///
-/// Check that isl constraint @p Constraint has only one non-zero
-/// coefficient for dimensions that have type @p DimType. If this is true,
-/// return the position of the dimension corresponding to the non-zero
-/// coefficient and negative value, otherwise.
-///
-/// @param Constraint The isl constraint to be checked.
-/// @param DimType    The type of the dimensions.
-/// @return           The position of the dimension in case the isl
-///                   constraint satisfies the requirements, a negative
-///                   value, otherwise.
-static int getMatMulConstraintDim(isl::constraint Constraint,
-                                  isl::dim DimType) {
-  int DimPos = -1;
-  auto LocalSpace = Constraint.get_local_space();
-  int LocalSpaceDimNum = LocalSpace.dim(DimType);
-  for (int i = 0; i < LocalSpaceDimNum; i++) {
-    auto Val = Constraint.get_coefficient_val(DimType, i);
-    if (Val.is_zero())
-      continue;
-    if (DimPos >= 0 || (DimType == isl::dim::out && !Val.is_one()) ||
-        (DimType == isl::dim::in && !Val.is_negone()))
-      return -1;
-    DimPos = i;
-  }
-  return DimPos;
-}
-
-/// Check the form of the isl constraint.
-///
-/// Check that the @p DimInPos input dimension of the isl constraint
-/// @p Constraint has a coefficient that is equal to negative one, the @p
-/// DimOutPos has a coefficient that is equal to one and others
-/// have coefficients equal to zero.
-///
-/// @param Constraint The isl constraint to be checked.
-/// @param DimInPos   The input dimension of the isl constraint.
-/// @param DimOutPos  The output dimension of the isl constraint.
-/// @return           isl_stat_ok in case the isl constraint satisfies
-///                   the requirements, isl_stat_error otherwise.
-static isl_stat isMatMulOperandConstraint(isl::constraint Constraint,
-                                          int &DimInPos, int &DimOutPos) {
-  auto Val = Constraint.get_constant_val();
-  if (!isl_constraint_is_equality(Constraint.get()) || !Val.is_zero())
-    return isl_stat_error;
-  DimInPos = getMatMulConstraintDim(Constraint, isl::dim::in);
-  if (DimInPos < 0)
-    return isl_stat_error;
-  DimOutPos = getMatMulConstraintDim(Constraint, isl::dim::out);
-  if (DimOutPos < 0)
-    return isl_stat_error;
-  return isl_stat_ok;
-}
-
 /// Permute the two dimensions of the isl map.
 ///
 /// Permute @p DstPos and @p SrcPos dimensions of the isl map @p Map that
@@ -585,30 +530,49 @@ isl::map permuteDimensions(isl::map Map,
 ///                  second output dimension.
 /// @return          True in case @p AccMap has the expected form and false,
 ///                  otherwise.
-static bool isMatMulOperandAcc(isl::map AccMap, int &FirstPos, int &SecondPos) {
-  int DimInPos[] = {FirstPos, SecondPos};
-  auto Lambda = [=, &DimInPos](isl::basic_map BasicMap) -> isl::stat {
-    auto Constraints = BasicMap.get_constraint_list();
-    if (isl_constraint_list_n_constraint(Constraints.get()) != 2)
-      return isl::stat::error;
-    for (int i = 0; i < 2; i++) {
-      auto Constraint =
-          isl::manage(isl_constraint_list_get_constraint(Constraints.get(), i));
-      int InPos, OutPos;
-      if (isMatMulOperandConstraint(Constraint, InPos, OutPos) ==
-              isl_stat_error ||
-          OutPos > 1 || (DimInPos[OutPos] >= 0 && DimInPos[OutPos] != InPos))
-        return isl::stat::error;
-      DimInPos[OutPos] = InPos;
-    }
-    return isl::stat::ok;
-  };
-  if (AccMap.foreach_basic_map(Lambda) != isl::stat::ok || DimInPos[0] < 0 ||
-      DimInPos[1] < 0)
+static bool isMatMulOperandAcc(isl::set Domain, isl::map AccMap, int &FirstPos,
+                               int &SecondPos) {
+
+  isl::space Space = AccMap.get_space();
+  isl::map Universe = isl::map::universe(Space);
+
+  if (Space.dim(isl::dim::out) != 2)
     return false;
-  FirstPos = DimInPos[0];
-  SecondPos = DimInPos[1];
-  return true;
+
+  // MatMul has the form:
+  // for (i = 0; i < N; i++)
+  //   for (j = 0; j < M; j++)
+  //     for (k = 0; k < P; k++)
+  //       C[i, j] += A[i, k] * B[k, j]
+  //
+  // Permutation of three outer loops: 3! = 6 possibilities.
+  int FirstDims[] = {0, 0, 1, 1, 2, 2};
+  int SecondDims[] = {1, 2, 2, 0, 0, 1};
+  for (int i = 0; i < 6; i += 1) {
+    auto PossibleMatMul =
+        Universe.equate(isl::dim::in, FirstDims[i], isl::dim::out, 0)
+            .equate(isl::dim::in, SecondDims[i], isl::dim::out, 1);
+
+    AccMap = AccMap.intersect_domain(Domain);
+    PossibleMatMul = PossibleMatMul.intersect_domain(Domain);
+
+    // If AccMap spans entire domain (Non-partial write),
+    // compute FirstPos and SecondPos.
+    // If AccMap != PossibleMatMul here (the two maps have been gisted at
+    // this point), it means that the writes are not complete, or in other
+    // words, it is a Partial write and Partial writes must be rejected.
+    if (AccMap.is_equal(PossibleMatMul)) {
+      if (FirstPos != -1 && FirstPos != FirstDims[i])
+        continue;
+      FirstPos = FirstDims[i];
+      if (SecondPos != -1 && SecondPos != SecondDims[i])
+        continue;
+      SecondPos = SecondDims[i];
+      return true;
+    }
+  }
+
+  return false;
 }
 
 /// Does the memory access represent a non-scalar operand of the matrix
@@ -627,18 +591,16 @@ static bool isMatMulNonScalarReadAccess(
   if (!MemAccess->isLatestArrayKind() || !MemAccess->isRead())
     return false;
   auto AccMap = MemAccess->getLatestAccessRelation();
-  if (isMatMulOperandAcc(AccMap, MMI.i, MMI.j) && !MMI.ReadFromC &&
-      isl_map_n_basic_map(AccMap.get()) == 1) {
+  isl::set StmtDomain = MemAccess->getStatement()->getDomain();
+  if (isMatMulOperandAcc(StmtDomain, AccMap, MMI.i, MMI.j) && !MMI.ReadFromC) {
     MMI.ReadFromC = MemAccess;
     return true;
   }
-  if (isMatMulOperandAcc(AccMap, MMI.i, MMI.k) && !MMI.A &&
-      isl_map_n_basic_map(AccMap.get()) == 1) {
+  if (isMatMulOperandAcc(StmtDomain, AccMap, MMI.i, MMI.k) && !MMI.A) {
     MMI.A = MemAccess;
     return true;
   }
-  if (isMatMulOperandAcc(AccMap, MMI.k, MMI.j) && !MMI.B &&
-      isl_map_n_basic_map(AccMap.get()) == 1) {
+  if (isMatMulOperandAcc(StmtDomain, AccMap, MMI.k, MMI.j) && !MMI.B) {
     MMI.B = MemAccess;
     return true;
   }
@@ -758,8 +720,7 @@ static bool containsMatrMult(isl::map Pa
     if (!MemAccessPtr->isWrite())
       return false;
     auto AccMap = MemAccessPtr->getLatestAccessRelation();
-    if (isl_map_n_basic_map(AccMap.get()) != 1 ||
-        !isMatMulOperandAcc(AccMap, MMI.i, MMI.j))
+    if (!isMatMulOperandAcc(Stmt->getDomain(), AccMap, MMI.i, MMI.j))
       return false;
     MMI.WriteToC = MemAccessPtr;
     break;

Added: polly/trunk/test/ScheduleOptimizer/pattern_matching_based_opts_splitmap.ll
URL: http://llvm.org/viewvc/llvm-project/polly/trunk/test/ScheduleOptimizer/pattern_matching_based_opts_splitmap.ll?rev=311302&view=auto
==============================================================================
--- polly/trunk/test/ScheduleOptimizer/pattern_matching_based_opts_splitmap.ll (added)
+++ polly/trunk/test/ScheduleOptimizer/pattern_matching_based_opts_splitmap.ll Sun Aug 20 14:31:11 2017
@@ -0,0 +1,59 @@
+; RUN: opt %loadPolly -polly-import-jscop -polly-import-jscop-postfix=transformed -polly-opt-isl -debug-only=polly-opt-isl -disable-output < %s 2>&1 | FileCheck %s
+; REQUIRES: asserts
+;
+; void pattern_matching_based_opts_splitmap(double C[static const restrict 2][2], double A[static const restrict 2][784], double B[static const restrict 784][2]) {
+;  for (int i = 0; i < 2; i+=1)
+;    for (int j = 0; j < 2; j+=1)
+;      for (int k = 0; k < 784; k+=1)
+;        C[i][j] += A[i][k] * B[k][j];
+;}
+;
+; Check that the pattern matching detects the matrix multiplication pattern
+; when the AccMap cannot be reduced to a single disjunct.
+;
+; CHECK: The matrix multiplication pattern was detected
+;
+; ModuleID = 'pattern_matching_based_opts_splitmap.ll'
+;
+; Function Attrs: noinline nounwind uwtable
+define void @pattern_matching_based_opts_splitmap([2 x double]* noalias dereferenceable(32) %C, [784 x double]* noalias dereferenceable(12544) %A, [2 x double]* noalias dereferenceable(12544) %B) {
+entry:
+  br label %for.body
+
+for.body:                                         ; preds = %entry, %for.inc21
+  %i = phi i64 [ 0, %entry ], [ %add22, %for.inc21 ]
+  br label %for.body3
+
+for.body3:                                        ; preds = %for.body, %for.inc18
+  %j = phi i64 [ 0, %for.body ], [ %add19, %for.inc18 ]
+  br label %for.body6
+
+for.body6:                                        ; preds = %for.body3, %for.body6
+  %k = phi i64 [ 0, %for.body3 ], [ %add17, %for.body6 ]
+  %arrayidx8 = getelementptr inbounds [784 x double], [784 x double]* %A, i64 %i, i64 %k
+  %tmp6 = load double, double* %arrayidx8, align 8
+  %arrayidx12 = getelementptr inbounds [2 x double], [2 x double]* %B, i64 %k, i64 %j
+  %tmp10 = load double, double* %arrayidx12, align 8
+  %mul = fmul double %tmp6, %tmp10
+  %arrayidx16 = getelementptr inbounds [2 x double], [2 x double]* %C, i64 %i, i64 %j
+  %tmp14 = load double, double* %arrayidx16, align 8
+  %add = fadd double %tmp14, %mul
+  store double %add, double* %arrayidx16, align 8
+  %add17 = add nsw i64 %k, 1
+  %cmp5 = icmp slt i64 %add17, 784
+  br i1 %cmp5, label %for.body6, label %for.inc18
+
+for.inc18:                                        ; preds = %for.body6
+  %add19 = add nsw i64 %j, 1
+  %cmp2 = icmp slt i64 %add19, 2
+  br i1 %cmp2, label %for.body3, label %for.inc21
+
+for.inc21:                                        ; preds = %for.inc18
+  %add22 = add nsw i64 %i, 1
+  %cmp = icmp slt i64 %add22, 2
+  br i1 %cmp, label %for.body, label %for.end23
+
+for.end23:                                        ; preds = %for.inc21
+  ret void
+}
+

Added: polly/trunk/test/ScheduleOptimizer/pattern_matching_based_opts_splitmap___%for.body---%for.end23.jscop
URL: http://llvm.org/viewvc/llvm-project/polly/trunk/test/ScheduleOptimizer/pattern_matching_based_opts_splitmap___%25for.body---%25for.end23.jscop?rev=311302&view=auto
==============================================================================
--- polly/trunk/test/ScheduleOptimizer/pattern_matching_based_opts_splitmap___%for.body---%for.end23.jscop (added)
+++ polly/trunk/test/ScheduleOptimizer/pattern_matching_based_opts_splitmap___%for.body---%for.end23.jscop Sun Aug 20 14:31:11 2017
@@ -0,0 +1,46 @@
+{
+   "arrays" : [
+      {
+         "name" : "MemRef_A",
+         "sizes" : [ "*", "784" ],
+         "type" : "double"
+      },
+      {
+         "name" : "MemRef_B",
+         "sizes" : [ "*", "2" ],
+         "type" : "double"
+      },
+      {
+         "name" : "MemRef_C",
+         "sizes" : [ "*", "2" ],
+         "type" : "double"
+      }
+   ],
+   "context" : "{  :  }",
+   "name" : "%for.body---%for.end23",
+   "statements" : [
+      {
+         "accesses" : [
+            {
+               "kind" : "read",
+               "relation" : "{ Stmt_for_body6[i0, i1, i2] -> MemRef_A[i0, i2] }"
+            },
+            {
+               "kind" : "read",
+               "relation" : "{ Stmt_for_body6[i0, i1, i2] -> MemRef_B[i2, i1] }"
+            },
+            {
+               "kind" : "read",
+               "relation" : "{ Stmt_for_body6[i0, i1, i2] -> MemRef_C[i0, i1] }"
+            },
+            {
+               "kind" : "write",
+               "relation" : "{ Stmt_for_body6[i0, i1, i2] -> MemRef_C[i0, i1] }"
+            }
+         ],
+         "domain" : "{ Stmt_for_body6[i0, i1, i2] : 0 <= i0 <= 1 and 0 <= i1 <= 1 and 0 <= i2 <= 783 }",
+         "name" : "Stmt_for_body6",
+         "schedule" : "{ Stmt_for_body6[i0, i1, i2] -> [i0, i1, i2] }"
+      }
+   ]
+}

Added: polly/trunk/test/ScheduleOptimizer/pattern_matching_based_opts_splitmap___%for.body---%for.end23.jscop.transformed
URL: http://llvm.org/viewvc/llvm-project/polly/trunk/test/ScheduleOptimizer/pattern_matching_based_opts_splitmap___%25for.body---%25for.end23.jscop.transformed?rev=311302&view=auto
==============================================================================
--- polly/trunk/test/ScheduleOptimizer/pattern_matching_based_opts_splitmap___%for.body---%for.end23.jscop.transformed (added)
+++ polly/trunk/test/ScheduleOptimizer/pattern_matching_based_opts_splitmap___%for.body---%for.end23.jscop.transformed Sun Aug 20 14:31:11 2017
@@ -0,0 +1,46 @@
+{
+   "arrays" : [
+      {
+         "name" : "MemRef_A",
+         "sizes" : [ "*", "784" ],
+         "type" : "double"
+      },
+      {
+         "name" : "MemRef_B",
+         "sizes" : [ "*", "2" ],
+         "type" : "double"
+      },
+      {
+         "name" : "MemRef_C",
+         "sizes" : [ "*", "2" ],
+         "type" : "double"
+      }
+   ],
+   "context" : "{  :  }",
+   "name" : "%for.body---%for.end23",
+   "statements" : [
+      {
+         "accesses" : [
+            {
+               "kind" : "read",
+               "relation" : "{ Stmt_for_body6[i0, i1, i2] -> MemRef_A[i0, i2] }"
+            },
+            {
+               "kind" : "read",
+               "relation" : "{ Stmt_for_body6[i0, i1, i2] -> MemRef_B[i2, i1] }"
+            },
+            {
+               "kind" : "read",
+               "relation" : "{ Stmt_for_body6[i0, i1, i2] -> MemRef_C[i0, i1] }"
+            },
+            {
+               "kind" : "write",
+               "relation" : "{ Stmt_for_body6[i0, i1, i2] -> MemRef_C[i0, i1] : i2 <= 784 - i0 - i1; Stmt_for_body6[1, 1, 783] -> MemRef_C[1, 1] }"
+            }
+         ],
+         "domain" : "{ Stmt_for_body6[i0, i1, i2] : 0 <= i0 <= 1 and 0 <= i1 <= 1 and 0 <= i2 <= 783 }",
+         "name" : "Stmt_for_body6",
+         "schedule" : "{ Stmt_for_body6[i0, i1, i2] -> [i0, i1, i2] }"
+      }
+   ]
+}




More information about the llvm-commits mailing list