[Mlir-commits] [mlir] 5ccac05 - [mlir][Linalg] Modify callback for getting id/nprocs in
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Aug 18 14:05:10 PDT 2020
Author: MaheshRavishankar
Date: 2020-08-18T14:04:40-07:00
New Revision: 5ccac05d433cf8a46683acb5293fb43280d0f2ed
URL: https://github.com/llvm/llvm-project/commit/5ccac05d433cf8a46683acb5293fb43280d0f2ed
DIFF: https://github.com/llvm/llvm-project/commit/5ccac05d433cf8a46683acb5293fb43280d0f2ed.diff
LOG: [mlir][Linalg] Modify callback for getting id/nprocs in
LinalgDistribution options to allow more general distributions.
Changing the signature of the callback to send in the ranges for all
the parallel loops and expect a vector with the Value to use for the
processor-id and number-of-processors for each of the parallel loops.
Differential Revision: https://reviews.llvm.org/D86095
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
mlir/lib/Dialect/Linalg/Utils/Utils.cpp
mlir/test/Dialect/Linalg/tile-and-distribute.mlir
mlir/test/lib/Transforms/TestLinalgTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 794ebcbc2645..beef1a70096e 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -198,19 +198,23 @@ enum class DistributionMethod {
};
/// Callback function type used to get processor ID, and number of processors
-/// used for distribution.
+/// used for distribution for all parallel loops generated.
struct ProcInfo {
Value procId;
Value nprocs;
};
-using ProcInfoCallBackFn =
- std::function<ProcInfo(OpBuilder &b, Location loc, unsigned loopNum)>;
+using ProcInfoCallBackFn = std::function<SmallVector<ProcInfo, 2>(
+ OpBuilder &b, Location loc, ArrayRef<SubViewOp::Range> parallelLoopRanges)>;
/// Options that allow distribution of loops generated in Linalg transforms to
/// processors while generating the loops.
struct LinalgLoopDistributionOptions {
- /// Callback function that returns the Value for processor ID, and number of
- /// processors used to execute a given loop.
+ /// Callback function that returns the Values for processor ID (`procId`), and
+ /// number of processors (`nprocs`) used to execute the parallel loops. The
+ /// number of `{procId, nprocs}` pairs returned must be equal to the number of
+ /// `parallelLoopRanges` passed into the callback, which in-turn is same as
+ /// the number of parallel loops for which the `distributionMethod` is
+ /// specified below.
ProcInfoCallBackFn procInfo;
/// Specification of how to distribute the `scf.parallel` loops that are
/// generated. As the `scf.parallel` loop is generated, the elements of this
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 4e9cbe9d913d..cf14555aa63f 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -334,21 +334,31 @@ void GenerateLoopNest<scf::ParallelOp>::doit(
SmallVector<DistributionMethod, 0> distributionMethod;
if (distributionOptions) {
auto &options = distributionOptions.getValue();
- unsigned index = 0;
OpBuilder &builder = edsc::ScopedContext::getBuilderRef();
Location loc = edsc::ScopedContext::getLocation();
distributionMethod.assign(distributionOptions->distributionMethod.begin(),
distributionOptions->distributionMethod.end());
- for (auto iteratorType : enumerate(iteratorTypes))
- if (isParallelIteratorType(iteratorType.value()) &&
- index < distributionMethod.size()) {
+ SmallVector<SubViewOp::Range, 2> parallelLoopRanges;
+ for (auto iteratorType : enumerate(iteratorTypes)) {
+ if (isParallelIteratorType(iteratorType.value()))
+ parallelLoopRanges.push_back(loopRanges[iteratorType.index()]);
+ }
+ if (distributionMethod.size() < parallelLoopRanges.size())
+ parallelLoopRanges.resize(distributionMethod.size());
+ SmallVector<ProcInfo, 2> procInfo =
+ options.procInfo(builder, loc, parallelLoopRanges);
+ unsigned index = 0;
+ for (auto iteratorType : enumerate(iteratorTypes)) {
+ if (index >= procInfo.size())
+ break;
+ if (isParallelIteratorType(iteratorType.value())) {
unsigned i = iteratorType.index();
- ProcInfo procInfo = options.procInfo(builder, loc, index);
- updateBoundsForCyclicDistribution(builder, loc, procInfo.procId,
- procInfo.nprocs, lbsStorage[i],
+ updateBoundsForCyclicDistribution(builder, loc, procInfo[index].procId,
+ procInfo[index].nprocs, lbsStorage[i],
ubsStorage[i], stepsStorage[i]);
index++;
}
+ }
}
ValueRange lbs(lbsStorage), ubs(ubsStorage), steps(stepsStorage);
generateParallelLoopNest(lbs, ubs, steps, iteratorTypes, bodyBuilderFn, ivs,
diff --git a/mlir/test/Dialect/Linalg/tile-and-distribute.mlir b/mlir/test/Dialect/Linalg/tile-and-distribute.mlir
index e1bc28e133bd..08f6d19fe6d6 100644
--- a/mlir/test/Dialect/Linalg/tile-and-distribute.mlir
+++ b/mlir/test/Dialect/Linalg/tile-and-distribute.mlir
@@ -11,16 +11,16 @@ func @gemm1(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>)
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<?x?xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<?x?xf32>
-// CHECK: %[[T1:.*]] = "gpu.block_id"() {dimension = "y"}
-// CHECK: %[[T2:.*]] = "gpu.block_id"() {dimension = "x"}
+// CHECK: %[[BIDY:.*]] = "gpu.block_id"() {dimension = "y"}
+// CHECK: %[[BIDX:.*]] = "gpu.block_id"() {dimension = "x"}
// CHECK: scf.for %[[ARG3:.*]] =
-// CHECK: %[[T3:.*]] = affine.apply #[[MAP0]]()[%[[T1]]]
-// CHECK: %[[SV1:.*]] = subview %[[ARG0]][%[[T3]], %[[ARG3]]]
-// CHECK: %[[T11:.*]] = affine.apply #[[MAP0]]()[%[[T2]]]
-// CHECK: %[[SV2:.*]] = subview %[[ARG1]][%[[ARG3]], %[[T11]]]
-// CHECK: %[[T15:.*]] = affine.apply #[[MAP0]]()[%[[T1]]]
-// CHECK: %[[T18:.*]] = affine.apply #[[MAP0]]()[%[[T2]]]
-// CHECK: %[[SV3:.*]] = subview %[[ARG2]][%[[T15]], %[[T18]]]
+// CHECK: %[[OFFSETY:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
+// CHECK: %[[SV1:.*]] = subview %[[ARG0]][%[[OFFSETY]], %[[ARG3]]]
+// CHECK: %[[OFFSETX:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
+// CHECK: %[[SV2:.*]] = subview %[[ARG1]][%[[ARG3]], %[[OFFSETX]]]
+// CHECK: %[[OFFSETY_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
+// CHECK: %[[OFFSETX:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
+// CHECK: %[[SV3:.*]] = subview %[[ARG2]][%[[OFFSETY_2]], %[[OFFSETX]]]
// CHECK: linalg.matmul %[[SV1]], %[[SV2]], %[[SV3]]
// -----
@@ -36,22 +36,22 @@ func @gemm2(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>)
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<?x?xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<?x?xf32>
-// CHECK: %[[T3:.*]] = "gpu.block_id"() {dimension = "y"}
-// CHECK: %[[T4:.*]] = affine.apply #[[MAP0]]()[%[[T3]]]
-// CHECK: %[[T5:.*]] = "gpu.block_id"() {dimension = "x"}
-// CHECK: %[[T6:.*]] = affine.apply #[[MAP0]]()[%[[T5]]]
-// CHECK: %[[T7:.*]] = cmpi "slt", %[[T4]], %{{.*}}
-// CHECK: %[[T8:.*]] = cmpi "slt", %[[T6]], %{{.*}}
-// CHECK: %[[T9:.*]] = and %[[T7]], %[[T8]]
-// CHECK: scf.if %[[T9]]
+// CHECK-DAG: %[[BIDY:.*]] = "gpu.block_id"() {dimension = "y"}
+// CHECK-DAG: %[[BIDX:.*]] = "gpu.block_id"() {dimension = "x"}
+// CHECK: %[[ITERY:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
+// CHECK: %[[ITERX:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
+// CHECK: %[[INBOUNDSY:.*]] = cmpi "slt", %[[ITERY]], %{{.*}}
+// CHECK: %[[INBOUNDSX:.*]] = cmpi "slt", %[[ITERX]], %{{.*}}
+// CHECK: %[[INBOUNDS:.*]] = and %[[INBOUNDSY]], %[[INBOUNDSX]]
+// CHECK: scf.if %[[INBOUNDS]]
// CHECK: scf.for %[[ARG3:.*]] =
-// CHECK: %[[T10:.*]] = affine.apply #[[MAP0]]()[%[[T3]]]
-// CHECK: %[[SV1:.*]] = subview %[[ARG0]][%[[T10]], %[[ARG3]]]
-// CHECK: %[[T18:.*]] = affine.apply #[[MAP0]]()[%[[T5]]]
-// CHECK: %[[SV2:.*]] = subview %[[ARG1]][%[[ARG3]], %[[T18]]]
-// CHECK: %[[T22:.*]] = affine.apply #[[MAP0]]()[%[[T3]]]
-// CHECK: %[[T25:.*]] = affine.apply #[[MAP0]]()[%[[T5]]]
-// CHECK: %[[SV3:.*]] = subview %[[ARG2]][%[[T22]], %[[T25]]]
+// CHECK: %[[OFFSETY:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
+// CHECK: %[[SV1:.*]] = subview %[[ARG0]][%[[OFFSETY]], %[[ARG3]]]
+// CHECK: %[[OFFSETX:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
+// CHECK: %[[SV2:.*]] = subview %[[ARG1]][%[[ARG3]], %[[OFFSETX]]]
+// CHECK: %[[OFFSETY_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
+// CHECK: %[[OFFSETX_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
+// CHECK: %[[SV3:.*]] = subview %[[ARG2]][%[[OFFSETY_2]], %[[OFFSETX_2]]]
// CHECK: linalg.matmul %[[SV1]], %[[SV2]], %[[SV3]]
// -----
@@ -67,15 +67,15 @@ func @gemm3(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>)
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<?x?xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<?x?xf32>
-// CHECK: %[[T3:.*]] = "gpu.block_id"() {dimension = "y"}
-// CHECK: %[[T4:.*]] = "gpu.grid_dim"() {dimension = "y"}
-// CHECK: %[[T5:.*]] = affine.apply #[[MAP0]]()[%[[T3]]]
-// CHECK: %[[T6:.*]] = affine.apply #[[MAP0]]()[%[[T4]]]
-// CHECK: %[[T7:.*]] = "gpu.block_id"() {dimension = "x"}
-// CHECK: %[[T8:.*]] = "gpu.grid_dim"() {dimension = "x"}
-// CHECK: %[[T9:.*]] = affine.apply #[[MAP0]]()[%[[T7]]]
-// CHECK: %[[T10:.*]] = affine.apply #[[MAP0]]()[%[[T8]]]
-// CHECK: scf.parallel (%[[ARG3:.*]], %[[ARG4:.*]]) = (%[[T5]], %[[T9]]) to (%{{.*}}, %{{.*}}) step (%[[T6]], %[[T10]])
+// CHECK: %[[BIDY:.*]] = "gpu.block_id"() {dimension = "y"}
+// CHECK: %[[NBLOCKSY:.*]] = "gpu.grid_dim"() {dimension = "y"}
+// CHECK: %[[BIDX:.*]] = "gpu.block_id"() {dimension = "x"}
+// CHECK: %[[NBLOCKSX:.*]] = "gpu.grid_dim"() {dimension = "x"}
+// CHECK: %[[LBY:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
+// CHECK: %[[STEPY:.*]] = affine.apply #[[MAP0]]()[%[[NBLOCKSY]]]
+// CHECK: %[[LBX:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
+// CHECK: %[[STEPX:.*]] = affine.apply #[[MAP0]]()[%[[NBLOCKSX]]]
+// CHECK: scf.parallel (%[[ARG3:.*]], %[[ARG4:.*]]) = (%[[LBY]], %[[LBX]]) to (%{{.*}}, %{{.*}}) step (%[[STEPY]], %[[STEPX]])
// CHECK: scf.for %[[ARG5:.*]] =
// CHECK: %[[SV1:.*]] = subview %[[ARG0]][%[[ARG3]], %[[ARG5]]]
// CHECK: %[[SV2:.*]] = subview %[[ARG1]][%[[ARG5]], %[[ARG4]]]
@@ -95,19 +95,19 @@ func @gemm4(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>)
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<?x?xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<?x?xf32>
-// CHECK: %[[T2:.*]] = "gpu.block_id"() {dimension = "y"}
-// CHECK: %[[T3:.*]] = "gpu.block_id"() {dimension = "x"}
-// CHECK: %[[T4:.*]] = affine.apply #[[MAP0]]()[%[[T3]]]
-// CHECK: %[[T5:.*]] = cmpi "slt", %[[T4]], %{{.*}}
-// CHECK: scf.if %[[T5]]
+// CHECK: %[[BIDY:.*]] = "gpu.block_id"() {dimension = "y"}
+// CHECK: %[[BIDX:.*]] = "gpu.block_id"() {dimension = "x"}
+// CHECK: %[[LBX:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
+// CHECK: %[[INBOUNDS:.*]] = cmpi "slt", %[[LBX]], %{{.*}}
+// CHECK: scf.if %[[INBOUNDS]]
// CHECK: scf.for %[[ARG3:.*]] =
-// CHECK: %[[T6:.*]] = affine.apply #[[MAP0]]()[%[[T2]]]
-// CHECK: %[[SV1:.*]] = subview %[[ARG0]][%[[T6]], %[[ARG3]]]
-// CHECK: %[[T14:.*]] = affine.apply #[[MAP0]]()[%[[T3]]]
-// CHECK: %[[SV2:.*]] = subview %[[ARG1]][%[[ARG3]], %[[T14]]]
-// CHECK: %[[T18:.*]] = affine.apply #[[MAP0]]()[%[[T2]]]
-// CHECK: %[[T21:.*]] = affine.apply #[[MAP0]]()[%[[T3]]]
-// CHECK: %[[SV3:.*]] = subview %[[ARG2]][%[[T18]], %[[T21]]]
+// CHECK: %[[OFFSETY:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
+// CHECK: %[[SV1:.*]] = subview %[[ARG0]][%[[OFFSETY]], %[[ARG3]]]
+// CHECK: %[[OFFSETX:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
+// CHECK: %[[SV2:.*]] = subview %[[ARG1]][%[[ARG3]], %[[OFFSETX]]]
+// CHECK: %[[OFFSETY_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
+// CHECK: %[[OFFSETX_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
+// CHECK: %[[SV3:.*]] = subview %[[ARG2]][%[[OFFSETY_2]], %[[OFFSETX_2]]]
// CHECK: linalg.matmul %[[SV1]], %[[SV2]], %[[SV3]]
// -----
@@ -123,21 +123,21 @@ func @gemm5(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>)
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<?x?xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<?x?xf32>
-// CHECK: %[[T3:.*]] = "gpu.block_id"() {dimension = "y"}
-// CHECK: %[[T4:.*]] = affine.apply #[[MAP0]]()[%[[T3]]]
-// CHECK: %[[T5:.*]] = "gpu.block_id"() {dimension = "x"}
-// CHECK: %[[T6:.*]] = "gpu.grid_dim"() {dimension = "x"}
-// CHECK: %[[T7:.*]] = affine.apply #[[MAP0]]()[%[[T5]]]
-// CHECK: %[[T8:.*]] = affine.apply #[[MAP0]]()[%[[T6]]]
-// CHECK: %[[T9:.*]] = cmpi "slt", %[[T4]], %{{.*}}
-// CHECK: scf.if %[[T9]]
-// CHECK: scf.parallel (%[[ARG3.*]]) = (%[[T7]]) to (%{{.*}}) step (%[[T8]])
+// CHECK: %[[BIDY:.*]] = "gpu.block_id"() {dimension = "y"}
+// CHECK: %[[BIDX:.*]] = "gpu.block_id"() {dimension = "x"}
+// CHECK: %[[NBLOCKSX:.*]] = "gpu.grid_dim"() {dimension = "x"}
+// CHECK: %[[LBY:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
+// CHECK: %[[LBX:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
+// CHECK: %[[STEPX:.*]] = affine.apply #[[MAP0]]()[%[[NBLOCKSX]]]
+// CHECK: %[[INBOUNDS:.*]] = cmpi "slt", %[[LBY]], %{{.*}}
+// CHECK: scf.if %[[INBOUNDS]]
+// CHECK: scf.parallel (%[[ARG3.*]]) = (%[[LBX]]) to (%{{.*}}) step (%[[STEPX]])
// CHECK: scf.for %[[ARG4:.*]] =
-// CHECK: %[[T10:.*]] = affine.apply #[[MAP0]]()[%[[T3]]]
-// CHECK: %[[SV1:.*]] = subview %[[ARG0]][%[[T10]], %[[ARG4]]]
+// CHECK: %[[OFFSETY:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
+// CHECK: %[[SV1:.*]] = subview %[[ARG0]][%[[OFFSETY]], %[[ARG4]]]
// CHECK: %[[SV2:.*]] = subview %[[ARG1]][%[[ARG4]], %[[ARG3]]]
-// CHECK: %[[T21:.*]] = affine.apply #[[MAP0]]()[%[[T3]]]
-// CHECK: %[[SV3:.*]] = subview %[[ARG2]][%[[T21]], %[[ARG3]]]
+// CHECK: %[[OFFSETY_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
+// CHECK: %[[SV3:.*]] = subview %[[ARG2]][%[[OFFSETY_2]], %[[ARG3]]]
// CHECK: linalg.matmul %[[SV1]], %[[SV2]], %[[SV3]]
// -----
@@ -153,16 +153,16 @@ func @gemm6(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>)
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<?x?xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<?x?xf32>
-// CHECK: %[[T2:.*]] = "gpu.block_id"() {dimension = "y"}
-// CHECK: %[[T3:.*]] = "gpu.grid_dim"() {dimension = "y"}
-// CHECK: %[[T4:.*]] = affine.apply #[[MAP0]]()[%[[T2]]]
-// CHECK: %[[T5:.*]] = affine.apply #[[MAP0]]()[%[[T3]]]
-// CHECK: %[[T6:.*]] = "gpu.block_id"() {dimension = "x"}
-// CHECK: scf.parallel (%[[ARG3.*]]) = (%[[T4]]) to (%{{.*}}) step (%[[T5]])
+// CHECK: %[[BIDY:.*]] = "gpu.block_id"() {dimension = "y"}
+// CHECK: %[[NBLOCKSY:.*]] = "gpu.grid_dim"() {dimension = "y"}
+// CHECK: %[[BIDX:.*]] = "gpu.block_id"() {dimension = "x"}
+// CHECK: %[[LBY:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
+// CHECK: %[[STEPY:.*]] = affine.apply #[[MAP0]]()[%[[NBLOCKSY]]]
+// CHECK: scf.parallel (%[[ARG3.*]]) = (%[[LBY]]) to (%{{.*}}) step (%[[STEPY]])
// CHECK: scf.for %[[ARG4:.*]] =
// CHECK: %[[SV1:.*]] = subview %[[ARG0]][%[[ARG3]], %[[ARG4]]]
-// CHECK: %[[T14:.*]] = affine.apply #[[MAP0]]()[%[[T6]]]
-// CHECK: %[[SV2:.*]] = subview %[[ARG1]][%[[ARG4]], %[[T14]]]
-// CHECK: %[[T20:.*]] = affine.apply #[[MAP0]]()[%[[T6]]]
-// CHECK: %[[SV3:.*]] = subview %[[ARG2]][%[[ARG3]], %[[T20]]]
+// CHECK: %[[OFFSETX:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
+// CHECK: %[[SV2:.*]] = subview %[[ARG1]][%[[ARG4]], %[[OFFSETX]]]
+// CHECK: %[[OFFSETX_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
+// CHECK: %[[SV3:.*]] = subview %[[ARG2]][%[[ARG3]], %[[OFFSETX_2]]]
// CHECK: linalg.matmul %[[SV1]], %[[SV2]], %[[SV3]]
diff --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
index f6c1160d35b0..dffe4f2a0796 100644
--- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
@@ -289,19 +289,16 @@ static void fillPromotionCallBackPatterns(MLIRContext *ctx,
}
template <typename IdOp, typename NProcsOp>
-static ProcInfo getGpuProcIds(OpBuilder &b, Location loc, unsigned loopNum) {
+static SmallVector<ProcInfo, 2>
+getGpuProcIds(OpBuilder &b, Location loc,
+ ArrayRef<SubViewOp::Range> parallelLoopRanges) {
Type indexType = b.getIndexType();
- switch (loopNum) {
- case 0:
- return {b.create<IdOp>(loc, indexType, b.getStringAttr("y")),
- b.create<NProcsOp>(loc, indexType, b.getStringAttr("y"))};
- case 1:
- return {b.create<IdOp>(loc, indexType, b.getStringAttr("x")),
- b.create<NProcsOp>(loc, indexType, b.getStringAttr("x"))};
- default:
- llvm_unreachable("test patterns handles only upto 2-level nested loops");
- }
- return {nullptr, nullptr};
+ SmallVector<ProcInfo, 2> procInfo(2);
+ procInfo[0] = {b.create<IdOp>(loc, indexType, b.getStringAttr("y")),
+ b.create<NProcsOp>(loc, indexType, b.getStringAttr("y"))};
+ procInfo[1] = {b.create<IdOp>(loc, indexType, b.getStringAttr("x")),
+ b.create<NProcsOp>(loc, indexType, b.getStringAttr("x"))};
+ return procInfo;
}
static void fillTileAndDistributePatterns(MLIRContext *context,
More information about the Mlir-commits
mailing list