[Mlir-commits] [mlir] 9a7d111 - [mlir][Linalg] NFC - Modernize transformation APIs.
Nicolas Vasilache
llvmlistbot at llvm.org
Wed Jan 5 08:06:37 PST 2022
Author: Nicolas Vasilache
Date: 2022-01-05T11:01:40-05:00
New Revision: 9a7d111f4fb65ad7343dcbd4f35ee608100634e8
URL: https://github.com/llvm/llvm-project/commit/9a7d111f4fb65ad7343dcbd4f35ee608100634e8
DIFF: https://github.com/llvm/llvm-project/commit/9a7d111f4fb65ad7343dcbd4f35ee608100634e8.diff
LOG: [mlir][Linalg] NFC - Modernize transformation APIs.
Differential Revision: https://reviews.llvm.org/D116665
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp
mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index bb396ce5a5541..c1185c0a8ff70 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -44,12 +44,12 @@ bool skipUnitDimReshape(const OpResult &producer, OpOperand &consumer);
//===----------------------------------------------------------------------===//
using LinalgLoops = SmallVector<Operation *, 4>;
-/// [DEPRECATED] Populates patterns for vectorization of all ConvN-D ops.
+/// [DEPRECATED] Populate patterns for vectorization of all ConvN-D ops.
void populateConvVectorizationPatterns(
MLIRContext *context, SmallVectorImpl<RewritePatternSet> &patterns,
ArrayRef<int64_t> tileSizes);
-/// Populates patterns for vectorizing low-D convolution ops. This is a step in
+/// Populate patterns for vectorizing low-D convolution ops. This is a step in
/// progressive lowering for convolution ops, it assume high-D convolution ops
/// were decomposed previously.
void populateConvolutionVectorizationPatterns(RewritePatternSet &patterns,
@@ -91,7 +91,7 @@ void populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
/// canonicalizations of named ops into another named op.
void populateLinalgNamedOpConversionPatterns(RewritePatternSet &patterns);
-/// Populates the given list with patterns to bufferize linalg ops.
+/// Populate the given list with patterns to bufferize linalg ops.
void populateLinalgBufferizePatterns(
bufferization::BufferizeTypeConverter &converter,
RewritePatternSet &patterns);
@@ -124,7 +124,7 @@ struct LinalgElementwiseFusionOptions {
return *this;
}
- /// Function that allows the caller to control when to stop fusion. Once a
+ /// Function to allow the caller to control when to stop fusion. Once a
/// producer is deemed fusable with the consumer (structurally), this callback
/// can be used to abort the fusion based on non-structural constraints. This
/// is the hook for cost models to control the amount of fusion done.
@@ -149,7 +149,7 @@ void populateElementwiseOpsFusionPatterns(
/// more fusion opportunities.
void populatePushReshapeOpsPatterns(RewritePatternSet &patterns);
-/// Performs standalone tiling of a single LinalgOp by `tileSizes`.
+/// Perform standalone tiling of a single LinalgOp by `tileSizes`.
/// and permute the loop nest according to `interchangeVector`
/// The permutation is expressed as a list of integers that specify
/// the new ordering of the loop nest. The length of `interchangeVector`
@@ -157,7 +157,7 @@ void populatePushReshapeOpsPatterns(RewritePatternSet &patterns);
/// An empty vector is interpreted as the identity permutation and the
/// transformation returns early.
///
-/// Returns a struct containing the tiled loops in the specified order
+/// Return a struct containing the tiled loops in the specified order
/// and the cloned op if successful, llvm::None otherwise.
///
/// E.g. the permutation `(i,j,k) -> (j,k,i)` is expressed by
@@ -237,7 +237,7 @@ tileAndFuseLinalgOps(OpBuilder &builder, ArrayRef<LinalgOp> ops,
const LinalgDependenceGraph &dependenceGraph,
const LinalgTilingOptions &tilingOptions);
-/// Interchanges the `iterator_types` and `iterator_maps` dimensions and adapts
+/// Interchange the `iterator_types` and `iterator_maps` dimensions and adapts
/// the index accesses of `op`. This is an in-place transformation controlled by
/// `interchangeVector`. An empty vector is interpreted as the identity
/// permutation and the transformation returns early.
@@ -246,12 +246,15 @@ tileAndFuseLinalgOps(OpBuilder &builder, ArrayRef<LinalgOp> ops,
/// `interchangeVector = [1,2,0]`. All values in `interchangeVector` must be
/// integers, in the range 0..`op.rank` without duplications
/// (i.e. `[1,1,2]` is an invalid permutation).
-void interchangeGenericOp(PatternRewriter &rewriter, GenericOp genericOp,
- ArrayRef<unsigned> interchangeVector);
+FailureOr<GenericOp> interchangeGenericOp(RewriterBase &rewriter,
+ GenericOp genericOp,
+ ArrayRef<unsigned> interchangeVector);
-/// Creates a GenericOp from the given named operation `namedOp`. Assumes
-/// `namedOp` is not a GenericOp and has a region builder.
-GenericOp generalizeNamedOp(PatternRewriter &rewriter, LinalgOp namedOp);
+/// Create a GenericOp from the given named operation `namedOp` and replace
+/// namedOp.
+/// Return failure if `namedOp` is a GenericOp or misses a region builder.
+FailureOr<GenericOp> generalizeNamedOp(RewriterBase &rewriter,
+ LinalgOp namedOp);
/// Callback function type used to perform the allocation for the promoted
/// `subView`. In `boundingSubViewsize` a best attempt is made to find the
@@ -346,7 +349,7 @@ struct LinalgPromotionOptions {
}
};
-/// Creates a new buffer using the `allocationFn` provided. The size of this
+/// Create a new buffer using the `allocationFn` provided. The size of this
/// buffer is the smallest constant bounding size along each dimension that can
/// be computed for the size of the result of `subView`. Returns the allocated
/// buffer as `fullLocalView` and the view that matches the size of the result
@@ -360,7 +363,7 @@ promoteSubviewAsNewBuffer(OpBuilder &b, Location loc, memref::SubViewOp subView,
const AllocBufferCallbackFn &allocationFn,
DataLayout &layout);
-/// Promotes the `subViews` into a new buffer allocated at the insertion point
+/// Promote the `subViews` into a new buffer allocated at the insertion point
/// `b`. Promotion occurs in 3 steps:
/// 1. Create a new buffer for a full tile (i.e. not clipped at the boundary).
/// 2. Take a full view on the buffer.
@@ -368,24 +371,23 @@ promoteSubviewAsNewBuffer(OpBuilder &b, Location loc, memref::SubViewOp subView,
/// Infers statically sized buffers from subViews unless `dynamicBuffers` is
/// true.
///
-/// Returns the modified linalg op (the modification happens in place) as well
+/// Return the modified linalg op (the modification happens in place) as well
/// as all the copy ops created.
FailureOr<LinalgOp> promoteSubViews(OpBuilder &b, LinalgOp op,
const LinalgPromotionOptions &options);
/// Emit a suitable vector form for a Linalg op with fully static shape.
-LogicalResult vectorizeLinalgOp(OpBuilder &builder, Operation *op,
- SmallVectorImpl<Value> &newResults);
+LogicalResult vectorize(RewriterBase &builder, LinalgOp linalgOp);
-/// Emits a loop nest of `scf.for` with the proper body for `linalgOp`.
+/// Emit a loop nest of `scf.for` with the proper body for `linalgOp`.
FailureOr<LinalgLoops> linalgOpToLoops(PatternRewriter &rewriter,
LinalgOp linalgOp);
-/// Emits a loop nest of `scf.parallel` with the proper body for `linalgOp`.
+/// Emit a loop nest of `scf.parallel` with the proper body for `linalgOp`.
FailureOr<LinalgLoops> linalgOpToParallelLoops(PatternRewriter &rewriter,
LinalgOp linalgOp);
-/// Emits a loop nest of `affine.for` with the proper body for `linalgOp`.
+/// Emit a loop nest of `affine.for` with the proper body for `linalgOp`.
FailureOr<LinalgLoops> linalgOpToAffineLoops(PatternRewriter &rewriter,
LinalgOp linalgOp);
@@ -393,28 +395,10 @@ FailureOr<LinalgLoops> linalgOpToAffineLoops(PatternRewriter &rewriter,
// Preconditions that ensure the corresponding transformation succeeds and can
// be applied as a rewrite pattern.
//===----------------------------------------------------------------------===//
-/// Emits a `generic` operation with the `indexing_maps` and `iterator_types`
-/// permutated according to `permutation`.
-LogicalResult
-interchangeGenericOpPrecondition(GenericOp genericOp,
- ArrayRef<unsigned> interchangeVector);
-
-/// Generalize named operations to generic operations.
-LogicalResult generalizeNamedOpPrecondition(Operation *op);
-
-/// Promote std.subviews feeding linalg operations.
+/// Promote memref.subviews feeding linalg-on-buffers operations.
LogicalResult promoteSubviewsPrecondition(Operation *op,
LinalgPromotionOptions options);
-/// Return success if the operation can be vectorized.
-LogicalResult vectorizeLinalgOpPrecondition(Operation *op);
-
-/// Return success if `op` can be vectorized assuming it is static. This allows
-/// checking if an op will be vectorizable once all the dimensions are folded to
-/// static values.
-/// It is the same as `vectorizeLinalgOpPrecondition` for static shapes.
-LogicalResult vectorizeStaticLinalgOpPrecondition(LinalgOp op);
-
//===----------------------------------------------------------------------===//
// Transformations exposed as rewrite patterns.
//===----------------------------------------------------------------------===//
@@ -610,7 +594,7 @@ struct LinalgTilingOptions {
RewritePatternSet getLinalgTilingCanonicalizationPatterns(MLIRContext *ctx);
void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns);
-/// Base pattern that applied the tiling transformation specified by `options`.
+/// Base pattern that applies the tiling transformation specified by `options`.
/// Abort and return failure in 2 cases:
/// 1. if the tiling specification is invalid and tiling fails to occur.
/// 2. if tiling occurs but `options.paddingValueComputationFunction` is set
@@ -812,9 +796,9 @@ struct LinalgTileAndFuseTensorOpsPattern : public RewritePattern {
};
///
-/// Linalg generic interchage pattern.
+/// Linalg generic interchange pattern.
///
-/// Apply the `interchange` transformation as a pattern.
+/// Apply the `interchange` transformation on a RewriterBase.
/// `filter` controls LinalgTransformMarker matching and update when specified.
/// See `interchange` for more details.
struct GenericOpInterchangePattern : public OpRewritePattern<GenericOp> {
@@ -909,13 +893,11 @@ struct LinalgPromotionPattern : public LinalgBasePromotionPattern {
///
/// Linalg vectorization patterns.
///
-/// Apply the `vectorizeLinalgOp` transformation as a pattern.
-/// `filter` controls LinalgTransformMarker matching and update when specified.
-/// See `vectorizeLinalgOp` for more details.
-
/// Empty for now, used for SFINAE purposes only.
struct LinalgVectorizationOptions {};
+/// `filter` controls LinalgTransformMarker matching and update when specified.
+/// See `vectorizeLinalgOp` for more details.
struct LinalgBaseVectorizationPattern : public RewritePattern {
/// MatchAnyOpTag-based constructor with a mandatory `filter`.
LinalgBaseVectorizationPattern(MLIRContext *context,
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
index a42ac8d81c4b9..721c47ca01308 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
@@ -29,7 +29,7 @@
using namespace mlir;
using namespace mlir::linalg;
-LogicalResult mlir::linalg::generalizeNamedOpPrecondition(Operation *op) {
+static LogicalResult generalizeNamedOpPrecondition(Operation *op) {
LinalgOp namedOp = dyn_cast<LinalgOp>(op);
// Check if the operation is a LinalgOp but not a GenericOp.
if (!namedOp || isa<GenericOp>(op))
@@ -40,8 +40,11 @@ LogicalResult mlir::linalg::generalizeNamedOpPrecondition(Operation *op) {
return success();
}
-GenericOp mlir::linalg::generalizeNamedOp(PatternRewriter &rewriter,
- LinalgOp namedOp) {
+FailureOr<GenericOp> mlir::linalg::generalizeNamedOp(RewriterBase &rewriter,
+ LinalgOp namedOp) {
+ if (failed(generalizeNamedOpPrecondition(namedOp)))
+ return rewriter.notifyMatchFailure(namedOp, "preconditions not met");
+
SmallVector<Value> inputOperands = namedOp.getInputOperands();
SmallVector<Value> outputOperands = namedOp.getOutputOperands();
SmallVector<AffineMap> indexingMaps = namedOp.getIndexingMaps();
@@ -58,6 +61,7 @@ GenericOp mlir::linalg::generalizeNamedOp(PatternRewriter &rewriter,
outputOperands, indexingMaps, iterators);
rewriter.inlineRegionBefore(namedOp->getRegion(0), genericOp.region(),
genericOp.region().begin());
+ rewriter.replaceOp(namedOp, genericOp->getResults());
return genericOp;
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp
index 0a1f7bc4565fc..8d0d26caa5fb0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp
@@ -21,6 +21,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/ScopeExit.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include <type_traits>
@@ -30,8 +31,9 @@
using namespace mlir;
using namespace mlir::linalg;
-LogicalResult mlir::linalg::interchangeGenericOpPrecondition(
- GenericOp genericOp, ArrayRef<unsigned> interchangeVector) {
+static LogicalResult
+interchangeGenericOpPrecondition(GenericOp genericOp,
+ ArrayRef<unsigned> interchangeVector) {
// Interchange vector must be non-empty and match the number of loops.
if (interchangeVector.empty() ||
genericOp.getNumLoops() != interchangeVector.size())
@@ -43,31 +45,38 @@ LogicalResult mlir::linalg::interchangeGenericOpPrecondition(
return success();
}
-void mlir::linalg::interchangeGenericOp(PatternRewriter &rewriter,
- GenericOp genericOp,
- ArrayRef<unsigned> interchangeVector) {
- // 1. Compute the inverse permutation map.
+FailureOr<GenericOp>
+mlir::linalg::interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp,
+ ArrayRef<unsigned> interchangeVector) {
+ if (failed(interchangeGenericOpPrecondition(genericOp, interchangeVector)))
+ return rewriter.notifyMatchFailure(genericOp, "preconditions not met");
+
+ // 1. Compute the inverse permutation map, it must be non-null since the
+ // preconditions are satisfied.
MLIRContext *context = genericOp.getContext();
AffineMap permutationMap = inversePermutation(
AffineMap::getPermutationMap(interchangeVector, context));
- assert(permutationMap && "expected permutation to be invertible");
- assert(interchangeVector.size() == genericOp.getNumLoops() &&
- "expected interchange vector to have entry for every loop");
+ assert(permutationMap && "unexpected null map");
+
+ // Start a guarded inplace update.
+ rewriter.startRootUpdate(genericOp);
+ auto guard =
+ llvm::make_scope_exit([&]() { rewriter.finalizeRootUpdate(genericOp); });
// 2. Compute the interchanged indexing maps.
- SmallVector<Attribute, 4> newIndexingMaps;
+ SmallVector<AffineMap> newIndexingMaps;
for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
AffineMap m = genericOp.getTiedIndexingMap(opOperand);
if (!permutationMap.isEmpty())
m = m.compose(permutationMap);
- newIndexingMaps.push_back(AffineMapAttr::get(m));
+ newIndexingMaps.push_back(m);
}
genericOp->setAttr(getIndexingMapsAttrName(),
- ArrayAttr::get(context, newIndexingMaps));
+ rewriter.getAffineMapArrayAttr(newIndexingMaps));
// 3. Compute the interchanged iterator types.
ArrayRef<Attribute> itTypes = genericOp.iterator_types().getValue();
- SmallVector<Attribute, 4> itTypesVector;
+ SmallVector<Attribute> itTypesVector;
llvm::append_range(itTypesVector, itTypes);
SmallVector<int64_t> permutation(interchangeVector.begin(),
interchangeVector.end());
@@ -91,4 +100,6 @@ void mlir::linalg::interchangeGenericOp(PatternRewriter &rewriter,
indexOp, permutationMap.getSubMap(indexOp.dim()), allIndices);
}
}
+
+ return genericOp;
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp b/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp
index bb38607d769ac..b331b6657fcea 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp
@@ -137,7 +137,7 @@ struct SimplifyDepthwiseConvQOp
struct LinalgNamedOpConversionPass
: public LinalgNamedOpConversionBase<LinalgNamedOpConversionPass> {
LinalgNamedOpConversionPass() = default;
- LinalgNamedOpConversionPass(const LinalgNamedOpConversionPass &) {}
+ LinalgNamedOpConversionPass(const LinalgNamedOpConversionPass &) = default;
void runOnOperation() override {
Operation *op = getOperation();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 8156c5d45744c..c2a3c2bda630a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -623,16 +623,14 @@ LogicalResult mlir::linalg::GenericOpInterchangePattern::matchAndRewrite(
GenericOp genericOp, PatternRewriter &rewriter) const {
if (failed(filter.checkAndNotify(rewriter, genericOp)))
return failure();
- if (failed(interchangeGenericOpPrecondition(genericOp, interchangeVector)))
+
+ FailureOr<GenericOp> transformedOp =
+ interchangeGenericOp(rewriter, genericOp, interchangeVector);
+ if (failed(transformedOp))
return failure();
- // TODO: figure out how this interplays with named ops. In particular this
- // should break the named op property.
- rewriter.updateRootInPlace(genericOp, [&]() {
- interchangeGenericOp(rewriter, genericOp, interchangeVector);
- // New filter if specified.
- filter.replaceLinalgTransformationFilter(rewriter, genericOp);
- });
+ // New filter if specified.
+ filter.replaceLinalgTransformationFilter(rewriter, genericOp);
return success();
}
@@ -652,12 +650,10 @@ LogicalResult mlir::linalg::LinalgGeneralizationPattern::matchAndRewrite(
Operation *op, PatternRewriter &rewriter) const {
if (failed(filter.checkAndNotify(rewriter, op)))
return failure();
- if (failed(generalizeNamedOpPrecondition(op)))
+ FailureOr<GenericOp> genericOp = generalizeNamedOp(rewriter, op);
+ if (failed(genericOp))
return failure();
-
- GenericOp genericOp = generalizeNamedOp(rewriter, op);
- rewriter.replaceOp(op, genericOp.getResults());
- filter.replaceLinalgTransformationFilter(rewriter, genericOp);
+ filter.replaceLinalgTransformationFilter(rewriter, *genericOp);
return success();
}
@@ -708,19 +704,13 @@ mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern(
LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite(
Operation *op, PatternRewriter &rewriter) const {
+ // TODO: Interface-based rewrite.
LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
if (!linalgOp)
return failure();
- if (failed(filter.checkAndNotify(rewriter, linalgOp)))
- return failure();
- SmallVector<Value> newResults;
- if (failed(vectorizeLinalgOp(rewriter, op, newResults)))
+ if (failed(filter.checkAndNotify(rewriter, op)))
return failure();
- if (!newResults.empty())
- rewriter.replaceOp(op, newResults);
- else
- rewriter.eraseOp(op);
- return success();
+ return vectorize(rewriter, linalgOp);
}
LogicalResult mlir::linalg::applyStagedPatterns(
@@ -758,8 +748,8 @@ static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) {
return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName());
}
-/// Rewrite a PadTensorOp into a sequence of InitTensorOp, FillOp (to initialize
-/// with pad_val) and GenericOp (to copy contents).
+/// Rewrite a PadTensorOp into a sequence of InitTensorOp, FillOp (to
+/// initialize with pad_val) and GenericOp (to copy contents).
LogicalResult PadTensorOpTransformationPattern::matchAndRewrite(
linalg::PadTensorOp padOp, PatternRewriter &rewriter) const {
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 5fda632b2f860..4a597f64d72ff 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -597,8 +597,7 @@ static LogicalResult reductionPreconditions(LinalgOp op) {
return success();
}
-LogicalResult
-mlir::linalg::vectorizeStaticLinalgOpPrecondition(linalg::LinalgOp op) {
+static LogicalResult vectorizeStaticLinalgOpPrecondition(linalg::LinalgOp op) {
if (isElementwise(op))
return success();
// TODO: isaConvolutionOpInterface that can also infer from generic features.
@@ -620,8 +619,7 @@ mlir::linalg::vectorizeStaticLinalgOpPrecondition(linalg::LinalgOp op) {
return success();
}
-LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
- auto linalgOp = cast<linalg::LinalgOp>(op);
+static LogicalResult vectorizeLinalgOpPrecondition(LinalgOp linalgOp) {
// All types must be static shape to go to vector.
if (linalgOp.hasDynamicShape()) {
LDBG("precondition failed: dynamic shape");
@@ -630,31 +628,32 @@ LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
return vectorizeStaticLinalgOpPrecondition(linalgOp);
}
-LogicalResult
-mlir::linalg::vectorizeLinalgOp(OpBuilder &b, Operation *op,
- SmallVectorImpl<Value> &newResults) {
- if (failed(vectorizeLinalgOpPrecondition(op)))
+LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter,
+ LinalgOp linalgOp) {
+ if (failed(vectorizeLinalgOpPrecondition(linalgOp)))
return failure();
- auto linalgOp = cast<LinalgOp>(op);
-
- // TODO: isaConvolutionOpInterface that can also infer from generic features.
- // But we will still need stride/dilation attributes that will be annoying to
- // reverse-engineer...
- if (auto convOp = dyn_cast<ConvolutionOpInterface>(op)) {
- FailureOr<Operation *> resultOrFail = vectorizeConvolution(b, convOp);
- if (failed(resultOrFail))
+ SmallVector<Value> results;
+ // TODO: isaConvolutionOpInterface that can also infer from generic
+ // features. Will require stride/dilation attributes inference.
+ if (auto convOp = dyn_cast<ConvolutionOpInterface>(linalgOp.getOperation())) {
+ LDBG("Vectorize as a conv: " << linalgOp);
+ FailureOr<Operation *> convOr = vectorizeConvolution(rewriter, convOp);
+ if (failed(convOr))
+ return failure();
+ llvm::append_range(results, (*convOr)->getResults());
+ } else {
+ LDBG("Vectorize generic by broadcasting to a common shape: " << linalgOp);
+ if (failed(vectorizeAsLinalgGeneric(rewriter, linalgOp, results)))
return failure();
- Operation *newOp = *resultOrFail;
- llvm::append_range(newResults, newOp->getResults());
- return success();
}
- LDBG(""
- << "Vectorize linalg op as a generic by broadcasting to "
- "maximal common shape: "
- << *op);
- return vectorizeAsLinalgGeneric(b, linalgOp, newResults);
+ if (!results.empty())
+ rewriter.replaceOp(linalgOp, results);
+ else
+ rewriter.eraseOp(linalgOp);
+
+ return success();
}
//----------------------------------------------------------------------------//
@@ -666,8 +665,9 @@ static int64_t getIntFromAttr(Attribute attr) {
return attr.cast<IntegerAttr>().getInt();
}
-/// Given an ArrayRef of OpFoldResults, return a vector of Values. IntegerAttrs
-/// are converted to ConstantIndexOps. Other attribute types are not supported.
+/// Given an ArrayRef of OpFoldResults, return a vector of Values.
+/// IntegerAttrs are converted to ConstantIndexOps. Other attribute types are
+/// not supported.
static SmallVector<Value> ofrToIndexValues(OpBuilder &builder, Location loc,
ArrayRef<OpFoldResult> ofrs) {
SmallVector<Value> result;
@@ -691,9 +691,9 @@ struct GenericPadTensorOpVectorizationPattern
GenericPadTensorOpVectorizationPattern(MLIRContext *context,
PatternBenefit benefit = 1)
: GeneralizePadTensorOpPattern(context, tryVectorizeCopy, benefit) {}
- /// Vectorize the copying of a PadTensorOp's source. This is possible if each
- /// dimension size is statically know in the source type or the result type
- /// (or both).
+ /// Vectorize the copying of a PadTensorOp's source. This is possible if
+ /// each dimension size is statically know in the source type or the result
+ /// type (or both).
static LogicalResult tryVectorizeCopy(PatternRewriter &rewriter,
PadTensorOp padOp, Value dest) {
auto sourceType = padOp.getSourceType();
@@ -718,13 +718,14 @@ struct GenericPadTensorOpVectorizationPattern
for (unsigned i = 0; i < sourceType.getRank(); ++i) {
if (!sourceType.isDynamicDim(i)) {
vecShape.push_back(sourceType.getDimSize(i));
- // Source shape is statically known: Neither read nor write are out-of-
- // bounds.
+ // Source shape is statically known: Neither read nor write are
+ // out-of- bounds.
readInBounds.push_back(true);
writeInBounds.push_back(true);
} else if (!resultType.isDynamicDim(i)) {
- // Source shape is not statically known, but result shape is. Vectorize
- // with size of result shape. This may be larger than the source size.
+ // Source shape is not statically known, but result shape is.
+ // Vectorize with size of result shape. This may be larger than the
+ // source size.
vecShape.push_back(resultType.getDimSize(i));
// Read may be out-of-bounds because the result size could be larger
// than the source size.
@@ -749,8 +750,8 @@ struct GenericPadTensorOpVectorizationPattern
padOp.getLoc(), vecType, padOp.source(), readIndices, padValue,
ArrayRef<bool>{readInBounds});
- // If `dest` is a FillOp and the TransferWriteOp would overwrite the entire
- // tensor, write directly to the FillOp's operand.
+ // If `dest` is a FillOp and the TransferWriteOp would overwrite the
+ // entire tensor, write directly to the FillOp's operand.
if (llvm::equal(vecShape, resultType.getShape()) &&
llvm::all_of(writeInBounds, [](bool b) { return b; }))
if (auto fill = dest.getDefiningOp<FillOp>())
@@ -766,8 +767,8 @@ struct GenericPadTensorOpVectorizationPattern
}
};
-/// Base pattern for rewriting PadTensorOps whose result is consumed by a given
-/// operation type OpTy.
+/// Base pattern for rewriting PadTensorOps whose result is consumed by a
+/// given operation type OpTy.
template <typename OpTy>
struct VectorizePadTensorOpUserPattern : public OpRewritePattern<PadTensorOp> {
using OpRewritePattern<PadTensorOp>::OpRewritePattern;
@@ -837,10 +838,10 @@ struct PadTensorOpVectorizationWithTransferReadPattern
};
/// Rewrite use of PadTensorOp result in TransferWriteOp.
-/// This pattern rewrites TransferWriteOps that write to a padded tensor value,
-/// where the same amount of padding is immediately removed again after the
-/// write. In such cases, the TransferWriteOp can write to the non-padded tensor
-/// value and apply out-of-bounds masking. E.g.:
+/// This pattern rewrites TransferWriteOps that write to a padded tensor
+/// value, where the same amount of padding is immediately removed again after
+/// the write. In such cases, the TransferWriteOp can write to the non-padded
+/// tensor value and apply out-of-bounds masking. E.g.:
/// ```
/// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
/// : tensor<...> to tensor<?x?xf32>
@@ -854,17 +855,19 @@ struct PadTensorOpVectorizationWithTransferReadPattern
/// ```
/// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
/// : tensor<...> to tensor<?x?xf32>
-/// %r = vector.transfer_write %vec, %0[...] : vector<17x5xf32>, tensor<?x?xf32>
+/// %r = vector.transfer_write %vec, %0[...] : vector<17x5xf32>,
+/// tensor<?x?xf32>
/// ```
/// Note: It is important that the ExtractSliceOp %r resizes the result of the
-/// TransferWriteOp to the same size as the input of the TensorPadOp (or an even
-/// smaller size). Otherwise, %r's new (dynamic) dimensions would
diff er from
-/// %r's old dimensions.
+/// TransferWriteOp to the same size as the input of the TensorPadOp (or an
+/// even smaller size). Otherwise, %r's new (dynamic) dimensions would
diff er
+/// from %r's old dimensions.
///
/// This rewrite is possible if:
/// - Low padding is static 0.
/// - `xferOp` has exactly one use, which is an ExtractSliceOp. This
-/// ExtractSliceOp trims the same amount of padding that was added beforehand.
+/// ExtractSliceOp trims the same amount of padding that was added
+/// beforehand.
/// - Single, scalar padding value.
struct PadTensorOpVectorizationWithTransferWritePattern
: public VectorizePadTensorOpUserPattern<vector::TransferWriteOp> {
@@ -922,8 +925,8 @@ struct PadTensorOpVectorizationWithTransferWritePattern
/// sizes may turn out to be equal at runtime.
bool hasSameTensorSize(Value beforePadding,
tensor::ExtractSliceOp afterTrimming) const {
- // If the input to PadTensorOp is a CastOp, try with with both CastOp result
- // and CastOp operand.
+ // If the input to PadTensorOp is a CastOp, try with with both CastOp
+ // result and CastOp operand.
if (auto castOp = beforePadding.getDefiningOp<tensor::CastOp>())
if (hasSameTensorSize(castOp.source(), afterTrimming))
return true;
@@ -950,8 +953,9 @@ struct PadTensorOpVectorizationWithTransferWritePattern
if (t1.getNumDynamicDims() == 0)
return true;
- // All dynamic sizes must be the same. The only supported case at the moment
- // is when `beforePadding` is an ExtractSliceOp (or a cast thereof).
+ // All dynamic sizes must be the same. The only supported case at the
+ // moment is when `beforePadding` is an ExtractSliceOp (or a cast
+ // thereof).
// Apart from CastOp, only ExtractSliceOp is supported.
auto beforeSlice = beforePadding.getDefiningOp<tensor::ExtractSliceOp>();
@@ -1062,7 +1066,8 @@ struct PadTensorOpVectorizationWithInsertSlicePattern
// InsertSliceOp.
rewriter.setInsertionPoint(insertOp);
- // Generate TransferReadOp: Read entire source tensor and add high padding.
+ // Generate TransferReadOp: Read entire source tensor and add high
+ // padding.
SmallVector<Value> readIndices(
vecRank, rewriter.create<arith::ConstantIndexOp>(padOp.getLoc(), 0));
auto read = rewriter.create<vector::TransferReadOp>(
@@ -1224,9 +1229,9 @@ void mlir::linalg::populateConvVectorizationPatterns(
// Forwarding patterns
//----------------------------------------------------------------------------//
-/// Check whether there is any interleaved use of any `values` between `firstOp`
-/// and `secondOp`. Conservatively return `true` if any op or value is in a
-///
diff erent block.
+/// Check whether there is any interleaved use of any `values` between
+/// `firstOp` and `secondOp`. Conservatively return `true` if any op or value
+/// is in a
diff erent block.
static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
ValueRange values) {
if (firstOp->getBlock() != secondOp->getBlock() ||
@@ -1252,7 +1257,8 @@ static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
return false;
}
-/// Return the unique subview use of `v` if it is indeed unique, null otherwise.
+/// Return the unique subview use of `v` if it is indeed unique, null
+/// otherwise.
static memref::SubViewOp getSubViewUseIfUnique(Value v) {
memref::SubViewOp subViewOp;
for (auto &u : v.getUses()) {
@@ -1307,7 +1313,8 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
return failure();
LDBG("with copy " << *copyOp);
- // Find the fill into `viewOrAlloc` without interleaved uses before the copy.
+ // Find the fill into `viewOrAlloc` without interleaved uses before the
+ // copy.
FillOp maybeFillOp;
for (auto &u : viewOrAlloc.getUses()) {
if (auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
@@ -1468,7 +1475,8 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
/// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
/// ```
/// kw is always unrolled.
- /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is > 1.
+ /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
+ /// > 1.
FailureOr<Operation *> conv() {
if (!valid)
return failure();
@@ -1483,7 +1491,8 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
// w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
- // When strideW == 1, we can batch the contiguous loads and avoid unrolling
+ // When strideW == 1, we can batch the contiguous loads and avoid
+ // unrolling
int64_t wSizeStep = strideW == 1 ? wSize : 1;
Type lhsEltType = lhsShapedType.getElementType();
@@ -1500,7 +1509,8 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
VectorType rhsType = VectorType::get({kwSize, cSize, fSize}, rhsEltType);
VectorType resType = VectorType::get({nSize, wSize, fSize}, resEltType);
- // Read lhs slice of size {w * strideW + kw * dilationW, c, f} @ [0, 0, 0].
+ // Read lhs slice of size {w * strideW + kw * dilationW, c, f} @ [0, 0,
+ // 0].
Value lhs = builder.create<vector::TransferReadOp>(
loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
// Read rhs slice of size {kw, c, f} @ [0, 0, 0].
@@ -1591,7 +1601,8 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
/// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
/// ```
/// kw is always unrolled.
- /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is > 1.
+ /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
+ /// > 1.
FailureOr<Operation *> dilatedConv() {
if (!valid)
return failure();
@@ -1605,7 +1616,8 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
// w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
- // When strideW == 1, we can batch the contiguous loads and avoid unrolling
+ // When strideW == 1, we can batch the contiguous loads and avoid
+ // unrolling
int64_t wSizeStep = strideW == 1 ? wSize : 1;
Type lhsEltType = lhsShapedType.getElementType();
@@ -1621,7 +1633,8 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
VectorType rhsType = VectorType::get({kwSize, cSize}, rhsEltType);
VectorType resType = VectorType::get({nSize, wSize, cSize}, resEltType);
- // Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0, 0].
+ // Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0,
+ // 0].
Value lhs = builder.create<vector::TransferReadOp>(
loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
// Read rhs slice of size {kw, c} @ [0, 0].
More information about the Mlir-commits
mailing list