[Mlir-commits] [mlir] 5c4e397 - [mlir][sparse] add parallelization strategies to sparse compiler
Aart Bik
llvmlistbot at llvm.org
Tue Nov 24 17:17:34 PST 2020
Author: Aart Bik
Date: 2020-11-24T17:17:13-08:00
New Revision: 5c4e397e6ce5c89d63f590857e5cb0e80237de62
URL: https://github.com/llvm/llvm-project/commit/5c4e397e6ce5c89d63f590857e5cb0e80237de62
DIFF: https://github.com/llvm/llvm-project/commit/5c4e397e6ce5c89d63f590857e5cb0e80237de62.diff
LOG: [mlir][sparse] add parallelization strategies to sparse compiler
This CL adds the ability to request different parallelization strategies
for the generate code. Every "parallel" loop is a candidate, and converted
to a parallel op if it is an actual for-loop (not a while) and the strategy
allows dense/sparse outer/inner parallelization.
This will connect directly with the work of @ezhulenev on parallel loops.
Still TBD: vectorization strategy
Reviewed By: penpornk
Differential Revision: https://reviews.llvm.org/D91978
Added:
mlir/test/Dialect/Linalg/sparse_parallel.mlir
Modified:
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
mlir/test/lib/Transforms/TestSparsification.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index d67e81ceab87..7a96a5ed1390 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -783,9 +783,62 @@ LogicalResult applyStagedPatterns(
//===----------------------------------------------------------------------===//
// Support for sparse tensor code generation.
+//
+// The sparse compiler part of MLIR lowers a tensor expression formulated as a
+// Linalg operation into a sequence of loops depending on what dimensions of the
+// tensors are marked dense or sparse. The generated code distinguishes between:
+// (1) for-loops that iterate over a single dense dimension,
+// (2) for-loops that iterate over a single sparse dimension,
+// (3) while-loops that co-iterate over several sparse dimensions.
+// The for-loops may be subsequently optimized for parallel or vector execution.
+//
+// For more details, the Dialect/Linalg/Transforms/Sparsification.cpp file.
//===----------------------------------------------------------------------===//
-void populateSparsificationPatterns(MLIRContext *context,
- OwningRewritePatternList &patterns);
+
+/// Defines a parallelization strategy. Any implicit loop in the Linalg
+/// operation that is marked "parallel" (thus not "reduction") is a candidate
+/// for parallelization. The loop is made parallel if (1) allowed by the
+/// strategy (e.g., AnyStorageOuterLoop considers either a dense or sparse
+/// outermost loop only), and (2) the generated code is an actual for-loop
+/// (and not a co-iterating while-loop).
+enum class SparseParallelizationStrategy {
+ kNone,
+ kDenseOuterLoop,
+ kAnyStorageOuterLoop,
+ kDenseAnyLoop,
+ kAnyStorageAnyLoop
+ // TODO: support reduction parallelization too?
+};
+
+/// Defines a vectorization strategy. Any implicit inner loop in the Linalg
+/// operation is a candidate (full SIMD for "parallel" loops and horizontal
+/// SIMD for "reduction" loops). A loop is actually vectorized if (1) allowed
+/// by the strategy, and (2) the emitted code is an actual for-loop (and not
+/// a co-iterating while-loop).
+enum class SparseVectorizationStrategy {
+ kNone,
+ kDenseInnerLoop,
+ kAnyStorageInnerLoop
+};
+
+/// Sparsification options.
+struct SparsificationOptions {
+ SparsificationOptions(SparseParallelizationStrategy p,
+ SparseVectorizationStrategy v, unsigned vl)
+ : parallelizationStrategy(p), vectorizationStrategy(v), vectorLength(vl) {
+ }
+ SparsificationOptions()
+ : SparsificationOptions(SparseParallelizationStrategy::kNone,
+ SparseVectorizationStrategy::kNone, 1u) {}
+ SparseParallelizationStrategy parallelizationStrategy;
+ SparseVectorizationStrategy vectorizationStrategy;
+ unsigned vectorLength;
+};
+
+/// Set up sparsification rewriting rules with the given options.
+void populateSparsificationPatterns(
+ MLIRContext *context, OwningRewritePatternList &patterns,
+ const SparsificationOptions &options = SparsificationOptions());
} // namespace linalg
} // namespace mlir
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
index 69a4d7e5648e..729268393ed9 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
@@ -235,22 +235,30 @@ class Merger {
// Code generation.
struct CodeGen {
- CodeGen(unsigned numTensors, unsigned numLoops)
- : loops(numLoops), sizes(numLoops), buffers(numTensors),
+ CodeGen(linalg::SparsificationOptions o, unsigned numTensors,
+ unsigned numLoops)
+ : options(o), loops(numLoops), sizes(numLoops), buffers(numTensors),
pointers(numTensors, std::vector<Value>(numLoops)),
indices(numTensors, std::vector<Value>(numLoops)),
highs(numTensors, std::vector<Value>(numLoops)),
pidxs(numTensors, std::vector<Value>(numLoops)),
idxs(numTensors, std::vector<Value>(numLoops)) {}
- // Universal dense indices and upper bounds (by index).
+ // Sparsification options.
+ linalg::SparsificationOptions options;
+ // Universal dense indices and upper bounds (by index). The loops array
+ // is updated with the value of the universal dense index in the current
+ // loop. The sizes array is set once with the inferred dimension sizes.
std::vector<Value> loops;
std::vector<Value> sizes;
// Buffers for storing dense and sparse numerical values (by tensor).
+ // This array is set once during bufferization of all tensors.
std::vector<Value> buffers;
// Sparse storage schemes (1-D): pointers and indices (by tensor and index).
+ // This array is set once during bufferization of all sparse tensors.
std::vector<std::vector<Value>> pointers;
std::vector<std::vector<Value>> indices;
- // Sparse iteration information (by tensor and index).
+ // Sparse iteration information (by tensor and index). These arrays
+ // are updated to remain current within the current loop.
std::vector<std::vector<Value>> highs;
std::vector<std::vector<Value>> pidxs;
std::vector<std::vector<Value>> idxs;
@@ -388,7 +396,7 @@ static unsigned buildLattices(Merger &merger, linalg::GenericOp op,
unsigned exp, unsigned idx) {
Kind kind = merger.exp(exp).kind;
if (kind == Kind::kTensor || kind == Kind::kInvariant) {
- // Either the index is really used in the tensor expression, or it it
+ // Either the index is really used in the tensor expression, or it is
// set to the "non-existing dense index" in that dimension. Invariant
// expressions borrow the output tensor indices.
unsigned s = merger.addSet();
@@ -573,38 +581,81 @@ static bool genInit(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
return needsUniv;
}
-/// Generates a for-loop or a while-loop, depending on whether it implements
-/// singleton iteration or co-iteration over the given conjunction.
-static void genLoop(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
- linalg::GenericOp op, unsigned idx, bool needsUniv,
- llvm::BitVector &indices, scf::ForOp &forOp,
- scf::WhileOp &whileOp) {
+/// Generates a for-loop on a single index.
+static Operation *genFor(Merger &merger, CodeGen &codegen,
+ PatternRewriter &rewriter, linalg::GenericOp op,
+ bool isOuter, unsigned idx, llvm::BitVector &indices) {
+ unsigned fb = indices.find_first();
+ unsigned tensor = merger.tensor(fb);
+ assert(idx == merger.index(fb));
+
+ // Parallelization strategy. Any implicit loop in the Linalg operation that
+ // is marked "parallel" is a candidate. Whether it is actually converted to
+ // a parallel operation depends on the requested strategy.
+ auto iteratorTypes = op.iterator_types().getValue();
+ bool isSparse = merger.isSparseBit(fb);
+ bool isParallel = linalg::isParallelIteratorType(iteratorTypes[idx]);
+ switch (codegen.options.parallelizationStrategy) {
+ case linalg::SparseParallelizationStrategy::kNone:
+ isParallel = false;
+ break;
+ case linalg::SparseParallelizationStrategy::kDenseOuterLoop:
+ isParallel &= isOuter && !isSparse;
+ break;
+ case linalg::SparseParallelizationStrategy::kAnyStorageOuterLoop:
+ isParallel &= isOuter;
+ break;
+ case linalg::SparseParallelizationStrategy::kDenseAnyLoop:
+ isParallel &= !isSparse;
+ break;
+ case linalg::SparseParallelizationStrategy::kAnyStorageAnyLoop:
+ break;
+ }
+
+ // Loop bounds and increment.
Location loc = op.getLoc();
+ Value lo;
+ Value hi;
+ Value step = rewriter.create<ConstantIndexOp>(loc, 1);
+ Value index;
+ if (isSparse) {
+ lo = codegen.pidxs[tensor][idx];
+ hi = codegen.highs[tensor][idx];
+ } else {
+ lo = codegen.loops[idx];
+ hi = codegen.sizes[idx];
+ }
+
+ // Emit a parallel loop.
+ if (isParallel) {
+ scf::ParallelOp parOp = rewriter.create<scf::ParallelOp>(loc, lo, hi, step);
+ if (isSparse)
+ codegen.pidxs[tensor][idx] = parOp.getInductionVars()[0];
+ else
+ codegen.loops[idx] = parOp.getInductionVars()[0];
+ rewriter.setInsertionPointToStart(parOp.getBody());
+ return parOp;
+ }
+
+ // Emit a sequential loop.
+ scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step);
+ if (isSparse)
+ codegen.pidxs[tensor][idx] = forOp.getInductionVar();
+ else
+ codegen.loops[idx] = forOp.getInductionVar();
+ rewriter.setInsertionPointToStart(forOp.getBody());
+ return forOp;
+}
- // Emit a for-loop for a single index.
- if (indices.count() == 1) {
- unsigned fb = indices.find_first();
- unsigned tensor = merger.tensor(fb);
- assert(idx == merger.index(fb));
- // Emit a sparse for-loop or a dense for-loop.
- Value one = rewriter.create<ConstantIndexOp>(loc, 1);
- if (merger.isSparseBit(fb)) {
- forOp = rewriter.create<scf::ForOp>(loc, codegen.pidxs[tensor][idx],
- codegen.highs[tensor][idx], one);
- codegen.pidxs[tensor][idx] = forOp.getInductionVar();
- } else {
- forOp = rewriter.create<scf::ForOp>(loc, codegen.loops[idx],
- codegen.sizes[idx], one);
- codegen.loops[idx] = forOp.getInductionVar();
- }
- rewriter.setInsertionPointToStart(forOp.getBody());
- return;
- }
-
- // Otherwise, emit a while-loop for co-iteration.
- Type indexType = rewriter.getIndexType();
+/// Emit a while-loop for co-iteration over multiple indices.
+static Operation *genWhile(Merger &merger, CodeGen &codegen,
+ PatternRewriter &rewriter, linalg::GenericOp op,
+ unsigned idx, bool needsUniv,
+ llvm::BitVector &indices) {
SmallVector<Type, 4> types;
SmallVector<Value, 4> operands;
+ // Construct the while-loop with a parameter for each index.
+ Type indexType = rewriter.getIndexType();
for (unsigned b = 0, be = indices.size(); b < be; b++) {
if (indices[b] && merger.isSparseBit(b)) {
unsigned tensor = merger.tensor(b);
@@ -617,9 +668,11 @@ static void genLoop(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
types.push_back(indexType);
operands.push_back(codegen.loops[idx]);
}
- whileOp = rewriter.create<scf::WhileOp>(loc, types, operands);
+ Location loc = op.getLoc();
+ scf::WhileOp whileOp = rewriter.create<scf::WhileOp>(loc, types, operands);
Block *before = rewriter.createBlock(&whileOp.before(), {}, types);
Block *after = rewriter.createBlock(&whileOp.after(), {}, types);
+
// Build the "before" region, which effectively consists
// of a conjunction of "i < upper" tests on all induction.
rewriter.setInsertionPointToStart(&whileOp.before().front());
@@ -641,6 +694,18 @@ static void genLoop(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
assert(o == operands.size());
rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments());
rewriter.setInsertionPointToStart(&whileOp.after().front());
+ return whileOp;
+}
+
+/// Generates a for-loop or a while-loop, depending on whether it implements
+/// singleton iteration or co-iteration over the given conjunction.
+static Operation *genLoop(Merger &merger, CodeGen &codegen,
+ PatternRewriter &rewriter, linalg::GenericOp op,
+ bool isOuter, unsigned idx, bool needsUniv,
+ llvm::BitVector &indices) {
+ if (indices.count() == 1)
+ return genFor(merger, codegen, rewriter, op, isOuter, idx, indices);
+ return genWhile(merger, codegen, rewriter, op, idx, needsUniv, indices);
}
/// Generates the local variables for this loop, consisting of the sparse
@@ -804,16 +869,16 @@ static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
LatPoint lati = merger.lat(li);
// Emit loop.
- scf::ForOp forOp;
- scf::WhileOp whileOp;
llvm::BitVector indices = lati.bits;
optimizeIndices(merger, lsize, indices);
- genLoop(merger, codegen, rewriter, op, idx, needsUniv, indices, forOp,
- whileOp);
+ bool isOuter = at == 0;
+ Operation *loop = genLoop(merger, codegen, rewriter, op, isOuter, idx,
+ needsUniv, indices);
genLocals(merger, codegen, rewriter, op, topSort, at, needsUniv, lati.bits);
// Visit all lattices points with Li >= Lj to generate the
// loop-body, possibly with if statements for coiteration.
+ bool isWhile = dyn_cast<scf::WhileOp>(loop) != nullptr;
scf::IfOp ifOp;
for (unsigned lj : merger.set(lts)) {
if (li == lj || merger.latGT(li, lj)) {
@@ -823,22 +888,22 @@ static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
if (merger.hasAnyOf(tmp, false))
continue; // dense exhausted within if/else
// Recurse into body of each branch.
- if (whileOp)
+ if (isWhile)
genIf(merger, codegen, rewriter, op, idx, latj.bits, ifOp);
genStmt(merger, codegen, rewriter, op, topSort, latj.exp, at + 1);
}
}
// Wrap-up induction and restore insertion point.
- if (forOp) {
- needsUniv = false;
- rewriter.setInsertionPointAfter(forOp);
- } else {
+ if (isWhile) {
+ scf::WhileOp whileOp = cast<scf::WhileOp>(loop);
rewriter.setInsertionPointToEnd(&whileOp.after().front());
genWhileInduction(merger, codegen, rewriter, op, idx, needsUniv,
lati.bits, whileOp.results());
- rewriter.setInsertionPointAfter(whileOp);
+ } else {
+ needsUniv = false;
}
+ rewriter.setInsertionPointAfter(loop);
}
}
@@ -846,7 +911,9 @@ namespace {
/// Sparse rewriting rule for generic Lingalg operation.
struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
- using OpRewritePattern<linalg::GenericOp>::OpRewritePattern;
+public:
+ GenericOpSparsifier(MLIRContext *context, linalg::SparsificationOptions o)
+ : OpRewritePattern<linalg::GenericOp>(context), options(o) {}
LogicalResult matchAndRewrite(linalg::GenericOp op,
PatternRewriter &rewriter) const override {
@@ -878,7 +945,7 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
return failure(); // build failure
// Recursively generates code.
- CodeGen codegen(numTensors, numLoops);
+ CodeGen codegen(options, numTensors, numLoops);
genBuffers(merger, codegen, rewriter, op);
genStmt(merger, codegen, rewriter, op, topSort, exp.getValue(), 0);
Value result =
@@ -886,13 +953,18 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
rewriter.replaceOp(op, result);
return success();
}
+
+private:
+ /// Options to control sparse code generation.
+ linalg::SparsificationOptions options;
};
} // namespace
/// Populates the given patterns list with rewriting rules required for
/// the sparsification of linear algebra operations.
-void mlir::linalg::populateSparsificationPatterns(
- MLIRContext *context, OwningRewritePatternList &patterns) {
- patterns.insert<GenericOpSparsifier>(context);
+void linalg::populateSparsificationPatterns(
+ MLIRContext *context, OwningRewritePatternList &patterns,
+ const SparsificationOptions &options) {
+ patterns.insert<GenericOpSparsifier>(context, options);
}
diff --git a/mlir/test/Dialect/Linalg/sparse_parallel.mlir b/mlir/test/Dialect/Linalg/sparse_parallel.mlir
new file mode 100644
index 000000000000..a75406fbab69
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/sparse_parallel.mlir
@@ -0,0 +1,161 @@
+// RUN: mlir-opt %s -test-sparsification="parallelization-strategy=0" | \
+// RUN: FileCheck %s --check-prefix=CHECK-PAR0
+// RUN: mlir-opt %s -test-sparsification="parallelization-strategy=1" | \
+// RUN: FileCheck %s --check-prefix=CHECK-PAR1
+// RUN: mlir-opt %s -test-sparsification="parallelization-strategy=2" | \
+// RUN: FileCheck %s --check-prefix=CHECK-PAR2
+// RUN: mlir-opt %s -test-sparsification="parallelization-strategy=3" | \
+// RUN: FileCheck %s --check-prefix=CHECK-PAR3
+// RUN: mlir-opt %s -test-sparsification="parallelization-strategy=4" | \
+// RUN: FileCheck %s --check-prefix=CHECK-PAR4
+
+#trait_dd = {
+ indexing_maps = [
+ affine_map<(i,j) -> (i,j)>, // A
+ affine_map<(i,j) -> (i,j)> // X (out)
+ ],
+ sparse = [
+ [ "D", "D" ], // A
+ [ "D", "D" ] // X
+ ],
+ iterator_types = ["parallel", "parallel"],
+ doc = "X(i,j) = A(i,j) * SCALE"
+}
+
+//
+// CHECK-PAR0-LABEL: func @scale_dd
+// CHECK-PAR0: scf.for
+// CHECK-PAR0: scf.for
+// CHECK-PAR0: return
+//
+// CHECK-PAR1-LABEL: func @scale_dd
+// CHECK-PAR1: scf.parallel
+// CHECK-PAR1: scf.for
+// CHECK-PAR1: return
+//
+// CHECK-PAR2-LABEL: func @scale_dd
+// CHECK-PAR2: scf.parallel
+// CHECK-PAR2: scf.for
+// CHECK-PAR2: return
+//
+// CHECK-PAR3-LABEL: func @scale_dd
+// CHECK-PAR3: scf.parallel
+// CHECK-PAR3: scf.parallel
+// CHECK-PAR3: return
+//
+// CHECK-PAR4-LABEL: func @scale_dd
+// CHECK-PAR4: scf.parallel
+// CHECK-PAR4: scf.parallel
+// CHECK-PAR4: return
+//
+func @scale_dd(%scale: f32, %arga: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.generic #trait_dd
+ ins(%arga: tensor<?x?xf32>) {
+ ^bb(%a: f32):
+ %0 = mulf %a, %scale : f32
+ linalg.yield %0 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+#trait_ss = {
+ indexing_maps = [
+ affine_map<(i,j) -> (i,j)>, // A
+ affine_map<(i,j) -> (i,j)> // X (out)
+ ],
+ sparse = [
+ [ "S", "S" ], // A
+ [ "D", "D" ] // X
+ ],
+ iterator_types = ["parallel", "parallel"],
+ doc = "X(i,j) = A(i,j) * SCALE"
+}
+
+//
+// CHECK-PAR0-LABEL: func @scale_ss
+// CHECK-PAR0: scf.for
+// CHECK-PAR0: scf.for
+// CHECK-PAR0: return
+//
+// CHECK-PAR1-LABEL: func @scale_ss
+// CHECK-PAR1: scf.for
+// CHECK-PAR1: scf.for
+// CHECK-PAR1: return
+//
+// CHECK-PAR2-LABEL: func @scale_ss
+// CHECK-PAR2: scf.parallel
+// CHECK-PAR2: scf.for
+// CHECK-PAR2: return
+//
+// CHECK-PAR3-LABEL: func @scale_ss
+// CHECK-PAR3: scf.for
+// CHECK-PAR3: scf.for
+// CHECK-PAR3: return
+//
+// CHECK-PAR4-LABEL: func @scale_ss
+// CHECK-PAR4: scf.parallel
+// CHECK-PAR4: scf.parallel
+// CHECK-PAR4: return
+//
+func @scale_ss(%scale: f32, %arga: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.generic #trait_ss
+ ins(%arga: tensor<?x?xf32>) {
+ ^bb(%a: f32):
+ %0 = mulf %a, %scale : f32
+ linalg.yield %0 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+#trait_matvec = {
+ indexing_maps = [
+ affine_map<(i,j) -> (i,j)>, // A
+ affine_map<(i,j) -> (j)>, // b
+ affine_map<(i,j) -> (i)> // x (out)
+ ],
+ sparse = [
+ [ "D", "S" ], // A
+ [ "D" ], // b
+ [ "D" ] // x
+ ],
+ iterator_types = ["parallel", "reduction"],
+ doc = "x(i) += A(i,j) * b(j)"
+}
+
+//
+// CHECK-PAR0-LABEL: func @matvec
+// CHECK-PAR0: scf.for
+// CHECK-PAR0: scf.for
+// CHECK-PAR0: return
+//
+// CHECK-PAR1-LABEL: func @matvec
+// CHECK-PAR1: scf.parallel
+// CHECK-PAR1: scf.for
+// CHECK-PAR1: return
+//
+// CHECK-PAR2-LABEL: func @matvec
+// CHECK-PAR2: scf.parallel
+// CHECK-PAR2: scf.for
+// CHECK-PAR2: return
+//
+// CHECK-PAR3-LABEL: func @matvec
+// CHECK-PAR3: scf.parallel
+// CHECK-PAR3: scf.for
+// CHECK-PAR3: return
+//
+// CHECK-PAR4-LABEL: func @matvec
+// CHECK-PAR4: scf.parallel
+// CHECK-PAR4: scf.for
+// CHECK-PAR4: return
+//
+func @matvec(%argA: tensor<16x32xf32>, %argb: tensor<32xf32>, %argx: tensor<16xf32>) -> tensor<16xf32> {
+ %0 = linalg.generic #trait_matvec
+ ins(%argA, %argb : tensor<16x32xf32>, tensor<32xf32>)
+ init(%argx : tensor<16xf32>) {
+ ^bb(%A: f32, %b: f32, %x: f32):
+ %0 = mulf %A, %b : f32
+ %1 = addf %0, %x : f32
+ linalg.yield %1 : f32
+ } -> tensor<16xf32>
+ return %0 : tensor<16xf32>
+}
diff --git a/mlir/test/lib/Transforms/TestSparsification.cpp b/mlir/test/lib/Transforms/TestSparsification.cpp
index 038d7fab0656..7544e48174b8 100644
--- a/mlir/test/lib/Transforms/TestSparsification.cpp
+++ b/mlir/test/lib/Transforms/TestSparsification.cpp
@@ -16,13 +16,63 @@ namespace {
struct TestSparsification
: public PassWrapper<TestSparsification, FunctionPass> {
+
+ TestSparsification() = default;
+ TestSparsification(const TestSparsification &pass) {}
+
+ Option<int32_t> parallelization{
+ *this, "parallelization-strategy",
+ llvm::cl::desc("Set the parallelization strategy"), llvm::cl::init(0)};
+
+ Option<int32_t> vectorization{
+ *this, "vectorization-strategy",
+ llvm::cl::desc("Set the vectorization strategy"), llvm::cl::init(0)};
+
+ Option<int32_t> vectorLength{
+ *this, "vl", llvm::cl::desc("Set the vector length"), llvm::cl::init(1)};
+
+ /// Registers all dialects required by testing.
void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<scf::SCFDialect>();
+ registry.insert<scf::SCFDialect, vector::VectorDialect>();
}
+
+ /// Returns parallelization strategy given on command line.
+ linalg::SparseParallelizationStrategy parallelOption() {
+ switch (parallelization) {
+ default:
+ return linalg::SparseParallelizationStrategy::kNone;
+ case 1:
+ return linalg::SparseParallelizationStrategy::kDenseOuterLoop;
+ case 2:
+ return linalg::SparseParallelizationStrategy::kAnyStorageOuterLoop;
+ case 3:
+ return linalg::SparseParallelizationStrategy::kDenseAnyLoop;
+ case 4:
+ return linalg::SparseParallelizationStrategy::kAnyStorageAnyLoop;
+ }
+ }
+
+ /// Returns vectorization strategy given on command line.
+ linalg::SparseVectorizationStrategy vectorOption() {
+ switch (vectorization) {
+ default:
+ return linalg::SparseVectorizationStrategy::kNone;
+ case 1:
+ return linalg::SparseVectorizationStrategy::kDenseInnerLoop;
+ case 2:
+ return linalg::SparseVectorizationStrategy::kAnyStorageInnerLoop;
+ }
+ }
+
+ /// Runs the test on a function.
void runOnFunction() override {
auto *ctx = &getContext();
OwningRewritePatternList patterns;
- linalg::populateSparsificationPatterns(ctx, patterns);
+ // Translate strategy flags to strategy options.
+ linalg::SparsificationOptions options(parallelOption(), vectorOption(),
+ vectorLength);
+ // Apply rewriting.
+ linalg::populateSparsificationPatterns(ctx, patterns, options);
applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}
};
More information about the Mlir-commits
mailing list