[Mlir-commits] [mlir] 9414db1 - [mlir][spirv] Add a pass to deduce version/extension/capability
Lei Zhang
llvmlistbot at llvm.org
Thu Mar 12 16:39:51 PDT 2020
Author: Lei Zhang
Date: 2020-03-12T19:37:45-04:00
New Revision: 9414db10906a845e8e485a22102440833d131e48
URL: https://github.com/llvm/llvm-project/commit/9414db10906a845e8e485a22102440833d131e48
DIFF: https://github.com/llvm/llvm-project/commit/9414db10906a845e8e485a22102440833d131e48.diff
LOG: [mlir][spirv] Add a pass to deduce version/extension/capability
Creates an operation pass that deduces and attaches the minimal version/
capabilities/extensions requirements for spv.module ops.
For each spv.module op, this pass requires a `spv.target_env` attribute on
it or an enclosing module-like op to drive the deduction. The reason is
that an op can be enabled by multiple extensions/capabilities. So we need
to know which one to pick. `spv.target_env` gives the hard limit as for
what the target environment can support; this pass deduces what are
actually needed for a specific spv.module op.
Differential Revision: https://reviews.llvm.org/D75870
Added:
mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
Modified:
mlir/include/mlir/Dialect/SPIRV/Passes.h
mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h
mlir/include/mlir/InitAllPasses.h
mlir/lib/Dialect/SPIRV/TargetAndABI.cpp
mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/Passes.h b/mlir/include/mlir/Dialect/SPIRV/Passes.h
index e14fbe918c3b..fc13460b797b 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Passes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Passes.h
@@ -26,12 +26,24 @@ class ModuleOp;
std::unique_ptr<OpPassBase<mlir::ModuleOp>>
createDecorateSPIRVCompositeTypeLayoutPass();
-/// Creates a module pass that lowers the ABI attributes specified during SPIR-V
-/// Lowering. Specifically,
-/// 1) Creates the global variables for arguments of entry point function using
-/// the specification in the ABI attributes for each argument.
-/// 2) Inserts the EntryPointOp and the ExecutionModeOp for entry point
-/// functions using the specification in the EntryPointAttr.
+/// Creates an operation pass that deduces and attaches the minimal version/
+/// capabilities/extensions requirements for spv.module ops.
+/// For each spv.module op, this pass requires a `spv.target_env` attribute on
+/// it or an enclosing module-like op to drive the deduction. The reason is
+/// that an op can be enabled by multiple extensions/capabilities. So we need
+/// to know which one to pick. `spv.target_env` gives the hard limit as for
+/// what the target environment can support; this pass deduces what are
+/// actually needed for a specific spv.module op.
+std::unique_ptr<OpPassBase<spirv::ModuleOp>>
+createUpdateVersionCapabilityExtensionPass();
+
+/// Creates an operation pass that lowers the ABI attributes specified during
+/// SPIR-V Lowering. Specifically,
+/// 1. Creates the global variables for arguments of entry point function using
+/// the specification in the `spv.interface_var_abi` attribute for each
+/// argument.
+/// 2. Inserts the EntryPointOp and the ExecutionModeOp for entry point
+/// functions using the specification in the `spv.entry_point_abi` attribute.
std::unique_ptr<OpPassBase<spirv::ModuleOp>> createLowerABIAttributesPass();
} // namespace spirv
diff --git a/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h b/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h
index 63f7e267ecff..d00468007446 100644
--- a/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h
+++ b/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h
@@ -174,6 +174,10 @@ StringRef getTargetEnvAttrName();
/// and no extra extensions.
TargetEnvAttr getDefaultTargetEnv(MLIRContext *context);
+/// Queries the target environment recursively from enclosing symbol table ops
+/// containing the given `op`.
+TargetEnvAttr lookupTargetEnv(Operation *op);
+
/// Queries the target environment recursively from enclosing symbol table ops
/// containing the given `op` or returns the default target environment as
/// returned by getDefaultTargetEnv() if not provided.
diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h
index 4df27e1bc1cd..009cc7309cff 100644
--- a/mlir/include/mlir/InitAllPasses.h
+++ b/mlir/include/mlir/InitAllPasses.h
@@ -122,6 +122,7 @@ inline void registerAllPasses() {
// SPIR-V
spirv::createDecorateSPIRVCompositeTypeLayoutPass();
spirv::createLowerABIAttributesPass();
+ spirv::createUpdateVersionCapabilityExtensionPass();
createConvertGPUToSPIRVPass();
createConvertStandardToSPIRVPass();
createLegalizeStdOpsForSPIRVLoweringPass();
diff --git a/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp b/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp
index 1bd8729347ac..b2c1dda88294 100644
--- a/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp
+++ b/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp
@@ -294,19 +294,25 @@ spirv::TargetEnvAttr spirv::getDefaultTargetEnv(MLIRContext *context) {
spirv::getDefaultResourceLimits(context));
}
-spirv::TargetEnvAttr spirv::lookupTargetEnvOrDefault(Operation *op) {
- Operation *symTable = op;
- while (symTable) {
- symTable = SymbolTable::getNearestSymbolTable(symTable);
- if (!symTable)
+spirv::TargetEnvAttr spirv::lookupTargetEnv(Operation *op) {
+ while (op) {
+ op = SymbolTable::getNearestSymbolTable(op);
+ if (!op)
break;
- if (auto attr = symTable->getAttrOfType<spirv::TargetEnvAttr>(
+ if (auto attr = op->getAttrOfType<spirv::TargetEnvAttr>(
spirv::getTargetEnvAttrName()))
return attr;
- symTable = symTable->getParentOp();
+ op = op->getParentOp();
}
+ return {};
+}
+
+spirv::TargetEnvAttr spirv::lookupTargetEnvOrDefault(Operation *op) {
+ if (spirv::TargetEnvAttr attr = spirv::lookupTargetEnv(op))
+ return attr;
+
return getDefaultTargetEnv(op->getContext());
}
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt
index a722f9dac2be..f31c4836e7f7 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
add_mlir_dialect_library(MLIRSPIRVTransforms
DecorateSPIRVCompositeTypeLayoutPass.cpp
LowerABIAttributesPass.cpp
+ UpdateVCEPass.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SPIRV
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
new file mode 100644
index 000000000000..26597dc46340
--- /dev/null
+++ b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
@@ -0,0 +1,164 @@
+//===- DeduceVersionExtensionCapabilityPass.cpp ---------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to deduce minimal version/extension/capability
+// requirements for a spirv::ModuleOp.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SPIRV/Passes.h"
+#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/SPIRVOps.h"
+#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
+#include "mlir/Dialect/SPIRV/TargetAndABI.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Visitors.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallSet.h"
+
+using namespace mlir;
+
+namespace {
+/// Pass to deduce minimal version/extension/capability requirements for a
+/// spirv::ModuleOp.
+class UpdateVCEPass final
+ : public OperationPass<UpdateVCEPass, spirv::ModuleOp> {
+private:
+ void runOnOperation() override;
+};
+} // namespace
+
+void UpdateVCEPass::runOnOperation() {
+ spirv::ModuleOp module = getOperation();
+
+ spirv::TargetEnvAttr targetEnv = spirv::lookupTargetEnv(module);
+ if (!targetEnv) {
+ module.emitError("missing 'spv.target_env' attribute");
+ return signalPassFailure();
+ }
+
+ spirv::Version allowedVersion = targetEnv.getVersion();
+
+ // Build a set for available extensions in the target environment.
+ llvm::SmallSet<spirv::Extension, 4> allowedExtensions;
+ for (spirv::Extension ext : targetEnv.getExtensions())
+ allowedExtensions.insert(ext);
+
+ // Add extensions implied by the current version.
+ for (spirv::Extension ext : spirv::getImpliedExtensions(allowedVersion))
+ allowedExtensions.insert(ext);
+
+ // Build a set for available capabilities in the target environment.
+ llvm::SmallSet<spirv::Capability, 8> allowedCapabilities;
+ for (spirv::Capability cap : targetEnv.getCapabilities()) {
+ allowedCapabilities.insert(cap);
+
+ // Add capabilities implied by the current capability.
+ for (spirv::Capability c : spirv::getRecursiveImpliedCapabilities(cap))
+ allowedCapabilities.insert(c);
+ }
+
+ spirv::Version deducedVersion = spirv::Version::V_1_0;
+ llvm::SetVector<spirv::Extension> deducedExtensions;
+ llvm::SetVector<spirv::Capability> deducedCapabilities;
+
+ // Walk each SPIR-V op to deduce the minimal version/extension/capability
+ // requirements.
+ WalkResult walkResult = module.walk([&](Operation *op) -> WalkResult {
+ if (auto minVersion = dyn_cast<spirv::QueryMinVersionInterface>(op)) {
+ deducedVersion = std::max(deducedVersion, minVersion.getMinVersion());
+ if (deducedVersion > allowedVersion) {
+ return op->emitError("'") << op->getName() << "' requires min version "
+ << spirv::stringifyVersion(deducedVersion)
+ << " but target environment allows up to "
+ << spirv::stringifyVersion(allowedVersion);
+ }
+ }
+
+ // Deduce this op's extension requirement. For each op, the query interfacce
+ // returns a vector of vector for its extension requirements following
+ // ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
+ // convention. Ops not implementing QueryExtensionInterface do not require
+ // extensions to be available.
+ if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op)) {
+ for (const auto &ors : extensions.getExtensions()) {
+ bool satisfied = false; // True when at least one extension can be used
+ for (spirv::Extension ext : ors) {
+ if (allowedExtensions.count(ext)) {
+ deducedExtensions.insert(ext);
+ satisfied = true;
+ break;
+ }
+ }
+
+ if (!satisfied) {
+ SmallVector<StringRef, 4> extStrings;
+ for (spirv::Extension ext : ors)
+ extStrings.push_back(spirv::stringifyExtension(ext));
+
+ return op->emitError("'")
+ << op->getName() << "' requires at least one extension in ["
+ << llvm::join(extStrings, ", ")
+ << "] but none allowed in target environment";
+ }
+ }
+ }
+
+ // Deduce this op's capability requirement. For each op, the queryinterface
+ // returns a vector of vector for its capability requirements following
+ // ((Capability::A OR Extension::B) AND (Capability::C OR Capability::D))
+ // convention. Ops not implementing QueryExtensionInterface do not require
+ // extensions to be available.
+ if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op)) {
+ for (const auto &ors : capabilities.getCapabilities()) {
+ bool satisfied = false; // True when at least one capability can be used
+ for (spirv::Capability cap : ors) {
+ if (allowedCapabilities.count(cap)) {
+ deducedCapabilities.insert(cap);
+ satisfied = true;
+ break;
+ }
+ }
+
+ if (!satisfied) {
+ SmallVector<StringRef, 4> capStrings;
+ for (spirv::Capability cap : ors)
+ capStrings.push_back(spirv::stringifyCapability(cap));
+
+ return op->emitError("'")
+ << op->getName() << "' requires at least one capability in ["
+ << llvm::join(capStrings, ", ")
+ << "] but none allowed in target environment";
+ }
+ }
+ }
+
+ return WalkResult::advance();
+ });
+
+ if (walkResult.wasInterrupted())
+ return signalPassFailure();
+
+ // TODO(antiagainst): verify that the deduced version is consistent with
+ // SPIR-V ops' maximal version requirements.
+
+ auto triple = spirv::VerCapExtAttr::get(
+ deducedVersion, deducedCapabilities.getArrayRef(),
+ deducedExtensions.getArrayRef(), &getContext());
+ module.setAttr("vce_triple", triple);
+}
+
+std::unique_ptr<OpPassBase<spirv::ModuleOp>>
+mlir::spirv::createUpdateVersionCapabilityExtensionPass() {
+ return std::make_unique<UpdateVCEPass>();
+}
+
+static PassRegistration<UpdateVCEPass>
+ pass("spirv-update-vce",
+ "Deduce and attach minimal (version, capabilities, extensions) "
+ "requirements to spv.module ops");
diff --git a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
new file mode 100644
index 000000000000..4f43a77c48c9
--- /dev/null
+++ b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
@@ -0,0 +1,146 @@
+// RUN: mlir-opt -spirv-update-vce %s | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// Version
+//===----------------------------------------------------------------------===//
+
+// Test deducing minimal version.
+// spv.IAdd is available from v1.0.
+
+// CHECK: vce_triple = #spv.vce<v1.0, [Shader], []>
+spv.module "Logical" "GLSL450" {
+ spv.func @iadd(%val : i32) -> i32 "None" {
+ %0 = spv.IAdd %val, %val: i32
+ spv.ReturnValue %0: i32
+ }
+} attributes {
+ spv.target_env = #spv.target_env<
+ #spv.vce<v1.5, [Shader], []>,
+ {max_compute_workgroup_invocations = 128 : i32,
+ max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+}
+
+// Test deducing minimal version.
+// spv.GroupNonUniformBallot is available since v1.3.
+
+// CHECK: vce_triple = #spv.vce<v1.3, [GroupNonUniformBallot, Shader], []>
+spv.module "Logical" "GLSL450" {
+ spv.func @group_non_uniform_ballot(%predicate : i1) -> vector<4xi32> "None" {
+ %0 = spv.GroupNonUniformBallot "Workgroup" %predicate : vector<4xi32>
+ spv.ReturnValue %0: vector<4xi32>
+ }
+} attributes {
+ spv.target_env = #spv.target_env<
+ #spv.vce<v1.5, [Shader, GroupNonUniformBallot], []>,
+ {max_compute_workgroup_invocations = 128 : i32,
+ max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+}
+
+//===----------------------------------------------------------------------===//
+// Capability
+//===----------------------------------------------------------------------===//
+
+// Test minimal capabilities.
+
+// CHECK: vce_triple = #spv.vce<v1.0, [Shader], []>
+spv.module "Logical" "GLSL450" {
+ spv.func @iadd(%val : i32) -> i32 "None" {
+ %0 = spv.IAdd %val, %val: i32
+ spv.ReturnValue %0: i32
+ }
+} attributes {
+ spv.target_env = #spv.target_env<
+ #spv.vce<v1.0, [Shader, Float16, Float64, Int16, Int64, VariablePointers], []>,
+ {max_compute_workgroup_invocations = 128 : i32,
+ max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+}
+
+// Test deducing implied capability.
+// AtomicStorage implies Shader.
+
+// CHECK: vce_triple = #spv.vce<v1.0, [Shader], []>
+spv.module "Logical" "GLSL450" {
+ spv.func @iadd(%val : i32) -> i32 "None" {
+ %0 = spv.IAdd %val, %val: i32
+ spv.ReturnValue %0: i32
+ }
+} attributes {
+ spv.target_env = #spv.target_env<
+ #spv.vce<v1.0, [AtomicStorage], []>,
+ {max_compute_workgroup_invocations = 128 : i32,
+ max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+}
+
+// Test selecting the capability available in the target environment.
+// spv.GroupNonUniform op itself can be enabled via any of
+// * GroupNonUniformArithmetic
+// * GroupNonUniformClustered
+// * GroupNonUniformPartitionedNV
+// Its 'Reduce' group operation can be enabled via any of
+// * Kernel
+// * GroupNonUniformArithmetic
+// * GroupNonUniformBallot
+
+// CHECK: vce_triple = #spv.vce<v1.3, [GroupNonUniformArithmetic, Shader], []>
+spv.module "Logical" "GLSL450" {
+ spv.func @group_non_uniform_iadd(%val : i32) -> i32 "None" {
+ %0 = spv.GroupNonUniformIAdd "Subgroup" "Reduce" %val : i32
+ spv.ReturnValue %0: i32
+ }
+} attributes {
+ spv.target_env = #spv.target_env<
+ #spv.vce<v1.3, [Shader, GroupNonUniformArithmetic], []>,
+ {max_compute_workgroup_invocations = 128 : i32,
+ max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+}
+
+// CHECK: vce_triple = #spv.vce<v1.3, [GroupNonUniformClustered, GroupNonUniformBallot, Shader], []>
+spv.module "Logical" "GLSL450" {
+ spv.func @group_non_uniform_iadd(%val : i32) -> i32 "None" {
+ %0 = spv.GroupNonUniformIAdd "Subgroup" "Reduce" %val : i32
+ spv.ReturnValue %0: i32
+ }
+} attributes {
+ spv.target_env = #spv.target_env<
+ #spv.vce<v1.3, [Shader, GroupNonUniformClustered, GroupNonUniformBallot], []>,
+ {max_compute_workgroup_invocations = 128 : i32,
+ max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+}
+
+//===----------------------------------------------------------------------===//
+// Extension
+//===----------------------------------------------------------------------===//
+
+// Test deducing minimal extensions.
+// spv.SubgroupBallotKHR requires the SPV_KHR_shader_ballot extension.
+
+// CHECK: vce_triple = #spv.vce<v1.0, [SubgroupBallotKHR, Shader], [SPV_KHR_shader_ballot]>
+spv.module "Logical" "GLSL450" {
+ spv.func @subgroup_ballot(%predicate : i1) -> vector<4xi32> "None" {
+ %0 = spv.SubgroupBallotKHR %predicate: vector<4xi32>
+ spv.ReturnValue %0: vector<4xi32>
+ }
+} attributes {
+ spv.target_env = #spv.target_env<
+ #spv.vce<v1.0, [Shader, SubgroupBallotKHR],
+ [SPV_KHR_shader_ballot, SPV_KHR_shader_clock, SPV_KHR_variable_pointers]>,
+ {max_compute_workgroup_invocations = 128 : i32,
+ max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+}
+
+// Test deducing implied extension.
+// Vulkan memory model requires SPV_KHR_vulkan_memory_model, which is enabled
+// implicitly by v1.5.
+
+// CHECK: vce_triple = #spv.vce<v1.0, [VulkanMemoryModel], [SPV_KHR_vulkan_memory_model]>
+spv.module "Logical" "Vulkan" {
+ spv.func @iadd(%val : i32) -> i32 "None" {
+ %0 = spv.IAdd %val, %val: i32
+ spv.ReturnValue %0: i32
+ }
+} attributes {
+ spv.target_env = #spv.target_env<
+ #spv.vce<v1.5, [Shader, VulkanMemoryModel], []>,
+ {max_compute_workgroup_invocations = 128 : i32,
+ max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+}
More information about the Mlir-commits
mailing list