[Mlir-commits] [mlir] af2bec7 - [mlir][sparse] Adding new STEA::{with, without}DimSlices factories

wren romano llvmlistbot at llvm.org
Tue May 30 15:53:38 PDT 2023


Author: wren romano
Date: 2023-05-30T15:53:30-07:00
New Revision: af2bec7c4a967c9e2e009cdbc4470eb5ba8332f6

URL: https://github.com/llvm/llvm-project/commit/af2bec7c4a967c9e2e009cdbc4470eb5ba8332f6
DIFF: https://github.com/llvm/llvm-project/commit/af2bec7c4a967c9e2e009cdbc4470eb5ba8332f6.diff

LOG: [mlir][sparse] Adding new STEA::{with,without}DimSlices factories

(These factories are used in downstream code, despite not being used within the MLIR codebase.)

Depends On D151513

Reviewed By: Peiming

Differential Revision: https://reviews.llvm.org/D151518

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
    mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index f0a502e5dcd9..9fe425a40415 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -304,6 +304,14 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
     /// reset to the default, and all other fields inherited from `this`.
     SparseTensorEncodingAttr withoutBitWidths() const;
 
+    /// Constructs a new encoding with the given dimSlices, and all
+    /// other fields inherited from `this`.
+    SparseTensorEncodingAttr withDimSlices(ArrayRef<::mlir::sparse_tensor::SparseTensorDimSliceAttr> dimSlices) const;
+
+    /// Constructs a new encoding with the dimSlices reset to the default,
+    /// and all other fields inherited from `this`.
+    SparseTensorEncodingAttr withoutDimSlices() const;
+
     //
     // Rank methods.
     //

diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
index 6cae09db36cc..cfc3374148f9 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
@@ -111,6 +111,15 @@ class SparseTensorType {
     return withEncoding(enc.withoutBitWidths());
   }
 
+  SparseTensorType
+  withDimSlices(ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {
+    return withEncoding(enc.withDimSlices(dimSlices));
+  }
+
+  SparseTensorType withoutDimSlices() const {
+    return withEncoding(enc.withoutDimSlices());
+  }
+
   //
   // Other methods.
   //

diff  --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 962e0ac21c63..a1eda8968a55 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -291,6 +291,17 @@ SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutBitWidths() const {
   return withBitWidths(0, 0);
 }
 
+SparseTensorEncodingAttr SparseTensorEncodingAttr::withDimSlices(
+    ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {
+  return SparseTensorEncodingAttr::get(getContext(), getLvlTypes(),
+                                       getDimToLvl(), getPosWidth(),
+                                       getCrdWidth(), dimSlices);
+}
+
+SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimSlices() const {
+  return withDimSlices(ArrayRef<SparseTensorDimSliceAttr>{});
+}
+
 bool SparseTensorEncodingAttr::isAllDense() const {
   return !getImpl() || llvm::all_of(getLvlTypes(), isDenseDLT);
 }

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index f84009c4b63b..a7f37e8189ea 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -1138,10 +1138,7 @@ class SparseExtractSliceConverter
     // TODO: We should check these in ExtractSliceOp::verify.
     if (!srcEnc || !dstEnc || !dstEnc.isSlice())
       return failure();
-    assert(srcEnc.getLvlTypes() == dstEnc.getLvlTypes());
-    assert(srcEnc.getDimToLvl() == dstEnc.getDimToLvl());
-    assert(srcEnc.getPosWidth() == dstEnc.getPosWidth());
-    assert(srcEnc.getCrdWidth() == dstEnc.getCrdWidth());
+    assert(srcEnc.withoutDimSlices() == dstEnc.withoutDimSlices());
 
     SmallVector<Value> fields;
     auto desc = getMutDescriptorFromTensorTuple(adaptor.getSource(), fields);


        


More information about the Mlir-commits mailing list