[Mlir-commits] [mlir] 898bf53 - [mlir][sparse] Surface syntax change in parsing
Yinying Li
llvmlistbot at llvm.org
Fri Sep 1 12:25:39 PDT 2023
Author: Yinying Li
Date: 2023-09-01T19:25:00Z
New Revision: 898bf539a74faae6935cf49bbeb42ce3c777911f
URL: https://github.com/llvm/llvm-project/commit/898bf539a74faae6935cf49bbeb42ce3c777911f
DIFF: https://github.com/llvm/llvm-project/commit/898bf539a74faae6935cf49bbeb42ce3c777911f.diff
LOG: [mlir][sparse] Surface syntax change in parsing
Example: compressed(nonunique, nonordered) or compressed(nonordered, nonunique) instead of compressed_nu_no.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D159366
Added:
Modified:
mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp
mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp
mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.h
mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index b51d272a790f4a..675c1534779192 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -169,6 +169,9 @@ enum class Action : uint32_t {
///
// TODO: We should generalize TwoOutOfFour to N out of M and use property to
// encode the value of N and M.
+// TODO: Update DimLevelType to use lower 8 bits for storage formats and the
+// higher 4 bits to store level properties. Consider CompressedWithHi and
+// TwoOutOfFour as properties instead of formats.
enum class DimLevelType : uint8_t {
Undef = 0, // 0b00000_00
Dense = 4, // 0b00001_00
@@ -197,6 +200,14 @@ enum class LevelFormat : uint8_t {
TwoOutOfFour = 64, // 0b10000_00
};
+/// This enum defines all the nondefault properties for storage formats.
+enum class LevelNondefaultProperty : uint8_t {
+ Nonunique = 1, // 0b00000_01
+ Nonordered = 2, // 0b00000_10
+ High = 32, // 0b01000_00
+ Block2_4 = 64 // 0b10000_00
+};
+
/// Returns string representation of the given dimension level type.
constexpr const char *toMLIRString(DimLevelType dlt) {
switch (dlt) {
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp
index e1eaa8a4d3f9c5..81302f200f686b 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp
@@ -391,7 +391,7 @@ void DimLvlMap::print(llvm::raw_ostream &os, bool wantElision) const {
os << '{';
llvm::interleaveComma(
lvlSpecs, os, [&](LvlSpec const &spec) { os << spec.getBoundVar(); });
- os << '}';
+ os << "} ";
}
// Dimension specifiers.
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp
index 3b6cedd6596297..680411f8008ea3 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp
@@ -354,7 +354,7 @@ ParseResult DimLvlMapParser::parseLvlSpec(bool requireLvlVarBinding) {
const auto type = lvlTypeParser.parseLvlType(parser);
FAILURE_IF_FAILED(type)
- lvlSpecs.emplace_back(var, expr, *type);
+ lvlSpecs.emplace_back(var, expr, static_cast<DimLevelType>(*type));
return success();
}
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
index 986529c45983ea..b8483f5db130dc 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "LvlTypeParser.h"
+#include "mlir/Dialect/SparseTensor/IR/Enums.h"
using namespace mlir;
using namespace mlir::sparse_tensor;
@@ -46,34 +47,57 @@ using namespace mlir::sparse_tensor::ir_detail;
// `LvlTypeParser` implementation.
//===----------------------------------------------------------------------===//
-std::optional<DimLevelType> LvlTypeParser::lookup(StringRef str) const {
- // NOTE: `StringMap::lookup` will return a default-constructed value if
- // the key isn't found; which for enums means zero, and therefore makes
- // it impossible to distinguish between actual zero-DimLevelType vs
- // not-found. Whereas `StringMap::at` asserts that the key is found,
- // which we don't want either.
- const auto it = map.find(str);
- return it == map.end() ? std::nullopt : std::make_optional(it->second);
-}
+FailureOr<uint8_t> LvlTypeParser::parseLvlType(AsmParser &parser) const {
+ StringRef base;
+ FAILURE_IF_FAILED(parser.parseOptionalKeyword(&base));
+ uint8_t properties = 0;
+ const auto loc = parser.getCurrentLocation();
-std::optional<DimLevelType> LvlTypeParser::lookup(StringAttr str) const {
- return str ? lookup(str.getValue()) : std::nullopt;
-}
+ ParseResult res = parser.parseCommaSeparatedList(
+ mlir::OpAsmParser::Delimiter::OptionalParen,
+ [&]() -> ParseResult { return parseProperty(parser, &properties); },
+ " in level property list");
+ FAILURE_IF_FAILED(res)
-FailureOr<DimLevelType> LvlTypeParser::parseLvlType(AsmParser &parser) const {
- DimLevelType out;
- FAILURE_IF_FAILED(parseLvlType(parser, out))
- return out;
+ // Set the base bit for properties.
+ if (base.compare("dense") == 0) {
+ properties |= static_cast<uint8_t>(LevelFormat::Dense);
+ } else if (base.compare("compressed") == 0) {
+ // TODO: Remove this condition once dimLvlType enum is refactored. Current
+ // enum treats High and TwoOutOfFour as formats instead of properties.
+ if (!(properties & static_cast<uint8_t>(LevelNondefaultProperty::High) ||
+ properties &
+ static_cast<uint8_t>(LevelNondefaultProperty::Block2_4))) {
+ properties |= static_cast<uint8_t>(LevelFormat::Compressed);
+ }
+ } else if (base.compare("singleton") == 0) {
+ properties |= static_cast<uint8_t>(LevelFormat::Singleton);
+ } else {
+ parser.emitError(loc, "unknown level format");
+ return failure();
+ }
+
+ ERROR_IF(!isValidDLT(static_cast<DimLevelType>(properties)),
+ "invalid level type");
+ return properties;
}
-ParseResult LvlTypeParser::parseLvlType(AsmParser &parser,
- DimLevelType &out) const {
- const auto loc = parser.getCurrentLocation();
+ParseResult LvlTypeParser::parseProperty(AsmParser &parser,
+ uint8_t *properties) const {
StringRef strVal;
FAILURE_IF_FAILED(parser.parseOptionalKeyword(&strVal));
- const auto lvlType = lookup(strVal);
- ERROR_IF(!lvlType, "unknown level-type '" + strVal + "'")
- out = *lvlType;
+ if (strVal.compare("nonunique") == 0) {
+ *properties |= static_cast<uint8_t>(LevelNondefaultProperty::Nonunique);
+ } else if (strVal.compare("nonordered") == 0) {
+ *properties |= static_cast<uint8_t>(LevelNondefaultProperty::Nonordered);
+ } else if (strVal.compare("high") == 0) {
+ *properties |= static_cast<uint8_t>(LevelNondefaultProperty::High);
+ } else if (strVal.compare("block2_4") == 0) {
+ *properties |= static_cast<uint8_t>(LevelNondefaultProperty::Block2_4);
+ } else {
+ parser.emitError(parser.getCurrentLocation(), "unknown level property");
+ return failure();
+ }
return success();
}
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.h b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.h
index b1fb3a42e41fe2..10fb6c8f1c0473 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.h
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.h
@@ -9,56 +9,19 @@
#ifndef MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_LVLTYPEPARSER_H
#define MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_LVLTYPEPARSER_H
-#include "mlir/Dialect/SparseTensor/IR/Enums.h"
#include "mlir/IR/OpImplementation.h"
-#include "llvm/ADT/StringMap.h"
namespace mlir {
namespace sparse_tensor {
namespace ir_detail {
-//===----------------------------------------------------------------------===//
-// These macros are for generating a C++ expression of type
-// `std::initializer_list<std::pair<StringRef,DimLevelType>>` since there's
-// no way to construct an object of that type directly via C++ code.
-#define FOREVERY_LEVELTYPE(DO) \
- DO(DimLevelType::Dense) \
- DO(DimLevelType::Compressed) \
- DO(DimLevelType::CompressedNu) \
- DO(DimLevelType::CompressedNo) \
- DO(DimLevelType::CompressedNuNo) \
- DO(DimLevelType::Singleton) \
- DO(DimLevelType::SingletonNu) \
- DO(DimLevelType::SingletonNo) \
- DO(DimLevelType::SingletonNuNo) \
- DO(DimLevelType::CompressedWithHi) \
- DO(DimLevelType::CompressedWithHiNu) \
- DO(DimLevelType::CompressedWithHiNo) \
- DO(DimLevelType::CompressedWithHiNuNo) \
- DO(DimLevelType::TwoOutOfFour)
-#define LEVELTYPE_INITLIST_ELEMENT(lvlType) \
- std::make_pair(StringRef(toMLIRString(lvlType)), lvlType),
-#define LEVELTYPE_INITLIST \
- { FOREVERY_LEVELTYPE(LEVELTYPE_INITLIST_ELEMENT) }
-
-// TODO(wrengr): Since this parser is non-trivial to construct, is there
-// any way to hook into the parsing process so that we construct it only once
-// at the begining of parsing and then destroy it once parsing has finished?
class LvlTypeParser {
- const llvm::StringMap<DimLevelType> map;
-
public:
- explicit LvlTypeParser() : map(LEVELTYPE_INITLIST) {}
-#undef LEVELTYPE_INITLIST
-#undef LEVELTYPE_INITLIST_ELEMENT
-#undef FOREVERY_LEVELTYPE
+ LvlTypeParser() = default;
+ FailureOr<uint8_t> parseLvlType(AsmParser &parser) const;
- std::optional<DimLevelType> lookup(StringRef str) const;
- std::optional<DimLevelType> lookup(StringAttr str) const;
- ParseResult parseLvlType(AsmParser &parser, DimLevelType &out) const;
- FailureOr<DimLevelType> parseLvlType(AsmParser &parser) const;
- // TODO(wrengr): `parseOptionalLvlType`?
- // TODO(wrengr): `parseLvlTypeList`?
+private:
+ ParseResult parseProperty(AsmParser &parser, uint8_t *properties) const;
};
} // namespace ir_detail
diff --git a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
index a3cb63e50015af..1cc5c0e3f61527 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
@@ -55,7 +55,7 @@ func.func private @sparse_dcsc(tensor<?x?xf32, #DCSC>)
// -----
#COO = #sparse_tensor.encoding<{
- lvlTypes = [ "compressed_nu_no", "singleton_no" ]
+ map = (d0, d1) -> (d0 : compressed(nonunique, nonordered), d1 : singleton(nonordered))
}>
// CHECK-LABEL: func private @sparse_coo(
@@ -65,7 +65,7 @@ func.func private @sparse_coo(tensor<?x?xf32, #COO>)
// -----
#BCOO = #sparse_tensor.encoding<{
- lvlTypes = [ "dense", "compressed_hi_nu", "singleton" ]
+ map = (d0, d1, d2) -> (d0 : dense, d1 : compressed(nonunique, high), d2 : singleton)
}>
// CHECK-LABEL: func private @sparse_bcoo(
@@ -75,7 +75,7 @@ func.func private @sparse_bcoo(tensor<?x?x?xf32, #BCOO>)
// -----
#SortedCOO = #sparse_tensor.encoding<{
- lvlTypes = [ "compressed_nu", "singleton" ]
+ map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton)
}>
// CHECK-LABEL: func private @sparse_sorted_coo(
@@ -144,7 +144,7 @@ func.func private @sparse_slice(tensor<?x?xf64, #CSR_SLICE>)
// below) to encode a 2D matrix, but it would require dim2lvl mapping which is not ready yet.
// So we take the simple path for now.
#NV_24= #sparse_tensor.encoding<{
- lvlTypes = [ "dense", "compressed24" ],
+ map = (d0, d1) -> (d0 : dense, d1 : compressed(block2_4))
}>
// CHECK-LABEL: func private @sparse_2_out_of_4(
@@ -195,7 +195,7 @@ func.func private @BCSR_explicit(%arg0: tensor<?x?xf64, #BCSR_explicit>) {
map = ( i, j ) ->
( i : dense,
j floordiv 4 : dense,
- j mod 4 : compressed24
+ j mod 4 : compressed(block2_4)
)
}>
More information about the Mlir-commits
mailing list