[Mlir-commits] [mlir] 5c4e397 - [mlir][sparse] add parallelization strategies to sparse compiler

Aart Bik llvmlistbot at llvm.org
Tue Nov 24 17:17:34 PST 2020


Author: Aart Bik
Date: 2020-11-24T17:17:13-08:00
New Revision: 5c4e397e6ce5c89d63f590857e5cb0e80237de62

URL: https://github.com/llvm/llvm-project/commit/5c4e397e6ce5c89d63f590857e5cb0e80237de62
DIFF: https://github.com/llvm/llvm-project/commit/5c4e397e6ce5c89d63f590857e5cb0e80237de62.diff

LOG: [mlir][sparse] add parallelization strategies to sparse compiler

This CL adds the ability to request different parallelization strategies
for the generate code. Every "parallel" loop is a candidate, and converted
to a parallel op if it is an actual for-loop (not a while) and the strategy
allows dense/sparse outer/inner parallelization.

This will connect directly with the work of @ezhulenev on parallel loops.

Still TBD: vectorization strategy

Reviewed By: penpornk

Differential Revision: https://reviews.llvm.org/D91978

Added: 
    mlir/test/Dialect/Linalg/sparse_parallel.mlir

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
    mlir/test/lib/Transforms/TestSparsification.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index d67e81ceab87..7a96a5ed1390 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -783,9 +783,62 @@ LogicalResult applyStagedPatterns(
 
 //===----------------------------------------------------------------------===//
 // Support for sparse tensor code generation.
+//
+// The sparse compiler part of MLIR lowers a tensor expression formulated as a
+// Linalg operation into a sequence of loops depending on what dimensions of the
+// tensors are marked dense or sparse. The generated code distinguishes between:
+// (1) for-loops that iterate over a single dense dimension,
+// (2) for-loops that iterate over a single sparse dimension,
+// (3) while-loops that co-iterate over several sparse dimensions.
+// The for-loops may be subsequently optimized for parallel or vector execution.
+//
+// For more details, the Dialect/Linalg/Transforms/Sparsification.cpp file.
 //===----------------------------------------------------------------------===//
-void populateSparsificationPatterns(MLIRContext *context,
-                                    OwningRewritePatternList &patterns);
+
+/// Defines a parallelization strategy. Any implicit loop in the Linalg
+/// operation that is marked "parallel" (thus not "reduction") is a candidate
+/// for parallelization. The loop is made parallel if (1) allowed by the
+/// strategy (e.g., AnyStorageOuterLoop considers either a dense or sparse
+/// outermost loop only), and (2) the generated code is an actual for-loop
+/// (and not a co-iterating while-loop).
+enum class SparseParallelizationStrategy {
+  kNone,
+  kDenseOuterLoop,
+  kAnyStorageOuterLoop,
+  kDenseAnyLoop,
+  kAnyStorageAnyLoop
+  // TODO: support reduction parallelization too?
+};
+
+/// Defines a vectorization strategy. Any implicit inner loop in the Linalg
+/// operation is a candidate (full SIMD for "parallel" loops and horizontal
+/// SIMD for "reduction" loops). A loop is actually vectorized if (1) allowed
+/// by the strategy, and (2) the emitted code is an actual for-loop (and not
+/// a co-iterating while-loop).
+enum class SparseVectorizationStrategy {
+  kNone,
+  kDenseInnerLoop,
+  kAnyStorageInnerLoop
+};
+
+/// Sparsification options.
+struct SparsificationOptions {
+  SparsificationOptions(SparseParallelizationStrategy p,
+                        SparseVectorizationStrategy v, unsigned vl)
+      : parallelizationStrategy(p), vectorizationStrategy(v), vectorLength(vl) {
+  }
+  SparsificationOptions()
+      : SparsificationOptions(SparseParallelizationStrategy::kNone,
+                              SparseVectorizationStrategy::kNone, 1u) {}
+  SparseParallelizationStrategy parallelizationStrategy;
+  SparseVectorizationStrategy vectorizationStrategy;
+  unsigned vectorLength;
+};
+
+/// Set up sparsification rewriting rules with the given options.
+void populateSparsificationPatterns(
+    MLIRContext *context, OwningRewritePatternList &patterns,
+    const SparsificationOptions &options = SparsificationOptions());
 
 } // namespace linalg
 } // namespace mlir

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
index 69a4d7e5648e..729268393ed9 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
@@ -235,22 +235,30 @@ class Merger {
 
 // Code generation.
 struct CodeGen {
-  CodeGen(unsigned numTensors, unsigned numLoops)
-      : loops(numLoops), sizes(numLoops), buffers(numTensors),
+  CodeGen(linalg::SparsificationOptions o, unsigned numTensors,
+          unsigned numLoops)
+      : options(o), loops(numLoops), sizes(numLoops), buffers(numTensors),
         pointers(numTensors, std::vector<Value>(numLoops)),
         indices(numTensors, std::vector<Value>(numLoops)),
         highs(numTensors, std::vector<Value>(numLoops)),
         pidxs(numTensors, std::vector<Value>(numLoops)),
         idxs(numTensors, std::vector<Value>(numLoops)) {}
-  // Universal dense indices and upper bounds (by index).
+  // Sparsification options.
+  linalg::SparsificationOptions options;
+  // Universal dense indices and upper bounds (by index). The loops array
+  // is updated with the value of the universal dense index in the current
+  // loop. The sizes array is set once with the inferred dimension sizes.
   std::vector<Value> loops;
   std::vector<Value> sizes;
   // Buffers for storing dense and sparse numerical values (by tensor).
+  // This array is set once during bufferization of all tensors.
   std::vector<Value> buffers;
   // Sparse storage schemes (1-D): pointers and indices (by tensor and index).
+  // This array is set once during bufferization of all sparse tensors.
   std::vector<std::vector<Value>> pointers;
   std::vector<std::vector<Value>> indices;
-  // Sparse iteration information (by tensor and index).
+  // Sparse iteration information (by tensor and index). These arrays
+  // are updated to remain current within the current loop.
   std::vector<std::vector<Value>> highs;
   std::vector<std::vector<Value>> pidxs;
   std::vector<std::vector<Value>> idxs;
@@ -388,7 +396,7 @@ static unsigned buildLattices(Merger &merger, linalg::GenericOp op,
                               unsigned exp, unsigned idx) {
   Kind kind = merger.exp(exp).kind;
   if (kind == Kind::kTensor || kind == Kind::kInvariant) {
-    // Either the index is really used in the tensor expression, or it it
+    // Either the index is really used in the tensor expression, or it is
     // set to the "non-existing dense index" in that dimension. Invariant
     // expressions borrow the output tensor indices.
     unsigned s = merger.addSet();
@@ -573,38 +581,81 @@ static bool genInit(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
   return needsUniv;
 }
 
-/// Generates a for-loop or a while-loop, depending on whether it implements
-/// singleton iteration or co-iteration over the given conjunction.
-static void genLoop(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
-                    linalg::GenericOp op, unsigned idx, bool needsUniv,
-                    llvm::BitVector &indices, scf::ForOp &forOp,
-                    scf::WhileOp &whileOp) {
+/// Generates a for-loop on a single index.
+static Operation *genFor(Merger &merger, CodeGen &codegen,
+                         PatternRewriter &rewriter, linalg::GenericOp op,
+                         bool isOuter, unsigned idx, llvm::BitVector &indices) {
+  unsigned fb = indices.find_first();
+  unsigned tensor = merger.tensor(fb);
+  assert(idx == merger.index(fb));
+
+  // Parallelization strategy. Any implicit loop in the Linalg operation that
+  // is marked "parallel" is a candidate. Whether it is actually converted to
+  // a parallel operation depends on the requested strategy.
+  auto iteratorTypes = op.iterator_types().getValue();
+  bool isSparse = merger.isSparseBit(fb);
+  bool isParallel = linalg::isParallelIteratorType(iteratorTypes[idx]);
+  switch (codegen.options.parallelizationStrategy) {
+  case linalg::SparseParallelizationStrategy::kNone:
+    isParallel = false;
+    break;
+  case linalg::SparseParallelizationStrategy::kDenseOuterLoop:
+    isParallel &= isOuter && !isSparse;
+    break;
+  case linalg::SparseParallelizationStrategy::kAnyStorageOuterLoop:
+    isParallel &= isOuter;
+    break;
+  case linalg::SparseParallelizationStrategy::kDenseAnyLoop:
+    isParallel &= !isSparse;
+    break;
+  case linalg::SparseParallelizationStrategy::kAnyStorageAnyLoop:
+    break;
+  }
+
+  // Loop bounds and increment.
   Location loc = op.getLoc();
+  Value lo;
+  Value hi;
+  Value step = rewriter.create<ConstantIndexOp>(loc, 1);
+  Value index;
+  if (isSparse) {
+    lo = codegen.pidxs[tensor][idx];
+    hi = codegen.highs[tensor][idx];
+  } else {
+    lo = codegen.loops[idx];
+    hi = codegen.sizes[idx];
+  }
+
+  // Emit a parallel loop.
+  if (isParallel) {
+    scf::ParallelOp parOp = rewriter.create<scf::ParallelOp>(loc, lo, hi, step);
+    if (isSparse)
+      codegen.pidxs[tensor][idx] = parOp.getInductionVars()[0];
+    else
+      codegen.loops[idx] = parOp.getInductionVars()[0];
+    rewriter.setInsertionPointToStart(parOp.getBody());
+    return parOp;
+  }
+
+  // Emit a sequential loop.
+  scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step);
+  if (isSparse)
+    codegen.pidxs[tensor][idx] = forOp.getInductionVar();
+  else
+    codegen.loops[idx] = forOp.getInductionVar();
+  rewriter.setInsertionPointToStart(forOp.getBody());
+  return forOp;
+}
 
-  // Emit a for-loop for a single index.
-  if (indices.count() == 1) {
-    unsigned fb = indices.find_first();
-    unsigned tensor = merger.tensor(fb);
-    assert(idx == merger.index(fb));
-    // Emit a sparse for-loop or a dense for-loop.
-    Value one = rewriter.create<ConstantIndexOp>(loc, 1);
-    if (merger.isSparseBit(fb)) {
-      forOp = rewriter.create<scf::ForOp>(loc, codegen.pidxs[tensor][idx],
-                                          codegen.highs[tensor][idx], one);
-      codegen.pidxs[tensor][idx] = forOp.getInductionVar();
-    } else {
-      forOp = rewriter.create<scf::ForOp>(loc, codegen.loops[idx],
-                                          codegen.sizes[idx], one);
-      codegen.loops[idx] = forOp.getInductionVar();
-    }
-    rewriter.setInsertionPointToStart(forOp.getBody());
-    return;
-  }
-
-  // Otherwise, emit a while-loop for co-iteration.
-  Type indexType = rewriter.getIndexType();
+/// Emit a while-loop for co-iteration over multiple indices.
+static Operation *genWhile(Merger &merger, CodeGen &codegen,
+                           PatternRewriter &rewriter, linalg::GenericOp op,
+                           unsigned idx, bool needsUniv,
+                           llvm::BitVector &indices) {
   SmallVector<Type, 4> types;
   SmallVector<Value, 4> operands;
+  // Construct the while-loop with a parameter for each index.
+  Type indexType = rewriter.getIndexType();
   for (unsigned b = 0, be = indices.size(); b < be; b++) {
     if (indices[b] && merger.isSparseBit(b)) {
       unsigned tensor = merger.tensor(b);
@@ -617,9 +668,11 @@ static void genLoop(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
     types.push_back(indexType);
     operands.push_back(codegen.loops[idx]);
   }
-  whileOp = rewriter.create<scf::WhileOp>(loc, types, operands);
+  Location loc = op.getLoc();
+  scf::WhileOp whileOp = rewriter.create<scf::WhileOp>(loc, types, operands);
   Block *before = rewriter.createBlock(&whileOp.before(), {}, types);
   Block *after = rewriter.createBlock(&whileOp.after(), {}, types);
+
   // Build the "before" region, which effectively consists
   // of a conjunction of "i < upper" tests on all induction.
   rewriter.setInsertionPointToStart(&whileOp.before().front());
@@ -641,6 +694,18 @@ static void genLoop(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
   assert(o == operands.size());
   rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments());
   rewriter.setInsertionPointToStart(&whileOp.after().front());
+  return whileOp;
+}
+
+/// Generates a for-loop or a while-loop, depending on whether it implements
+/// singleton iteration or co-iteration over the given conjunction.
+static Operation *genLoop(Merger &merger, CodeGen &codegen,
+                          PatternRewriter &rewriter, linalg::GenericOp op,
+                          bool isOuter, unsigned idx, bool needsUniv,
+                          llvm::BitVector &indices) {
+  if (indices.count() == 1)
+    return genFor(merger, codegen, rewriter, op, isOuter, idx, indices);
+  return genWhile(merger, codegen, rewriter, op, idx, needsUniv, indices);
 }
 
 /// Generates the local variables for this loop, consisting of the sparse
@@ -804,16 +869,16 @@ static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
     LatPoint lati = merger.lat(li);
 
     // Emit loop.
-    scf::ForOp forOp;
-    scf::WhileOp whileOp;
     llvm::BitVector indices = lati.bits;
     optimizeIndices(merger, lsize, indices);
-    genLoop(merger, codegen, rewriter, op, idx, needsUniv, indices, forOp,
-            whileOp);
+    bool isOuter = at == 0;
+    Operation *loop = genLoop(merger, codegen, rewriter, op, isOuter, idx,
+                              needsUniv, indices);
     genLocals(merger, codegen, rewriter, op, topSort, at, needsUniv, lati.bits);
 
     // Visit all lattices points with Li >= Lj to generate the
     // loop-body, possibly with if statements for coiteration.
+    bool isWhile = dyn_cast<scf::WhileOp>(loop) != nullptr;
     scf::IfOp ifOp;
     for (unsigned lj : merger.set(lts)) {
       if (li == lj || merger.latGT(li, lj)) {
@@ -823,22 +888,22 @@ static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
         if (merger.hasAnyOf(tmp, false))
           continue; // dense exhausted within if/else
         // Recurse into body of each branch.
-        if (whileOp)
+        if (isWhile)
           genIf(merger, codegen, rewriter, op, idx, latj.bits, ifOp);
         genStmt(merger, codegen, rewriter, op, topSort, latj.exp, at + 1);
       }
     }
 
     // Wrap-up induction and restore insertion point.
-    if (forOp) {
-      needsUniv = false;
-      rewriter.setInsertionPointAfter(forOp);
-    } else {
+    if (isWhile) {
+      scf::WhileOp whileOp = cast<scf::WhileOp>(loop);
       rewriter.setInsertionPointToEnd(&whileOp.after().front());
       genWhileInduction(merger, codegen, rewriter, op, idx, needsUniv,
                         lati.bits, whileOp.results());
-      rewriter.setInsertionPointAfter(whileOp);
+    } else {
+      needsUniv = false;
     }
+    rewriter.setInsertionPointAfter(loop);
   }
 }
 
@@ -846,7 +911,9 @@ namespace {
 
 /// Sparse rewriting rule for generic Lingalg operation.
 struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
-  using OpRewritePattern<linalg::GenericOp>::OpRewritePattern;
+public:
+  GenericOpSparsifier(MLIRContext *context, linalg::SparsificationOptions o)
+      : OpRewritePattern<linalg::GenericOp>(context), options(o) {}
 
   LogicalResult matchAndRewrite(linalg::GenericOp op,
                                 PatternRewriter &rewriter) const override {
@@ -878,7 +945,7 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
       return failure(); // build failure
 
     // Recursively generates code.
-    CodeGen codegen(numTensors, numLoops);
+    CodeGen codegen(options, numTensors, numLoops);
     genBuffers(merger, codegen, rewriter, op);
     genStmt(merger, codegen, rewriter, op, topSort, exp.getValue(), 0);
     Value result =
@@ -886,13 +953,18 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
     rewriter.replaceOp(op, result);
     return success();
   }
+
+private:
+  /// Options to control sparse code generation.
+  linalg::SparsificationOptions options;
 };
 
 } // namespace
 
 /// Populates the given patterns list with rewriting rules required for
 /// the sparsification of linear algebra operations.
-void mlir::linalg::populateSparsificationPatterns(
-    MLIRContext *context, OwningRewritePatternList &patterns) {
-  patterns.insert<GenericOpSparsifier>(context);
+void linalg::populateSparsificationPatterns(
+    MLIRContext *context, OwningRewritePatternList &patterns,
+    const SparsificationOptions &options) {
+  patterns.insert<GenericOpSparsifier>(context, options);
 }

diff  --git a/mlir/test/Dialect/Linalg/sparse_parallel.mlir b/mlir/test/Dialect/Linalg/sparse_parallel.mlir
new file mode 100644
index 000000000000..a75406fbab69
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/sparse_parallel.mlir
@@ -0,0 +1,161 @@
+// RUN: mlir-opt %s -test-sparsification="parallelization-strategy=0" | \
+// RUN:   FileCheck %s --check-prefix=CHECK-PAR0
+// RUN: mlir-opt %s -test-sparsification="parallelization-strategy=1" | \
+// RUN:   FileCheck %s --check-prefix=CHECK-PAR1
+// RUN: mlir-opt %s -test-sparsification="parallelization-strategy=2" | \
+// RUN:   FileCheck %s --check-prefix=CHECK-PAR2
+// RUN: mlir-opt %s -test-sparsification="parallelization-strategy=3" | \
+// RUN:   FileCheck %s --check-prefix=CHECK-PAR3
+// RUN: mlir-opt %s -test-sparsification="parallelization-strategy=4" | \
+// RUN:   FileCheck %s --check-prefix=CHECK-PAR4
+
+#trait_dd = {
+  indexing_maps = [
+    affine_map<(i,j) -> (i,j)>,  // A
+    affine_map<(i,j) -> (i,j)>   // X (out)
+  ],
+  sparse = [
+    [ "D", "D" ],  // A
+    [ "D", "D" ]   // X
+  ],
+  iterator_types = ["parallel", "parallel"],
+  doc = "X(i,j) = A(i,j) * SCALE"
+}
+
+//
+// CHECK-PAR0-LABEL: func @scale_dd
+// CHECK-PAR0:         scf.for
+// CHECK-PAR0:           scf.for
+// CHECK-PAR0:         return
+//
+// CHECK-PAR1-LABEL: func @scale_dd
+// CHECK-PAR1:         scf.parallel
+// CHECK-PAR1:           scf.for
+// CHECK-PAR1:         return
+//
+// CHECK-PAR2-LABEL: func @scale_dd
+// CHECK-PAR2:         scf.parallel
+// CHECK-PAR2:           scf.for
+// CHECK-PAR2:         return
+//
+// CHECK-PAR3-LABEL: func @scale_dd
+// CHECK-PAR3:         scf.parallel
+// CHECK-PAR3:           scf.parallel
+// CHECK-PAR3:         return
+//
+// CHECK-PAR4-LABEL: func @scale_dd
+// CHECK-PAR4:         scf.parallel
+// CHECK-PAR4:           scf.parallel
+// CHECK-PAR4:         return
+//
+func @scale_dd(%scale: f32, %arga: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.generic #trait_dd
+    ins(%arga: tensor<?x?xf32>) {
+      ^bb(%a: f32):
+        %0 = mulf %a, %scale  : f32
+        linalg.yield %0 : f32
+  } -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+#trait_ss = {
+  indexing_maps = [
+    affine_map<(i,j) -> (i,j)>,  // A
+    affine_map<(i,j) -> (i,j)>   // X (out)
+  ],
+  sparse = [
+    [ "S", "S" ],  // A
+    [ "D", "D" ]   // X
+  ],
+  iterator_types = ["parallel", "parallel"],
+  doc = "X(i,j) = A(i,j) * SCALE"
+}
+
+//
+// CHECK-PAR0-LABEL: func @scale_ss
+// CHECK-PAR0:         scf.for
+// CHECK-PAR0:           scf.for
+// CHECK-PAR0:         return
+//
+// CHECK-PAR1-LABEL: func @scale_ss
+// CHECK-PAR1:         scf.for
+// CHECK-PAR1:           scf.for
+// CHECK-PAR1:         return
+//
+// CHECK-PAR2-LABEL: func @scale_ss
+// CHECK-PAR2:         scf.parallel
+// CHECK-PAR2:           scf.for
+// CHECK-PAR2:         return
+//
+// CHECK-PAR3-LABEL: func @scale_ss
+// CHECK-PAR3:         scf.for
+// CHECK-PAR3:           scf.for
+// CHECK-PAR3:         return
+//
+// CHECK-PAR4-LABEL: func @scale_ss
+// CHECK-PAR4:         scf.parallel
+// CHECK-PAR4:           scf.parallel
+// CHECK-PAR4:         return
+//
+func @scale_ss(%scale: f32, %arga: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.generic #trait_ss
+    ins(%arga: tensor<?x?xf32>) {
+      ^bb(%a: f32):
+        %0 = mulf %a, %scale  : f32
+        linalg.yield %0 : f32
+  } -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+#trait_matvec = {
+  indexing_maps = [
+    affine_map<(i,j) -> (i,j)>,  // A
+    affine_map<(i,j) -> (j)>,    // b
+    affine_map<(i,j) -> (i)>     // x (out)
+  ],
+  sparse = [
+    [ "D", "S" ],  // A
+    [ "D" ],       // b
+    [ "D" ]        // x
+  ],
+  iterator_types = ["parallel", "reduction"],
+  doc = "x(i) += A(i,j) * b(j)"
+}
+
+//
+// CHECK-PAR0-LABEL: func @matvec
+// CHECK-PAR0:         scf.for
+// CHECK-PAR0:           scf.for
+// CHECK-PAR0:         return
+//
+// CHECK-PAR1-LABEL: func @matvec
+// CHECK-PAR1:         scf.parallel
+// CHECK-PAR1:           scf.for
+// CHECK-PAR1:         return
+//
+// CHECK-PAR2-LABEL: func @matvec
+// CHECK-PAR2:         scf.parallel
+// CHECK-PAR2:           scf.for
+// CHECK-PAR2:         return
+//
+// CHECK-PAR3-LABEL: func @matvec
+// CHECK-PAR3:         scf.parallel
+// CHECK-PAR3:           scf.for
+// CHECK-PAR3:         return
+//
+// CHECK-PAR4-LABEL: func @matvec
+// CHECK-PAR4:         scf.parallel
+// CHECK-PAR4:           scf.for
+// CHECK-PAR4:         return
+//
+func @matvec(%argA: tensor<16x32xf32>, %argb: tensor<32xf32>, %argx: tensor<16xf32>) -> tensor<16xf32> {
+  %0 = linalg.generic #trait_matvec
+      ins(%argA, %argb : tensor<16x32xf32>, tensor<32xf32>)
+      init(%argx : tensor<16xf32>) {
+    ^bb(%A: f32, %b: f32, %x: f32):
+      %0 = mulf %A, %b : f32
+      %1 = addf %0, %x : f32
+      linalg.yield %1 : f32
+  } -> tensor<16xf32>
+  return %0 : tensor<16xf32>
+}

diff  --git a/mlir/test/lib/Transforms/TestSparsification.cpp b/mlir/test/lib/Transforms/TestSparsification.cpp
index 038d7fab0656..7544e48174b8 100644
--- a/mlir/test/lib/Transforms/TestSparsification.cpp
+++ b/mlir/test/lib/Transforms/TestSparsification.cpp
@@ -16,13 +16,63 @@ namespace {
 
 struct TestSparsification
     : public PassWrapper<TestSparsification, FunctionPass> {
+
+  TestSparsification() = default;
+  TestSparsification(const TestSparsification &pass) {}
+
+  Option<int32_t> parallelization{
+      *this, "parallelization-strategy",
+      llvm::cl::desc("Set the parallelization strategy"), llvm::cl::init(0)};
+
+  Option<int32_t> vectorization{
+      *this, "vectorization-strategy",
+      llvm::cl::desc("Set the vectorization strategy"), llvm::cl::init(0)};
+
+  Option<int32_t> vectorLength{
+      *this, "vl", llvm::cl::desc("Set the vector length"), llvm::cl::init(1)};
+
+  /// Registers all dialects required by testing.
   void getDependentDialects(DialectRegistry &registry) const override {
-    registry.insert<scf::SCFDialect>();
+    registry.insert<scf::SCFDialect, vector::VectorDialect>();
   }
+
+  /// Returns parallelization strategy given on command line.
+  linalg::SparseParallelizationStrategy parallelOption() {
+    switch (parallelization) {
+    default:
+      return linalg::SparseParallelizationStrategy::kNone;
+    case 1:
+      return linalg::SparseParallelizationStrategy::kDenseOuterLoop;
+    case 2:
+      return linalg::SparseParallelizationStrategy::kAnyStorageOuterLoop;
+    case 3:
+      return linalg::SparseParallelizationStrategy::kDenseAnyLoop;
+    case 4:
+      return linalg::SparseParallelizationStrategy::kAnyStorageAnyLoop;
+    }
+  }
+
+  /// Returns vectorization strategy given on command line.
+  linalg::SparseVectorizationStrategy vectorOption() {
+    switch (vectorization) {
+    default:
+      return linalg::SparseVectorizationStrategy::kNone;
+    case 1:
+      return linalg::SparseVectorizationStrategy::kDenseInnerLoop;
+    case 2:
+      return linalg::SparseVectorizationStrategy::kAnyStorageInnerLoop;
+    }
+  }
+
+  /// Runs the test on a function.
   void runOnFunction() override {
     auto *ctx = &getContext();
     OwningRewritePatternList patterns;
-    linalg::populateSparsificationPatterns(ctx, patterns);
+    // Translate strategy flags to strategy options.
+    linalg::SparsificationOptions options(parallelOption(), vectorOption(),
+                                          vectorLength);
+    // Apply rewriting.
+    linalg::populateSparsificationPatterns(ctx, patterns, options);
     applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
   }
 };


        


More information about the Mlir-commits mailing list