[Mlir-commits] [mlir] 9916ab0 - [mlir][sparse] (re)introducing getRankedTensorType/getMemrefType

wren romano llvmlistbot at llvm.org
Wed Jan 25 11:30:02 PST 2023


Author: wren romano
Date: 2023-01-25T11:29:54-08:00
New Revision: 9916ab03f19dc50c688b8567ac0d30b4a6615f9d

URL: https://github.com/llvm/llvm-project/commit/9916ab03f19dc50c688b8567ac0d30b4a6615f9d
DIFF: https://github.com/llvm/llvm-project/commit/9916ab03f19dc50c688b8567ac0d30b4a6615f9d.diff

LOG: [mlir][sparse] (re)introducing getRankedTensorType/getMemrefType

The bulk of D142074 seems to have gotten overwritten due to some sort of merge conflict (afaict there's no record of it having been reverted intentionally).  So this commit redoes those changes.  In addition to the original changes, this commit also:
* moves the definition of `getRankedTensorType` (from `Transforms/CodegenUtils.h` to `IR/SparseTensor.h`), so that it can be used by `IR/SparseTensorDialect.cpp`.
* adds `getMemRefType` as another abbreviation.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D142503

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
    mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
    mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index c7c0826499091..777a5b40d6119 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -38,6 +38,18 @@
 namespace mlir {
 namespace sparse_tensor {
 
+/// Convenience method to abbreviate casting `getType()`.
+template <typename T>
+inline RankedTensorType getRankedTensorType(T t) {
+  return t.getType().template cast<RankedTensorType>();
+}
+
+/// Convenience method to abbreviate casting `getType()`.
+template <typename T>
+inline MemRefType getMemRefType(T t) {
+  return t.getType().template cast<MemRefType>();
+}
+
 /// Convenience method to get a sparse encoding attribute from a type.
 /// Returns null-attribute for any type without an encoding.
 SparseTensorEncodingAttr getSparseTensorEncoding(Type type);

diff  --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index f2495da395023..364c7e7d45962 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -28,6 +28,15 @@
 using namespace mlir;
 using namespace mlir::sparse_tensor;
 
+//===----------------------------------------------------------------------===//
+// Additional convenience methods.
+//===----------------------------------------------------------------------===//
+
+template <typename T>
+static inline int64_t getTypeRank(T t) {
+  return getRankedTensorType(t).getRank();
+}
+
 //===----------------------------------------------------------------------===//
 // TensorDialect Attribute Methods.
 //===----------------------------------------------------------------------===//
@@ -525,12 +534,11 @@ Type StorageSpecifierType::getFieldType(StorageSpecifierKind kind,
 //===----------------------------------------------------------------------===//
 
 static LogicalResult isInBounds(uint64_t dim, Value tensor) {
-  return success(dim <
-                 (uint64_t)tensor.getType().cast<RankedTensorType>().getRank());
+  return success(dim < static_cast<uint64_t>(getTypeRank(tensor)));
 }
 
 static LogicalResult isMatchingWidth(Value result, unsigned width) {
-  const Type etp = result.getType().cast<MemRefType>().getElementType();
+  const Type etp = getMemRefType(result).getElementType();
   return success(width == 0 ? etp.isIndex() : etp.isInteger(width));
 }
 
@@ -562,8 +570,7 @@ static LogicalResult verifySparsifierGetterSetter(
 }
 
 LogicalResult NewOp::verify() {
-  if (getExpandSymmetry() &&
-      getResult().getType().cast<RankedTensorType>().getRank() != 2)
+  if (getExpandSymmetry() && getTypeRank(getResult()) != 2)
     return emitOpError("expand_symmetry can only be used for 2D tensors");
   return success();
 }
@@ -624,8 +631,8 @@ LogicalResult ToIndicesBufferOp::verify() {
 }
 
 LogicalResult ToValuesOp::verify() {
-  RankedTensorType ttp = getTensor().getType().cast<RankedTensorType>();
-  MemRefType mtp = getResult().getType().cast<MemRefType>();
+  auto ttp = getRankedTensorType(getTensor());
+  auto mtp = getMemRefType(getResult());
   if (ttp.getElementType() != mtp.getElementType())
     return emitError("unexpected mismatch in element types");
   return success();
@@ -754,7 +761,7 @@ LogicalResult UnaryOp::verify() {
 }
 
 LogicalResult ConcatenateOp::verify() {
-  auto dstTp = getType().cast<RankedTensorType>();
+  auto dstTp = getRankedTensorType(*this);
   uint64_t concatDim = getDimension().getZExtValue();
   unsigned rank = dstTp.getRank();
 
@@ -775,8 +782,7 @@ LogicalResult ConcatenateOp::verify() {
         concatDim));
 
   for (size_t i = 0, e = getInputs().size(); i < e; i++) {
-    Value input = getInputs()[i];
-    auto inputRank = input.getType().cast<RankedTensorType>().getRank();
+    const auto inputRank = getTypeRank(getInputs()[i]);
     if (inputRank != rank)
       return emitError(
           llvm::formatv("The input tensor ${0} has a 
diff erent rank (rank={1}) "
@@ -785,15 +791,13 @@ LogicalResult ConcatenateOp::verify() {
   }
 
   for (unsigned i = 0; i < rank; i++) {
-    auto dstDim = dstTp.getShape()[i];
+    const auto dstDim = dstTp.getShape()[i];
     if (i == concatDim) {
       if (!ShapedType::isDynamic(dstDim)) {
+        // If we reach here, all inputs should have static shapes.
         unsigned sumDim = 0;
-        for (auto src : getInputs()) {
-          // If we reach here, all inputs should have static shapes.
-          auto d = src.getType().cast<RankedTensorType>().getShape()[i];
-          sumDim += d;
-        }
+        for (auto src : getInputs())
+          sumDim += getRankedTensorType(src).getShape()[i];
         // If all dimension are statically known, the sum of all the input
         // dimensions should be equal to the output dimension.
         if (sumDim != dstDim)
@@ -804,7 +808,7 @@ LogicalResult ConcatenateOp::verify() {
     } else {
       int64_t prev = dstDim;
       for (auto src : getInputs()) {
-        auto d = src.getType().cast<RankedTensorType>().getShape()[i];
+        const auto d = getRankedTensorType(src).getShape()[i];
         if (!ShapedType::isDynamic(prev) && d != prev)
           return emitError("All dimensions (expect for the concatenating one) "
                            "should be equal.");
@@ -817,8 +821,7 @@ LogicalResult ConcatenateOp::verify() {
 }
 
 LogicalResult InsertOp::verify() {
-  RankedTensorType ttp = getTensor().getType().cast<RankedTensorType>();
-  if (ttp.getRank() != static_cast<int64_t>(getIndices().size()))
+  if (getTypeRank(getTensor()) != static_cast<int64_t>(getIndices().size()))
     return emitOpError("incorrect number of indices");
   return success();
 }
@@ -838,8 +841,7 @@ LogicalResult PushBackOp::verify() {
 }
 
 LogicalResult CompressOp::verify() {
-  RankedTensorType ttp = getTensor().getType().cast<RankedTensorType>();
-  if (ttp.getRank() != 1 + static_cast<int64_t>(getIndices().size()))
+  if (getTypeRank(getTensor()) != 1 + static_cast<int64_t>(getIndices().size()))
     return emitOpError("incorrect number of indices");
   return success();
 }
@@ -860,7 +862,7 @@ void ForeachOp::build(
   // Builds foreach body.
   if (!bodyBuilder)
     return;
-  auto rtp = tensor.getType().cast<RankedTensorType>();
+  auto rtp = getRankedTensorType(tensor);
   int64_t rank = rtp.getRank();
 
   SmallVector<Type> blockArgTypes;
@@ -886,7 +888,7 @@ void ForeachOp::build(
 }
 
 LogicalResult ForeachOp::verify() {
-  auto t = getTensor().getType().cast<RankedTensorType>();
+  auto t = getRankedTensorType(getTensor());
   auto args = getBody()->getArguments();
 
   if (static_cast<size_t>(t.getRank()) + 1 + getInitArgs().size() !=
@@ -944,11 +946,11 @@ LogicalResult SortOp::verify() {
 
   auto n = getN().getDefiningOp<arith::ConstantIndexOp>();
 
-  Type xtp = getXs().front().getType().cast<MemRefType>().getElementType();
+  Type xtp = getMemRefType(getXs().front()).getElementType();
   auto checkTypes = [&](ValueRange operands,
                         bool checkEleType = true) -> LogicalResult {
     for (Value opnd : operands) {
-      MemRefType mtp = opnd.getType().cast<MemRefType>();
+      auto mtp = getMemRefType(opnd);
       int64_t dim = mtp.getShape()[0];
       // We can't check the size of dynamic dimension at compile-time, but all
       // xs and ys should have a dimension not less than n at runtime.
@@ -986,7 +988,7 @@ LogicalResult SortCooOp::verify() {
   }
 
   auto checkDim = [&](Value v, uint64_t min, const char *message) {
-    MemRefType tp = v.getType().cast<MemRefType>();
+    auto tp = getMemRefType(v);
     int64_t dim = tp.getShape()[0];
     if (!ShapedType::isDynamic(dim) && dim < (int64_t)min) {
       emitError(llvm::formatv("{0} got {1} < {2}", message, dim, min));

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
index a73d6275e09c4..cf2f127be05ae 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
@@ -558,7 +558,7 @@ Value sparse_tensor::reshapeValuesToLevels(
   idxBuffer = builder.create<memref::CastOp>(
       loc, MemRefType::get({rank}, builder.getIndexType()), idxBuffer);
   SmallVector<int64_t> shape(rank, ShapedType::kDynamic);
-  Type elemTp = valuesBuffer.getType().cast<MemRefType>().getElementType();
+  Type elemTp = getMemRefType(valuesBuffer).getElementType();
   return builder.create<memref::ReshapeOp>(loc, MemRefType::get(shape, elemTp),
                                            valuesBuffer, idxBuffer);
 }

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
index 8d8b0f8c0ab26..b07991ef5f64e 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
@@ -78,11 +78,6 @@ StringRef primaryTypeFunctionSuffix(Type elemTp);
 // Misc code generators and utilities.
 //===----------------------------------------------------------------------===//
 
-template <typename T>
-inline RankedTensorType getRankedTensorType(T t) {
-  return t.getType().template cast<RankedTensorType>();
-}
-
 /// Generates a 1-valued attribute of the given type.  This supports
 /// all the same types as `getZeroAttr`; however, unlike `getZeroAttr`,
 /// for unsupported types we raise `llvm_unreachable` rather than

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
index fc9476cd2b655..73b5bd48b3f4b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
@@ -53,16 +53,15 @@ static void getMangledSortHelperFuncName(llvm::raw_svector_ostream &nameOstream,
                                          StringRef namePrefix, uint64_t nx,
                                          uint64_t ny, bool isCoo,
                                          ValueRange operands) {
-  nameOstream
-      << namePrefix << nx << "_"
-      << operands[xStartIdx].getType().cast<MemRefType>().getElementType();
+  nameOstream << namePrefix << nx << "_"
+              << getMemRefType(operands[xStartIdx]).getElementType();
 
   if (isCoo)
     nameOstream << "_coo_" << ny;
 
   uint64_t yBufferOffset = isCoo ? 1 : nx;
   for (Value v : operands.drop_front(xStartIdx + yBufferOffset))
-    nameOstream << "_" << v.getType().cast<MemRefType>().getElementType();
+    nameOstream << "_" << getMemRefType(v).getElementType();
 }
 
 /// Looks up a function that is appropriate for the given operands being
@@ -719,7 +718,7 @@ LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, uint64_t nx,
 
   // Convert `values` to have dynamic shape and append them to `operands`.
   for (Value v : xys) {
-    auto mtp = v.getType().cast<MemRefType>();
+    auto mtp = getMemRefType(v);
     if (!mtp.isDynamicDim(0)) {
       auto newMtp =
           MemRefType::get({ShapedType::kDynamic}, mtp.getElementType());

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index aaeb041eb7bbc..074a8de245e34 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -505,8 +505,8 @@ static LogicalResult
 genSparse2SparseReshape(ReshapeOp op, typename ReshapeOp::Adaptor adaptor,
                         ConversionPatternRewriter &rewriter) {
   Location loc = op.getLoc();
-  auto srcTp = op.getSrc().getType().template cast<RankedTensorType>();
-  auto dstTp = op.getResult().getType().template cast<RankedTensorType>();
+  auto srcTp = getRankedTensorType(op.getSrc());
+  auto dstTp = getRankedTensorType(op.getResult());
   auto encSrc = getSparseTensorEncoding(srcTp);
   auto encDst = getSparseTensorEncoding(dstTp);
   if (!encDst || !encSrc)
@@ -888,8 +888,8 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
   matchAndRewrite(ConvertOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     Location loc = op->getLoc();
-    Type resType = op.getType();
-    Type srcType = op.getSource().getType();
+    auto resType = getRankedTensorType(op);
+    auto srcType = getRankedTensorType(op.getSource());
     auto encDst = getSparseTensorEncoding(resType);
     auto encSrc = getSparseTensorEncoding(srcType);
     Value src = adaptor.getOperands()[0];
@@ -953,10 +953,8 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
       //     dst[elem.indices] = elem.value;
       //   }
       //   delete iter;
-      RankedTensorType dstTensorTp = resType.cast<RankedTensorType>();
-      RankedTensorType srcTensorTp = srcType.cast<RankedTensorType>();
-      unsigned rank = dstTensorTp.getRank();
-      Type elemTp = dstTensorTp.getElementType();
+      const unsigned rank = resType.getRank();
+      const Type elemTp = resType.getElementType();
       // Fabricate a no-permutation encoding for NewCallParams
       // The pointer/index types must be those of `src`.
       // The dimLevelTypes aren't actually used by Action::kToIterator.
@@ -965,16 +963,16 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
           SmallVector<DimLevelType>(rank, DimLevelType::Dense), AffineMap(),
           AffineMap(), encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
       SmallVector<Value> dimSizes =
-          getDimSizes(rewriter, loc, encSrc, srcTensorTp, src);
+          getDimSizes(rewriter, loc, encSrc, srcType, src);
       Value iter = NewCallParams(rewriter, loc)
-                       .genBuffers(encDst, dimSizes, dstTensorTp)
+                       .genBuffers(encDst, dimSizes, resType)
                        .genNewCall(Action::kToIterator, src);
       Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType());
       Value elemPtr = genAllocaScalar(rewriter, loc, elemTp);
       Block *insertionBlock = rewriter.getInsertionBlock();
       // TODO: Dense buffers should be allocated/deallocated via the callback
       // in BufferizationOptions.
-      Value dst = allocDenseTensor(rewriter, loc, dstTensorTp, dimSizes);
+      Value dst = allocDenseTensor(rewriter, loc, resType, dimSizes);
       SmallVector<Value> noArgs;
       SmallVector<Type> noTypes;
       auto whileOp = rewriter.create<scf::WhileOp>(loc, noTypes, noArgs);
@@ -1192,7 +1190,7 @@ class SparseTensorInsertConverter : public OpConversionPattern<InsertOp> {
     // index order. All values are passed by reference through stack
     // allocated memrefs.
     Location loc = op->getLoc();
-    auto tp = op.getTensor().getType().cast<RankedTensorType>();
+    auto tp = getRankedTensorType(op.getTensor());
     auto elemTp = tp.getElementType();
     unsigned rank = tp.getRank();
     auto mref = genAlloca(rewriter, loc, rank, rewriter.getIndexType());
@@ -1217,8 +1215,7 @@ class SparseTensorExpandConverter : public OpConversionPattern<ExpandOp> {
   matchAndRewrite(ExpandOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     Location loc = op->getLoc();
-    RankedTensorType srcType =
-        op.getTensor().getType().cast<RankedTensorType>();
+    auto srcType = getRankedTensorType(op.getTensor());
     Type eltType = srcType.getElementType();
     Type boolType = rewriter.getIntegerType(1);
     Type idxType = rewriter.getIndexType();
@@ -1272,7 +1269,7 @@ class SparseTensorCompressConverter : public OpConversionPattern<CompressOp> {
     Value added = adaptor.getAdded();
     Value count = adaptor.getCount();
     Value tensor = adaptor.getTensor();
-    auto tp = op.getTensor().getType().cast<RankedTensorType>();
+    auto tp = getRankedTensorType(op.getTensor());
     Type elemTp = tp.getElementType();
     unsigned rank = tp.getRank();
     auto mref = genAlloca(rewriter, loc, rank, rewriter.getIndexType());
@@ -1326,7 +1323,7 @@ class SparseTensorConcatConverter : public OpConversionPattern<ConcatenateOp> {
     //      a[ adjustForOffset(elem.indices) ] = elem.value
     //    return a
     Location loc = op.getLoc();
-    auto dstTp = op.getType().cast<RankedTensorType>();
+    auto dstTp = getRankedTensorType(op);
     auto encDst = getSparseTensorEncoding(dstTp);
     Type elemTp = dstTp.getElementType();
     uint64_t concatDim = op.getDimension().getZExtValue();
@@ -1381,7 +1378,7 @@ class SparseTensorConcatConverter : public OpConversionPattern<ConcatenateOp> {
     for (auto it : llvm::zip(op.getInputs(), adaptor.getInputs())) {
       Value orignalOp = std::get<0>(it); // Input (with encoding) from Op
       Value adaptedOp = std::get<1>(it); // Input (type converted) from adaptor
-      RankedTensorType srcTp = orignalOp.getType().cast<RankedTensorType>();
+      auto srcTp = getRankedTensorType(orignalOp);
       auto encSrc = getSparseTensorEncoding(srcTp);
       if (encSrc) {
         genSparseCOOIterationLoop(

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
index 4b92540c47499..bc05137bcac47 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
@@ -69,7 +69,7 @@ static VectorType vectorType(VL vl, Type etp) {
 
 /// Constructs vector type from pointer.
 static VectorType vectorType(VL vl, Value ptr) {
-  return vectorType(vl, ptr.getType().cast<MemRefType>().getElementType());
+  return vectorType(vl, getMemRefType(ptr).getElementType());
 }
 
 /// Constructs vector iteration mask.


        


More information about the Mlir-commits mailing list