[Mlir-commits] [mlir] bb6f5c8 - [mlir][spirv] Convert tensor.extract for very small tensors

Lei Zhang llvmlistbot at llvm.org
Sat Mar 6 05:07:12 PST 2021


Author: Lei Zhang
Date: 2021-03-06T08:03:36-05:00
New Revision: bb6f5c8314799a6553829b724c649e825a558caf

URL: https://github.com/llvm/llvm-project/commit/bb6f5c8314799a6553829b724c649e825a558caf
DIFF: https://github.com/llvm/llvm-project/commit/bb6f5c8314799a6553829b724c649e825a558caf.diff

LOG: [mlir][spirv] Convert tensor.extract for very small tensors

Normally tensors will be stored in buffers before converting to SPIR-V,
given that is how a large amount of data is sent to the GPU. However,
SPIR-V supports converting from tensors directly too. This is for the
cases where the tensor just contains a small amount of elements and it
makes sense to directly inline them as a small data array in the shader.
To handle this, internally the conversion might create new local
variables. SPIR-V consumers in GPU drivers may or may not optimize that
away. So this has implications over register pressure. Therefore, a
threshold is used to control when the patterns should kick in.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D98052

Added: 
    

Modified: 
    mlir/include/mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h
    mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
    mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
    mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRVPass.cpp
    mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
    mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h b/mlir/include/mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h
index 6cea1999f368..87946d387f25 100644
--- a/mlir/include/mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h
+++ b/mlir/include/mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h
@@ -25,6 +25,23 @@ void populateStandardToSPIRVPatterns(MLIRContext *context,
                                      SPIRVTypeConverter &typeConverter,
                                      OwningRewritePatternList &patterns);
 
+/// Appends to a pattern list additional patterns for translating tensor ops
+/// to SPIR-V ops.
+///
+/// Note: Normally tensors will be stored in buffers before converting to
+/// SPIR-V, given that is how a large amount of data is sent to the GPU.
+/// However, SPIR-V supports converting from tensors directly too. This is
+/// for the cases where the tensor just contains a small amount of elements
+/// and it makes sense to directly inline them as a small data array in the
+/// shader. To handle this, internally the conversion might create new local
+/// variables. SPIR-V consumers in GPU drivers may or may not optimize that
+/// away. So this has implications over register pressure. Therefore, a
+/// threshold is used to control when the patterns should kick in.
+void populateTensorToSPIRVPatterns(MLIRContext *context,
+                                   SPIRVTypeConverter &typeConverter,
+                                   int64_t byteCountThreshold,
+                                   OwningRewritePatternList &patterns);
+
 /// Appends to a pattern list patterns to legalize ops that are not directly
 /// lowered to SPIR-V.
 void populateStdLegalizationPatternsForSPIRVLowering(

diff  --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
index 4143091543d6..1ac7db13793b 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
@@ -104,6 +104,11 @@ class SPIRVConversionTarget : public ConversionTarget {
 Value getBuiltinVariableValue(Operation *op, BuiltIn builtin,
                               OpBuilder &builder);
 
+/// Generates IR to perform index linearization with the given `indices` and
+/// their corresponding `strides`, adding an initial `offset`.
+Value linearizeIndex(ValueRange indices, ArrayRef<int64_t> strides,
+                     int64_t offset, Location loc, OpBuilder &builder);
+
 /// Performs the index computation to get to the element at `indices` of the
 /// memory pointed to by `basePtr`, using the layout map of `baseType`.
 

diff  --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
index d17622d80920..96557dfa1ac7 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
 #include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/Support/LogicalResult.h"
 #include "llvm/ADT/SetVector.h"
@@ -512,6 +513,65 @@ class ZeroExtendI1Pattern final : public OpConversionPattern<ZeroExtendIOp> {
   }
 };
 
+/// Converts tensor.extract into loading using access chains from SPIR-V local
+/// variables.
+class TensorExtractPattern final
+    : public OpConversionPattern<tensor::ExtractOp> {
+public:
+  TensorExtractPattern(TypeConverter &typeConverter, MLIRContext *context,
+                       int64_t threshold, PatternBenefit benefit = 1)
+      : OpConversionPattern(typeConverter, context, benefit),
+        byteCountThreshold(threshold) {}
+
+  LogicalResult
+  matchAndRewrite(tensor::ExtractOp extractOp, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    TensorType tensorType = extractOp.tensor().getType().cast<TensorType>();
+
+    if (!tensorType.hasStaticShape())
+      return rewriter.notifyMatchFailure(extractOp, "non-static tensor");
+
+    if (tensorType.getNumElements() * tensorType.getElementTypeBitWidth() >
+        byteCountThreshold * 8)
+      return rewriter.notifyMatchFailure(extractOp,
+                                         "exceeding byte count threshold");
+
+    Location loc = extractOp.getLoc();
+    tensor::ExtractOp::Adaptor adaptor(operands);
+
+    int64_t rank = tensorType.getRank();
+    SmallVector<int64_t, 4> strides(rank, 1);
+    for (int i = rank - 2; i >= 0; --i) {
+      strides[i] = strides[i + 1] * tensorType.getDimSize(i + 1);
+    }
+
+    Type varType = spirv::PointerType::get(adaptor.tensor().getType(),
+                                           spirv::StorageClass::Function);
+
+    spirv::VariableOp varOp;
+    if (adaptor.tensor().getDefiningOp<spirv::ConstantOp>()) {
+      varOp = rewriter.create<spirv::VariableOp>(
+          loc, varType, spirv::StorageClass::Function,
+          /*initializer=*/adaptor.tensor());
+    } else {
+      // Need to store the value to the local variable. It's questionable
+      // whether we want to support such case though.
+      return failure();
+    }
+
+    Value index = spirv::linearizeIndex(adaptor.indices(), strides,
+                                        /*offset=*/0, loc, rewriter);
+    auto acOp = rewriter.create<spirv::AccessChainOp>(loc, varOp, index);
+
+    rewriter.replaceOpWithNewOp<spirv::LoadOp>(extractOp, acOp);
+
+    return success();
+  }
+
+private:
+  int64_t byteCountThreshold;
+};
+
 /// Converts std.trunci to spv.Select if the type of result is i1 or vector of
 /// i1.
 class TruncI1Pattern final : public OpConversionPattern<TruncateIOp> {
@@ -622,6 +682,9 @@ LogicalResult SignedRemIOpPattern::matchAndRewrite(
 // ConstantOp with composite type.
 //===----------------------------------------------------------------------===//
 
+// TODO: This probably should be split into the vector case and tensor case,
+// so that the tensor case can be moved to TensorToSPIRV conversion. But,
+// std.constant is for the standard dialect though.
 LogicalResult ConstantCompositeOpPattern::matchAndRewrite(
     ConstantOp constOp, ArrayRef<Value> operands,
     ConversionPatternRewriter &rewriter) const {
@@ -1170,6 +1233,7 @@ void populateStandardToSPIRVPatterns(MLIRContext *context,
       UnaryAndBinaryOpPattern<math::SinOp, spirv::GLSLSinOp>,
       UnaryAndBinaryOpPattern<math::SqrtOp, spirv::GLSLSqrtOp>,
       UnaryAndBinaryOpPattern<math::TanhOp, spirv::GLSLTanhOp>,
+
       // Unary and binary patterns
       BitwiseOpPattern<AndOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
       BitwiseOpPattern<OrOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
@@ -1224,4 +1288,13 @@ void populateStandardToSPIRVPatterns(MLIRContext *context,
   patterns.insert<CmpFOpNanKernelPattern>(typeConverter, context,
                                           /*benefit=*/2);
 }
+
+void populateTensorToSPIRVPatterns(MLIRContext *context,
+                                   SPIRVTypeConverter &typeConverter,
+                                   int64_t byteCountThreshold,
+                                   OwningRewritePatternList &patterns) {
+  patterns.insert<TensorExtractPattern>(typeConverter, context,
+                                        byteCountThreshold);
+}
+
 } // namespace mlir

diff  --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRVPass.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRVPass.cpp
index 8ae67cc6a16a..ce8419b40719 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRVPass.cpp
@@ -37,6 +37,8 @@ void ConvertStandardToSPIRVPass::runOnOperation() {
   SPIRVTypeConverter typeConverter(targetAttr);
   OwningRewritePatternList patterns;
   populateStandardToSPIRVPatterns(context, typeConverter, patterns);
+  populateTensorToSPIRVPatterns(context, typeConverter,
+                                /*byteCountThreshold=*/64, patterns);
   populateBuiltinFuncToSPIRVPatterns(context, typeConverter, patterns);
 
   if (failed(applyPartialConversion(module, *target, std::move(patterns))))

diff  --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 1ef5cb5e5f0e..c544512950f0 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -607,6 +607,31 @@ Value mlir::spirv::getBuiltinVariableValue(Operation *op,
 // Index calculation
 //===----------------------------------------------------------------------===//
 
+Value mlir::spirv::linearizeIndex(ValueRange indices, ArrayRef<int64_t> strides,
+                                  int64_t offset, Location loc,
+                                  OpBuilder &builder) {
+  assert(indices.size() == strides.size() &&
+         "must provide indices for all dimensions");
+
+  auto indexType = SPIRVTypeConverter::getIndexType(builder.getContext());
+
+  // TODO: Consider moving to use affine.apply and patterns converting
+  // affine.apply to standard ops. This needs converting to SPIR-V passes to be
+  // broken down into progressive small steps so we can have intermediate steps
+  // using other dialects. At the moment SPIR-V is the final sink.
+
+  Value linearizedIndex = builder.create<spirv::ConstantOp>(
+      loc, indexType, IntegerAttr::get(indexType, offset));
+  for (auto index : llvm::enumerate(indices)) {
+    Value strideVal = builder.create<spirv::ConstantOp>(
+        loc, indexType, IntegerAttr::get(indexType, strides[index.index()]));
+    Value update = builder.create<spirv::IMulOp>(loc, strideVal, index.value());
+    linearizedIndex =
+        builder.create<spirv::IAddOp>(loc, linearizedIndex, update);
+  }
+  return linearizedIndex;
+}
+
 spirv::AccessChainOp mlir::spirv::getElementPtr(
     SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr,
     ValueRange indices, Location loc, OpBuilder &builder) {
@@ -623,28 +648,16 @@ spirv::AccessChainOp mlir::spirv::getElementPtr(
   auto indexType = typeConverter.getIndexType(builder.getContext());
 
   SmallVector<Value, 2> linearizedIndices;
-  // Add a '0' at the start to index into the struct.
   auto zero = spirv::ConstantOp::getZero(indexType, loc, builder);
+
+  // Add a '0' at the start to index into the struct.
   linearizedIndices.push_back(zero);
 
   if (baseType.getRank() == 0) {
     linearizedIndices.push_back(zero);
   } else {
-    // TODO: Instead of this logic, use affine.apply and add patterns for
-    // lowering affine.apply to standard ops. These will get lowered to SPIR-V
-    // ops by the DialectConversion framework.
-    Value ptrLoc = builder.create<spirv::ConstantOp>(
-        loc, indexType, IntegerAttr::get(indexType, offset));
-    assert(indices.size() == strides.size() &&
-           "must provide indices for all dimensions");
-    for (auto index : llvm::enumerate(indices)) {
-      Value strideVal = builder.create<spirv::ConstantOp>(
-          loc, indexType, IntegerAttr::get(indexType, strides[index.index()]));
-      Value update =
-          builder.create<spirv::IMulOp>(loc, strideVal, index.value());
-      ptrLoc = builder.create<spirv::IAddOp>(loc, ptrLoc, update);
-    }
-    linearizedIndices.push_back(ptrLoc);
+    linearizedIndices.push_back(
+        linearizeIndex(indices, strides, offset, loc, builder));
   }
   return builder.create<spirv::AccessChainOp>(loc, basePtr, linearizedIndices);
 }

diff  --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
index 7729ec2df080..7a11228f26d7 100644
--- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
@@ -1148,3 +1148,32 @@ func @return_multi_val(%arg0: f32) -> (f32, f32) {
 }
 
 }
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// tensor.extract
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func @tensor_extract_constant
+// CHECK-SAME: (%[[A:.+]]: i32, %[[B:.+]]: i32, %[[C:.+]]: i32)
+func @tensor_extract_constant(%a : index, %b: index, %c: index) -> i32 {
+  // CHECK: %[[CST:.+]] = spv.Constant dense<[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]>
+  %cst = constant dense<[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]> : tensor<2x2x3xi32>
+  // CHECK: %[[VAR:.+]] = spv.Variable init(%[[CST]]) : !spv.ptr<!spv.array<12 x i32, stride=4>, Function>
+  // CHECK: %[[C0:.+]] = spv.Constant 0 : i32
+  // CHECK: %[[C6:.+]] = spv.Constant 6 : i32
+  // CHECK: %[[MUL0:.+]] = spv.IMul %[[C6]], %[[A]] : i32
+  // CHECK: %[[ADD0:.+]] = spv.IAdd %[[C0]], %[[MUL0]] : i32
+  // CHECK: %[[C3:.+]] = spv.Constant 3 : i32
+  // CHECK: %[[MUL1:.+]] = spv.IMul %[[C3]], %[[B]] : i32
+  // CHECK: %[[ADD1:.+]] = spv.IAdd %[[ADD0]], %[[MUL1]] : i32
+  // CHECK: %[[C1:.+]] = spv.Constant 1 : i32
+  // CHECK: %[[MUL2:.+]] = spv.IMul %[[C1]], %[[C]] : i32
+  // CHECK: %[[ADD2:.+]] = spv.IAdd %[[ADD1]], %[[MUL2]] : i32
+  // CHECK: %[[AC:.+]] = spv.AccessChain %[[VAR]][%[[ADD2]]]
+  // CHECK: %[[VAL:.+]] = spv.Load "Function" %[[AC]] : i32
+  %extract = tensor.extract %cst[%a, %b, %c] : tensor<2x2x3xi32>
+  // CHECK: spv.ReturnValue %[[VAL]]
+  return %extract : i32
+}


        


More information about the Mlir-commits mailing list