[Mlir-commits] [mlir] 96fef4d - [mlir][sparse] Added new SparseTensorEncodingAttr::withoutOrdering factory

wren romano llvmlistbot at llvm.org
Thu Dec 15 18:15:01 PST 2022


Author: wren romano
Date: 2022-12-15T18:14:54-08:00
New Revision: 96fef4dc3c9313d476fafe96d4c380988cf0ecda

URL: https://github.com/llvm/llvm-project/commit/96fef4dc3c9313d476fafe96d4c380988cf0ecda
DIFF: https://github.com/llvm/llvm-project/commit/96fef4dc3c9313d476fafe96d4c380988cf0ecda.diff

LOG: [mlir][sparse] Added new SparseTensorEncodingAttr::withoutOrdering factory

Reviewed By: aartbik, Peiming

Differential Revision: https://reviews.llvm.org/D140171

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
    mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index 47993b075fca3..70c74bf7f82f5 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -169,6 +169,10 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
 
     /// Returns the type for index storage based on indexBitWidth
     Type getIndexType() const;
+
+    /// Constructs a new encoding with the dimOrdering and higherOrdering
+    /// reset to the default/identity.
+    SparseTensorEncodingAttr withoutOrdering() const;
   }];
 
   let genVerifyDecl = 1;

diff  --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 06a451fc78ffe..aecde7dfd4e2d 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -57,6 +57,12 @@ Type SparseTensorEncodingAttr::getIndexType() const {
   return idxWidth ? IntegerType::get(getContext(), idxWidth) : indexType;
 }
 
+SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutOrdering() const {
+  return SparseTensorEncodingAttr::get(
+      getContext(), getDimLevelType(), AffineMap(), AffineMap(),
+      getPointerBitWidth(), getIndexBitWidth());
+}
+
 Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
   if (failed(parser.parseLess()))
     return {};

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index 4f4dbe49926f7..b4d1491ad6af6 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -529,9 +529,7 @@ genSparse2SparseReshape(ReshapeOp op, typename ReshapeOp::Adaptor adaptor,
   assert(elemTp == dstTp.getElementType() &&
          "reshape should not change element type");
   // Start an iterator over the source tensor (in original index order).
-  auto noPerm = SparseTensorEncodingAttr::get(
-      op->getContext(), encSrc.getDimLevelType(), AffineMap(), AffineMap(),
-      encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
+  const auto noPerm = encSrc.withoutOrdering();
   SmallVector<Value> srcDimSizes =
       getDimSizes(rewriter, loc, encSrc, srcTp, adaptor.getSrc());
   NewCallParams params(rewriter, loc);
@@ -596,9 +594,7 @@ static void genSparseCOOIterationLoop(
   Type elemTp = tensorTp.getElementType();
 
   // Start an iterator over the tensor (in original index order).
-  auto noPerm = SparseTensorEncodingAttr::get(
-      rewriter.getContext(), enc.getDimLevelType(), AffineMap(), AffineMap(),
-      enc.getPointerBitWidth(), enc.getIndexBitWidth());
+  const auto noPerm = enc.withoutOrdering();
   SmallVector<Value> dimSizes = getDimSizes(rewriter, loc, noPerm, tensorTp, t);
   Value iter = NewCallParams(rewriter, loc)
                    .genBuffers(noPerm, dimSizes, tensorTp)
@@ -1485,9 +1481,7 @@ class SparseTensorOutConverter : public OpConversionPattern<OutOp> {
     auto encSrc = getSparseTensorEncoding(srcType);
     SmallVector<Value> dimSizes =
         getDimSizes(rewriter, loc, encSrc, srcType, src);
-    auto enc = SparseTensorEncodingAttr::get(
-        op->getContext(), encSrc.getDimLevelType(), AffineMap(), AffineMap(),
-        encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
+    const auto enc = encSrc.withoutOrdering();
     Value coo = NewCallParams(rewriter, loc)
                     .genBuffers(enc, dimSizes, srcType)
                     .genNewCall(Action::kToCOO, src);


        


More information about the Mlir-commits mailing list