[Mlir-commits] [mlir] 5ccac05 - [mlir][Linalg] Modify callback for getting id/nprocs in

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Aug 18 14:05:10 PDT 2020


Author: MaheshRavishankar
Date: 2020-08-18T14:04:40-07:00
New Revision: 5ccac05d433cf8a46683acb5293fb43280d0f2ed

URL: https://github.com/llvm/llvm-project/commit/5ccac05d433cf8a46683acb5293fb43280d0f2ed
DIFF: https://github.com/llvm/llvm-project/commit/5ccac05d433cf8a46683acb5293fb43280d0f2ed.diff

LOG: [mlir][Linalg] Modify callback for getting id/nprocs in
LinalgDistribution options to allow more general distributions.

Changing the signature of the callback to send in the ranges for all
the parallel loops and expect a vector with the Value to use for the
processor-id and number-of-processors for each of the parallel loops.

Differential Revision: https://reviews.llvm.org/D86095

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
    mlir/lib/Dialect/Linalg/Utils/Utils.cpp
    mlir/test/Dialect/Linalg/tile-and-distribute.mlir
    mlir/test/lib/Transforms/TestLinalgTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 794ebcbc2645..beef1a70096e 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -198,19 +198,23 @@ enum class DistributionMethod {
 };
 
 /// Callback function type used to get processor ID, and number of processors
-/// used for distribution.
+/// used for distribution for all parallel loops generated.
 struct ProcInfo {
   Value procId;
   Value nprocs;
 };
-using ProcInfoCallBackFn =
-    std::function<ProcInfo(OpBuilder &b, Location loc, unsigned loopNum)>;
+using ProcInfoCallBackFn = std::function<SmallVector<ProcInfo, 2>(
+    OpBuilder &b, Location loc, ArrayRef<SubViewOp::Range> parallelLoopRanges)>;
 
 /// Options that allow distribution of loops generated in Linalg transforms to
 /// processors while generating the loops.
 struct LinalgLoopDistributionOptions {
-  /// Callback function that returns the Value for processor ID, and number of
-  /// processors used to execute a given loop.
+  /// Callback function that returns the Values for processor ID (`procId`), and
+  /// number of processors (`nprocs`) used to execute the parallel loops. The
+  /// number of `{procId, nprocs}` pairs returned must be equal to the number of
+  /// `parallelLoopRanges` passed into the callback, which in-turn is same as
+  /// the number of parallel loops for which the `distributionMethod` is
+  /// specified below.
   ProcInfoCallBackFn procInfo;
   /// Specification of how to distribute the `scf.parallel` loops that are
   /// generated. As the `scf.parallel` loop is generated, the elements of this

diff  --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 4e9cbe9d913d..cf14555aa63f 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -334,21 +334,31 @@ void GenerateLoopNest<scf::ParallelOp>::doit(
   SmallVector<DistributionMethod, 0> distributionMethod;
   if (distributionOptions) {
     auto &options = distributionOptions.getValue();
-    unsigned index = 0;
     OpBuilder &builder = edsc::ScopedContext::getBuilderRef();
     Location loc = edsc::ScopedContext::getLocation();
     distributionMethod.assign(distributionOptions->distributionMethod.begin(),
                               distributionOptions->distributionMethod.end());
-    for (auto iteratorType : enumerate(iteratorTypes))
-      if (isParallelIteratorType(iteratorType.value()) &&
-          index < distributionMethod.size()) {
+    SmallVector<SubViewOp::Range, 2> parallelLoopRanges;
+    for (auto iteratorType : enumerate(iteratorTypes)) {
+      if (isParallelIteratorType(iteratorType.value()))
+        parallelLoopRanges.push_back(loopRanges[iteratorType.index()]);
+    }
+    if (distributionMethod.size() < parallelLoopRanges.size())
+      parallelLoopRanges.resize(distributionMethod.size());
+    SmallVector<ProcInfo, 2> procInfo =
+        options.procInfo(builder, loc, parallelLoopRanges);
+    unsigned index = 0;
+    for (auto iteratorType : enumerate(iteratorTypes)) {
+      if (index >= procInfo.size())
+        break;
+      if (isParallelIteratorType(iteratorType.value())) {
         unsigned i = iteratorType.index();
-        ProcInfo procInfo = options.procInfo(builder, loc, index);
-        updateBoundsForCyclicDistribution(builder, loc, procInfo.procId,
-                                          procInfo.nprocs, lbsStorage[i],
+        updateBoundsForCyclicDistribution(builder, loc, procInfo[index].procId,
+                                          procInfo[index].nprocs, lbsStorage[i],
                                           ubsStorage[i], stepsStorage[i]);
         index++;
       }
+    }
   }
   ValueRange lbs(lbsStorage), ubs(ubsStorage), steps(stepsStorage);
   generateParallelLoopNest(lbs, ubs, steps, iteratorTypes, bodyBuilderFn, ivs,

diff  --git a/mlir/test/Dialect/Linalg/tile-and-distribute.mlir b/mlir/test/Dialect/Linalg/tile-and-distribute.mlir
index e1bc28e133bd..08f6d19fe6d6 100644
--- a/mlir/test/Dialect/Linalg/tile-and-distribute.mlir
+++ b/mlir/test/Dialect/Linalg/tile-and-distribute.mlir
@@ -11,16 +11,16 @@ func @gemm1(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>)
 // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<?x?xf32>
 // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<?x?xf32>
 // CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<?x?xf32>
-//      CHECK: %[[T1:.*]] = "gpu.block_id"() {dimension = "y"}
-//      CHECK: %[[T2:.*]] = "gpu.block_id"() {dimension = "x"}
+//      CHECK: %[[BIDY:.*]] = "gpu.block_id"() {dimension = "y"}
+//      CHECK: %[[BIDX:.*]] = "gpu.block_id"() {dimension = "x"}
 //      CHECK: scf.for %[[ARG3:.*]] =
-//      CHECK:   %[[T3:.*]] = affine.apply #[[MAP0]]()[%[[T1]]]
-//      CHECK:   %[[SV1:.*]] = subview %[[ARG0]][%[[T3]], %[[ARG3]]]
-//      CHECK:   %[[T11:.*]] = affine.apply #[[MAP0]]()[%[[T2]]]
-//      CHECK:   %[[SV2:.*]] = subview %[[ARG1]][%[[ARG3]], %[[T11]]]
-//      CHECK:   %[[T15:.*]] = affine.apply #[[MAP0]]()[%[[T1]]]
-//      CHECK:   %[[T18:.*]] = affine.apply #[[MAP0]]()[%[[T2]]]
-//      CHECK:   %[[SV3:.*]] = subview %[[ARG2]][%[[T15]], %[[T18]]]
+//      CHECK:   %[[OFFSETY:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
+//      CHECK:   %[[SV1:.*]] = subview %[[ARG0]][%[[OFFSETY]], %[[ARG3]]]
+//      CHECK:   %[[OFFSETX:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
+//      CHECK:   %[[SV2:.*]] = subview %[[ARG1]][%[[ARG3]], %[[OFFSETX]]]
+//      CHECK:   %[[OFFSETY_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
+//      CHECK:   %[[OFFSETX:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
+//      CHECK:   %[[SV3:.*]] = subview %[[ARG2]][%[[OFFSETY_2]], %[[OFFSETX]]]
 //      CHECK:   linalg.matmul %[[SV1]], %[[SV2]], %[[SV3]]
 
 // -----
@@ -36,22 +36,22 @@ func @gemm2(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>)
 // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<?x?xf32>
 // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<?x?xf32>
 // CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<?x?xf32>
-//      CHECK: %[[T3:.*]] = "gpu.block_id"() {dimension = "y"}
-//      CHECK: %[[T4:.*]] = affine.apply #[[MAP0]]()[%[[T3]]]
-//      CHECK: %[[T5:.*]] = "gpu.block_id"() {dimension = "x"}
-//      CHECK: %[[T6:.*]] = affine.apply #[[MAP0]]()[%[[T5]]]
-//      CHECK: %[[T7:.*]] = cmpi "slt", %[[T4]], %{{.*}}
-//      CHECK: %[[T8:.*]] = cmpi "slt", %[[T6]], %{{.*}}
-//      CHECK: %[[T9:.*]] = and %[[T7]], %[[T8]]
-//      CHECK: scf.if %[[T9]]
+//  CHECK-DAG: %[[BIDY:.*]] = "gpu.block_id"() {dimension = "y"}
+//  CHECK-DAG: %[[BIDX:.*]] = "gpu.block_id"() {dimension = "x"}
+//      CHECK: %[[ITERY:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
+//      CHECK: %[[ITERX:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
+//      CHECK: %[[INBOUNDSY:.*]] = cmpi "slt", %[[ITERY]], %{{.*}}
+//      CHECK: %[[INBOUNDSX:.*]] = cmpi "slt", %[[ITERX]], %{{.*}}
+//      CHECK: %[[INBOUNDS:.*]] = and %[[INBOUNDSY]], %[[INBOUNDSX]]
+//      CHECK: scf.if %[[INBOUNDS]]
 //      CHECK:   scf.for %[[ARG3:.*]] =
-//      CHECK:     %[[T10:.*]] = affine.apply #[[MAP0]]()[%[[T3]]]
-//      CHECK:     %[[SV1:.*]] = subview %[[ARG0]][%[[T10]], %[[ARG3]]]
-//      CHECK:     %[[T18:.*]] = affine.apply #[[MAP0]]()[%[[T5]]]
-//      CHECK:     %[[SV2:.*]] = subview %[[ARG1]][%[[ARG3]], %[[T18]]]
-//      CHECK:     %[[T22:.*]] = affine.apply #[[MAP0]]()[%[[T3]]]
-//      CHECK:     %[[T25:.*]] = affine.apply #[[MAP0]]()[%[[T5]]]
-//      CHECK:     %[[SV3:.*]] = subview %[[ARG2]][%[[T22]], %[[T25]]]
+//      CHECK:     %[[OFFSETY:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
+//      CHECK:     %[[SV1:.*]] = subview %[[ARG0]][%[[OFFSETY]], %[[ARG3]]]
+//      CHECK:     %[[OFFSETX:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
+//      CHECK:     %[[SV2:.*]] = subview %[[ARG1]][%[[ARG3]], %[[OFFSETX]]]
+//      CHECK:     %[[OFFSETY_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
+//      CHECK:     %[[OFFSETX_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
+//      CHECK:     %[[SV3:.*]] = subview %[[ARG2]][%[[OFFSETY_2]], %[[OFFSETX_2]]]
 //      CHECK:     linalg.matmul %[[SV1]], %[[SV2]], %[[SV3]]
 
 // -----
@@ -67,15 +67,15 @@ func @gemm3(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>)
 // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<?x?xf32>
 // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<?x?xf32>
 // CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<?x?xf32>
-//      CHECK: %[[T3:.*]] = "gpu.block_id"() {dimension = "y"}
-//      CHECK: %[[T4:.*]] = "gpu.grid_dim"() {dimension = "y"}
-//      CHECK: %[[T5:.*]] = affine.apply #[[MAP0]]()[%[[T3]]]
-//      CHECK: %[[T6:.*]] = affine.apply #[[MAP0]]()[%[[T4]]]
-//      CHECK: %[[T7:.*]] = "gpu.block_id"() {dimension = "x"}
-//      CHECK: %[[T8:.*]] = "gpu.grid_dim"() {dimension = "x"}
-//      CHECK: %[[T9:.*]] = affine.apply #[[MAP0]]()[%[[T7]]]
-//      CHECK: %[[T10:.*]] = affine.apply #[[MAP0]]()[%[[T8]]]
-//      CHECK: scf.parallel (%[[ARG3:.*]], %[[ARG4:.*]]) = (%[[T5]], %[[T9]]) to (%{{.*}}, %{{.*}}) step (%[[T6]], %[[T10]])
+//      CHECK: %[[BIDY:.*]] = "gpu.block_id"() {dimension = "y"}
+//      CHECK: %[[NBLOCKSY:.*]] = "gpu.grid_dim"() {dimension = "y"}
+//      CHECK: %[[BIDX:.*]] = "gpu.block_id"() {dimension = "x"}
+//      CHECK: %[[NBLOCKSX:.*]] = "gpu.grid_dim"() {dimension = "x"}
+//      CHECK: %[[LBY:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
+//      CHECK: %[[STEPY:.*]] = affine.apply #[[MAP0]]()[%[[NBLOCKSY]]]
+//      CHECK: %[[LBX:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
+//      CHECK: %[[STEPX:.*]] = affine.apply #[[MAP0]]()[%[[NBLOCKSX]]]
+//      CHECK: scf.parallel (%[[ARG3:.*]], %[[ARG4:.*]]) = (%[[LBY]], %[[LBX]]) to (%{{.*}}, %{{.*}}) step (%[[STEPY]], %[[STEPX]])
 //      CHECK:   scf.for %[[ARG5:.*]] =
 //      CHECK:     %[[SV1:.*]] = subview %[[ARG0]][%[[ARG3]], %[[ARG5]]]
 //      CHECK:     %[[SV2:.*]] = subview %[[ARG1]][%[[ARG5]], %[[ARG4]]]
@@ -95,19 +95,19 @@ func @gemm4(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>)
 // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<?x?xf32>
 // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<?x?xf32>
 // CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<?x?xf32>
-//      CHECK: %[[T2:.*]] = "gpu.block_id"() {dimension = "y"}
-//      CHECK: %[[T3:.*]] = "gpu.block_id"() {dimension = "x"}
-//      CHECK: %[[T4:.*]] = affine.apply #[[MAP0]]()[%[[T3]]]
-//      CHECK: %[[T5:.*]] = cmpi "slt", %[[T4]], %{{.*}}
-//      CHECK: scf.if %[[T5]]
+//      CHECK: %[[BIDY:.*]] = "gpu.block_id"() {dimension = "y"}
+//      CHECK: %[[BIDX:.*]] = "gpu.block_id"() {dimension = "x"}
+//      CHECK: %[[LBX:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
+//      CHECK: %[[INBOUNDS:.*]] = cmpi "slt", %[[LBX]], %{{.*}}
+//      CHECK: scf.if %[[INBOUNDS]]
 //      CHECK:   scf.for %[[ARG3:.*]] =
-//      CHECK:     %[[T6:.*]] = affine.apply #[[MAP0]]()[%[[T2]]]
-//      CHECK:     %[[SV1:.*]] = subview %[[ARG0]][%[[T6]], %[[ARG3]]]
-//      CHECK:     %[[T14:.*]] = affine.apply #[[MAP0]]()[%[[T3]]]
-//      CHECK:     %[[SV2:.*]] = subview %[[ARG1]][%[[ARG3]], %[[T14]]]
-//      CHECK:     %[[T18:.*]] = affine.apply #[[MAP0]]()[%[[T2]]]
-//      CHECK:     %[[T21:.*]] = affine.apply #[[MAP0]]()[%[[T3]]]
-//      CHECK:     %[[SV3:.*]] = subview %[[ARG2]][%[[T18]], %[[T21]]]
+//      CHECK:     %[[OFFSETY:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
+//      CHECK:     %[[SV1:.*]] = subview %[[ARG0]][%[[OFFSETY]], %[[ARG3]]]
+//      CHECK:     %[[OFFSETX:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
+//      CHECK:     %[[SV2:.*]] = subview %[[ARG1]][%[[ARG3]], %[[OFFSETX]]]
+//      CHECK:     %[[OFFSETY_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
+//      CHECK:     %[[OFFSETX_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
+//      CHECK:     %[[SV3:.*]] = subview %[[ARG2]][%[[OFFSETY_2]], %[[OFFSETX_2]]]
 //      CHECK:     linalg.matmul %[[SV1]], %[[SV2]], %[[SV3]]
 
 // -----
@@ -123,21 +123,21 @@ func @gemm5(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>)
 // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<?x?xf32>
 // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<?x?xf32>
 // CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<?x?xf32>
-//      CHECK: %[[T3:.*]] = "gpu.block_id"() {dimension = "y"}
-//      CHECK: %[[T4:.*]] = affine.apply #[[MAP0]]()[%[[T3]]]
-//      CHECK: %[[T5:.*]] = "gpu.block_id"() {dimension = "x"}
-//      CHECK: %[[T6:.*]] = "gpu.grid_dim"() {dimension = "x"}
-//      CHECK: %[[T7:.*]] = affine.apply #[[MAP0]]()[%[[T5]]]
-//      CHECK: %[[T8:.*]] = affine.apply #[[MAP0]]()[%[[T6]]]
-//      CHECK: %[[T9:.*]] = cmpi "slt", %[[T4]], %{{.*}}
-//      CHECK: scf.if %[[T9]]
-//      CHECK:   scf.parallel (%[[ARG3.*]]) = (%[[T7]]) to (%{{.*}}) step (%[[T8]])
+//      CHECK: %[[BIDY:.*]] = "gpu.block_id"() {dimension = "y"}
+//      CHECK: %[[BIDX:.*]] = "gpu.block_id"() {dimension = "x"}
+//      CHECK: %[[NBLOCKSX:.*]] = "gpu.grid_dim"() {dimension = "x"}
+//      CHECK: %[[LBY:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
+//      CHECK: %[[LBX:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
+//      CHECK: %[[STEPX:.*]] = affine.apply #[[MAP0]]()[%[[NBLOCKSX]]]
+//      CHECK: %[[INBOUNDS:.*]] = cmpi "slt", %[[LBY]], %{{.*}}
+//      CHECK: scf.if %[[INBOUNDS]]
+//      CHECK:   scf.parallel (%[[ARG3.*]]) = (%[[LBX]]) to (%{{.*}}) step (%[[STEPX]])
 //      CHECK:     scf.for %[[ARG4:.*]] =
-//      CHECK:      %[[T10:.*]] = affine.apply #[[MAP0]]()[%[[T3]]]
-//      CHECK:       %[[SV1:.*]] = subview %[[ARG0]][%[[T10]], %[[ARG4]]]
+//      CHECK:      %[[OFFSETY:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
+//      CHECK:       %[[SV1:.*]] = subview %[[ARG0]][%[[OFFSETY]], %[[ARG4]]]
 //      CHECK:       %[[SV2:.*]] = subview %[[ARG1]][%[[ARG4]], %[[ARG3]]]
-//      CHECK:       %[[T21:.*]] = affine.apply #[[MAP0]]()[%[[T3]]]
-//      CHECK:       %[[SV3:.*]] = subview %[[ARG2]][%[[T21]], %[[ARG3]]]
+//      CHECK:       %[[OFFSETY_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
+//      CHECK:       %[[SV3:.*]] = subview %[[ARG2]][%[[OFFSETY_2]], %[[ARG3]]]
 //      CHECK:       linalg.matmul %[[SV1]], %[[SV2]], %[[SV3]]
 
 // -----
@@ -153,16 +153,16 @@ func @gemm6(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>)
 // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<?x?xf32>
 // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<?x?xf32>
 // CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<?x?xf32>
-//      CHECK: %[[T2:.*]] = "gpu.block_id"() {dimension = "y"}
-//      CHECK: %[[T3:.*]] = "gpu.grid_dim"() {dimension = "y"}
-//      CHECK: %[[T4:.*]] = affine.apply #[[MAP0]]()[%[[T2]]]
-//      CHECK: %[[T5:.*]] = affine.apply #[[MAP0]]()[%[[T3]]]
-//      CHECK: %[[T6:.*]] = "gpu.block_id"() {dimension = "x"}
-//      CHECK: scf.parallel (%[[ARG3.*]]) = (%[[T4]]) to (%{{.*}}) step (%[[T5]])
+//      CHECK: %[[BIDY:.*]] = "gpu.block_id"() {dimension = "y"}
+//      CHECK: %[[NBLOCKSY:.*]] = "gpu.grid_dim"() {dimension = "y"}
+//      CHECK: %[[BIDX:.*]] = "gpu.block_id"() {dimension = "x"}
+//      CHECK: %[[LBY:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
+//      CHECK: %[[STEPY:.*]] = affine.apply #[[MAP0]]()[%[[NBLOCKSY]]]
+//      CHECK: scf.parallel (%[[ARG3.*]]) = (%[[LBY]]) to (%{{.*}}) step (%[[STEPY]])
 //      CHECK:   scf.for %[[ARG4:.*]] =
 //      CHECK:     %[[SV1:.*]] = subview %[[ARG0]][%[[ARG3]], %[[ARG4]]]
-//      CHECK:     %[[T14:.*]] = affine.apply #[[MAP0]]()[%[[T6]]]
-//      CHECK:     %[[SV2:.*]] = subview %[[ARG1]][%[[ARG4]], %[[T14]]]
-//      CHECK:     %[[T20:.*]] = affine.apply #[[MAP0]]()[%[[T6]]]
-//      CHECK:     %[[SV3:.*]] = subview %[[ARG2]][%[[ARG3]], %[[T20]]]
+//      CHECK:     %[[OFFSETX:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
+//      CHECK:     %[[SV2:.*]] = subview %[[ARG1]][%[[ARG4]], %[[OFFSETX]]]
+//      CHECK:     %[[OFFSETX_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
+//      CHECK:     %[[SV3:.*]] = subview %[[ARG2]][%[[ARG3]], %[[OFFSETX_2]]]
 //      CHECK:     linalg.matmul %[[SV1]], %[[SV2]], %[[SV3]]

diff  --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
index f6c1160d35b0..dffe4f2a0796 100644
--- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
@@ -289,19 +289,16 @@ static void fillPromotionCallBackPatterns(MLIRContext *ctx,
 }
 
 template <typename IdOp, typename NProcsOp>
-static ProcInfo getGpuProcIds(OpBuilder &b, Location loc, unsigned loopNum) {
+static SmallVector<ProcInfo, 2>
+getGpuProcIds(OpBuilder &b, Location loc,
+              ArrayRef<SubViewOp::Range> parallelLoopRanges) {
   Type indexType = b.getIndexType();
-  switch (loopNum) {
-  case 0:
-    return {b.create<IdOp>(loc, indexType, b.getStringAttr("y")),
-            b.create<NProcsOp>(loc, indexType, b.getStringAttr("y"))};
-  case 1:
-    return {b.create<IdOp>(loc, indexType, b.getStringAttr("x")),
-            b.create<NProcsOp>(loc, indexType, b.getStringAttr("x"))};
-  default:
-    llvm_unreachable("test patterns handles only upto 2-level nested loops");
-  }
-  return {nullptr, nullptr};
+  SmallVector<ProcInfo, 2> procInfo(2);
+  procInfo[0] = {b.create<IdOp>(loc, indexType, b.getStringAttr("y")),
+                 b.create<NProcsOp>(loc, indexType, b.getStringAttr("y"))};
+  procInfo[1] = {b.create<IdOp>(loc, indexType, b.getStringAttr("x")),
+                 b.create<NProcsOp>(loc, indexType, b.getStringAttr("x"))};
+  return procInfo;
 }
 
 static void fillTileAndDistributePatterns(MLIRContext *context,


        


More information about the Mlir-commits mailing list