[Mlir-commits] [mlir] 67e8690 - [mlir][spirv] Let SPIRVConversionTarget consider type availability
Lei Zhang
llvmlistbot at llvm.org
Wed Mar 18 17:13:22 PDT 2020
Author: Lei Zhang
Date: 2020-03-18T20:11:04-04:00
New Revision: 67e8690e53c341ba433f9c2de3f5a16b8beb7f0b
URL: https://github.com/llvm/llvm-project/commit/67e8690e53c341ba433f9c2de3f5a16b8beb7f0b
DIFF: https://github.com/llvm/llvm-project/commit/67e8690e53c341ba433f9c2de3f5a16b8beb7f0b.diff
LOG: [mlir][spirv] Let SPIRVConversionTarget consider type availability
Previously we only consider the version/extension/capability requirement
on the op itself. This commit updates SPIRVConversionTarget to also
take into consideration the values' types when deciding op legality.
Differential Revision: https://reviews.llvm.org/D75876
Added:
Modified:
mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
mlir/test/Conversion/GPUToSPIRV/if.mlir
mlir/test/Conversion/GPUToSPIRV/load-store.mlir
mlir/test/Conversion/GPUToSPIRV/loop.mlir
mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir
mlir/test/Conversion/StandardToSPIRV/subview-to-spirv.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
index e9250c56a1d2..6d73432fead4 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
@@ -1,4 +1,4 @@
-//===- SPIRVLowering.cpp - Standard to SPIR-V dialect conversion--===//
+//===- SPIRVLowering.cpp - SPIR-V lowering utilities ----------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -15,6 +15,7 @@
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
#include "llvm/ADT/Sequence.h"
+#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/Debug.h"
#include <functional>
@@ -443,6 +444,66 @@ spirv::SPIRVConversionTarget::SPIRVConversionTarget(
}
}
+/// Checks that `candidates` extension requirements are possible to be satisfied
+/// with the given `allowedExtensions`.
+///
+/// `candidates` is a vector of vector for extension requirements following
+/// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
+/// convention.
+static LogicalResult checkExtensionRequirements(
+ Operation *op, const llvm::SmallSet<spirv::Extension, 4> &allowedExtensions,
+ const spirv::SPIRVType::ExtensionArrayRefVector &candidates) {
+ for (const auto &ors : candidates) {
+ auto chosen = llvm::find_if(ors, [&](spirv::Extension ext) {
+ return allowedExtensions.count(ext);
+ });
+
+ if (chosen == ors.end()) {
+ SmallVector<StringRef, 4> extStrings;
+ for (spirv::Extension ext : ors)
+ extStrings.push_back(spirv::stringifyExtension(ext));
+
+ LLVM_DEBUG(llvm::dbgs() << op->getName()
+ << "illegal: requires at least one extension in ["
+ << llvm::join(extStrings, ", ")
+ << "] but none allowed in target environment\n");
+ return failure();
+ }
+ }
+ return success();
+}
+
+/// Checks that `candidates`capability requirements are possible to be satisfied
+/// with the given `allowedCapabilities`.
+///
+/// `candidates` is a vector of vector for capability requirements following
+/// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D))
+/// convention.
+static LogicalResult checkCapabilityRequirements(
+ Operation *op,
+ const llvm::SmallSet<spirv::Capability, 8> &allowedCapabilities,
+ const spirv::SPIRVType::CapabilityArrayRefVector &candidates) {
+ for (const auto &ors : candidates) {
+ auto chosen = llvm::find_if(ors, [&](spirv::Capability cap) {
+ return allowedCapabilities.count(cap);
+ });
+
+ if (chosen == ors.end()) {
+ SmallVector<StringRef, 4> capStrings;
+ for (spirv::Capability cap : ors)
+ capStrings.push_back(spirv::stringifyCapability(cap));
+
+ LLVM_DEBUG(llvm::dbgs()
+ << op->getName()
+ << "illegal: requires at least one capability in ["
+ << llvm::join(capStrings, ", ")
+ << "] but none allowed in target environment\n");
+ return failure();
+ }
+ }
+ return success();
+}
+
bool spirv::SPIRVConversionTarget::isLegalOp(Operation *op) {
// Make sure this op is available at the given version. Ops not implementing
// QueryMinVersionInterface/QueryMaxVersionInterface are available to all
@@ -464,38 +525,47 @@ bool spirv::SPIRVConversionTarget::isLegalOp(Operation *op) {
return false;
}
- // Make sure this op's required extensions are allowed to use. For each op,
- // we return a vector of vector for its extension requirements following
- // ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
- // convention. Ops not implementing QueryExtensionInterface do not require
- // extensions to be available.
- if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op)) {
- auto exts = extensions.getExtensions();
- for (const auto &ors : exts)
- if (llvm::all_of(ors, [this](spirv::Extension ext) {
- return this->givenExtensions.count(ext) == 0;
- })) {
- LLVM_DEBUG(llvm::dbgs() << op->getName()
- << " illegal: missing required extension\n");
- return false;
- }
- }
+ // Make sure this op's required extensions are allowed to use. Ops not
+ // implementing QueryExtensionInterface do not require extensions to be
+ // available.
+ if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op))
+ if (failed(checkExtensionRequirements(op, this->givenExtensions,
+ extensions.getExtensions())))
+ return false;
- // Make sure this op's required extensions are allowed to use. For each op,
- // we return a vector of vector for its capability requirements following
- // ((Capability::A OR Extension::B) AND (Capability::C OR Capability::D))
- // convention. Ops not implementing QueryExtensionInterface do not require
- // extensions to be available.
- if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op)) {
- auto caps = capabilities.getCapabilities();
- for (const auto &ors : caps)
- if (llvm::all_of(ors, [this](spirv::Capability cap) {
- return this->givenCapabilities.count(cap) == 0;
- })) {
- LLVM_DEBUG(llvm::dbgs() << op->getName()
- << " illegal: missing required capability\n");
- return false;
- }
+ // Make sure this op's required extensions are allowed to use. Ops not
+ // implementing QueryCapabilityInterface do not require capabilities to be
+ // available.
+ if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op))
+ if (failed(checkCapabilityRequirements(op, this->givenCapabilities,
+ capabilities.getCapabilities())))
+ return false;
+
+ SmallVector<Type, 4> valueTypes;
+ valueTypes.append(op->operand_type_begin(), op->operand_type_end());
+ valueTypes.append(op->result_type_begin(), op->result_type_end());
+
+ // Special treatment for global variables, whose type requirements are
+ // conveyed by type attributes.
+ if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))
+ valueTypes.push_back(globalVar.type());
+
+ // Make sure the op's operands/results use types that are allowed by the
+ // target environment.
+ SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
+ SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
+ for (Type valueType : valueTypes) {
+ typeExtensions.clear();
+ valueType.cast<spirv::SPIRVType>().getExtensions(typeExtensions);
+ if (failed(checkExtensionRequirements(op, this->givenExtensions,
+ typeExtensions)))
+ return false;
+
+ typeCapabilities.clear();
+ valueType.cast<spirv::SPIRVType>().getCapabilities(typeCapabilities);
+ if (failed(checkCapabilityRequirements(op, this->givenCapabilities,
+ typeCapabilities)))
+ return false;
}
return true;
diff --git a/mlir/test/Conversion/GPUToSPIRV/if.mlir b/mlir/test/Conversion/GPUToSPIRV/if.mlir
index 1585c53116c5..8a8aa1c88813 100644
--- a/mlir/test/Conversion/GPUToSPIRV/if.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/if.mlir
@@ -1,6 +1,12 @@
// RUN: mlir-opt -convert-gpu-to-spirv %s -o - | FileCheck %s
-module attributes {gpu.container_module} {
+module attributes {
+ gpu.container_module,
+ spv.target_env = #spv.target_env<
+ #spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>,
+ {max_compute_workgroup_invocations = 128 : i32,
+ max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+} {
func @main(%arg0 : memref<10xf32>, %arg1 : i1) {
%c0 = constant 1 : index
"gpu.launch_func"(%c0, %c0, %c0, %c0, %c0, %c0, %arg0, %arg1) { kernel = "kernel_simple_selection", kernel_module = @kernels} : (index, index, index, index, index, index, memref<10xf32>, i1) -> ()
diff --git a/mlir/test/Conversion/GPUToSPIRV/load-store.mlir b/mlir/test/Conversion/GPUToSPIRV/load-store.mlir
index d0224fd16e02..05c9d90c498c 100644
--- a/mlir/test/Conversion/GPUToSPIRV/load-store.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/load-store.mlir
@@ -1,6 +1,12 @@
// RUN: mlir-opt -convert-gpu-to-spirv %s -o - | FileCheck %s
-module attributes {gpu.container_module} {
+module attributes {
+ gpu.container_module,
+ spv.target_env = #spv.target_env<
+ #spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>,
+ {max_compute_workgroup_invocations = 128 : i32,
+ max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+} {
func @load_store(%arg0: memref<12x4xf32>, %arg1: memref<12x4xf32>, %arg2: memref<12x4xf32>) {
%c0 = constant 0 : index
%c12 = constant 12 : index
diff --git a/mlir/test/Conversion/GPUToSPIRV/loop.mlir b/mlir/test/Conversion/GPUToSPIRV/loop.mlir
index 7044d5474d3c..8adc5e355f08 100644
--- a/mlir/test/Conversion/GPUToSPIRV/loop.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/loop.mlir
@@ -1,6 +1,12 @@
// RUN: mlir-opt -convert-gpu-to-spirv %s -o - | FileCheck %s
-module attributes {gpu.container_module} {
+module attributes {
+ gpu.container_module,
+ spv.target_env = #spv.target_env<
+ #spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>,
+ {max_compute_workgroup_invocations = 128 : i32,
+ max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+} {
func @loop(%arg0 : memref<10xf32>, %arg1 : memref<10xf32>) {
%c0 = constant 1 : index
"gpu.launch_func"(%c0, %c0, %c0, %c0, %c0, %c0, %arg0, %arg1) { kernel = "loop_kernel", kernel_module = @kernels} : (index, index, index, index, index, index, memref<10xf32>, memref<10xf32>) -> ()
diff --git a/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir
index 9b8d695af422..26e2ea42d3a2 100644
--- a/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir
+++ b/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir
@@ -1,5 +1,12 @@
// RUN: mlir-opt -convert-std-to-spirv %s -o - | FileCheck %s
+module attributes {
+ spv.target_env = #spv.target_env<
+ #spv.vce<v1.0, [Shader, Int64, Float64], [SPV_KHR_storage_buffer_storage_class]>,
+ {max_compute_workgroup_invocations = 128 : i32,
+ max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+} {
+
//===----------------------------------------------------------------------===//
// std binary arithmetic ops
//===----------------------------------------------------------------------===//
@@ -366,3 +373,5 @@ func @load_store_zero_rank_int(%arg0: memref<i32>, %arg1: memref<i32>) {
store %0, %arg1[] : memref<i32>
return
}
+
+} // end module
diff --git a/mlir/test/Conversion/StandardToSPIRV/subview-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/subview-to-spirv.mlir
index c9d1195bc056..cc94c089dfb2 100644
--- a/mlir/test/Conversion/StandardToSPIRV/subview-to-spirv.mlir
+++ b/mlir/test/Conversion/StandardToSPIRV/subview-to-spirv.mlir
@@ -4,6 +4,13 @@
// the desired output. Adding all of patterns within a single pass does
// not seem to work.
+module attributes {
+ spv.target_env = #spv.target_env<
+ #spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>,
+ {max_compute_workgroup_invocations = 128 : i32,
+ max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+} {
+
//===----------------------------------------------------------------------===//
// std.subview
//===----------------------------------------------------------------------===//
@@ -51,3 +58,5 @@ func @fold_static_stride_subview_with_store(%arg0 : memref<12x32xf32>, %arg1 : i
store %arg5, %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [64, 3]>
return
}
+
+} // end module
More information about the Mlir-commits
mailing list