[Mlir-commits] [mlir] 1e35a76 - [mlir][spirv] Initial support for 64 bit index type and builtins
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Aug 26 15:43:16 PDT 2021
Author: Butygin
Date: 2021-08-27T01:38:53+03:00
New Revision: 1e35a7690d778d0e03add6c8ea33888d46199326
URL: https://github.com/llvm/llvm-project/commit/1e35a7690d778d0e03add6c8ea33888d46199326
DIFF: https://github.com/llvm/llvm-project/commit/1e35a7690d778d0e03add6c8ea33888d46199326.diff
LOG: [mlir][spirv] Initial support for 64 bit index type and builtins
Differential Revision: https://reviews.llvm.org/D108516
Added:
Modified:
mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp
mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
index a11214673cbac..21372d5b7a63e 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
@@ -49,20 +49,25 @@ class SPIRVTypeConverter : public TypeConverter {
/// values will be packed into one 32-bit value to be memory efficient.
bool emulateNon32BitScalarTypes;
+ /// Use 64-bit integers to convert index types.
+ bool use64bitIndex;
+
/// The number of bits to store a boolean value. It is eight bits by
/// default.
unsigned boolNumBits;
- // Note: we need this instead of inline initializers becuase of
+ // Note: we need this instead of inline initializers because of
// https://bugs.llvm.org/show_bug.cgi?id=36684
- Options() : emulateNon32BitScalarTypes(true), boolNumBits(8) {}
+ Options()
+ : emulateNon32BitScalarTypes(true), use64bitIndex(false),
+ boolNumBits(8) {}
};
explicit SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
Options options = {});
/// Gets the SPIR-V correspondence for the standard index type.
- static Type getIndexType(MLIRContext *context);
+ Type getIndexType() const;
/// Returns the corresponding memory space for memref given a SPIR-V storage
/// class.
@@ -79,6 +84,8 @@ class SPIRVTypeConverter : public TypeConverter {
private:
spirv::TargetEnv targetEnv;
Options options;
+
+ MLIRContext *getContext() const;
};
//===----------------------------------------------------------------------===//
@@ -129,21 +136,23 @@ class AccessChainOp;
/// Returns the value for the given `builtin` variable. This function gets or
/// inserts the global variable associated for the builtin within the nearest
/// symbol table enclosing `op`. Returns null Value on error.
-Value getBuiltinVariableValue(Operation *op, BuiltIn builtin,
+Value getBuiltinVariableValue(Operation *op, BuiltIn builtin, Type integerType,
OpBuilder &builder);
/// Gets the value at the given `offset` of the push constant storage with a
-/// total of `elementCount` 32-bit integers. A global variable will be created
-/// in the nearest symbol table enclosing `op` for the push constant storage if
-/// not existing. Load ops will be created via the given `builder` to load
-/// values from the push constant. Returns null Value on error.
+/// total of `elementCount` `integerType` integers. A global variable will be
+/// created in the nearest symbol table enclosing `op` for the push constant
+/// storage if not existing. Load ops will be created via the given `builder` to
+/// load values from the push constant. Returns null Value on error.
Value getPushConstantValue(Operation *op, unsigned elementCount,
- unsigned offset, OpBuilder &builder);
+ unsigned offset, Type integerType,
+ OpBuilder &builder);
/// Generates IR to perform index linearization with the given `indices` and
/// their corresponding `strides`, adding an initial `offset`.
Value linearizeIndex(ValueRange indices, ArrayRef<int64_t> strides,
- int64_t offset, Location loc, OpBuilder &builder);
+ int64_t offset, Type integerType, Location loc,
+ OpBuilder &builder);
/// Performs the index computation to get to the element at `indices` of the
/// memory pointed to by `basePtr`, using the layout map of `baseType`.
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index 79f68e2e476f3..a303a872c55bd 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -135,10 +135,14 @@ LogicalResult LaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
if (!index)
return failure();
+ auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>();
+ auto indexType = typeConverter->getIndexType();
+
// SPIR-V invocation builtin variables are a vector of type <3xi32>
- auto spirvBuiltin = spirv::getBuiltinVariableValue(op, builtin, rewriter);
+ auto spirvBuiltin =
+ spirv::getBuiltinVariableValue(op, builtin, indexType, rewriter);
rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
- op, rewriter.getIntegerType(32), spirvBuiltin,
+ op, indexType, spirvBuiltin,
rewriter.getI32ArrayAttr({index.getValue()}));
return success();
}
@@ -148,7 +152,11 @@ LogicalResult
SingleDimLaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
SourceOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
- auto spirvBuiltin = spirv::getBuiltinVariableValue(op, builtin, rewriter);
+ auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>();
+ auto indexType = typeConverter->getIndexType();
+
+ auto spirvBuiltin =
+ spirv::getBuiltinVariableValue(op, builtin, indexType, rewriter);
rewriter.replaceOp(op, spirvBuiltin);
return success();
}
diff --git a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp
index a94435c043584..c20f4ed6b567d 100644
--- a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp
+++ b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp
@@ -26,11 +26,11 @@ using namespace mlir;
/// Returns a `Value` containing the `dim`-th dimension's size of SPIR-V
/// location invocation ID. This function will create necessary operations with
/// `builder` at the proper region containing `op`.
-static Value getLocalInvocationDimSize(Operation *op, int dim, Location loc,
- OpBuilder *builder) {
+static Value getLocalInvocationDimSize(Operation *op, int dim, Type integerType,
+ Location loc, OpBuilder *builder) {
assert(dim >= 0 && dim < 3 && "local invocation only has three dimensions");
Value invocation = spirv::getBuiltinVariableValue(
- op, spirv::BuiltIn::LocalInvocationId, *builder);
+ op, spirv::BuiltIn::LocalInvocationId, integerType, *builder);
Type xType = invocation.getType().cast<ShapedType>().getElementType();
return builder->create<spirv::CompositeExtractOp>(
loc, xType, invocation, builder->getI32ArrayAttr({dim}));
@@ -137,12 +137,15 @@ LogicalResult SingleWorkgroupReduction::matchAndRewrite(
Value convertedInput = operands[0], convertedOutput = operands[1];
Location loc = genericOp.getLoc();
+ auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
+ auto indexType = typeConverter->getIndexType();
+
// Get the invocation ID.
- Value x = getLocalInvocationDimSize(genericOp, /*dim=*/0, loc, &rewriter);
+ Value x = getLocalInvocationDimSize(genericOp, /*dim=*/0, indexType, loc,
+ &rewriter);
// TODO: Load to Workgroup storage class first.
- auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
// Get the input element accessed by this invocation.
Value inputElementPtr = spirv::getElementPtr(
@@ -164,8 +167,7 @@ LogicalResult SingleWorkgroupReduction::matchAndRewrite(
#undef CREATE_GROUP_NON_UNIFORM_BIN_OP
// Get the output element accessed by this reduction.
- Value zero = spirv::ConstantOp::getZero(
- typeConverter->getIndexType(rewriter.getContext()), loc, rewriter);
+ Value zero = spirv::ConstantOp::getZero(indexType, loc, rewriter);
SmallVector<Value, 1> zeroIndices(originalOutputType.getRank(), zero);
Value outputElementPtr =
spirv::getElementPtr(*typeConverter, originalOutputType, convertedOutput,
diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
index b65898e77d7fd..c59f291adc1e9 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
@@ -373,8 +373,11 @@ class TensorExtractPattern final
return failure();
}
+ auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
+ auto indexType = typeConverter.getIndexType();
+
Value index = spirv::linearizeIndex(adaptor.indices(), strides,
- /*offset=*/0, loc, rewriter);
+ /*offset=*/0, indexType, loc, rewriter);
auto acOp = rewriter.create<spirv::AccessChainOp>(loc, varOp, index);
rewriter.replaceOpWithNewOp<spirv::LoadOp>(extractOp, acOp);
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
index ae607f249c078..20a793d880dc1 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
@@ -178,6 +178,9 @@ LogicalResult ProcessInterfaceVarABI::matchAndRewrite(
TypeConverter::SignatureConversion signatureConverter(
funcOp.getType().getNumInputs());
+ auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
+ auto indexType = typeConverter.getIndexType();
+
auto attrName = spirv::getInterfaceVarABIAttrName();
for (auto argType : llvm::enumerate(funcOp.getType().getInputs())) {
auto abiInfo = funcOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(
@@ -206,7 +209,6 @@ LogicalResult ProcessInterfaceVarABI::matchAndRewrite(
// before the use. There might be multiple loads and currently there is no
// easy way to replace all uses with a sequence of operations.
if (argType.value().cast<spirv::SPIRVType>().isScalarOrVector()) {
- auto indexType = SPIRVTypeConverter::getIndexType(funcOp.getContext());
auto zero =
spirv::ConstantOp::getZero(indexType, funcOp.getLoc(), rewriter);
auto loadPtr = rewriter.create<spirv::AccessChainOp>(
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 6e807a72ac1cd..9097ef578abf1 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -112,15 +112,8 @@ wrapInStructAndGetPointer(Type elementType, spirv::StorageClass storageClass) {
// Type Conversion
//===----------------------------------------------------------------------===//
-Type SPIRVTypeConverter::getIndexType(MLIRContext *context) {
- // Convert to 32-bit integers for now. Might need a way to control this in
- // future.
- // TODO: It is probably better to make it 64-bit integers. To
- // this some support is needed in SPIR-V dialect for Conversion
- // instructions. The Vulkan spec requires the builtins like
- // GlobalInvocationID, etc. to be 32-bit (unsigned) integers which should be
- // SExtended to 64-bit for index computations.
- return IntegerType::get(context, 32);
+Type SPIRVTypeConverter::getIndexType() const {
+ return IntegerType::get(getContext(), options.use64bitIndex ? 64 : 32);
}
/// Mapping between SPIR-V storage classes to memref memory spaces.
@@ -183,6 +176,10 @@ const SPIRVTypeConverter::Options &SPIRVTypeConverter::getOptions() const {
return options;
}
+MLIRContext *SPIRVTypeConverter::getContext() const {
+ return targetEnv.getAttr().getContext();
+}
+
#undef STORAGE_SPACE_MAP_LIST
// TODO: This is a utility function that should probably be exposed by the
@@ -505,9 +502,7 @@ SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
// want to validate and convert to be safe.
addConversion([](spirv::SPIRVType type) { return type; });
- addConversion([](IndexType indexType) {
- return SPIRVTypeConverter::getIndexType(indexType.getContext());
- });
+ addConversion([this](IndexType /*indexType*/) { return getIndexType(); });
addConversion([this](IntegerType intType) -> Optional<Type> {
if (auto scalarType = intType.dyn_cast<spirv::ScalarType>())
@@ -630,7 +625,7 @@ static std::string getBuiltinVarName(spirv::BuiltIn builtin) {
/// Gets or inserts a global variable for a builtin within `body` block.
static spirv::GlobalVariableOp
getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin,
- OpBuilder &builder) {
+ Type integerType, OpBuilder &builder) {
if (auto varOp = getBuiltinVariable(body, builtin))
return varOp;
@@ -644,9 +639,8 @@ getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin,
case spirv::BuiltIn::WorkgroupId:
case spirv::BuiltIn::LocalInvocationId:
case spirv::BuiltIn::GlobalInvocationId: {
- auto ptrType = spirv::PointerType::get(
- VectorType::get({3}, builder.getIntegerType(32)),
- spirv::StorageClass::Input);
+ auto ptrType = spirv::PointerType::get(VectorType::get({3}, integerType),
+ spirv::StorageClass::Input);
std::string name = getBuiltinVarName(builtin);
newVarOp =
builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
@@ -655,8 +649,8 @@ getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin,
case spirv::BuiltIn::SubgroupId:
case spirv::BuiltIn::NumSubgroups:
case spirv::BuiltIn::SubgroupSize: {
- auto ptrType = spirv::PointerType::get(builder.getIntegerType(32),
- spirv::StorageClass::Input);
+ auto ptrType =
+ spirv::PointerType::get(integerType, spirv::StorageClass::Input);
std::string name = getBuiltinVarName(builtin);
newVarOp =
builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
@@ -671,6 +665,7 @@ getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin,
Value mlir::spirv::getBuiltinVariableValue(Operation *op,
spirv::BuiltIn builtin,
+ Type integerType,
OpBuilder &builder) {
Operation *parent = SymbolTable::getNearestSymbolTable(op->getParentOp());
if (!parent) {
@@ -678,8 +673,9 @@ Value mlir::spirv::getBuiltinVariableValue(Operation *op,
return nullptr;
}
- spirv::GlobalVariableOp varOp = getOrInsertBuiltinVariable(
- *parent->getRegion(0).begin(), op->getLoc(), builtin, builder);
+ spirv::GlobalVariableOp varOp =
+ getOrInsertBuiltinVariable(*parent->getRegion(0).begin(), op->getLoc(),
+ builtin, integerType, builder);
Value ptr = builder.create<spirv::AddressOfOp>(op->getLoc(), varOp);
return builder.create<spirv::LoadOp>(op->getLoc(), ptr);
}
@@ -691,10 +687,10 @@ Value mlir::spirv::getBuiltinVariableValue(Operation *op,
/// Returns the pointer type for the push constant storage containing
/// `elementCount` 32-bit integer values.
static spirv::PointerType getPushConstantStorageType(unsigned elementCount,
- Builder &builder) {
- auto arrayType = spirv::ArrayType::get(
- SPIRVTypeConverter::getIndexType(builder.getContext()), elementCount,
- /*stride=*/4);
+ Builder &builder,
+ Type indexType) {
+ auto arrayType = spirv::ArrayType::get(indexType, elementCount,
+ /*stride=*/4);
auto structType = spirv::StructType::get({arrayType}, /*offsetInfo=*/0);
return spirv::PointerType::get(structType, spirv::StorageClass::PushConstant);
}
@@ -725,19 +721,21 @@ static spirv::GlobalVariableOp getPushConstantVariable(Block &body,
/// `elementCount` 32-bit integer values in `block`.
static spirv::GlobalVariableOp
getOrInsertPushConstantVariable(Location loc, Block &block,
- unsigned elementCount, OpBuilder &b) {
+ unsigned elementCount, OpBuilder &b,
+ Type indexType) {
if (auto varOp = getPushConstantVariable(block, elementCount))
return varOp;
auto builder = OpBuilder::atBlockBegin(&block, b.getListener());
- auto type = getPushConstantStorageType(elementCount, builder);
+ auto type = getPushConstantStorageType(elementCount, builder, indexType);
const char *name = "__push_constant_var__";
return builder.create<spirv::GlobalVariableOp>(loc, type, name,
/*initializer=*/nullptr);
}
Value spirv::getPushConstantValue(Operation *op, unsigned elementCount,
- unsigned offset, OpBuilder &builder) {
+ unsigned offset, Type integerType,
+ OpBuilder &builder) {
Location loc = op->getLoc();
Operation *parent = SymbolTable::getNearestSymbolTable(op->getParentOp());
if (!parent) {
@@ -746,12 +744,11 @@ Value spirv::getPushConstantValue(Operation *op, unsigned elementCount,
}
spirv::GlobalVariableOp varOp = getOrInsertPushConstantVariable(
- loc, parent->getRegion(0).front(), elementCount, builder);
+ loc, parent->getRegion(0).front(), elementCount, builder, integerType);
- auto i32Type = SPIRVTypeConverter::getIndexType(builder.getContext());
- Value zeroOp = spirv::ConstantOp::getZero(i32Type, loc, builder);
+ Value zeroOp = spirv::ConstantOp::getZero(integerType, loc, builder);
Value offsetOp = builder.create<spirv::ConstantOp>(
- loc, i32Type, builder.getI32IntegerAttr(offset));
+ loc, integerType, builder.getI32IntegerAttr(offset));
auto addrOp = builder.create<spirv::AddressOfOp>(loc, varOp);
auto acOp = builder.create<spirv::AccessChainOp>(
loc, addrOp, llvm::makeArrayRef({zeroOp, offsetOp}));
@@ -763,23 +760,22 @@ Value spirv::getPushConstantValue(Operation *op, unsigned elementCount,
//===----------------------------------------------------------------------===//
Value mlir::spirv::linearizeIndex(ValueRange indices, ArrayRef<int64_t> strides,
- int64_t offset, Location loc,
- OpBuilder &builder) {
+ int64_t offset, Type integerType,
+ Location loc, OpBuilder &builder) {
assert(indices.size() == strides.size() &&
"must provide indices for all dimensions");
- auto indexType = SPIRVTypeConverter::getIndexType(builder.getContext());
-
// TODO: Consider moving to use affine.apply and patterns converting
// affine.apply to standard ops. This needs converting to SPIR-V passes to be
// broken down into progressive small steps so we can have intermediate steps
// using other dialects. At the moment SPIR-V is the final sink.
Value linearizedIndex = builder.create<spirv::ConstantOp>(
- loc, indexType, IntegerAttr::get(indexType, offset));
+ loc, integerType, IntegerAttr::get(integerType, offset));
for (auto index : llvm::enumerate(indices)) {
Value strideVal = builder.create<spirv::ConstantOp>(
- loc, indexType, IntegerAttr::get(indexType, strides[index.index()]));
+ loc, integerType,
+ IntegerAttr::get(integerType, strides[index.index()]));
Value update = builder.create<spirv::IMulOp>(loc, strideVal, index.value());
linearizedIndex =
builder.create<spirv::IAddOp>(loc, linearizedIndex, update);
@@ -800,7 +796,7 @@ spirv::AccessChainOp mlir::spirv::getElementPtr(
return nullptr;
}
- auto indexType = typeConverter.getIndexType(builder.getContext());
+ auto indexType = typeConverter.getIndexType();
SmallVector<Value, 2> linearizedIndices;
auto zero = spirv::ConstantOp::getZero(indexType, loc, builder);
@@ -812,7 +808,7 @@ spirv::AccessChainOp mlir::spirv::getElementPtr(
linearizedIndices.push_back(zero);
} else {
linearizedIndices.push_back(
- linearizeIndex(indices, strides, offset, loc, builder));
+ linearizeIndex(indices, strides, offset, indexType, loc, builder));
}
return builder.create<spirv::AccessChainOp>(loc, basePtr, linearizedIndices);
}
More information about the Mlir-commits
mailing list