[Mlir-commits] [mlir] d5d4fb6 - [mlir][linalg] Add support for using scalar attributes in TC ops.
Hanhan Wang
llvmlistbot at llvm.org
Wed Mar 10 01:51:44 PST 2021
Author: Hanhan Wang
Date: 2021-03-10T01:51:12-08:00
New Revision: d5d4fb635ee0b3852f4fee1c98bac87c5811e051
URL: https://github.com/llvm/llvm-project/commit/d5d4fb635ee0b3852f4fee1c98bac87c5811e051
DIFF: https://github.com/llvm/llvm-project/commit/d5d4fb635ee0b3852f4fee1c98bac87c5811e051.diff
LOG: [mlir][linalg] Add support for using scalar attributes in TC ops.
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D97876
Added:
Modified:
mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
Removed:
################################################################################
diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
index 7627a017a07e..8c4c31da2fce 100644
--- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
+++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
@@ -85,12 +85,12 @@ def test3(A: f32(Batch, M, K), B: f32(K, N)) -> (C: f32(Batch, M, N)) {
// Test attribute definitions
// ODS-LABEL: def Test4Op
// ODS: F32ArrayAttr:$array_attr,
-// ODS: F32:$f32_attr,
+// ODS: F32Attr:$f32_attr,
// ODS: RankedF32ElementsAttr<[4]>:$fvec_attr,
-// ODS: I32:$i32_attr,
-// ODS: I64:$i64_attr,
+// ODS: I32Attr:$i32_attr,
+// ODS: I64Attr:$i64_attr,
// ODS: RankedI32ElementsAttr<[5, 6]>:$ivec_attr,
-// ODS: OptionalAttr<F32>:$optional_attr
+// ODS: OptionalAttr<F32Attr>:$optional_attr
//
// ODS: bool hasDynamicIndexingMaps();
// ODS: LogicalResult verifyIndexingMapRequiredAttributes();
diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
index 9d2d26a5cbd2..7165a0fe89fe 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
@@ -1174,7 +1174,7 @@ class TCParser {
// Returns the function to get values at the given indices from this
// attribute.
- std::string getValueFn(ArrayRef<uint64_t> indices) const;
+ llvm::Optional<std::string> getValueFn(ArrayRef<uint64_t> indices) const;
};
//===--------------------------------------------------------------------===//
@@ -1841,16 +1841,19 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
const auto &dims = attr.second.vectorDims;
if (!dims.empty()) {
+ // Vector case
SmallVector<std::string, 4> dimStrs;
for (uint64_t dim : dims)
dimStrs.push_back(std::to_string(dim));
odsType = llvm::formatv("Ranked{0}ElementsAttr<[{1}]>", odsType,
llvm::join(dimStrs, ", "));
- }
-
- assert(dims.empty() || !attr.second.isArray);
- if (attr.second.isArray)
+ } else if (attr.second.isArray) {
+ // Array case
odsType = llvm::formatv("{0}ArrayAttr", odsType);
+ } else {
+ // Scalar case
+ odsType = llvm::formatv("{0}Attr", odsType);
+ }
if (attr.second.isOptional)
odsType = llvm::formatv("OptionalAttr<{0}>", odsType);
@@ -2242,13 +2245,14 @@ void TCParser::printReferenceIndexingMaps(llvm::raw_ostream &os,
StringRef attrName = attrUse.value().attrName;
auto it = registeredAttrs.find(attrName.str());
assert(it != registeredAttrs.end() && "uses should point to valid attr!");
- std::string getValueFn = it->second.getValueFn(attrUse.value().indices);
- if (getValueFn.empty()) {
+ llvm::Optional<std::string> getValueFn =
+ it->second.getValueFn(attrUse.value().indices);
+ if (!getValueFn) {
(void)parser.emitError("unimplemented getValueFn for attribute: " +
attrName);
return;
}
- std::string cstVal = llvm::formatv("{0}().{1}", attrName, getValueFn);
+ std::string cstVal = llvm::formatv("{0}(){1}", attrName, *getValueFn);
const char *cstFmt =
"\n\tauto cst{0} = getAffineConstantExpr({1}, context);";
mapsStringStream << llvm::formatv(cstFmt, attrUse.index(), cstVal);
@@ -2374,10 +2378,10 @@ void TCParser::printRegionBuilder(llvm::raw_ostream &os, StringRef cppOpName,
expressionsStr, yieldStr);
}
-std::string
+llvm::Optional<std::string>
TCParser::RegisteredAttr::getValueFn(ArrayRef<uint64_t> indices) const {
if (isArray)
- return "";
+ return llvm::None;
if (!vectorDims.empty()) {
SmallVector<std::string, 4> indexStrs;
@@ -2385,20 +2389,20 @@ TCParser::RegisteredAttr::getValueFn(ArrayRef<uint64_t> indices) const {
indexStrs.push_back(std::to_string(index));
std::string indexList = llvm::join(indexStrs, ", ");
if (elementType == "f32")
- return llvm::formatv("getValue<float>({ {0} })", indexList);
+ return llvm::formatv(".getValue<float>({ {0} })", indexList).str();
if (elementType == "i32")
- return llvm::formatv("getValue<int>({ {0} })", indexList);
+ return llvm::formatv(".getValue<int>({ {0} })", indexList).str();
if (elementType == "i64")
- return llvm::formatv("getValue<int64_t>({ {0} })", indexList);
+ return llvm::formatv(".getValue<int64_t>({ {0} })", indexList).str();
- return "";
+ return llvm::None;
}
if (elementType == "f32")
- return "getValue().convertToFloat()";
+ return std::string(".convertToFloat()");
if (elementType == "i32" || elementType == "i64")
- return "getInt()";
- return "";
+ return std::string("");
+ return llvm::None;
}
/// Iterate over each Tensor Comprehension def.
More information about the Mlir-commits
mailing list