[Mlir-commits] [mlir] 6601b65 - [mlir][StandardToSPIRV] Emulate bitwidths not supported for load op.
Hanhan Wang
llvmlistbot at llvm.org
Thu Apr 30 19:28:41 PDT 2020
Author: Hanhan Wang
Date: 2020-04-30T19:27:45-07:00
New Revision: 6601b65aedd093cb62549ff9e8f39872d0b55499
URL: https://github.com/llvm/llvm-project/commit/6601b65aedd093cb62549ff9e8f39872d0b55499
DIFF: https://github.com/llvm/llvm-project/commit/6601b65aedd093cb62549ff9e8f39872d0b55499.diff
LOG: [mlir][StandardToSPIRV] Emulate bitwidths not supported for load op.
Summary:
The current implementation in SPIRVTypeConverter just unconditionally turns
everything into 32-bit if it doesn't meet the requirements of extensions or
capabilities. In this case, we can load a 32-bit value and then do bit
extraction to get the value.
Differential Revision: https://reviews.llvm.org/D78974
Added:
Modified:
mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
index 12b22cacdee2..16fea677b51e 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
@@ -97,6 +97,55 @@ static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType,
return builder.getF32FloatAttr(dstVal.convertToFloat());
}
+/// Returns the offset of the value in `targetBits` representation. `srcIdx` is
+/// an index into a 1-D array with each element having `sourceBits`. When
+/// accessing an element in the array treating as having elements of
+/// `targetBits`, multiple values are loaded in the same time. The method
+/// returns the offset where the `srcIdx` locates in the value. For example, if
+/// `sourceBits` equals to 8 and `targetBits` equals to 32, the x-th element is
+/// located at (x % 4) * 8. Because there are four elements in one i32, and one
+/// element has 8 bits.
+static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits,
+ int targetBits, OpBuilder &builder) {
+ assert(targetBits % sourceBits == 0);
+ IntegerType targetType = builder.getIntegerType(targetBits);
+ IntegerAttr idxAttr =
+ builder.getIntegerAttr(targetType, targetBits / sourceBits);
+ auto idx = builder.create<spirv::ConstantOp>(loc, targetType, idxAttr);
+ IntegerAttr srcBitsAttr = builder.getIntegerAttr(targetType, sourceBits);
+ auto srcBitsValue =
+ builder.create<spirv::ConstantOp>(loc, targetType, srcBitsAttr);
+ auto m = builder.create<spirv::SModOp>(loc, srcIdx, idx);
+ return builder.create<spirv::IMulOp>(loc, targetType, m, srcBitsValue);
+}
+
+/// Returns an adjusted spirv::AccessChainOp. Based on the
+/// extension/capabilities, certain integer bitwidths `sourceBits` might not be
+/// supported. During conversion if a memref of an unsupported type is used,
+/// load/stores to this memref need to be modified to use a supported higher
+/// bitwidth `targetBits` and extracting the required bits. For an accessing a
+/// 1D array (spv.array or spv.rt_array), the last index is modified to load the
+/// bits needed. The extraction of the actual bits needed are handled
+/// separately. Note that this only works for a 1-D tensor.
+static Value adjustAccessChainForBitwidth(SPIRVTypeConverter &typeConverter,
+ spirv::AccessChainOp op,
+ int sourceBits, int targetBits,
+ OpBuilder &builder) {
+ assert(targetBits % sourceBits == 0);
+ const auto loc = op.getLoc();
+ IntegerType targetType = builder.getIntegerType(targetBits);
+ IntegerAttr attr =
+ builder.getIntegerAttr(targetType, targetBits / sourceBits);
+ auto idx = builder.create<spirv::ConstantOp>(loc, targetType, attr);
+ auto lastDim = op.getOperation()->getOperand(op.getNumOperands() - 1);
+ auto indices = llvm::to_vector<4>(op.indices());
+ // There are two elements if this is a 1-D tensor.
+ assert(indices.size() == 2);
+ indices.back() = builder.create<spirv::SDivOp>(loc, lastDim, idx);
+ Type t = typeConverter.convertType(op.component_ptr().getType());
+ return builder.create<spirv::AccessChainOp>(loc, t, op.base_ptr(), indices);
+}
+
//===----------------------------------------------------------------------===//
// Operation conversion
//===----------------------------------------------------------------------===//
@@ -204,6 +253,16 @@ class CmpIOpPattern final : public SPIRVOpLowering<CmpIOp> {
ConversionPatternRewriter &rewriter) const override;
};
+/// Converts std.load to spv.Load.
+class IntLoadOpPattern final : public SPIRVOpLowering<LoadOp> {
+public:
+ using SPIRVOpLowering<LoadOp>::SPIRVOpLowering;
+
+ LogicalResult
+ matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
/// Converts std.load to spv.Load.
class LoadOpPattern final : public SPIRVOpLowering<LoadOp> {
public:
@@ -528,13 +587,79 @@ CmpIOpPattern::matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
// LoadOp
//===----------------------------------------------------------------------===//
+LogicalResult
+IntLoadOpPattern::matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const {
+ LoadOpOperandAdaptor loadOperands(operands);
+ auto loc = loadOp.getLoc();
+ auto memrefType = loadOp.memref().getType().cast<MemRefType>();
+ if (!memrefType.getElementType().isSignlessInteger())
+ return failure();
+ spirv::AccessChainOp accessChainOp =
+ spirv::getElementPtr(typeConverter, memrefType, loadOperands.memref(),
+ loadOperands.indices(), loc, rewriter);
+
+ int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
+ auto dstType = typeConverter.convertType(memrefType)
+ .cast<spirv::PointerType>()
+ .getPointeeType()
+ .cast<spirv::StructType>()
+ .getElementType(0)
+ .cast<spirv::ArrayType>()
+ .getElementType();
+ int dstBits = dstType.getIntOrFloatBitWidth();
+ assert(dstBits % srcBits == 0);
+
+ // If the rewrited load op has the same bit width, use the loading value
+ // directly.
+ if (srcBits == dstBits) {
+ rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp,
+ accessChainOp.getResult());
+ return success();
+ }
+
+ // Assume that getElementPtr() works linearizely. If it's a scalar, the method
+ // still returns a linearized accessing. If the accessing is not linearized,
+ // there will be offset issues.
+ assert(accessChainOp.indices().size() == 2);
+ Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
+ srcBits, dstBits, rewriter);
+ Value spvLoadOp = rewriter.create<spirv::LoadOp>(
+ loc, dstType, adjustedPtr,
+ loadOp.getAttrOfType<IntegerAttr>(
+ spirv::attributeName<spirv::MemoryAccess>()),
+ loadOp.getAttrOfType<IntegerAttr>("alignment"));
+
+ // Shift the bits to the rightmost.
+ // ____XXXX________ -> ____________XXXX
+ Value lastDim = accessChainOp.getOperation()->getOperand(
+ accessChainOp.getNumOperands() - 1);
+ Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
+ Value result = rewriter.create<spirv::ShiftRightArithmeticOp>(
+ loc, spvLoadOp.getType(), spvLoadOp, offset);
+
+ // Apply the mask to extract corresponding bits.
+ Value mask = rewriter.create<spirv::ConstantOp>(
+ loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
+ result = rewriter.create<spirv::BitwiseAndOp>(loc, dstType, result, mask);
+ rewriter.replaceOp(loadOp, result);
+
+ assert(accessChainOp.use_empty());
+ rewriter.eraseOp(accessChainOp);
+
+ return success();
+}
+
LogicalResult
LoadOpPattern::matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
LoadOpOperandAdaptor loadOperands(operands);
- auto loadPtr = spirv::getElementPtr(
- typeConverter, loadOp.memref().getType().cast<MemRefType>(),
- loadOperands.memref(), loadOperands.indices(), loadOp.getLoc(), rewriter);
+ auto memrefType = loadOp.memref().getType().cast<MemRefType>();
+ if (memrefType.getElementType().isSignlessInteger())
+ return failure();
+ auto loadPtr =
+ spirv::getElementPtr(typeConverter, memrefType, loadOperands.memref(),
+ loadOperands.indices(), loadOp.getLoc(), rewriter);
rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr);
return success();
}
@@ -642,8 +767,8 @@ void populateStandardToSPIRVPatterns(MLIRContext *context,
BitwiseOpPattern<AndOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
BitwiseOpPattern<OrOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
BoolCmpIOpPattern, ConstantCompositeOpPattern, ConstantScalarOpPattern,
- CmpFOpPattern, CmpIOpPattern, LoadOpPattern, ReturnOpPattern,
- SelectOpPattern, StoreOpPattern,
+ CmpFOpPattern, CmpIOpPattern, IntLoadOpPattern, LoadOpPattern,
+ ReturnOpPattern, SelectOpPattern, StoreOpPattern,
TypeCastingOpPattern<SIToFPOp, spirv::ConvertSToFOp>,
TypeCastingOpPattern<FPExtOp, spirv::FConvertOp>,
TypeCastingOpPattern<FPTruncOp, spirv::FConvertOp>, XOrOpPattern>(
diff --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
index e7ad95a1a173..7351e62b19a2 100644
--- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
@@ -619,3 +619,115 @@ func @load_store_zero_rank_int(%arg0: memref<i32>, %arg1: memref<i32>) {
}
} // end module
+
+// -----
+
+// Check that access chain indices are properly adjusted if non-32-bit types are
+// emulated via 32-bit types.
+// TODO: Test i64 type.
+module attributes {
+ spv.target_env = #spv.target_env<
+ #spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>,
+ {max_compute_workgroup_invocations = 128 : i32,
+ max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+} {
+
+// CHECK-LABEL: @load_i8
+func @load_i8(%arg0: memref<i8>) {
+ // CHECK: %[[ZERO:.+]] = spv.constant 0 : i32
+ // CHECK: %[[FOUR1:.+]] = spv.constant 4 : i32
+ // CHECK: %[[QUOTIENT:.+]] = spv.SDiv %[[ZERO]], %[[FOUR1]] : i32
+ // CHECK: %[[PTR:.+]] = spv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]]
+ // CHECK: %[[LOAD:.+]] = spv.Load "StorageBuffer" %[[PTR]]
+ // CHECK: %[[FOUR2:.+]] = spv.constant 4 : i32
+ // CHECK: %[[EIGHT:.+]] = spv.constant 8 : i32
+ // CHECK: %[[IDX:.+]] = spv.SMod %[[ZERO]], %[[FOUR2]] : i32
+ // CHECK: %[[BITS:.+]] = spv.IMul %[[IDX]], %[[EIGHT]] : i32
+ // CHECK: %[[VALUE:.+]] = spv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32
+ // CHECK: %[[MASK:.+]] = spv.constant 255 : i32
+ // CHECK: spv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32
+ %0 = load %arg0[] : memref<i8>
+ return
+}
+
+// CHECK-LABEL: @load_i16
+// CHECK: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: i32)
+func @load_i16(%arg0: memref<10xi16>, %index : index) {
+ // CHECK: %[[ONE:.+]] = spv.constant 1 : i32
+ // CHECK: %[[FLAT_IDX:.+]] = spv.IMul %[[ONE]], %[[ARG1]] : i32
+ // CHECK: %[[ZERO:.+]] = spv.constant 0 : i32
+ // CHECK: %[[TWO1:.+]] = spv.constant 2 : i32
+ // CHECK: %[[QUOTIENT:.+]] = spv.SDiv %[[FLAT_IDX]], %[[TWO1]] : i32
+ // CHECK: %[[PTR:.+]] = spv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]]
+ // CHECK: %[[LOAD:.+]] = spv.Load "StorageBuffer" %[[PTR]]
+ // CHECK: %[[TWO2:.+]] = spv.constant 2 : i32
+ // CHECK: %[[SIXTEEN:.+]] = spv.constant 16 : i32
+ // CHECK: %[[IDX:.+]] = spv.SMod %[[FLAT_IDX]], %[[TWO2]] : i32
+ // CHECK: %[[BITS:.+]] = spv.IMul %[[IDX]], %[[SIXTEEN]] : i32
+ // CHECK: %[[VALUE:.+]] = spv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32
+ // CHECK: %[[MASK:.+]] = spv.constant 65535 : i32
+ // CHECK: spv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32
+ %0 = load %arg0[%index] : memref<10xi16>
+ return
+}
+
+// CHECK-LABEL: @load_i32
+func @load_i32(%arg0: memref<i32>) {
+ // CHECK-NOT: spv.SDiv
+ // CHECK: spv.Load
+ // CHECK-NOT: spv.ShiftRightArithmetic
+ %0 = load %arg0[] : memref<i32>
+ return
+}
+
+// CHECK-LABEL: @load_f32
+func @load_f32(%arg0: memref<f32>) {
+ // CHECK-NOT: spv.SDiv
+ // CHECK: spv.Load
+ // CHECK-NOT: spv.ShiftRightArithmetic
+ %0 = load %arg0[] : memref<f32>
+ return
+}
+
+} // end module
+
+// -----
+
+// Check that access chain indices are properly adjusted if non-16/32-bit types
+// are emulated via 32-bit types.
+module attributes {
+ spv.target_env = #spv.target_env<
+ #spv.vce<v1.0, [Int16, StorageBuffer16BitAccess, Shader],
+ [SPV_KHR_storage_buffer_storage_class, SPV_KHR_16bit_storage]>,
+ {max_compute_workgroup_invocations = 128 : i32,
+ max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+} {
+
+// CHECK-LABEL: @load_i8
+func @load_i8(%arg0: memref<i8>) {
+ // CHECK: %[[ZERO:.+]] = spv.constant 0 : i32
+ // CHECK: %[[FOUR1:.+]] = spv.constant 4 : i32
+ // CHECK: %[[QUOTIENT:.+]] = spv.SDiv %[[ZERO]], %[[FOUR1]] : i32
+ // CHECK: %[[PTR:.+]] = spv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]]
+ // CHECK: %[[LOAD:.+]] = spv.Load "StorageBuffer" %[[PTR]]
+ // CHECK: %[[FOUR2:.+]] = spv.constant 4 : i32
+ // CHECK: %[[EIGHT:.+]] = spv.constant 8 : i32
+ // CHECK: %[[IDX:.+]] = spv.SMod %[[ZERO]], %[[FOUR2]] : i32
+ // CHECK: %[[BITS:.+]] = spv.IMul %[[IDX]], %[[EIGHT]] : i32
+ // CHECK: %[[VALUE:.+]] = spv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32
+ // CHECK: %[[MASK:.+]] = spv.constant 255 : i32
+ // CHECK: spv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32
+ %0 = load %arg0[] : memref<i8>
+ return
+}
+
+// CHECK-LABEL: @load_i16
+func @load_i16(%arg0: memref<i16>) {
+ // CHECK-NOT: spv.SDiv
+ // CHECK: spv.Load
+ // CHECK-NOT: spv.ShiftRightArithmetic
+ %0 = load %arg0[] : memref<i16>
+ return
+}
+
+} // end module
More information about the Mlir-commits
mailing list