[Mlir-commits] [mlir] 9414db1 - [mlir][spirv] Add a pass to deduce version/extension/capability

Lei Zhang llvmlistbot at llvm.org
Thu Mar 12 16:39:51 PDT 2020


Author: Lei Zhang
Date: 2020-03-12T19:37:45-04:00
New Revision: 9414db10906a845e8e485a22102440833d131e48

URL: https://github.com/llvm/llvm-project/commit/9414db10906a845e8e485a22102440833d131e48
DIFF: https://github.com/llvm/llvm-project/commit/9414db10906a845e8e485a22102440833d131e48.diff

LOG: [mlir][spirv] Add a pass to deduce version/extension/capability

Creates an operation pass that deduces and attaches the minimal version/
capabilities/extensions requirements for spv.module ops.

For each spv.module op, this pass requires a `spv.target_env` attribute on
it or an enclosing module-like op to drive the deduction. The reason is
that an op can be enabled by multiple extensions/capabilities. So we need
to know which one to pick. `spv.target_env` gives the hard limit as for
what the target environment can support; this pass deduces what are
actually needed for a specific spv.module op.

Differential Revision: https://reviews.llvm.org/D75870

Added: 
    mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
    mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir

Modified: 
    mlir/include/mlir/Dialect/SPIRV/Passes.h
    mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h
    mlir/include/mlir/InitAllPasses.h
    mlir/lib/Dialect/SPIRV/TargetAndABI.cpp
    mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/Passes.h b/mlir/include/mlir/Dialect/SPIRV/Passes.h
index e14fbe918c3b..fc13460b797b 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Passes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Passes.h
@@ -26,12 +26,24 @@ class ModuleOp;
 std::unique_ptr<OpPassBase<mlir::ModuleOp>>
 createDecorateSPIRVCompositeTypeLayoutPass();
 
-/// Creates a module pass that lowers the ABI attributes specified during SPIR-V
-/// Lowering. Specifically,
-/// 1) Creates the global variables for arguments of entry point function using
-/// the specification in the ABI attributes for each argument.
-/// 2) Inserts the EntryPointOp and the ExecutionModeOp for entry point
-/// functions using the specification in the EntryPointAttr.
+/// Creates an operation pass that deduces and attaches the minimal version/
+/// capabilities/extensions requirements for spv.module ops.
+/// For each spv.module op, this pass requires a `spv.target_env` attribute on
+/// it or an enclosing module-like op to drive the deduction. The reason is
+/// that an op can be enabled by multiple extensions/capabilities. So we need
+/// to know which one to pick. `spv.target_env` gives the hard limit as for
+/// what the target environment can support; this pass deduces what are
+/// actually needed for a specific spv.module op.
+std::unique_ptr<OpPassBase<spirv::ModuleOp>>
+createUpdateVersionCapabilityExtensionPass();
+
+/// Creates an operation pass that lowers the ABI attributes specified during
+/// SPIR-V Lowering. Specifically,
+/// 1. Creates the global variables for arguments of entry point function using
+///    the specification in the `spv.interface_var_abi` attribute for each
+///    argument.
+/// 2. Inserts the EntryPointOp and the ExecutionModeOp for entry point
+///    functions using the specification in the `spv.entry_point_abi` attribute.
 std::unique_ptr<OpPassBase<spirv::ModuleOp>> createLowerABIAttributesPass();
 
 } // namespace spirv

diff  --git a/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h b/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h
index 63f7e267ecff..d00468007446 100644
--- a/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h
+++ b/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h
@@ -174,6 +174,10 @@ StringRef getTargetEnvAttrName();
 /// and no extra extensions.
 TargetEnvAttr getDefaultTargetEnv(MLIRContext *context);
 
+/// Queries the target environment recursively from enclosing symbol table ops
+/// containing the given `op`.
+TargetEnvAttr lookupTargetEnv(Operation *op);
+
 /// Queries the target environment recursively from enclosing symbol table ops
 /// containing the given `op` or returns the default target environment as
 /// returned by getDefaultTargetEnv() if not provided.

diff  --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h
index 4df27e1bc1cd..009cc7309cff 100644
--- a/mlir/include/mlir/InitAllPasses.h
+++ b/mlir/include/mlir/InitAllPasses.h
@@ -122,6 +122,7 @@ inline void registerAllPasses() {
   // SPIR-V
   spirv::createDecorateSPIRVCompositeTypeLayoutPass();
   spirv::createLowerABIAttributesPass();
+  spirv::createUpdateVersionCapabilityExtensionPass();
   createConvertGPUToSPIRVPass();
   createConvertStandardToSPIRVPass();
   createLegalizeStdOpsForSPIRVLoweringPass();

diff  --git a/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp b/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp
index 1bd8729347ac..b2c1dda88294 100644
--- a/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp
+++ b/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp
@@ -294,19 +294,25 @@ spirv::TargetEnvAttr spirv::getDefaultTargetEnv(MLIRContext *context) {
                                    spirv::getDefaultResourceLimits(context));
 }
 
-spirv::TargetEnvAttr spirv::lookupTargetEnvOrDefault(Operation *op) {
-  Operation *symTable = op;
-  while (symTable) {
-    symTable = SymbolTable::getNearestSymbolTable(symTable);
-    if (!symTable)
+spirv::TargetEnvAttr spirv::lookupTargetEnv(Operation *op) {
+  while (op) {
+    op = SymbolTable::getNearestSymbolTable(op);
+    if (!op)
       break;
 
-    if (auto attr = symTable->getAttrOfType<spirv::TargetEnvAttr>(
+    if (auto attr = op->getAttrOfType<spirv::TargetEnvAttr>(
             spirv::getTargetEnvAttrName()))
       return attr;
 
-    symTable = symTable->getParentOp();
+    op = op->getParentOp();
   }
 
+  return {};
+}
+
+spirv::TargetEnvAttr spirv::lookupTargetEnvOrDefault(Operation *op) {
+  if (spirv::TargetEnvAttr attr = spirv::lookupTargetEnv(op))
+    return attr;
+
   return getDefaultTargetEnv(op->getContext());
 }

diff  --git a/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt
index a722f9dac2be..f31c4836e7f7 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
 add_mlir_dialect_library(MLIRSPIRVTransforms
   DecorateSPIRVCompositeTypeLayoutPass.cpp
   LowerABIAttributesPass.cpp
+  UpdateVCEPass.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SPIRV

diff  --git a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
new file mode 100644
index 000000000000..26597dc46340
--- /dev/null
+++ b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
@@ -0,0 +1,164 @@
+//===- DeduceVersionExtensionCapabilityPass.cpp ---------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to deduce minimal version/extension/capability
+// requirements for a spirv::ModuleOp.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SPIRV/Passes.h"
+#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/SPIRVOps.h"
+#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
+#include "mlir/Dialect/SPIRV/TargetAndABI.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Visitors.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallSet.h"
+
+using namespace mlir;
+
+namespace {
+/// Pass to deduce minimal version/extension/capability requirements for a
+/// spirv::ModuleOp.
+class UpdateVCEPass final
+    : public OperationPass<UpdateVCEPass, spirv::ModuleOp> {
+private:
+  void runOnOperation() override;
+};
+} // namespace
+
+void UpdateVCEPass::runOnOperation() {
+  spirv::ModuleOp module = getOperation();
+
+  spirv::TargetEnvAttr targetEnv = spirv::lookupTargetEnv(module);
+  if (!targetEnv) {
+    module.emitError("missing 'spv.target_env' attribute");
+    return signalPassFailure();
+  }
+
+  spirv::Version allowedVersion = targetEnv.getVersion();
+
+  // Build a set for available extensions in the target environment.
+  llvm::SmallSet<spirv::Extension, 4> allowedExtensions;
+  for (spirv::Extension ext : targetEnv.getExtensions())
+    allowedExtensions.insert(ext);
+
+  // Add extensions implied by the current version.
+  for (spirv::Extension ext : spirv::getImpliedExtensions(allowedVersion))
+    allowedExtensions.insert(ext);
+
+  // Build a set for available capabilities in the target environment.
+  llvm::SmallSet<spirv::Capability, 8> allowedCapabilities;
+  for (spirv::Capability cap : targetEnv.getCapabilities()) {
+    allowedCapabilities.insert(cap);
+
+    // Add capabilities implied by the current capability.
+    for (spirv::Capability c : spirv::getRecursiveImpliedCapabilities(cap))
+      allowedCapabilities.insert(c);
+  }
+
+  spirv::Version deducedVersion = spirv::Version::V_1_0;
+  llvm::SetVector<spirv::Extension> deducedExtensions;
+  llvm::SetVector<spirv::Capability> deducedCapabilities;
+
+  // Walk each SPIR-V op to deduce the minimal version/extension/capability
+  // requirements.
+  WalkResult walkResult = module.walk([&](Operation *op) -> WalkResult {
+    if (auto minVersion = dyn_cast<spirv::QueryMinVersionInterface>(op)) {
+      deducedVersion = std::max(deducedVersion, minVersion.getMinVersion());
+      if (deducedVersion > allowedVersion) {
+        return op->emitError("'") << op->getName() << "' requires min version "
+                                  << spirv::stringifyVersion(deducedVersion)
+                                  << " but target environment allows up to "
+                                  << spirv::stringifyVersion(allowedVersion);
+      }
+    }
+
+    // Deduce this op's extension requirement. For each op, the query interfacce
+    // returns a vector of vector for its extension requirements following
+    // ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
+    // convention. Ops not implementing QueryExtensionInterface do not require
+    // extensions to be available.
+    if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op)) {
+      for (const auto &ors : extensions.getExtensions()) {
+        bool satisfied = false; // True when at least one extension can be used
+        for (spirv::Extension ext : ors) {
+          if (allowedExtensions.count(ext)) {
+            deducedExtensions.insert(ext);
+            satisfied = true;
+            break;
+          }
+        }
+
+        if (!satisfied) {
+          SmallVector<StringRef, 4> extStrings;
+          for (spirv::Extension ext : ors)
+            extStrings.push_back(spirv::stringifyExtension(ext));
+
+          return op->emitError("'")
+                 << op->getName() << "' requires at least one extension in ["
+                 << llvm::join(extStrings, ", ")
+                 << "] but none allowed in target environment";
+        }
+      }
+    }
+
+    // Deduce this op's capability requirement. For each op, the queryinterface
+    // returns a vector of vector for its capability requirements following
+    // ((Capability::A OR Extension::B) AND (Capability::C OR Capability::D))
+    // convention. Ops not implementing QueryExtensionInterface do not require
+    // extensions to be available.
+    if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op)) {
+      for (const auto &ors : capabilities.getCapabilities()) {
+        bool satisfied = false; // True when at least one capability can be used
+        for (spirv::Capability cap : ors) {
+          if (allowedCapabilities.count(cap)) {
+            deducedCapabilities.insert(cap);
+            satisfied = true;
+            break;
+          }
+        }
+
+        if (!satisfied) {
+          SmallVector<StringRef, 4> capStrings;
+          for (spirv::Capability cap : ors)
+            capStrings.push_back(spirv::stringifyCapability(cap));
+
+          return op->emitError("'")
+                 << op->getName() << "' requires at least one capability in ["
+                 << llvm::join(capStrings, ", ")
+                 << "] but none allowed in target environment";
+        }
+      }
+    }
+
+    return WalkResult::advance();
+  });
+
+  if (walkResult.wasInterrupted())
+    return signalPassFailure();
+
+  // TODO(antiagainst): verify that the deduced version is consistent with
+  // SPIR-V ops' maximal version requirements.
+
+  auto triple = spirv::VerCapExtAttr::get(
+      deducedVersion, deducedCapabilities.getArrayRef(),
+      deducedExtensions.getArrayRef(), &getContext());
+  module.setAttr("vce_triple", triple);
+}
+
+std::unique_ptr<OpPassBase<spirv::ModuleOp>>
+mlir::spirv::createUpdateVersionCapabilityExtensionPass() {
+  return std::make_unique<UpdateVCEPass>();
+}
+
+static PassRegistration<UpdateVCEPass>
+    pass("spirv-update-vce",
+         "Deduce and attach minimal (version, capabilities, extensions) "
+         "requirements to spv.module ops");

diff  --git a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
new file mode 100644
index 000000000000..4f43a77c48c9
--- /dev/null
+++ b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
@@ -0,0 +1,146 @@
+// RUN: mlir-opt -spirv-update-vce %s | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// Version
+//===----------------------------------------------------------------------===//
+
+// Test deducing minimal version.
+// spv.IAdd is available from v1.0.
+
+// CHECK: vce_triple = #spv.vce<v1.0, [Shader], []>
+spv.module "Logical" "GLSL450" {
+  spv.func @iadd(%val : i32) -> i32 "None" {
+    %0 = spv.IAdd %val, %val: i32
+    spv.ReturnValue %0: i32
+  }
+} attributes {
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.5, [Shader], []>,
+    {max_compute_workgroup_invocations = 128 : i32,
+     max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+}
+
+// Test deducing minimal version.
+// spv.GroupNonUniformBallot is available since v1.3.
+
+// CHECK: vce_triple = #spv.vce<v1.3, [GroupNonUniformBallot, Shader], []>
+spv.module "Logical" "GLSL450" {
+  spv.func @group_non_uniform_ballot(%predicate : i1) -> vector<4xi32> "None" {
+    %0 = spv.GroupNonUniformBallot "Workgroup" %predicate : vector<4xi32>
+    spv.ReturnValue %0: vector<4xi32>
+  }
+} attributes {
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.5, [Shader, GroupNonUniformBallot], []>,
+    {max_compute_workgroup_invocations = 128 : i32,
+     max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+}
+
+//===----------------------------------------------------------------------===//
+// Capability
+//===----------------------------------------------------------------------===//
+
+// Test minimal capabilities.
+
+// CHECK: vce_triple = #spv.vce<v1.0, [Shader], []>
+spv.module "Logical" "GLSL450" {
+  spv.func @iadd(%val : i32) -> i32 "None" {
+    %0 = spv.IAdd %val, %val: i32
+    spv.ReturnValue %0: i32
+  }
+} attributes {
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.0, [Shader, Float16, Float64, Int16, Int64, VariablePointers], []>,
+    {max_compute_workgroup_invocations = 128 : i32,
+     max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+}
+
+// Test deducing implied capability.
+// AtomicStorage implies Shader.
+
+// CHECK: vce_triple = #spv.vce<v1.0, [Shader], []>
+spv.module "Logical" "GLSL450" {
+  spv.func @iadd(%val : i32) -> i32 "None" {
+    %0 = spv.IAdd %val, %val: i32
+    spv.ReturnValue %0: i32
+  }
+} attributes {
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.0, [AtomicStorage], []>,
+    {max_compute_workgroup_invocations = 128 : i32,
+     max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+}
+
+// Test selecting the capability available in the target environment.
+// spv.GroupNonUniform op itself can be enabled via any of
+// * GroupNonUniformArithmetic
+// * GroupNonUniformClustered
+// * GroupNonUniformPartitionedNV
+// Its 'Reduce' group operation can be enabled via any of
+// * Kernel
+// * GroupNonUniformArithmetic
+// * GroupNonUniformBallot
+
+// CHECK: vce_triple = #spv.vce<v1.3, [GroupNonUniformArithmetic, Shader], []>
+spv.module "Logical" "GLSL450" {
+  spv.func @group_non_uniform_iadd(%val : i32) -> i32 "None" {
+    %0 = spv.GroupNonUniformIAdd "Subgroup" "Reduce" %val : i32
+    spv.ReturnValue %0: i32
+  }
+} attributes {
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.3, [Shader, GroupNonUniformArithmetic], []>,
+    {max_compute_workgroup_invocations = 128 : i32,
+     max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+}
+
+// CHECK: vce_triple = #spv.vce<v1.3, [GroupNonUniformClustered, GroupNonUniformBallot, Shader], []>
+spv.module "Logical" "GLSL450" {
+  spv.func @group_non_uniform_iadd(%val : i32) -> i32 "None" {
+    %0 = spv.GroupNonUniformIAdd "Subgroup" "Reduce" %val : i32
+    spv.ReturnValue %0: i32
+  }
+} attributes {
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.3, [Shader, GroupNonUniformClustered, GroupNonUniformBallot], []>,
+    {max_compute_workgroup_invocations = 128 : i32,
+     max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+}
+
+//===----------------------------------------------------------------------===//
+// Extension
+//===----------------------------------------------------------------------===//
+
+// Test deducing minimal extensions.
+// spv.SubgroupBallotKHR requires the SPV_KHR_shader_ballot extension.
+
+// CHECK: vce_triple = #spv.vce<v1.0, [SubgroupBallotKHR, Shader], [SPV_KHR_shader_ballot]>
+spv.module "Logical" "GLSL450" {
+  spv.func @subgroup_ballot(%predicate : i1) -> vector<4xi32> "None" {
+    %0 = spv.SubgroupBallotKHR %predicate: vector<4xi32>
+    spv.ReturnValue %0: vector<4xi32>
+  }
+} attributes {
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.0, [Shader, SubgroupBallotKHR],
+             [SPV_KHR_shader_ballot, SPV_KHR_shader_clock, SPV_KHR_variable_pointers]>,
+    {max_compute_workgroup_invocations = 128 : i32,
+     max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+}
+
+// Test deducing implied extension.
+// Vulkan memory model requires SPV_KHR_vulkan_memory_model, which is enabled
+// implicitly by v1.5.
+
+// CHECK: vce_triple = #spv.vce<v1.0, [VulkanMemoryModel], [SPV_KHR_vulkan_memory_model]>
+spv.module "Logical" "Vulkan" {
+  spv.func @iadd(%val : i32) -> i32 "None" {
+    %0 = spv.IAdd %val, %val: i32
+    spv.ReturnValue %0: i32
+  }
+} attributes {
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.5, [Shader, VulkanMemoryModel], []>,
+    {max_compute_workgroup_invocations = 128 : i32,
+     max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+}


        


More information about the Mlir-commits mailing list