[Mlir-commits] [mlir] 6601b65 - [mlir][StandardToSPIRV] Emulate bitwidths not supported for load op.

Hanhan Wang llvmlistbot at llvm.org
Thu Apr 30 19:28:41 PDT 2020


Author: Hanhan Wang
Date: 2020-04-30T19:27:45-07:00
New Revision: 6601b65aedd093cb62549ff9e8f39872d0b55499

URL: https://github.com/llvm/llvm-project/commit/6601b65aedd093cb62549ff9e8f39872d0b55499
DIFF: https://github.com/llvm/llvm-project/commit/6601b65aedd093cb62549ff9e8f39872d0b55499.diff

LOG: [mlir][StandardToSPIRV] Emulate bitwidths not supported for load op.

Summary:
The current implementation in SPIRVTypeConverter just unconditionally turns
everything into 32-bit if it doesn't meet the requirements of extensions or
capabilities. In this case, we can load a 32-bit value and then do bit
extraction to get the value.

Differential Revision: https://reviews.llvm.org/D78974

Added: 
    

Modified: 
    mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
    mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
index 12b22cacdee2..16fea677b51e 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
@@ -97,6 +97,55 @@ static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType,
   return builder.getF32FloatAttr(dstVal.convertToFloat());
 }
 
+/// Returns the offset of the value in `targetBits` representation. `srcIdx` is
+/// an index into a 1-D array with each element having `sourceBits`. When
+/// accessing an element in the array treating as having elements of
+/// `targetBits`, multiple values are loaded in the same time. The method
+/// returns the offset where the `srcIdx` locates in the value. For example, if
+/// `sourceBits` equals to 8 and `targetBits` equals to 32, the x-th element is
+/// located at (x % 4) * 8. Because there are four elements in one i32, and one
+/// element has 8 bits.
+static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits,
+                                  int targetBits, OpBuilder &builder) {
+  assert(targetBits % sourceBits == 0);
+  IntegerType targetType = builder.getIntegerType(targetBits);
+  IntegerAttr idxAttr =
+      builder.getIntegerAttr(targetType, targetBits / sourceBits);
+  auto idx = builder.create<spirv::ConstantOp>(loc, targetType, idxAttr);
+  IntegerAttr srcBitsAttr = builder.getIntegerAttr(targetType, sourceBits);
+  auto srcBitsValue =
+      builder.create<spirv::ConstantOp>(loc, targetType, srcBitsAttr);
+  auto m = builder.create<spirv::SModOp>(loc, srcIdx, idx);
+  return builder.create<spirv::IMulOp>(loc, targetType, m, srcBitsValue);
+}
+
+/// Returns an adjusted spirv::AccessChainOp. Based on the
+/// extension/capabilities, certain integer bitwidths `sourceBits` might not be
+/// supported. During conversion if a memref of an unsupported type is used,
+/// load/stores to this memref need to be modified to use a supported higher
+/// bitwidth `targetBits` and extracting the required bits. For an accessing a
+/// 1D array (spv.array or spv.rt_array), the last index is modified to load the
+/// bits needed. The extraction of the actual bits needed are handled
+/// separately. Note that this only works for a 1-D tensor.
+static Value adjustAccessChainForBitwidth(SPIRVTypeConverter &typeConverter,
+                                          spirv::AccessChainOp op,
+                                          int sourceBits, int targetBits,
+                                          OpBuilder &builder) {
+  assert(targetBits % sourceBits == 0);
+  const auto loc = op.getLoc();
+  IntegerType targetType = builder.getIntegerType(targetBits);
+  IntegerAttr attr =
+      builder.getIntegerAttr(targetType, targetBits / sourceBits);
+  auto idx = builder.create<spirv::ConstantOp>(loc, targetType, attr);
+  auto lastDim = op.getOperation()->getOperand(op.getNumOperands() - 1);
+  auto indices = llvm::to_vector<4>(op.indices());
+  // There are two elements if this is a 1-D tensor.
+  assert(indices.size() == 2);
+  indices.back() = builder.create<spirv::SDivOp>(loc, lastDim, idx);
+  Type t = typeConverter.convertType(op.component_ptr().getType());
+  return builder.create<spirv::AccessChainOp>(loc, t, op.base_ptr(), indices);
+}
+
 //===----------------------------------------------------------------------===//
 // Operation conversion
 //===----------------------------------------------------------------------===//
@@ -204,6 +253,16 @@ class CmpIOpPattern final : public SPIRVOpLowering<CmpIOp> {
                   ConversionPatternRewriter &rewriter) const override;
 };
 
+/// Converts std.load to spv.Load.
+class IntLoadOpPattern final : public SPIRVOpLowering<LoadOp> {
+public:
+  using SPIRVOpLowering<LoadOp>::SPIRVOpLowering;
+
+  LogicalResult
+  matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+
 /// Converts std.load to spv.Load.
 class LoadOpPattern final : public SPIRVOpLowering<LoadOp> {
 public:
@@ -528,13 +587,79 @@ CmpIOpPattern::matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
 // LoadOp
 //===----------------------------------------------------------------------===//
 
+LogicalResult
+IntLoadOpPattern::matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
+                                  ConversionPatternRewriter &rewriter) const {
+  LoadOpOperandAdaptor loadOperands(operands);
+  auto loc = loadOp.getLoc();
+  auto memrefType = loadOp.memref().getType().cast<MemRefType>();
+  if (!memrefType.getElementType().isSignlessInteger())
+    return failure();
+  spirv::AccessChainOp accessChainOp =
+      spirv::getElementPtr(typeConverter, memrefType, loadOperands.memref(),
+                           loadOperands.indices(), loc, rewriter);
+
+  int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
+  auto dstType = typeConverter.convertType(memrefType)
+                     .cast<spirv::PointerType>()
+                     .getPointeeType()
+                     .cast<spirv::StructType>()
+                     .getElementType(0)
+                     .cast<spirv::ArrayType>()
+                     .getElementType();
+  int dstBits = dstType.getIntOrFloatBitWidth();
+  assert(dstBits % srcBits == 0);
+
+  // If the rewrited load op has the same bit width, use the loading value
+  // directly.
+  if (srcBits == dstBits) {
+    rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp,
+                                               accessChainOp.getResult());
+    return success();
+  }
+
+  // Assume that getElementPtr() works linearizely. If it's a scalar, the method
+  // still returns a linearized accessing. If the accessing is not linearized,
+  // there will be offset issues.
+  assert(accessChainOp.indices().size() == 2);
+  Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
+                                                   srcBits, dstBits, rewriter);
+  Value spvLoadOp = rewriter.create<spirv::LoadOp>(
+      loc, dstType, adjustedPtr,
+      loadOp.getAttrOfType<IntegerAttr>(
+          spirv::attributeName<spirv::MemoryAccess>()),
+      loadOp.getAttrOfType<IntegerAttr>("alignment"));
+
+  // Shift the bits to the rightmost.
+  // ____XXXX________ -> ____________XXXX
+  Value lastDim = accessChainOp.getOperation()->getOperand(
+      accessChainOp.getNumOperands() - 1);
+  Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
+  Value result = rewriter.create<spirv::ShiftRightArithmeticOp>(
+      loc, spvLoadOp.getType(), spvLoadOp, offset);
+
+  // Apply the mask to extract corresponding bits.
+  Value mask = rewriter.create<spirv::ConstantOp>(
+      loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
+  result = rewriter.create<spirv::BitwiseAndOp>(loc, dstType, result, mask);
+  rewriter.replaceOp(loadOp, result);
+
+  assert(accessChainOp.use_empty());
+  rewriter.eraseOp(accessChainOp);
+
+  return success();
+}
+
 LogicalResult
 LoadOpPattern::matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
                                ConversionPatternRewriter &rewriter) const {
   LoadOpOperandAdaptor loadOperands(operands);
-  auto loadPtr = spirv::getElementPtr(
-      typeConverter, loadOp.memref().getType().cast<MemRefType>(),
-      loadOperands.memref(), loadOperands.indices(), loadOp.getLoc(), rewriter);
+  auto memrefType = loadOp.memref().getType().cast<MemRefType>();
+  if (memrefType.getElementType().isSignlessInteger())
+    return failure();
+  auto loadPtr =
+      spirv::getElementPtr(typeConverter, memrefType, loadOperands.memref(),
+                           loadOperands.indices(), loadOp.getLoc(), rewriter);
   rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr);
   return success();
 }
@@ -642,8 +767,8 @@ void populateStandardToSPIRVPatterns(MLIRContext *context,
       BitwiseOpPattern<AndOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
       BitwiseOpPattern<OrOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
       BoolCmpIOpPattern, ConstantCompositeOpPattern, ConstantScalarOpPattern,
-      CmpFOpPattern, CmpIOpPattern, LoadOpPattern, ReturnOpPattern,
-      SelectOpPattern, StoreOpPattern,
+      CmpFOpPattern, CmpIOpPattern, IntLoadOpPattern, LoadOpPattern,
+      ReturnOpPattern, SelectOpPattern, StoreOpPattern,
       TypeCastingOpPattern<SIToFPOp, spirv::ConvertSToFOp>,
       TypeCastingOpPattern<FPExtOp, spirv::FConvertOp>,
       TypeCastingOpPattern<FPTruncOp, spirv::FConvertOp>, XOrOpPattern>(

diff  --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
index e7ad95a1a173..7351e62b19a2 100644
--- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
@@ -619,3 +619,115 @@ func @load_store_zero_rank_int(%arg0: memref<i32>, %arg1: memref<i32>) {
 }
 
 } // end module
+
+// -----
+
+// Check that access chain indices are properly adjusted if non-32-bit types are
+// emulated via 32-bit types.
+// TODO: Test i64 type.
+module attributes {
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>,
+    {max_compute_workgroup_invocations = 128 : i32,
+     max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+} {
+
+// CHECK-LABEL: @load_i8
+func @load_i8(%arg0: memref<i8>) {
+  //     CHECK: %[[ZERO:.+]] = spv.constant 0 : i32
+  //     CHECK: %[[FOUR1:.+]] = spv.constant 4 : i32
+  //     CHECK: %[[QUOTIENT:.+]] = spv.SDiv %[[ZERO]], %[[FOUR1]] : i32
+  //     CHECK: %[[PTR:.+]] = spv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]]
+  //     CHECK: %[[LOAD:.+]] = spv.Load  "StorageBuffer" %[[PTR]]
+  //     CHECK: %[[FOUR2:.+]] = spv.constant 4 : i32
+  //     CHECK: %[[EIGHT:.+]] = spv.constant 8 : i32
+  //     CHECK: %[[IDX:.+]] = spv.SMod %[[ZERO]], %[[FOUR2]] : i32
+  //     CHECK: %[[BITS:.+]] = spv.IMul %[[IDX]], %[[EIGHT]] : i32
+  //     CHECK: %[[VALUE:.+]] = spv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32
+  //     CHECK: %[[MASK:.+]] = spv.constant 255 : i32
+  //     CHECK: spv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32
+  %0 = load %arg0[] : memref<i8>
+  return
+}
+
+// CHECK-LABEL: @load_i16
+//       CHECK: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: i32)
+func @load_i16(%arg0: memref<10xi16>, %index : index) {
+  //     CHECK: %[[ONE:.+]] = spv.constant 1 : i32
+  //     CHECK: %[[FLAT_IDX:.+]] = spv.IMul %[[ONE]], %[[ARG1]] : i32
+  //     CHECK: %[[ZERO:.+]] = spv.constant 0 : i32
+  //     CHECK: %[[TWO1:.+]] = spv.constant 2 : i32
+  //     CHECK: %[[QUOTIENT:.+]] = spv.SDiv %[[FLAT_IDX]], %[[TWO1]] : i32
+  //     CHECK: %[[PTR:.+]] = spv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]]
+  //     CHECK: %[[LOAD:.+]] = spv.Load  "StorageBuffer" %[[PTR]]
+  //     CHECK: %[[TWO2:.+]] = spv.constant 2 : i32
+  //     CHECK: %[[SIXTEEN:.+]] = spv.constant 16 : i32
+  //     CHECK: %[[IDX:.+]] = spv.SMod %[[FLAT_IDX]], %[[TWO2]] : i32
+  //     CHECK: %[[BITS:.+]] = spv.IMul %[[IDX]], %[[SIXTEEN]] : i32
+  //     CHECK: %[[VALUE:.+]] = spv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32
+  //     CHECK: %[[MASK:.+]] = spv.constant 65535 : i32
+  //     CHECK: spv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32
+  %0 = load %arg0[%index] : memref<10xi16>
+  return
+}
+
+// CHECK-LABEL: @load_i32
+func @load_i32(%arg0: memref<i32>) {
+  // CHECK-NOT: spv.SDiv
+  //     CHECK: spv.Load
+  // CHECK-NOT: spv.ShiftRightArithmetic
+  %0 = load %arg0[] : memref<i32>
+  return
+}
+
+// CHECK-LABEL: @load_f32
+func @load_f32(%arg0: memref<f32>) {
+  // CHECK-NOT: spv.SDiv
+  //     CHECK: spv.Load
+  // CHECK-NOT: spv.ShiftRightArithmetic
+  %0 = load %arg0[] : memref<f32>
+  return
+}
+
+} // end module
+
+// -----
+
+// Check that access chain indices are properly adjusted if non-16/32-bit types
+// are emulated via 32-bit types.
+module attributes {
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.0, [Int16, StorageBuffer16BitAccess, Shader],
+    [SPV_KHR_storage_buffer_storage_class, SPV_KHR_16bit_storage]>,
+    {max_compute_workgroup_invocations = 128 : i32,
+     max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+} {
+
+// CHECK-LABEL: @load_i8
+func @load_i8(%arg0: memref<i8>) {
+  //     CHECK: %[[ZERO:.+]] = spv.constant 0 : i32
+  //     CHECK: %[[FOUR1:.+]] = spv.constant 4 : i32
+  //     CHECK: %[[QUOTIENT:.+]] = spv.SDiv %[[ZERO]], %[[FOUR1]] : i32
+  //     CHECK: %[[PTR:.+]] = spv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]]
+  //     CHECK: %[[LOAD:.+]] = spv.Load  "StorageBuffer" %[[PTR]]
+  //     CHECK: %[[FOUR2:.+]] = spv.constant 4 : i32
+  //     CHECK: %[[EIGHT:.+]] = spv.constant 8 : i32
+  //     CHECK: %[[IDX:.+]] = spv.SMod %[[ZERO]], %[[FOUR2]] : i32
+  //     CHECK: %[[BITS:.+]] = spv.IMul %[[IDX]], %[[EIGHT]] : i32
+  //     CHECK: %[[VALUE:.+]] = spv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32
+  //     CHECK: %[[MASK:.+]] = spv.constant 255 : i32
+  //     CHECK: spv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32
+  %0 = load %arg0[] : memref<i8>
+  return
+}
+
+// CHECK-LABEL: @load_i16
+func @load_i16(%arg0: memref<i16>) {
+  // CHECK-NOT: spv.SDiv
+  //     CHECK: spv.Load
+  // CHECK-NOT: spv.ShiftRightArithmetic
+  %0 = load %arg0[] : memref<i16>
+  return
+}
+
+} // end module


        


More information about the Mlir-commits mailing list