[Mlir-commits] [mlir] 96179df - [mlir][Linalg] Add a transform dialect op to rewrite ops to destination passing style.
Nicolas Vasilache
llvmlistbot at llvm.org
Thu Feb 16 05:39:01 PST 2023
Author: Nicolas Vasilache
Date: 2023-02-16T05:26:33-08:00
New Revision: 96179dff46a9a3981708b06bc9e0f981be4cc1a8
URL: https://github.com/llvm/llvm-project/commit/96179dff46a9a3981708b06bc9e0f981be4cc1a8
DIFF: https://github.com/llvm/llvm-project/commit/96179dff46a9a3981708b06bc9e0f981be4cc1a8.diff
LOG: [mlir][Linalg] Add a transform dialect op to rewrite ops to destination passing style.
A new transform dialect op is introduced to perform the rewrite.
The test pass option is now obsolete and is removed in favor of the transform.
In the process I realized the tensor.pad nofold attribute was not taken into account
and added support to emit a bufferization.alloc_tensor + linalg.copy.
Reviewed By: springerm
Differential Revision: https://reviews.llvm.org/D143943
Added:
mlir/test/Dialect/Linalg/transform-op-rewrite-in-destination-passing-style.mlir
Modified:
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
mlir/lib/Dialect/Linalg/Utils/Utils.cpp
mlir/lib/Dialect/Utils/StaticValueUtils.cpp
mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
Removed:
mlir/test/Dialect/Linalg/convert-to-destination-style.mlir
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 58ef106563a5b..ff8f3e33703c2 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -873,6 +873,45 @@ def ScalarizeOp : Op<Transform_Dialect, "structured.scalarize",
}];
}
+//===----------------------------------------------------------------------===//
+// RewriteInDestinationPassingStyleOp.
+//===----------------------------------------------------------------------===//
+
+def RewriteInDestinationPassingStyleOp : Op<
+ Transform_Dialect, "structured.rewrite_in_destination_passing_style",
+ [MemoryEffectsOpInterface,
+ NavigationTransformOpTrait,
+ DeclareOpInterfaceMethods<TransformOpInterface>]> {
+ let description = [{
+ Rewrite a supported tensor operation that is not in destination-passing style
+ into a form that is in destination-passing style.
+ Currently supported operations are:
+ - tensor.pad
+ - tensor.generate
+ - tensor.from_elements
+ This dichotomy hints at a future interface, for now the implementation just
+ switches between
diff erent implementation.
+
+ #### Return modes
+
+ This operation ignores non-unsupported ops and drops them from the return.
+ If all the operations referred to by the `target` PDLOperation generalize
+ properly, the transform succeeds. Otherwise the transform silently fails.
+ The return handle points to a subset of successfully produced operations:
+ - tensor.pad case, the returned handle points to the tensor.insert_slice.
+ - tensor.generate case, the returned handle points to the linalg.generic.
+ - tensor.from_elements case, the returned handle points to the last
+ tensor.insert.
+ }];
+
+ let arguments = (ins TransformHandleTypeInterface:$target);
+ let results = (outs TransformHandleTypeInterface:$transformed);
+ let assemblyFormat = [{
+ $target attr-dict
+ `:` functional-type($target, results)
+ }];
+}
+
//===----------------------------------------------------------------------===//
// SplitOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index dd01a2e3325bf..aedef03b88fc7 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1206,6 +1206,20 @@ packTranspose(RewriterBase &rewriter, tensor::PackOp packOp,
linalg::LinalgOp linalgOp, tensor::UnPackOp maybeUnPackOp,
ArrayRef<int64_t> outerPerm, ArrayRef<int64_t> innerPerm);
+/// Rewrite tensor.from_elements to linalg.generic.
+FailureOr<Operation *>
+rewriteInDestinationPassingStyle(RewriterBase &rewriter,
+ tensor::FromElementsOp fromElementsOp);
+
+/// Rewrite tensor.generate to linalg.generic.
+FailureOr<Operation *>
+rewriteInDestinationPassingStyle(RewriterBase &rewriter,
+ tensor::GenerateOp generateOp);
+
+/// Rewrite tensor.pad to linalg.generic + tensor.insert_slice.
+FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
+ tensor::PadOp padOp);
+
} // namespace linalg
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index 0b0e03dc2788c..c37d35134dce1 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -21,6 +21,10 @@
namespace mlir {
+/// Return true if `v` is an IntegerAttr with value `0` of a ConstantIndexOp
+/// with attribute with value `0`.
+bool isZeroIndex(OpFoldResult v);
+
/// Represents a range (offset, size, and stride) where each element of the
/// triple may be dynamic or static.
struct Range {
@@ -30,8 +34,8 @@ struct Range {
};
/// Given an array of Range values, return a tuple of (offset vector, sizes
-/// vector, and strides vector) formed by separating out the individual elements
-/// of each range.
+/// vector, and strides vector) formed by separating out the individual
+/// elements of each range.
std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
SmallVector<OpFoldResult>>
getOffsetsSizesAndStrides(ArrayRef<Range> ranges);
@@ -40,14 +44,15 @@ getOffsetsSizesAndStrides(ArrayRef<Range> ranges);
/// a) it is an IntegerAttr
/// In other cases, the OpFoldResult is dispached to the `dynamicVec`.
/// In such dynamic cases, ShapedType::kDynamic is also pushed to
-/// `staticVec`. This is useful to extract mixed static and dynamic entries that
-/// come from an AttrSizedOperandSegments trait.
+/// `staticVec`. This is useful to extract mixed static and dynamic entries
+/// that come from an AttrSizedOperandSegments trait.
void dispatchIndexOpFoldResult(OpFoldResult ofr,
SmallVectorImpl<Value> &dynamicVec,
SmallVectorImpl<int64_t> &staticVec);
-/// Helper function to dispatch multiple OpFoldResults according to the behavior
-/// of `dispatchIndexOpFoldResult(OpFoldResult ofr` for a single OpFoldResult.
+/// Helper function to dispatch multiple OpFoldResults according to the
+/// behavior of `dispatchIndexOpFoldResult(OpFoldResult ofr` for a single
+/// OpFoldResult.
void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs,
SmallVectorImpl<Value> &dynamicVec,
SmallVectorImpl<int64_t> &staticVec);
@@ -72,27 +77,28 @@ std::optional<int64_t> getConstantIntValue(OpFoldResult ofr);
/// Return true if `ofr` is constant integer equal to `value`.
bool isConstantIntValue(OpFoldResult ofr, int64_t value);
-/// Return true if ofr1 and ofr2 are the same integer constant attribute values
-/// or the same SSA value.
-/// Ignore integer bitwitdh and type mismatch that come from the fact there is
-/// no IndexAttr and that IndexType have no bitwidth.
+/// Return true if ofr1 and ofr2 are the same integer constant attribute
+/// values or the same SSA value. Ignore integer bitwitdh and type mismatch
+/// that come from the fact there is no IndexAttr and that IndexType have no
+/// bitwidth.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2);
/// Helper function to convert a vector of `OpFoldResult`s into a vector of
-/// `Value`s. For each `OpFoldResult` in `valueOrAttrVec` return the fold result
-/// if it casts to a `Value` or create an index-type constant if it casts to
-/// `IntegerAttr`. No other attribute types are supported.
+/// `Value`s. For each `OpFoldResult` in `valueOrAttrVec` return the fold
+/// result if it casts to a `Value` or create an index-type constant if it
+/// casts to `IntegerAttr`. No other attribute types are supported.
SmallVector<Value> getAsValues(OpBuilder &b, Location loc,
ArrayRef<OpFoldResult> valueOrAttrVec);
-/// Return a vector of OpFoldResults with the same size a staticValues, but all
-/// elements for which ShapedType::isDynamic is true, will be replaced by
+/// Return a vector of OpFoldResults with the same size a staticValues, but
+/// all elements for which ShapedType::isDynamic is true, will be replaced by
/// dynamicValues.
SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
ValueRange dynamicValues, Builder &b);
-/// Decompose a vector of mixed static or dynamic values into the corresponding
-/// pair of arrays. This is the inverse function of `getMixedValues`.
+/// Decompose a vector of mixed static or dynamic values into the
+/// corresponding pair of arrays. This is the inverse function of
+/// `getMixedValues`.
std::pair<ArrayAttr, SmallVector<Value>>
decomposeMixedValues(Builder &b,
const SmallVectorImpl<OpFoldResult> &mixedValues);
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index dab98d2406f45..82c4c39032001 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -39,6 +39,7 @@
#include "llvm/ADT/SetOperations.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringSet.h"
+#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
using namespace mlir;
@@ -1919,6 +1920,32 @@ transform::ScalarizeOp::applyToOne(LinalgOp target,
return DiagnosedSilenceableFailure::success();
}
+//===----------------------------------------------------------------------===//
+// RewriteInDestinationPassingStyleOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::RewriteInDestinationPassingStyleOp::apply(
+ transform::TransformResults &results, transform::TransformState &state) {
+ SmallVector<Operation *> res;
+ ArrayRef<Operation *> targetOps = state.getPayloadOps(getTarget());
+ for (Operation *target : targetOps) {
+ IRRewriter rewriter(target->getContext());
+ rewriter.setInsertionPoint(target);
+ FailureOr<Operation *> maybeResult =
+ TypeSwitch<Operation *, FailureOr<Operation *>>(target)
+ .Case<tensor::FromElementsOp, tensor::GenerateOp, tensor::PadOp>(
+ [&rewriter](auto op) {
+ return rewriteInDestinationPassingStyle(rewriter, op);
+ });
+ if (failed(maybeResult))
+ return emitDefaultSilenceableFailure(target);
+ res.push_back(*maybeResult);
+ }
+ results.set(getResult().cast<OpResult>(), res);
+ return DiagnosedSilenceableFailure::success();
+}
+
//===----------------------------------------------------------------------===//
// SplitOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
index 1915c2acea441..85235ef2db376 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
@@ -19,8 +19,10 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Debug.h"
using namespace mlir;
@@ -50,94 +52,6 @@ static Value createInserts(RewriterBase &rewriter, Location loc, int dim,
return destination;
}
-namespace {
-
-/// Lower tensor.from_elements to a sequence of chained tensor.insert.
-struct FromElementsOpConverter : public OpRewritePattern<FromElementsOp> {
- using OpRewritePattern<FromElementsOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(FromElementsOp elementsOp,
- PatternRewriter &rewriter) const override {
- Location loc = elementsOp.getLoc();
- RankedTensorType tensorType = elementsOp.getType().cast<RankedTensorType>();
- auto shape = tensorType.getShape();
-
- // Create tensor.empty.
- auto emptyOp = rewriter.create<EmptyOp>(loc, tensorType, ValueRange());
-
- // Case: tensor<elem_type>.
- if (shape.empty()) {
- rewriter.replaceOpWithNewOp<tensor::InsertOp>(
- elementsOp, elementsOp.getElements().front(), emptyOp.getResult(),
- ValueRange());
- return success();
- }
-
- // Create constants for the range of possible indices [0, max{shape_i}).
- auto maxDim = *std::max_element(shape.begin(), shape.end());
- SmallVector<Value, 2> constants;
- constants.reserve(maxDim);
- for (int i = 0; i < maxDim; ++i)
- constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i));
-
- // Traverse all elements and create tensor.insert ops.
- auto elementIt = elementsOp.getElements().begin();
- SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]);
- Value result = createInserts(rewriter, loc, /*dim=*/0, emptyOp.getResult(),
- shape, constants, elementIt, indices);
-
- // Replace tensor.from_elements.
- rewriter.replaceOp(elementsOp, result);
- return success();
- }
-};
-
-/// Lower tensor.generate to linalg.generic.
-struct GenerateOpConverter : public OpRewritePattern<GenerateOp> {
- using OpRewritePattern<GenerateOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(GenerateOp generateOp,
- PatternRewriter &rewriter) const override {
- // Only ops with exactly one block are supported.
- if (!generateOp.getBody().hasOneBlock())
- return failure();
-
- Location loc = generateOp.getLoc();
- RankedTensorType tensorType = generateOp.getType().cast<RankedTensorType>();
-
- // Create tensor.empty.
- auto emptyOp = rewriter.create<EmptyOp>(loc, tensorType,
- generateOp.getDynamicExtents());
-
- // Create linalg.generic.
- SmallVector<utils::IteratorType> iteratorTypes(
- tensorType.getRank(), utils::IteratorType::parallel);
- SmallVector<AffineMap> indexingMaps(
- 1, rewriter.getMultiDimIdentityMap(tensorType.getRank()));
- auto genericOp = rewriter.create<linalg::GenericOp>(
- loc, tensorType, /*inputs=*/ValueRange(),
- /*outputs=*/ValueRange{emptyOp.getResult()}, /*indexingMaps=*/
- indexingMaps, iteratorTypes);
- Block *body = rewriter.createBlock(&genericOp->getRegion(0), {},
- tensorType.getElementType(), loc);
- rewriter.setInsertionPointToStart(body);
- SmallVector<Value> bbArgReplacements;
- for (int64_t i = 0; i < tensorType.getRank(); ++i)
- bbArgReplacements.push_back(rewriter.create<linalg::IndexOp>(loc, i));
- rewriter.mergeBlocks(&generateOp.getBody().front(), body,
- bbArgReplacements);
-
- // Update terminator.
- auto yieldOp = cast<tensor::YieldOp>(body->getTerminator());
- rewriter.replaceOpWithNewOp<linalg::YieldOp>(yieldOp, yieldOp.getValue());
-
- // Replace tensor.generate.
- rewriter.replaceOp(generateOp, genericOp->getResult(0));
- return success();
- }
-};
-} // namespace
-
static Operation *movePaddingToFillOrGenericOp(RewriterBase &rewriter,
Location loc, PadOp padOp,
Value dest) {
@@ -287,49 +201,133 @@ Value linalg::bufferizeToAllocation(RewriterBase &rewriter, PadOp padOp,
return toTensorOp;
}
-namespace {
-/// Lower tensor.pad to linalg.generic + tensor.insert_slice.
-struct PadOpConverter : public OpRewritePattern<PadOp> {
- using OpRewritePattern<PadOp>::OpRewritePattern;
+/// Lower tensor.from_elements to a sequence of chained tensor.insert.
+FailureOr<Operation *> mlir::linalg::rewriteInDestinationPassingStyle(
+ RewriterBase &rewriter, tensor::FromElementsOp fromElementsOp) {
+ Location loc = fromElementsOp.getLoc();
+ RankedTensorType tensorType =
+ fromElementsOp.getType().cast<RankedTensorType>();
+ auto shape = tensorType.getShape();
+
+ // Create tensor.empty.
+ auto emptyOp = rewriter.create<EmptyOp>(loc, tensorType, ValueRange());
+
+ // Case: tensor<elem_type>.
+ if (shape.empty()) {
+ Operation *res = rewriter.replaceOpWithNewOp<tensor::InsertOp>(
+ fromElementsOp, fromElementsOp.getElements().front(),
+ emptyOp.getResult(), ValueRange());
+ return res;
+ }
- LogicalResult matchAndRewrite(PadOp padOp,
- PatternRewriter &rewriter) const override {
- // Only ops with exactly one block are supported.
- if (!padOp.getBodyRegion().hasOneBlock())
- return failure();
+ // Create constants for the range of possible indices [0, max{shape_i}).
+ auto maxDim = *std::max_element(shape.begin(), shape.end());
+ SmallVector<Value, 2> constants;
+ constants.reserve(maxDim);
+ for (int i = 0; i < maxDim; ++i)
+ constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i));
+
+ // Traverse all elements and create tensor.insert ops.
+ auto elementIt = fromElementsOp.getElements().begin();
+ SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]);
+ Value result = createInserts(rewriter, loc, /*dim=*/0, emptyOp.getResult(),
+ shape, constants, elementIt, indices);
+
+ // Replace tensor.from_elements.
+ rewriter.replaceOp(fromElementsOp, result);
+ return result.getDefiningOp();
+}
- // Create tensor.empty.
- Location loc = padOp.getLoc();
- RankedTensorType resultType = padOp.getResultType();
- ReifiedRankedShapedTypeDims reifiedShape;
- if (failed(cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation())
- .reifyResultShapes(rewriter, reifiedShape)))
- return rewriter.notifyMatchFailure(
- padOp, "failed to reify tensor.pad op result shape");
- SmallVector<Value> dynamicSizes;
- for (int64_t i = 0; i < resultType.getRank(); ++i)
- if (resultType.isDynamicDim(i))
- dynamicSizes.push_back(reifiedShape[0][i]);
- auto emptyOp = rewriter.create<EmptyOp>(loc, resultType, dynamicSizes);
-
- // Create linalg.fill or linalg.generic.
- Operation *fillOp =
- movePaddingToFillOrGenericOp(rewriter, loc, padOp, emptyOp.getResult());
- rewriter.setInsertionPointAfter(fillOp);
-
- // Create tensor::InsertSliceOp.
- SmallVector<OpFoldResult> sliceSizes =
- getMixedSizes(rewriter, loc, padOp.getSource());
- SmallVector<OpFoldResult> sliceStrides(resultType.getRank(),
- rewriter.getIndexAttr(1));
- rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
- padOp, padOp.getSource(), fillOp->getResult(0),
- /*offsets=*/padOp.getMixedLowPad(), sliceSizes, sliceStrides);
+/// Lower tensor.generate to linalg.generic.
+FailureOr<Operation *>
+mlir::linalg::rewriteInDestinationPassingStyle(RewriterBase &rewriter,
+ tensor::GenerateOp generateOp) {
+ // Only ops with exactly one block are supported.
+ if (!generateOp.getBody().hasOneBlock())
+ return failure();
- return success();
+ Location loc = generateOp.getLoc();
+ RankedTensorType tensorType = generateOp.getType().cast<RankedTensorType>();
+
+ // Create tensor.empty.
+ auto emptyOp =
+ rewriter.create<EmptyOp>(loc, tensorType, generateOp.getDynamicExtents());
+
+ // Create linalg.generic.
+ SmallVector<utils::IteratorType> iteratorTypes(tensorType.getRank(),
+ utils::IteratorType::parallel);
+ SmallVector<AffineMap> indexingMaps(
+ 1, rewriter.getMultiDimIdentityMap(tensorType.getRank()));
+ auto genericOp = rewriter.create<linalg::GenericOp>(
+ loc, tensorType, /*inputs=*/ValueRange(),
+ /*outputs=*/ValueRange{emptyOp.getResult()}, /*indexingMaps=*/
+ indexingMaps, iteratorTypes);
+ Block *body = rewriter.createBlock(&genericOp->getRegion(0), {},
+ tensorType.getElementType(), loc);
+ rewriter.setInsertionPointToStart(body);
+ SmallVector<Value> bbArgReplacements;
+ for (int64_t i = 0; i < tensorType.getRank(); ++i)
+ bbArgReplacements.push_back(rewriter.create<linalg::IndexOp>(loc, i));
+ rewriter.mergeBlocks(&generateOp.getBody().front(), body, bbArgReplacements);
+
+ // Update terminator.
+ auto yieldOp = cast<tensor::YieldOp>(body->getTerminator());
+ rewriter.replaceOpWithNewOp<linalg::YieldOp>(yieldOp, yieldOp.getValue());
+
+ // Replace tensor.generate.
+ rewriter.replaceOp(generateOp, genericOp->getResult(0));
+ return genericOp.getOperation();
+}
+
+/// Lower tensor.pad to linalg.generic + tensor.insert_slice.
+FailureOr<Operation *>
+mlir::linalg::rewriteInDestinationPassingStyle(RewriterBase &rewriter,
+ tensor::PadOp padOp) {
+ // Only ops with exactly one block are supported.
+ if (!padOp.getBodyRegion().hasOneBlock())
+ return failure();
+
+ // Create tensor.empty.
+ Location loc = padOp.getLoc();
+ RankedTensorType resultType = padOp.getResultType();
+ ReifiedRankedShapedTypeDims reifiedShape;
+ if (failed(cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation())
+ .reifyResultShapes(rewriter, reifiedShape)))
+ return rewriter.notifyMatchFailure(
+ padOp, "failed to reify tensor.pad op result shape");
+ SmallVector<Value> dynamicSizes;
+ for (int64_t i = 0; i < resultType.getRank(); ++i)
+ if (resultType.isDynamicDim(i))
+ dynamicSizes.push_back(reifiedShape[0][i]);
+
+ // If the `padOp` has a nofold attribute and all paddings are known to be 0,
+ // explicitly insert a `linalg.copy`.
+ if (padOp.getNofoldAttr() &&
+ llvm::all_of(padOp.getMixedLowPad(), isZeroIndex) &&
+ llvm::all_of(padOp.getMixedHighPad(), isZeroIndex)) {
+ using bufferization::AllocTensorOp;
+ Value allocated =
+ rewriter.create<AllocTensorOp>(loc, resultType, dynamicSizes);
+ auto copyOp = rewriter.replaceOpWithNewOp<linalg::CopyOp>(
+ padOp, padOp.getSource(), allocated);
+ return copyOp.getOperation();
}
-};
-} // namespace
+
+ Value empty = rewriter.create<EmptyOp>(loc, resultType, dynamicSizes);
+ // Create linalg.fill or linalg.generic.
+ Operation *fillOp = movePaddingToFillOrGenericOp(rewriter, loc, padOp, empty);
+ rewriter.setInsertionPointAfter(fillOp);
+
+ // Create tensor::InsertSliceOp.
+ SmallVector<OpFoldResult> sliceSizes =
+ getMixedSizes(rewriter, loc, padOp.getSource());
+ SmallVector<OpFoldResult> sliceStrides(resultType.getRank(),
+ rewriter.getIndexAttr(1));
+ auto insertSliceOp = rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
+ padOp, padOp.getSource(), fillOp->getResult(0),
+ /*offsets=*/padOp.getMixedLowPad(), sliceSizes, sliceStrides);
+ return insertSliceOp.getOperation();
+}
Value linalg::bufferizeToAllocation(RewriterBase &rewriter, Value value,
Attribute memorySpace) {
@@ -368,6 +366,45 @@ Value linalg::bufferizeToAllocation(RewriterBase &rewriter, Value value,
return toTensorOp;
}
+namespace {
+/// Lower tensor.from_elements to a sequence of chained tensor.insert.
+struct FromElementsOpConverter : public OpRewritePattern<FromElementsOp> {
+ using OpRewritePattern<FromElementsOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(FromElementsOp fromElementsOp,
+ PatternRewriter &rewriter) const override {
+ if (failed(
+ linalg::rewriteInDestinationPassingStyle(rewriter, fromElementsOp)))
+ return failure();
+ return success();
+ }
+};
+
+/// Lower tensor.generate to linalg.generic.
+struct GenerateOpConverter : public OpRewritePattern<GenerateOp> {
+ using OpRewritePattern<GenerateOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(GenerateOp generateOp,
+ PatternRewriter &rewriter) const override {
+ if (failed(linalg::rewriteInDestinationPassingStyle(rewriter, generateOp)))
+ return failure();
+ return success();
+ }
+};
+
+/// Lower tensor.pad to linalg.generic + tensor.insert_slice.
+struct PadOpConverter : public OpRewritePattern<PadOp> {
+ using OpRewritePattern<PadOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(PadOp padOp,
+ PatternRewriter &rewriter) const override {
+ if (failed(linalg::rewriteInDestinationPassingStyle(rewriter, padOp)))
+ return failure();
+ return success();
+ }
+};
+} // namespace
+
void linalg::populateConvertToDestinationStylePatterns(
RewritePatternSet &patterns) {
patterns.insert<FromElementsOpConverter, GenerateOpConverter, PadOpConverter>(
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 7ea36248c96f7..f3879f5dd9d12 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -44,18 +44,6 @@ using namespace presburger;
using namespace mlir::linalg;
using namespace mlir::scf;
-static bool isZero(OpFoldResult v) {
- if (!v)
- return false;
- if (auto attr = v.dyn_cast<Attribute>()) {
- IntegerAttr intAttr = attr.dyn_cast<IntegerAttr>();
- return intAttr && intAttr.getValue().isZero();
- }
- if (auto cst = v.get<Value>().getDefiningOp<arith::ConstantIndexOp>())
- return cst.value() == 0;
- return false;
-}
-
namespace {
// Helper visitor to determine whether an AffineExpr is tiled.
@@ -70,7 +58,7 @@ struct TileCheck : public AffineExprVisitor<TileCheck> {
TileCheck(ArrayRef<OpFoldResult> tileSizes) : tileSizes(tileSizes) {}
void visitDimExpr(AffineDimExpr expr) {
- isTiled |= !isZero(tileSizes[expr.getPosition()]);
+ isTiled |= !isZeroIndex(tileSizes[expr.getPosition()]);
}
void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) {
visit(expr.getLHS());
@@ -869,7 +857,7 @@ SmallVector<OpFoldResult> computeTileOffsets(OpBuilder &b, Location loc,
SmallVector<OpFoldResult> offsets;
for (unsigned idx = 0, idxIvs = 0, e = tileSizes.size(); idx < e; ++idx) {
LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for loop#" << idx << "\n");
- bool isTiled = !isZero(tileSizes[idx]);
+ bool isTiled = !isZeroIndex(tileSizes[idx]);
offsets.push_back(isTiled ? ivs[idxIvs++] : b.getIndexAttr(0));
LLVM_DEBUG(llvm::dbgs()
<< "computeTileOffsets: " << offsets.back() << "\n");
@@ -882,7 +870,7 @@ SmallVector<OpFoldResult> computeTileSizes(OpBuilder &b, Location loc,
ArrayRef<OpFoldResult> sizeBounds) {
SmallVector<OpFoldResult> sizes;
for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx) {
- bool isTiled = !isZero(tileSizes[idx]);
+ bool isTiled = !isZeroIndex(tileSizes[idx]);
// Before composing, we need to make range a closed interval.
OpFoldResult size = isTiled ? tileSizes[idx] : sizeBounds[idx];
AffineExpr d0 = getAffineDimExpr(0, b.getContext());
@@ -938,7 +926,7 @@ computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp,
bool omitPartialTileCheck) {
assert(ivs.size() == static_cast<size_t>(llvm::count_if(
llvm::make_range(tileSizes.begin(), tileSizes.end()),
- [](OpFoldResult v) { return !isZero(v); })) &&
+ [](OpFoldResult v) { return !isZeroIndex(v); })) &&
"expected as many ivs as non-zero sizes");
// Construct (potentially temporary) mins and maxes on which to apply maps
diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index 436e6e901a4a6..294dc810507b4 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -14,6 +14,18 @@
namespace mlir {
+bool isZeroIndex(OpFoldResult v) {
+ if (!v)
+ return false;
+ if (auto attr = v.dyn_cast<Attribute>()) {
+ IntegerAttr intAttr = attr.dyn_cast<IntegerAttr>();
+ return intAttr && intAttr.getValue().isZero();
+ }
+ if (auto cst = v.get<Value>().getDefiningOp<arith::ConstantIndexOp>())
+ return cst.value() == 0;
+ return false;
+}
+
std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
SmallVector<OpFoldResult>>
getOffsetsSizesAndStrides(ArrayRef<Range> ranges) {
diff --git a/mlir/test/Dialect/Linalg/convert-to-destination-style.mlir b/mlir/test/Dialect/Linalg/transform-op-rewrite-in-destination-passing-style.mlir
similarity index 71%
rename from mlir/test/Dialect/Linalg/convert-to-destination-style.mlir
rename to mlir/test/Dialect/Linalg/transform-op-rewrite-in-destination-passing-style.mlir
index 3a7472fcec08f..9b89d83f6a95d 100644
--- a/mlir/test/Dialect/Linalg/convert-to-destination-style.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-rewrite-in-destination-passing-style.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -split-input-file -test-linalg-transform-patterns=test-convert-to-destination-style-patterns -canonicalize %s | FileCheck %s
+// RUN: mlir-opt -test-transform-dialect-interpreter --split-input-file -canonicalize %s | FileCheck %s
// CHECK-LABEL: func @tensor_from_elements_0d(
// CHECK-SAME: %[[arg0:.*]]: index
@@ -10,6 +10,14 @@ func.func @tensor_from_elements_0d(%arg0: index) -> tensor<index> {
return %0 : tensor<index>
}
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["tensor.from_elements"]} in %arg1
+ : (!pdl.operation) -> !pdl.operation
+ transform.structured.rewrite_in_destination_passing_style %0
+ : (!pdl.operation) -> !pdl.operation
+}
+
// -----
// CHECK-LABEL: func @tensor_from_elements_1d(
@@ -25,6 +33,14 @@ func.func @tensor_from_elements_1d(%arg0: index, %arg1: index) -> tensor<2xindex
return %0 : tensor<2xindex>
}
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["tensor.from_elements"]} in %arg1
+ : (!pdl.operation) -> !pdl.operation
+ transform.structured.rewrite_in_destination_passing_style %0
+ : (!pdl.operation) -> !pdl.operation
+}
+
// -----
// CHECK-LABEL: func @tensor_from_elements_2d(
@@ -46,6 +62,14 @@ func.func @tensor_from_elements_2d(%arg0: index, %arg1: index) -> tensor<3x2xind
return %0 : tensor<3x2xindex>
}
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["tensor.from_elements"]} in %arg1
+ : (!pdl.operation) -> !pdl.operation
+ transform.structured.rewrite_in_destination_passing_style %0
+ : (!pdl.operation) -> !pdl.operation
+}
+
// -----
// CHECK: #[[$map:.*]] = affine_map<(d0, d1) -> (d0, d1)>
@@ -70,6 +94,14 @@ func.func @tensor_generate(%s1: index, %s2: index) -> tensor<?x?xindex> {
return %0 : tensor<?x?xindex>
}
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["tensor.generate"]} in %arg1
+ : (!pdl.operation) -> !pdl.operation
+ transform.structured.rewrite_in_destination_passing_style %0
+ : (!pdl.operation) -> !pdl.operation
+}
+
// -----
// CHECK: #[[$map:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 5)>
@@ -103,6 +135,14 @@ func.func @tensor_pad(%t1: tensor<?x10xindex>, %l2: index, %h1: index,
return %0 : tensor<?x?xindex>
}
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["tensor.pad"]} in %arg1
+ : (!pdl.operation) -> !pdl.operation
+ transform.structured.rewrite_in_destination_passing_style %0
+ : (!pdl.operation) -> !pdl.operation
+}
+
// -----
// CHECK: #[[$map:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 5)>
@@ -129,6 +169,14 @@ func.func @tensor_pad_constant(%t1: tensor<?x10xindex>, %l2: index, %h1: index,
return %0 : tensor<?x?xindex>
}
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["tensor.pad"]} in %arg1
+ : (!pdl.operation) -> !pdl.operation
+ transform.structured.rewrite_in_destination_passing_style %0
+ : (!pdl.operation) -> !pdl.operation
+}
+
// -----
// CHECK: #[[$map:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 5)>
@@ -152,3 +200,39 @@ func.func @tensor_pad_invariant(%t1: tensor<?x10xindex>, %l2: index, %h1: index,
} : tensor<?x10xindex> to tensor<?x?xindex>
return %0 : tensor<?x?xindex>
}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["tensor.pad"]} in %arg1
+ : (!pdl.operation) -> !pdl.operation
+ transform.structured.rewrite_in_destination_passing_style %0
+ : (!pdl.operation) -> !pdl.operation
+}
+
+// -----
+
+// CHECK-LABEL: func @tensor_pad_nofold(
+// CHECK-SAME: %[[t1:.*]]: tensor<?x?xindex>, %[[padding:.*]]: index
+// CHECK-NOT: linalg.fill
+// CHECK-NOT: generic
+// CHECK-NOT: insert_slice
+// CHECK: %[[alloc_tensor:.*]] = bufferization.alloc_tensor(%{{.*}}) : tensor<?x?xindex>
+// CHECK: %[[copied:.*]] = linalg.copy ins(%[[t1]] : tensor<?x?xindex>) outs(%[[alloc_tensor]] : tensor<?x?xindex>) -> tensor<?x?xindex>
+// CHECK: return %[[copied]]
+func.func @tensor_pad_nofold(%t1: tensor<?x?xindex>, %padding: index)
+ -> tensor<?x?xindex> {
+ %c0 = arith.constant 0 : index
+ %0 = tensor.pad %t1 nofold low[0, %c0] high[%c0, 0] {
+ ^bb0(%arg0: index, %arg1: index):
+ tensor.yield %padding : index
+ } : tensor<?x?xindex> to tensor<?x?xindex>
+ return %0: tensor<?x?xindex>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["tensor.pad"]} in %arg1
+ : (!pdl.operation) -> !pdl.operation
+ transform.structured.rewrite_in_destination_passing_style %0
+ : (!pdl.operation) -> !pdl.operation
+}
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
index 7842e860c6a67..5ce43ff99232b 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
@@ -128,10 +128,6 @@ struct TestLinalgTransforms
*this, "test-erase-unnecessary-inputs",
llvm::cl::desc("Test patterns to erase unnecessary inputs"),
llvm::cl::init(false)};
- Option<bool> testConvertToDestinationStylePatterns{
- *this, "test-convert-to-destination-style-patterns",
- llvm::cl::desc("Test patterns that convert ops to destination style"),
- llvm::cl::init(false)};
};
} // namespace
@@ -222,12 +218,6 @@ static void applyEraseUnnecessaryInputs(func::FuncOp funcOp) {
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}
-static void applyConvertToDestinationStylePatterns(Operation *rootOp) {
- RewritePatternSet patterns(rootOp->getContext());
- populateConvertToDestinationStylePatterns(patterns);
- (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
-}
-
/// Apply transformations specified as patterns.
void TestLinalgTransforms::runOnOperation() {
if (testPatterns)
@@ -254,8 +244,6 @@ void TestLinalgTransforms::runOnOperation() {
return applyEraseUnusedOperandsAndResultsPatterns(getOperation());
if (testEraseUnnecessaryInputs)
return applyEraseUnnecessaryInputs(getOperation());
- if (testConvertToDestinationStylePatterns)
- applyConvertToDestinationStylePatterns(getOperation());
}
namespace mlir {
More information about the Mlir-commits
mailing list