[Mlir-commits] [mlir] 96fef4d - [mlir][sparse] Added new SparseTensorEncodingAttr::withoutOrdering factory
wren romano
llvmlistbot at llvm.org
Thu Dec 15 18:15:01 PST 2022
Author: wren romano
Date: 2022-12-15T18:14:54-08:00
New Revision: 96fef4dc3c9313d476fafe96d4c380988cf0ecda
URL: https://github.com/llvm/llvm-project/commit/96fef4dc3c9313d476fafe96d4c380988cf0ecda
DIFF: https://github.com/llvm/llvm-project/commit/96fef4dc3c9313d476fafe96d4c380988cf0ecda.diff
LOG: [mlir][sparse] Added new SparseTensorEncodingAttr::withoutOrdering factory
Reviewed By: aartbik, Peiming
Differential Revision: https://reviews.llvm.org/D140171
Added:
Modified:
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index 47993b075fca3..70c74bf7f82f5 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -169,6 +169,10 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
/// Returns the type for index storage based on indexBitWidth
Type getIndexType() const;
+
+ /// Constructs a new encoding with the dimOrdering and higherOrdering
+ /// reset to the default/identity.
+ SparseTensorEncodingAttr withoutOrdering() const;
}];
let genVerifyDecl = 1;
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 06a451fc78ffe..aecde7dfd4e2d 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -57,6 +57,12 @@ Type SparseTensorEncodingAttr::getIndexType() const {
return idxWidth ? IntegerType::get(getContext(), idxWidth) : indexType;
}
+SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutOrdering() const {
+ return SparseTensorEncodingAttr::get(
+ getContext(), getDimLevelType(), AffineMap(), AffineMap(),
+ getPointerBitWidth(), getIndexBitWidth());
+}
+
Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
if (failed(parser.parseLess()))
return {};
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index 4f4dbe49926f7..b4d1491ad6af6 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -529,9 +529,7 @@ genSparse2SparseReshape(ReshapeOp op, typename ReshapeOp::Adaptor adaptor,
assert(elemTp == dstTp.getElementType() &&
"reshape should not change element type");
// Start an iterator over the source tensor (in original index order).
- auto noPerm = SparseTensorEncodingAttr::get(
- op->getContext(), encSrc.getDimLevelType(), AffineMap(), AffineMap(),
- encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
+ const auto noPerm = encSrc.withoutOrdering();
SmallVector<Value> srcDimSizes =
getDimSizes(rewriter, loc, encSrc, srcTp, adaptor.getSrc());
NewCallParams params(rewriter, loc);
@@ -596,9 +594,7 @@ static void genSparseCOOIterationLoop(
Type elemTp = tensorTp.getElementType();
// Start an iterator over the tensor (in original index order).
- auto noPerm = SparseTensorEncodingAttr::get(
- rewriter.getContext(), enc.getDimLevelType(), AffineMap(), AffineMap(),
- enc.getPointerBitWidth(), enc.getIndexBitWidth());
+ const auto noPerm = enc.withoutOrdering();
SmallVector<Value> dimSizes = getDimSizes(rewriter, loc, noPerm, tensorTp, t);
Value iter = NewCallParams(rewriter, loc)
.genBuffers(noPerm, dimSizes, tensorTp)
@@ -1485,9 +1481,7 @@ class SparseTensorOutConverter : public OpConversionPattern<OutOp> {
auto encSrc = getSparseTensorEncoding(srcType);
SmallVector<Value> dimSizes =
getDimSizes(rewriter, loc, encSrc, srcType, src);
- auto enc = SparseTensorEncodingAttr::get(
- op->getContext(), encSrc.getDimLevelType(), AffineMap(), AffineMap(),
- encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
+ const auto enc = encSrc.withoutOrdering();
Value coo = NewCallParams(rewriter, loc)
.genBuffers(enc, dimSizes, srcType)
.genNewCall(Action::kToCOO, src);
More information about the Mlir-commits
mailing list