[Mlir-commits] [mlir] 2ae5e47 - [mlir][spirv] Use SmallVector<ArrayRef> for availability queries
Lei Zhang
llvmlistbot at llvm.org
Thu Mar 12 16:39:58 PDT 2020
Author: Lei Zhang
Date: 2020-03-12T19:37:45-04:00
New Revision: 2ae5e472e6427795ce0efc727f3dc616c912856b
URL: https://github.com/llvm/llvm-project/commit/2ae5e472e6427795ce0efc727f3dc616c912856b
DIFF: https://github.com/llvm/llvm-project/commit/2ae5e472e6427795ce0efc727f3dc616c912856b.diff
LOG: [mlir][spirv] Use SmallVector<ArrayRef> for availability queries
Previously extensions and capabilities requirements are returned as
SmallVector<SmallVector>. It's an anti-pattern; this commit improves
a bit by returning as SmallVector<ArrayRef>. This is possible because
the internal sequence is always known statically (from the spec)
so that we can use a static constant array for it and get an ArrayRef.
Differential Revision: https://reviews.llvm.org/D75874
Added:
Modified:
mlir/include/mlir/Dialect/SPIRV/SPIRVAvailability.td
mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVAvailability.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVAvailability.td
index 0231d4805fa1..712237c895dd 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVAvailability.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVAvailability.td
@@ -47,6 +47,9 @@ class Availability {
// The following are fields for a concrete availability instance.
+ // The code for preparing a concrete instance. This should be C++ statements
+ // and will be generated before the `mergeAction` logic.
+ code instancePreparation = "";
// The availability requirement carried by a concrete instance.
string instance = ?;
}
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
index 8ef1e363eebc..faedfb3993cb 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
@@ -148,23 +148,27 @@ class Extension<list<StrEnumAttrCase> extensions> : Availability {
AND (`Extension::C`) AND (`Extension::D` OR `Extension::E`) is enabled.
}];
- // TODO(antiagainst): Using SmallVector<SmallVector<...>> is an anti-pattern.
+ // TODO(antiagainst): Returning SmallVector<ArrayRef<...>> is not recommended.
// Find a better way for this.
- let queryFnRetType = "::llvm::SmallVector<::llvm::SmallVector<"
- "::mlir::spirv::Extension, 1>, 1>";
+ let queryFnRetType = "::llvm::SmallVector<::llvm::ArrayRef<"
+ "::mlir::spirv::Extension>, 1>";
let queryFnName = "getExtensions";
let mergeAction = !if(
!empty(extensions), "", "$overall.emplace_back($instance)");
let initializer = "{}";
- let instanceType = "::llvm::SmallVector<::mlir::spirv::Extension, 1>";
+ let instanceType = "::llvm::ArrayRef<::mlir::spirv::Extension>";
- // Compose all capabilities as an C++ initializer list
- let instance = "std::initializer_list<::mlir::spirv::Extension>{" #
- StrJoin<!foreach(
- ext, extensions,
- "::mlir::spirv::Extension::" # ext.symbol)>.result #
- "}";
+ // Pack all extensions as a static array and get its reference.
+ let instancePreparation = !if(!empty(extensions), "",
+ "static const ::mlir::spirv::Extension exts[] = {" #
+ StrJoin<!foreach(ext, extensions,
+ "::mlir::spirv::Extension::" # ext.symbol)>.result #
+ "}; " #
+ // The following manual ArrayRef constructor call is to satisfy GCC 5.
+ "ArrayRef<::mlir::spirv::Extension> " #
+ "ref(exts, ::llvm::array_lengthof(exts));");
+ let instance = "ref";
}
class Capability<list<I32EnumAttrCase> capabilities> : Availability {
@@ -187,21 +191,25 @@ class Capability<list<I32EnumAttrCase> capabilities> : Availability {
AND (`Capability::C`) AND (`Capability::D` OR `Capability::E`) is enabled.
}];
- let queryFnRetType = "::llvm::SmallVector<::llvm::SmallVector<"
- "::mlir::spirv::Capability, 1>, 1>";
+ let queryFnRetType = "::llvm::SmallVector<::llvm::ArrayRef<"
+ "::mlir::spirv::Capability>, 1>";
let queryFnName = "getCapabilities";
let mergeAction = !if(
!empty(capabilities), "", "$overall.emplace_back($instance)");
let initializer = "{}";
- let instanceType = "::llvm::SmallVector<::mlir::spirv::Capability, 1>";
+ let instanceType = "::llvm::ArrayRef<::mlir::spirv::Capability>";
- // Compose all capabilities as an C++ initializer list
- let instance = "std::initializer_list<::mlir::spirv::Capability>{" #
- StrJoin<!foreach(
- cap, capabilities,
- "::mlir::spirv::Capability::" # cap.symbol)>.result #
- "}";
+ // Pack all capabilities as a static array and get its reference.
+ let instancePreparation = !if(!empty(capabilities), "",
+ "static const ::mlir::spirv::Capability caps[] = {" #
+ StrJoin<!foreach(cap, capabilities,
+ "::mlir::spirv::Capability::" # cap.symbol)>.result #
+ "}; " #
+ // The following manual ArrayRef constructor call is to satisfy GCC 5.
+ "ArrayRef<::mlir::spirv::Capability> " #
+ "ref(caps, ::llvm::array_lengthof(caps));");
+ let instance = "ref";
}
// TODO(antiagainst): the following interfaces definitions are duplicating with
@@ -216,13 +224,13 @@ def QueryMaxVersionInterface : OpInterface<"QueryMaxVersionInterface"> {
def QueryExtensionInterface : OpInterface<"QueryExtensionInterface"> {
let methods = [InterfaceMethod<
"",
- "::llvm::SmallVector<::llvm::SmallVector<::mlir::spirv::Extension, 1>, 1>",
+ "::llvm::SmallVector<::llvm::ArrayRef<::mlir::spirv::Extension>, 1>",
"getExtensions">];
}
def QueryCapabilityInterface : OpInterface<"QueryCapabilityInterface"> {
let methods = [InterfaceMethod<
"",
- "::llvm::SmallVector<::llvm::SmallVector<::mlir::spirv::Capability, 1>, 1>",
+ "::llvm::SmallVector<::llvm::ArrayRef<::mlir::spirv::Capability>, 1>",
"getCapabilities">];
}
diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
index 52c87cc74cdd..a96f2bca8a35 100644
--- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
+++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
@@ -85,6 +85,9 @@ class Availability {
// Returns the C++ type for an availability instance.
StringRef getMergeInstanceType() const;
+ // Returns the C++ statements for preparing availability instance.
+ StringRef getMergeInstancePreparation() const;
+
// Returns the concrete availability instance carried in this case.
StringRef getMergeInstance() const;
@@ -137,6 +140,10 @@ StringRef Availability::getMergeInstanceType() const {
return def->getValueAsString("instanceType");
}
+StringRef Availability::getMergeInstancePreparation() const {
+ return def->getValueAsString("instancePreparation");
+}
+
StringRef Availability::getMergeInstance() const {
return def->getValueAsString("instance");
}
@@ -321,9 +328,9 @@ static void emitAvailabilityQueryForIntEnum(const Record &enumDef,
for (const auto &caseSpecPair : classCasePair.getValue()) {
EnumAttrCase enumerant = caseSpecPair.first;
Availability avail = caseSpecPair.second;
- os << formatv(" case {0}::{1}: return {2}({3});\n", enumName,
- enumerant.getSymbol(), avail.getMergeInstanceType(),
- avail.getMergeInstance());
+ os << formatv(" case {0}::{1}: { {2} return {3}({4}); }\n", enumName,
+ enumerant.getSymbol(), avail.getMergeInstancePreparation(),
+ avail.getMergeInstanceType(), avail.getMergeInstance());
}
// Only emit default if uncovered cases.
if (classCasePair.getValue().size() < enumAttr.getAllCases().size())
@@ -368,9 +375,9 @@ static void emitAvailabilityQueryForBitEnum(const Record &enumDef,
for (const auto &caseSpecPair : classCasePair.getValue()) {
EnumAttrCase enumerant = caseSpecPair.first;
Availability avail = caseSpecPair.second;
- os << formatv(" case {0}::{1}: return {2}({3});\n", enumName,
- enumerant.getSymbol(), avail.getMergeInstanceType(),
- avail.getMergeInstance());
+ os << formatv(" case {0}::{1}: { {2} return {3}({4}); }\n", enumName,
+ enumerant.getSymbol(), avail.getMergeInstancePreparation(),
+ avail.getMergeInstanceType(), avail.getMergeInstance());
}
os << " default: break;\n";
os << " }\n"
@@ -1162,7 +1169,7 @@ static mlir::GenRegistration
static void emitAvailabilityImpl(const Operator &srcOp, raw_ostream &os) {
mlir::tblgen::FmtContext fctx;
- fctx.addSubst("overall", "overall");
+ fctx.addSubst("overall", "tblgen_overall");
std::vector<Availability> opAvailabilities =
getAvailabilities(srcOp.getDef());
@@ -1195,17 +1202,23 @@ static void emitAvailabilityImpl(const Operator &srcOp, raw_ostream &os) {
srcOp.getCppClassName(), avail.getQueryFnName());
// Create the variable for the final requirement and initialize it.
- os << formatv(" {0} overall = {1};\n", avail.getQueryFnRetType(),
+ os << formatv(" {0} tblgen_overall = {1};\n", avail.getQueryFnRetType(),
avail.getMergeInitializer());
// Update with the op's specific availability spec.
for (const Availability &avail : opAvailabilities)
- if (avail.getClass() == availClassName) {
- os << " "
+ if (avail.getClass() == availClassName &&
+ (!avail.getMergeInstancePreparation().empty() ||
+ !avail.getMergeActionCode().empty())) {
+ os << " {\n "
+ // Prepare this instance.
+ << avail.getMergeInstancePreparation()
+ << "\n "
+ // Merge this instance.
<< std::string(
tgfmt(avail.getMergeActionCode(),
&fctx.addSubst("instance", avail.getMergeInstance())))
- << ";\n";
+ << ";\n }\n";
}
// Update with enum attributes' specific availability spec.
@@ -1236,30 +1249,32 @@ static void emitAvailabilityImpl(const Operator &srcOp, raw_ostream &os) {
os << formatv(" for (unsigned i = 0; "
"i < std::numeric_limits<{0}>::digits; ++i) {{\n",
enumAttr->getUnderlyingType());
- os << formatv(" {0}::{1} attrVal = this->{2}() & "
+ os << formatv(" {0}::{1} tblgen_attrVal = this->{2}() & "
"static_cast<{0}::{1}>(1 << i);\n",
enumAttr->getCppNamespace(), enumAttr->getEnumClassName(),
namedAttr.name);
- os << formatv(" if (static_cast<{0}>(attrVal) == 0) continue;\n",
- enumAttr->getUnderlyingType());
+ os << formatv(
+ " if (static_cast<{0}>(tblgen_attrVal) == 0) continue;\n",
+ enumAttr->getUnderlyingType());
} else {
// For IntEnumAttr, we just need to query the value as a whole.
os << " {\n";
- os << formatv(" auto attrVal = this->{0}();\n", namedAttr.name);
+ os << formatv(" auto tblgen_attrVal = this->{0}();\n",
+ namedAttr.name);
}
- os << formatv(" auto instance = {0}::{1}(attrVal);\n",
+ os << formatv(" auto tblgen_instance = {0}::{1}(tblgen_attrVal);\n",
enumAttr->getCppNamespace(), avail.getQueryFnName());
- os << " if (instance) "
+ os << " if (tblgen_instance) "
// TODO(antiagainst): use `avail.getMergeCode()` here once ODS supports
// dialect-specific contents so that we can use not implementing the
// availability interface as indication of no requirements.
<< std::string(tgfmt(caseSpecs.front().second.getMergeActionCode(),
- &fctx.addSubst("instance", "*instance")))
+ &fctx.addSubst("instance", "*tblgen_instance")))
<< ";\n";
os << " }\n";
}
- os << " return overall;\n";
+ os << " return tblgen_overall;\n";
os << "}\n";
}
}
More information about the Mlir-commits
mailing list