[Mlir-commits] [mlir] 55a1d50 - [mlir][sparse] Make sparse compiler more admissible.

Peiming Liu llvmlistbot at llvm.org
Wed Sep 14 08:59:58 PDT 2022


Author: Peiming Liu
Date: 2022-09-14T15:59:47Z
New Revision: 55a1d50fb9abb79c540ae32f1bf16a1fbd29f1a6

URL: https://github.com/llvm/llvm-project/commit/55a1d50fb9abb79c540ae32f1bf16a1fbd29f1a6
DIFF: https://github.com/llvm/llvm-project/commit/55a1d50fb9abb79c540ae32f1bf16a1fbd29f1a6.diff

LOG: [mlir][sparse] Make sparse compiler more admissible.

Previously, the iteration graph is computed without priority. This patch add a heuristic when computing the iteration graph by starting with Reduction iterator when doing topo sort, which makes Reduction iterators (likely) appear as late in the sorted array as possible.

The current sparse compiler also failed to compile the newly added case.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D133738

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
    mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_filter_conv2d.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 918018a20c457..e283748cb79c8 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -223,43 +223,65 @@ static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) {
   return annotated;
 }
 
-/// A DFS helper to compute a topological sort. Note that recursion is
-/// bounded by the number of implicit loops, which is always small.
-/// Returns false when a cycle is detected.
-static bool topSortDFS(unsigned i, std::vector<unsigned> &visit,
-                       std::vector<unsigned> &topSort,
-                       std::vector<std::vector<bool>> &adjM) {
-  if (visit[i] != 0)
-    return visit[i] != 1; // 1 denotes cycle!
-  visit[i] = 1;
-  for (unsigned j = 0, e = visit.size(); j < e; j++)
-    if (adjM[i][j])
-      if (!topSortDFS(j, visit, topSort, adjM))
-        return false;
-  visit[i] = 2;
-  topSort.push_back(i);
-  return true;
+/// A helper to compute a topological sort. O(n^2) time complexity
+/// as we use adj matrix for the graph.
+/// The sorted result will put the first Reduction iterator to the
+/// latest possible index.
+static bool topSortOptimal(unsigned n, ArrayRef<Attribute> iteratorTypes,
+                           std::vector<unsigned> &topSort,
+                           std::vector<unsigned> &inDegree,
+                           std::vector<std::vector<bool>> &adjM) {
+  std::vector<unsigned> redIt; // reduce iterator with 0 degree
+  std::vector<unsigned> parIt; // parallel iterator with 0 degree
+  for (unsigned i = 0; i < n; i++) {
+    if (inDegree[i] == 0) {
+      if (linalg::isReductionIterator(iteratorTypes[i]))
+        redIt.push_back(i);
+      else
+        parIt.push_back(i);
+    }
+  }
+
+  while (!redIt.empty() || !parIt.empty()) {
+    // We always choose parallel iterator if there is any.
+    auto &it = !parIt.empty() ? parIt : redIt;
+    auto src = it.back();
+    topSort.push_back(src);
+    it.pop_back();
+    // Update in-degree, and push 0-degree node into worklist.
+    for (unsigned dst = 0; dst < n; dst++)
+      if (adjM[src][dst] && --inDegree[dst] == 0) {
+        if (linalg::isReductionIterator(iteratorTypes[dst]))
+          redIt.push_back(dst);
+        else
+          parIt.push_back(dst);
+      }
+  }
+  return topSort.size() == n;
 }
 
 /// Helper method to add all constraints from the indices in one affine
 /// expression before all indices in the other affine expression. For
 /// example i0+i1 < i2+i3+1 yields i0<i2, i0<i3, i1<i2, and i1<i3.
 static void addAffineOrderings(std::vector<std::vector<bool>> &adjM,
-                               AffineExpr a, AffineExpr b, unsigned fidx) {
+                               std::vector<unsigned> &inDegree, AffineExpr a,
+                               AffineExpr b, unsigned fidx) {
   switch (a.getKind()) {
   case AffineExprKind::DimId: {
     unsigned idx = a.cast<AffineDimExpr>().getPosition();
     if (b)
-      addAffineOrderings(adjM, b, AffineExpr(), idx);
-    else
+      addAffineOrderings(adjM, inDegree, b, AffineExpr(), idx);
+    else if (!adjM[fidx][idx]) {
       adjM[fidx][idx] = true;
+      inDegree[idx]++;
+    }
     break;
   }
   case AffineExprKind::Add:
   case AffineExprKind::Mul: {
     auto binOp = a.cast<AffineBinaryOpExpr>();
-    addAffineOrderings(adjM, binOp.getLHS(), b, fidx);
-    addAffineOrderings(adjM, binOp.getRHS(), b, fidx);
+    addAffineOrderings(adjM, inDegree, binOp.getLHS(), b, fidx);
+    addAffineOrderings(adjM, inDegree, binOp.getRHS(), b, fidx);
     break;
   }
   default:
@@ -279,7 +301,8 @@ static bool computeIterationGraph(Merger &merger, linalg::GenericOp op,
   // for the implicit loop indices i_0 .. i_n-1.
   unsigned n = op.getNumLoops();
   std::vector<std::vector<bool>> adjM(n, std::vector<bool>(n, false));
-
+  std::vector<unsigned> inDegree(n, 0); // in-degree of each node.
+  auto iteratorTypes = op.iterator_types().getValue();
   // Iterate over the indexing maps of every tensor in the tensor expression.
   for (OpOperand *t : op.getInputAndOutputOperands()) {
     // Skip tensor during cycle resolution.
@@ -299,7 +322,7 @@ static bool computeIterationGraph(Merger &merger, linalg::GenericOp op,
     for (unsigned d = 1, rank = map.getNumResults(); d < rank; d++) {
       AffineExpr f = map.getResult(perm(enc, d - 1));
       AffineExpr t = map.getResult(perm(enc, d));
-      addAffineOrderings(adjM, f, t, 0);
+      addAffineOrderings(adjM, inDegree, f, t, 0);
     }
     // Push unrelated loops into sparse iteration space, so these
     // will be skipped more often.
@@ -309,21 +332,17 @@ static bool computeIterationGraph(Merger &merger, linalg::GenericOp op,
         if (merger.isDimLevelType(tensor, i, DimLvlType::kCompressed) ||
             merger.isDimLevelType(tensor, i, DimLvlType::kSingleton))
           for (unsigned j = 0; j < n; j++)
-            if (merger.isDimLevelType(tensor, j, DimLvlType::kUndef))
+            if (merger.isDimLevelType(tensor, j, DimLvlType::kUndef)) {
               adjM[i][j] = true;
+              inDegree[j]++;
+            }
     }
   }
   // Topologically sort the iteration graph to determine loop order.
   // Report failure for a cyclic iteration graph.
   topSort.clear();
   topSort.reserve(n);
-  std::vector<unsigned> visit(n, 0);
-  for (unsigned i = 0; i < n; i++)
-    if (visit[i] == 0)
-      if (!topSortDFS(i, visit, topSort, adjM))
-        return false; // cycle!
-  std::reverse(std::begin(topSort), std::end(topSort));
-  return true;
+  return topSortOptimal(n, iteratorTypes, topSort, inDegree, adjM);
 }
 
 /// Returns true if tensor materializes uninitialized into the computation.
@@ -1271,7 +1290,8 @@ static Operation *genFor(Merger &merger, CodeGen &codegen, OpBuilder &builder,
   bool isParallel =
       isParallelFor(codegen, isOuter, isReduction, isSparse, isVector);
 
-  assert(!merger.isDimLevelType(fb, DimLvlType::kSingleton) && "TODO: implement");
+  assert(!merger.isDimLevelType(fb, DimLvlType::kSingleton) &&
+         "TODO: implement");
 
   // Prepare vector length.
   if (isVector)
@@ -1798,33 +1818,42 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
     if (!findSparseAnnotations(merger, op))
       return failure();
 
+    // Builds the tensor expression for the Linalg operation in SSA form.
+    Optional<unsigned> optExp = merger.buildTensorExpFromLinalg(op);
+    if (!optExp.has_value())
+      return failure();
+
+    unsigned exp = optExp.value();
+    OpOperand *sparseOut = nullptr;
+    unsigned outerParNest = 0;
     // Computes a topologically sorted iteration graph to ensure tensors
     // are visited in natural index order. Gradually relaxes the considered
     // constraints until an acyclic iteration graph results, such that sparse
     // code generation can proceed. As a last resort, an attempt is made
     // to resolve cycles by inserting a conversion.
     std::vector<unsigned> topSort;
-    if (!computeIterationGraph(merger, op, topSort, SortMask::kIncludeAll) &&
-        !computeIterationGraph(merger, op, topSort, SortMask::kIncludeUndef) &&
-        !computeIterationGraph(merger, op, topSort, SortMask::kIncludeDense) &&
-        !computeIterationGraph(merger, op, topSort, SortMask::kSparseOnly)) {
-      return resolveCycle(merger, rewriter, op);
+    // Whether the current GenericOp is admissible
+    bool isAdmissible = false;
+    // An const list of all masks that we used for interation graph
+    // computation. Must be ordered from strict -> loose.
+    const auto allMask = {SortMask::kIncludeAll, SortMask::kIncludeUndef,
+                          SortMask::kIncludeDense, SortMask::kSparseOnly};
+    for (auto mask : allMask) {
+      if (computeIterationGraph(merger, op, topSort, mask) &&
+          isAdmissableTensorExp(merger, op, topSort, exp, &sparseOut,
+                                outerParNest)) {
+        // This is an admissible GenericOp.
+        isAdmissible = true;
+        break;
+      }
+      // else try a less strict constraints.
     }
 
-    // Builds the tensor expression for the Linalg operation in SSA form.
-    Optional<unsigned> optExp = merger.buildTensorExpFromLinalg(op);
-    if (!optExp.has_value())
-      return failure();
-    unsigned exp = optExp.value();
-
-    // Rejects an inadmissable tensor expression.
-    OpOperand *sparseOut = nullptr;
-    unsigned outerParNest = 0;
-    if (!isAdmissableTensorExp(merger, op, topSort, exp, &sparseOut,
-                               outerParNest))
-      return failure();
+    if (!isAdmissible)
+      // Give it one last shot to resolve the cycle.
+      return resolveCycle(merger, rewriter, op);
 
-    // Recursively generates code.
+    // Recursively generates code if admissible.
     merger.setHasSparseOut(sparseOut != nullptr);
     CodeGen codegen(options, numTensors, numLoops, sparseOut, outerParNest);
     genBuffers(merger, codegen, rewriter, op);

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir b/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir
index dfd1eefe5cd28..5d52fa5256502 100755
--- a/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir
@@ -123,55 +123,66 @@ func.func @sampled_dd_unfused(%args: tensor<8x8xf64, #SM>,
   return %3 : tensor<8x8xf64>
 }
 
-// CHECK-LABEL: func.func @sparse_sampled_dd_unfused(
-// CHECK-SAME:    %[[VAL_0:.*]]: tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>>,
-// CHECK-SAME:    %[[VAL_1:.*]]: tensor<8x8xf64>,
-// CHECK-SAME:    %[[VAL_2:.*]]: tensor<8x8xf64>) -> tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>> {
-// CHECK-DAG:     %[[VAL_3:.*]] = arith.constant 8 : index
-// CHECK-DAG:     %[[VAL_4:.*]] = arith.constant 0 : index
-// CHECK-DAG:     %[[VAL_5:.*]] = arith.constant 1 : index
-// CHECK-DAG:     %[[VAL_6:.*]] = arith.constant 2 : index
-// CHECK-DAG:     %[[VAL_7:.*]] = arith.constant 0.000000e+00 : f64
-// CHECK-DAG:     %[[VAL_8:.*]] = arith.constant dense<0.000000e+00> : tensor<8x8xf64>
-// CHECK:         %[[VAL_9:.*]] = bufferization.alloc_tensor() copy(%[[VAL_8]]) {bufferization.escape = [false]} : tensor<8x8xf64>
-// CHECK:         %[[VAL_10:.*]] = bufferization.alloc_tensor() {bufferization.escape = [false]} : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>>
-// CHECK:         %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_1]] : memref<8x8xf64>
-// CHECK:         %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_2]] : memref<8x8xf64>
-// CHECK:         %[[VAL_13:.*]] = sparse_tensor.pointers %[[VAL_0]] {dimension = 0 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
-// CHECK:         %[[VAL_14:.*]] = sparse_tensor.indices %[[VAL_0]] {dimension = 0 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
-// CHECK:         %[[VAL_15:.*]] = sparse_tensor.pointers %[[VAL_0]] {dimension = 1 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
-// CHECK:         %[[VAL_16:.*]] = sparse_tensor.indices %[[VAL_0]] {dimension = 1 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
-// CHECK:         %[[VAL_17:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>> to memref<?xf64>
-// CHECK:         %[[VAL_18:.*]] = memref.alloca(%[[VAL_6]]) : memref<?xindex>
-// CHECK:         %[[VAL_19:.*]] = memref.alloca() : memref<f64>
-// CHECK:         %[[VAL_20:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_4]]] : memref<?xindex>
-// CHECK:         %[[VAL_21:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_5]]] : memref<?xindex>
-// CHECK:         scf.for %[[VAL_22:.*]] = %[[VAL_20]] to %[[VAL_21]] step %[[VAL_5]] {
-// CHECK:           %[[VAL_23:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_22]]] : memref<?xindex>
-// CHECK:           memref.store %[[VAL_23]], %[[VAL_18]]{{\[}}%[[VAL_4]]] : memref<?xindex>
-// CHECK:           %[[VAL_24:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_22]]] : memref<?xindex>
-// CHECK:           %[[VAL_25:.*]] = arith.addi %[[VAL_22]], %[[VAL_5]] : index
-// CHECK:           %[[VAL_26:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_25]]] : memref<?xindex>
-// CHECK:           scf.for %[[VAL_27:.*]] = %[[VAL_24]] to %[[VAL_26]] step %[[VAL_5]] {
-// CHECK:             %[[VAL_28:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_27]]] : memref<?xindex>
-// CHECK:             memref.store %[[VAL_28]], %[[VAL_18]]{{\[}}%[[VAL_5]]] : memref<?xindex>
-// CHECK:             %[[VAL_29:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_27]]] : memref<?xf64>
-// CHECK:             %[[VAL_30:.*]] = scf.for %[[VAL_31:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] iter_args(%[[VAL_32:.*]] = %[[VAL_7]]) -> (f64) {
-// CHECK:               memref.store %[[VAL_31]], %[[VAL_18]]{{\[}}%[[VAL_6]]] : memref<?xindex>
-// CHECK:               %[[VAL_33:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_23]], %[[VAL_31]]] : memref<8x8xf64>
-// CHECK:               %[[VAL_34:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_31]], %[[VAL_28]]] : memref<8x8xf64>
-// CHECK:               %[[VAL_35:.*]] = arith.mulf %[[VAL_33]], %[[VAL_34]] : f64
-// CHECK:               %[[VAL_36:.*]] = arith.mulf %[[VAL_35]], %[[VAL_29]] : f64
-// CHECK:               %[[VAL_37:.*]] = arith.addf %[[VAL_32]], %[[VAL_36]] : f64
-// CHECK:               scf.yield %[[VAL_37]] : f64
-// CHECK:             }
-// CHECK:             memref.store %[[VAL_30:.*]], %[[VAL_19]][] : memref<f64>
-// CHECK:             sparse_tensor.insert %[[VAL_10]], %[[VAL_18]], %[[VAL_19]] : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>>, memref<?xindex>, memref<f64>
-// CHECK:           }
-// CHECK:         }
-// CHECK:         %[[VAL_39:.*]] = sparse_tensor.load %[[VAL_10]] hasInserts : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>>
-// CHECK:         return %[[VAL_39]] : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>>
-// CHECK:       }
+
+// CHECK-LABEL:  func @sparse_sampled_dd_unfused(
+// CHECK-SAME:   %[[TMP_arg0:.*]]: tensor<8x8xf64, #sparse_tensor.encoding
+// CHECK-SAME:   %[[TMP_arg1:.*]]: tensor<8x8xf64>,
+// CHECK-SAME:   %[[TMP_arg2:.*]]: tensor<8x8xf64>)
+// CHECK-DAG:    %[[TMP_c8:.*]] = arith.constant 8 : index
+// CHECK-DAG:    %[[TMP_c2:.*]] = arith.constant 2 : index
+// CHECK-DAG:    %[[TMP_c0:.*]] = arith.constant 0 : index
+// CHECK-DAG:    %[[TMP_c1:.*]] = arith.constant 1 : index
+// CHECK-DAG:    %[[TMP_false:.*]] = arith.constant false
+// CHECK-DAG:    %[[TMP_true:.*]] = arith.constant true
+// CHECK-DAG:    %[[TMP_cst:.*]] = arith.constant dense<0.000000e+00> : tensor<8x8xf64>
+// CHECK:        %[[TMP_0:.*]] = bufferization.alloc_tensor() copy(%[[TMP_cst]]) {bufferization.escape = [false]}
+// CHECK:        %[[TMP_1:.*]] = bufferization.alloc_tensor() {bufferization.escape = [false]}
+// CHECK:        %[[TMP_2:.*]] = bufferization.to_memref %[[TMP_arg1]] : memref<8x8xf64>
+// CHECK:        %[[TMP_3:.*]] = bufferization.to_memref %[[TMP_arg2]] : memref<8x8xf64>
+// CHECK:        %[[TMP_4:.*]] = sparse_tensor.pointers %[[TMP_arg0]] {dimension = 0 : index}
+// CHECK:        %[[TMP_5:.*]] = sparse_tensor.indices %[[TMP_arg0]] {dimension = 0 : index}
+// CHECK:        %[[TMP_6:.*]] = sparse_tensor.pointers %[[TMP_arg0]] {dimension = 1 : index}
+// CHECK:        %[[TMP_7:.*]] = sparse_tensor.indices %[[TMP_arg0]] {dimension = 1 : index}
+// CHECK:        %[[TMP_8:.*]] = sparse_tensor.values %[[TMP_arg0]]
+// CHECK:        %[[TMP_9:.*]] = memref.alloca(%[[TMP_c2]]) : memref<?xindex>
+// CHECK:        %[[TMP_10:.*]] = memref.load %[[TMP_4]][%[[TMP_c0]]] : memref<?xindex>
+// CHECK:        %[[TMP_11:.*]] = memref.load %[[TMP_4]][%[[TMP_c1]]] : memref<?xindex>
+// CHECK:        scf.for %[[TMP_arg3:.*]] = %[[TMP_10]] to %[[TMP_11]] step %[[TMP_c1]] {
+// CHECK:          %[[TMP_13:.*]] = memref.load %[[TMP_5]][%[[TMP_arg3]]] : memref<?xindex>
+// CHECK:          memref.store %[[TMP_13]], %[[TMP_9]][%[[TMP_c0]]] : memref<?xindex>
+// CHECK:          %[[TMP_values:.*]], %[[TMP_filled:.*]], %[[TMP_added:.*]], %[[TMP_count:.*]] = sparse_tensor.expand %[[TMP_1]]
+// CHECK:          %[[TMP_14:.*]] = scf.for %[[TMP_arg4:.*]] = %[[TMP_c0]] to %[[TMP_c8]] step %[[TMP_c1]] iter_args(%[[TMP_arg5:.*]] = %[[TMP_count]]) -> (index) {
+// CHECK:            %[[TMP_15:.*]] = memref.load %[[TMP_2]][%[[TMP_13]], %[[TMP_arg4]]] : memref<8x8xf64>
+// CHECK:            %[[TMP_16:.*]] = memref.load %[[TMP_6]][%[[TMP_arg3]]] : memref<?xindex>
+// CHECK:            %[[TMP_17:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index
+// CHECK:            %[[TMP_18:.*]] = memref.load %[[TMP_6]][%[[TMP_17]]] : memref<?xindex>
+// CHECK:            %[[TMP_19:.*]] = scf.for %[[TMP_arg6:.*]] = %[[TMP_16]] to %[[TMP_18]] step %[[TMP_c1]] iter_args(%[[TMP_arg7:.*]] = %[[TMP_arg5]]) -> (index) {
+// CHECK:              %[[TMP_20:.*]] = memref.load %[[TMP_7]][%[[TMP_arg6]]] : memref<?xindex>
+// CHECK:              %[[TMP_21:.*]] = memref.load %[[TMP_values]][%[[TMP_20]]] : memref<?xf64>
+// CHECK:              %[[TMP_22:.*]] = memref.load %[[TMP_3]][%[[TMP_arg4]], %[[TMP_20]]] : memref<8x8xf64>
+// CHECK:              %[[TMP_23:.*]] = arith.mulf %[[TMP_15]], %[[TMP_22]] : f64
+// CHECK:              %[[TMP_24:.*]] = memref.load %[[TMP_8]][%[[TMP_arg6]]] : memref<?xf64>
+// CHECK:              %[[TMP_25:.*]] = arith.mulf %[[TMP_23]], %[[TMP_24]] : f64
+// CHECK:              %[[TMP_26:.*]] = arith.addf %[[TMP_21]], %[[TMP_25]] : f64
+// CHECK:              %[[TMP_27:.*]] = memref.load %[[TMP_filled]][%[[TMP_20]]] : memref<?xi1>
+// CHECK:              %[[TMP_28:.*]] = arith.cmpi eq, %[[TMP_27]], %[[TMP_false]] : i1
+// CHECK:              %[[TMP_29:.*]] = scf.if %[[TMP_28]] -> (index) {
+// CHECK:                memref.store %[[TMP_true]], %[[TMP_filled]][%[[TMP_20]]] : memref<?xi1>
+// CHECK:                memref.store %[[TMP_20]], %[[TMP_added]][%[[TMP_arg7]]] : memref<?xindex>
+// CHECK:                %[[TMP_30:.*]] = arith.addi %[[TMP_arg7]], %[[TMP_c1]] : index
+// CHECK:                scf.yield %[[TMP_30]] : index
+// CHECK:              } else {
+// CHECK:                scf.yield %[[TMP_arg7]] : index
+// CHECK:              }
+// CHECK:              memref.store %[[TMP_26]], %[[TMP_values]][%[[TMP_20]]] : memref<?xf64>
+// CHECK:              scf.yield %[[TMP_29]] : index
+// CHECK:            }
+// CHECK:            scf.yield %[[TMP_19]] : index
+// CHECK:          }
+// CHECK:          sparse_tensor.compress %[[TMP_1]], %[[TMP_9]], %[[TMP_values]], %[[TMP_filled]], %[[TMP_added]], %[[TMP_14]]
+// CHECK:        }
+// CHECK:        %[[TMP_12:.*]] = sparse_tensor.load %[[TMP_1]] hasInserts 
+// CHECK:        return %[[TMP_12]] : tensor<8x8xf64, #sparse_tensor.encoding
 func.func @sparse_sampled_dd_unfused(%args: tensor<8x8xf64, #SM>,
                                      %arga: tensor<8x8xf64>,
                                      %argb: tensor<8x8xf64>) -> tensor<8x8xf64, #SM> {

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_filter_conv2d.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_filter_conv2d.mlir
index 2ff473216b642..38d809c6e09ab 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_filter_conv2d.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_filter_conv2d.mlir
@@ -24,6 +24,15 @@ module {
     return %0 : tensor<6x6xi32>
   }
 
+  func.func @conv2d_sparse_out(%input:  tensor<8x8xi32>,
+               %filter: tensor<3x3xi32, #DCSR>) -> tensor<6x6xi32, #DCSR> {
+    %s = bufferization.alloc_tensor() : tensor<6x6xi32, #DCSR>           
+    %0 = linalg.conv_2d
+      ins  (%input, %filter: tensor<8x8xi32>, tensor<3x3xi32, #DCSR>)
+      outs (%s: tensor<6x6xi32, #DCSR>) -> tensor<6x6xi32, #DCSR>
+    return %0 : tensor<6x6xi32, #DCSR>
+  }
+
   func.func @entry() {
     %c0 = arith.constant 0 : index
     %i0 = arith.constant 0 : i32
@@ -53,7 +62,10 @@ module {
     %0 = call @conv2d(%input, %sparse_filter, %output)
        : (tensor<8x8xi32>,
           tensor<3x3xi32, #DCSR>, tensor<6x6xi32>) -> tensor<6x6xi32>
-
+    %1 = call @conv2d_sparse_out(%input, %sparse_filter)
+       : (tensor<8x8xi32>,
+          tensor<3x3xi32, #DCSR>) -> tensor<6x6xi32, #DCSR>
+ 
     // Verify the output.
     //
     // CHECK:    ( ( 0, 0, -1, -6, -1, 6 ),
@@ -67,9 +79,24 @@ module {
       : tensor<6x6xi32>, vector<6x6xi32>
     vector.print %v : vector<6x6xi32>
 
+    //
+    // Should be the same as dense output
+    // CHECK:    ( ( 0, 0, -1, -6, -1, 6 ),
+    // CHECK-SAME: ( -1, 0, 1, 0, 1, 0 ),
+    // CHECK-SAME: ( 0, -1, 1, 0, 0, 0 ),
+    // CHECK-SAME: ( -1, 0, 0, 0, 0, 0 ),
+    // CHECK-SAME: ( 0, 0, 3, 6, -3, -6 ),
+    // CHECK-SAME: ( 2, -1, 3, 0, -3, 0 ) )
+    //
+    %sparse_ret = sparse_tensor.convert %1
+      : tensor<6x6xi32, #DCSR> to tensor<6x6xi32>
+    %v1 = vector.transfer_read %sparse_ret[%c0, %c0], %i0
+      : tensor<6x6xi32>, vector<6x6xi32>
+    vector.print %v1 : vector<6x6xi32>
+
     // Release the resources.
     bufferization.dealloc_tensor %sparse_filter : tensor<3x3xi32, #DCSR>
-
+    bufferization.dealloc_tensor %1 : tensor<6x6xi32, #DCSR>
     return
   }
 }


        


More information about the Mlir-commits mailing list