[Mlir-commits] [mlir] 5d10613 - [mlir][StandardToSPIRV] Emulate bitwidths not supported for store op.

Hanhan Wang llvmlistbot at llvm.org
Mon May 4 15:19:03 PDT 2020


Author: Hanhan Wang
Date: 2020-05-04T15:18:44-07:00
New Revision: 5d10613b6edcfb1e50cf08c7ec97ed872d0989be

URL: https://github.com/llvm/llvm-project/commit/5d10613b6edcfb1e50cf08c7ec97ed872d0989be
DIFF: https://github.com/llvm/llvm-project/commit/5d10613b6edcfb1e50cf08c7ec97ed872d0989be.diff

LOG: [mlir][StandardToSPIRV] Emulate bitwidths not supported for store op.

Summary:
As D78974, this patch implements the emulation for store op. The emulation is
done with atomic operations. E.g., if the storing value is i8, rewrite the
StoreOp to:

 1) load a 32-bit integer
 2) clear 8 bits in the loading value
 3) store 32-bit value back
 4) load a 32-bit integer
 5) modify 8 bits in the loading value
 6) store 32-bit value back

The step 1 to step 3 are done by AtomicAnd as one atomic step, and the step 4
to step 6 are done by AtomicOr as another atomic step.

Differential Revision: https://reviews.llvm.org/D79272

Added: 
    

Modified: 
    mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
    mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
index 3120f45d7992..2f7868e89336 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
@@ -146,6 +146,15 @@ static Value adjustAccessChainForBitwidth(SPIRVTypeConverter &typeConverter,
   return builder.create<spirv::AccessChainOp>(loc, t, op.base_ptr(), indices);
 }
 
+/// Returns the shifted `targetBits`-bit value with the given offset.
+Value shiftValue(Location loc, Value value, Value offset, Value mask,
+                 int targetBits, OpBuilder &builder) {
+  Type targetType = builder.getIntegerType(targetBits);
+  Value result = builder.create<spirv::BitwiseAndOp>(loc, value, mask);
+  return builder.create<spirv::ShiftLeftLogicalOp>(loc, targetType, result,
+                                                   offset);
+}
+
 //===----------------------------------------------------------------------===//
 // Operation conversion
 //===----------------------------------------------------------------------===//
@@ -292,6 +301,16 @@ class SelectOpPattern final : public SPIRVOpLowering<SelectOp> {
                   ConversionPatternRewriter &rewriter) const override;
 };
 
+/// Converts std.store to spv.Store on integers.
+class IntStoreOpPattern final : public SPIRVOpLowering<StoreOp> {
+public:
+  using SPIRVOpLowering<StoreOp>::SPIRVOpLowering;
+
+  LogicalResult
+  matchAndRewrite(StoreOp storeOp, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+
 /// Converts std.store to spv.Store.
 class StoreOpPattern final : public SPIRVOpLowering<StoreOp> {
 public:
@@ -696,14 +715,92 @@ SelectOpPattern::matchAndRewrite(SelectOp op, ArrayRef<Value> operands,
 // StoreOp
 //===----------------------------------------------------------------------===//
 
+LogicalResult
+IntStoreOpPattern::matchAndRewrite(StoreOp storeOp, ArrayRef<Value> operands,
+                                   ConversionPatternRewriter &rewriter) const {
+  StoreOpOperandAdaptor storeOperands(operands);
+  auto memrefType = storeOp.memref().getType().cast<MemRefType>();
+  if (!memrefType.getElementType().isSignlessInteger())
+    return failure();
+
+  auto loc = storeOp.getLoc();
+  spirv::AccessChainOp accessChainOp =
+      spirv::getElementPtr(typeConverter, memrefType, storeOperands.memref(),
+                           storeOperands.indices(), loc, rewriter);
+  int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
+  auto dstType = typeConverter.convertType(memrefType)
+                     .cast<spirv::PointerType>()
+                     .getPointeeType()
+                     .cast<spirv::StructType>()
+                     .getElementType(0)
+                     .cast<spirv::ArrayType>()
+                     .getElementType();
+  int dstBits = dstType.getIntOrFloatBitWidth();
+  assert(dstBits % srcBits == 0);
+
+  if (srcBits == dstBits) {
+    rewriter.replaceOpWithNewOp<spirv::StoreOp>(
+        storeOp, accessChainOp.getResult(), storeOperands.value());
+    return success();
+  }
+
+  // Since there are multi threads in the processing, the emulation will be done
+  // with atomic operations. E.g., if the storing value is i8, rewrite the
+  // StoreOp to
+  // 1) load a 32-bit integer
+  // 2) clear 8 bits in the loading value
+  // 3) store 32-bit value back
+  // 4) load a 32-bit integer
+  // 5) modify 8 bits in the loading value
+  // 6) store 32-bit value back
+  // The step 1 to step 3 are done by AtomicAnd as one atomic step, and the step
+  // 4 to step 6 are done by AtomicOr as another atomic step.
+  assert(accessChainOp.indices().size() == 2);
+  Value lastDim = accessChainOp.getOperation()->getOperand(
+      accessChainOp.getNumOperands() - 1);
+  Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
+
+  // Create a mask to clear the destination. E.g., if it is the second i8 in
+  // i32, 0xFFFF00FF is created.
+  Value mask = rewriter.create<spirv::ConstantOp>(
+      loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
+  Value clearBitsMask =
+      rewriter.create<spirv::ShiftLeftLogicalOp>(loc, dstType, mask, offset);
+  clearBitsMask = rewriter.create<spirv::NotOp>(loc, dstType, clearBitsMask);
+
+  Value storeVal =
+      shiftValue(loc, storeOperands.value(), offset, mask, dstBits, rewriter);
+  Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
+                                                   srcBits, dstBits, rewriter);
+  Value result = rewriter.create<spirv::AtomicAndOp>(
+      loc, dstType, adjustedPtr, spirv::Scope::Device,
+      spirv::MemorySemantics::AcquireRelease, clearBitsMask);
+  result = rewriter.create<spirv::AtomicOrOp>(
+      loc, dstType, adjustedPtr, spirv::Scope::Device,
+      spirv::MemorySemantics::AcquireRelease, storeVal);
+
+  // The AtomicOrOp has no side effect. Since it is already inserted, we can
+  // just remove the original StoreOp. Note that rewriter.replaceOp()
+  // doesn't work because it only accepts that the numbers of result are the
+  // same.
+  rewriter.eraseOp(storeOp);
+
+  assert(accessChainOp.use_empty());
+  rewriter.eraseOp(accessChainOp);
+
+  return success();
+}
+
 LogicalResult
 StoreOpPattern::matchAndRewrite(StoreOp storeOp, ArrayRef<Value> operands,
                                 ConversionPatternRewriter &rewriter) const {
   StoreOpOperandAdaptor storeOperands(operands);
-  auto storePtr = spirv::getElementPtr(
-      typeConverter, storeOp.memref().getType().cast<MemRefType>(),
-      storeOperands.memref(), storeOperands.indices(), storeOp.getLoc(),
-      rewriter);
+  auto memrefType = storeOp.memref().getType().cast<MemRefType>();
+  if (memrefType.getElementType().isSignlessInteger())
+    return failure();
+  auto storePtr =
+      spirv::getElementPtr(typeConverter, memrefType, storeOperands.memref(),
+                           storeOperands.indices(), storeOp.getLoc(), rewriter);
   rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, storePtr,
                                               storeOperands.value());
   return success();
@@ -769,7 +866,7 @@ void populateStandardToSPIRVPatterns(MLIRContext *context,
       BitwiseOpPattern<OrOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
       BoolCmpIOpPattern, ConstantCompositeOpPattern, ConstantScalarOpPattern,
       CmpFOpPattern, CmpIOpPattern, IntLoadOpPattern, LoadOpPattern,
-      ReturnOpPattern, SelectOpPattern, StoreOpPattern,
+      ReturnOpPattern, SelectOpPattern, IntStoreOpPattern, StoreOpPattern,
       TypeCastingOpPattern<SIToFPOp, spirv::ConvertSToFOp>,
       TypeCastingOpPattern<ZeroExtendIOp, spirv::UConvertOp>,
       TypeCastingOpPattern<TruncateIOp, spirv::SConvertOp>,

diff  --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
index 41cc7c60ca59..3dcc3cdb2115 100644
--- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
@@ -654,7 +654,7 @@ func @load_store_zero_rank_int(%arg0: memref<i32>, %arg1: memref<i32>) {
 
 // Check that access chain indices are properly adjusted if non-32-bit types are
 // emulated via 32-bit types.
-// TODO: Test i64 type.
+// TODO: Test i1 and i64 types.
 module attributes {
   spv.target_env = #spv.target_env<
     #spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>,
@@ -719,6 +719,70 @@ func @load_f32(%arg0: memref<f32>) {
   return
 }
 
+// CHECK-LABEL: @store_i8
+//       CHECK: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: i32)
+func @store_i8(%arg0: memref<i8>, %value: i8) {
+  //     CHECK: %[[ZERO:.+]] = spv.constant 0 : i32
+  //     CHECK: %[[FOUR:.+]] = spv.constant 4 : i32
+  //     CHECK: %[[EIGHT:.+]] = spv.constant 8 : i32
+  //     CHECK: %[[IDX:.+]] = spv.SMod %[[ZERO]], %[[FOUR]] : i32
+  //     CHECK: %[[OFFSET:.+]] = spv.IMul %[[IDX]], %[[EIGHT]] : i32
+  //     CHECK: %[[MASK1:.+]] = spv.constant 255 : i32
+  //     CHECK: %[[TMP1:.+]] = spv.ShiftLeftLogical %[[MASK1]], %[[OFFSET]] : i32, i32
+  //     CHECK: %[[MASK:.+]] = spv.Not %[[TMP1]] : i32
+  //     CHECK: %[[CLAMPED_VAL:.+]] = spv.BitwiseAnd %[[ARG1]], %[[MASK1]] : i32
+  //     CHECK: %[[STORE_VAL:.+]] = spv.ShiftLeftLogical %[[CLAMPED_VAL]], %[[OFFSET]] : i32, i32
+  //     CHECK: %[[FOUR2:.+]] = spv.constant 4 : i32
+  //     CHECK: %[[ACCESS_IDX:.+]] = spv.SDiv %[[ZERO]], %[[FOUR2]] : i32
+  //     CHECK: %[[PTR:.+]] = spv.AccessChain %[[ARG0]][%[[ZERO]], %[[ACCESS_IDX]]]
+  //     CHECK: spv.AtomicAnd "Device" "AcquireRelease" %[[PTR]], %[[MASK]]
+  //     CHECK: spv.AtomicOr "Device" "AcquireRelease" %[[PTR]], %[[STORE_VAL]]
+  store %value, %arg0[] : memref<i8>
+  return
+}
+
+// CHECK-LABEL: @store_i16
+//       CHECK: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: i32, %[[ARG2:.+]]: i32)
+func @store_i16(%arg0: memref<10xi16>, %index: index, %value: i16) {
+  //     CHECK: %[[ONE:.+]] = spv.constant 1 : i32
+  //     CHECK: %[[FLAT_IDX:.+]] = spv.IMul %[[ONE]], %[[ARG1]] : i32
+  //     CHECK: %[[ZERO:.+]] = spv.constant 0 : i32
+  //     CHECK: %[[TWO:.+]] = spv.constant 2 : i32
+  //     CHECK: %[[SIXTEEN:.+]] = spv.constant 16 : i32
+  //     CHECK: %[[IDX:.+]] = spv.SMod %[[FLAT_IDX]], %[[TWO]] : i32
+  //     CHECK: %[[OFFSET:.+]] = spv.IMul %[[IDX]], %[[SIXTEEN]] : i32
+  //     CHECK: %[[MASK1:.+]] = spv.constant 65535 : i32
+  //     CHECK: %[[TMP1:.+]] = spv.ShiftLeftLogical %[[MASK1]], %[[OFFSET]] : i32, i32
+  //     CHECK: %[[MASK:.+]] = spv.Not %[[TMP1]] : i32
+  //     CHECK: %[[CLAMPED_VAL:.+]] = spv.BitwiseAnd %[[ARG2]], %[[MASK1]] : i32
+  //     CHECK: %[[STORE_VAL:.+]] = spv.ShiftLeftLogical %[[CLAMPED_VAL]], %[[OFFSET]] : i32, i32
+  //     CHECK: %[[TWO2:.+]] = spv.constant 2 : i32
+  //     CHECK: %[[ACCESS_IDX:.+]] = spv.SDiv %[[FLAT_IDX]], %[[TWO2]] : i32
+  //     CHECK: %[[PTR:.+]] = spv.AccessChain %[[ARG0]][%[[ZERO]], %[[ACCESS_IDX]]]
+  //     CHECK: spv.AtomicAnd "Device" "AcquireRelease" %[[PTR]], %[[MASK]]
+  //     CHECK: spv.AtomicOr "Device" "AcquireRelease" %[[PTR]], %[[STORE_VAL]]
+  store %value, %arg0[%index] : memref<10xi16>
+  return
+}
+
+// CHECK-LABEL: @store_i32
+func @store_i32(%arg0: memref<i32>, %value: i32) {
+  //     CHECK: spv.Store
+  // CHECK-NOT: spv.AtomicAnd
+  // CHECK-NOT: spv.AtomicOr
+  store %value, %arg0[] : memref<i32>
+  return
+}
+
+// CHECK-LABEL: @store_f32
+func @store_f32(%arg0: memref<f32>, %value: f32) {
+  //     CHECK: spv.Store
+  // CHECK-NOT: spv.AtomicAnd
+  // CHECK-NOT: spv.AtomicOr
+  store %value, %arg0[] : memref<f32>
+  return
+}
+
 } // end module
 
 // -----
@@ -760,4 +824,35 @@ func @load_i16(%arg0: memref<i16>) {
   return
 }
 
+// CHECK-LABEL: @store_i8
+//       CHECK: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: i32)
+func @store_i8(%arg0: memref<i8>, %value: i8) {
+  //     CHECK: %[[ZERO:.+]] = spv.constant 0 : i32
+  //     CHECK: %[[FOUR:.+]] = spv.constant 4 : i32
+  //     CHECK: %[[EIGHT:.+]] = spv.constant 8 : i32
+  //     CHECK: %[[IDX:.+]] = spv.SMod %[[ZERO]], %[[FOUR]] : i32
+  //     CHECK: %[[OFFSET:.+]] = spv.IMul %[[IDX]], %[[EIGHT]] : i32
+  //     CHECK: %[[MASK1:.+]] = spv.constant 255 : i32
+  //     CHECK: %[[TMP1:.+]] = spv.ShiftLeftLogical %[[MASK1]], %[[OFFSET]] : i32, i32
+  //     CHECK: %[[MASK:.+]] = spv.Not %[[TMP1]] : i32
+  //     CHECK: %[[CLAMPED_VAL:.+]] = spv.BitwiseAnd %[[ARG1]], %[[MASK1]] : i32
+  //     CHECK: %[[STORE_VAL:.+]] = spv.ShiftLeftLogical %[[CLAMPED_VAL]], %[[OFFSET]] : i32, i32
+  //     CHECK: %[[FOUR2:.+]] = spv.constant 4 : i32
+  //     CHECK: %[[ACCESS_IDX:.+]] = spv.SDiv %[[ZERO]], %[[FOUR2]] : i32
+  //     CHECK: %[[PTR:.+]] = spv.AccessChain %[[ARG0]][%[[ZERO]], %[[ACCESS_IDX]]]
+  //     CHECK: spv.AtomicAnd "Device" "AcquireRelease" %[[PTR]], %[[MASK]]
+  //     CHECK: spv.AtomicOr "Device" "AcquireRelease" %[[PTR]], %[[STORE_VAL]]
+  store %value, %arg0[] : memref<i8>
+  return
+}
+
+// CHECK-LABEL: @store_i16
+func @store_i16(%arg0: memref<10xi16>, %index: index, %value: i16) {
+  //     CHECK: spv.Store
+  // CHECK-NOT: spv.AtomicAnd
+  // CHECK-NOT: spv.AtomicOr
+  store %value, %arg0[%index] : memref<10xi16>
+  return
+}
+
 } // end module


        


More information about the Mlir-commits mailing list