[Mlir-commits] [mlir] 67e8690 - [mlir][spirv] Let SPIRVConversionTarget consider type availability

Lei Zhang llvmlistbot at llvm.org
Wed Mar 18 17:13:22 PDT 2020


Author: Lei Zhang
Date: 2020-03-18T20:11:04-04:00
New Revision: 67e8690e53c341ba433f9c2de3f5a16b8beb7f0b

URL: https://github.com/llvm/llvm-project/commit/67e8690e53c341ba433f9c2de3f5a16b8beb7f0b
DIFF: https://github.com/llvm/llvm-project/commit/67e8690e53c341ba433f9c2de3f5a16b8beb7f0b.diff

LOG: [mlir][spirv] Let SPIRVConversionTarget consider type availability

Previously we only consider the version/extension/capability requirement
on the op itself. This commit updates SPIRVConversionTarget to also
take into consideration the values' types when deciding op legality.

Differential Revision: https://reviews.llvm.org/D75876

Added: 
    

Modified: 
    mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
    mlir/test/Conversion/GPUToSPIRV/if.mlir
    mlir/test/Conversion/GPUToSPIRV/load-store.mlir
    mlir/test/Conversion/GPUToSPIRV/loop.mlir
    mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir
    mlir/test/Conversion/StandardToSPIRV/subview-to-spirv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
index e9250c56a1d2..6d73432fead4 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
@@ -1,4 +1,4 @@
-//===- SPIRVLowering.cpp - Standard to SPIR-V dialect conversion--===//
+//===- SPIRVLowering.cpp - SPIR-V lowering utilities ----------------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/SPIRV/SPIRVDialect.h"
 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
 #include "llvm/ADT/Sequence.h"
+#include "llvm/ADT/StringExtras.h"
 #include "llvm/Support/Debug.h"
 
 #include <functional>
@@ -443,6 +444,66 @@ spirv::SPIRVConversionTarget::SPIRVConversionTarget(
   }
 }
 
+/// Checks that `candidates` extension requirements are possible to be satisfied
+/// with the given `allowedExtensions`.
+///
+///  `candidates` is a vector of vector for extension requirements following
+/// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
+/// convention.
+static LogicalResult checkExtensionRequirements(
+    Operation *op, const llvm::SmallSet<spirv::Extension, 4> &allowedExtensions,
+    const spirv::SPIRVType::ExtensionArrayRefVector &candidates) {
+  for (const auto &ors : candidates) {
+    auto chosen = llvm::find_if(ors, [&](spirv::Extension ext) {
+      return allowedExtensions.count(ext);
+    });
+
+    if (chosen == ors.end()) {
+      SmallVector<StringRef, 4> extStrings;
+      for (spirv::Extension ext : ors)
+        extStrings.push_back(spirv::stringifyExtension(ext));
+
+      LLVM_DEBUG(llvm::dbgs() << op->getName()
+                              << "illegal: requires at least one extension in ["
+                              << llvm::join(extStrings, ", ")
+                              << "] but none allowed in target environment\n");
+      return failure();
+    }
+  }
+  return success();
+}
+
+/// Checks that `candidates`capability requirements are possible to be satisfied
+/// with the given `allowedCapabilities`.
+///
+///  `candidates` is a vector of vector for capability requirements following
+/// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D))
+/// convention.
+static LogicalResult checkCapabilityRequirements(
+    Operation *op,
+    const llvm::SmallSet<spirv::Capability, 8> &allowedCapabilities,
+    const spirv::SPIRVType::CapabilityArrayRefVector &candidates) {
+  for (const auto &ors : candidates) {
+    auto chosen = llvm::find_if(ors, [&](spirv::Capability cap) {
+      return allowedCapabilities.count(cap);
+    });
+
+    if (chosen == ors.end()) {
+      SmallVector<StringRef, 4> capStrings;
+      for (spirv::Capability cap : ors)
+        capStrings.push_back(spirv::stringifyCapability(cap));
+
+      LLVM_DEBUG(llvm::dbgs()
+                 << op->getName()
+                 << "illegal: requires at least one capability in ["
+                 << llvm::join(capStrings, ", ")
+                 << "] but none allowed in target environment\n");
+      return failure();
+    }
+  }
+  return success();
+}
+
 bool spirv::SPIRVConversionTarget::isLegalOp(Operation *op) {
   // Make sure this op is available at the given version. Ops not implementing
   // QueryMinVersionInterface/QueryMaxVersionInterface are available to all
@@ -464,38 +525,47 @@ bool spirv::SPIRVConversionTarget::isLegalOp(Operation *op) {
       return false;
     }
 
-  // Make sure this op's required extensions are allowed to use. For each op,
-  // we return a vector of vector for its extension requirements following
-  // ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
-  // convention. Ops not implementing QueryExtensionInterface do not require
-  // extensions to be available.
-  if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op)) {
-    auto exts = extensions.getExtensions();
-    for (const auto &ors : exts)
-      if (llvm::all_of(ors, [this](spirv::Extension ext) {
-            return this->givenExtensions.count(ext) == 0;
-          })) {
-        LLVM_DEBUG(llvm::dbgs() << op->getName()
-                                << " illegal: missing required extension\n");
-        return false;
-      }
-  }
+  // Make sure this op's required extensions are allowed to use. Ops not
+  // implementing QueryExtensionInterface do not require extensions to be
+  // available.
+  if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op))
+    if (failed(checkExtensionRequirements(op, this->givenExtensions,
+                                          extensions.getExtensions())))
+      return false;
 
-  // Make sure this op's required extensions are allowed to use. For each op,
-  // we return a vector of vector for its capability requirements following
-  // ((Capability::A OR Extension::B) AND (Capability::C OR Capability::D))
-  // convention. Ops not implementing QueryExtensionInterface do not require
-  // extensions to be available.
-  if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op)) {
-    auto caps = capabilities.getCapabilities();
-    for (const auto &ors : caps)
-      if (llvm::all_of(ors, [this](spirv::Capability cap) {
-            return this->givenCapabilities.count(cap) == 0;
-          })) {
-        LLVM_DEBUG(llvm::dbgs() << op->getName()
-                                << " illegal: missing required capability\n");
-        return false;
-      }
+  // Make sure this op's required extensions are allowed to use. Ops not
+  // implementing QueryCapabilityInterface do not require capabilities to be
+  // available.
+  if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op))
+    if (failed(checkCapabilityRequirements(op, this->givenCapabilities,
+                                           capabilities.getCapabilities())))
+      return false;
+
+  SmallVector<Type, 4> valueTypes;
+  valueTypes.append(op->operand_type_begin(), op->operand_type_end());
+  valueTypes.append(op->result_type_begin(), op->result_type_end());
+
+  // Special treatment for global variables, whose type requirements are
+  // conveyed by type attributes.
+  if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))
+    valueTypes.push_back(globalVar.type());
+
+  // Make sure the op's operands/results use types that are allowed by the
+  // target environment.
+  SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
+  SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
+  for (Type valueType : valueTypes) {
+    typeExtensions.clear();
+    valueType.cast<spirv::SPIRVType>().getExtensions(typeExtensions);
+    if (failed(checkExtensionRequirements(op, this->givenExtensions,
+                                          typeExtensions)))
+      return false;
+
+    typeCapabilities.clear();
+    valueType.cast<spirv::SPIRVType>().getCapabilities(typeCapabilities);
+    if (failed(checkCapabilityRequirements(op, this->givenCapabilities,
+                                           typeCapabilities)))
+      return false;
   }
 
   return true;

diff  --git a/mlir/test/Conversion/GPUToSPIRV/if.mlir b/mlir/test/Conversion/GPUToSPIRV/if.mlir
index 1585c53116c5..8a8aa1c88813 100644
--- a/mlir/test/Conversion/GPUToSPIRV/if.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/if.mlir
@@ -1,6 +1,12 @@
 // RUN: mlir-opt -convert-gpu-to-spirv %s -o - | FileCheck %s
 
-module attributes {gpu.container_module} {
+module attributes {
+  gpu.container_module,
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>,
+    {max_compute_workgroup_invocations = 128 : i32,
+     max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+} {
   func @main(%arg0 : memref<10xf32>, %arg1 : i1) {
     %c0 = constant 1 : index
     "gpu.launch_func"(%c0, %c0, %c0, %c0, %c0, %c0, %arg0, %arg1) { kernel = "kernel_simple_selection", kernel_module = @kernels} : (index, index, index, index, index, index, memref<10xf32>, i1) -> ()

diff  --git a/mlir/test/Conversion/GPUToSPIRV/load-store.mlir b/mlir/test/Conversion/GPUToSPIRV/load-store.mlir
index d0224fd16e02..05c9d90c498c 100644
--- a/mlir/test/Conversion/GPUToSPIRV/load-store.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/load-store.mlir
@@ -1,6 +1,12 @@
 // RUN: mlir-opt -convert-gpu-to-spirv %s -o - | FileCheck %s
 
-module attributes {gpu.container_module} {
+module attributes {
+  gpu.container_module,
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>,
+    {max_compute_workgroup_invocations = 128 : i32,
+     max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+} {
   func @load_store(%arg0: memref<12x4xf32>, %arg1: memref<12x4xf32>, %arg2: memref<12x4xf32>) {
     %c0 = constant 0 : index
     %c12 = constant 12 : index

diff  --git a/mlir/test/Conversion/GPUToSPIRV/loop.mlir b/mlir/test/Conversion/GPUToSPIRV/loop.mlir
index 7044d5474d3c..8adc5e355f08 100644
--- a/mlir/test/Conversion/GPUToSPIRV/loop.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/loop.mlir
@@ -1,6 +1,12 @@
 // RUN: mlir-opt -convert-gpu-to-spirv %s -o - | FileCheck %s
 
-module attributes {gpu.container_module} {
+module attributes {
+  gpu.container_module,
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>,
+    {max_compute_workgroup_invocations = 128 : i32,
+     max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+} {
   func @loop(%arg0 : memref<10xf32>, %arg1 : memref<10xf32>) {
     %c0 = constant 1 : index
     "gpu.launch_func"(%c0, %c0, %c0, %c0, %c0, %c0, %arg0, %arg1) { kernel = "loop_kernel", kernel_module = @kernels} : (index, index, index, index, index, index, memref<10xf32>, memref<10xf32>) -> ()

diff  --git a/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir
index 9b8d695af422..26e2ea42d3a2 100644
--- a/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir
+++ b/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir
@@ -1,5 +1,12 @@
 // RUN: mlir-opt -convert-std-to-spirv %s -o - | FileCheck %s
 
+module attributes {
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.0, [Shader, Int64, Float64], [SPV_KHR_storage_buffer_storage_class]>,
+    {max_compute_workgroup_invocations = 128 : i32,
+     max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+} {
+
 //===----------------------------------------------------------------------===//
 // std binary arithmetic ops
 //===----------------------------------------------------------------------===//
@@ -366,3 +373,5 @@ func @load_store_zero_rank_int(%arg0: memref<i32>, %arg1: memref<i32>) {
   store %0, %arg1[] : memref<i32>
   return
 }
+
+} // end module

diff  --git a/mlir/test/Conversion/StandardToSPIRV/subview-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/subview-to-spirv.mlir
index c9d1195bc056..cc94c089dfb2 100644
--- a/mlir/test/Conversion/StandardToSPIRV/subview-to-spirv.mlir
+++ b/mlir/test/Conversion/StandardToSPIRV/subview-to-spirv.mlir
@@ -4,6 +4,13 @@
 // the desired output. Adding all of patterns within a single pass does
 // not seem to work.
 
+module attributes {
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>,
+    {max_compute_workgroup_invocations = 128 : i32,
+     max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+} {
+
 //===----------------------------------------------------------------------===//
 // std.subview
 //===----------------------------------------------------------------------===//
@@ -51,3 +58,5 @@ func @fold_static_stride_subview_with_store(%arg0 : memref<12x32xf32>, %arg1 : i
   store %arg5, %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [64, 3]>
   return
 }
+
+} // end module


        


More information about the Mlir-commits mailing list