[Mlir-commits] [mlir] 1e35a76 - [mlir][spirv] Initial support for 64 bit index type and builtins

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Aug 26 15:43:16 PDT 2021


Author: Butygin
Date: 2021-08-27T01:38:53+03:00
New Revision: 1e35a7690d778d0e03add6c8ea33888d46199326

URL: https://github.com/llvm/llvm-project/commit/1e35a7690d778d0e03add6c8ea33888d46199326
DIFF: https://github.com/llvm/llvm-project/commit/1e35a7690d778d0e03add6c8ea33888d46199326.diff

LOG: [mlir][spirv] Initial support for 64 bit index type and builtins

Differential Revision: https://reviews.llvm.org/D108516

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
    mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
    mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp
    mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
    mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
    mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
index a11214673cbac..21372d5b7a63e 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
@@ -49,20 +49,25 @@ class SPIRVTypeConverter : public TypeConverter {
     /// values will be packed into one 32-bit value to be memory efficient.
     bool emulateNon32BitScalarTypes;
 
+    /// Use 64-bit integers to convert index types.
+    bool use64bitIndex;
+
     /// The number of bits to store a boolean value. It is eight bits by
     /// default.
     unsigned boolNumBits;
 
-    // Note: we need this instead of inline initializers becuase of
+    // Note: we need this instead of inline initializers because of
     // https://bugs.llvm.org/show_bug.cgi?id=36684
-    Options() : emulateNon32BitScalarTypes(true), boolNumBits(8) {}
+    Options()
+        : emulateNon32BitScalarTypes(true), use64bitIndex(false),
+          boolNumBits(8) {}
   };
 
   explicit SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
                               Options options = {});
 
   /// Gets the SPIR-V correspondence for the standard index type.
-  static Type getIndexType(MLIRContext *context);
+  Type getIndexType() const;
 
   /// Returns the corresponding memory space for memref given a SPIR-V storage
   /// class.
@@ -79,6 +84,8 @@ class SPIRVTypeConverter : public TypeConverter {
 private:
   spirv::TargetEnv targetEnv;
   Options options;
+
+  MLIRContext *getContext() const;
 };
 
 //===----------------------------------------------------------------------===//
@@ -129,21 +136,23 @@ class AccessChainOp;
 /// Returns the value for the given `builtin` variable. This function gets or
 /// inserts the global variable associated for the builtin within the nearest
 /// symbol table enclosing `op`. Returns null Value on error.
-Value getBuiltinVariableValue(Operation *op, BuiltIn builtin,
+Value getBuiltinVariableValue(Operation *op, BuiltIn builtin, Type integerType,
                               OpBuilder &builder);
 
 /// Gets the value at the given `offset` of the push constant storage with a
-/// total of `elementCount` 32-bit integers. A global variable will be created
-/// in the nearest symbol table enclosing `op` for the push constant storage if
-/// not existing. Load ops will be created via the given `builder` to load
-/// values from the push constant. Returns null Value on error.
+/// total of `elementCount` `integerType` integers. A global variable will be
+/// created in the nearest symbol table enclosing `op` for the push constant
+/// storage if not existing. Load ops will be created via the given `builder` to
+/// load values from the push constant. Returns null Value on error.
 Value getPushConstantValue(Operation *op, unsigned elementCount,
-                           unsigned offset, OpBuilder &builder);
+                           unsigned offset, Type integerType,
+                           OpBuilder &builder);
 
 /// Generates IR to perform index linearization with the given `indices` and
 /// their corresponding `strides`, adding an initial `offset`.
 Value linearizeIndex(ValueRange indices, ArrayRef<int64_t> strides,
-                     int64_t offset, Location loc, OpBuilder &builder);
+                     int64_t offset, Type integerType, Location loc,
+                     OpBuilder &builder);
 
 /// Performs the index computation to get to the element at `indices` of the
 /// memory pointed to by `basePtr`, using the layout map of `baseType`.

diff  --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index 79f68e2e476f3..a303a872c55bd 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -135,10 +135,14 @@ LogicalResult LaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
   if (!index)
     return failure();
 
+  auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>();
+  auto indexType = typeConverter->getIndexType();
+
   // SPIR-V invocation builtin variables are a vector of type <3xi32>
-  auto spirvBuiltin = spirv::getBuiltinVariableValue(op, builtin, rewriter);
+  auto spirvBuiltin =
+      spirv::getBuiltinVariableValue(op, builtin, indexType, rewriter);
   rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
-      op, rewriter.getIntegerType(32), spirvBuiltin,
+      op, indexType, spirvBuiltin,
       rewriter.getI32ArrayAttr({index.getValue()}));
   return success();
 }
@@ -148,7 +152,11 @@ LogicalResult
 SingleDimLaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
     SourceOp op, ArrayRef<Value> operands,
     ConversionPatternRewriter &rewriter) const {
-  auto spirvBuiltin = spirv::getBuiltinVariableValue(op, builtin, rewriter);
+  auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>();
+  auto indexType = typeConverter->getIndexType();
+
+  auto spirvBuiltin =
+      spirv::getBuiltinVariableValue(op, builtin, indexType, rewriter);
   rewriter.replaceOp(op, spirvBuiltin);
   return success();
 }

diff  --git a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp
index a94435c043584..c20f4ed6b567d 100644
--- a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp
+++ b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp
@@ -26,11 +26,11 @@ using namespace mlir;
 /// Returns a `Value` containing the `dim`-th dimension's size of SPIR-V
 /// location invocation ID. This function will create necessary operations with
 /// `builder` at the proper region containing `op`.
-static Value getLocalInvocationDimSize(Operation *op, int dim, Location loc,
-                                       OpBuilder *builder) {
+static Value getLocalInvocationDimSize(Operation *op, int dim, Type integerType,
+                                       Location loc, OpBuilder *builder) {
   assert(dim >= 0 && dim < 3 && "local invocation only has three dimensions");
   Value invocation = spirv::getBuiltinVariableValue(
-      op, spirv::BuiltIn::LocalInvocationId, *builder);
+      op, spirv::BuiltIn::LocalInvocationId, integerType, *builder);
   Type xType = invocation.getType().cast<ShapedType>().getElementType();
   return builder->create<spirv::CompositeExtractOp>(
       loc, xType, invocation, builder->getI32ArrayAttr({dim}));
@@ -137,12 +137,15 @@ LogicalResult SingleWorkgroupReduction::matchAndRewrite(
   Value convertedInput = operands[0], convertedOutput = operands[1];
   Location loc = genericOp.getLoc();
 
+  auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
+  auto indexType = typeConverter->getIndexType();
+
   // Get the invocation ID.
-  Value x = getLocalInvocationDimSize(genericOp, /*dim=*/0, loc, &rewriter);
+  Value x = getLocalInvocationDimSize(genericOp, /*dim=*/0, indexType, loc,
+                                      &rewriter);
 
   // TODO: Load to Workgroup storage class first.
 
-  auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
 
   // Get the input element accessed by this invocation.
   Value inputElementPtr = spirv::getElementPtr(
@@ -164,8 +167,7 @@ LogicalResult SingleWorkgroupReduction::matchAndRewrite(
 #undef CREATE_GROUP_NON_UNIFORM_BIN_OP
 
   // Get the output element accessed by this reduction.
-  Value zero = spirv::ConstantOp::getZero(
-      typeConverter->getIndexType(rewriter.getContext()), loc, rewriter);
+  Value zero = spirv::ConstantOp::getZero(indexType, loc, rewriter);
   SmallVector<Value, 1> zeroIndices(originalOutputType.getRank(), zero);
   Value outputElementPtr =
       spirv::getElementPtr(*typeConverter, originalOutputType, convertedOutput,

diff  --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
index b65898e77d7fd..c59f291adc1e9 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
@@ -373,8 +373,11 @@ class TensorExtractPattern final
       return failure();
     }
 
+    auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
+    auto indexType = typeConverter.getIndexType();
+
     Value index = spirv::linearizeIndex(adaptor.indices(), strides,
-                                        /*offset=*/0, loc, rewriter);
+                                        /*offset=*/0, indexType, loc, rewriter);
     auto acOp = rewriter.create<spirv::AccessChainOp>(loc, varOp, index);
 
     rewriter.replaceOpWithNewOp<spirv::LoadOp>(extractOp, acOp);

diff  --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
index ae607f249c078..20a793d880dc1 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
@@ -178,6 +178,9 @@ LogicalResult ProcessInterfaceVarABI::matchAndRewrite(
   TypeConverter::SignatureConversion signatureConverter(
       funcOp.getType().getNumInputs());
 
+  auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
+  auto indexType = typeConverter.getIndexType();
+
   auto attrName = spirv::getInterfaceVarABIAttrName();
   for (auto argType : llvm::enumerate(funcOp.getType().getInputs())) {
     auto abiInfo = funcOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(
@@ -206,7 +209,6 @@ LogicalResult ProcessInterfaceVarABI::matchAndRewrite(
     // before the use. There might be multiple loads and currently there is no
     // easy way to replace all uses with a sequence of operations.
     if (argType.value().cast<spirv::SPIRVType>().isScalarOrVector()) {
-      auto indexType = SPIRVTypeConverter::getIndexType(funcOp.getContext());
       auto zero =
           spirv::ConstantOp::getZero(indexType, funcOp.getLoc(), rewriter);
       auto loadPtr = rewriter.create<spirv::AccessChainOp>(

diff  --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 6e807a72ac1cd..9097ef578abf1 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -112,15 +112,8 @@ wrapInStructAndGetPointer(Type elementType, spirv::StorageClass storageClass) {
 // Type Conversion
 //===----------------------------------------------------------------------===//
 
-Type SPIRVTypeConverter::getIndexType(MLIRContext *context) {
-  // Convert to 32-bit integers for now. Might need a way to control this in
-  // future.
-  // TODO: It is probably better to make it 64-bit integers. To
-  // this some support is needed in SPIR-V dialect for Conversion
-  // instructions. The Vulkan spec requires the builtins like
-  // GlobalInvocationID, etc. to be 32-bit (unsigned) integers which should be
-  // SExtended to 64-bit for index computations.
-  return IntegerType::get(context, 32);
+Type SPIRVTypeConverter::getIndexType() const {
+  return IntegerType::get(getContext(), options.use64bitIndex ? 64 : 32);
 }
 
 /// Mapping between SPIR-V storage classes to memref memory spaces.
@@ -183,6 +176,10 @@ const SPIRVTypeConverter::Options &SPIRVTypeConverter::getOptions() const {
   return options;
 }
 
+MLIRContext *SPIRVTypeConverter::getContext() const {
+  return targetEnv.getAttr().getContext();
+}
+
 #undef STORAGE_SPACE_MAP_LIST
 
 // TODO: This is a utility function that should probably be exposed by the
@@ -505,9 +502,7 @@ SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
   // want to validate and convert to be safe.
   addConversion([](spirv::SPIRVType type) { return type; });
 
-  addConversion([](IndexType indexType) {
-    return SPIRVTypeConverter::getIndexType(indexType.getContext());
-  });
+  addConversion([this](IndexType /*indexType*/) { return getIndexType(); });
 
   addConversion([this](IntegerType intType) -> Optional<Type> {
     if (auto scalarType = intType.dyn_cast<spirv::ScalarType>())
@@ -630,7 +625,7 @@ static std::string getBuiltinVarName(spirv::BuiltIn builtin) {
 /// Gets or inserts a global variable for a builtin within `body` block.
 static spirv::GlobalVariableOp
 getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin,
-                           OpBuilder &builder) {
+                           Type integerType, OpBuilder &builder) {
   if (auto varOp = getBuiltinVariable(body, builtin))
     return varOp;
 
@@ -644,9 +639,8 @@ getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin,
   case spirv::BuiltIn::WorkgroupId:
   case spirv::BuiltIn::LocalInvocationId:
   case spirv::BuiltIn::GlobalInvocationId: {
-    auto ptrType = spirv::PointerType::get(
-        VectorType::get({3}, builder.getIntegerType(32)),
-        spirv::StorageClass::Input);
+    auto ptrType = spirv::PointerType::get(VectorType::get({3}, integerType),
+                                           spirv::StorageClass::Input);
     std::string name = getBuiltinVarName(builtin);
     newVarOp =
         builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
@@ -655,8 +649,8 @@ getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin,
   case spirv::BuiltIn::SubgroupId:
   case spirv::BuiltIn::NumSubgroups:
   case spirv::BuiltIn::SubgroupSize: {
-    auto ptrType = spirv::PointerType::get(builder.getIntegerType(32),
-                                           spirv::StorageClass::Input);
+    auto ptrType =
+        spirv::PointerType::get(integerType, spirv::StorageClass::Input);
     std::string name = getBuiltinVarName(builtin);
     newVarOp =
         builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
@@ -671,6 +665,7 @@ getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin,
 
 Value mlir::spirv::getBuiltinVariableValue(Operation *op,
                                            spirv::BuiltIn builtin,
+                                           Type integerType,
                                            OpBuilder &builder) {
   Operation *parent = SymbolTable::getNearestSymbolTable(op->getParentOp());
   if (!parent) {
@@ -678,8 +673,9 @@ Value mlir::spirv::getBuiltinVariableValue(Operation *op,
     return nullptr;
   }
 
-  spirv::GlobalVariableOp varOp = getOrInsertBuiltinVariable(
-      *parent->getRegion(0).begin(), op->getLoc(), builtin, builder);
+  spirv::GlobalVariableOp varOp =
+      getOrInsertBuiltinVariable(*parent->getRegion(0).begin(), op->getLoc(),
+                                 builtin, integerType, builder);
   Value ptr = builder.create<spirv::AddressOfOp>(op->getLoc(), varOp);
   return builder.create<spirv::LoadOp>(op->getLoc(), ptr);
 }
@@ -691,10 +687,10 @@ Value mlir::spirv::getBuiltinVariableValue(Operation *op,
 /// Returns the pointer type for the push constant storage containing
 /// `elementCount` 32-bit integer values.
 static spirv::PointerType getPushConstantStorageType(unsigned elementCount,
-                                                     Builder &builder) {
-  auto arrayType = spirv::ArrayType::get(
-      SPIRVTypeConverter::getIndexType(builder.getContext()), elementCount,
-      /*stride=*/4);
+                                                     Builder &builder,
+                                                     Type indexType) {
+  auto arrayType = spirv::ArrayType::get(indexType, elementCount,
+                                         /*stride=*/4);
   auto structType = spirv::StructType::get({arrayType}, /*offsetInfo=*/0);
   return spirv::PointerType::get(structType, spirv::StorageClass::PushConstant);
 }
@@ -725,19 +721,21 @@ static spirv::GlobalVariableOp getPushConstantVariable(Block &body,
 /// `elementCount` 32-bit integer values in `block`.
 static spirv::GlobalVariableOp
 getOrInsertPushConstantVariable(Location loc, Block &block,
-                                unsigned elementCount, OpBuilder &b) {
+                                unsigned elementCount, OpBuilder &b,
+                                Type indexType) {
   if (auto varOp = getPushConstantVariable(block, elementCount))
     return varOp;
 
   auto builder = OpBuilder::atBlockBegin(&block, b.getListener());
-  auto type = getPushConstantStorageType(elementCount, builder);
+  auto type = getPushConstantStorageType(elementCount, builder, indexType);
   const char *name = "__push_constant_var__";
   return builder.create<spirv::GlobalVariableOp>(loc, type, name,
                                                  /*initializer=*/nullptr);
 }
 
 Value spirv::getPushConstantValue(Operation *op, unsigned elementCount,
-                                  unsigned offset, OpBuilder &builder) {
+                                  unsigned offset, Type integerType,
+                                  OpBuilder &builder) {
   Location loc = op->getLoc();
   Operation *parent = SymbolTable::getNearestSymbolTable(op->getParentOp());
   if (!parent) {
@@ -746,12 +744,11 @@ Value spirv::getPushConstantValue(Operation *op, unsigned elementCount,
   }
 
   spirv::GlobalVariableOp varOp = getOrInsertPushConstantVariable(
-      loc, parent->getRegion(0).front(), elementCount, builder);
+      loc, parent->getRegion(0).front(), elementCount, builder, integerType);
 
-  auto i32Type = SPIRVTypeConverter::getIndexType(builder.getContext());
-  Value zeroOp = spirv::ConstantOp::getZero(i32Type, loc, builder);
+  Value zeroOp = spirv::ConstantOp::getZero(integerType, loc, builder);
   Value offsetOp = builder.create<spirv::ConstantOp>(
-      loc, i32Type, builder.getI32IntegerAttr(offset));
+      loc, integerType, builder.getI32IntegerAttr(offset));
   auto addrOp = builder.create<spirv::AddressOfOp>(loc, varOp);
   auto acOp = builder.create<spirv::AccessChainOp>(
       loc, addrOp, llvm::makeArrayRef({zeroOp, offsetOp}));
@@ -763,23 +760,22 @@ Value spirv::getPushConstantValue(Operation *op, unsigned elementCount,
 //===----------------------------------------------------------------------===//
 
 Value mlir::spirv::linearizeIndex(ValueRange indices, ArrayRef<int64_t> strides,
-                                  int64_t offset, Location loc,
-                                  OpBuilder &builder) {
+                                  int64_t offset, Type integerType,
+                                  Location loc, OpBuilder &builder) {
   assert(indices.size() == strides.size() &&
          "must provide indices for all dimensions");
 
-  auto indexType = SPIRVTypeConverter::getIndexType(builder.getContext());
-
   // TODO: Consider moving to use affine.apply and patterns converting
   // affine.apply to standard ops. This needs converting to SPIR-V passes to be
   // broken down into progressive small steps so we can have intermediate steps
   // using other dialects. At the moment SPIR-V is the final sink.
 
   Value linearizedIndex = builder.create<spirv::ConstantOp>(
-      loc, indexType, IntegerAttr::get(indexType, offset));
+      loc, integerType, IntegerAttr::get(integerType, offset));
   for (auto index : llvm::enumerate(indices)) {
     Value strideVal = builder.create<spirv::ConstantOp>(
-        loc, indexType, IntegerAttr::get(indexType, strides[index.index()]));
+        loc, integerType,
+        IntegerAttr::get(integerType, strides[index.index()]));
     Value update = builder.create<spirv::IMulOp>(loc, strideVal, index.value());
     linearizedIndex =
         builder.create<spirv::IAddOp>(loc, linearizedIndex, update);
@@ -800,7 +796,7 @@ spirv::AccessChainOp mlir::spirv::getElementPtr(
     return nullptr;
   }
 
-  auto indexType = typeConverter.getIndexType(builder.getContext());
+  auto indexType = typeConverter.getIndexType();
 
   SmallVector<Value, 2> linearizedIndices;
   auto zero = spirv::ConstantOp::getZero(indexType, loc, builder);
@@ -812,7 +808,7 @@ spirv::AccessChainOp mlir::spirv::getElementPtr(
     linearizedIndices.push_back(zero);
   } else {
     linearizedIndices.push_back(
-        linearizeIndex(indices, strides, offset, loc, builder));
+        linearizeIndex(indices, strides, offset, indexType, loc, builder));
   }
   return builder.create<spirv::AccessChainOp>(loc, basePtr, linearizedIndices);
 }


        


More information about the Mlir-commits mailing list