[Mlir-commits] [mlir] d09bef8 - [MLIR] Vectorize tensor.extract on 1-d tensor
Diego Caballero
llvmlistbot at llvm.org
Mon Oct 17 17:06:37 PDT 2022
Author: Che-Yu Wu
Date: 2022-10-18T00:06:02Z
New Revision: d09bef82c0709c5755ce33da20532481b7da2245
URL: https://github.com/llvm/llvm-project/commit/d09bef82c0709c5755ce33da20532481b7da2245
DIFF: https://github.com/llvm/llvm-project/commit/d09bef82c0709c5755ce33da20532481b7da2245.diff
LOG: [MLIR] Vectorize tensor.extract on 1-d tensor
This patch implements the vectorization of tensor.extract for the
basic 1-d lookup case. It only vectorizes the tensor.extract to a
vector.gather when the op extracts value from an 1-d tensor.
Related discussion: https://github.com/iree-org/iree/issues/9198
Reviewed By: dcaballe
Differential Revision: https://reviews.llvm.org/D133786
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/lib/Dialect/Linalg/Utils/Utils.cpp
mlir/test/Dialect/Linalg/vectorization.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 2b70155b24887..a435c10833676 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -232,6 +232,12 @@ static Value buildVectorWrite(OpBuilder &b, Value value,
return Value();
}
+// Custom vectorization precondition function type. This is intented to be used
+// with CustomVectorizationHook. Returns success if the correpsonding custom
+// hook can vectorize the op.
+using CustomVectorizationPrecondition =
+ std::function<LogicalResult(Operation *)>;
+
// Custom vectorization function type. Produce a vector form of Operation*
// assuming all its vectorized operands are already in the BlockAndValueMapping.
// Return nullptr if the Operation cannot be vectorized.
@@ -300,6 +306,69 @@ static VectorizationResult vectorizeLinalgIndex(OpBuilder &b, Operation *op,
return VectorizationResult{VectorizationStatus::NewOp, transposeOp};
}
+/// Helper function to check if the tensor.extract can be vectorized by the
+/// custom hook vectorizeTensorExtract.
+static LogicalResult tensorExtractVectorizationPrecondition(Operation *op) {
+ tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
+ if (!extractOp)
+ return failure();
+
+ // Currently only supports extraction with an 1-D index.
+ if (extractOp.getIndices().size() != 1)
+ return failure();
+
+ if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType()))
+ return failure();
+
+ if (llvm::any_of(extractOp->getResultTypes(), [](Type type) {
+ return !VectorType::isValidElementType(type);
+ })) {
+ return failure();
+ }
+
+ return success();
+}
+
+/// Helper function to vectorize the tensor.extract operations. Returns
+/// VectorizationStatus::NewOp to signal the vectorization algorithm that it
+/// should map the produced operations. This function is meant to be used as a
+/// CustomVectorizationHook.
+static VectorizationResult
+vectorizeTensorExtract(OpBuilder &b, Operation *op, LinalgOp linalgOp,
+ const BlockAndValueMapping &bvm) {
+ tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
+ if (!extractOp)
+ return VectorizationResult{VectorizationStatus::Failure, nullptr};
+ auto loc = extractOp.getLoc();
+
+ // Currently only supports extraction with an 1-D index. Checked in the
+ // tensorExtractVectorizationPrecondition.
+ assert(extractOp.getIndices().size() == 1);
+
+ auto indexVec = bvm.lookup(extractOp.getIndices()[0]);
+ // Compute the static loop sizes of the extract op.
+ auto targetShape = linalgOp.computeStaticLoopSizes();
+
+ SmallVector<Value> gatherIndices;
+ gatherIndices.push_back(b.create<arith::ConstantIndexOp>(loc, 0));
+
+ auto maskConstantOp = b.create<arith::ConstantOp>(
+ loc,
+ DenseIntElementsAttr::get(VectorType::get(targetShape, b.getI1Type()),
+ /*value=*/true));
+
+ auto resultType =
+ VectorType::get(targetShape, extractOp.getResult().getType());
+ auto passThruConstantOp =
+ b.create<arith::ConstantOp>(loc, b.getZeroAttr(resultType));
+
+ auto gatherOp = b.create<vector::GatherOp>(
+ loc, resultType, extractOp.getTensor(), gatherIndices, indexVec,
+ maskConstantOp, passThruConstantOp);
+
+ return VectorizationResult{VectorizationStatus::NewOp, gatherOp};
+}
+
/// Emit reduction operations if the shapes of the value to reduce is
diff erent
/// that the result shape.
static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op,
@@ -515,6 +584,14 @@ vectorizeAsLinalgGeneric(OpBuilder &b, LinalgOp linalgOp,
};
hooks.push_back(vectorizeIndex);
+ // 4c. Register CustomVectorizationHook for extractOp.
+ CustomVectorizationHook vectorizeExtract =
+ [&](Operation *op,
+ const BlockAndValueMapping &bvm) -> VectorizationResult {
+ return vectorizeTensorExtract(b, op, linalgOp, bvm);
+ };
+ hooks.push_back(vectorizeExtract);
+
// 5. Iteratively call `vectorizeOneOp` to each op in the slice.
for (Operation &op : block->getOperations()) {
VectorizationResult result = vectorizeOneOp(b, linalgOp, &op, bvm, hooks);
@@ -552,9 +629,20 @@ static LogicalResult reductionPreconditions(LinalgOp op) {
return success();
}
-static LogicalResult vectorizeStaticLinalgOpPrecondition(linalg::LinalgOp op) {
+static LogicalResult vectorizeStaticLinalgOpPrecondition(
+ linalg::LinalgOp op,
+ ArrayRef<CustomVectorizationPrecondition> customPreconditions) {
+
// All types in the body should be a supported element type for VectorType.
for (Operation &innerOp : op->getRegion(0).front()) {
+ // Check if any custom hook can vectorize the inner op.
+ if (llvm::any_of(
+ customPreconditions,
+ [&](const CustomVectorizationPrecondition &customPrecondition) {
+ return succeeded(customPrecondition(&innerOp));
+ })) {
+ continue;
+ }
if (llvm::any_of(innerOp.getOperandTypes(), [](Type type) {
return !VectorType::isValidElementType(type);
})) {
@@ -566,16 +654,8 @@ static LogicalResult vectorizeStaticLinalgOpPrecondition(linalg::LinalgOp op) {
return failure();
}
}
- if (isElementwise(op)) {
- // Some operations in the body cannot be vectorized.
- for (Operation &payloadOp : *op.getBlock()) {
- if (isa<tensor::ExtractOp>(payloadOp)) {
- LDBG("precondition failed: `tensor.extract` not vectorizable");
- return failure();
- }
- }
+ if (isElementwise(op))
return success();
- }
// TODO: isaConvolutionOpInterface that can also infer from generic features.
// But we will still need stride/dilation attributes that will be annoying to
// reverse-engineer...
@@ -601,7 +681,13 @@ LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(LinalgOp linalgOp) {
LDBG("precondition failed: dynamic shape");
return failure();
}
- return vectorizeStaticLinalgOpPrecondition(linalgOp);
+
+ SmallVector<CustomVectorizationPrecondition> customPreconditions;
+
+ // Register CustomVectorizationPrecondition for extractOp.
+ customPreconditions.push_back(tensorExtractVectorizationPrecondition);
+
+ return vectorizeStaticLinalgOpPrecondition(linalgOp, customPreconditions);
}
LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter,
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 119c3db644c55..aba2d5f5cd49f 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -161,8 +161,8 @@ bool hasOnlyScalarElementwiseOp(Region &r) {
if (!llvm::hasSingleElement(r))
return false;
for (Operation &op : r.front()) {
- if (!(isa<arith::ConstantOp, func::ConstantOp, linalg::YieldOp,
- linalg::IndexOp>(op) ||
+ if (!(isa<arith::ConstantOp, func::ConstantOp, tensor::ExtractOp,
+ linalg::YieldOp, linalg::IndexOp>(op) ||
OpTrait::hasElementwiseMappableTraits(&op)) ||
llvm::any_of(op.getResultTypes(),
[](Type type) { return !type.isIntOrIndexOrFloat(); }))
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index f3735be9a3a5d..8a83c3137e6a0 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -1457,3 +1457,68 @@ transform.sequence failures(propagate) {
%1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation
%2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns, disable_transfer_permutation_map_lowering_patterns }
}
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+func.func @vectorize_1d_tensor_extract(%arg0: tensor<3xf32>, %arg1: tensor<4x3xi32>, %arg2: tensor<4x7x2xf32>, %arg3: tensor<4x7x3x2xf32>) -> tensor<4x7x3x2xf32> {
+ %2 = linalg.generic {
+ indexing_maps = [#map0, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+ } ins(%arg1, %arg2 : tensor<4x3xi32>, tensor<4x7x2xf32>) outs(%arg3 : tensor<4x7x3x2xf32>) {
+ ^bb0(%arg4: i32, %arg5: f32, %arg6: f32):
+ %3 = arith.index_cast %arg4 : i32 to index
+ %7 = tensor.extract %arg0[%3] : tensor<3xf32>
+ linalg.yield %7 : f32
+ } -> tensor<4x7x3x2xf32>
+ return %2 : tensor<4x7x3x2xf32>
+}
+// CHECK-LABEL: func.func @vectorize_1d_tensor_extract
+// CHECK-SAME: %[[ARG0:.*]]: tensor<3xf32>
+// CHECK-SAME: %[[ARG1:.*]]: tensor<4x3xi32>
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<4x7x3x2xi1>
+// CHECK: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<4x7x3x2xf32>
+// CHECK: %[[V0:.*]] = vector.transfer_read %[[ARG1]]
+// CHECK: %[[CAST:.*]] = arith.index_cast %[[V0]]
+// CHECK: %[[BROADCAST:.*]] = vector.broadcast %[[CAST]]
+// CHECK: %[[INDICES:.*]] = vector.transpose %[[BROADCAST]]
+// CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]]] [%[[INDICES]]], %[[MASK]], %[[PASSTHRU]]
+// CHECK: vector.transfer_write %[[GATHER]]
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+ %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation
+ %2 = transform.structured.vectorize %1
+}
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+func.func @not_vectorize_nd_tensor_extract(%arg0: tensor<3x3xf32>, %arg1: tensor<4x3xi32>, %arg2: tensor<4x3xi32>, %arg3: tensor<4x7x2xf32>, %arg4: tensor<4x7x3x2xf32>) -> tensor<4x7x3x2xf32> {
+ %2 = linalg.generic {
+ indexing_maps = [#map0, #map0, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+ } ins(%arg1, %arg2, %arg3 : tensor<4x3xi32>, tensor<4x3xi32>, tensor<4x7x2xf32>) outs(%arg4 : tensor<4x7x3x2xf32>) {
+ ^bb0(%arg5: i32, %arg6: i32, %arg7: f32, %arg8: f32):
+ %3 = arith.index_cast %arg5 : i32 to index
+ %4 = arith.index_cast %arg6 : i32 to index
+ %7 = tensor.extract %arg0[%3, %4] : tensor<3x3xf32>
+ linalg.yield %7 : f32
+ } -> tensor<4x7x3x2xf32>
+ return %2 : tensor<4x7x3x2xf32>
+}
+// CHECK-LABEL: func.func @not_vectorize_nd_tensor_extract
+// CHECK: tensor.extract
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+ %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation
+ %2 = transform.structured.vectorize %1
+}
More information about the Mlir-commits
mailing list