[Mlir-commits] [mlir] 83b7f01 - [mlir][sparse] fix crashes when the tensor that defines the loop bound can not be found

Peiming Liu llvmlistbot at llvm.org
Wed Jun 14 13:27:57 PDT 2023


Author: Peiming Liu
Date: 2023-06-14T20:27:50Z
New Revision: 83b7f018fd704714fbb4c15081bce3dde48eb707

URL: https://github.com/llvm/llvm-project/commit/83b7f018fd704714fbb4c15081bce3dde48eb707
DIFF: https://github.com/llvm/llvm-project/commit/83b7f018fd704714fbb4c15081bce3dde48eb707.diff

LOG: [mlir][sparse] fix crashes when the tensor that defines the loop bound can not be found

Reviewed By: aartbik, K-Wu

Differential Revision: https://reviews.llvm.org/D152877

Added: 
    mlir/test/Dialect/SparseTensor/unused-tensor.mlir

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
    mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
index d0884ca482de2..0a70cb0a970fc 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
@@ -376,8 +376,15 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
     loopIdToOrd[topSort[n]] = n;
 }
 
-void LoopEmitter::initializeLoopEmit(OpBuilder &builder, Location loc,
-                                     LoopEmitter::OutputUpdater updater) {
+void LoopEmitter::initializeLoopEmit(
+    OpBuilder &builder, Location loc, LoopEmitter::OutputUpdater updater,
+    LoopEmitter::SynTensorBoundSetter synSetter) {
+
+  // For every synthetic tensor, set the high bound by calling the callback.
+  if (synSetter)
+    for (unsigned i = 0, e = highs[getSynTensorId()].size(); i < e; i++)
+      highs[getSynTensorId()][i] = synSetter(builder, loc, i);
+
   // For every manifest tensor:
   // * get the values buffer.
   // * For every level:
@@ -534,27 +541,15 @@ void LoopEmitter::enterNewLoopSeq(OpBuilder &builder, Location loc,
   // Prepares for all the tensors used in the current loop sequence.
   std::vector<std::tuple<TensorId, Level, bool>> slicedTids;
 
-  bool hasSynTensor = false;
-  std::optional<std::pair<TensorId, Level>> loopBoundDefLevel = std::nullopt;
   for (auto [tid, lvl] : unpackTensorLevelRange(tidLvls)) {
     if (!dependentLvlMap[tid][lvl].empty()) {
       bool fullyRed = genSliceBegin(builder, loc, tid, lvl);
       slicedTids.emplace_back(tid, lvl, fullyRed);
-    } else {
-      if (isSynTensor(tid)) {
-        hasSynTensor = true;
-      } else {
-        loopBoundDefLevel = std::make_pair(tid, lvl);
-        prepareLoopOverTensorAtLvl(builder, loc, tid, lvl);
-      }
+    } else if (!isSynTensor(tid)) {
+      prepareLoopOverTensorAtLvl(builder, loc, tid, lvl);
     }
   }
 
-  if (hasSynTensor && loopBoundDefLevel.has_value()) {
-    // TODO: compute the loopBound for index reduction by d - sum(unres_lvls).
-    highs[getSynTensorId()][getCurrentDepth()] =
-        lvlSizes[loopBoundDefLevel->first][loopBoundDefLevel->second];
-  }
   // Universal Index starts from 0.
   loopSeqStack.emplace_back(C_IDX(0), std::move(slicedTids));
 }

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
index f178366a738a4..4f5100cc36104 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
@@ -78,6 +78,12 @@ class LoopEmitter {
   /// initializing the loop emitter (e.g., to fill a dense output with zeros).
   using OutputUpdater = function_ref<Value(OpBuilder &builder, Location loc,
                                            Value memref, Value tensor)>;
+
+  /// Optional callback function to set the bound for the synthetic tensor,
+  /// which essentially is the dense loop bound.
+  using SynTensorBoundSetter =
+      function_ref<Value(OpBuilder &builder, Location loc, Level lvl)>;
+
   // Map from [tid, dim] to a list of dependent [tid, dim] for affine expression
   // index on sparse tensors.
   // E.g., for affine index (d0 + d1), it depends on two [tid, dim] that defines
@@ -114,7 +120,8 @@ class LoopEmitter {
   /// Starts a loop emitting session by generating all the buffers needed
   /// for iterating over the tensors.
   void initializeLoopEmit(OpBuilder &builder, Location loc,
-                          OutputUpdater updater = nullptr);
+                          OutputUpdater updater = nullptr,
+                          SynTensorBoundSetter synSetter = nullptr);
 
   /// Generates code to compute an affine expression whose variables are
   /// `LoopId`s (i.e., `a.cast<AffineDimExpr>().getPosition()` is a valid

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 637c16f92a293..7e69a737b0661 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -832,6 +832,21 @@ static void genBuffers(CodegenEnv &env, OpBuilder &builder) {
   Location loc = op.getLoc();
   assert(op.getNumOperands() == op.getNumDpsInputs() + 1);
 
+  SmallVector<Range, 4> loopRange =
+      llvm::cast<linalg::LinalgOp>(op.getOperation())
+          .createLoopRanges(builder, loc);
+
+  assert(loopRange.size() == env.merger().getStartingFilterLoopId());
+  SmallVector<Range, 4> sortedRange;
+  for (unsigned i = 0, e = env.topSortSize(); i < e; i++) {
+    LoopId ldx = env.topSortAt(i);
+    // FIXME: Gets rid of filter loops since we have a better algorithm to deal
+    // with affine index expression.
+    if (ldx < env.merger().getStartingFilterLoopId()) {
+      sortedRange.push_back(loopRange[ldx]);
+    }
+  }
+
   env.emitter().initializeLoopEmit(
       builder, loc,
       /// Generates buffer for the output tensor.
@@ -865,6 +880,16 @@ static void genBuffers(CodegenEnv &env, OpBuilder &builder) {
                                          ValueRange{init});
         }
         return init;
+      },
+      [&sortedRange, &env](OpBuilder &b, Location loc, Level l) {
+        assert(l < env.topSortSize());
+        // FIXME: Remove filter loop since we have a better algorithm to
+        // deal with affine index expression.
+        if (l >= env.merger().getStartingFilterLoopId())
+          return Value();
+
+        return mlir::getValueOrCreateConstantIndexOp(b, loc,
+                                                     sortedRange[l].size);
       });
 }
 
@@ -1594,7 +1619,9 @@ static bool translateBitsToTidLvlPairs(
             // iterate based on the level of output tensor.  E.g., this
             // could be a synthetic tensor (for invariants and sparse
             // output tensor).
-            if (env.isReduc() && env.merger().getSynTensorID() == tid) {
+            auto itType = env.op().getIteratorTypesArray()[ldx];
+            if (linalg::isReductionIterator(itType) &&
+                env.merger().getSynTensorID() == tid) {
               // Coiterating with an invariant, and this is a reduction loop
               // e.g., out = prod(in[i][j] op invariant);
               // In this case, we can not infer the loop bound from output
@@ -1669,7 +1696,14 @@ static bool translateBitsToTidLvlPairs(
     tidLvls.push_back(env.makeTensorLevel(outTid, *outLvl));
   }
 
-  assert(numloopCond > 0);
+  if (numloopCond == 0) {
+    // Corner cases where the loop bound is defined by a *unused* operand, in
+    // this case, we just generate a dense "fake" loop by iterating over the
+    // synthetic tensor.
+    tidLvls.push_back(env.makeTensorLevel(env.merger().getSynTensorID(),
+                                          env.emitter().getCurrentDepth()));
+    numloopCond++;
+  }
   // If we just need to one loop conditions and the conditions is not imposed on
   // non-unique level, the loop can be generated by a for loop.
   return numloopCond == 1 && !hasNonUnique;

diff  --git a/mlir/test/Dialect/SparseTensor/unused-tensor.mlir b/mlir/test/Dialect/SparseTensor/unused-tensor.mlir
new file mode 100644
index 0000000000000..05da6c455135c
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/unused-tensor.mlir
@@ -0,0 +1,57 @@
+// RUN: mlir-opt %s -sparsification | FileCheck %s
+
+//
+// A contrived example where the sparse tensor B is only
+// used in the linalg op to determine the number of iterations
+// for the k-loop. This is included to make sure the sparse
+// compiler still generates the correct loop nest for this case.
+//
+
+#SM = #sparse_tensor.encoding<{ lvlTypes = [ "dense", "compressed" ] }>
+
+#trait = {
+  indexing_maps = [
+    affine_map<(i,j,k) -> (i,j)>,  // A
+    affine_map<(i,j,k) -> (k,j)>,  // B
+    affine_map<(i,j,k) -> (i,j)>   // S_out
+  ],
+  iterator_types = ["parallel", "parallel", "reduction"],
+  doc = "C(i,j) = SUM_k A(i,j)"
+}
+
+// CHECK-LABEL:   func.func @b_ununsed(
+// CHECK-SAME:      %[[VAL_0:.*]]: tensor<2x4xf64>,
+// CHECK-SAME:      %[[VAL_1:.*]]: tensor<8x4xf64, #sparse_tensor.encoding<{{.*}}>>,
+// CHECK-SAME:      %[[VAL_2:.*]]: tensor<2x4xf64>) -> tensor<2x4xf64> {
+// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 8 : index
+// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 2 : index
+// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 4 : index
+// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[VAL_7:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_0]] : memref<2x4xf64>
+// CHECK-DAG:       %[[VAL_9:.*]] = bufferization.to_memref %[[VAL_2]] : memref<2x4xf64>
+// CHECK:           scf.for %[[VAL_10:.*]] = %[[VAL_6]] to %[[VAL_4]] step %[[VAL_7]] {
+// CHECK:             scf.for %[[VAL_11:.*]] = %[[VAL_6]] to %[[VAL_3]] step %[[VAL_7]] {
+// CHECK:               scf.for %[[VAL_12:.*]] = %[[VAL_6]] to %[[VAL_5]] step %[[VAL_7]] {
+// CHECK:                 %[[VAL_13:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_10]], %[[VAL_12]]] : memref<2x4xf64>
+// CHECK:                 %[[VAL_14:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_10]], %[[VAL_12]]] : memref<2x4xf64>
+// CHECK:                 %[[VAL_15:.*]] = arith.addf %[[VAL_13]], %[[VAL_14]] : f64
+// CHECK:                 memref.store %[[VAL_15]], %[[VAL_9]]{{\[}}%[[VAL_10]], %[[VAL_12]]] : memref<2x4xf64>
+// CHECK:               }
+// CHECK:             }
+// CHECK:           }
+// CHECK:           %[[VAL_16:.*]] = bufferization.to_tensor %[[VAL_9]] : memref<2x4xf64>
+// CHECK:           return %[[VAL_16]] : tensor<2x4xf64>
+// CHECK:         }
+func.func @b_ununsed(%argA: tensor<2x4xf64>,
+                     %argB: tensor<8x4xf64, #SM>,
+                     %argC: tensor<2x4xf64>) -> tensor<2x4xf64> {
+  %result = linalg.generic #trait
+    ins(%argA, %argB: tensor<2x4xf64>, tensor<8x4xf64, #SM>)
+    outs(%argC: tensor<2x4xf64>) {
+      ^bb(%a: f64, %b: f64, %c: f64):
+         %0 = arith.addf %c, %a : f64
+         linalg.yield %0 : f64
+  } -> tensor<2x4xf64>
+  return %result : tensor<2x4xf64>
+}


        


More information about the Mlir-commits mailing list