[Mlir-commits] [mlir] 755285f - [mlir][sparse] Factoring out LoopEmitter::isValidLevel

wren romano llvmlistbot at llvm.org
Fri Mar 24 15:52:07 PDT 2023


Author: wren romano
Date: 2023-03-24T15:51:59-07:00
New Revision: 755285f1e99e534034002ccf113669a68c7b5369

URL: https://github.com/llvm/llvm-project/commit/755285f1e99e534034002ccf113669a68c7b5369
DIFF: https://github.com/llvm/llvm-project/commit/755285f1e99e534034002ccf113669a68c7b5369.diff

LOG: [mlir][sparse] Factoring out LoopEmitter::isValidLevel

Depends On D146674

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D146676

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
index cae92c34e258..7a5605346f50 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
@@ -456,7 +456,7 @@ Operation *LoopEmitter::enterLoopOverTensorAtLvl(
   for (auto [t, l] : llvm::zip(tids, lvls)) {
     // TODO: this check for validity of the (t,l) pairs should be
     // checked/enforced at the callsites, if possible.
-    assert(t < lvlTypes.size() && l < lvlTypes[t].size());
+    assert(isValidLevel(t, l));
     assert(!coords[t][l]); // We cannot re-enter the same level
     const auto lvlTp = lvlTypes[t][l];
     const bool isSparse = isCompressedDLT(lvlTp) || isSingletonDLT(lvlTp);
@@ -572,7 +572,7 @@ Operation *LoopEmitter::enterLoopOverTensorAtLvl(
 Operation *LoopEmitter::enterFilterLoopOverTensorAtLvl(
     OpBuilder &builder, Location loc, TensorId tid, Level lvl,
     AffineExpr affine, MutableArrayRef<Value> reduc) {
-  assert(tid < lvlTypes.size() && lvl < lvlTypes[tid].size());
+  assert(isValidLevel(tid, lvl));
   assert(!affine.isa<AffineDimExpr>() && !isDenseDLT(lvlTypes[tid][lvl]));
   // We can not re-enter the same level.
   assert(!coords[tid][lvl]);
@@ -862,7 +862,7 @@ Operation *LoopEmitter::enterCoIterationOverTensorsAtLvls(
 
 void LoopEmitter::prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
                                              TensorId tid, Level dstLvl) {
-  assert(tid < lvlTypes.size() && dstLvl < lvlTypes[tid].size());
+  assert(isValidLevel(tid, dstLvl));
   const auto lvlTp = lvlTypes[tid][dstLvl];
 
   if (isDenseDLT(lvlTp))

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
index 8cfe00100eba..b5772d6f7a10 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
@@ -265,6 +265,10 @@ class LoopEmitter {
     return isOutputTensor(tid) && isSparseOut;
   }
 
+  bool isValidLevel(TensorId tid, Level lvl) const {
+    return tid < lvlTypes.size() && lvl < lvlTypes[tid].size();
+  }
+
   /// Prepares loop for iterating over `tensor[lvl]`, under the assumption
   /// that `tensor[0...lvl-1]` loops have already been set up.
   void prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc,


        


More information about the Mlir-commits mailing list