[Mlir-commits] [mlir] [mlir][sparse] Add verification for explicit/implicit value (PR #90111)
Yinying Li
llvmlistbot at llvm.org
Thu May 2 10:42:54 PDT 2024
https://github.com/yinying-lisa-li updated https://github.com/llvm/llvm-project/pull/90111
>From 1cb0885be2aa1ab1e46e516b7863704d6d6e77ff Mon Sep 17 00:00:00 2001
From: Yinying Li <yinyingli at google.com>
Date: Thu, 25 Apr 2024 00:39:00 +0000
Subject: [PATCH 1/6] add verification for explicit/implicit values
---
.../SparseTensor/IR/SparseTensorDialect.cpp | 35 +++++++++
.../SparseTensor/invalid_encoding.mlir | 72 +++++++++++++++++++
2 files changed, 107 insertions(+)
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index de3d3006ebaac5..c4524a24346eb9 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -912,6 +912,41 @@ LogicalResult SparseTensorEncodingAttr::verifyEncoding(
return emitError()
<< "dimension-rank mismatch between encoding and tensor shape: "
<< getDimRank() << " != " << dimRank;
+ Type expType, impType;
+ if (getExplicitVal()) {
+ auto fVal = llvm::dyn_cast<FloatAttr>(getExplicitVal());
+ auto intVal = llvm::dyn_cast<IntegerAttr>(getExplicitVal());
+ if (fVal && fVal.getType() != elementType) {
+ expType = fVal.getType();
+ } else if (intVal && intVal.getType() != elementType) {
+ expType = intVal.getType();
+ }
+ if (expType) {
+ return emitError() << "explicit value type mismatch between encoding and "
+ << "tensor element type: " << expType
+ << " != " << elementType;
+ }
+ }
+
+ if (getImplicitVal()) {
+ auto impFVal = llvm::dyn_cast<FloatAttr>(getImplicitVal());
+ auto impIntVal = llvm::dyn_cast<IntegerAttr>(getImplicitVal());
+ if (impFVal && impFVal.getType() != elementType) {
+ impType = impFVal.getType();
+ } else if (impIntVal && impIntVal.getType() != elementType) {
+ impType = impIntVal.getType();
+ }
+ if (impType) {
+ return emitError() << "implicit value type mismatch between encoding and "
+ << "tensor element type: " << impType
+ << " != " << elementType;
+ }
+ // Currently, we only support zero as the implicit value.
+ if ((impFVal && impFVal.getValueAsDouble() != 0.0) ||
+ (impIntVal && impIntVal.getInt() != 0)) {
+ return emitError() << "implicit value must be zero";
+ }
+ }
return success();
}
diff --git a/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir b/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir
index 8096c010ac935a..19e8fc95e22813 100644
--- a/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir
@@ -443,3 +443,75 @@ func.func private @NOutOfM(%arg0: tensor<?x?x?xf64, #NOutOfM>) {
func.func private @NOutOfM(%arg0: tensor<?x?x?xf64, #NOutOfM>) {
return
}
+
+// -----
+
+#CSR_ExpType = #sparse_tensor.encoding<{
+ map = (d0, d1) -> (d0 : dense, d1 : compressed),
+ posWidth = 32,
+ crdWidth = 32,
+ explicitVal = 1 : i32,
+ implicitVal = 0.0 : f32
+}>
+
+// expected-error at +1 {{explicit value type mismatch between encoding and tensor element type: 'i32' != 'f32'}}
+func.func private @sparse_csr(tensor<?x?xf32, #CSR_ExpType>)
+
+// -----
+
+#CSR_ImpType = #sparse_tensor.encoding<{
+ map = (d0, d1) -> (d0 : dense, d1 : compressed),
+ posWidth = 32,
+ crdWidth = 32,
+ explicitVal = 1 : i32,
+ implicitVal = 0.0 : f32
+}>
+
+// expected-error at +1 {{implicit value type mismatch between encoding and tensor element type: 'f32' != 'i32'}}
+func.func private @sparse_csr(tensor<?x?xi32, #CSR_ImpType>)
+
+// -----
+
+// expected-error at +1 {{expected a numeric value for explicitVal}}
+#CSR_ExpType = #sparse_tensor.encoding<{
+ map = (d0, d1) -> (d0 : dense, d1 : compressed),
+ posWidth = 32,
+ crdWidth = 32,
+ explicitVal = "str"
+}>
+func.func private @sparse_csr(tensor<?x?xi32, #CSR_ExpType>)
+
+// -----
+
+// expected-error at +1 {{expected a numeric value for implicitVal}}
+#CSR_ImpType = #sparse_tensor.encoding<{
+ map = (d0, d1) -> (d0 : dense, d1 : compressed),
+ posWidth = 32,
+ crdWidth = 32,
+ implicitVal = "str"
+}>
+func.func private @sparse_csr(tensor<?x?xi32, #CSR_ImpType>)
+
+// -----
+
+#CSR_ImpVal = #sparse_tensor.encoding<{
+ map = (d0, d1) -> (d0 : dense, d1 : compressed),
+ posWidth = 32,
+ crdWidth = 32,
+ implicitVal = 1 : i32
+}>
+
+// expected-error at +1 {{implicit value must be zero}}
+func.func private @sparse_csr(tensor<?x?xi32, #CSR_ImpVal>)
+
+// -----
+
+#CSR_ImpVal = #sparse_tensor.encoding<{
+ map = (d0, d1) -> (d0 : dense, d1 : compressed),
+ posWidth = 32,
+ crdWidth = 32,
+ implicitVal = 1.0 : f32
+}>
+
+// expected-error at +1 {{implicit value must be zero}}
+func.func private @sparse_csr(tensor<?x?xf32, #CSR_ImpVal>)
>From 7497b4815bca0f7de62f97737dfd35c179c2f174 Mon Sep 17 00:00:00 2001
From: Yinying Li <yinyingli at google.com>
Date: Thu, 25 Apr 2024 17:37:28 +0000
Subject: [PATCH 2/6] new function
---
.../SparseTensor/IR/SparseTensorAttrDefs.td | 5 +++
.../SparseTensor/IR/SparseTensorDialect.cpp | 41 ++++++++++---------
2 files changed, 26 insertions(+), 20 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index eefa4c71bbd2ca..37fa4913aa6a60 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -512,6 +512,11 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
void printSymbols(AffineMap &map, AsmPrinter &printer) const;
void printDimensions(AffineMap &map, AsmPrinter &printer, ArrayRef<::mlir::sparse_tensor::SparseTensorDimSliceAttr> dimSlices) const;
void printLevels(AffineMap &map, AsmPrinter &printer, ArrayRef<::mlir::sparse_tensor::LevelType> lvlTypes) const;
+
+ //
+ // Explicit/implicit value methods.
+ //
+ Type getMismatchedValueType(Type elementType, Attribute val) const;
}];
let genVerifyDecl = 1;
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index c4524a24346eb9..d451864c66db8a 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -893,6 +893,19 @@ LogicalResult SparseTensorEncodingAttr::verify(
return success();
}
+Type SparseTensorEncodingAttr::getMismatchedValueType(Type elementType,
+ Attribute val) const {
+ Type type;
+ auto fVal = llvm::dyn_cast<FloatAttr>(val);
+ auto intVal = llvm::dyn_cast<IntegerAttr>(val);
+ if (fVal && fVal.getType() != elementType) {
+ type = fVal.getType();
+ } else if (intVal && intVal.getType() != elementType) {
+ type = intVal.getType();
+ }
+ return type;
+}
+
LogicalResult SparseTensorEncodingAttr::verifyEncoding(
ArrayRef<Size> dimShape, Type elementType,
function_ref<InFlightDiagnostic()> emitError) const {
@@ -912,36 +925,24 @@ LogicalResult SparseTensorEncodingAttr::verifyEncoding(
return emitError()
<< "dimension-rank mismatch between encoding and tensor shape: "
<< getDimRank() << " != " << dimRank;
- Type expType, impType;
+ Type type;
if (getExplicitVal()) {
- auto fVal = llvm::dyn_cast<FloatAttr>(getExplicitVal());
- auto intVal = llvm::dyn_cast<IntegerAttr>(getExplicitVal());
- if (fVal && fVal.getType() != elementType) {
- expType = fVal.getType();
- } else if (intVal && intVal.getType() != elementType) {
- expType = intVal.getType();
- }
- if (expType) {
+ if ((type = getMismatchedValueType(elementType, getExplicitVal()))) {
return emitError() << "explicit value type mismatch between encoding and "
- << "tensor element type: " << expType
+ << "tensor element type: " << type
<< " != " << elementType;
}
}
-
if (getImplicitVal()) {
- auto impFVal = llvm::dyn_cast<FloatAttr>(getImplicitVal());
- auto impIntVal = llvm::dyn_cast<IntegerAttr>(getImplicitVal());
- if (impFVal && impFVal.getType() != elementType) {
- impType = impFVal.getType();
- } else if (impIntVal && impIntVal.getType() != elementType) {
- impType = impIntVal.getType();
- }
- if (impType) {
+ auto impVal = getImplicitVal();
+ if ((type = getMismatchedValueType(elementType, impVal))) {
return emitError() << "implicit value type mismatch between encoding and "
- << "tensor element type: " << impType
+ << "tensor element type: " << type
<< " != " << elementType;
}
// Currently, we only support zero as the implicit value.
+ auto impFVal = llvm::dyn_cast<FloatAttr>(impVal);
+ auto impIntVal = llvm::dyn_cast<IntegerAttr>(impVal);
if ((impFVal && impFVal.getValueAsDouble() != 0.0) ||
(impIntVal && impIntVal.getInt() != 0)) {
return emitError() << "implicit value must be zero";
>From 94424277f21c479032e1d0d6fd0f5c630a39e989 Mon Sep 17 00:00:00 2001
From: Yinying Li <yinyingli at google.com>
Date: Thu, 25 Apr 2024 23:49:03 +0000
Subject: [PATCH 3/6] use TypedAttr
---
.../SparseTensor/IR/SparseTensorAttrDefs.td | 5 ---
.../SparseTensor/IR/SparseTensorDialect.cpp | 40 +++++++++----------
2 files changed, 18 insertions(+), 27 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index 37fa4913aa6a60..eefa4c71bbd2ca 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -512,11 +512,6 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
void printSymbols(AffineMap &map, AsmPrinter &printer) const;
void printDimensions(AffineMap &map, AsmPrinter &printer, ArrayRef<::mlir::sparse_tensor::SparseTensorDimSliceAttr> dimSlices) const;
void printLevels(AffineMap &map, AsmPrinter &printer, ArrayRef<::mlir::sparse_tensor::LevelType> lvlTypes) const;
-
- //
- // Explicit/implicit value methods.
- //
- Type getMismatchedValueType(Type elementType, Attribute val) const;
}];
let genVerifyDecl = 1;
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index d451864c66db8a..8c74d5a14779e3 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -893,19 +893,6 @@ LogicalResult SparseTensorEncodingAttr::verify(
return success();
}
-Type SparseTensorEncodingAttr::getMismatchedValueType(Type elementType,
- Attribute val) const {
- Type type;
- auto fVal = llvm::dyn_cast<FloatAttr>(val);
- auto intVal = llvm::dyn_cast<IntegerAttr>(val);
- if (fVal && fVal.getType() != elementType) {
- type = fVal.getType();
- } else if (intVal && intVal.getType() != elementType) {
- type = intVal.getType();
- }
- return type;
-}
-
LogicalResult SparseTensorEncodingAttr::verifyEncoding(
ArrayRef<Size> dimShape, Type elementType,
function_ref<InFlightDiagnostic()> emitError) const {
@@ -925,20 +912,29 @@ LogicalResult SparseTensorEncodingAttr::verifyEncoding(
return emitError()
<< "dimension-rank mismatch between encoding and tensor shape: "
<< getDimRank() << " != " << dimRank;
- Type type;
if (getExplicitVal()) {
- if ((type = getMismatchedValueType(elementType, getExplicitVal()))) {
- return emitError() << "explicit value type mismatch between encoding and "
- << "tensor element type: " << type
- << " != " << elementType;
+ if (auto typedAttr = llvm::dyn_cast<TypedAttr>(getExplicitVal())) {
+ Type attrType = typedAttr.getType();
+ if (attrType != elementType) {
+ return emitError()
+ << "explicit value type mismatch between encoding and "
+ << "tensor element type: " << attrType << " != " << elementType;
+ }
+ } else {
+ return emitError() << "expected typed explicit value";
}
}
if (getImplicitVal()) {
auto impVal = getImplicitVal();
- if ((type = getMismatchedValueType(elementType, impVal))) {
- return emitError() << "implicit value type mismatch between encoding and "
- << "tensor element type: " << type
- << " != " << elementType;
+ if (auto typedAttr = llvm::dyn_cast<TypedAttr>(getImplicitVal())) {
+ Type attrType = typedAttr.getType();
+ if (attrType != elementType) {
+ return emitError()
+ << "implicit value type mismatch between encoding and "
+ << "tensor element type: " << attrType << " != " << elementType;
+ }
+ } else {
+ return emitError() << "expected typed implicit value";
}
// Currently, we only support zero as the implicit value.
auto impFVal = llvm::dyn_cast<FloatAttr>(impVal);
>From c96df1e08041bdd4fdc465846fd008bbb18d8aaa Mon Sep 17 00:00:00 2001
From: Yinying Li <yinyingli at google.com>
Date: Mon, 29 Apr 2024 19:37:43 +0000
Subject: [PATCH 4/6] remove redundant call
---
mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 8c74d5a14779e3..1b397e950179b8 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -926,7 +926,7 @@ LogicalResult SparseTensorEncodingAttr::verifyEncoding(
}
if (getImplicitVal()) {
auto impVal = getImplicitVal();
- if (auto typedAttr = llvm::dyn_cast<TypedAttr>(getImplicitVal())) {
+ if (auto typedAttr = llvm::dyn_cast<TypedAttr>(impVal)) {
Type attrType = typedAttr.getType();
if (attrType != elementType) {
return emitError()
>From bf8d54a0c75ad4ca380b57bac645a4b689fe8b9b Mon Sep 17 00:00:00 2001
From: Yinying Li <yinyingli at google.com>
Date: Thu, 2 May 2024 17:09:16 +0000
Subject: [PATCH 5/6] add complex type verification
---
.../SparseTensor/IR/SparseTensorAttrDefs.td | 28 ++++++
.../SparseTensor/IR/SparseTensorType.h | 23 +----
.../SparseTensor/IR/SparseTensorDialect.cpp | 95 +++++++++----------
.../SparseTensor/invalid_encoding.mlir | 13 +++
4 files changed, 90 insertions(+), 69 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index eefa4c71bbd2ca..86d7de0e66faa2 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -502,9 +502,37 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
//
// Helper function to translate between level/dimension space.
//
+
SmallVector<int64_t> translateShape(::mlir::ArrayRef<int64_t> srcShape, ::mlir::sparse_tensor::CrdTransDirectionKind) const;
ValueRange translateCrds(::mlir::OpBuilder &builder, ::mlir::Location loc, ::mlir::ValueRange crds, ::mlir::sparse_tensor::CrdTransDirectionKind) const;
+ //
+ // COO struct and methods.
+ //
+
+ /// A simple structure that encodes a range of levels in the sparse tensors
+ /// that forms a COO segment.
+ struct COOSegment {
+ std::pair<Level, Level> lvlRange; // [low, high)
+ bool isSoA;
+
+ bool isAoS() const { return !isSoA; }
+ bool isSegmentStart(Level l) const { return l == lvlRange.first; }
+ bool inSegment(Level l) const {
+ return l >= lvlRange.first && l < lvlRange.second;
+ }
+ };
+
+ /// Returns the starting level of this sparse tensor type for a
+ /// trailing COO region that spans **at least** two levels. If
+ /// no such COO region is found, then returns the level-rank.
+ ///
+ /// DEPRECATED: use getCOOSegment instead;
+ Level getAoSCOOStart() const;
+
+ /// Returns a list of COO segments in the sparse tensor types.
+ SmallVector<COOSegment> getCOOSegments() const;
+
//
// Printing methods.
//
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
index ea3d8013b45671..365a8cba30bd59 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
@@ -18,18 +18,6 @@
namespace mlir {
namespace sparse_tensor {
-/// A simple structure that encodes a range of levels in the sparse tensors that
-/// forms a COO segment.
-struct COOSegment {
- std::pair<Level, Level> lvlRange; // [low, high)
- bool isSoA;
-
- bool isAoS() const { return !isSoA; }
- bool isSegmentStart(Level l) const { return l == lvlRange.first; }
- bool inSegment(Level l) const {
- return l >= lvlRange.first && l < lvlRange.second;
- }
-};
//===----------------------------------------------------------------------===//
/// A wrapper around `RankedTensorType`, which has three goals:
@@ -73,11 +61,6 @@ class SparseTensorType {
: SparseTensorType(
RankedTensorType::get(stp.getShape(), stp.getElementType(), enc)) {}
- // TODO: remove?
- SparseTensorType(SparseTensorEncodingAttr enc)
- : SparseTensorType(RankedTensorType::get(
- SmallVector<Size>(enc.getDimRank(), ShapedType::kDynamic),
- Float32Type::get(enc.getContext()), enc)) {}
SparseTensorType &operator=(const SparseTensorType &) = delete;
SparseTensorType(const SparseTensorType &) = default;
@@ -369,13 +352,15 @@ class SparseTensorType {
/// no such COO region is found, then returns the level-rank.
///
/// DEPRECATED: use getCOOSegment instead;
- Level getAoSCOOStart() const;
+ Level getAoSCOOStart() const { return getEncoding().getAoSCOOStart(); };
/// Returns [un]ordered COO type for this sparse tensor type.
RankedTensorType getCOOType(bool ordered) const;
/// Returns a list of COO segments in the sparse tensor types.
- SmallVector<COOSegment> getCOOSegments() const;
+ SmallVector<SparseTensorEncodingAttr::COOSegment> getCOOSegments() const {
+ return getEncoding().getCOOSegments();
+ }
private:
// These two must be const, to ensure coherence of the memoized fields.
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 1b397e950179b8..8626cb141abfbb 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -104,7 +104,8 @@ void StorageLayout::foreachField(
callback) const {
const auto lvlTypes = enc.getLvlTypes();
const Level lvlRank = enc.getLvlRank();
- SmallVector<COOSegment> cooSegs = SparseTensorType(enc).getCOOSegments();
+ SmallVector<SparseTensorEncodingAttr::COOSegment> cooSegs =
+ enc.getCOOSegments();
FieldIndex fieldIdx = kDataFieldStartingIdx;
ArrayRef cooSegsRef = cooSegs;
@@ -211,7 +212,7 @@ StorageLayout::getFieldIndexAndStride(SparseTensorFieldKind kind,
unsigned stride = 1;
if (kind == SparseTensorFieldKind::CrdMemRef) {
assert(lvl.has_value());
- const Level cooStart = SparseTensorType(enc).getAoSCOOStart();
+ const Level cooStart = enc.getAoSCOOStart();
const Level lvlRank = enc.getLvlRank();
if (lvl.value() >= cooStart && lvl.value() < lvlRank) {
lvl = cooStart;
@@ -912,78 +913,53 @@ LogicalResult SparseTensorEncodingAttr::verifyEncoding(
return emitError()
<< "dimension-rank mismatch between encoding and tensor shape: "
<< getDimRank() << " != " << dimRank;
- if (getExplicitVal()) {
- if (auto typedAttr = llvm::dyn_cast<TypedAttr>(getExplicitVal())) {
- Type attrType = typedAttr.getType();
- if (attrType != elementType) {
- return emitError()
- << "explicit value type mismatch between encoding and "
- << "tensor element type: " << attrType << " != " << elementType;
- }
- } else {
- return emitError() << "expected typed explicit value";
+ if (auto expVal = getExplicitVal()) {
+ Type attrType = llvm::dyn_cast<TypedAttr>(expVal).getType();
+ if (attrType != elementType) {
+ return emitError() << "explicit value type mismatch between encoding and "
+ << "tensor element type: " << attrType
+ << " != " << elementType;
}
}
- if (getImplicitVal()) {
- auto impVal = getImplicitVal();
- if (auto typedAttr = llvm::dyn_cast<TypedAttr>(impVal)) {
- Type attrType = typedAttr.getType();
- if (attrType != elementType) {
- return emitError()
- << "implicit value type mismatch between encoding and "
- << "tensor element type: " << attrType << " != " << elementType;
- }
- } else {
- return emitError() << "expected typed implicit value";
+ if (auto impVal = getImplicitVal()) {
+ Type attrType = llvm::dyn_cast<TypedAttr>(impVal).getType();
+ if (attrType != elementType) {
+ return emitError() << "implicit value type mismatch between encoding and "
+ << "tensor element type: " << attrType
+ << " != " << elementType;
}
// Currently, we only support zero as the implicit value.
auto impFVal = llvm::dyn_cast<FloatAttr>(impVal);
auto impIntVal = llvm::dyn_cast<IntegerAttr>(impVal);
- if ((impFVal && impFVal.getValueAsDouble() != 0.0) ||
- (impIntVal && impIntVal.getInt() != 0)) {
+ auto impComplexVal = llvm::dyn_cast<complex::NumberAttr>(impVal);
+ if ((impFVal && impFVal.getValue().isNonZero()) ||
+ (impIntVal && !impIntVal.getValue().isZero()) ||
+ (impComplexVal && (impComplexVal.getImag().isNonZero() ||
+ impComplexVal.getReal().isNonZero()))) {
return emitError() << "implicit value must be zero";
}
}
return success();
}
-//===----------------------------------------------------------------------===//
-// SparseTensorType Methods.
-//===----------------------------------------------------------------------===//
-
-bool mlir::sparse_tensor::SparseTensorType::isCOOType(Level startLvl,
- bool isUnique) const {
- if (!hasEncoding())
- return false;
- if (!isCompressedLvl(startLvl) && !isLooseCompressedLvl(startLvl))
- return false;
- for (Level l = startLvl + 1; l < lvlRank; ++l)
- if (!isSingletonLvl(l))
- return false;
- // If isUnique is true, then make sure that the last level is unique,
- // that is, when lvlRank == 1, the only compressed level is unique,
- // and when lvlRank > 1, the last singleton is unique.
- return !isUnique || isUniqueLvl(lvlRank - 1);
-}
-
-Level mlir::sparse_tensor::SparseTensorType::getAoSCOOStart() const {
+Level mlir::sparse_tensor::SparseTensorEncodingAttr::getAoSCOOStart() const {
SmallVector<COOSegment> coo = getCOOSegments();
assert(coo.size() == 1 || coo.empty());
if (!coo.empty() && coo.front().isAoS()) {
return coo.front().lvlRange.first;
}
- return lvlRank;
+ return getLvlRank();
}
-SmallVector<COOSegment>
-mlir::sparse_tensor::SparseTensorType::getCOOSegments() const {
+SmallVector<SparseTensorEncodingAttr::COOSegment>
+mlir::sparse_tensor::SparseTensorEncodingAttr::getCOOSegments() const {
SmallVector<COOSegment> ret;
- if (!hasEncoding() || lvlRank <= 1)
+ if (getLvlRank() <= 1)
return ret;
ArrayRef<LevelType> lts = getLvlTypes();
Level l = 0;
- while (l < lvlRank) {
+ while (l < getLvlRank()) {
auto lt = lts[l];
if (lt.isa<LevelFormat::Compressed, LevelFormat::LooseCompressed>()) {
auto cur = lts.begin() + l;
@@ -1007,6 +983,25 @@ mlir::sparse_tensor::SparseTensorType::getCOOSegments() const {
return ret;
}
+//===----------------------------------------------------------------------===//
+// SparseTensorType Methods.
+//===----------------------------------------------------------------------===//
+
+bool mlir::sparse_tensor::SparseTensorType::isCOOType(Level startLvl,
+ bool isUnique) const {
+ if (!hasEncoding())
+ return false;
+ if (!isCompressedLvl(startLvl) && !isLooseCompressedLvl(startLvl))
+ return false;
+ for (Level l = startLvl + 1; l < lvlRank; ++l)
+ if (!isSingletonLvl(l))
+ return false;
+ // If isUnique is true, then make sure that the last level is unique,
+ // that is, when lvlRank == 1, the only compressed level is unique,
+ // and when lvlRank > 1, the last singleton is unique.
+ return !isUnique || isUniqueLvl(lvlRank - 1);
+}
+
RankedTensorType
mlir::sparse_tensor::SparseTensorType::getCOOType(bool ordered) const {
SmallVector<LevelType> lvlTypes;
diff --git a/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir b/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir
index 19e8fc95e22813..a3f72bd3ae971c 100644
--- a/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir
@@ -515,3 +515,16 @@ func.func private @sparse_csr(tensor<?x?xi32, #CSR_ImpVal>)
// expected-error at +1 {{implicit value must be zero}}
func.func private @sparse_csr(tensor<?x?xf32, #CSR_ImpVal>)
+
+// -----
+
+#CSR_OnlyOnes = #sparse_tensor.encoding<{
+ map = (d0, d1) -> (d0 : dense, d1 : compressed),
+ posWidth = 64,
+ crdWidth = 64,
+ explicitVal = #complex.number<:f32 1.0, 0.0>,
+ implicitVal = #complex.number<:f32 1.0, 0.0>
+}>
+
+// expected-error at +1 {{implicit value must be zero}}
+func.func private @sparse_csr(tensor<?x?xcomplex<f32>, #CSR_OnlyOnes>)
>From 35c1a47ffc7fbc5ff9dab9e09c3615c3dd326c3d Mon Sep 17 00:00:00 2001
From: Yinying Li <yinyingli at google.com>
Date: Thu, 2 May 2024 17:42:09 +0000
Subject: [PATCH 6/6] format
---
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h | 2 --
1 file changed, 2 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
index 365a8cba30bd59..664ca6f0127213 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
@@ -18,7 +18,6 @@
namespace mlir {
namespace sparse_tensor {
-
//===----------------------------------------------------------------------===//
/// A wrapper around `RankedTensorType`, which has three goals:
///
@@ -61,7 +60,6 @@ class SparseTensorType {
: SparseTensorType(
RankedTensorType::get(stp.getShape(), stp.getElementType(), enc)) {}
-
SparseTensorType &operator=(const SparseTensorType &) = delete;
SparseTensorType(const SparseTensorType &) = default;
More information about the Mlir-commits
mailing list