[PATCH] D78974: [mlir][StandardToSPIRV] Emulate bitwidths not supported for load op.

Lei Zhang via Phabricator via llvm-commits llvm-commits at lists.llvm.org
Wed Apr 29 09:38:47 PDT 2020


antiagainst added a comment.

Awesome, thanks Hanhan for taking on this! Sorry for a lot of comments; but this is type availability in SPIR-V is quite nuanced. :)



================
Comment at: mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp:103
+// are four elements in one i32, and one element has 8 bits.
+static Value getOffsetOfInt(spirv::AccessChainOp op, int bits,
+                            ConversionPatternRewriter &rewriter) {
----------------
mravishankar wrote:
> Couple of things here
> 1) This assumes bits < 32. Probably need to assert that as well.
> 2) It would be nice to actually not specialize this to 32-bits. You could take the target integer type as an argument and the same logic should more or less hold.
Nit: s/bits/elementBits/


================
Comment at: mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp:103
+// are four elements in one i32, and one element has 8 bits.
+static Value getOffsetOfInt(spirv::AccessChainOp op, int bits,
+                            ConversionPatternRewriter &rewriter) {
----------------
antiagainst wrote:
> mravishankar wrote:
> > Couple of things here
> > 1) This assumes bits < 32. Probably need to assert that as well.
> > 2) It would be nice to actually not specialize this to 32-bits. You could take the target integer type as an argument and the same logic should more or less hold.
> Nit: s/bits/elementBits/
Do we need to pass in the op here? I think we just need the location and the last index? That way this function can be clearer that it is just adjusting an index into 32-bit arrays into another index into `bits`-bit arrays.


================
Comment at: mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp:104
+static Value getOffsetOfInt(spirv::AccessChainOp op, int bits,
+                            ConversionPatternRewriter &rewriter) {
+  assert(32 % bits == 0);
----------------
Just use normal OpBuilder?


================
Comment at: mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp:121
+/// elements. One element was a `bits`-bit integer. The method adjust the last
+/// index to make it access the corresponding i32 element. Note that this only
+/// works for a scalar or 1-D tensor.
----------------
Assert in the function regarding 1-D array?


================
Comment at: mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp:124
+static Value convertToI32AccessChain(SPIRVTypeConverter &typeConverter,
+                                     spirv::AccessChainOp op, int bits,
+                                     ConversionPatternRewriter &rewriter) {
----------------
Nit: s/bits/elementBits/?


================
Comment at: mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp:125
+                                     spirv::AccessChainOp op, int bits,
+                                     ConversionPatternRewriter &rewriter) {
+  const auto loc = op.getLoc();
----------------
Nit: this can just be normal OpBuilder?


================
Comment at: mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp:129
+  auto idx = rewriter.create<spirv::ConstantOp>(
+      loc, i32Type, rewriter.getI32IntegerAttr(32 / bits));
+  auto lastDim = op.getOperation()->getOperand(op.getNumOperands() - 1);
----------------
assert llvm::isPowerOf2_32(bits)?


================
Comment at: mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp:131
+  auto lastDim = op.getOperation()->getOperand(op.getNumOperands() - 1);
+  SmallVector<Value, 4> indices;
+  for (auto it : op.indices())
----------------
auto indices = llvm::to_vector<4>(op.indices())?


================
Comment at: mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp:539
+
+  int bits = memrefType.getElementType().getIntOrFloatBitWidth();
+  Type convertedType = typeConverter.convertType(memrefType.getElementType());
----------------
mravishankar wrote:
> This will assert if this is not an integer. So it might be better to have a different pattern for load stores when the memref is integer type. So one pattern will implement this logic for integer type load/stores. Another pattern will be generic that will be type agnostic (and will return failure for integer types to not intersect with the other pattern)
+1 to having separate patterns and reject not-handled cases early. It's okay to just implement integer for now and add others gradually.


================
Comment at: mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp:540
+  int bits = memrefType.getElementType().getIntOrFloatBitWidth();
+  Type convertedType = typeConverter.convertType(memrefType.getElementType());
+  int convertedBits = convertedType.getIntOrFloatBitWidth();
----------------
The type conversion must factor in the storage class, which is carried as the memref memory space. This affects the converted element type. For example, if `StorageBuffer16BitAccess` is available then  16-bit integers in storage buffer class (which right now mapped to memory space `0`) does not need conversion. If we only consider the element type here it can be wrong because as long as `Int16` is not available, we will convert 16-bit integers to 32-bit. So here we should convert the whole memref type and then get the element type.


================
Comment at: mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp:548
+  if (!convertedType.isSignlessInteger() || (bits == convertedBits)) {
+    Value spvLoadOp =
+        rewriter.create<spirv::LoadOp>(loc, accessChainOp.getResult());
----------------
Just directly update `result` instead creating this local variable?


================
Comment at: mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir:616
+
+// Check that non-32-bit integer types are converted to 32-bit types if the
+// corresponding capabilities are not available.
----------------
This needs to be updated:

// Check that access chain indices are properly adjusted if non-32-bit types are emulated via 32-bit types.


================
Comment at: mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir:626
+// CHECK-LABEL: @load
+func @load(%arg0: memref<i8>, %arg1: memref<10xi16>, %arg2: memref<i32>,
+           %arg3: memref<f32>) {
----------------
What about creating separate functions for each type so that we have more focused and easier-to-read tests?


================
Comment at: mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir:626
+// CHECK-LABEL: @load
+func @load(%arg0: memref<i8>, %arg1: memref<10xi16>, %arg2: memref<i32>,
+           %arg3: memref<f32>) {
----------------
antiagainst wrote:
> What about creating separate functions for each type so that we have more focused and easier-to-read tests?
We will need tests with `StorageBuffer16BitAcess`/etc. capability.


================
Comment at: mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir:628
+           %arg3: memref<f32>) {
+  //     CHECK: spv.SDiv
+  //     CHECK: spv.AccessChain
----------------
I think we want to check the index calculation in detail for at least one of the case here given it's the crucial part of the adjusting. For others we might be able to just check the op name.


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D78974/new/

https://reviews.llvm.org/D78974





More information about the llvm-commits mailing list