[Mlir-commits] [mlir] bb6f5c8 - [mlir][spirv] Convert tensor.extract for very small tensors
Lei Zhang
llvmlistbot at llvm.org
Sat Mar 6 05:07:12 PST 2021
Author: Lei Zhang
Date: 2021-03-06T08:03:36-05:00
New Revision: bb6f5c8314799a6553829b724c649e825a558caf
URL: https://github.com/llvm/llvm-project/commit/bb6f5c8314799a6553829b724c649e825a558caf
DIFF: https://github.com/llvm/llvm-project/commit/bb6f5c8314799a6553829b724c649e825a558caf.diff
LOG: [mlir][spirv] Convert tensor.extract for very small tensors
Normally tensors will be stored in buffers before converting to SPIR-V,
given that is how a large amount of data is sent to the GPU. However,
SPIR-V supports converting from tensors directly too. This is for the
cases where the tensor just contains a small amount of elements and it
makes sense to directly inline them as a small data array in the shader.
To handle this, internally the conversion might create new local
variables. SPIR-V consumers in GPU drivers may or may not optimize that
away. So this has implications over register pressure. Therefore, a
threshold is used to control when the patterns should kick in.
Reviewed By: ThomasRaoux
Differential Revision: https://reviews.llvm.org/D98052
Added:
Modified:
mlir/include/mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h
mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRVPass.cpp
mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h b/mlir/include/mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h
index 6cea1999f368..87946d387f25 100644
--- a/mlir/include/mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h
+++ b/mlir/include/mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h
@@ -25,6 +25,23 @@ void populateStandardToSPIRVPatterns(MLIRContext *context,
SPIRVTypeConverter &typeConverter,
OwningRewritePatternList &patterns);
+/// Appends to a pattern list additional patterns for translating tensor ops
+/// to SPIR-V ops.
+///
+/// Note: Normally tensors will be stored in buffers before converting to
+/// SPIR-V, given that is how a large amount of data is sent to the GPU.
+/// However, SPIR-V supports converting from tensors directly too. This is
+/// for the cases where the tensor just contains a small amount of elements
+/// and it makes sense to directly inline them as a small data array in the
+/// shader. To handle this, internally the conversion might create new local
+/// variables. SPIR-V consumers in GPU drivers may or may not optimize that
+/// away. So this has implications over register pressure. Therefore, a
+/// threshold is used to control when the patterns should kick in.
+void populateTensorToSPIRVPatterns(MLIRContext *context,
+ SPIRVTypeConverter &typeConverter,
+ int64_t byteCountThreshold,
+ OwningRewritePatternList &patterns);
+
/// Appends to a pattern list patterns to legalize ops that are not directly
/// lowered to SPIR-V.
void populateStdLegalizationPatternsForSPIRVLowering(
diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
index 4143091543d6..1ac7db13793b 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
@@ -104,6 +104,11 @@ class SPIRVConversionTarget : public ConversionTarget {
Value getBuiltinVariableValue(Operation *op, BuiltIn builtin,
OpBuilder &builder);
+/// Generates IR to perform index linearization with the given `indices` and
+/// their corresponding `strides`, adding an initial `offset`.
+Value linearizeIndex(ValueRange indices, ArrayRef<int64_t> strides,
+ int64_t offset, Location loc, OpBuilder &builder);
+
/// Performs the index computation to get to the element at `indices` of the
/// memory pointed to by `basePtr`, using the layout map of `baseType`.
diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
index d17622d80920..96557dfa1ac7 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/SetVector.h"
@@ -512,6 +513,65 @@ class ZeroExtendI1Pattern final : public OpConversionPattern<ZeroExtendIOp> {
}
};
+/// Converts tensor.extract into loading using access chains from SPIR-V local
+/// variables.
+class TensorExtractPattern final
+ : public OpConversionPattern<tensor::ExtractOp> {
+public:
+ TensorExtractPattern(TypeConverter &typeConverter, MLIRContext *context,
+ int64_t threshold, PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit),
+ byteCountThreshold(threshold) {}
+
+ LogicalResult
+ matchAndRewrite(tensor::ExtractOp extractOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ TensorType tensorType = extractOp.tensor().getType().cast<TensorType>();
+
+ if (!tensorType.hasStaticShape())
+ return rewriter.notifyMatchFailure(extractOp, "non-static tensor");
+
+ if (tensorType.getNumElements() * tensorType.getElementTypeBitWidth() >
+ byteCountThreshold * 8)
+ return rewriter.notifyMatchFailure(extractOp,
+ "exceeding byte count threshold");
+
+ Location loc = extractOp.getLoc();
+ tensor::ExtractOp::Adaptor adaptor(operands);
+
+ int64_t rank = tensorType.getRank();
+ SmallVector<int64_t, 4> strides(rank, 1);
+ for (int i = rank - 2; i >= 0; --i) {
+ strides[i] = strides[i + 1] * tensorType.getDimSize(i + 1);
+ }
+
+ Type varType = spirv::PointerType::get(adaptor.tensor().getType(),
+ spirv::StorageClass::Function);
+
+ spirv::VariableOp varOp;
+ if (adaptor.tensor().getDefiningOp<spirv::ConstantOp>()) {
+ varOp = rewriter.create<spirv::VariableOp>(
+ loc, varType, spirv::StorageClass::Function,
+ /*initializer=*/adaptor.tensor());
+ } else {
+ // Need to store the value to the local variable. It's questionable
+ // whether we want to support such case though.
+ return failure();
+ }
+
+ Value index = spirv::linearizeIndex(adaptor.indices(), strides,
+ /*offset=*/0, loc, rewriter);
+ auto acOp = rewriter.create<spirv::AccessChainOp>(loc, varOp, index);
+
+ rewriter.replaceOpWithNewOp<spirv::LoadOp>(extractOp, acOp);
+
+ return success();
+ }
+
+private:
+ int64_t byteCountThreshold;
+};
+
/// Converts std.trunci to spv.Select if the type of result is i1 or vector of
/// i1.
class TruncI1Pattern final : public OpConversionPattern<TruncateIOp> {
@@ -622,6 +682,9 @@ LogicalResult SignedRemIOpPattern::matchAndRewrite(
// ConstantOp with composite type.
//===----------------------------------------------------------------------===//
+// TODO: This probably should be split into the vector case and tensor case,
+// so that the tensor case can be moved to TensorToSPIRV conversion. But,
+// std.constant is for the standard dialect though.
LogicalResult ConstantCompositeOpPattern::matchAndRewrite(
ConstantOp constOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
@@ -1170,6 +1233,7 @@ void populateStandardToSPIRVPatterns(MLIRContext *context,
UnaryAndBinaryOpPattern<math::SinOp, spirv::GLSLSinOp>,
UnaryAndBinaryOpPattern<math::SqrtOp, spirv::GLSLSqrtOp>,
UnaryAndBinaryOpPattern<math::TanhOp, spirv::GLSLTanhOp>,
+
// Unary and binary patterns
BitwiseOpPattern<AndOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
BitwiseOpPattern<OrOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
@@ -1224,4 +1288,13 @@ void populateStandardToSPIRVPatterns(MLIRContext *context,
patterns.insert<CmpFOpNanKernelPattern>(typeConverter, context,
/*benefit=*/2);
}
+
+void populateTensorToSPIRVPatterns(MLIRContext *context,
+ SPIRVTypeConverter &typeConverter,
+ int64_t byteCountThreshold,
+ OwningRewritePatternList &patterns) {
+ patterns.insert<TensorExtractPattern>(typeConverter, context,
+ byteCountThreshold);
+}
+
} // namespace mlir
diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRVPass.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRVPass.cpp
index 8ae67cc6a16a..ce8419b40719 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRVPass.cpp
@@ -37,6 +37,8 @@ void ConvertStandardToSPIRVPass::runOnOperation() {
SPIRVTypeConverter typeConverter(targetAttr);
OwningRewritePatternList patterns;
populateStandardToSPIRVPatterns(context, typeConverter, patterns);
+ populateTensorToSPIRVPatterns(context, typeConverter,
+ /*byteCountThreshold=*/64, patterns);
populateBuiltinFuncToSPIRVPatterns(context, typeConverter, patterns);
if (failed(applyPartialConversion(module, *target, std::move(patterns))))
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 1ef5cb5e5f0e..c544512950f0 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -607,6 +607,31 @@ Value mlir::spirv::getBuiltinVariableValue(Operation *op,
// Index calculation
//===----------------------------------------------------------------------===//
+Value mlir::spirv::linearizeIndex(ValueRange indices, ArrayRef<int64_t> strides,
+ int64_t offset, Location loc,
+ OpBuilder &builder) {
+ assert(indices.size() == strides.size() &&
+ "must provide indices for all dimensions");
+
+ auto indexType = SPIRVTypeConverter::getIndexType(builder.getContext());
+
+ // TODO: Consider moving to use affine.apply and patterns converting
+ // affine.apply to standard ops. This needs converting to SPIR-V passes to be
+ // broken down into progressive small steps so we can have intermediate steps
+ // using other dialects. At the moment SPIR-V is the final sink.
+
+ Value linearizedIndex = builder.create<spirv::ConstantOp>(
+ loc, indexType, IntegerAttr::get(indexType, offset));
+ for (auto index : llvm::enumerate(indices)) {
+ Value strideVal = builder.create<spirv::ConstantOp>(
+ loc, indexType, IntegerAttr::get(indexType, strides[index.index()]));
+ Value update = builder.create<spirv::IMulOp>(loc, strideVal, index.value());
+ linearizedIndex =
+ builder.create<spirv::IAddOp>(loc, linearizedIndex, update);
+ }
+ return linearizedIndex;
+}
+
spirv::AccessChainOp mlir::spirv::getElementPtr(
SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr,
ValueRange indices, Location loc, OpBuilder &builder) {
@@ -623,28 +648,16 @@ spirv::AccessChainOp mlir::spirv::getElementPtr(
auto indexType = typeConverter.getIndexType(builder.getContext());
SmallVector<Value, 2> linearizedIndices;
- // Add a '0' at the start to index into the struct.
auto zero = spirv::ConstantOp::getZero(indexType, loc, builder);
+
+ // Add a '0' at the start to index into the struct.
linearizedIndices.push_back(zero);
if (baseType.getRank() == 0) {
linearizedIndices.push_back(zero);
} else {
- // TODO: Instead of this logic, use affine.apply and add patterns for
- // lowering affine.apply to standard ops. These will get lowered to SPIR-V
- // ops by the DialectConversion framework.
- Value ptrLoc = builder.create<spirv::ConstantOp>(
- loc, indexType, IntegerAttr::get(indexType, offset));
- assert(indices.size() == strides.size() &&
- "must provide indices for all dimensions");
- for (auto index : llvm::enumerate(indices)) {
- Value strideVal = builder.create<spirv::ConstantOp>(
- loc, indexType, IntegerAttr::get(indexType, strides[index.index()]));
- Value update =
- builder.create<spirv::IMulOp>(loc, strideVal, index.value());
- ptrLoc = builder.create<spirv::IAddOp>(loc, ptrLoc, update);
- }
- linearizedIndices.push_back(ptrLoc);
+ linearizedIndices.push_back(
+ linearizeIndex(indices, strides, offset, loc, builder));
}
return builder.create<spirv::AccessChainOp>(loc, basePtr, linearizedIndices);
}
diff --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
index 7729ec2df080..7a11228f26d7 100644
--- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
@@ -1148,3 +1148,32 @@ func @return_multi_val(%arg0: f32) -> (f32, f32) {
}
}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// tensor.extract
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func @tensor_extract_constant
+// CHECK-SAME: (%[[A:.+]]: i32, %[[B:.+]]: i32, %[[C:.+]]: i32)
+func @tensor_extract_constant(%a : index, %b: index, %c: index) -> i32 {
+ // CHECK: %[[CST:.+]] = spv.Constant dense<[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]>
+ %cst = constant dense<[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]> : tensor<2x2x3xi32>
+ // CHECK: %[[VAR:.+]] = spv.Variable init(%[[CST]]) : !spv.ptr<!spv.array<12 x i32, stride=4>, Function>
+ // CHECK: %[[C0:.+]] = spv.Constant 0 : i32
+ // CHECK: %[[C6:.+]] = spv.Constant 6 : i32
+ // CHECK: %[[MUL0:.+]] = spv.IMul %[[C6]], %[[A]] : i32
+ // CHECK: %[[ADD0:.+]] = spv.IAdd %[[C0]], %[[MUL0]] : i32
+ // CHECK: %[[C3:.+]] = spv.Constant 3 : i32
+ // CHECK: %[[MUL1:.+]] = spv.IMul %[[C3]], %[[B]] : i32
+ // CHECK: %[[ADD1:.+]] = spv.IAdd %[[ADD0]], %[[MUL1]] : i32
+ // CHECK: %[[C1:.+]] = spv.Constant 1 : i32
+ // CHECK: %[[MUL2:.+]] = spv.IMul %[[C1]], %[[C]] : i32
+ // CHECK: %[[ADD2:.+]] = spv.IAdd %[[ADD1]], %[[MUL2]] : i32
+ // CHECK: %[[AC:.+]] = spv.AccessChain %[[VAR]][%[[ADD2]]]
+ // CHECK: %[[VAL:.+]] = spv.Load "Function" %[[AC]] : i32
+ %extract = tensor.extract %cst[%a, %b, %c] : tensor<2x2x3xi32>
+ // CHECK: spv.ReturnValue %[[VAL]]
+ return %extract : i32
+}
More information about the Mlir-commits
mailing list