[Mlir-commits] [mlir] cb395f6 - [mlir][spirv] Change the return type for {Min|Max}VersionBase

Lei Zhang llvmlistbot at llvm.org
Wed Nov 24 14:33:16 PST 2021


Author: Lei Zhang
Date: 2021-11-24T17:33:01-05:00
New Revision: cb395f66ac3ce60427ca2b99580e716ac6dd551a

URL: https://github.com/llvm/llvm-project/commit/cb395f66ac3ce60427ca2b99580e716ac6dd551a
DIFF: https://github.com/llvm/llvm-project/commit/cb395f66ac3ce60427ca2b99580e716ac6dd551a.diff

LOG: [mlir][spirv] Change the return type for {Min|Max}VersionBase

For synthesizing an op's implementation of the generated interface
from {Min|Max}Version, we need to define an `initializer` and
`mergeAction`. The `initializer` specifies the initial version,
and `mergeAction` specifies how version specifications from
different parts of the op should be merged to generate a final
version requirements.

Previously we use the specified version enum as the type for both
the initializer and thus the final return type. This means we need
to perform `static_cast` over some hopefully not used number (`~0u`)
as the initializer. This is quite opaque and sort of not guaranteed
to work. Also, there are ops that have an enum attribute where some
values declare version requirements (e.g., enumerant `B` requires
v1.1+) but some not (e.g., enumerant `A` requires nothing). Then a
concrete op instance with `A` will still declare it implements the
version interface (because interface implementation is static for
an op) but actually theirs no requirements for version.

So this commit changes to use an more explicit `llvm::Optional`
to wrap around the returned version enum.  This should make it
more clear.

Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D108312

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAvailability.td
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
    mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
    mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
    mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
    mlir/tools/mlir-opt/mlir-opt.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAvailability.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAvailability.td
index bfb03dac33a79..211a46b971e16 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAvailability.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAvailability.td
@@ -60,12 +60,15 @@ class MinVersionBase<string name, I32EnumAttr scheme, I32EnumAttrCase min>
     : Availability {
   let interfaceName = name;
 
-  let queryFnRetType = scheme.returnType;
+  let queryFnRetType = "llvm::Optional<" # scheme.returnType # ">";
   let queryFnName = "getMinVersion";
 
-  let mergeAction = "$overall = static_cast<" # scheme.returnType # ">("
-                      "std::max($overall, $instance))";
-  let initializer = "static_cast<" # scheme.returnType # ">(uint32_t(0))";
+  let mergeAction = "{ "
+    "if ($overall.hasValue()) { "
+      "$overall = static_cast<" # scheme.returnType # ">("
+                  "std::max(*$overall, $instance)); "
+    "} else { $overall = $instance; }}";
+  let initializer = "::llvm::None";
   let instanceType = scheme.cppNamespace # "::" # scheme.className;
 
   let instance = scheme.cppNamespace # "::" # scheme.className # "::" #
@@ -76,12 +79,15 @@ class MaxVersionBase<string name, I32EnumAttr scheme, I32EnumAttrCase max>
     : Availability {
   let interfaceName = name;
 
-  let queryFnRetType = scheme.returnType;
+  let queryFnRetType = "llvm::Optional<" # scheme.returnType # ">";
   let queryFnName = "getMaxVersion";
 
-  let mergeAction = "$overall = static_cast<" # scheme.returnType # ">("
-                      "std::min($overall, $instance))";
-  let initializer = "static_cast<" # scheme.returnType # ">(~uint32_t(0))";
+  let mergeAction = "{ "
+    "if ($overall.hasValue()) { "
+      "$overall = static_cast<" # scheme.returnType # ">("
+                  "std::min(*$overall, $instance)); "
+    "} else { $overall = $instance; }}";
+  let initializer = "::llvm::None";
   let instanceType = scheme.cppNamespace # "::" # scheme.className;
 
   let instance = scheme.cppNamespace # "::" # scheme.className # "::" #

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 8dd9b5c1fcea9..5429cf8c79320 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -239,10 +239,12 @@ class SPIRVOpInterface<string name> : OpInterface<name> {
 // TODO: the following interfaces definitions are duplicating with the above.
 // Remove them once we are able to support dialect-specific contents in ODS.
 def QueryMinVersionInterface : SPIRVOpInterface<"QueryMinVersionInterface"> {
-  let methods = [InterfaceMethod<"", "::mlir::spirv::Version", "getMinVersion">];
+  let methods = [InterfaceMethod<
+    "", "::llvm::Optional<::mlir::spirv::Version>", "getMinVersion">];
 }
 def QueryMaxVersionInterface : SPIRVOpInterface<"QueryMaxVersionInterface"> {
-  let methods = [InterfaceMethod<"", "::mlir::spirv::Version", "getMaxVersion">];
+  let methods = [InterfaceMethod<
+    "", "::llvm::Optional<::mlir::spirv::Version>", "getMaxVersion">];
 }
 def QueryExtensionInterface : SPIRVOpInterface<"QueryExtensionInterface"> {
   let methods = [InterfaceMethod<

diff  --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 7c63fce6919be..afa26650b4c44 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -843,22 +843,24 @@ bool SPIRVConversionTarget::isLegalOp(Operation *op) {
   // Make sure this op is available at the given version. Ops not implementing
   // QueryMinVersionInterface/QueryMaxVersionInterface are available to all
   // SPIR-V versions.
-  if (auto minVersion = dyn_cast<spirv::QueryMinVersionInterface>(op))
-    if (minVersion.getMinVersion() > this->targetEnv.getVersion()) {
+  if (auto minVersionIfx = dyn_cast<spirv::QueryMinVersionInterface>(op)) {
+    Optional<spirv::Version> minVersion = minVersionIfx.getMinVersion();
+    if (minVersion && *minVersion > this->targetEnv.getVersion()) {
       LLVM_DEBUG(llvm::dbgs()
                  << op->getName() << " illegal: requiring min version "
-                 << spirv::stringifyVersion(minVersion.getMinVersion())
-                 << "\n");
+                 << spirv::stringifyVersion(*minVersion) << "\n");
       return false;
     }
-  if (auto maxVersion = dyn_cast<spirv::QueryMaxVersionInterface>(op))
-    if (maxVersion.getMaxVersion() < this->targetEnv.getVersion()) {
+  }
+  if (auto maxVersionIfx = dyn_cast<spirv::QueryMaxVersionInterface>(op)) {
+    Optional<spirv::Version> maxVersion = maxVersionIfx.getMaxVersion();
+    if (maxVersion && *maxVersion < this->targetEnv.getVersion()) {
       LLVM_DEBUG(llvm::dbgs()
                  << op->getName() << " illegal: requiring max version "
-                 << spirv::stringifyVersion(maxVersion.getMaxVersion())
-                 << "\n");
+                 << spirv::stringifyVersion(*maxVersion) << "\n");
       return false;
     }
+  }
 
   // Make sure this op's required extensions are allowed to use. Ops not
   // implementing QueryExtensionInterface do not require extensions to be

diff  --git a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
index 6296369315254..3232c984b290a 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
@@ -109,13 +109,17 @@ void UpdateVCEPass::runOnOperation() {
   // requirements.
   WalkResult walkResult = module.walk([&](Operation *op) -> WalkResult {
     // Op min version requirements
-    if (auto minVersion = dyn_cast<spirv::QueryMinVersionInterface>(op)) {
-      deducedVersion = std::max(deducedVersion, minVersion.getMinVersion());
-      if (deducedVersion > allowedVersion) {
-        return op->emitError("'") << op->getName() << "' requires min version "
-                                  << spirv::stringifyVersion(deducedVersion)
-                                  << " but target environment allows up to "
-                                  << spirv::stringifyVersion(allowedVersion);
+    if (auto minVersionIfx = dyn_cast<spirv::QueryMinVersionInterface>(op)) {
+      Optional<spirv::Version> minVersion = minVersionIfx.getMinVersion();
+      if (minVersion) {
+        deducedVersion = std::max(deducedVersion, *minVersion);
+        if (deducedVersion > allowedVersion) {
+          return op->emitError("'")
+                 << op->getName() << "' requires min version "
+                 << spirv::stringifyVersion(deducedVersion)
+                 << " but target environment allows up to "
+                 << spirv::stringifyVersion(allowedVersion);
+        }
       }
     }
 

diff  --git a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
index 1887fbd353ea5..f68cf63890559 100644
--- a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
+++ b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
@@ -43,13 +43,23 @@ void PrintOpAvailability::runOnFunction() {
     auto opName = op->getName();
     auto &os = llvm::outs();
 
-    if (auto minVersion = dyn_cast<spirv::QueryMinVersionInterface>(op))
-      os << opName << " min version: "
-         << spirv::stringifyVersion(minVersion.getMinVersion()) << "\n";
+    if (auto minVersionIfx = dyn_cast<spirv::QueryMinVersionInterface>(op)) {
+      Optional<spirv::Version> minVersion = minVersionIfx.getMinVersion();
+      os << opName << " min version: ";
+      if (minVersion)
+        os << spirv::stringifyVersion(*minVersion) << "\n";
+      else
+        os << "None\n";
+    }
 
-    if (auto maxVersion = dyn_cast<spirv::QueryMaxVersionInterface>(op))
-      os << opName << " max version: "
-         << spirv::stringifyVersion(maxVersion.getMaxVersion()) << "\n";
+    if (auto maxVersionIfx = dyn_cast<spirv::QueryMaxVersionInterface>(op)) {
+      Optional<spirv::Version> maxVersion = maxVersionIfx.getMaxVersion();
+      os << opName << " max version: ";
+      if (maxVersion)
+        os << spirv::stringifyVersion(*maxVersion) << "\n";
+      else
+        os << "None\n";
+    }
 
     if (auto extension = dyn_cast<spirv::QueryExtensionInterface>(op)) {
       os << opName << " extensions: [";
@@ -81,7 +91,7 @@ void PrintOpAvailability::runOnFunction() {
 }
 
 namespace mlir {
-void registerPrintOpAvailabilityPass() {
+void registerPrintSpirvAvailabilityPass() {
   PassRegistration<PrintOpAvailability>();
 }
 } // namespace mlir

diff  --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 285b06d7aa831..6b77c37353eca 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -31,7 +31,7 @@ using namespace mlir;
 namespace mlir {
 void registerConvertToTargetEnvPass();
 void registerPassManagerTestPass();
-void registerPrintOpAvailabilityPass();
+void registerPrintSpirvAvailabilityPass();
 void registerShapeFunctionTestPasses();
 void registerSideEffectTestPasses();
 void registerSliceAnalysisTestPass();
@@ -119,7 +119,7 @@ void registerTestDialect(DialectRegistry &);
 void registerTestPasses() {
   registerConvertToTargetEnvPass();
   registerPassManagerTestPass();
-  registerPrintOpAvailabilityPass();
+  registerPrintSpirvAvailabilityPass();
   registerShapeFunctionTestPasses();
   registerSideEffectTestPasses();
   registerSliceAnalysisTestPass();


        


More information about the Mlir-commits mailing list