[Mlir-commits] [mlir] 5d10613 - [mlir][StandardToSPIRV] Emulate bitwidths not supported for store op.
Hanhan Wang
llvmlistbot at llvm.org
Mon May 4 15:19:03 PDT 2020
Author: Hanhan Wang
Date: 2020-05-04T15:18:44-07:00
New Revision: 5d10613b6edcfb1e50cf08c7ec97ed872d0989be
URL: https://github.com/llvm/llvm-project/commit/5d10613b6edcfb1e50cf08c7ec97ed872d0989be
DIFF: https://github.com/llvm/llvm-project/commit/5d10613b6edcfb1e50cf08c7ec97ed872d0989be.diff
LOG: [mlir][StandardToSPIRV] Emulate bitwidths not supported for store op.
Summary:
As D78974, this patch implements the emulation for store op. The emulation is
done with atomic operations. E.g., if the storing value is i8, rewrite the
StoreOp to:
1) load a 32-bit integer
2) clear 8 bits in the loading value
3) store 32-bit value back
4) load a 32-bit integer
5) modify 8 bits in the loading value
6) store 32-bit value back
The step 1 to step 3 are done by AtomicAnd as one atomic step, and the step 4
to step 6 are done by AtomicOr as another atomic step.
Differential Revision: https://reviews.llvm.org/D79272
Added:
Modified:
mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
index 3120f45d7992..2f7868e89336 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
@@ -146,6 +146,15 @@ static Value adjustAccessChainForBitwidth(SPIRVTypeConverter &typeConverter,
return builder.create<spirv::AccessChainOp>(loc, t, op.base_ptr(), indices);
}
+/// Returns the shifted `targetBits`-bit value with the given offset.
+Value shiftValue(Location loc, Value value, Value offset, Value mask,
+ int targetBits, OpBuilder &builder) {
+ Type targetType = builder.getIntegerType(targetBits);
+ Value result = builder.create<spirv::BitwiseAndOp>(loc, value, mask);
+ return builder.create<spirv::ShiftLeftLogicalOp>(loc, targetType, result,
+ offset);
+}
+
//===----------------------------------------------------------------------===//
// Operation conversion
//===----------------------------------------------------------------------===//
@@ -292,6 +301,16 @@ class SelectOpPattern final : public SPIRVOpLowering<SelectOp> {
ConversionPatternRewriter &rewriter) const override;
};
+/// Converts std.store to spv.Store on integers.
+class IntStoreOpPattern final : public SPIRVOpLowering<StoreOp> {
+public:
+ using SPIRVOpLowering<StoreOp>::SPIRVOpLowering;
+
+ LogicalResult
+ matchAndRewrite(StoreOp storeOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
/// Converts std.store to spv.Store.
class StoreOpPattern final : public SPIRVOpLowering<StoreOp> {
public:
@@ -696,14 +715,92 @@ SelectOpPattern::matchAndRewrite(SelectOp op, ArrayRef<Value> operands,
// StoreOp
//===----------------------------------------------------------------------===//
+LogicalResult
+IntStoreOpPattern::matchAndRewrite(StoreOp storeOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const {
+ StoreOpOperandAdaptor storeOperands(operands);
+ auto memrefType = storeOp.memref().getType().cast<MemRefType>();
+ if (!memrefType.getElementType().isSignlessInteger())
+ return failure();
+
+ auto loc = storeOp.getLoc();
+ spirv::AccessChainOp accessChainOp =
+ spirv::getElementPtr(typeConverter, memrefType, storeOperands.memref(),
+ storeOperands.indices(), loc, rewriter);
+ int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
+ auto dstType = typeConverter.convertType(memrefType)
+ .cast<spirv::PointerType>()
+ .getPointeeType()
+ .cast<spirv::StructType>()
+ .getElementType(0)
+ .cast<spirv::ArrayType>()
+ .getElementType();
+ int dstBits = dstType.getIntOrFloatBitWidth();
+ assert(dstBits % srcBits == 0);
+
+ if (srcBits == dstBits) {
+ rewriter.replaceOpWithNewOp<spirv::StoreOp>(
+ storeOp, accessChainOp.getResult(), storeOperands.value());
+ return success();
+ }
+
+ // Since there are multi threads in the processing, the emulation will be done
+ // with atomic operations. E.g., if the storing value is i8, rewrite the
+ // StoreOp to
+ // 1) load a 32-bit integer
+ // 2) clear 8 bits in the loading value
+ // 3) store 32-bit value back
+ // 4) load a 32-bit integer
+ // 5) modify 8 bits in the loading value
+ // 6) store 32-bit value back
+ // The step 1 to step 3 are done by AtomicAnd as one atomic step, and the step
+ // 4 to step 6 are done by AtomicOr as another atomic step.
+ assert(accessChainOp.indices().size() == 2);
+ Value lastDim = accessChainOp.getOperation()->getOperand(
+ accessChainOp.getNumOperands() - 1);
+ Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
+
+ // Create a mask to clear the destination. E.g., if it is the second i8 in
+ // i32, 0xFFFF00FF is created.
+ Value mask = rewriter.create<spirv::ConstantOp>(
+ loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
+ Value clearBitsMask =
+ rewriter.create<spirv::ShiftLeftLogicalOp>(loc, dstType, mask, offset);
+ clearBitsMask = rewriter.create<spirv::NotOp>(loc, dstType, clearBitsMask);
+
+ Value storeVal =
+ shiftValue(loc, storeOperands.value(), offset, mask, dstBits, rewriter);
+ Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
+ srcBits, dstBits, rewriter);
+ Value result = rewriter.create<spirv::AtomicAndOp>(
+ loc, dstType, adjustedPtr, spirv::Scope::Device,
+ spirv::MemorySemantics::AcquireRelease, clearBitsMask);
+ result = rewriter.create<spirv::AtomicOrOp>(
+ loc, dstType, adjustedPtr, spirv::Scope::Device,
+ spirv::MemorySemantics::AcquireRelease, storeVal);
+
+ // The AtomicOrOp has no side effect. Since it is already inserted, we can
+ // just remove the original StoreOp. Note that rewriter.replaceOp()
+ // doesn't work because it only accepts that the numbers of result are the
+ // same.
+ rewriter.eraseOp(storeOp);
+
+ assert(accessChainOp.use_empty());
+ rewriter.eraseOp(accessChainOp);
+
+ return success();
+}
+
LogicalResult
StoreOpPattern::matchAndRewrite(StoreOp storeOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
StoreOpOperandAdaptor storeOperands(operands);
- auto storePtr = spirv::getElementPtr(
- typeConverter, storeOp.memref().getType().cast<MemRefType>(),
- storeOperands.memref(), storeOperands.indices(), storeOp.getLoc(),
- rewriter);
+ auto memrefType = storeOp.memref().getType().cast<MemRefType>();
+ if (memrefType.getElementType().isSignlessInteger())
+ return failure();
+ auto storePtr =
+ spirv::getElementPtr(typeConverter, memrefType, storeOperands.memref(),
+ storeOperands.indices(), storeOp.getLoc(), rewriter);
rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, storePtr,
storeOperands.value());
return success();
@@ -769,7 +866,7 @@ void populateStandardToSPIRVPatterns(MLIRContext *context,
BitwiseOpPattern<OrOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
BoolCmpIOpPattern, ConstantCompositeOpPattern, ConstantScalarOpPattern,
CmpFOpPattern, CmpIOpPattern, IntLoadOpPattern, LoadOpPattern,
- ReturnOpPattern, SelectOpPattern, StoreOpPattern,
+ ReturnOpPattern, SelectOpPattern, IntStoreOpPattern, StoreOpPattern,
TypeCastingOpPattern<SIToFPOp, spirv::ConvertSToFOp>,
TypeCastingOpPattern<ZeroExtendIOp, spirv::UConvertOp>,
TypeCastingOpPattern<TruncateIOp, spirv::SConvertOp>,
diff --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
index 41cc7c60ca59..3dcc3cdb2115 100644
--- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
@@ -654,7 +654,7 @@ func @load_store_zero_rank_int(%arg0: memref<i32>, %arg1: memref<i32>) {
// Check that access chain indices are properly adjusted if non-32-bit types are
// emulated via 32-bit types.
-// TODO: Test i64 type.
+// TODO: Test i1 and i64 types.
module attributes {
spv.target_env = #spv.target_env<
#spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>,
@@ -719,6 +719,70 @@ func @load_f32(%arg0: memref<f32>) {
return
}
+// CHECK-LABEL: @store_i8
+// CHECK: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: i32)
+func @store_i8(%arg0: memref<i8>, %value: i8) {
+ // CHECK: %[[ZERO:.+]] = spv.constant 0 : i32
+ // CHECK: %[[FOUR:.+]] = spv.constant 4 : i32
+ // CHECK: %[[EIGHT:.+]] = spv.constant 8 : i32
+ // CHECK: %[[IDX:.+]] = spv.SMod %[[ZERO]], %[[FOUR]] : i32
+ // CHECK: %[[OFFSET:.+]] = spv.IMul %[[IDX]], %[[EIGHT]] : i32
+ // CHECK: %[[MASK1:.+]] = spv.constant 255 : i32
+ // CHECK: %[[TMP1:.+]] = spv.ShiftLeftLogical %[[MASK1]], %[[OFFSET]] : i32, i32
+ // CHECK: %[[MASK:.+]] = spv.Not %[[TMP1]] : i32
+ // CHECK: %[[CLAMPED_VAL:.+]] = spv.BitwiseAnd %[[ARG1]], %[[MASK1]] : i32
+ // CHECK: %[[STORE_VAL:.+]] = spv.ShiftLeftLogical %[[CLAMPED_VAL]], %[[OFFSET]] : i32, i32
+ // CHECK: %[[FOUR2:.+]] = spv.constant 4 : i32
+ // CHECK: %[[ACCESS_IDX:.+]] = spv.SDiv %[[ZERO]], %[[FOUR2]] : i32
+ // CHECK: %[[PTR:.+]] = spv.AccessChain %[[ARG0]][%[[ZERO]], %[[ACCESS_IDX]]]
+ // CHECK: spv.AtomicAnd "Device" "AcquireRelease" %[[PTR]], %[[MASK]]
+ // CHECK: spv.AtomicOr "Device" "AcquireRelease" %[[PTR]], %[[STORE_VAL]]
+ store %value, %arg0[] : memref<i8>
+ return
+}
+
+// CHECK-LABEL: @store_i16
+// CHECK: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: i32, %[[ARG2:.+]]: i32)
+func @store_i16(%arg0: memref<10xi16>, %index: index, %value: i16) {
+ // CHECK: %[[ONE:.+]] = spv.constant 1 : i32
+ // CHECK: %[[FLAT_IDX:.+]] = spv.IMul %[[ONE]], %[[ARG1]] : i32
+ // CHECK: %[[ZERO:.+]] = spv.constant 0 : i32
+ // CHECK: %[[TWO:.+]] = spv.constant 2 : i32
+ // CHECK: %[[SIXTEEN:.+]] = spv.constant 16 : i32
+ // CHECK: %[[IDX:.+]] = spv.SMod %[[FLAT_IDX]], %[[TWO]] : i32
+ // CHECK: %[[OFFSET:.+]] = spv.IMul %[[IDX]], %[[SIXTEEN]] : i32
+ // CHECK: %[[MASK1:.+]] = spv.constant 65535 : i32
+ // CHECK: %[[TMP1:.+]] = spv.ShiftLeftLogical %[[MASK1]], %[[OFFSET]] : i32, i32
+ // CHECK: %[[MASK:.+]] = spv.Not %[[TMP1]] : i32
+ // CHECK: %[[CLAMPED_VAL:.+]] = spv.BitwiseAnd %[[ARG2]], %[[MASK1]] : i32
+ // CHECK: %[[STORE_VAL:.+]] = spv.ShiftLeftLogical %[[CLAMPED_VAL]], %[[OFFSET]] : i32, i32
+ // CHECK: %[[TWO2:.+]] = spv.constant 2 : i32
+ // CHECK: %[[ACCESS_IDX:.+]] = spv.SDiv %[[FLAT_IDX]], %[[TWO2]] : i32
+ // CHECK: %[[PTR:.+]] = spv.AccessChain %[[ARG0]][%[[ZERO]], %[[ACCESS_IDX]]]
+ // CHECK: spv.AtomicAnd "Device" "AcquireRelease" %[[PTR]], %[[MASK]]
+ // CHECK: spv.AtomicOr "Device" "AcquireRelease" %[[PTR]], %[[STORE_VAL]]
+ store %value, %arg0[%index] : memref<10xi16>
+ return
+}
+
+// CHECK-LABEL: @store_i32
+func @store_i32(%arg0: memref<i32>, %value: i32) {
+ // CHECK: spv.Store
+ // CHECK-NOT: spv.AtomicAnd
+ // CHECK-NOT: spv.AtomicOr
+ store %value, %arg0[] : memref<i32>
+ return
+}
+
+// CHECK-LABEL: @store_f32
+func @store_f32(%arg0: memref<f32>, %value: f32) {
+ // CHECK: spv.Store
+ // CHECK-NOT: spv.AtomicAnd
+ // CHECK-NOT: spv.AtomicOr
+ store %value, %arg0[] : memref<f32>
+ return
+}
+
} // end module
// -----
@@ -760,4 +824,35 @@ func @load_i16(%arg0: memref<i16>) {
return
}
+// CHECK-LABEL: @store_i8
+// CHECK: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: i32)
+func @store_i8(%arg0: memref<i8>, %value: i8) {
+ // CHECK: %[[ZERO:.+]] = spv.constant 0 : i32
+ // CHECK: %[[FOUR:.+]] = spv.constant 4 : i32
+ // CHECK: %[[EIGHT:.+]] = spv.constant 8 : i32
+ // CHECK: %[[IDX:.+]] = spv.SMod %[[ZERO]], %[[FOUR]] : i32
+ // CHECK: %[[OFFSET:.+]] = spv.IMul %[[IDX]], %[[EIGHT]] : i32
+ // CHECK: %[[MASK1:.+]] = spv.constant 255 : i32
+ // CHECK: %[[TMP1:.+]] = spv.ShiftLeftLogical %[[MASK1]], %[[OFFSET]] : i32, i32
+ // CHECK: %[[MASK:.+]] = spv.Not %[[TMP1]] : i32
+ // CHECK: %[[CLAMPED_VAL:.+]] = spv.BitwiseAnd %[[ARG1]], %[[MASK1]] : i32
+ // CHECK: %[[STORE_VAL:.+]] = spv.ShiftLeftLogical %[[CLAMPED_VAL]], %[[OFFSET]] : i32, i32
+ // CHECK: %[[FOUR2:.+]] = spv.constant 4 : i32
+ // CHECK: %[[ACCESS_IDX:.+]] = spv.SDiv %[[ZERO]], %[[FOUR2]] : i32
+ // CHECK: %[[PTR:.+]] = spv.AccessChain %[[ARG0]][%[[ZERO]], %[[ACCESS_IDX]]]
+ // CHECK: spv.AtomicAnd "Device" "AcquireRelease" %[[PTR]], %[[MASK]]
+ // CHECK: spv.AtomicOr "Device" "AcquireRelease" %[[PTR]], %[[STORE_VAL]]
+ store %value, %arg0[] : memref<i8>
+ return
+}
+
+// CHECK-LABEL: @store_i16
+func @store_i16(%arg0: memref<10xi16>, %index: index, %value: i16) {
+ // CHECK: spv.Store
+ // CHECK-NOT: spv.AtomicAnd
+ // CHECK-NOT: spv.AtomicOr
+ store %value, %arg0[%index] : memref<10xi16>
+ return
+}
+
} // end module
More information about the Mlir-commits
mailing list