[Mlir-commits] [mlir] c181f21 - [MLIR] Vectorize tensor.extract on n-D tensor (n >= 2)
Andrzej Warzynski
llvmlistbot at llvm.org
Mon Dec 12 01:34:58 PST 2022
Author: Andrzej Warzynski
Date: 2022-12-12T09:32:16Z
New Revision: c181f21ac71f1d52ec3d11a381634bbcec3844a8
URL: https://github.com/llvm/llvm-project/commit/c181f21ac71f1d52ec3d11a381634bbcec3844a8
DIFF: https://github.com/llvm/llvm-project/commit/c181f21ac71f1d52ec3d11a381634bbcec3844a8.diff
LOG: [MLIR] Vectorize tensor.extract on n-D tensor (n >= 2)
This patch implements the vectorization of tensor.extract for arbitrary
tensors. It basically extends https://reviews.llvm.org/D133786 by adding
support for n-D tensors (n >= 2). This is implemented by essentially
flattening the indices.
When benchmarking the vectorized code, we have observed that it is
slower than the scalar code. That's most likely due to sub-optimal (and,
in general slow) gather loads. More work is needed to identify an
implementation and/or a representation that would lead to better code.
In the meantime, the vectorization of n-D tensors (where n >= 2) has to
be explicitly enabled. This can be done either via:
* transfer dialect's `vectorize_nd_extract` attribute,
* dedicated bool argument in the `vectorize` method from
"Vectorization.cpp".
The second option was added to control the new functionality through
means other than the transfer dialect.
Related discussion: https://github.com/iree-org/iree/issues/9198
Differential Revision: https://reviews.llvm.org/D137660
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/test/Dialect/Linalg/vectorization.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index f2b3fb795723c..e6bfa0841fa5e 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1091,6 +1091,7 @@ def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
let arguments = (ins PDL_Operation:$target,
UnitAttr:$vectorize_padding,
+ UnitAttr:$vectorize_nd_extract,
UnitAttr:$disable_multi_reduction_to_contract_patterns,
UnitAttr:$disable_transfer_permutation_map_lowering_patterns);
let results = (outs PDL_Operation:$transformed);
@@ -1098,7 +1099,9 @@ def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
let assemblyFormat = "$target attr-dict";
let builders = [
- OpBuilder<(ins "Value":$target, CArg<"bool", "false">:$vectorizePadding)>
+ OpBuilder<(ins "Value":$target,
+ CArg<"bool", "false">:$vectorizePadding,
+ CArg<"bool", "false">:$vectorizeNDExtract)>,
];
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 81ec026a8a79f..7d6a58431979d 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -345,7 +345,8 @@ FailureOr<LinalgOp> promoteSubViews(OpBuilder &b, LinalgOp op,
const LinalgPromotionOptions &options);
/// Emit a suitable vector form for a Linalg op with fully static shape.
-LogicalResult vectorize(RewriterBase &builder, LinalgOp linalgOp);
+LogicalResult vectorize(RewriterBase &builder, LinalgOp linalgOp,
+ bool vectorizeNDExtract = false);
/// Emit a suitable vector form for a Copy op with fully static shape.
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp);
@@ -371,7 +372,8 @@ LogicalResult promoteSubviewsPrecondition(Operation *op,
LinalgPromotionOptions options);
/// Return success if the operation can be vectorized.
-LogicalResult vectorizeLinalgOpPrecondition(LinalgOp linalgOp);
+LogicalResult vectorizeLinalgOpPrecondition(LinalgOp linalgOp,
+ bool vectorizeNDExtract = false);
//===----------------------------------------------------------------------===//
// Transformations exposed as rewrite patterns.
@@ -865,6 +867,9 @@ struct GeneralizePadOpPattern : public OpRewritePattern<tensor::PadOp> {
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns,
PatternBenefit baseBenefit = 1);
+void populateExtractOpVectorizationPatterns(RewritePatternSet &patterns,
+ PatternBenefit baseBenefit = 1);
+
/// Match and rewrite for the pattern:
/// ```
/// %alloc = ...
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 8fdd6cb2b6dee..d35d96ac4310e 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1781,12 +1781,17 @@ void transform::TileToScfForOp::getEffects(
//===----------------------------------------------------------------------===//
void transform::VectorizeOp::build(OpBuilder &builder, OperationState &result,
- Value target, bool vectorizePadding) {
+ Value target, bool vectorizePadding,
+ bool vectorizeExtract) {
result.addOperands(target);
if (vectorizePadding) {
result.addAttribute(VectorizeOp::getVectorizePaddingAttrName(result.name),
builder.getUnitAttr());
}
+ if (vectorizeExtract) {
+ result.addAttribute(VectorizeOp::getVectorizeNdExtractAttrName(result.name),
+ builder.getUnitAttr());
+ }
result.addTypes(pdl::OperationType::get(builder.getContext()));
}
@@ -1794,15 +1799,22 @@ namespace {
/// This is an helper only to call vectorize via a pattern inside of
/// VectorizeOp::applyToOne.
struct VectorizationPattern : public RewritePattern {
- explicit VectorizationPattern(MLIRContext *context)
- : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
+ explicit VectorizationPattern(MLIRContext *context,
+ bool vectorizeExtract = false)
+ : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context),
+ vectorizeNDExtract(vectorizeExtract) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
if (!linalgOp)
return rewriter.notifyMatchFailure(op, "expected Linalg Op");
- return vectorize(rewriter, linalgOp);
+ return vectorize(rewriter, linalgOp, vectorizeNDExtract);
}
+
+private:
+ /// Controls whether to vectorize `tensor.extract` when the input tensor is
+ /// rank >= 2.
+ bool vectorizeNDExtract = false;
};
} // namespace
@@ -1818,7 +1830,7 @@ transform::VectorizeOp::applyToOne(Operation *target,
MLIRContext *ctx = getContext();
RewritePatternSet patterns(ctx);
- patterns.add<VectorizationPattern>(ctx);
+ patterns.add<VectorizationPattern>(ctx, getVectorizeNdExtract());
if (!getDisableTransferPermutationMapLoweringPatterns())
vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index ad52a5a28c349..a7c3c0094889f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -242,7 +242,7 @@ static Value buildVectorWrite(OpBuilder &b, Value value,
// with CustomVectorizationHook. Returns success if the corresponding custom
// hook can vectorize the op.
using CustomVectorizationPrecondition =
- std::function<LogicalResult(Operation *)>;
+ std::function<LogicalResult(Operation *, bool)>;
// Custom vectorization function type. Produce a vector form of Operation*
// assuming all its vectorized operands are already in the BlockAndValueMapping.
@@ -314,13 +314,13 @@ vectorizeLinalgIndex(RewriterBase &rewriter, Operation *op, LinalgOp linalgOp) {
/// Helper function to check if the tensor.extract can be vectorized by the
/// custom hook vectorizeTensorExtract.
-static LogicalResult tensorExtractVectorizationPrecondition(Operation *op) {
+static LogicalResult
+tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract) {
tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
if (!extractOp)
return failure();
- // Currently only supports extraction with an 1-D index.
- if (extractOp.getIndices().size() != 1)
+ if (extractOp.getIndices().size() != 1 && !vectorizeNDExtract)
return failure();
if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType()))
@@ -335,6 +335,51 @@ static LogicalResult tensorExtractVectorizationPrecondition(Operation *op) {
return success();
}
+/// Calculates the offsets (`$index_vec`) for `vector.gather` operations
+/// generated from `tensor.extract`. The offset is calculated as follows
+/// (example using scalar values):
+///
+/// offset = extractOp.indices[0]
+/// for (i = 1; i < numIndices; i++)
+/// offset = extractOp.dimSize[i] * offset + extractOp.indices[i];
+///
+/// For tensor<45 x 80 x 15 x f32> and index [1, 2, 3], this leads to:
+/// offset = ( ( 1 ) * 80 + 2 ) * 15 + 3
+static Value
+calculateGatherOffset(OpBuilder &b, tensor::ExtractOp extractOp,
+ const BlockAndValueMapping &bvm,
+ const SmallVectorImpl<int64_t> &targetShape) {
+ // The vector of indices for GatherOp should be shaped as the output vector
+ auto indexVecType = VectorType::get(targetShape, b.getIndexType());
+ auto loc = extractOp.getLoc();
+
+ Value offset = b.create<vector::BroadcastOp>(
+ loc, indexVecType, bvm.lookup(extractOp.getIndices()[0]));
+
+ const size_t numIndices = extractOp.getIndices().size();
+ for (size_t i = 1; i < numIndices; i++) {
+ auto dimSizeBcast = b.create<vector::BroadcastOp>(
+ loc, indexVecType,
+ b.create<arith::ConstantIndexOp>(
+ loc,
+ extractOp.getTensor().getType().cast<ShapedType>().getDimSize(i)));
+ offset = b.create<arith::MulIOp>(loc, offset, dimSizeBcast);
+
+ auto originalIndexBcast = bvm.lookup(extractOp.getIndices()[i]);
+ if (i == numIndices - 1) {
+ // We only need an additional broadcast for the trailing index. All other
+ // indices have already been broadcast by `vectorizeLinalgIndex` to match
+ // the output size.
+ originalIndexBcast = b.create<vector::BroadcastOp>(
+ loc, indexVecType, bvm.lookup(extractOp.getIndices()[i]));
+ }
+
+ offset = b.create<arith::AddIOp>(loc, originalIndexBcast, offset);
+ }
+
+ return offset;
+}
+
/// Helper function to vectorize the tensor.extract operations. Returns
/// VectorizationStatus::NewOp to signal the vectorization algorithm that it
/// should map the produced operations. This function is meant to be used as a
@@ -347,29 +392,29 @@ vectorizeTensorExtract(RewriterBase &rewriter, Operation *op, LinalgOp linalgOp,
return VectorizationResult{VectorizationStatus::Failure, nullptr};
auto loc = extractOp.getLoc();
- // Currently only supports extraction with an 1-D index. Checked in the
- // tensorExtractVectorizationPrecondition.
- assert(extractOp.getIndices().size() == 1);
-
- auto indexVec = bvm.lookup(extractOp.getIndices()[0]);
// Compute the static loop sizes of the extract op.
auto targetShape = linalgOp.computeStaticLoopSizes();
- SmallVector<Value> gatherIndices;
- gatherIndices.push_back(rewriter.create<arith::ConstantIndexOp>(loc, 0));
-
+ auto resultType =
+ VectorType::get(targetShape, extractOp.getResult().getType());
auto maskConstantOp = rewriter.create<arith::ConstantOp>(
loc, DenseIntElementsAttr::get(
VectorType::get(targetShape, rewriter.getI1Type()),
/*value=*/true));
-
- auto resultType =
- VectorType::get(targetShape, extractOp.getResult().getType());
auto passThruConstantOp =
rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(resultType));
+ // Base indices are currently set to 0. We will need to re-visit if more
+ // generic scenarios are to be supported.
+ SmallVector<Value> baseIndices(
+ extractOp.getIndices().size(),
+ rewriter.create<arith::ConstantIndexOp>(loc, 0));
+
+ Value offset = calculateGatherOffset(rewriter, extractOp, bvm, targetShape);
+
+ // Generate the gather load
auto gatherOp = rewriter.create<vector::GatherOp>(
- loc, resultType, extractOp.getTensor(), gatherIndices, indexVec,
+ loc, resultType, extractOp.getTensor(), baseIndices, offset,
maskConstantOp, passThruConstantOp);
return VectorizationResult{VectorizationStatus::NewOp, gatherOp};
@@ -586,7 +631,7 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, LinalgOp linalgOp,
};
hooks.push_back(vectorizeYield);
- // 4rewriter. Register CustomVectorizationHook for indexOp.
+ // 4b. Register CustomVectorizationHook for indexOp.
CustomVectorizationHook vectorizeIndex =
[&](Operation *op,
const BlockAndValueMapping &bvm) -> VectorizationResult {
@@ -642,7 +687,8 @@ static LogicalResult reductionPreconditions(LinalgOp op) {
static LogicalResult vectorizeStaticLinalgOpPrecondition(
linalg::LinalgOp op,
- ArrayRef<CustomVectorizationPrecondition> customPreconditions) {
+ ArrayRef<CustomVectorizationPrecondition> customPreconditions,
+ bool vectorizeNDExtract) {
// All types in the body should be a supported element type for VectorType.
for (Operation &innerOp : op->getRegion(0).front()) {
@@ -650,7 +696,8 @@ static LogicalResult vectorizeStaticLinalgOpPrecondition(
if (llvm::any_of(
customPreconditions,
[&](const CustomVectorizationPrecondition &customPrecondition) {
- return succeeded(customPrecondition(&innerOp));
+ return succeeded(
+ customPrecondition(&innerOp, vectorizeNDExtract));
})) {
continue;
}
@@ -686,7 +733,9 @@ static LogicalResult vectorizeStaticLinalgOpPrecondition(
return success();
}
-LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(LinalgOp linalgOp) {
+LogicalResult
+mlir::linalg::vectorizeLinalgOpPrecondition(LinalgOp linalgOp,
+ bool vectorizeNDExtract) {
// All types must be static shape to go to vector.
if (linalgOp.hasDynamicShape()) {
LDBG("precondition failed: dynamic shape");
@@ -698,12 +747,13 @@ LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(LinalgOp linalgOp) {
// Register CustomVectorizationPrecondition for extractOp.
customPreconditions.push_back(tensorExtractVectorizationPrecondition);
- return vectorizeStaticLinalgOpPrecondition(linalgOp, customPreconditions);
+ return vectorizeStaticLinalgOpPrecondition(linalgOp, customPreconditions,
+ vectorizeNDExtract);
}
-LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter,
- LinalgOp linalgOp) {
- if (failed(vectorizeLinalgOpPrecondition(linalgOp)))
+LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, LinalgOp linalgOp,
+ bool vectorizeNDExtract) {
+ if (failed(vectorizeLinalgOpPrecondition(linalgOp, vectorizeNDExtract)))
return failure();
SmallVector<Value> results;
@@ -713,7 +763,7 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter,
if (succeeded(convOr)) {
llvm::append_range(results, (*convOr)->getResults());
} else {
- if (failed(vectorizeLinalgOpPrecondition(linalgOp)))
+ if (failed(vectorizeLinalgOpPrecondition(linalgOp, vectorizeNDExtract)))
return failure();
LDBG("Vectorize generic by broadcasting to a common shape: " << linalgOp);
if (failed(vectorizeAsLinalgGeneric(rewriter, linalgOp, results)))
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 3b351dc6b334f..a9a536ae596b8 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -1500,7 +1500,7 @@ transform.sequence failures(propagate) {
#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d2)>
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-func.func @not_vectorize_nd_tensor_extract(%arg0: tensor<3x3xf32>, %arg1: tensor<4x3xi32>, %arg2: tensor<4x3xi32>, %arg3: tensor<4x7x2xf32>, %arg4: tensor<4x7x3x2xf32>) -> tensor<4x7x3x2xf32> {
+func.func @vectorize_nd_tensor_extract(%arg0: tensor<3x3xf32>, %arg1: tensor<4x3xi32>, %arg2: tensor<4x3xi32>, %arg3: tensor<4x7x2xf32>, %arg4: tensor<4x7x3x2xf32>) -> tensor<4x7x3x2xf32> {
%2 = linalg.generic {
indexing_maps = [#map0, #map0, #map1, #map2],
iterator_types = ["parallel", "parallel", "parallel", "parallel"]
@@ -1513,14 +1513,34 @@ func.func @not_vectorize_nd_tensor_extract(%arg0: tensor<3x3xf32>, %arg1: tensor
} -> tensor<4x7x3x2xf32>
return %2 : tensor<4x7x3x2xf32>
}
-// CHECK-LABEL: func.func @not_vectorize_nd_tensor_extract
-// CHECK: tensor.extract
+// CHECK-LABEL: func.func @vectorize_nd_tensor_extract
+// CHECK-SAME: %[[ARG0:.*]]: tensor<3x3xf32>
+// CHECK-SAME: %[[ARG1:arg1]]: tensor<4x3xi32>
+// CHECK-SAME: %[[ARG2:arg2]]: tensor<4x3xi32>
+// CHECK-SAME: %[[ARG3:.*]]: tensor<4x7x2xf32>
+// CHECK-SAME: %[[ARG4:.*]]: tensor<4x7x3x2xf32>
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[C0_i32:.*]] = arith.constant 0 : i32
+// CHECK: %[[CST:.*]] = arith.constant dense<3> : vector<7x2x4x3xindex>
+// CHECK: %[[CST_1:.*]] = arith.constant dense<true> : vector<4x7x3x2xi1>
+// CHECK: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<4x7x3x2xf32>
+// CHECK: %[[V0:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], %[[C0_i32]] {in_bounds = [true, true]} : tensor<4x3xi32>, vector<4x3xi32>
+// CHECK: %[[V1:.*]] = vector.transfer_read %[[ARG2]][%[[C0]], %[[C0]]], %[[C0_i32]] {in_bounds = [true, true]} : tensor<4x3xi32>, vector<4x3xi32>
+// CHECK: %[[CAST:.*]] = arith.index_cast %[[V0]] : vector<4x3xi32> to vector<4x3xindex>
+// CHECK: %[[B1:.*]] = vector.broadcast %[[CAST]] : vector<4x3xindex> to vector<7x2x4x3xindex>
+// CHECK: %[[CAST_1:.*]] = arith.index_cast %[[V1]] : vector<4x3xi32> to vector<4x3xindex>
+// CHECK: %[[B2:.*]] = vector.broadcast %[[CAST_1]] : vector<4x3xindex> to vector<7x2x4x3xindex>
+// CHECK: %[[MULI:.*]] = arith.muli %[[B1]], %[[CST]] : vector<7x2x4x3xindex>
+// CHECK: %[[ADDI:.*]] = arith.addi %[[B2]], %[[MULI]] : vector<7x2x4x3xindex>
+// CHECK: %[[T:.*]] = vector.transpose %[[ADDI]], [2, 0, 3, 1] : vector<7x2x4x3xindex> to vector<4x7x3x2xindex>
+// CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]]] [%[[T]]], %[[CST_1]], %[[PASSTHRU]] : tensor<3x3xf32>, vector<4x7x3x2xindex>, vector<4x7x3x2xi1>, vector<4x7x3x2xf32> into vector<4x7x3x2xf32>
+// CHECK: vector.transfer_write %[[GATHER]], %[[ARG4]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true, true]} : vector<4x7x3x2xf32>, tensor<4x7x3x2xf32>
transform.sequence failures(propagate) {
^bb1(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1
%1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation
- %2 = transform.structured.vectorize %1
+ %2 = transform.structured.vectorize %1 { vectorize_nd_extract }
}
// -----
More information about the Mlir-commits
mailing list