[Mlir-commits] [mlir] e057f25 - [mlir][sparse] auto-insertion of conversion to resolve cycles

Aart Bik llvmlistbot at llvm.org
Wed Jun 29 18:28:26 PDT 2022


Author: Aart Bik
Date: 2022-06-29T18:28:18-07:00
New Revision: e057f25dee59e61b870595156656f90d015b859f

URL: https://github.com/llvm/llvm-project/commit/e057f25dee59e61b870595156656f90d015b859f
DIFF: https://github.com/llvm/llvm-project/commit/e057f25dee59e61b870595156656f90d015b859f.diff

LOG: [mlir][sparse] auto-insertion of conversion to resolve cycles

When the iteration graph is cyclic (even after several attempts using less and less constraints), the current sparse compiler bails out, and no rewriting hapens. However, this revision adds some new logic where the sparse compiler tries to find a single input sparse tensor that breaks the cycle, and then adds a proper sparse conversion operation. This way, more incoming kernels can be handled!

Note, the resulting code is not optimal (although it keeps more or less proper "sparse" complexity), and more improvements should be added (especially when the kernel directly yields without computation, such as the transpose example). However, handling is better than not handling ;-)

Reviewed By: bixia

Differential Revision: https://reviews.llvm.org/D128847

Added: 
    mlir/test/Dialect/SparseTensor/sparse_transpose.mlir

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_transpose.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index cd777416ba730..614700aba9a16 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -41,7 +41,12 @@ using namespace mlir::sparse_tensor;
 namespace {
 
 // Iteration graph sorting.
-enum SortMask { kSparseOnly = 0x0, kIncludeDense = 0x1, kIncludeUndef = 0x2 };
+enum SortMask {
+  kSparseOnly = 0x0,
+  kIncludeDense = 0x1,
+  kIncludeUndef = 0x2,
+  kIncludeAll = 0x3
+};
 
 // Reduction kinds.
 enum Reduction { kNoReduc, kSum, kProduct, kAnd, kOr, kXor };
@@ -105,6 +110,17 @@ struct CodeGen {
 // Sparse compiler analysis methods.
 //===----------------------------------------------------------------------===//
 
+/// Helper method to construct a permuted dimension ordering
+/// that adheres to the given topological sort.
+static AffineMap permute(MLIRContext *context, AffineMap m,
+                         std::vector<unsigned> &topSort) {
+  unsigned sz = topSort.size();
+  SmallVector<unsigned, 4> perm(sz);
+  for (unsigned i = 0; i < sz; i++)
+    perm[i] = m.getPermutedPosition(topSort[i]);
+  return AffineMap::getPermutationMap(perm, context);
+}
+
 /// Helper method to apply dimension ordering permutation.
 static unsigned perm(const SparseTensorEncodingAttr &enc, unsigned d) {
   if (enc) {
@@ -231,8 +247,8 @@ static void addAffineOrderings(std::vector<std::vector<bool>> &adjM,
 /// dimensions. Even for dense storage formats, however, the natural index
 /// order yields innermost unit-stride access with better spatial locality.
 static bool computeIterationGraph(Merger &merger, linalg::GenericOp op,
-                                  std::vector<unsigned> &topSort,
-                                  unsigned mask) {
+                                  std::vector<unsigned> &topSort, unsigned mask,
+                                  OpOperand *skip = nullptr) {
   // Set up an n x n from/to adjacency matrix of the iteration graph
   // for the implicit loop indices i_0 .. i_n-1.
   unsigned n = op.getNumLoops();
@@ -240,6 +256,10 @@ static bool computeIterationGraph(Merger &merger, linalg::GenericOp op,
 
   // Iterate over the indexing maps of every tensor in the tensor expression.
   for (OpOperand *t : op.getInputAndOutputOperands()) {
+    // Skip tensor during cycle resolution.
+    if (t == skip)
+      continue;
+    // Get map and encoding.
     auto map = op.getTiedIndexingMap(t);
     auto enc = getSparseTensorEncoding(t->get().getType());
     assert(map.getNumDims() == n);
@@ -328,7 +348,7 @@ static bool isAdmissableTensorExp(Merger &merger, linalg::GenericOp op,
   // A tensor expression with a sparse output tensor that changes its values
   // but not its nonzero structure, an operation called "simply dynamic" in
   // [Bik96,Ch9], is also admissable without special codegen, provided
-  // the tensor's underlying sparse storage scheme can be modified in place.
+  // the tensor's underlying sparse storage scheme can be modified in-place.
   if (merger.isSingleCondition(tensor, exp) && isInPlace(lhs->get()))
     return true;
   // Accept "truly dynamic" if the output tensor materializes uninitialized
@@ -1725,18 +1745,18 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
     if (!findSparseAnnotations(merger, op))
       return failure();
 
-    // Computes a topologically sorted iteration graph to ensure
-    // tensors are visited in natural index order. Fails on cycles.
-    // This assumes that higher-level passes have already put the
-    // tensors in each tensor expression in a feasible order.
+    // Computes a topologically sorted iteration graph to ensure tensors
+    // are visited in natural index order. Gradually relaxes the considered
+    // constraints until an acyclic iteration graph results, such that sparse
+    // code generation can proceed. As a last resort, an attempt is made
+    // to resolve cycles by inserting a conversion.
     std::vector<unsigned> topSort;
-    if (!computeIterationGraph(merger, op, topSort,
-                               SortMask::kIncludeUndef |
-                                   SortMask::kIncludeDense) &&
+    if (!computeIterationGraph(merger, op, topSort, SortMask::kIncludeAll) &&
         !computeIterationGraph(merger, op, topSort, SortMask::kIncludeUndef) &&
         !computeIterationGraph(merger, op, topSort, SortMask::kIncludeDense) &&
-        !computeIterationGraph(merger, op, topSort, SortMask::kSparseOnly))
-      return failure();
+        !computeIterationGraph(merger, op, topSort, SortMask::kSparseOnly)) {
+      return resolveCycle(merger, rewriter, op);
+    }
 
     // Builds the tensor expression for the Linalg operation in SSA form.
     Optional<unsigned> optExp = merger.buildTensorExpFromLinalg(op);
@@ -1761,6 +1781,45 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
   }
 
 private:
+  // Last resort cycle resolution.
+  LogicalResult resolveCycle(Merger &merger, PatternRewriter &rewriter,
+                             linalg::GenericOp op) const {
+    // Compute topological sort while leaving out every
+    // sparse input tensor in succession until an acylic
+    // iteration graph results.
+    std::vector<unsigned> topSort;
+    for (OpOperand *t : op.getInputOperands()) {
+      unsigned tensor = t->getOperandNumber();
+      Value tval = t->get();
+      auto srcEnc = getSparseTensorEncoding(tval.getType());
+      if (!srcEnc ||
+          !computeIterationGraph(merger, op, topSort, SortMask::kSparseOnly, t))
+        continue;
+      // Found an input tensor that resolves the cycle by inserting a
+      // conversion into a sparse tensor that adheres to the iteration
+      // graph order. Also releases the temporary sparse tensor.
+      //
+      // TODO: investigate fusing the conversion with computation,
+      //       especially if it is a direct yield!
+      //
+      auto srcTp = tval.getType().cast<RankedTensorType>();
+      auto dstEnc = SparseTensorEncodingAttr::get(
+          op->getContext(), srcEnc.getDimLevelType(),
+          permute(getContext(), op.getTiedIndexingMap(t), topSort), // new order
+          srcEnc.getPointerBitWidth(), srcEnc.getIndexBitWidth());
+      auto dstTp = RankedTensorType::get(srcTp.getShape(),
+                                         srcTp.getElementType(), dstEnc);
+      auto convert = rewriter.create<ConvertOp>(tval.getLoc(), dstTp, tval);
+      op->setOperand(tensor, convert);
+      rewriter.setInsertionPointAfter(op);
+      rewriter.create<ReleaseOp>(tval.getLoc(), convert);
+      return success();
+    }
+    // Cannot be resolved with a single conversion.
+    // TODO: convert more than one?
+    return failure();
+  }
+
   /// Options to control sparse code generation.
   SparsificationOptions options;
 };

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_transpose.mlir b/mlir/test/Dialect/SparseTensor/sparse_transpose.mlir
new file mode 100644
index 0000000000000..88505179da6bb
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/sparse_transpose.mlir
@@ -0,0 +1,62 @@
+// RUN: mlir-opt %s -sparsification | FileCheck %s
+
+#DCSR = #sparse_tensor.encoding<{
+  dimLevelType = [ "compressed", "compressed" ]
+}>
+
+#transpose_trait = {
+  indexing_maps = [
+    affine_map<(i,j) -> (j,i)>,  // A
+    affine_map<(i,j) -> (i,j)>   // X
+  ],
+  iterator_types = ["parallel", "parallel"],
+  doc = "X(i,j) = A(j,i)"
+}
+
+// TODO: improve auto-conversion followed by yield
+
+// CHECK-LABEL:   func.func @sparse_transpose_auto(
+// CHECK-SAME:                                     %[[VAL_0:.*]]: tensor<3x4xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>) -> tensor<4x3xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> {
+// CHECK-DAG:       %[[VAL_1:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[VAL_2:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 2 : index
+// CHECK:           %[[VAL_4:.*]] = bufferization.alloc_tensor() : tensor<4x3xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>
+// CHECK:           %[[VAL_5:.*]] = sparse_tensor.convert %[[VAL_0]] : tensor<3x4xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to tensor<3x4xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d1, d0)>, pointerBitWidth = 0, indexBitWidth = 0 }>>
+// CHECK:           %[[VAL_6:.*]] = sparse_tensor.pointers %[[VAL_5]], %[[VAL_1]] : tensor<3x4xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d1, d0)>, pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_7:.*]] = sparse_tensor.indices %[[VAL_5]], %[[VAL_1]] : tensor<3x4xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d1, d0)>, pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_8:.*]] = sparse_tensor.pointers %[[VAL_5]], %[[VAL_2]] : tensor<3x4xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d1, d0)>, pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_9:.*]] = sparse_tensor.indices %[[VAL_5]], %[[VAL_2]] : tensor<3x4xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d1, d0)>, pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_5]] : tensor<3x4xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d1, d0)>, pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf64>
+// CHECK:           %[[VAL_11:.*]] = memref.alloca(%[[VAL_3]]) : memref<?xindex>
+// CHECK:           %[[VAL_12:.*]] = memref.alloca() : memref<f64>
+// CHECK:           %[[VAL_13:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_1]]] : memref<?xindex>
+// CHECK:           %[[VAL_14:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_2]]] : memref<?xindex>
+// CHECK:           scf.for %[[VAL_15:.*]] = %[[VAL_13]] to %[[VAL_14]] step %[[VAL_2]] {
+// CHECK:             %[[VAL_16:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_15]]] : memref<?xindex>
+// CHECK:             memref.store %[[VAL_16]], %[[VAL_11]]{{\[}}%[[VAL_1]]] : memref<?xindex>
+// CHECK:             %[[VAL_17:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_15]]] : memref<?xindex>
+// CHECK:             %[[VAL_18:.*]] = arith.addi %[[VAL_15]], %[[VAL_2]] : index
+// CHECK:             %[[VAL_19:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_18]]] : memref<?xindex>
+// CHECK:             scf.for %[[VAL_20:.*]] = %[[VAL_17]] to %[[VAL_19]] step %[[VAL_2]] {
+// CHECK:               %[[VAL_21:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_20]]] : memref<?xindex>
+// CHECK:               memref.store %[[VAL_21]], %[[VAL_11]]{{\[}}%[[VAL_2]]] : memref<?xindex>
+// CHECK:               %[[VAL_22:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_20]]] : memref<?xf64>
+// CHECK:               memref.store %[[VAL_22]], %[[VAL_12]][] : memref<f64>
+// CHECK:               sparse_tensor.lex_insert %[[VAL_4]], %[[VAL_11]], %[[VAL_12]] : tensor<4x3xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>, memref<?xindex>, memref<f64>
+// CHECK:             }
+// CHECK:           }
+// CHECK:           %[[VAL_23:.*]] = sparse_tensor.load %[[VAL_4]] hasInserts : tensor<4x3xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>
+// CHECK:           sparse_tensor.release %[[VAL_5]] : tensor<3x4xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d1, d0)>, pointerBitWidth = 0, indexBitWidth = 0 }>>
+// CHECK:           return %[[VAL_23]] : tensor<4x3xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>
+// CHECK:         }
+func.func @sparse_transpose_auto(%arga: tensor<3x4xf64, #DCSR>)
+                                     -> tensor<4x3xf64, #DCSR> {
+  %i = bufferization.alloc_tensor() : tensor<4x3xf64, #DCSR>
+  %0 = linalg.generic #transpose_trait
+     ins(%arga: tensor<3x4xf64, #DCSR>)
+     outs(%i: tensor<4x3xf64, #DCSR>) {
+     ^bb(%a: f64, %x: f64):
+       linalg.yield %a : f64
+  } -> tensor<4x3xf64, #DCSR>
+  return %0 : tensor<4x3xf64, #DCSR>
+}

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_transpose.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_transpose.mlir
index a73229ffe5666..496da608483fc 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_transpose.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_transpose.mlir
@@ -25,25 +25,42 @@ module {
 
   //
   // Transposing a sparse row-wise matrix into another sparse row-wise
-  // matrix would fail direct codegen, since it introduces a cycle in
-  // the iteration graph. This can be avoided by converting the incoming
+  // matrix introduces a cycle in the iteration graph. This complication
+  // can be avoided by manually inserting a conversion of the incoming
   // matrix into a sparse column-wise matrix first.
   //
-  func.func @sparse_transpose(%arga: tensor<3x4xf64, #DCSR>) -> tensor<4x3xf64, #DCSR> {
-    %t = sparse_tensor.convert %arga : tensor<3x4xf64, #DCSR> to tensor<3x4xf64, #DCSC>
+  func.func @sparse_transpose(%arga: tensor<3x4xf64, #DCSR>)
+                                  -> tensor<4x3xf64, #DCSR> {
+    %t = sparse_tensor.convert %arga
+      : tensor<3x4xf64, #DCSR> to tensor<3x4xf64, #DCSC>
 
     %i = bufferization.alloc_tensor() : tensor<4x3xf64, #DCSR>
-
     %0 = linalg.generic #transpose_trait
        ins(%t: tensor<3x4xf64, #DCSC>)
        outs(%i: tensor<4x3xf64, #DCSR>) {
        ^bb(%a: f64, %x: f64):
          linalg.yield %a : f64
-     } -> tensor<4x3xf64, #DCSR>
+    } -> tensor<4x3xf64, #DCSR>
+
+    sparse_tensor.release %t : tensor<3x4xf64, #DCSC>
 
-     sparse_tensor.release %t : tensor<3x4xf64, #DCSC>
+    return %0 : tensor<4x3xf64, #DCSR>
+  }
 
-     return %0 : tensor<4x3xf64, #DCSR>
+  //
+  // However, even better, the sparse compiler is able to insert such a
+  // conversion automatically to resolve a cycle in the iteration graph!
+  //
+  func.func @sparse_transpose_auto(%arga: tensor<3x4xf64, #DCSR>)
+                                       -> tensor<4x3xf64, #DCSR> {
+    %i = bufferization.alloc_tensor() : tensor<4x3xf64, #DCSR>
+    %0 = linalg.generic #transpose_trait
+       ins(%arga: tensor<3x4xf64, #DCSR>)
+       outs(%i: tensor<4x3xf64, #DCSR>) {
+       ^bb(%a: f64, %x: f64):
+         linalg.yield %a : f64
+    } -> tensor<4x3xf64, #DCSR>
+    return %0 : tensor<4x3xf64, #DCSR>
   }
 
   //
@@ -63,8 +80,11 @@ module {
     ]> : tensor<3x4xf64>
     %a = sparse_tensor.convert %d : tensor<3x4xf64> to tensor<3x4xf64, #DCSR>
 
-    // Call the kernel.
-    %0 = call @sparse_transpose(%a) : (tensor<3x4xf64, #DCSR>) -> tensor<4x3xf64, #DCSR>
+    // Call the kernels.
+    %0 = call @sparse_transpose(%a)
+      : (tensor<3x4xf64, #DCSR>) -> tensor<4x3xf64, #DCSR>
+    %1 = call @sparse_transpose_auto(%a)
+      : (tensor<3x4xf64, #DCSR>) -> tensor<4x3xf64, #DCSR>
 
     //
     // Verify result.
@@ -74,17 +94,30 @@ module {
     // CHECK-NEXT: ( 0, 0, 3.3 )
     // CHECK-NEXT: ( 1.4, 0, 3.4 )
     //
+    // CHECK-NEXT: ( 1.1, 0, 3.1 )
+    // CHECK-NEXT: ( 1.2, 0, 0 )
+    // CHECK-NEXT: ( 0, 0, 3.3 )
+    // CHECK-NEXT: ( 1.4, 0, 3.4 )
+    //
     %x = sparse_tensor.convert %0 : tensor<4x3xf64, #DCSR> to tensor<4x3xf64>
     %m = bufferization.to_memref %x : memref<4x3xf64>
     scf.for %i = %c0 to %c4 step %c1 {
-      %v = vector.transfer_read %m[%i, %c0], %du: memref<4x3xf64>, vector<3xf64>
-      vector.print %v : vector<3xf64>
+      %v1 = vector.transfer_read %m[%i, %c0], %du: memref<4x3xf64>, vector<3xf64>
+      vector.print %v1 : vector<3xf64>
+    }
+    %y = sparse_tensor.convert %1 : tensor<4x3xf64, #DCSR> to tensor<4x3xf64>
+    %n = bufferization.to_memref %y : memref<4x3xf64>
+    scf.for %i = %c0 to %c4 step %c1 {
+      %v2 = vector.transfer_read %n[%i, %c0], %du: memref<4x3xf64>, vector<3xf64>
+      vector.print %v2 : vector<3xf64>
     }
 
     // Release resources.
     sparse_tensor.release %a : tensor<3x4xf64, #DCSR>
     sparse_tensor.release %0 : tensor<4x3xf64, #DCSR>
+    sparse_tensor.release %1 : tensor<4x3xf64, #DCSR>
     memref.dealloc %m : memref<4x3xf64>
+    memref.dealloc %n : memref<4x3xf64>
 
     return
   }


        


More information about the Mlir-commits mailing list