[Mlir-commits] [mlir] 6dd07fa - [mlir][spirv] Add utilities for push constant value

Lei Zhang llvmlistbot at llvm.org
Fri Apr 2 04:55:03 PDT 2021


Author: Lei Zhang
Date: 2021-04-02T07:51:07-04:00
New Revision: 6dd07fa513cd3b806e7f852bb98e5c34bab11b36

URL: https://github.com/llvm/llvm-project/commit/6dd07fa513cd3b806e7f852bb98e5c34bab11b36
DIFF: https://github.com/llvm/llvm-project/commit/6dd07fa513cd3b806e7f852bb98e5c34bab11b36.diff

LOG: [mlir][spirv] Add utilities for push constant value

This commit add utility functions for creating push constant
storage variable and loading values from it.

Along the way, performs some clean up:

* Deleted `setABIAttrs`, which is just a 4-liner function
  with one user.
* Moved `SPIRVConverstionTarget` into `mlir` namespace,
  to be consistent with `SPIRVTypeConverter` and
  `LLVMConversionTarget`.

Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D99725

Added: 
    

Modified: 
    mlir/docs/Dialects/SPIR-V.md
    mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
    mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
    mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
    mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp
    mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp
    mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRVPass.cpp
    mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp
    mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
    mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/Dialects/SPIR-V.md b/mlir/docs/Dialects/SPIR-V.md
index 014ca97a63ad2..8951323959927 100644
--- a/mlir/docs/Dialects/SPIR-V.md
+++ b/mlir/docs/Dialects/SPIR-V.md
@@ -932,9 +932,15 @@ The attribute has a few fields:
 *   Binding number for the corresponding resource variable.
 *   Storage class for the corresponding resource variable.
 
-The SPIR-V dialect provides a [`LowerABIAttributesPass`][MlirSpirvPasses] for
-consuming these attributes and create SPIR-V module complying with the
-interface.
+The SPIR-V dialect provides a [`LowerABIAttributesPass`][MlirSpirvPasses] that
+uses this information to lower the entry point function and its ABI consistent
+with the Vulkan validation rules. Specifically,
+
+*   Creates `spv.GlobalVariable`s for the arguments, and replaces all uses of
+    the argument with this variable. The SSA value used for replacement is
+    obtained using the `spv.mlir.addressof` operation.
+*   Adds the `spv.EntryPoint` and `spv.ExecutionMode` operations into the
+    `spv.module` for the entry function.
 
 ## Serialization and deserialization
 
@@ -1052,29 +1058,8 @@ is obtained from the layout specification of the memref. The storage class of
 the pointer type are derived from the memref's memory space with
 `SPIRVTypeConverter::getStorageClassForMemorySpace()`.
 
-### `SPIRVOpLowering`
-
-`mlir::SPIRVOpLowering` is a base class that can be used to define the patterns
-used for implementing the lowering. For now this only provides derived classes
-access to an instance of `mlir::SPIRVTypeLowering` class.
-
 ### Utility functions for lowering
 
-#### Setting shader interface
-
-The method `mlir::spirv::setABIAttrs` allows setting the [shader interface
-attributes](#shader-interface-abi) for a function that is to be an entry
-point function within the `spv.module` on lowering. A later pass
-`mlir::spirv::LowerABIAttributesPass` uses this information to lower the entry
-point function and its ABI consistent with the Vulkan validation
-rules. Specifically,
-
-*   Creates `spv.GlobalVariable`s for the arguments, and replaces all uses of
-    the argument with this variable. The SSA value used for replacement is
-    obtained using the `spv.mlir.addressof` operation.
-*   Adds the `spv.EntryPoint` and `spv.ExecutionMode` operations into the
-    `spv.module` for the entry function.
-
 #### Setting layout for shader interface variables
 
 SPIR-V validation rules for shaders require composite objects to be explicitly

diff  --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
index 881f8e90fa0db..27e47f92396bc 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
@@ -21,6 +21,10 @@
 
 namespace mlir {
 
+//===----------------------------------------------------------------------===//
+// Type Converter
+//===----------------------------------------------------------------------===//
+
 /// Type conversion from builtin types to SPIR-V types for shader interface.
 ///
 /// Non-32-bit scalar types require special hardware support that may not exist
@@ -63,24 +67,22 @@ class SPIRVTypeConverter : public TypeConverter {
   spirv::TargetEnv targetEnv;
 };
 
-/// Appends to a pattern list additional patterns for translating the builtin
-/// `func` op to the SPIR-V dialect. These patterns do not handle shader
-/// interface/ABI; they convert function parameters to be of SPIR-V allowed
-/// types.
-void populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
-                                        RewritePatternSet &patterns);
-
-namespace spirv {
-class AccessChainOp;
-class FuncOp;
+//===----------------------------------------------------------------------===//
+// Conversion Target
+//===----------------------------------------------------------------------===//
 
+// The default SPIR-V conversion target.
+//
+// It takes a SPIR-V target environment and controls operation legality based on
+// the their availability in the target environment.
 class SPIRVConversionTarget : public ConversionTarget {
 public:
   /// Creates a SPIR-V conversion target for the given target environment.
-  static std::unique_ptr<SPIRVConversionTarget> get(TargetEnvAttr targetAttr);
+  static std::unique_ptr<SPIRVConversionTarget>
+  get(spirv::TargetEnvAttr targetAttr);
 
 private:
-  explicit SPIRVConversionTarget(TargetEnvAttr targetAttr);
+  explicit SPIRVConversionTarget(spirv::TargetEnvAttr targetAttr);
 
   // Be explicit that instance of this class cannot be copied or moved: there
   // are lambdas capturing fields of the instance.
@@ -93,16 +95,37 @@ class SPIRVConversionTarget : public ConversionTarget {
   /// environment.
   bool isLegalOp(Operation *op);
 
-  TargetEnv targetEnv;
+  spirv::TargetEnv targetEnv;
 };
 
+//===----------------------------------------------------------------------===//
+// Patterns and Utility Functions
+//===----------------------------------------------------------------------===//
+
+/// Appends to a pattern list additional patterns for translating the builtin
+/// `func` op to the SPIR-V dialect. These patterns do not handle shader
+/// interface/ABI; they convert function parameters to be of SPIR-V allowed
+/// types.
+void populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
+                                        RewritePatternSet &patterns);
+
+namespace spirv {
+class AccessChainOp;
+
 /// Returns the value for the given `builtin` variable. This function gets or
 /// inserts the global variable associated for the builtin within the nearest
-/// enclosing op that has a symbol table. Returns null Value if such an
-/// enclosing op cannot be found.
+/// symbol table enclosing `op`. Returns null Value on error.
 Value getBuiltinVariableValue(Operation *op, BuiltIn builtin,
                               OpBuilder &builder);
 
+/// Gets the value at the given `offset` of the push constant storage with a
+/// total of `elementCount` 32-bit integers. A global variable will be created
+/// in the nearest symbol table enclosing `op` for the push constant storage if
+/// not existing. Load ops will be created via the given `builder` to load
+/// values from the push constant. Returns null Value on error.
+Value getPushConstantValue(Operation *op, unsigned elementCount,
+                           unsigned offset, OpBuilder &builder);
+
 /// Generates IR to perform index linearization with the given `indices` and
 /// their corresponding `strides`, adding an initial `offset`.
 Value linearizeIndex(ValueRange indices, ArrayRef<int64_t> strides,
@@ -118,11 +141,6 @@ spirv::AccessChainOp getElementPtr(SPIRVTypeConverter &typeConverter,
                                    ValueRange indices, Location loc,
                                    OpBuilder &builder);
 
-/// Sets the InterfaceVarABIAttr and EntryPointABIAttr for a function and its
-/// arguments.
-LogicalResult setABIAttrs(spirv::FuncOp funcOp,
-                          EntryPointABIAttr entryPointInfo,
-                          ArrayRef<InterfaceVarABIAttr> argABIInfo);
 } // namespace spirv
 } // namespace mlir
 

diff  --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index c2cd4baea631c..2066debb7d45e 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -208,8 +208,13 @@ lowerAsEntryFunction(gpu::GPUFuncOp funcOp, TypeConverter &typeConverter,
     return nullptr;
   rewriter.eraseOp(funcOp);
 
-  if (failed(spirv::setABIAttrs(newFuncOp, entryPointInfo, argABIInfo)))
-    return nullptr;
+  // Set the attributes for argument and the function.
+  StringRef argABIAttrName = spirv::getInterfaceVarABIAttrName();
+  for (auto argIndex : llvm::seq<unsigned>(0, argABIInfo.size())) {
+    newFuncOp.setArgAttr(argIndex, argABIAttrName, argABIInfo[argIndex]);
+  }
+  newFuncOp->setAttr(spirv::getEntryPointABIAttrName(), entryPointInfo);
+
   return newFuncOp;
 }
 

diff  --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
index 1f23f7ce380ef..7109b2b1e8637 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
@@ -54,7 +54,7 @@ void GPUToSPIRVPass::runOnOperation() {
 
   auto targetAttr = spirv::lookupTargetEnvOrDefault(module);
   std::unique_ptr<ConversionTarget> target =
-      spirv::SPIRVConversionTarget::get(targetAttr);
+      SPIRVConversionTarget::get(targetAttr);
 
   SPIRVTypeConverter typeConverter(targetAttr);
   RewritePatternSet patterns(context);

diff  --git a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp
index 6038a9841cf95..815e3d16e54c5 100644
--- a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp
@@ -27,7 +27,7 @@ void LinalgToSPIRVPass::runOnOperation() {
 
   auto targetAttr = spirv::lookupTargetEnvOrDefault(module);
   std::unique_ptr<ConversionTarget> target =
-      spirv::SPIRVConversionTarget::get(targetAttr);
+      SPIRVConversionTarget::get(targetAttr);
 
   SPIRVTypeConverter typeConverter(targetAttr);
   RewritePatternSet patterns(context);

diff  --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp
index 637e6a7501b71..84e64d6cc62a0 100644
--- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp
@@ -33,7 +33,7 @@ void SCFToSPIRVPass::runOnOperation() {
 
   auto targetAttr = spirv::lookupTargetEnvOrDefault(module);
   std::unique_ptr<ConversionTarget> target =
-      spirv::SPIRVConversionTarget::get(targetAttr);
+      SPIRVConversionTarget::get(targetAttr);
 
   SPIRVTypeConverter typeConverter(targetAttr);
   ScfToSPIRVContext scfContext;

diff  --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRVPass.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRVPass.cpp
index c738537f74382..5cd00e1c8cd2d 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRVPass.cpp
@@ -32,7 +32,7 @@ void ConvertStandardToSPIRVPass::runOnOperation() {
 
   auto targetAttr = spirv::lookupTargetEnvOrDefault(module);
   std::unique_ptr<ConversionTarget> target =
-      spirv::SPIRVConversionTarget::get(targetAttr);
+      SPIRVConversionTarget::get(targetAttr);
 
   SPIRVTypeConverter typeConverter(targetAttr);
   RewritePatternSet patterns(context);

diff  --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp
index e170df2948fef..9ffd1e595cd59 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp
@@ -34,7 +34,7 @@ void LowerVectorToSPIRVPass::runOnOperation() {
 
   auto targetAttr = spirv::lookupTargetEnvOrDefault(module);
   std::unique_ptr<ConversionTarget> target =
-      spirv::SPIRVConversionTarget::get(targetAttr);
+      SPIRVConversionTarget::get(targetAttr);
 
   SPIRVTypeConverter typeConverter(targetAttr);
   RewritePatternSet patterns(context);

diff  --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 7539e9a050765..4c01b97cc0bee 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -171,12 +171,14 @@ static Optional<int64_t> getTypeNumBytes(Type t) {
     }
     return bitWidth / 8;
   }
+
   if (auto vecType = t.dyn_cast<VectorType>()) {
     auto elementSize = getTypeNumBytes(vecType.getElementType());
     if (!elementSize)
       return llvm::None;
     return vecType.getNumElements() * *elementSize;
   }
+
   if (auto memRefType = t.dyn_cast<MemRefType>()) {
     // TODO: Layout should also be controlled by the ABI attributes. For now
     // using the layout from MemRef.
@@ -207,7 +209,9 @@ static Optional<int64_t> getTypeNumBytes(Type t) {
       memrefSize = std::max(memrefSize, shape.value() * strides[shape.index()]);
     }
     return (offset + memrefSize) * elementSize.getValue();
-  } else if (auto tensorType = t.dyn_cast<TensorType>()) {
+  }
+
+  if (auto tensorType = t.dyn_cast<TensorType>()) {
     if (!tensorType.hasStaticShape()) {
       return llvm::None;
     }
@@ -221,6 +225,7 @@ static Optional<int64_t> getTypeNumBytes(Type t) {
     }
     return size;
   }
+
   // TODO: Add size computation for other types.
   return llvm::None;
 }
@@ -602,6 +607,80 @@ Value mlir::spirv::getBuiltinVariableValue(Operation *op,
   return builder.create<spirv::LoadOp>(op->getLoc(), ptr);
 }
 
+//===----------------------------------------------------------------------===//
+// Push constant storage
+//===----------------------------------------------------------------------===//
+
+/// Returns the pointer type for the push constant storage containing
+/// `elementCount` 32-bit integer values.
+static spirv::PointerType getPushConstantStorageType(unsigned elementCount,
+                                                     Builder &builder) {
+  auto arrayType = spirv::ArrayType::get(
+      SPIRVTypeConverter::getIndexType(builder.getContext()), elementCount,
+      /*stride=*/4);
+  auto structType = spirv::StructType::get({arrayType}, /*offsetInfo=*/0);
+  return spirv::PointerType::get(structType, spirv::StorageClass::PushConstant);
+}
+
+/// Returns the push constant varible containing `elementCount` 32-bit integer
+/// values in `body`. Returns null op if such an op does not exit.
+static spirv::GlobalVariableOp getPushConstantVariable(Block &body,
+                                                       unsigned elementCount) {
+  for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
+    auto ptrType = varOp.type().cast<spirv::PointerType>();
+    // Note that Vulkan requires "There must be no more than one push constant
+    // block statically used per shader entry point." So we should always reuse
+    // the existing one.
+    if (ptrType.getStorageClass() == spirv::StorageClass::PushConstant) {
+      auto numElements = ptrType.getPointeeType()
+                             .cast<spirv::StructType>()
+                             .getElementType(0)
+                             .cast<spirv::ArrayType>()
+                             .getNumElements();
+      if (numElements == elementCount)
+        return varOp;
+    }
+  }
+  return nullptr;
+}
+
+/// Gets or inserts a global variable for push constant storage containing
+/// `elementCount` 32-bit integer values in `block`.
+static spirv::GlobalVariableOp
+getOrInsertPushConstantVariable(Location loc, Block &block,
+                                unsigned elementCount, OpBuilder &b) {
+  if (auto varOp = getPushConstantVariable(block, elementCount))
+    return varOp;
+
+  auto builder = OpBuilder::atBlockBegin(&block, b.getListener());
+  auto type = getPushConstantStorageType(elementCount, builder);
+  const char *name = "__push_constant_var__";
+  return builder.create<spirv::GlobalVariableOp>(loc, type, name,
+                                                 /*initializer=*/nullptr);
+}
+
+Value spirv::getPushConstantValue(Operation *op, unsigned elementCount,
+                                  unsigned offset, OpBuilder &builder) {
+  Location loc = op->getLoc();
+  Operation *parent = SymbolTable::getNearestSymbolTable(op->getParentOp());
+  if (!parent) {
+    op->emitError("expected operation to be within a module-like op");
+    return nullptr;
+  }
+
+  spirv::GlobalVariableOp varOp = getOrInsertPushConstantVariable(
+      loc, parent->getRegion(0).front(), elementCount, builder);
+
+  auto i32Type = SPIRVTypeConverter::getIndexType(builder.getContext());
+  Value zeroOp = spirv::ConstantOp::getZero(i32Type, loc, builder);
+  Value offsetOp = builder.create<spirv::ConstantOp>(
+      loc, i32Type, builder.getI32IntegerAttr(offset));
+  auto addrOp = builder.create<spirv::AddressOfOp>(loc, varOp);
+  auto acOp = builder.create<spirv::AccessChainOp>(
+      loc, addrOp, llvm::makeArrayRef({zeroOp, offsetOp}));
+  return builder.create<spirv::LoadOp>(loc, acOp);
+}
+
 //===----------------------------------------------------------------------===//
 // Index calculation
 //===----------------------------------------------------------------------===//
@@ -661,45 +740,27 @@ spirv::AccessChainOp mlir::spirv::getElementPtr(
   return builder.create<spirv::AccessChainOp>(loc, basePtr, linearizedIndices);
 }
 
-//===----------------------------------------------------------------------===//
-// Set ABI attributes for lowering entry functions.
-//===----------------------------------------------------------------------===//
-
-LogicalResult
-mlir::spirv::setABIAttrs(spirv::FuncOp funcOp,
-                         spirv::EntryPointABIAttr entryPointInfo,
-                         ArrayRef<spirv::InterfaceVarABIAttr> argABIInfo) {
-  // Set the attributes for argument and the function.
-  StringRef argABIAttrName = spirv::getInterfaceVarABIAttrName();
-  for (auto argIndex : llvm::seq<unsigned>(0, argABIInfo.size())) {
-    funcOp.setArgAttr(argIndex, argABIAttrName, argABIInfo[argIndex]);
-  }
-  funcOp->setAttr(spirv::getEntryPointABIAttrName(), entryPointInfo);
-  return success();
-}
-
 //===----------------------------------------------------------------------===//
 // SPIR-V ConversionTarget
 //===----------------------------------------------------------------------===//
 
-std::unique_ptr<spirv::SPIRVConversionTarget>
-spirv::SPIRVConversionTarget::get(spirv::TargetEnvAttr targetAttr) {
+std::unique_ptr<SPIRVConversionTarget>
+SPIRVConversionTarget::get(spirv::TargetEnvAttr targetAttr) {
   std::unique_ptr<SPIRVConversionTarget> target(
       // std::make_unique does not work here because the constructor is private.
       new SPIRVConversionTarget(targetAttr));
   SPIRVConversionTarget *targetPtr = target.get();
-  target->addDynamicallyLegalDialect<SPIRVDialect>(
+  target->addDynamicallyLegalDialect<spirv::SPIRVDialect>(
       // We need to capture the raw pointer here because it is stable:
       // target will be destroyed once this function is returned.
       [targetPtr](Operation *op) { return targetPtr->isLegalOp(op); });
   return target;
 }
 
-spirv::SPIRVConversionTarget::SPIRVConversionTarget(
-    spirv::TargetEnvAttr targetAttr)
+SPIRVConversionTarget::SPIRVConversionTarget(spirv::TargetEnvAttr targetAttr)
     : ConversionTarget(*targetAttr.getContext()), targetEnv(targetAttr) {}
 
-bool spirv::SPIRVConversionTarget::isLegalOp(Operation *op) {
+bool SPIRVConversionTarget::isLegalOp(Operation *op) {
   // Make sure this op is available at the given version. Ops not implementing
   // QueryMinVersionInterface/QueryMaxVersionInterface are available to all
   // SPIR-V versions.

diff  --git a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
index 11cd05aa9bec8..d8a0d4bd70abb 100644
--- a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
+++ b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
@@ -137,7 +137,7 @@ void ConvertToTargetEnv::runOnFunction() {
     return signalPassFailure();
   }
 
-  auto target = spirv::SPIRVConversionTarget::get(targetEnv);
+  auto target = SPIRVConversionTarget::get(targetEnv);
 
   RewritePatternSet patterns(context);
   patterns.add<ConvertToAtomCmpExchangeWeak, ConvertToBitReverse,


        


More information about the Mlir-commits mailing list