[Mlir-commits] [mlir] 7e71823 - [mlir][linalg] Restrict distribution to parallel dims

Lei Zhang llvmlistbot at llvm.org
Mon May 10 12:23:14 PDT 2021


Author: Lei Zhang
Date: 2021-05-10T15:23:00-04:00
New Revision: 7e71823f1deb54a1465bc4040f4e3158357f71df

URL: https://github.com/llvm/llvm-project/commit/7e71823f1deb54a1465bc4040f4e3158357f71df
DIFF: https://github.com/llvm/llvm-project/commit/7e71823f1deb54a1465bc4040f4e3158357f71df.diff

LOG: [mlir][linalg] Restrict distribution to parallel dims

According to the API contract, LinalgLoopDistributionOptions
expects to work on parallel iterators. When getting processor
information, only loop ranges for parallel dimensions should
be fed in. But right now after generating scf.for loop nests,
we feed in *all* loops, including the ones materialized for
reduction iterators. This can cause unexpected distribution
of reduction dimensions. This commit fixes it.

Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D102079

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Utils/Utils.cpp
    mlir/test/Dialect/Linalg/tile-and-distribute.mlir
    mlir/test/lib/Transforms/TestLinalgTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 022a57343bc82..0bba27d931df9 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -205,21 +205,39 @@ void GenerateLoopNest<scf::ForOp>::doit(
   // Create procInfo so it dominates loops, if appropriate.
   OpBuilder &builder = edsc::ScopedContext::getBuilderRef();
   Location loc = edsc::ScopedContext::getLocation();
-  SmallVector<ProcInfo, 2> procInfo;
-  if (distributionOptions.hasValue())
-    procInfo = distributionOptions->procInfo(builder, loc, loopRanges);
+
+  SmallVector<ProcInfo, 4> procInfo;
+  SmallVector<DistributionMethod, 0> distributionMethod;
+  if (distributionOptions.hasValue()) {
+    // Collect loop ranges for parallel dimensions.
+    SmallVector<Range, 2> parallelLoopRanges;
+    for (auto iteratorType : enumerate(iteratorTypes))
+      if (isParallelIteratorType(iteratorType.value()))
+        parallelLoopRanges.push_back(loopRanges[iteratorType.index()]);
+
+    // Get their distribution schemes.
+    distributionMethod = distributionOptions->distributionMethod;
+    if (distributionMethod.size() < parallelLoopRanges.size())
+      parallelLoopRanges.resize(distributionMethod.size());
+    procInfo = distributionOptions->procInfo(builder, loc, parallelLoopRanges);
+  }
 
   SmallVector<Value, 4> lbs, ubs, steps;
   unpackRanges(loopRanges, lbs, ubs, steps);
   LoopNest loopNest =
       edsc::loopNestBuilder(lbs, ubs, steps, iterArgInitValues, bodyBuilderFn);
 
-  if (!distributionOptions.hasValue() || loopNest.loops.empty())
+  if (!distributionOptions || loopNest.loops.empty())
     return;
 
-  // Only supports cyclic distribution for now.
-  for (auto it : llvm::zip(loopNest.loops, procInfo,
-                           distributionOptions->distributionMethod))
+  // Filter out scf.for loops that were created out of parallel dimensions.
+  SmallVector<scf::ForOp, 4> loops;
+  for (auto iteratorType : enumerate(iteratorTypes))
+    if (isParallelIteratorType(iteratorType.value()))
+      loops.push_back(loopNest.loops[iteratorType.index()]);
+
+  // Distribute - only supports cyclic distribution for now.
+  for (auto it : llvm::zip(loops, procInfo, distributionMethod))
     if (std::get<2>(it) == DistributionMethod::Cyclic)
       mapLoopToProcessorIds(std::get<0>(it), std::get<1>(it).procId,
                             std::get<1>(it).nprocs);

diff  --git a/mlir/test/Dialect/Linalg/tile-and-distribute.mlir b/mlir/test/Dialect/Linalg/tile-and-distribute.mlir
index a6756306cff89..59c34bfacd934 100644
--- a/mlir/test/Dialect/Linalg/tile-and-distribute.mlir
+++ b/mlir/test/Dialect/Linalg/tile-and-distribute.mlir
@@ -12,8 +12,8 @@ func @gemm1(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>)
 // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<?x?xf32>
 // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<?x?xf32>
 // CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<?x?xf32>
-//      CHECK: %[[BIDY:.*]] = "gpu.block_id"() {dimension = "y"}
-//      CHECK: %[[BIDX:.*]] = "gpu.block_id"() {dimension = "x"}
+//  CHECK-DAG: %[[BIDY:.*]] = "gpu.block_id"() {dimension = "y"}
+//  CHECK-DAG: %[[BIDX:.*]] = "gpu.block_id"() {dimension = "x"}
 //      CHECK: scf.for %[[ARG3:.*]] =
 //      CHECK:   %[[OFFSETY:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
 //      CHECK:   %[[SV1:.*]] = memref.subview %[[ARG0]][%[[OFFSETY]], %[[ARG3]]]
@@ -70,10 +70,10 @@ func @gemm3(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>)
 // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<?x?xf32>
 // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<?x?xf32>
 // CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<?x?xf32>
-//      CHECK: %[[BIDY:.*]] = "gpu.block_id"() {dimension = "y"}
-//      CHECK: %[[NBLOCKSY:.*]] = "gpu.grid_dim"() {dimension = "y"}
-//      CHECK: %[[BIDX:.*]] = "gpu.block_id"() {dimension = "x"}
-//      CHECK: %[[NBLOCKSX:.*]] = "gpu.grid_dim"() {dimension = "x"}
+//  CHECK-DAG: %[[BIDY:.*]] = "gpu.block_id"() {dimension = "y"}
+//  CHECK-DAG: %[[NBLOCKSY:.*]] = "gpu.grid_dim"() {dimension = "y"}
+//  CHECK-DAG: %[[BIDX:.*]] = "gpu.block_id"() {dimension = "x"}
+//  CHECK-DAG: %[[NBLOCKSX:.*]] = "gpu.grid_dim"() {dimension = "x"}
 //      CHECK: %[[LBY:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
 //      CHECK: %[[STEPY:.*]] = affine.apply #[[MAP0]]()[%[[NBLOCKSY]]]
 //      CHECK: %[[LBX:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
@@ -99,8 +99,8 @@ func @gemm4(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>)
 // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<?x?xf32>
 // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<?x?xf32>
 // CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<?x?xf32>
-//      CHECK: %[[BIDY:.*]] = "gpu.block_id"() {dimension = "y"}
-//      CHECK: %[[BIDX:.*]] = "gpu.block_id"() {dimension = "x"}
+//  CHECK-DAG: %[[BIDY:.*]] = "gpu.block_id"() {dimension = "y"}
+//  CHECK-DAG: %[[BIDX:.*]] = "gpu.block_id"() {dimension = "x"}
 //      CHECK: %[[LBX:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
 //      CHECK: %[[INBOUNDS:.*]] = cmpi slt, %[[LBX]], %{{.*}}
 //      CHECK: scf.if %[[INBOUNDS]]
@@ -128,9 +128,9 @@ func @gemm5(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>)
 // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<?x?xf32>
 // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<?x?xf32>
 // CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<?x?xf32>
-//      CHECK: %[[BIDY:.*]] = "gpu.block_id"() {dimension = "y"}
-//      CHECK: %[[BIDX:.*]] = "gpu.block_id"() {dimension = "x"}
-//      CHECK: %[[NBLOCKSX:.*]] = "gpu.grid_dim"() {dimension = "x"}
+//  CHECK-DAG: %[[BIDY:.*]] = "gpu.block_id"() {dimension = "y"}
+//  CHECK-DAG: %[[BIDX:.*]] = "gpu.block_id"() {dimension = "x"}
+//  CHECK-DAG: %[[NBLOCKSX:.*]] = "gpu.grid_dim"() {dimension = "x"}
 //      CHECK: %[[LBY:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
 //      CHECK: %[[LBX:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
 //      CHECK: %[[STEPX:.*]] = affine.apply #[[MAP0]]()[%[[NBLOCKSX]]]
@@ -159,9 +159,9 @@ func @gemm6(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>)
 // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<?x?xf32>
 // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<?x?xf32>
 // CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<?x?xf32>
-//      CHECK: %[[BIDY:.*]] = "gpu.block_id"() {dimension = "y"}
-//      CHECK: %[[NBLOCKSY:.*]] = "gpu.grid_dim"() {dimension = "y"}
-//      CHECK: %[[BIDX:.*]] = "gpu.block_id"() {dimension = "x"}
+//  CHECK-DAG: %[[BIDY:.*]] = "gpu.block_id"() {dimension = "y"}
+//  CHECK-DAG: %[[NBLOCKSY:.*]] = "gpu.grid_dim"() {dimension = "y"}
+//  CHECK-DAG: %[[BIDX:.*]] = "gpu.block_id"() {dimension = "x"}
 //      CHECK: %[[LBY:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
 //      CHECK: %[[STEPY:.*]] = affine.apply #[[MAP0]]()[%[[NBLOCKSY]]]
 //      CHECK: scf.parallel (%[[ARG3:.*]]) = (%[[LBY]]) to (%{{.*}}) step (%[[STEPY]])
@@ -186,10 +186,10 @@ func @matmul_tensors(
     -> tensor<?x?xf32> {
 //  CHECK-DAG: %[[C8:.*]] = constant 8 : index
 //  CHECK-DAG: %[[C0:.*]] = constant 0 : index
-//      CHECK: %[[BIDY:.*]] = "gpu.block_id"() {dimension = "y"}
-//      CHECK: %[[NBLOCKSY:.*]] = "gpu.grid_dim"() {dimension = "y"}
-//      CHECK: %[[BIDX:.*]] = "gpu.block_id"() {dimension = "x"}
-//      CHECK: %[[NBLOCKSX:.*]] = "gpu.grid_dim"() {dimension = "x"}
+//  CHECK-DAG: %[[BIDY:.*]] = "gpu.block_id"() {dimension = "y"}
+//  CHECK-DAG: %[[NBLOCKSY:.*]] = "gpu.grid_dim"() {dimension = "y"}
+//  CHECK-DAG: %[[BIDX:.*]] = "gpu.block_id"() {dimension = "x"}
+//  CHECK-DAG: %[[NBLOCKSX:.*]] = "gpu.grid_dim"() {dimension = "x"}
 //      CHECK: %[[MUL:.+]] = affine.apply #[[MULMAP]]()[%[[BIDY]], %[[C8]]]
 //      CHECK: %[[LBY:.+]] = affine.apply #[[ADDMAP]]()[%[[MUL]], %[[C0]]]
 //      CHECK: %[[STEPY:.+]] = affine.apply #[[MULMAP]]()[%[[NBLOCKSY]], %[[C8]]]

diff  --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
index 73282748e0c7c..94ab9b951c37e 100644
--- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
@@ -333,12 +333,15 @@ static void fillPromotionCallBackPatterns(MLIRContext *ctx,
 template <typename IdOp, typename NProcsOp>
 static SmallVector<ProcInfo, 2>
 getGpuProcIds(OpBuilder &b, Location loc, ArrayRef<Range> parallelLoopRanges) {
+  size_t count = std::min<size_t>(3, parallelLoopRanges.size());
+  SmallVector<ProcInfo, 2> procInfo(count);
+  const char *xyz[] = {"x", "y", "z"};
   Type indexType = b.getIndexType();
-  SmallVector<ProcInfo, 2> procInfo(2);
-  procInfo[0] = {b.create<IdOp>(loc, indexType, b.getStringAttr("y")),
-                 b.create<NProcsOp>(loc, indexType, b.getStringAttr("y"))};
-  procInfo[1] = {b.create<IdOp>(loc, indexType, b.getStringAttr("x")),
-                 b.create<NProcsOp>(loc, indexType, b.getStringAttr("x"))};
+  for (unsigned i = 0; i < count; ++i) {
+    procInfo[count - 1 - i] = {
+        b.create<IdOp>(loc, indexType, b.getStringAttr(xyz[i])),
+        b.create<NProcsOp>(loc, indexType, b.getStringAttr(xyz[i]))};
+  }
   return procInfo;
 }
 


        


More information about the Mlir-commits mailing list