[Mlir-commits] [mlir] f365e85 - [mlir] Revisit `LinalgLoopDistributionOptions`.

Mahesh Ravishankar llvmlistbot at llvm.org
Mon Aug 15 08:56:39 PDT 2022


Author: Mahesh Ravishankar
Date: 2022-08-15T15:56:17Z
New Revision: f365e85c8314657e0c0f9257fbcb367d74e9d0cd

URL: https://github.com/llvm/llvm-project/commit/f365e85c8314657e0c0f9257fbcb367d74e9d0cd
DIFF: https://github.com/llvm/llvm-project/commit/f365e85c8314657e0c0f9257fbcb367d74e9d0cd.diff

LOG: [mlir] Revisit `LinalgLoopDistributionOptions`.

This patch cleans up the way `LinalgLoopDistributionOptions` are meant
to be used. The option just contains a call back that takes the list
of loop ranges that represent the loops that are to be distributed.
These loops are the outer parallel loops of the tiled operation which
have non-zero tile sizes specified. The call back returns for each of
the loops,
- The procId to use,
- The number of processors,
- The distribution method to use for that loop.

Reviewed By: antiagainst, hanchung

Differential Revision: https://reviews.llvm.org/D131232

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
    mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
    mlir/lib/Dialect/Linalg/Utils/Utils.cpp
    mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 46e662c6724ec..c3806e035cc8f 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -369,7 +369,10 @@ enum class DistributionMethod {
   /// to
   ///
   /// %iv = %lb + %procId * %step
-  CyclicNumProcsEqNumIters = 2
+  CyclicNumProcsEqNumIters = 2,
+
+  /// No Distribution.
+  None = 3
 };
 
 /// Callback function type used to get processor ID, and number of processors
@@ -377,11 +380,10 @@ enum class DistributionMethod {
 struct ProcInfo {
   Value procId;
   Value nprocs;
+  DistributionMethod distributionMethod;
 };
-using ProcInfoCallBackFn = std::function<SmallVector<ProcInfo, 2>(
+using ProcInfoCallBackFn = std::function<SmallVector<ProcInfo>(
     OpBuilder &b, Location loc, ArrayRef<Range> parallelLoopRanges)>;
-using OneDimProcInfoCallBackFn =
-    std::function<ProcInfo(OpBuilder &b, Location loc)>;
 
 /// Options that allow distribution of loops generated in Linalg transforms to
 /// processors while generating the loops.
@@ -389,21 +391,10 @@ struct LinalgLoopDistributionOptions {
   /// Callback function that returns the Values for processor ID (`procId`), and
   /// number of processors (`nprocs`) used to execute the parallel loops. The
   /// number of `{procId, nprocs}` pairs returned must be equal to the number of
-  /// `parallelLoopRanges` passed into the callback, which in-turn is same as
-  /// the number of parallel loops for which the `distributionMethod` is
-  /// specified below.
+  /// `parallelLoopRanges` passed into the callback. The `parallelLoopRanges`
+  /// are ranges of the outer parallel loops of the operation that
+  /// do have non-zero tile sizes specified.
   ProcInfoCallBackFn procInfo;
-  /// Specification of how to distribute the `scf.parallel` loops that are
-  /// generated. As the `scf.parallel` loop is generated, the elements of this
-  /// vector is used (from left to right) and the specified distribution is
-  /// applied. If the vector is less than the number of `scf.parallel` loops
-  /// generated, then no distribution is applied.
-  SmallVector<DistributionMethod, 0> distributionMethod = {};
-
-  /// The map keyed by the distribution type that contains callback functions
-  /// that return the Values for processor ID (`procId`), and number of
-  /// processors (`nprocs`) used to execute the parallel loops.
-  DenseMap<StringRef, OneDimProcInfoCallBackFn> procInfoMap;
 };
 
 /// Update the `lb`, `ub` and `step` to get per processor `lb`, `ub` and `step`.
@@ -521,8 +512,7 @@ struct GenerateLoopNest {
                    function_ref<scf::ValueVector(OpBuilder &, Location,
                                                  ValueRange, ValueRange)>
                        bodyBuilderFn,
-                   Optional<LinalgLoopDistributionOptions> = None,
-                   ArrayRef<StringRef> distributionTypes = {});
+                   ArrayRef<linalg::ProcInfo> procInfo = {});
 };
 
 } // namespace linalg

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index ddcc3470c9dcf..417e727f9dd83 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -450,6 +450,31 @@ tileLinalgOpImpl(RewriterBase &b, LinalgOp op, ArrayRef<OpFoldResult> tileSizes,
     applyPermutationToVector(iteratorTypes, permutation);
   }
 
+  // Handle distribution. Create a vector of the same size of loops that are to
+  // be tiled.
+  SmallVector<linalg::ProcInfo> procInfo;
+  if (options.distribution) {
+    procInfo.resize(
+        iteratorTypes.size(),
+        linalg::ProcInfo{nullptr, nullptr, linalg::DistributionMethod::None});
+    // Collect loop ranges of tiled loopss, loops that are parallel.
+    SmallVector<Range> parallelLoopRanges;
+    for (auto iteratorType : llvm::enumerate(iteratorTypes)) {
+      if (!isParallelIterator(iteratorType.value()))
+        break;
+      parallelLoopRanges.push_back(loopRanges[iteratorType.index()]);
+    }
+    auto returnedProcInfo =
+        options.distribution->procInfo(b, op.getLoc(), parallelLoopRanges);
+    unsigned procIdIdx = 0;
+    // Update the distribution information for the loops.
+    for (auto iteratorType : llvm::enumerate(iteratorTypes)) {
+      if (!isParallelIterator(iteratorType.value()))
+        break;
+      procInfo[iteratorType.index()] = returnedProcInfo[procIdIdx++];
+    }
+  }
+
   // 2. Create the tiled loops.
   LinalgOp res = op;
   SmallVector<Value, 4> ivs, tensorResults;
@@ -489,8 +514,7 @@ tileLinalgOpImpl(RewriterBase &b, LinalgOp op, ArrayRef<OpFoldResult> tileSizes,
     return scf::ValueVector(tensorResults.begin(), tensorResults.end());
   };
   GenerateLoopNest<LoopTy>::doit(b, op.getLoc(), loopRanges, op, iteratorTypes,
-                                 tiledLoopBodyBuilder, options.distribution,
-                                 options.distributionTypes);
+                                 tiledLoopBodyBuilder, procInfo);
 
   // 3. Transform IndexOp results w.r.t. the tiling.
   transformIndexOps(b, res, ivs, loopIndexToRangeIndex);

diff  --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 0390e0c7c19f4..a9c44f0ac8635 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -518,25 +518,11 @@ void GenerateLoopNest<scf::ForOp>::doit(
     function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange,
                                   ValueRange)>
         bodyBuilderFn,
-    Optional<LinalgLoopDistributionOptions> distributionOptions,
-    ArrayRef<StringRef> distributionTypes) {
+    ArrayRef<linalg::ProcInfo> procInfo) {
+  assert((procInfo.empty() || (procInfo.size() == loopRanges.size())) &&
+         "expected as many entries for proc info as number of loops, even if "
+         "they are null entries");
   SmallVector<Value> iterArgInitValues = linalgOp.getOutputTensorOperands();
-  // Create procInfo so it dominates loops, if appropriate.
-  SmallVector<ProcInfo, 4> procInfo;
-  SmallVector<DistributionMethod, 0> distributionMethod;
-  if (distributionOptions) {
-    // Collect loop ranges for parallel dimensions.
-    SmallVector<Range, 2> parallelLoopRanges;
-    for (const auto &iteratorType : enumerate(iteratorTypes))
-      if (isParallelIterator(iteratorType.value()))
-        parallelLoopRanges.push_back(loopRanges[iteratorType.index()]);
-
-    // Get their distribution schemes.
-    distributionMethod = distributionOptions->distributionMethod;
-    if (distributionMethod.size() < parallelLoopRanges.size())
-      parallelLoopRanges.resize(distributionMethod.size());
-    procInfo = distributionOptions->procInfo(b, loc, parallelLoopRanges);
-  }
 
   SmallVector<Value, 4> lbs, ubs, steps;
   unpackRanges(b, loc, loopRanges, lbs, ubs, steps);
@@ -554,20 +540,17 @@ void GenerateLoopNest<scf::ForOp>::doit(
         return bodyBuilderFn(b, loc, ivs, operandValuesToUse);
       });
 
-  if (!distributionOptions || loopNest.loops.empty())
+  if (loopNest.loops.empty() || procInfo.empty())
     return;
 
   // Filter out scf.for loops that were created out of parallel dimensions.
-  SmallVector<scf::ForOp, 4> loops;
-  for (const auto &iteratorType : enumerate(iteratorTypes))
-    if (isParallelIterator(iteratorType.value()))
-      loops.push_back(loopNest.loops[iteratorType.index()]);
-
-  // Distribute - only supports cyclic distribution for now.
-  for (auto it : llvm::zip(loops, procInfo, distributionMethod))
-    if (std::get<2>(it) == DistributionMethod::Cyclic)
-      mapLoopToProcessorIds(std::get<0>(it), std::get<1>(it).procId,
-                            std::get<1>(it).nprocs);
+  for (auto loop : llvm::enumerate(loopNest.loops)) {
+    if (procInfo[loop.index()].distributionMethod ==
+        DistributionMethod::Cyclic) {
+      mapLoopToProcessorIds(loop.value(), procInfo[loop.index()].procId,
+                            procInfo[loop.index()].nprocs);
+    }
+  }
 }
 
 /// Specialization to build affine "for" nest.
@@ -578,7 +561,7 @@ void GenerateLoopNest<AffineForOp>::doit(
     function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange,
                                   ValueRange)>
         bodyBuilderFn,
-    Optional<LinalgLoopDistributionOptions>, ArrayRef<StringRef>) {
+    ArrayRef<linalg::ProcInfo> /*procInfo*/) {
   SmallVector<Value> iterArgInitValues = linalgOp.getOutputTensorOperands();
   assert(iterArgInitValues.empty() && "unexpected AffineForOp init values");
   SmallVector<Value, 4> lbs, ubs, steps;
@@ -625,12 +608,13 @@ void updateBoundsForCyclicDistribution(OpBuilder &b, Location loc, Value procId,
 static void generateParallelLoopNest(
     OpBuilder &b, Location loc, ValueRange lbs, ValueRange ubs,
     ValueRange steps, ArrayRef<Attribute> iteratorTypes,
+    ArrayRef<linalg::ProcInfo> procInfo,
     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn,
-    SmallVectorImpl<Value> &ivStorage,
-    ArrayRef<DistributionMethod> distributionMethod = {}) {
+    SmallVectorImpl<Value> &ivStorage) {
   assert(lbs.size() == ubs.size());
   assert(lbs.size() == steps.size());
   assert(lbs.size() == iteratorTypes.size());
+  assert(procInfo.empty() || (lbs.size() == procInfo.size()));
 
   // If there are no (more) loops to be generated, generate the body and be
   // done with it.
@@ -639,55 +623,56 @@ static void generateParallelLoopNest(
     return;
   }
 
-  // Find the outermost parallel loops and drop their types from the list.
-  unsigned nLoops = iteratorTypes.size();
-  unsigned nOuterPar =
-      nLoops - iteratorTypes.drop_while(isParallelIterator).size();
-
   // If there are no outer parallel loops, generate one sequential loop and
-  // recurse. Note that we wouldn't have dropped anything from `iteratorTypes`
-  // in this case.
-  if (nOuterPar == 0) {
+  // recurse.
+  if (!isParallelIterator(iteratorTypes.front())) {
     LoopNest singleLoop = buildLoopNest(
         b, loc, lbs.take_front(), ubs.take_front(), steps.take_front(),
         [&](OpBuilder &b, Location loc, ValueRange ivs) {
           ivStorage.append(ivs.begin(), ivs.end());
-          generateParallelLoopNest(b, loc, lbs.drop_front(), ubs.drop_front(),
-                                   steps.drop_front(),
-                                   iteratorTypes.drop_front(), bodyBuilderFn,
-                                   ivStorage, distributionMethod);
+          generateParallelLoopNest(
+              b, loc, lbs.drop_front(), ubs.drop_front(), steps.drop_front(),
+              iteratorTypes.drop_front(),
+              procInfo.empty() ? procInfo : procInfo.drop_front(),
+              bodyBuilderFn, ivStorage);
         });
     return;
   }
-  if (distributionMethod.empty()) {
+
+  unsigned nLoops = iteratorTypes.size();
+  unsigned numProcessed = 0;
+  DistributionMethod distributionMethod = DistributionMethod::None;
+  if (procInfo.empty()) {
+    numProcessed = nLoops - iteratorTypes.drop_while(isParallelIterator).size();
+  } else {
+    distributionMethod = procInfo.front().distributionMethod;
+    numProcessed =
+        nLoops - procInfo
+                     .drop_while([&](linalg::ProcInfo p) {
+                       return p.distributionMethod == distributionMethod;
+                     })
+                     .size();
+  }
+
+  auto remainderProcInfo =
+      procInfo.empty() ? procInfo : procInfo.drop_front(numProcessed);
+  switch (distributionMethod) {
+  case DistributionMethod::None: {
     // Generate a single parallel loop-nest operation for all outermost
     // parallel loops and recurse.
     b.create<scf::ParallelOp>(
-        loc, lbs.take_front(nOuterPar), ubs.take_front(nOuterPar),
-        steps.take_front(nOuterPar),
+        loc, lbs.take_front(numProcessed), ubs.take_front(numProcessed),
+        steps.take_front(numProcessed),
         [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange localIvs) {
           ivStorage.append(localIvs.begin(), localIvs.end());
           generateParallelLoopNest(
-              nestedBuilder, nestedLoc, lbs.drop_front(nOuterPar),
-              ubs.drop_front(nOuterPar), steps.drop_front(nOuterPar),
-              iteratorTypes.drop_front(nOuterPar), bodyBuilderFn, ivStorage,
-              (distributionMethod.size() < nOuterPar)
-                  ? ArrayRef<DistributionMethod>()
-                  : distributionMethod.drop_front(nOuterPar));
+              nestedBuilder, nestedLoc, lbs.drop_front(numProcessed),
+              ubs.drop_front(numProcessed), steps.drop_front(numProcessed),
+              iteratorTypes.drop_front(numProcessed), remainderProcInfo,
+              bodyBuilderFn, ivStorage);
         });
     return;
   }
-
-  // Process all consecutive similarly distributed loops simultaneously.
-  DistributionMethod methodToUse = distributionMethod[0];
-  unsigned numProcessed = 1;
-  for (unsigned i = 1; i < nOuterPar && i < distributionMethod.size(); ++i) {
-    if (distributionMethod[i] != methodToUse)
-      break;
-    numProcessed++;
-  }
-
-  switch (methodToUse) {
   case DistributionMethod::Cyclic: {
     // Generate a single parallel loop-nest operation for all outermost
     // parallel loops and recurse.
@@ -699,10 +684,8 @@ static void generateParallelLoopNest(
           generateParallelLoopNest(
               nestedBuilder, nestedLoc, lbs.drop_front(numProcessed),
               ubs.drop_front(numProcessed), steps.drop_front(numProcessed),
-              iteratorTypes.drop_front(numProcessed), bodyBuilderFn, ivStorage,
-              (distributionMethod.size() < numProcessed)
-                  ? ArrayRef<DistributionMethod>()
-                  : distributionMethod.drop_front(numProcessed));
+              iteratorTypes.drop_front(numProcessed), remainderProcInfo,
+              bodyBuilderFn, ivStorage);
         });
     return;
   }
@@ -714,11 +697,11 @@ static void generateParallelLoopNest(
       cond = ab._and(cond, ab.slt(lbs[i], ubs[i]));
     ivStorage.append(lbs.begin(), std::next(lbs.begin(), numProcessed));
     b.create<scf::IfOp>(loc, cond, [&](OpBuilder &b, Location loc) {
-      generateParallelLoopNest(
-          b, loc, lbs.drop_front(numProcessed), ubs.drop_front(numProcessed),
-          steps.drop_front(numProcessed),
-          iteratorTypes.drop_front(numProcessed), bodyBuilderFn, ivStorage,
-          distributionMethod.drop_front(numProcessed));
+      generateParallelLoopNest(b, loc, lbs.drop_front(numProcessed),
+                               ubs.drop_front(numProcessed),
+                               steps.drop_front(numProcessed),
+                               iteratorTypes.drop_front(numProcessed),
+                               remainderProcInfo, bodyBuilderFn, ivStorage);
       b.create<scf::YieldOp>(loc, ValueRange{});
     });
     return;
@@ -730,7 +713,7 @@ static void generateParallelLoopNest(
     generateParallelLoopNest(
         b, loc, lbs.drop_front(numProcessed), ubs.drop_front(numProcessed),
         steps.drop_front(numProcessed), iteratorTypes.drop_front(numProcessed),
-        bodyBuilderFn, ivStorage, distributionMethod.drop_front(numProcessed));
+        remainderProcInfo, bodyBuilderFn, ivStorage);
     return;
   }
 }
@@ -743,13 +726,14 @@ void GenerateLoopNest<scf::ParallelOp>::doit(
     function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange,
                                   ValueRange)>
         bodyBuilderFn,
-    Optional<LinalgLoopDistributionOptions> distributionOptions,
-    ArrayRef<StringRef> distributionTypes) {
+    ArrayRef<linalg::ProcInfo> procInfo) {
   SmallVector<Value> iterArgInitValues = linalgOp.getOutputTensorOperands();
   assert(iterArgInitValues.empty() && "unexpected ParallelOp init values");
   // This function may be passed more iterator types than ranges.
   assert(iteratorTypes.size() >= loopRanges.size() &&
          "expected iterator type for all ranges");
+  assert((procInfo.empty() || (procInfo.size() == loopRanges.size())) &&
+         "expected proc information for all loops when present");
   iteratorTypes = iteratorTypes.take_front(loopRanges.size());
   SmallVector<Value, 8> lbsStorage, ubsStorage, stepsStorage, ivs;
   unsigned numLoops = iteratorTypes.size();
@@ -762,42 +746,22 @@ void GenerateLoopNest<scf::ParallelOp>::doit(
   unpackRanges(b, loc, loopRanges, lbsStorage, ubsStorage, stepsStorage);
 
   // Modify the lb, ub, and step based on the distribution options.
-  SmallVector<DistributionMethod, 0> distributionMethod;
-  if (distributionOptions) {
-    auto &options = *distributionOptions;
-    distributionMethod.assign(distributionOptions->distributionMethod.begin(),
-                              distributionOptions->distributionMethod.end());
-    SmallVector<Range, 2> parallelLoopRanges;
-    for (const auto &iteratorType : enumerate(iteratorTypes)) {
-      if (isParallelIterator(iteratorType.value()))
-        parallelLoopRanges.push_back(loopRanges[iteratorType.index()]);
-    }
-    if (distributionMethod.size() < parallelLoopRanges.size())
-      parallelLoopRanges.resize(distributionMethod.size());
-    SmallVector<ProcInfo, 2> procInfo =
-        options.procInfo(b, loc, parallelLoopRanges);
-    unsigned index = 0;
-    for (const auto &iteratorType : enumerate(iteratorTypes)) {
-      if (index >= procInfo.size())
-        break;
-      if (isParallelIterator(iteratorType.value())) {
-        unsigned i = iteratorType.index();
-        updateBoundsForCyclicDistribution(b, loc, procInfo[index].procId,
-                                          procInfo[index].nprocs, lbsStorage[i],
-                                          ubsStorage[i], stepsStorage[i]);
-        index++;
-      }
+  for (auto it : llvm::enumerate(procInfo)) {
+    if (it.value().distributionMethod != linalg::DistributionMethod::None) {
+      updateBoundsForCyclicDistribution(
+          b, loc, it.value().procId, it.value().nprocs, lbsStorage[it.index()],
+          ubsStorage[it.index()], stepsStorage[it.index()]);
     }
   }
   ValueRange lbs(lbsStorage), ubs(ubsStorage), steps(stepsStorage);
   generateParallelLoopNest(
-      b, loc, lbs, ubs, steps, iteratorTypes,
+      b, loc, lbs, ubs, steps, iteratorTypes, procInfo,
       [&](OpBuilder &b, Location loc, ValueRange ivs) {
         SmallVector<Value> operandValuesToUse =
             linalgOp.getInputAndOutputOperands();
         bodyBuilderFn(b, loc, ivs, operandValuesToUse);
       },
-      ivs, distributionMethod);
+      ivs);
 
   assert(ivs.size() == iteratorTypes.size() && "did not generate enough loops");
 }

diff  --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
index 67e1c368461d6..fc988eaa267c1 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
@@ -249,14 +249,16 @@ static void applyPatterns(func::FuncOp funcOp) {
 
 template <typename IdOp, typename NProcsOp>
 static SmallVector<ProcInfo, 2>
-getGpuProcIds(OpBuilder &b, Location loc, ArrayRef<Range> parallelLoopRanges) {
+getGpuProcIds(OpBuilder &b, Location loc, ArrayRef<Range> parallelLoopRanges,
+              ArrayRef<linalg::DistributionMethod> distributionMethod) {
   size_t count = std::min<size_t>(3, parallelLoopRanges.size());
   SmallVector<ProcInfo, 2> procInfo(count);
   Type indexType = b.getIndexType();
   for (unsigned i = 0; i < count; ++i) {
     gpu::Dimension dim = *gpu::symbolizeDimension(i);
     procInfo[count - 1 - i] = {b.create<IdOp>(loc, indexType, dim),
-                               b.create<NProcsOp>(loc, indexType, dim)};
+                               b.create<NProcsOp>(loc, indexType, dim),
+                               distributionMethod[count - 1 - i]};
   }
   return procInfo;
 }
@@ -265,10 +267,15 @@ static void fillTileAndDistributePatterns(MLIRContext *context,
                                           RewritePatternSet &patterns) {
   {
     LinalgLoopDistributionOptions cyclicNprocsEqNiters;
-    cyclicNprocsEqNiters.distributionMethod.resize(
-        2, DistributionMethod::CyclicNumProcsEqNumIters);
+    SmallVector<linalg::DistributionMethod> distributionMethod = {
+        DistributionMethod::CyclicNumProcsEqNumIters,
+        DistributionMethod::CyclicNumProcsEqNumIters};
     cyclicNprocsEqNiters.procInfo =
-        getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
+        [distributionMethod](OpBuilder &b, Location loc,
+                             ArrayRef<Range> parallelLoopRanges) {
+          return getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>(
+              b, loc, parallelLoopRanges, distributionMethod);
+        };
     patterns.add<LinalgTilingPattern>(
         MatmulOp::getOperationName(), context,
         LinalgTilingOptions()
@@ -282,10 +289,15 @@ static void fillTileAndDistributePatterns(MLIRContext *context,
 
   {
     LinalgLoopDistributionOptions cyclicNprocsGeNiters;
-    cyclicNprocsGeNiters.distributionMethod.resize(
-        2, DistributionMethod::CyclicNumProcsGeNumIters);
+    SmallVector<linalg::DistributionMethod> distributionMethod = {
+        DistributionMethod::CyclicNumProcsGeNumIters,
+        DistributionMethod::CyclicNumProcsGeNumIters};
     cyclicNprocsGeNiters.procInfo =
-        getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
+        [distributionMethod](OpBuilder &b, Location loc,
+                             ArrayRef<Range> parallelLoopRanges) {
+          return getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>(
+              b, loc, parallelLoopRanges, distributionMethod);
+        };
     patterns.add<LinalgTilingPattern>(
         MatmulOp::getOperationName(), context,
         LinalgTilingOptions()
@@ -299,10 +311,14 @@ static void fillTileAndDistributePatterns(MLIRContext *context,
 
   {
     LinalgLoopDistributionOptions cyclicNprocsDefault;
-    cyclicNprocsDefault.distributionMethod.resize(2,
-                                                  DistributionMethod::Cyclic);
+    SmallVector<linalg::DistributionMethod> distributionMethod = {
+        DistributionMethod::Cyclic, DistributionMethod::Cyclic};
     cyclicNprocsDefault.procInfo =
-        getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
+        [distributionMethod](OpBuilder &b, Location loc,
+                             ArrayRef<Range> parallelLoopRanges) {
+          return getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>(
+              b, loc, parallelLoopRanges, distributionMethod);
+        };
     patterns.add<LinalgTilingPattern>(
         MatmulOp::getOperationName(), context,
         LinalgTilingOptions()
@@ -316,10 +332,15 @@ static void fillTileAndDistributePatterns(MLIRContext *context,
 
   {
     LinalgLoopDistributionOptions cyclicNprocsMixed1;
-    cyclicNprocsMixed1.distributionMethod = {
+    SmallVector<linalg::DistributionMethod> distributionMethod = {
         DistributionMethod::CyclicNumProcsEqNumIters,
         DistributionMethod::CyclicNumProcsGeNumIters};
-    cyclicNprocsMixed1.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
+    cyclicNprocsMixed1.procInfo =
+        [distributionMethod](OpBuilder &b, Location loc,
+                             ArrayRef<Range> parallelLoopRanges) {
+          return getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>(
+              b, loc, parallelLoopRanges, distributionMethod);
+        };
     patterns.add<LinalgTilingPattern>(
         MatmulOp::getOperationName(), context,
         LinalgTilingOptions()
@@ -333,10 +354,15 @@ static void fillTileAndDistributePatterns(MLIRContext *context,
 
   {
     LinalgLoopDistributionOptions cyclicNprocsMixed2;
-    cyclicNprocsMixed2.distributionMethod = {
+    SmallVector<linalg::DistributionMethod> distributionMethod = {
         DistributionMethod::CyclicNumProcsGeNumIters,
         DistributionMethod::Cyclic};
-    cyclicNprocsMixed2.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
+    cyclicNprocsMixed2.procInfo =
+        [distributionMethod](OpBuilder &b, Location loc,
+                             ArrayRef<Range> parallelLoopRanges) {
+          return getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>(
+              b, loc, parallelLoopRanges, distributionMethod);
+        };
     patterns.add<LinalgTilingPattern>(
         MatmulOp::getOperationName(), context,
         LinalgTilingOptions()
@@ -350,10 +376,15 @@ static void fillTileAndDistributePatterns(MLIRContext *context,
 
   {
     LinalgLoopDistributionOptions cyclicNprocsMixed3;
-    cyclicNprocsMixed3.distributionMethod = {
+    SmallVector<linalg::DistributionMethod> distributionMethod = {
         DistributionMethod::Cyclic,
         DistributionMethod::CyclicNumProcsEqNumIters};
-    cyclicNprocsMixed3.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
+    cyclicNprocsMixed3.procInfo =
+        [distributionMethod](OpBuilder &b, Location loc,
+                             ArrayRef<Range> parallelLoopRanges) {
+          return getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>(
+              b, loc, parallelLoopRanges, distributionMethod);
+        };
 
     patterns.add<LinalgTilingPattern>(
         MatmulOp::getOperationName(), context,
@@ -368,10 +399,14 @@ static void fillTileAndDistributePatterns(MLIRContext *context,
 
   {
     LinalgLoopDistributionOptions cyclicNprocsEqNiters;
-    cyclicNprocsEqNiters.distributionMethod.resize(2,
-                                                   DistributionMethod::Cyclic);
+    SmallVector<linalg::DistributionMethod> distributionMethod = {
+        DistributionMethod::Cyclic, DistributionMethod::Cyclic};
     cyclicNprocsEqNiters.procInfo =
-        getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
+        [distributionMethod](OpBuilder &b, Location loc,
+                             ArrayRef<Range> parallelLoopRanges) {
+          return getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>(
+              b, loc, parallelLoopRanges, distributionMethod);
+        };
     patterns.add<LinalgTilingPattern>(
         MatmulOp::getOperationName(), context,
         LinalgTilingOptions()
@@ -387,8 +422,14 @@ static void fillTileAndDistributePatterns(MLIRContext *context,
 static void fillTileFuseAndDistributePatterns(MLIRContext *context,
                                               RewritePatternSet &patterns) {
   LinalgLoopDistributionOptions cyclicNprocsEqNiters;
-  cyclicNprocsEqNiters.distributionMethod.resize(2, DistributionMethod::Cyclic);
-  cyclicNprocsEqNiters.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
+  SmallVector<linalg::DistributionMethod> distributionMethod = {
+      DistributionMethod::Cyclic, DistributionMethod::Cyclic};
+  cyclicNprocsEqNiters.procInfo =
+      [distributionMethod](OpBuilder &b, Location loc,
+                           ArrayRef<Range> parallelLoopRanges) {
+        return getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>(
+            b, loc, parallelLoopRanges, distributionMethod);
+      };
   patterns.add<LinalgTileAndFuseTensorOpsPattern>(
       MatmulOp::getOperationName(), context,
       LinalgTilingAndFusionOptions()


        


More information about the Mlir-commits mailing list