[Mlir-commits] [mlir] 6dd07fa - [mlir][spirv] Add utilities for push constant value
Lei Zhang
llvmlistbot at llvm.org
Fri Apr 2 04:55:03 PDT 2021
Author: Lei Zhang
Date: 2021-04-02T07:51:07-04:00
New Revision: 6dd07fa513cd3b806e7f852bb98e5c34bab11b36
URL: https://github.com/llvm/llvm-project/commit/6dd07fa513cd3b806e7f852bb98e5c34bab11b36
DIFF: https://github.com/llvm/llvm-project/commit/6dd07fa513cd3b806e7f852bb98e5c34bab11b36.diff
LOG: [mlir][spirv] Add utilities for push constant value
This commit add utility functions for creating push constant
storage variable and loading values from it.
Along the way, performs some clean up:
* Deleted `setABIAttrs`, which is just a 4-liner function
with one user.
* Moved `SPIRVConverstionTarget` into `mlir` namespace,
to be consistent with `SPIRVTypeConverter` and
`LLVMConversionTarget`.
Reviewed By: mravishankar
Differential Revision: https://reviews.llvm.org/D99725
Added:
Modified:
mlir/docs/Dialects/SPIR-V.md
mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp
mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp
mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRVPass.cpp
mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp
mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
Removed:
################################################################################
diff --git a/mlir/docs/Dialects/SPIR-V.md b/mlir/docs/Dialects/SPIR-V.md
index 014ca97a63ad2..8951323959927 100644
--- a/mlir/docs/Dialects/SPIR-V.md
+++ b/mlir/docs/Dialects/SPIR-V.md
@@ -932,9 +932,15 @@ The attribute has a few fields:
* Binding number for the corresponding resource variable.
* Storage class for the corresponding resource variable.
-The SPIR-V dialect provides a [`LowerABIAttributesPass`][MlirSpirvPasses] for
-consuming these attributes and create SPIR-V module complying with the
-interface.
+The SPIR-V dialect provides a [`LowerABIAttributesPass`][MlirSpirvPasses] that
+uses this information to lower the entry point function and its ABI consistent
+with the Vulkan validation rules. Specifically,
+
+* Creates `spv.GlobalVariable`s for the arguments, and replaces all uses of
+ the argument with this variable. The SSA value used for replacement is
+ obtained using the `spv.mlir.addressof` operation.
+* Adds the `spv.EntryPoint` and `spv.ExecutionMode` operations into the
+ `spv.module` for the entry function.
## Serialization and deserialization
@@ -1052,29 +1058,8 @@ is obtained from the layout specification of the memref. The storage class of
the pointer type are derived from the memref's memory space with
`SPIRVTypeConverter::getStorageClassForMemorySpace()`.
-### `SPIRVOpLowering`
-
-`mlir::SPIRVOpLowering` is a base class that can be used to define the patterns
-used for implementing the lowering. For now this only provides derived classes
-access to an instance of `mlir::SPIRVTypeLowering` class.
-
### Utility functions for lowering
-#### Setting shader interface
-
-The method `mlir::spirv::setABIAttrs` allows setting the [shader interface
-attributes](#shader-interface-abi) for a function that is to be an entry
-point function within the `spv.module` on lowering. A later pass
-`mlir::spirv::LowerABIAttributesPass` uses this information to lower the entry
-point function and its ABI consistent with the Vulkan validation
-rules. Specifically,
-
-* Creates `spv.GlobalVariable`s for the arguments, and replaces all uses of
- the argument with this variable. The SSA value used for replacement is
- obtained using the `spv.mlir.addressof` operation.
-* Adds the `spv.EntryPoint` and `spv.ExecutionMode` operations into the
- `spv.module` for the entry function.
-
#### Setting layout for shader interface variables
SPIR-V validation rules for shaders require composite objects to be explicitly
diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
index 881f8e90fa0db..27e47f92396bc 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
@@ -21,6 +21,10 @@
namespace mlir {
+//===----------------------------------------------------------------------===//
+// Type Converter
+//===----------------------------------------------------------------------===//
+
/// Type conversion from builtin types to SPIR-V types for shader interface.
///
/// Non-32-bit scalar types require special hardware support that may not exist
@@ -63,24 +67,22 @@ class SPIRVTypeConverter : public TypeConverter {
spirv::TargetEnv targetEnv;
};
-/// Appends to a pattern list additional patterns for translating the builtin
-/// `func` op to the SPIR-V dialect. These patterns do not handle shader
-/// interface/ABI; they convert function parameters to be of SPIR-V allowed
-/// types.
-void populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
- RewritePatternSet &patterns);
-
-namespace spirv {
-class AccessChainOp;
-class FuncOp;
+//===----------------------------------------------------------------------===//
+// Conversion Target
+//===----------------------------------------------------------------------===//
+// The default SPIR-V conversion target.
+//
+// It takes a SPIR-V target environment and controls operation legality based on
+// the their availability in the target environment.
class SPIRVConversionTarget : public ConversionTarget {
public:
/// Creates a SPIR-V conversion target for the given target environment.
- static std::unique_ptr<SPIRVConversionTarget> get(TargetEnvAttr targetAttr);
+ static std::unique_ptr<SPIRVConversionTarget>
+ get(spirv::TargetEnvAttr targetAttr);
private:
- explicit SPIRVConversionTarget(TargetEnvAttr targetAttr);
+ explicit SPIRVConversionTarget(spirv::TargetEnvAttr targetAttr);
// Be explicit that instance of this class cannot be copied or moved: there
// are lambdas capturing fields of the instance.
@@ -93,16 +95,37 @@ class SPIRVConversionTarget : public ConversionTarget {
/// environment.
bool isLegalOp(Operation *op);
- TargetEnv targetEnv;
+ spirv::TargetEnv targetEnv;
};
+//===----------------------------------------------------------------------===//
+// Patterns and Utility Functions
+//===----------------------------------------------------------------------===//
+
+/// Appends to a pattern list additional patterns for translating the builtin
+/// `func` op to the SPIR-V dialect. These patterns do not handle shader
+/// interface/ABI; they convert function parameters to be of SPIR-V allowed
+/// types.
+void populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
+ RewritePatternSet &patterns);
+
+namespace spirv {
+class AccessChainOp;
+
/// Returns the value for the given `builtin` variable. This function gets or
/// inserts the global variable associated for the builtin within the nearest
-/// enclosing op that has a symbol table. Returns null Value if such an
-/// enclosing op cannot be found.
+/// symbol table enclosing `op`. Returns null Value on error.
Value getBuiltinVariableValue(Operation *op, BuiltIn builtin,
OpBuilder &builder);
+/// Gets the value at the given `offset` of the push constant storage with a
+/// total of `elementCount` 32-bit integers. A global variable will be created
+/// in the nearest symbol table enclosing `op` for the push constant storage if
+/// not existing. Load ops will be created via the given `builder` to load
+/// values from the push constant. Returns null Value on error.
+Value getPushConstantValue(Operation *op, unsigned elementCount,
+ unsigned offset, OpBuilder &builder);
+
/// Generates IR to perform index linearization with the given `indices` and
/// their corresponding `strides`, adding an initial `offset`.
Value linearizeIndex(ValueRange indices, ArrayRef<int64_t> strides,
@@ -118,11 +141,6 @@ spirv::AccessChainOp getElementPtr(SPIRVTypeConverter &typeConverter,
ValueRange indices, Location loc,
OpBuilder &builder);
-/// Sets the InterfaceVarABIAttr and EntryPointABIAttr for a function and its
-/// arguments.
-LogicalResult setABIAttrs(spirv::FuncOp funcOp,
- EntryPointABIAttr entryPointInfo,
- ArrayRef<InterfaceVarABIAttr> argABIInfo);
} // namespace spirv
} // namespace mlir
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index c2cd4baea631c..2066debb7d45e 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -208,8 +208,13 @@ lowerAsEntryFunction(gpu::GPUFuncOp funcOp, TypeConverter &typeConverter,
return nullptr;
rewriter.eraseOp(funcOp);
- if (failed(spirv::setABIAttrs(newFuncOp, entryPointInfo, argABIInfo)))
- return nullptr;
+ // Set the attributes for argument and the function.
+ StringRef argABIAttrName = spirv::getInterfaceVarABIAttrName();
+ for (auto argIndex : llvm::seq<unsigned>(0, argABIInfo.size())) {
+ newFuncOp.setArgAttr(argIndex, argABIAttrName, argABIInfo[argIndex]);
+ }
+ newFuncOp->setAttr(spirv::getEntryPointABIAttrName(), entryPointInfo);
+
return newFuncOp;
}
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
index 1f23f7ce380ef..7109b2b1e8637 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
@@ -54,7 +54,7 @@ void GPUToSPIRVPass::runOnOperation() {
auto targetAttr = spirv::lookupTargetEnvOrDefault(module);
std::unique_ptr<ConversionTarget> target =
- spirv::SPIRVConversionTarget::get(targetAttr);
+ SPIRVConversionTarget::get(targetAttr);
SPIRVTypeConverter typeConverter(targetAttr);
RewritePatternSet patterns(context);
diff --git a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp
index 6038a9841cf95..815e3d16e54c5 100644
--- a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp
@@ -27,7 +27,7 @@ void LinalgToSPIRVPass::runOnOperation() {
auto targetAttr = spirv::lookupTargetEnvOrDefault(module);
std::unique_ptr<ConversionTarget> target =
- spirv::SPIRVConversionTarget::get(targetAttr);
+ SPIRVConversionTarget::get(targetAttr);
SPIRVTypeConverter typeConverter(targetAttr);
RewritePatternSet patterns(context);
diff --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp
index 637e6a7501b71..84e64d6cc62a0 100644
--- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp
@@ -33,7 +33,7 @@ void SCFToSPIRVPass::runOnOperation() {
auto targetAttr = spirv::lookupTargetEnvOrDefault(module);
std::unique_ptr<ConversionTarget> target =
- spirv::SPIRVConversionTarget::get(targetAttr);
+ SPIRVConversionTarget::get(targetAttr);
SPIRVTypeConverter typeConverter(targetAttr);
ScfToSPIRVContext scfContext;
diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRVPass.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRVPass.cpp
index c738537f74382..5cd00e1c8cd2d 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRVPass.cpp
@@ -32,7 +32,7 @@ void ConvertStandardToSPIRVPass::runOnOperation() {
auto targetAttr = spirv::lookupTargetEnvOrDefault(module);
std::unique_ptr<ConversionTarget> target =
- spirv::SPIRVConversionTarget::get(targetAttr);
+ SPIRVConversionTarget::get(targetAttr);
SPIRVTypeConverter typeConverter(targetAttr);
RewritePatternSet patterns(context);
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp
index e170df2948fef..9ffd1e595cd59 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp
@@ -34,7 +34,7 @@ void LowerVectorToSPIRVPass::runOnOperation() {
auto targetAttr = spirv::lookupTargetEnvOrDefault(module);
std::unique_ptr<ConversionTarget> target =
- spirv::SPIRVConversionTarget::get(targetAttr);
+ SPIRVConversionTarget::get(targetAttr);
SPIRVTypeConverter typeConverter(targetAttr);
RewritePatternSet patterns(context);
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 7539e9a050765..4c01b97cc0bee 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -171,12 +171,14 @@ static Optional<int64_t> getTypeNumBytes(Type t) {
}
return bitWidth / 8;
}
+
if (auto vecType = t.dyn_cast<VectorType>()) {
auto elementSize = getTypeNumBytes(vecType.getElementType());
if (!elementSize)
return llvm::None;
return vecType.getNumElements() * *elementSize;
}
+
if (auto memRefType = t.dyn_cast<MemRefType>()) {
// TODO: Layout should also be controlled by the ABI attributes. For now
// using the layout from MemRef.
@@ -207,7 +209,9 @@ static Optional<int64_t> getTypeNumBytes(Type t) {
memrefSize = std::max(memrefSize, shape.value() * strides[shape.index()]);
}
return (offset + memrefSize) * elementSize.getValue();
- } else if (auto tensorType = t.dyn_cast<TensorType>()) {
+ }
+
+ if (auto tensorType = t.dyn_cast<TensorType>()) {
if (!tensorType.hasStaticShape()) {
return llvm::None;
}
@@ -221,6 +225,7 @@ static Optional<int64_t> getTypeNumBytes(Type t) {
}
return size;
}
+
// TODO: Add size computation for other types.
return llvm::None;
}
@@ -602,6 +607,80 @@ Value mlir::spirv::getBuiltinVariableValue(Operation *op,
return builder.create<spirv::LoadOp>(op->getLoc(), ptr);
}
+//===----------------------------------------------------------------------===//
+// Push constant storage
+//===----------------------------------------------------------------------===//
+
+/// Returns the pointer type for the push constant storage containing
+/// `elementCount` 32-bit integer values.
+static spirv::PointerType getPushConstantStorageType(unsigned elementCount,
+ Builder &builder) {
+ auto arrayType = spirv::ArrayType::get(
+ SPIRVTypeConverter::getIndexType(builder.getContext()), elementCount,
+ /*stride=*/4);
+ auto structType = spirv::StructType::get({arrayType}, /*offsetInfo=*/0);
+ return spirv::PointerType::get(structType, spirv::StorageClass::PushConstant);
+}
+
+/// Returns the push constant varible containing `elementCount` 32-bit integer
+/// values in `body`. Returns null op if such an op does not exit.
+static spirv::GlobalVariableOp getPushConstantVariable(Block &body,
+ unsigned elementCount) {
+ for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
+ auto ptrType = varOp.type().cast<spirv::PointerType>();
+ // Note that Vulkan requires "There must be no more than one push constant
+ // block statically used per shader entry point." So we should always reuse
+ // the existing one.
+ if (ptrType.getStorageClass() == spirv::StorageClass::PushConstant) {
+ auto numElements = ptrType.getPointeeType()
+ .cast<spirv::StructType>()
+ .getElementType(0)
+ .cast<spirv::ArrayType>()
+ .getNumElements();
+ if (numElements == elementCount)
+ return varOp;
+ }
+ }
+ return nullptr;
+}
+
+/// Gets or inserts a global variable for push constant storage containing
+/// `elementCount` 32-bit integer values in `block`.
+static spirv::GlobalVariableOp
+getOrInsertPushConstantVariable(Location loc, Block &block,
+ unsigned elementCount, OpBuilder &b) {
+ if (auto varOp = getPushConstantVariable(block, elementCount))
+ return varOp;
+
+ auto builder = OpBuilder::atBlockBegin(&block, b.getListener());
+ auto type = getPushConstantStorageType(elementCount, builder);
+ const char *name = "__push_constant_var__";
+ return builder.create<spirv::GlobalVariableOp>(loc, type, name,
+ /*initializer=*/nullptr);
+}
+
+Value spirv::getPushConstantValue(Operation *op, unsigned elementCount,
+ unsigned offset, OpBuilder &builder) {
+ Location loc = op->getLoc();
+ Operation *parent = SymbolTable::getNearestSymbolTable(op->getParentOp());
+ if (!parent) {
+ op->emitError("expected operation to be within a module-like op");
+ return nullptr;
+ }
+
+ spirv::GlobalVariableOp varOp = getOrInsertPushConstantVariable(
+ loc, parent->getRegion(0).front(), elementCount, builder);
+
+ auto i32Type = SPIRVTypeConverter::getIndexType(builder.getContext());
+ Value zeroOp = spirv::ConstantOp::getZero(i32Type, loc, builder);
+ Value offsetOp = builder.create<spirv::ConstantOp>(
+ loc, i32Type, builder.getI32IntegerAttr(offset));
+ auto addrOp = builder.create<spirv::AddressOfOp>(loc, varOp);
+ auto acOp = builder.create<spirv::AccessChainOp>(
+ loc, addrOp, llvm::makeArrayRef({zeroOp, offsetOp}));
+ return builder.create<spirv::LoadOp>(loc, acOp);
+}
+
//===----------------------------------------------------------------------===//
// Index calculation
//===----------------------------------------------------------------------===//
@@ -661,45 +740,27 @@ spirv::AccessChainOp mlir::spirv::getElementPtr(
return builder.create<spirv::AccessChainOp>(loc, basePtr, linearizedIndices);
}
-//===----------------------------------------------------------------------===//
-// Set ABI attributes for lowering entry functions.
-//===----------------------------------------------------------------------===//
-
-LogicalResult
-mlir::spirv::setABIAttrs(spirv::FuncOp funcOp,
- spirv::EntryPointABIAttr entryPointInfo,
- ArrayRef<spirv::InterfaceVarABIAttr> argABIInfo) {
- // Set the attributes for argument and the function.
- StringRef argABIAttrName = spirv::getInterfaceVarABIAttrName();
- for (auto argIndex : llvm::seq<unsigned>(0, argABIInfo.size())) {
- funcOp.setArgAttr(argIndex, argABIAttrName, argABIInfo[argIndex]);
- }
- funcOp->setAttr(spirv::getEntryPointABIAttrName(), entryPointInfo);
- return success();
-}
-
//===----------------------------------------------------------------------===//
// SPIR-V ConversionTarget
//===----------------------------------------------------------------------===//
-std::unique_ptr<spirv::SPIRVConversionTarget>
-spirv::SPIRVConversionTarget::get(spirv::TargetEnvAttr targetAttr) {
+std::unique_ptr<SPIRVConversionTarget>
+SPIRVConversionTarget::get(spirv::TargetEnvAttr targetAttr) {
std::unique_ptr<SPIRVConversionTarget> target(
// std::make_unique does not work here because the constructor is private.
new SPIRVConversionTarget(targetAttr));
SPIRVConversionTarget *targetPtr = target.get();
- target->addDynamicallyLegalDialect<SPIRVDialect>(
+ target->addDynamicallyLegalDialect<spirv::SPIRVDialect>(
// We need to capture the raw pointer here because it is stable:
// target will be destroyed once this function is returned.
[targetPtr](Operation *op) { return targetPtr->isLegalOp(op); });
return target;
}
-spirv::SPIRVConversionTarget::SPIRVConversionTarget(
- spirv::TargetEnvAttr targetAttr)
+SPIRVConversionTarget::SPIRVConversionTarget(spirv::TargetEnvAttr targetAttr)
: ConversionTarget(*targetAttr.getContext()), targetEnv(targetAttr) {}
-bool spirv::SPIRVConversionTarget::isLegalOp(Operation *op) {
+bool SPIRVConversionTarget::isLegalOp(Operation *op) {
// Make sure this op is available at the given version. Ops not implementing
// QueryMinVersionInterface/QueryMaxVersionInterface are available to all
// SPIR-V versions.
diff --git a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
index 11cd05aa9bec8..d8a0d4bd70abb 100644
--- a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
+++ b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
@@ -137,7 +137,7 @@ void ConvertToTargetEnv::runOnFunction() {
return signalPassFailure();
}
- auto target = spirv::SPIRVConversionTarget::get(targetEnv);
+ auto target = SPIRVConversionTarget::get(targetEnv);
RewritePatternSet patterns(context);
patterns.add<ConvertToAtomCmpExchangeWeak, ConvertToBitReverse,
More information about the Mlir-commits
mailing list