[Mlir-commits] [mlir] 8c5ad0a - [mlir][Vector] Add a masked vectorization of tensor.pad
Nicolas Vasilache
llvmlistbot at llvm.org
Thu Apr 13 13:20:36 PDT 2023
Author: Nicolas Vasilache
Date: 2023-04-13T13:20:29-07:00
New Revision: 8c5ad0a2f6532cec2f6841cc3e9a1ea043409398
URL: https://github.com/llvm/llvm-project/commit/8c5ad0a2f6532cec2f6841cc3e9a1ea043409398
DIFF: https://github.com/llvm/llvm-project/commit/8c5ad0a2f6532cec2f6841cc3e9a1ea043409398.diff
LOG: [mlir][Vector] Add a masked vectorization of tensor.pad
This revision takes advantage of masking support to introduce a vectorized
version of pad that does not require lowering to lower-level form.
Lowering to lower-level form (if/else + generate + fill + copy + insert_slice)
creates unnecessary complexity that can be completely sidestepped by using
masked vectorization properly.
Differential Revision: https://reviews.llvm.org/D148261
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Dialect/Linalg/vectorization.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 7eaa2f7168bd..52982c3fd537 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -589,6 +589,13 @@ LogicalResult vectorize(RewriterBase &rewriter, LinalgOp linalgOp,
/// Emit a suitable vector form for a Copy op with fully static shape.
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp);
+/// Vectorize a `padOp` with (1) static result type, (2) constant padding value
+/// and (3) all-zero lowPad to
+/// `transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))`.
+FailureOr<vector::TransferWriteOp>
+maskedVectorize(RewriterBase &rewriter, tensor::PadOp padOp,
+ ArrayRef<int64_t> inputVectorSizes);
+
/// Emit a loop nest of `scf.for` with the proper body for `linalgOp`.
FailureOr<LinalgLoops> linalgOpToLoops(RewriterBase &rewriter,
LinalgOp linalgOp);
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 0b844e1e3333..2a95ff243fbf 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2322,6 +2322,11 @@ def Vector_CreateMaskOp :
```
}];
+ let builders = [
+ // Build with mixed static/dynamic operands.
+ OpBuilder<(ins "VectorType":$type, "ArrayRef<OpFoldResult>":$mixedOperands)>
+ ];
+
let hasCanonicalizer = 1;
let hasVerifier = 1;
let assemblyFormat = "$operands attr-dict `:` type(results)";
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 2970d3476d51..39f7802a688a 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3098,6 +3098,16 @@ DiagnosedSilenceableFailure transform::MaskedVectorizeOp::apply(
// TODO: Check that the correct number of vectorSizes was provided.
for (Operation *target : targets) {
+ if (auto padOp = dyn_cast<tensor::PadOp>(target)) {
+ FailureOr<vector::TransferWriteOp> maybeWriteOp =
+ maskedVectorize(rewriter, padOp, vectorSizes);
+ if (failed(maybeWriteOp)) {
+ return mlir::emitSilenceableFailure(target->getLoc())
+ << "failed to vectorize padOp";
+ }
+ continue;
+ }
+
auto linalgOp = dyn_cast<LinalgOp>(target);
if (!linalgOp) {
return mlir::emitSilenceableFailure(target->getLoc())
@@ -3107,7 +3117,7 @@ DiagnosedSilenceableFailure transform::MaskedVectorizeOp::apply(
if (failed(linalg::vectorize(rewriter, linalgOp, vectorSizes,
getVectorizeNdExtract()))) {
return mlir::emitSilenceableFailure(target->getLoc())
- << "failed to vectorize op";
+ << "failed to vectorize linalg op";
}
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index b54eb0fa9a4f..14726a8c579f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -26,6 +26,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/RegionUtils.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"
@@ -1385,6 +1386,63 @@ static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
}
}
+FailureOr<vector::TransferWriteOp>
+mlir::linalg::maskedVectorize(RewriterBase &rewriter, tensor::PadOp padOp,
+ ArrayRef<int64_t> inputVectorSizes) {
+ auto padValue = padOp.getConstantPaddingValue();
+ if (!padValue) {
+ LDBG("pad value is not constant: " << padOp << "\n");
+ return rewriter.notifyMatchFailure(padOp, "pad value is not constant");
+ }
+
+ ArrayRef<int64_t> resultTensorShape = padOp.getResultType().getShape();
+ if (!(resultTensorShape == inputVectorSizes)) {
+ LDBG("result tensor shape must match input vector sizes: " << padOp
+ << "\n");
+ return rewriter.notifyMatchFailure(
+ padOp, "result tensor shape must match input vector sizes");
+ }
+ if (llvm::any_of(padOp.getStaticLow(),
+ [](int64_t val) { return val != 0; })) {
+ LDBG("low pad must all be zero: " << padOp << "\n");
+ return rewriter.notifyMatchFailure(padOp, "low pad must all be zero");
+ }
+
+ Location loc = padOp.getLoc();
+ int64_t rank = inputVectorSizes.size();
+ auto maskType = VectorType::get(inputVectorSizes, rewriter.getI1Type());
+ auto vectorType = VectorType::get(inputVectorSizes, padValue.getType());
+
+ // transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(padOp);
+ auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ auto emptyOp =
+ rewriter.create<tensor::EmptyOp>(loc, padOp.getResultType(),
+ /*dynamicSizes=*/ValueRange{});
+ SmallVector<OpFoldResult> mixedSourceDims =
+ getMixedDimensions(rewriter, loc, padOp.getSource());
+ Value mask =
+ rewriter.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
+ auto transferReadOp = rewriter.create<vector::TransferReadOp>(
+ loc,
+ /*vectorType=*/vectorType,
+ /*source=*/padOp.getSource(),
+ /*indices=*/SmallVector<Value>(rank, zero),
+ /*padding=*/padValue,
+ /*inBounds=*/SmallVector<bool>(rank, true));
+ auto maskedOp = cast<vector::MaskOp>(
+ mlir::vector::maskOperation(rewriter, transferReadOp, mask));
+ auto transferWriteOp = rewriter.create<vector::TransferWriteOp>(
+ loc,
+ /*vector=*/maskedOp->getResult(0),
+ /*source=*/emptyOp,
+ /*indices=*/SmallVector<Value>(rank, zero),
+ /*inBounds=*/SmallVector<bool>(rank, true));
+ rewriter.replaceOp(padOp, transferWriteOp->getResults());
+ return transferWriteOp;
+}
+
/// Emit a suitable vector form for a Linalg op. If provided, `inputVectorSizes`
/// are used to vectorize this operation. `inputVectorSizes` must match the rank
/// of the iteration space of the operation and the input vector sizes must be
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 8ee59659c3a4..89ca09911230 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -640,10 +640,9 @@ ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) {
auto loc = parser.getCurrentLocation();
DictionaryAttr dictAttr;
// TODO: Unify linalg op attribute parsing.
- if (parser.parseAttribute(dictAttr) ||
- parser.parseOperand(lhsInfo) || parser.parseComma() ||
- parser.parseOperand(rhsInfo) || parser.parseComma() ||
- parser.parseOperand(accInfo) ||
+ if (parser.parseAttribute(dictAttr) || parser.parseOperand(lhsInfo) ||
+ parser.parseComma() || parser.parseOperand(rhsInfo) ||
+ parser.parseComma() || parser.parseOperand(accInfo) ||
parser.parseTrailingOperandList(masksInfo) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonTypeList(types) ||
@@ -5369,6 +5368,14 @@ LogicalResult ConstantMaskOp::verify() {
// CreateMaskOp
//===----------------------------------------------------------------------===//
+void CreateMaskOp::build(OpBuilder &builder, OperationState &result,
+ VectorType type,
+ ArrayRef<OpFoldResult> mixedOperands) {
+ SmallVector<Value> operands =
+ getValueOrCreateConstantIndexOp(builder, result.location, mixedOperands);
+ build(builder, result, type, operands);
+}
+
LogicalResult CreateMaskOp::verify() {
auto vectorType = getResult().getType().cast<VectorType>();
// Verify that an operand was specified for each result vector each dimension.
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index c407b49d896c..d54a2f57617c 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -2757,3 +2757,37 @@ transform.sequence failures(propagate) {
%0 = transform.structured.match ops{["linalg.copy"]} in %arg1 : (!pdl.operation) -> !pdl.operation
transform.structured.masked_vectorize %0 vector_sizes [2, 4]
}
+
+// -----
+
+// CHECK-LABEL: func @test_masked_vectorize_pad
+func.func @test_masked_vectorize_pad(
+ %0 : tensor<?x?xf32>, %h0 : index, %h1 : index)
+ -> tensor<2x4xf32>
+{
+ // CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[c42:.*]] = arith.constant 4.243000e+01 : f32
+ // CHECK-DAG: %[[empty:.*]] = tensor.empty() : tensor<2x4xf32>
+ // CHECK: %[[d0:.*]] = tensor.dim {{.*}} : tensor<?x?xf32>
+ // CHECK: %[[d1:.*]] = tensor.dim {{.*}} : tensor<?x?xf32>
+ // CHECK: %[[mask:.*]] = vector.create_mask %[[d0]], %[[d1]] : vector<2x4xi1>
+ // CHECK: %[[masked_read:.*]] = vector.mask %[[mask]] {
+ // CHECK-SAME: vector.transfer_read %{{.*}}[%[[c0]], %[[c0]]], %[[c42]]
+ // CHECK-SAME: {in_bounds = [true, true]} : tensor<?x?xf32>, vector<2x4xf32>
+ // CHECK-SAME: } : vector<2x4xi1> -> vector<2x4xf32>
+ // CHECK: vector.transfer_write %[[masked_read]], %[[empty]][%[[c0]], %[[c0]]]
+ // CHECK-SAME: {in_bounds = [true, true]} : vector<2x4xf32>, tensor<2x4xf32>
+ %cst = arith.constant 42.43 : f32
+ %1 = tensor.pad %0 low[0, 0] high[%h0, %h1] {
+ ^bb0(%hh1: index, %hh2: index):
+ tensor.yield %cst : f32
+ } : tensor<?x?xf32> to tensor<2x4xf32>
+ return %1: tensor<2x4xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["tensor.pad"]} in %arg1
+ : (!pdl.operation) -> !pdl.operation
+ transform.structured.masked_vectorize %0 vector_sizes [2, 4]
+}
More information about the Mlir-commits
mailing list