[Mlir-commits] [mlir] af2bec7 - [mlir][sparse] Adding new STEA::{with, without}DimSlices factories
wren romano
llvmlistbot at llvm.org
Tue May 30 15:53:38 PDT 2023
Author: wren romano
Date: 2023-05-30T15:53:30-07:00
New Revision: af2bec7c4a967c9e2e009cdbc4470eb5ba8332f6
URL: https://github.com/llvm/llvm-project/commit/af2bec7c4a967c9e2e009cdbc4470eb5ba8332f6
DIFF: https://github.com/llvm/llvm-project/commit/af2bec7c4a967c9e2e009cdbc4470eb5ba8332f6.diff
LOG: [mlir][sparse] Adding new STEA::{with,without}DimSlices factories
(These factories are used in downstream code, despite not being used within the MLIR codebase.)
Depends On D151513
Reviewed By: Peiming
Differential Revision: https://reviews.llvm.org/D151518
Added:
Modified:
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index f0a502e5dcd9..9fe425a40415 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -304,6 +304,14 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
/// reset to the default, and all other fields inherited from `this`.
SparseTensorEncodingAttr withoutBitWidths() const;
+ /// Constructs a new encoding with the given dimSlices, and all
+ /// other fields inherited from `this`.
+ SparseTensorEncodingAttr withDimSlices(ArrayRef<::mlir::sparse_tensor::SparseTensorDimSliceAttr> dimSlices) const;
+
+ /// Constructs a new encoding with the dimSlices reset to the default,
+ /// and all other fields inherited from `this`.
+ SparseTensorEncodingAttr withoutDimSlices() const;
+
//
// Rank methods.
//
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
index 6cae09db36cc..cfc3374148f9 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
@@ -111,6 +111,15 @@ class SparseTensorType {
return withEncoding(enc.withoutBitWidths());
}
+ SparseTensorType
+ withDimSlices(ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {
+ return withEncoding(enc.withDimSlices(dimSlices));
+ }
+
+ SparseTensorType withoutDimSlices() const {
+ return withEncoding(enc.withoutDimSlices());
+ }
+
//
// Other methods.
//
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 962e0ac21c63..a1eda8968a55 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -291,6 +291,17 @@ SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutBitWidths() const {
return withBitWidths(0, 0);
}
+SparseTensorEncodingAttr SparseTensorEncodingAttr::withDimSlices(
+ ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {
+ return SparseTensorEncodingAttr::get(getContext(), getLvlTypes(),
+ getDimToLvl(), getPosWidth(),
+ getCrdWidth(), dimSlices);
+}
+
+SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimSlices() const {
+ return withDimSlices(ArrayRef<SparseTensorDimSliceAttr>{});
+}
+
bool SparseTensorEncodingAttr::isAllDense() const {
return !getImpl() || llvm::all_of(getLvlTypes(), isDenseDLT);
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index f84009c4b63b..a7f37e8189ea 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -1138,10 +1138,7 @@ class SparseExtractSliceConverter
// TODO: We should check these in ExtractSliceOp::verify.
if (!srcEnc || !dstEnc || !dstEnc.isSlice())
return failure();
- assert(srcEnc.getLvlTypes() == dstEnc.getLvlTypes());
- assert(srcEnc.getDimToLvl() == dstEnc.getDimToLvl());
- assert(srcEnc.getPosWidth() == dstEnc.getPosWidth());
- assert(srcEnc.getCrdWidth() == dstEnc.getCrdWidth());
+ assert(srcEnc.withoutDimSlices() == dstEnc.withoutDimSlices());
SmallVector<Value> fields;
auto desc = getMutDescriptorFromTensorTuple(adaptor.getSource(), fields);
More information about the Mlir-commits
mailing list