[Mlir-commits] [mlir] 10db57b - [mlir][sparse] replace magic constant with symbol

Aart Bik llvmlistbot at llvm.org
Tue Nov 1 11:13:30 PDT 2022


Author: Aart Bik
Date: 2022-11-01T11:13:20-07:00
New Revision: 10db57b7ea4ef756b3ab2269263bffc2dfff9310

URL: https://github.com/llvm/llvm-project/commit/10db57b7ea4ef756b3ab2269263bffc2dfff9310
DIFF: https://github.com/llvm/llvm-project/commit/10db57b7ea4ef756b3ab2269263bffc2dfff9310.diff

LOG: [mlir][sparse] replace magic constant with symbol

Reviewed By: Peiming

Differential Revision: https://reviews.llvm.org/D137177

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 85f4c4e073ad6..944139f38626a 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -31,9 +31,7 @@ using namespace mlir::sparse_tensor;
 
 namespace {
 
-// TODO: start using these when insertions are implemented
-// static constexpr uint64_t DimSizesIdx = 0;
-// static constexpr uint64_t DimCursorIdx = 1;
+static constexpr uint64_t DimSizesIdx = 0;
 static constexpr uint64_t MemSizesIdx = 2;
 static constexpr uint64_t FieldsIdx = 3;
 
@@ -88,11 +86,12 @@ static Optional<Value> sizeFromTensorAtDim(OpBuilder &rewriter, Location loc,
   if (!ShapedType::isDynamic(shape[dim]))
     return constantIndex(rewriter, loc, shape[dim]);
 
-  // Any other query can consult the dimSizes array at field 0 using,
+  // Any other query can consult the dimSizes array at field DimSizesIdx,
   // accounting for the reordering applied to the sparse storage.
   auto tuple = getTuple(adaptedValue);
   Value idx = constantIndex(rewriter, loc, toStoredDim(tensorTp, dim));
-  return rewriter.create<memref::LoadOp>(loc, tuple.getInputs().front(), idx)
+  return rewriter
+      .create<memref::LoadOp>(loc, tuple.getInputs()[DimSizesIdx], idx)
       .getResult();
 }
 


        


More information about the Mlir-commits mailing list