[Mlir-commits] [mlir] 2ae5e47 - [mlir][spirv] Use SmallVector<ArrayRef> for availability queries

Lei Zhang llvmlistbot at llvm.org
Thu Mar 12 16:39:58 PDT 2020


Author: Lei Zhang
Date: 2020-03-12T19:37:45-04:00
New Revision: 2ae5e472e6427795ce0efc727f3dc616c912856b

URL: https://github.com/llvm/llvm-project/commit/2ae5e472e6427795ce0efc727f3dc616c912856b
DIFF: https://github.com/llvm/llvm-project/commit/2ae5e472e6427795ce0efc727f3dc616c912856b.diff

LOG: [mlir][spirv] Use SmallVector<ArrayRef> for availability queries

Previously extensions and capabilities requirements are returned as
SmallVector<SmallVector>. It's an anti-pattern; this commit improves
a bit by returning as SmallVector<ArrayRef>. This is possible because
the internal sequence is always known statically (from the spec)
so that we can use a static constant array for it and get an ArrayRef.

Differential Revision: https://reviews.llvm.org/D75874

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/SPIRVAvailability.td
    mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
    mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVAvailability.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVAvailability.td
index 0231d4805fa1..712237c895dd 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVAvailability.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVAvailability.td
@@ -47,6 +47,9 @@ class Availability {
 
   // The following are fields for a concrete availability instance.
 
+  // The code for preparing a concrete instance. This should be C++ statements
+  // and will be generated before the `mergeAction` logic.
+  code instancePreparation = "";
   // The availability requirement carried by a concrete instance.
   string instance = ?;
 }

diff  --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
index 8ef1e363eebc..faedfb3993cb 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
@@ -148,23 +148,27 @@ class Extension<list<StrEnumAttrCase> extensions> : Availability {
     AND (`Extension::C`) AND (`Extension::D` OR `Extension::E`) is enabled.
   }];
 
-  // TODO(antiagainst): Using SmallVector<SmallVector<...>> is an anti-pattern.
+  // TODO(antiagainst): Returning SmallVector<ArrayRef<...>> is not recommended.
   // Find a better way for this.
-  let queryFnRetType = "::llvm::SmallVector<::llvm::SmallVector<"
-                          "::mlir::spirv::Extension, 1>, 1>";
+  let queryFnRetType = "::llvm::SmallVector<::llvm::ArrayRef<"
+                          "::mlir::spirv::Extension>, 1>";
   let queryFnName = "getExtensions";
 
   let mergeAction = !if(
       !empty(extensions), "", "$overall.emplace_back($instance)");
   let initializer = "{}";
-  let instanceType = "::llvm::SmallVector<::mlir::spirv::Extension, 1>";
+  let instanceType = "::llvm::ArrayRef<::mlir::spirv::Extension>";
 
-  // Compose all capabilities as an C++ initializer list
-  let instance = "std::initializer_list<::mlir::spirv::Extension>{" #
-                 StrJoin<!foreach(
-                   ext, extensions,
-                   "::mlir::spirv::Extension::" # ext.symbol)>.result #
-                 "}";
+  // Pack all extensions as a static array and get its reference.
+  let instancePreparation = !if(!empty(extensions), "",
+    "static const ::mlir::spirv::Extension exts[] = {" #
+    StrJoin<!foreach(ext, extensions,
+      "::mlir::spirv::Extension::" # ext.symbol)>.result #
+    "}; " #
+    // The following manual ArrayRef constructor call is to satisfy GCC 5.
+    "ArrayRef<::mlir::spirv::Extension> " #
+      "ref(exts, ::llvm::array_lengthof(exts));");
+  let instance = "ref";
 }
 
 class Capability<list<I32EnumAttrCase> capabilities> : Availability {
@@ -187,21 +191,25 @@ class Capability<list<I32EnumAttrCase> capabilities> : Availability {
     AND (`Capability::C`) AND (`Capability::D` OR `Capability::E`) is enabled.
   }];
 
-  let queryFnRetType = "::llvm::SmallVector<::llvm::SmallVector<"
-                          "::mlir::spirv::Capability, 1>, 1>";
+  let queryFnRetType = "::llvm::SmallVector<::llvm::ArrayRef<"
+                          "::mlir::spirv::Capability>, 1>";
   let queryFnName = "getCapabilities";
 
   let mergeAction = !if(
       !empty(capabilities), "", "$overall.emplace_back($instance)");
   let initializer = "{}";
-  let instanceType = "::llvm::SmallVector<::mlir::spirv::Capability, 1>";
+  let instanceType = "::llvm::ArrayRef<::mlir::spirv::Capability>";
 
-  // Compose all capabilities as an C++ initializer list
-  let instance = "std::initializer_list<::mlir::spirv::Capability>{" #
-                 StrJoin<!foreach(
-                   cap, capabilities,
-                   "::mlir::spirv::Capability::" # cap.symbol)>.result #
-                 "}";
+  // Pack all capabilities as a static array and get its reference.
+  let instancePreparation = !if(!empty(capabilities), "",
+    "static const ::mlir::spirv::Capability caps[] = {" #
+    StrJoin<!foreach(cap, capabilities,
+      "::mlir::spirv::Capability::" # cap.symbol)>.result #
+    "}; " #
+    // The following manual ArrayRef constructor call is to satisfy GCC 5.
+    "ArrayRef<::mlir::spirv::Capability> " #
+      "ref(caps, ::llvm::array_lengthof(caps));");
+  let instance = "ref";
 }
 
 // TODO(antiagainst): the following interfaces definitions are duplicating with
@@ -216,13 +224,13 @@ def QueryMaxVersionInterface : OpInterface<"QueryMaxVersionInterface"> {
 def QueryExtensionInterface : OpInterface<"QueryExtensionInterface"> {
   let methods = [InterfaceMethod<
     "",
-    "::llvm::SmallVector<::llvm::SmallVector<::mlir::spirv::Extension, 1>, 1>",
+    "::llvm::SmallVector<::llvm::ArrayRef<::mlir::spirv::Extension>, 1>",
     "getExtensions">];
 }
 def QueryCapabilityInterface : OpInterface<"QueryCapabilityInterface"> {
   let methods = [InterfaceMethod<
     "",
-    "::llvm::SmallVector<::llvm::SmallVector<::mlir::spirv::Capability, 1>, 1>",
+    "::llvm::SmallVector<::llvm::ArrayRef<::mlir::spirv::Capability>, 1>",
     "getCapabilities">];
 }
 

diff  --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
index 52c87cc74cdd..a96f2bca8a35 100644
--- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
+++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
@@ -85,6 +85,9 @@ class Availability {
   // Returns the C++ type for an availability instance.
   StringRef getMergeInstanceType() const;
 
+  // Returns the C++ statements for preparing availability instance.
+  StringRef getMergeInstancePreparation() const;
+
   // Returns the concrete availability instance carried in this case.
   StringRef getMergeInstance() const;
 
@@ -137,6 +140,10 @@ StringRef Availability::getMergeInstanceType() const {
   return def->getValueAsString("instanceType");
 }
 
+StringRef Availability::getMergeInstancePreparation() const {
+  return def->getValueAsString("instancePreparation");
+}
+
 StringRef Availability::getMergeInstance() const {
   return def->getValueAsString("instance");
 }
@@ -321,9 +328,9 @@ static void emitAvailabilityQueryForIntEnum(const Record &enumDef,
     for (const auto &caseSpecPair : classCasePair.getValue()) {
       EnumAttrCase enumerant = caseSpecPair.first;
       Availability avail = caseSpecPair.second;
-      os << formatv("  case {0}::{1}: return {2}({3});\n", enumName,
-                    enumerant.getSymbol(), avail.getMergeInstanceType(),
-                    avail.getMergeInstance());
+      os << formatv("  case {0}::{1}: { {2} return {3}({4}); }\n", enumName,
+                    enumerant.getSymbol(), avail.getMergeInstancePreparation(),
+                    avail.getMergeInstanceType(), avail.getMergeInstance());
     }
     // Only emit default if uncovered cases.
     if (classCasePair.getValue().size() < enumAttr.getAllCases().size())
@@ -368,9 +375,9 @@ static void emitAvailabilityQueryForBitEnum(const Record &enumDef,
     for (const auto &caseSpecPair : classCasePair.getValue()) {
       EnumAttrCase enumerant = caseSpecPair.first;
       Availability avail = caseSpecPair.second;
-      os << formatv("  case {0}::{1}: return {2}({3});\n", enumName,
-                    enumerant.getSymbol(), avail.getMergeInstanceType(),
-                    avail.getMergeInstance());
+      os << formatv("  case {0}::{1}: { {2} return {3}({4}); }\n", enumName,
+                    enumerant.getSymbol(), avail.getMergeInstancePreparation(),
+                    avail.getMergeInstanceType(), avail.getMergeInstance());
     }
     os << "  default: break;\n";
     os << "  }\n"
@@ -1162,7 +1169,7 @@ static mlir::GenRegistration
 
 static void emitAvailabilityImpl(const Operator &srcOp, raw_ostream &os) {
   mlir::tblgen::FmtContext fctx;
-  fctx.addSubst("overall", "overall");
+  fctx.addSubst("overall", "tblgen_overall");
 
   std::vector<Availability> opAvailabilities =
       getAvailabilities(srcOp.getDef());
@@ -1195,17 +1202,23 @@ static void emitAvailabilityImpl(const Operator &srcOp, raw_ostream &os) {
                   srcOp.getCppClassName(), avail.getQueryFnName());
 
     // Create the variable for the final requirement and initialize it.
-    os << formatv("  {0} overall = {1};\n", avail.getQueryFnRetType(),
+    os << formatv("  {0} tblgen_overall = {1};\n", avail.getQueryFnRetType(),
                   avail.getMergeInitializer());
 
     // Update with the op's specific availability spec.
     for (const Availability &avail : opAvailabilities)
-      if (avail.getClass() == availClassName) {
-        os << "  "
+      if (avail.getClass() == availClassName &&
+          (!avail.getMergeInstancePreparation().empty() ||
+           !avail.getMergeActionCode().empty())) {
+        os << "  {\n    "
+           // Prepare this instance.
+           << avail.getMergeInstancePreparation()
+           << "\n    "
+           // Merge this instance.
            << std::string(
                   tgfmt(avail.getMergeActionCode(),
                         &fctx.addSubst("instance", avail.getMergeInstance())))
-           << ";\n";
+           << ";\n  }\n";
       }
 
     // Update with enum attributes' specific availability spec.
@@ -1236,30 +1249,32 @@ static void emitAvailabilityImpl(const Operator &srcOp, raw_ostream &os) {
         os << formatv("  for (unsigned i = 0; "
                       "i < std::numeric_limits<{0}>::digits; ++i) {{\n",
                       enumAttr->getUnderlyingType());
-        os << formatv("    {0}::{1} attrVal = this->{2}() & "
+        os << formatv("    {0}::{1} tblgen_attrVal = this->{2}() & "
                       "static_cast<{0}::{1}>(1 << i);\n",
                       enumAttr->getCppNamespace(), enumAttr->getEnumClassName(),
                       namedAttr.name);
-        os << formatv("    if (static_cast<{0}>(attrVal) == 0) continue;\n",
-                      enumAttr->getUnderlyingType());
+        os << formatv(
+            "    if (static_cast<{0}>(tblgen_attrVal) == 0) continue;\n",
+            enumAttr->getUnderlyingType());
       } else {
         // For IntEnumAttr, we just need to query the value as a whole.
         os << "  {\n";
-        os << formatv("    auto attrVal = this->{0}();\n", namedAttr.name);
+        os << formatv("    auto tblgen_attrVal = this->{0}();\n",
+                      namedAttr.name);
       }
-      os << formatv("    auto instance = {0}::{1}(attrVal);\n",
+      os << formatv("    auto tblgen_instance = {0}::{1}(tblgen_attrVal);\n",
                     enumAttr->getCppNamespace(), avail.getQueryFnName());
-      os << "    if (instance) "
+      os << "    if (tblgen_instance) "
          // TODO(antiagainst): use `avail.getMergeCode()` here once ODS supports
          // dialect-specific contents so that we can use not implementing the
          // availability interface as indication of no requirements.
          << std::string(tgfmt(caseSpecs.front().second.getMergeActionCode(),
-                              &fctx.addSubst("instance", "*instance")))
+                              &fctx.addSubst("instance", "*tblgen_instance")))
          << ";\n";
       os << "  }\n";
     }
 
-    os << "  return overall;\n";
+    os << "  return tblgen_overall;\n";
     os << "}\n";
   }
 }


        


More information about the Mlir-commits mailing list