[Mlir-commits] [mlir] 27c201a - [MLIR] Add parallel loop collapsing.
Tres Popp
llvmlistbot at llvm.org
Thu Mar 26 01:33:32 PDT 2020
Author: Tres Popp
Date: 2020-03-26T09:32:52+01:00
New Revision: 27c201aa1d972c10bfc38ff5b4abaf78ab3d2ad2
URL: https://github.com/llvm/llvm-project/commit/27c201aa1d972c10bfc38ff5b4abaf78ab3d2ad2
DIFF: https://github.com/llvm/llvm-project/commit/27c201aa1d972c10bfc38ff5b4abaf78ab3d2ad2.diff
LOG: [MLIR] Add parallel loop collapsing.
This allows conversion of a ParallelLoop from N induction variables to
some nuber of induction variables less than N.
The first intended use of this is for the GPUDialect to convert
ParallelLoops to iterate over 3 dimensions so they can be launched as
GPU Kernels.
To implement this:
- Normalize each iteration space of the ParallelLoop
- Use the same induction variable in a new ParallelLoop for multiple
original iterations.
- Split the new induction variable back into the original set of values
inside the body of the ParallelLoop.
Subscribers: mgorny, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, aartbik, liufengdb, Joonsoo, llvm-commits
Tags: #llvm
Differential Revision: https://reviews.llvm.org/D76363
Added:
mlir/lib/Transforms/ParallelLoopCollapsing.cpp
mlir/test/Transforms/parallel-loop-collapsing.mlir
mlir/test/Transforms/single-parallel-loop-collapsing.mlir
Modified:
mlir/include/mlir/InitAllPasses.h
mlir/include/mlir/Transforms/LoopUtils.h
mlir/include/mlir/Transforms/Passes.h
mlir/lib/Transforms/CMakeLists.txt
mlir/lib/Transforms/Utils/LoopUtils.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h
index b983ecb478dd..7cde5d1c701e 100644
--- a/mlir/include/mlir/InitAllPasses.h
+++ b/mlir/include/mlir/InitAllPasses.h
@@ -113,6 +113,7 @@ inline void registerAllPasses() {
LLVM::createLegalizeForExportPass();
// LoopOps
+ createParallelLoopCollapsingPass();
createParallelLoopFusionPass();
createParallelLoopSpecializationPass();
createParallelLoopTilingPass();
diff --git a/mlir/include/mlir/Transforms/LoopUtils.h b/mlir/include/mlir/Transforms/LoopUtils.h
index 479b7ce81a42..861a2821ee00 100644
--- a/mlir/include/mlir/Transforms/LoopUtils.h
+++ b/mlir/include/mlir/Transforms/LoopUtils.h
@@ -28,6 +28,7 @@ struct MemRefRegion;
namespace loop {
class ForOp;
+class ParallelOp;
} // end namespace loop
/// Unrolls this for operation completely if the trip count is known to be
@@ -226,6 +227,12 @@ TileLoops extractFixedOuterLoops(loop::ForOp rootFOrOp,
/// independent of any loop induction variable involved in the nest.
void coalesceLoops(MutableArrayRef<loop::ForOp> loops);
+/// Take the ParallelLoop and for each set of dimension indices, combine them
+/// into a single dimension. combinedDimensions must contain each index into
+/// loops exactly once.
+void collapsePLoops(loop::ParallelOp loops,
+ ArrayRef<std::vector<unsigned>> combinedDimensions);
+
/// Maps `forOp` for execution on a parallel grid of virtual `processorIds` of
/// size given by `numProcessors`. This is achieved by embedding the SSA values
/// corresponding to `processorIds` and `numProcessors` into the bounds and step
diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 28b29eb29ad4..a8462e862083 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -57,6 +57,18 @@ std::unique_ptr<OpPassBase<FuncOp>> createLowerAffinePass();
/// bounds into a single loop.
std::unique_ptr<OpPassBase<FuncOp>> createLoopCoalescingPass();
+/// Creates a pass that transforms a single ParallelLoop over N induction
+/// variables into another ParallelLoop over less than N induction variables.
+std::unique_ptr<Pass> createParallelLoopCollapsingPass();
+
+/// Performs packing (or explicit copying) of accessed memref regions into
+/// buffers in the specified faster memory space through either pointwise copies
+/// or DMA operations.
+std::unique_ptr<OpPassBase<FuncOp>> createAffineDataCopyGenerationPass(
+ unsigned slowMemorySpace, unsigned fastMemorySpace,
+ unsigned tagMemorySpace = 0, int minDmaTransferSize = 1024,
+ uint64_t fastMemCapacityBytes = std::numeric_limits<uint64_t>::max());
+
/// Creates a pass to perform optimizations relying on memref dataflow such as
/// store to load forwarding, elimination of dead stores, and dead allocs.
std::unique_ptr<OpPassBase<FuncOp>> createMemRefDataFlowOptPass();
diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt
index 8d3c5f46f8c4..acb0a5b8b9a1 100644
--- a/mlir/lib/Transforms/CMakeLists.txt
+++ b/mlir/lib/Transforms/CMakeLists.txt
@@ -11,6 +11,7 @@ add_mlir_library(MLIRTransforms
LoopInvariantCodeMotion.cpp
MemRefDataFlowOpt.cpp
OpStats.cpp
+ ParallelLoopCollapsing.cpp
PipelineDataTransfer.cpp
StripDebugInfo.cpp
SymbolDCE.cpp
diff --git a/mlir/lib/Transforms/ParallelLoopCollapsing.cpp b/mlir/lib/Transforms/ParallelLoopCollapsing.cpp
new file mode 100644
index 000000000000..4660af0e3d84
--- /dev/null
+++ b/mlir/lib/Transforms/ParallelLoopCollapsing.cpp
@@ -0,0 +1,69 @@
+//===- ParallelLoopCollapsing.cpp - Pass collapsing parallel loop indices -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/LoopOps/LoopOps.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/LoopUtils.h"
+#include "mlir/Transforms/Passes.h"
+#include "mlir/Transforms/RegionUtils.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
+
+#define PASS_NAME "parallel-loop-collapsing"
+#define DEBUG_TYPE PASS_NAME
+
+using namespace mlir;
+
+namespace {
+struct ParallelLoopCollapsing : public OperationPass<ParallelLoopCollapsing> {
+ ParallelLoopCollapsing() = default;
+ ParallelLoopCollapsing(const ParallelLoopCollapsing &) {}
+ void runOnOperation() override {
+ Operation *module = getOperation();
+
+ module->walk([&](loop::ParallelOp op) {
+ // The common case for GPU dialect will be simplifying the ParallelOp to 3
+ // arguments, so we do that here to simplify things.
+ llvm::SmallVector<std::vector<unsigned>, 3> combinedLoops;
+ if (clCollapsedIndices0.size())
+ combinedLoops.push_back(clCollapsedIndices0);
+ if (clCollapsedIndices1.size())
+ combinedLoops.push_back(clCollapsedIndices1);
+ if (clCollapsedIndices2.size())
+ combinedLoops.push_back(clCollapsedIndices2);
+ collapsePLoops(op, combinedLoops);
+ });
+ }
+
+ ListOption<unsigned> clCollapsedIndices0{
+ *this, "collapsed-indices-0",
+ llvm::cl::desc("Which loop indices to combine 0th loop index"),
+ llvm::cl::MiscFlags::CommaSeparated};
+
+ ListOption<unsigned> clCollapsedIndices1{
+ *this, "collapsed-indices-1",
+ llvm::cl::desc(
+ "Which loop indices to combine into the position 1 loop index"),
+ llvm::cl::MiscFlags::CommaSeparated};
+
+ ListOption<unsigned> clCollapsedIndices2{
+ *this, "collapsed-indices-2",
+ llvm::cl::desc(
+ "Which loop indices to combine into the position 2 loop index"),
+ llvm::cl::MiscFlags::CommaSeparated};
+};
+
+} // namespace
+
+std::unique_ptr<Pass> mlir::createParallelLoopCollapsingPass() {
+ return std::make_unique<ParallelLoopCollapsing>();
+}
+
+static PassRegistration<ParallelLoopCollapsing>
+ reg(PASS_NAME, "collapse parallel loops to use less induction variables.");
diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp
index a1c888b35af8..d5976d0278ed 100644
--- a/mlir/lib/Transforms/Utils/LoopUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp
@@ -36,6 +36,16 @@ using namespace mlir;
using llvm::SetVector;
using llvm::SmallMapVector;
+namespace {
+// This structure is to pass and return sets of loop parameters without
+// confusing the order.
+struct LoopParams {
+ Value lowerBound;
+ Value upperBound;
+ Value step;
+};
+} // namespace
+
/// Computes the cleanup loop lower bound of the loop being unrolled with
/// the specified unroll factor; this bound will also be upper bound of the main
/// part of the unrolled loop. Computes the bound as an AffineMap with its
@@ -1094,69 +1104,78 @@ replaceAllUsesExcept(Value orig, Value replacement,
}
}
-// Transform a loop with a strictly positive step
-// for %i = %lb to %ub step %s
-// into a 0-based loop with step 1
-// for %ii = 0 to ceildiv(%ub - %lb, %s) step 1 {
-// %i = %ii * %s + %lb
-// Insert the induction variable remapping in the body of `inner`, which is
-// expected to be either `loop` or another loop perfectly nested under `loop`.
-// Insert the definition of new bounds immediate before `outer`, which is
-// expected to be either `loop` or its parent in the loop nest.
-static void normalizeLoop(loop::ForOp loop, loop::ForOp outer,
- loop::ForOp inner) {
- OpBuilder builder(outer);
- Location loc = loop.getLoc();
-
+/// Return the new lower bound, upper bound, and step in that order. Insert any
+/// additional bounds calculations before the given builder and any additional
+/// conversion back to the original loop induction value inside the given Block.
+static LoopParams normalizeLoop(OpBuilder &boundsBuilder,
+ OpBuilder &insideLoopBuilder, Location loc,
+ Value lowerBound, Value upperBound, Value step,
+ Value inductionVar) {
// Check if the loop is already known to have a constant zero lower bound or
// a constant one step.
bool isZeroBased = false;
if (auto ubCst =
- dyn_cast_or_null<ConstantIndexOp>(loop.lowerBound().getDefiningOp()))
+ dyn_cast_or_null<ConstantIndexOp>(lowerBound.getDefiningOp()))
isZeroBased = ubCst.getValue() == 0;
bool isStepOne = false;
- if (auto stepCst =
- dyn_cast_or_null<ConstantIndexOp>(loop.step().getDefiningOp()))
+ if (auto stepCst = dyn_cast_or_null<ConstantIndexOp>(step.getDefiningOp()))
isStepOne = stepCst.getValue() == 1;
- if (isZeroBased && isStepOne)
- return;
// Compute the number of iterations the loop executes: ceildiv(ub - lb, step)
// assuming the step is strictly positive. Update the bounds and the step
// of the loop to go from 0 to the number of iterations, if necessary.
// TODO(zinenko): introduce support for negative steps or emit dynamic asserts
// on step positivity, whatever gets implemented first.
- Value
diff =
- builder.create<SubIOp>(loc, loop.upperBound(), loop.lowerBound());
- Value numIterations = ceilDivPositive(builder, loc,
diff , loop.step());
- loop.setUpperBound(numIterations);
-
- Value lb = loop.lowerBound();
- if (!isZeroBased) {
- Value cst0 = builder.create<ConstantIndexOp>(loc, 0);
- loop.setLowerBound(cst0);
- }
+ if (isZeroBased && isStepOne)
+ return {/*lowerBound=*/lowerBound, /*upperBound=*/upperBound,
+ /*step=*/step};
- Value step = loop.step();
- if (!isStepOne) {
- Value cst1 = builder.create<ConstantIndexOp>(loc, 1);
- loop.setStep(cst1);
- }
+ Value
diff = boundsBuilder.create<SubIOp>(loc, upperBound, lowerBound);
+ Value newUpperBound = ceilDivPositive(boundsBuilder, loc,
diff , step);
+
+ Value newLowerBound =
+ isZeroBased ? lowerBound : boundsBuilder.create<ConstantIndexOp>(loc, 0);
+ Value newStep =
+ isStepOne ? step : boundsBuilder.create<ConstantIndexOp>(loc, 1);
// Insert code computing the value of the original loop induction variable
// from the "normalized" one.
- builder.setInsertionPointToStart(inner.getBody());
Value scaled =
- isStepOne ? loop.getInductionVar()
- : builder.create<MulIOp>(loc, loop.getInductionVar(), step);
+ isStepOne ? inductionVar
+ : insideLoopBuilder.create<MulIOp>(loc, inductionVar, step);
Value shifted =
- isZeroBased ? scaled : builder.create<AddIOp>(loc, scaled, lb);
+ isZeroBased ? scaled
+ : insideLoopBuilder.create<AddIOp>(loc, scaled, lowerBound);
SmallPtrSet<Operation *, 2> preserve{scaled.getDefiningOp(),
shifted.getDefiningOp()};
- replaceAllUsesExcept(loop.getInductionVar(), shifted, preserve);
+ replaceAllUsesExcept(inductionVar, shifted, preserve);
+ return {/*lowerBound=*/newLowerBound, /*upperBound=*/newUpperBound,
+ /*step=*/newStep};
+}
+
+/// Transform a loop with a strictly positive step
+/// for %i = %lb to %ub step %s
+/// into a 0-based loop with step 1
+/// for %ii = 0 to ceildiv(%ub - %lb, %s) step 1 {
+/// %i = %ii * %s + %lb
+/// Insert the induction variable remapping in the body of `inner`, which is
+/// expected to be either `loop` or another loop perfectly nested under `loop`.
+/// Insert the definition of new bounds immediate before `outer`, which is
+/// expected to be either `loop` or its parent in the loop nest.
+static void normalizeLoop(loop::ForOp loop, loop::ForOp outer,
+ loop::ForOp inner) {
+ OpBuilder builder(outer);
+ OpBuilder innerBuilder(inner.getBody(), inner.getBody()->begin());
+ auto loopPieces =
+ normalizeLoop(builder, innerBuilder, loop.getLoc(), loop.lowerBound(),
+ loop.upperBound(), loop.step(), loop.getInductionVar());
+
+ loop.setLowerBound(loopPieces.lowerBound);
+ loop.setUpperBound(loopPieces.upperBound);
+ loop.setStep(loopPieces.step);
}
void mlir::coalesceLoops(MutableArrayRef<loop::ForOp> loops) {
@@ -1214,6 +1233,86 @@ void mlir::coalesceLoops(MutableArrayRef<loop::ForOp> loops) {
second.erase();
}
+void mlir::collapsePLoops(loop::ParallelOp loops,
+ ArrayRef<std::vector<unsigned>> combinedDimensions) {
+ OpBuilder outsideBuilder(loops);
+ Location loc = loops.getLoc();
+
+ // Normalize ParallelOp's iteration pattern.
+ SmallVector<Value, 3> normalizedLowerBounds;
+ SmallVector<Value, 3> normalizedSteps;
+ SmallVector<Value, 3> normalizedUpperBounds;
+ for (unsigned i = 0, e = loops.getNumLoops(); i < e; ++i) {
+ OpBuilder insideLoopBuilder(loops.getBody(), loops.getBody()->begin());
+ auto resultBounds =
+ normalizeLoop(outsideBuilder, insideLoopBuilder, loc,
+ loops.lowerBound()[i], loops.upperBound()[i],
+ loops.step()[i], loops.getBody()->getArgument(i));
+
+ normalizedLowerBounds.push_back(resultBounds.lowerBound);
+ normalizedUpperBounds.push_back(resultBounds.upperBound);
+ normalizedSteps.push_back(resultBounds.step);
+ }
+
+ // Combine iteration spaces
+ SmallVector<Value, 3> lowerBounds;
+ SmallVector<Value, 3> steps;
+ SmallVector<Value, 3> upperBounds;
+ auto cst0 = outsideBuilder.create<ConstantIndexOp>(loc, 0);
+ auto cst1 = outsideBuilder.create<ConstantIndexOp>(loc, 1);
+ for (unsigned i = 0, e = combinedDimensions.size(); i < e; ++i) {
+ Value newUpperBound = outsideBuilder.create<ConstantIndexOp>(loc, 1);
+ for (auto idx : combinedDimensions[i]) {
+ newUpperBound = outsideBuilder.create<MulIOp>(loc, newUpperBound,
+ normalizedUpperBounds[idx]);
+ }
+ lowerBounds.push_back(cst0);
+ steps.push_back(cst1);
+ upperBounds.push_back(newUpperBound);
+ }
+
+ // Create new ParallelLoop with conversions to the original induction values.
+ // The loop below uses divisions to get the relevant range of values in the
+ // new induction value that represent each range of the original induction
+ // value. The remainders then determine based on that range, which iteration
+ // of the original induction value this represents. This is a normalized value
+ // that is un-normalized already by the previous logic.
+ auto newPloop = outsideBuilder.create<loop::ParallelOp>(loc, lowerBounds,
+ upperBounds, steps);
+ OpBuilder insideBuilder(newPloop.getBody(), newPloop.getBody()->begin());
+ for (unsigned i = 0, e = combinedDimensions.size(); i < e; ++i) {
+ Value previous = newPloop.getBody()->getArgument(i);
+ unsigned numberCombinedDimensions = combinedDimensions[i].size();
+ // Iterate over all except the last induction value.
+ for (unsigned j = 0, e = numberCombinedDimensions - 1; j < e; ++j) {
+ unsigned idx = combinedDimensions[i][j];
+
+ // Determine the current induction value's current loop iteration
+ Value iv = insideBuilder.create<SignedRemIOp>(loc, previous,
+ normalizedUpperBounds[idx]);
+ replaceAllUsesInRegionWith(loops.getBody()->getArgument(idx), iv,
+ loops.region());
+
+ // Remove the effect of the current induction value to prepare for the
+ // next value.
+ previous = insideBuilder.create<SignedDivIOp>(
+ loc, previous, normalizedUpperBounds[idx + 1]);
+ }
+
+ // The final induction value is just the remaining value.
+ unsigned idx = combinedDimensions[i][numberCombinedDimensions - 1];
+ replaceAllUsesInRegionWith(loops.getBody()->getArgument(idx), previous,
+ loops.region());
+ }
+
+ // Replace the old loop with the new loop.
+ loops.getBody()->back().erase();
+ newPloop.getBody()->getOperations().splice(
+ Block::iterator(newPloop.getBody()->back()),
+ loops.getBody()->getOperations());
+ loops.erase();
+}
+
void mlir::mapLoopToProcessorIds(loop::ForOp forOp, ArrayRef<Value> processorId,
ArrayRef<Value> numProcessors) {
assert(processorId.size() == numProcessors.size());
diff --git a/mlir/test/Transforms/parallel-loop-collapsing.mlir b/mlir/test/Transforms/parallel-loop-collapsing.mlir
new file mode 100644
index 000000000000..68aa4b0e3782
--- /dev/null
+++ b/mlir/test/Transforms/parallel-loop-collapsing.mlir
@@ -0,0 +1,52 @@
+// RUN: mlir-opt %s -pass-pipeline='func(parallel-loop-collapsing{collapsed-indices-0=0,3 collapsed-indices-1=1,4 collapsed-indices-2=2}, canonicalize)' | FileCheck %s
+
+// CHECK-LABEL: func @parallel_many_dims() {
+func @parallel_many_dims() {
+ // CHECK: [[VAL_0:%.*]] = constant 6 : index
+ // CHECK: [[VAL_1:%.*]] = constant 7 : index
+ // CHECK: [[VAL_2:%.*]] = constant 9 : index
+ // CHECK: [[VAL_3:%.*]] = constant 10 : index
+ // CHECK: [[VAL_4:%.*]] = constant 12 : index
+ // CHECK: [[VAL_5:%.*]] = constant 13 : index
+ // CHECK: [[VAL_6:%.*]] = constant 3 : index
+ // CHECK: [[VAL_7:%.*]] = constant 0 : index
+ // CHECK: [[VAL_8:%.*]] = constant 1 : index
+ // CHECK: [[VAL_9:%.*]] = constant 2 : index
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %c2 = constant 2 : index
+ %c3 = constant 3 : index
+ %c4 = constant 4 : index
+ %c5 = constant 5 : index
+ %c6 = constant 6 : index
+ %c7 = constant 7 : index
+ %c8 = constant 8 : index
+ %c9 = constant 9 : index
+ %c10 = constant 10 : index
+ %c11 = constant 11 : index
+ %c12 = constant 12 : index
+ %c13 = constant 13 : index
+ %c14 = constant 14 : index
+
+ // CHECK: loop.parallel ([[VAL_10:%.*]], [[VAL_11:%.*]], [[VAL_12:%.*]]) = ([[VAL_7]], [[VAL_7]], [[VAL_7]]) to ([[VAL_9]], [[VAL_8]], [[VAL_8]]) step ([[VAL_8]], [[VAL_8]], [[VAL_8]]) {
+ loop.parallel (%i0, %i1, %i2, %i3, %i4) = (%c0, %c3, %c6, %c9, %c12) to (%c2, %c5, %c8, %c11, %c14)
+ step (%c1, %c4, %c7, %c10, %c13) {
+ // CHECK: [[VAL_13:%.*]] = remi_signed [[VAL_10]], [[VAL_9]] : index
+ // CHECK: [[VAL_14:%.*]] = divi_signed [[VAL_10]], [[VAL_8]] : index
+ // CHECK: [[VAL_15:%.*]] = divi_signed [[VAL_11]], [[VAL_8]] : index
+ // CHECK: [[VAL_16:%.*]] = muli [[VAL_15]], [[VAL_5]] : index
+ // CHECK: [[VAL_17:%.*]] = addi [[VAL_16]], [[VAL_4]] : index
+ // CHECK: [[VAL_18:%.*]] = muli [[VAL_14]], [[VAL_3]] : index
+ // CHECK: [[VAL_19:%.*]] = addi [[VAL_18]], [[VAL_2]] : index
+ // CHECK: [[VAL_20:%.*]] = muli [[VAL_12]], [[VAL_1]] : index
+ // CHECK: [[VAL_21:%.*]] = addi [[VAL_20]], [[VAL_0]] : index
+ // CHECK: [[VAL_22:%.*]] = "magic.op"([[VAL_13]], [[VAL_6]], [[VAL_21]], [[VAL_19]], [[VAL_17]]) : (index, index, index, index, index) -> index
+ %result = "magic.op"(%i0, %i1, %i2, %i3, %i4): (index, index, index, index, index) -> index
+ }
+ return
+}
+// CHECK: loop.yield
+// CHECK: }
+// CHECK: return
+// CHECK: }
+
diff --git a/mlir/test/Transforms/single-parallel-loop-collapsing.mlir b/mlir/test/Transforms/single-parallel-loop-collapsing.mlir
new file mode 100644
index 000000000000..0b05c61a62d1
--- /dev/null
+++ b/mlir/test/Transforms/single-parallel-loop-collapsing.mlir
@@ -0,0 +1,36 @@
+// RUN: mlir-opt %s -pass-pipeline='func(parallel-loop-collapsing{collapsed-indices-0=0,1}, canonicalize)' | FileCheck %s
+
+// CHECK-LABEL: func @collapse_to_single() {
+func @collapse_to_single() {
+ // CHECK: [[VAL_0:%.*]] = constant 7 : index
+ // CHECK: [[VAL_1:%.*]] = constant 4 : index
+ // CHECK: [[VAL_2:%.*]] = constant 18 : index
+ // CHECK: [[VAL_3:%.*]] = constant 3 : index
+ // CHECK: [[VAL_4:%.*]] = constant 6 : index
+ // CHECK: [[VAL_5:%.*]] = constant 0 : index
+ // CHECK: [[VAL_6:%.*]] = constant 1 : index
+ %c0 = constant 3 : index
+ %c1 = constant 7 : index
+ %c2 = constant 11 : index
+ %c3 = constant 29 : index
+ %c4 = constant 3 : index
+ %c5 = constant 4 : index
+ // CHECK: loop.parallel ([[VAL_7:%.*]]) = ([[VAL_5]]) to ([[VAL_2]]) step ([[VAL_6]]) {
+ loop.parallel (%i0, %i1) = (%c0, %c1) to (%c2, %c3) step (%c4, %c5) {
+ // CHECK: [[VAL_8:%.*]] = remi_signed [[VAL_7]], [[VAL_3]] : index
+ // CHECK: [[VAL_9:%.*]] = divi_signed [[VAL_7]], [[VAL_4]] : index
+ // CHECK: [[VAL_10:%.*]] = muli [[VAL_9]], [[VAL_1]] : index
+ // CHECK: [[VAL_11:%.*]] = addi [[VAL_10]], [[VAL_0]] : index
+ // CHECK: [[VAL_12:%.*]] = muli [[VAL_8]], [[VAL_3]] : index
+ // CHECK: [[VAL_13:%.*]] = addi [[VAL_12]], [[VAL_3]] : index
+ // CHECK: [[VAL_14:%.*]] = "magic.op"([[VAL_13]], [[VAL_11]]) : (index, index) -> index
+ %result = "magic.op"(%i0, %i1): (index, index) -> index
+ }
+ return
+}
+// CHECK: loop.yield
+// CHECK: }
+// CHECK: return
+// CHECK: }
+
+
More information about the Mlir-commits
mailing list