[Mlir-commits] [mlir] 96179df - [mlir][Linalg] Add a transform dialect op to rewrite ops to destination passing style.

Nicolas Vasilache llvmlistbot at llvm.org
Thu Feb 16 05:39:01 PST 2023


Author: Nicolas Vasilache
Date: 2023-02-16T05:26:33-08:00
New Revision: 96179dff46a9a3981708b06bc9e0f981be4cc1a8

URL: https://github.com/llvm/llvm-project/commit/96179dff46a9a3981708b06bc9e0f981be4cc1a8
DIFF: https://github.com/llvm/llvm-project/commit/96179dff46a9a3981708b06bc9e0f981be4cc1a8.diff

LOG: [mlir][Linalg] Add a transform dialect op to rewrite ops to destination passing style.

A new transform dialect op is introduced to perform the rewrite.
The test pass option is now obsolete and is removed in favor of the transform.

In the process I realized the tensor.pad nofold attribute was not taken into account
and added support to emit a bufferization.alloc_tensor + linalg.copy.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D143943

Added: 
    mlir/test/Dialect/Linalg/transform-op-rewrite-in-destination-passing-style.mlir

Modified: 
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
    mlir/lib/Dialect/Linalg/Utils/Utils.cpp
    mlir/lib/Dialect/Utils/StaticValueUtils.cpp
    mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp

Removed: 
    mlir/test/Dialect/Linalg/convert-to-destination-style.mlir


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 58ef106563a5b..ff8f3e33703c2 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -873,6 +873,45 @@ def ScalarizeOp : Op<Transform_Dialect, "structured.scalarize",
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// RewriteInDestinationPassingStyleOp.
+//===----------------------------------------------------------------------===//
+
+def RewriteInDestinationPassingStyleOp : Op<
+    Transform_Dialect, "structured.rewrite_in_destination_passing_style",
+    [MemoryEffectsOpInterface,
+     NavigationTransformOpTrait,
+     DeclareOpInterfaceMethods<TransformOpInterface>]> {
+  let description = [{
+    Rewrite a supported tensor operation that is not in destination-passing style
+    into a form that is in destination-passing style.
+    Currently supported operations are:
+      - tensor.pad
+      - tensor.generate
+      - tensor.from_elements
+    This dichotomy hints at a future interface, for now the implementation just 
+    switches between 
diff erent implementation.
+
+    #### Return modes
+
+    This operation ignores non-unsupported ops and drops them from the return.
+    If all the operations referred to by the `target` PDLOperation generalize
+    properly, the transform succeeds. Otherwise the transform silently fails.
+    The return handle points to a subset of successfully produced operations:
+      - tensor.pad case, the returned handle points to the tensor.insert_slice.
+      - tensor.generate case, the returned handle points to the linalg.generic.
+      - tensor.from_elements case, the returned handle points to the last 
+        tensor.insert.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$target);
+  let results = (outs TransformHandleTypeInterface:$transformed);
+  let assemblyFormat = [{
+    $target attr-dict
+    `:` functional-type($target, results)
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // SplitOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index dd01a2e3325bf..aedef03b88fc7 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1206,6 +1206,20 @@ packTranspose(RewriterBase &rewriter, tensor::PackOp packOp,
               linalg::LinalgOp linalgOp, tensor::UnPackOp maybeUnPackOp,
               ArrayRef<int64_t> outerPerm, ArrayRef<int64_t> innerPerm);
 
+/// Rewrite tensor.from_elements to linalg.generic.
+FailureOr<Operation *>
+rewriteInDestinationPassingStyle(RewriterBase &rewriter,
+                                 tensor::FromElementsOp fromElementsOp);
+
+/// Rewrite tensor.generate to linalg.generic.
+FailureOr<Operation *>
+rewriteInDestinationPassingStyle(RewriterBase &rewriter,
+                                 tensor::GenerateOp generateOp);
+
+/// Rewrite tensor.pad to linalg.generic + tensor.insert_slice.
+FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
+                                                        tensor::PadOp padOp);
+
 } // namespace linalg
 } // namespace mlir
 

diff  --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index 0b0e03dc2788c..c37d35134dce1 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -21,6 +21,10 @@
 
 namespace mlir {
 
+/// Return true if `v` is an IntegerAttr with value `0` of a ConstantIndexOp
+/// with attribute with value `0`.
+bool isZeroIndex(OpFoldResult v);
+
 /// Represents a range (offset, size, and stride) where each element of the
 /// triple may be dynamic or static.
 struct Range {
@@ -30,8 +34,8 @@ struct Range {
 };
 
 /// Given an array of Range values, return a tuple of (offset vector, sizes
-/// vector, and strides vector) formed by separating out the individual elements
-/// of each range.
+/// vector, and strides vector) formed by separating out the individual
+/// elements of each range.
 std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
            SmallVector<OpFoldResult>>
 getOffsetsSizesAndStrides(ArrayRef<Range> ranges);
@@ -40,14 +44,15 @@ getOffsetsSizesAndStrides(ArrayRef<Range> ranges);
 ///   a) it is an IntegerAttr
 /// In other cases, the OpFoldResult is dispached to the `dynamicVec`.
 /// In such dynamic cases, ShapedType::kDynamic is also pushed to
-/// `staticVec`. This is useful to extract mixed static and dynamic entries that
-/// come from an AttrSizedOperandSegments trait.
+/// `staticVec`. This is useful to extract mixed static and dynamic entries
+/// that come from an AttrSizedOperandSegments trait.
 void dispatchIndexOpFoldResult(OpFoldResult ofr,
                                SmallVectorImpl<Value> &dynamicVec,
                                SmallVectorImpl<int64_t> &staticVec);
 
-/// Helper function to dispatch multiple OpFoldResults according to the behavior
-/// of `dispatchIndexOpFoldResult(OpFoldResult ofr` for a single OpFoldResult.
+/// Helper function to dispatch multiple OpFoldResults according to the
+/// behavior of `dispatchIndexOpFoldResult(OpFoldResult ofr` for a single
+/// OpFoldResult.
 void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs,
                                 SmallVectorImpl<Value> &dynamicVec,
                                 SmallVectorImpl<int64_t> &staticVec);
@@ -72,27 +77,28 @@ std::optional<int64_t> getConstantIntValue(OpFoldResult ofr);
 /// Return true if `ofr` is constant integer equal to `value`.
 bool isConstantIntValue(OpFoldResult ofr, int64_t value);
 
-/// Return true if ofr1 and ofr2 are the same integer constant attribute values
-/// or the same SSA value.
-/// Ignore integer bitwitdh and type mismatch that come from the fact there is
-/// no IndexAttr and that IndexType have no bitwidth.
+/// Return true if ofr1 and ofr2 are the same integer constant attribute
+/// values or the same SSA value. Ignore integer bitwitdh and type mismatch
+/// that come from the fact there is no IndexAttr and that IndexType have no
+/// bitwidth.
 bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2);
 
 /// Helper function to convert a vector of `OpFoldResult`s into a vector of
-/// `Value`s. For each `OpFoldResult` in `valueOrAttrVec` return the fold result
-/// if it casts to  a `Value` or create an index-type constant if it casts to
-/// `IntegerAttr`. No other attribute types are supported.
+/// `Value`s. For each `OpFoldResult` in `valueOrAttrVec` return the fold
+/// result if it casts to  a `Value` or create an index-type constant if it
+/// casts to `IntegerAttr`. No other attribute types are supported.
 SmallVector<Value> getAsValues(OpBuilder &b, Location loc,
                                ArrayRef<OpFoldResult> valueOrAttrVec);
 
-/// Return a vector of OpFoldResults with the same size a staticValues, but all
-/// elements for which ShapedType::isDynamic is true, will be replaced by
+/// Return a vector of OpFoldResults with the same size a staticValues, but
+/// all elements for which ShapedType::isDynamic is true, will be replaced by
 /// dynamicValues.
 SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
                                          ValueRange dynamicValues, Builder &b);
 
-/// Decompose a vector of mixed static or dynamic values into the corresponding
-/// pair of arrays. This is the inverse function of `getMixedValues`.
+/// Decompose a vector of mixed static or dynamic values into the
+/// corresponding pair of arrays. This is the inverse function of
+/// `getMixedValues`.
 std::pair<ArrayAttr, SmallVector<Value>>
 decomposeMixedValues(Builder &b,
                      const SmallVectorImpl<OpFoldResult> &mixedValues);

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index dab98d2406f45..82c4c39032001 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -39,6 +39,7 @@
 #include "llvm/ADT/SetOperations.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringSet.h"
+#include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/Debug.h"
 
 using namespace mlir;
@@ -1919,6 +1920,32 @@ transform::ScalarizeOp::applyToOne(LinalgOp target,
   return DiagnosedSilenceableFailure::success();
 }
 
+//===----------------------------------------------------------------------===//
+// RewriteInDestinationPassingStyleOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::RewriteInDestinationPassingStyleOp::apply(
+    transform::TransformResults &results, transform::TransformState &state) {
+  SmallVector<Operation *> res;
+  ArrayRef<Operation *> targetOps = state.getPayloadOps(getTarget());
+  for (Operation *target : targetOps) {
+    IRRewriter rewriter(target->getContext());
+    rewriter.setInsertionPoint(target);
+    FailureOr<Operation *> maybeResult =
+        TypeSwitch<Operation *, FailureOr<Operation *>>(target)
+            .Case<tensor::FromElementsOp, tensor::GenerateOp, tensor::PadOp>(
+                [&rewriter](auto op) {
+                  return rewriteInDestinationPassingStyle(rewriter, op);
+                });
+    if (failed(maybeResult))
+      return emitDefaultSilenceableFailure(target);
+    res.push_back(*maybeResult);
+  }
+  results.set(getResult().cast<OpResult>(), res);
+  return DiagnosedSilenceableFailure::success();
+}
+
 //===----------------------------------------------------------------------===//
 // SplitOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
index 1915c2acea441..85235ef2db376 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
@@ -19,8 +19,10 @@
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/Support/Debug.h"
 
 using namespace mlir;
@@ -50,94 +52,6 @@ static Value createInserts(RewriterBase &rewriter, Location loc, int dim,
   return destination;
 }
 
-namespace {
-
-/// Lower tensor.from_elements to a sequence of chained tensor.insert.
-struct FromElementsOpConverter : public OpRewritePattern<FromElementsOp> {
-  using OpRewritePattern<FromElementsOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(FromElementsOp elementsOp,
-                                PatternRewriter &rewriter) const override {
-    Location loc = elementsOp.getLoc();
-    RankedTensorType tensorType = elementsOp.getType().cast<RankedTensorType>();
-    auto shape = tensorType.getShape();
-
-    // Create tensor.empty.
-    auto emptyOp = rewriter.create<EmptyOp>(loc, tensorType, ValueRange());
-
-    // Case: tensor<elem_type>.
-    if (shape.empty()) {
-      rewriter.replaceOpWithNewOp<tensor::InsertOp>(
-          elementsOp, elementsOp.getElements().front(), emptyOp.getResult(),
-          ValueRange());
-      return success();
-    }
-
-    // Create constants for the range of possible indices [0, max{shape_i}).
-    auto maxDim = *std::max_element(shape.begin(), shape.end());
-    SmallVector<Value, 2> constants;
-    constants.reserve(maxDim);
-    for (int i = 0; i < maxDim; ++i)
-      constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i));
-
-    // Traverse all elements and create tensor.insert ops.
-    auto elementIt = elementsOp.getElements().begin();
-    SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]);
-    Value result = createInserts(rewriter, loc, /*dim=*/0, emptyOp.getResult(),
-                                 shape, constants, elementIt, indices);
-
-    // Replace tensor.from_elements.
-    rewriter.replaceOp(elementsOp, result);
-    return success();
-  }
-};
-
-/// Lower tensor.generate to linalg.generic.
-struct GenerateOpConverter : public OpRewritePattern<GenerateOp> {
-  using OpRewritePattern<GenerateOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(GenerateOp generateOp,
-                                PatternRewriter &rewriter) const override {
-    // Only ops with exactly one block are supported.
-    if (!generateOp.getBody().hasOneBlock())
-      return failure();
-
-    Location loc = generateOp.getLoc();
-    RankedTensorType tensorType = generateOp.getType().cast<RankedTensorType>();
-
-    // Create tensor.empty.
-    auto emptyOp = rewriter.create<EmptyOp>(loc, tensorType,
-                                            generateOp.getDynamicExtents());
-
-    // Create linalg.generic.
-    SmallVector<utils::IteratorType> iteratorTypes(
-        tensorType.getRank(), utils::IteratorType::parallel);
-    SmallVector<AffineMap> indexingMaps(
-        1, rewriter.getMultiDimIdentityMap(tensorType.getRank()));
-    auto genericOp = rewriter.create<linalg::GenericOp>(
-        loc, tensorType, /*inputs=*/ValueRange(),
-        /*outputs=*/ValueRange{emptyOp.getResult()}, /*indexingMaps=*/
-        indexingMaps, iteratorTypes);
-    Block *body = rewriter.createBlock(&genericOp->getRegion(0), {},
-                                       tensorType.getElementType(), loc);
-    rewriter.setInsertionPointToStart(body);
-    SmallVector<Value> bbArgReplacements;
-    for (int64_t i = 0; i < tensorType.getRank(); ++i)
-      bbArgReplacements.push_back(rewriter.create<linalg::IndexOp>(loc, i));
-    rewriter.mergeBlocks(&generateOp.getBody().front(), body,
-                         bbArgReplacements);
-
-    // Update terminator.
-    auto yieldOp = cast<tensor::YieldOp>(body->getTerminator());
-    rewriter.replaceOpWithNewOp<linalg::YieldOp>(yieldOp, yieldOp.getValue());
-
-    // Replace tensor.generate.
-    rewriter.replaceOp(generateOp, genericOp->getResult(0));
-    return success();
-  }
-};
-} // namespace
-
 static Operation *movePaddingToFillOrGenericOp(RewriterBase &rewriter,
                                                Location loc, PadOp padOp,
                                                Value dest) {
@@ -287,49 +201,133 @@ Value linalg::bufferizeToAllocation(RewriterBase &rewriter, PadOp padOp,
   return toTensorOp;
 }
 
-namespace {
-/// Lower tensor.pad to linalg.generic + tensor.insert_slice.
-struct PadOpConverter : public OpRewritePattern<PadOp> {
-  using OpRewritePattern<PadOp>::OpRewritePattern;
+/// Lower tensor.from_elements to a sequence of chained tensor.insert.
+FailureOr<Operation *> mlir::linalg::rewriteInDestinationPassingStyle(
+    RewriterBase &rewriter, tensor::FromElementsOp fromElementsOp) {
+  Location loc = fromElementsOp.getLoc();
+  RankedTensorType tensorType =
+      fromElementsOp.getType().cast<RankedTensorType>();
+  auto shape = tensorType.getShape();
+
+  // Create tensor.empty.
+  auto emptyOp = rewriter.create<EmptyOp>(loc, tensorType, ValueRange());
+
+  // Case: tensor<elem_type>.
+  if (shape.empty()) {
+    Operation *res = rewriter.replaceOpWithNewOp<tensor::InsertOp>(
+        fromElementsOp, fromElementsOp.getElements().front(),
+        emptyOp.getResult(), ValueRange());
+    return res;
+  }
 
-  LogicalResult matchAndRewrite(PadOp padOp,
-                                PatternRewriter &rewriter) const override {
-    // Only ops with exactly one block are supported.
-    if (!padOp.getBodyRegion().hasOneBlock())
-      return failure();
+  // Create constants for the range of possible indices [0, max{shape_i}).
+  auto maxDim = *std::max_element(shape.begin(), shape.end());
+  SmallVector<Value, 2> constants;
+  constants.reserve(maxDim);
+  for (int i = 0; i < maxDim; ++i)
+    constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i));
+
+  // Traverse all elements and create tensor.insert ops.
+  auto elementIt = fromElementsOp.getElements().begin();
+  SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]);
+  Value result = createInserts(rewriter, loc, /*dim=*/0, emptyOp.getResult(),
+                               shape, constants, elementIt, indices);
+
+  // Replace tensor.from_elements.
+  rewriter.replaceOp(fromElementsOp, result);
+  return result.getDefiningOp();
+}
 
-    // Create tensor.empty.
-    Location loc = padOp.getLoc();
-    RankedTensorType resultType = padOp.getResultType();
-    ReifiedRankedShapedTypeDims reifiedShape;
-    if (failed(cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation())
-                   .reifyResultShapes(rewriter, reifiedShape)))
-      return rewriter.notifyMatchFailure(
-          padOp, "failed to reify tensor.pad op result shape");
-    SmallVector<Value> dynamicSizes;
-    for (int64_t i = 0; i < resultType.getRank(); ++i)
-      if (resultType.isDynamicDim(i))
-        dynamicSizes.push_back(reifiedShape[0][i]);
-    auto emptyOp = rewriter.create<EmptyOp>(loc, resultType, dynamicSizes);
-
-    // Create linalg.fill or linalg.generic.
-    Operation *fillOp =
-        movePaddingToFillOrGenericOp(rewriter, loc, padOp, emptyOp.getResult());
-    rewriter.setInsertionPointAfter(fillOp);
-
-    // Create tensor::InsertSliceOp.
-    SmallVector<OpFoldResult> sliceSizes =
-        getMixedSizes(rewriter, loc, padOp.getSource());
-    SmallVector<OpFoldResult> sliceStrides(resultType.getRank(),
-                                           rewriter.getIndexAttr(1));
-    rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
-        padOp, padOp.getSource(), fillOp->getResult(0),
-        /*offsets=*/padOp.getMixedLowPad(), sliceSizes, sliceStrides);
+/// Lower tensor.generate to linalg.generic.
+FailureOr<Operation *>
+mlir::linalg::rewriteInDestinationPassingStyle(RewriterBase &rewriter,
+                                               tensor::GenerateOp generateOp) {
+  // Only ops with exactly one block are supported.
+  if (!generateOp.getBody().hasOneBlock())
+    return failure();
 
-    return success();
+  Location loc = generateOp.getLoc();
+  RankedTensorType tensorType = generateOp.getType().cast<RankedTensorType>();
+
+  // Create tensor.empty.
+  auto emptyOp =
+      rewriter.create<EmptyOp>(loc, tensorType, generateOp.getDynamicExtents());
+
+  // Create linalg.generic.
+  SmallVector<utils::IteratorType> iteratorTypes(tensorType.getRank(),
+                                                 utils::IteratorType::parallel);
+  SmallVector<AffineMap> indexingMaps(
+      1, rewriter.getMultiDimIdentityMap(tensorType.getRank()));
+  auto genericOp = rewriter.create<linalg::GenericOp>(
+      loc, tensorType, /*inputs=*/ValueRange(),
+      /*outputs=*/ValueRange{emptyOp.getResult()}, /*indexingMaps=*/
+      indexingMaps, iteratorTypes);
+  Block *body = rewriter.createBlock(&genericOp->getRegion(0), {},
+                                     tensorType.getElementType(), loc);
+  rewriter.setInsertionPointToStart(body);
+  SmallVector<Value> bbArgReplacements;
+  for (int64_t i = 0; i < tensorType.getRank(); ++i)
+    bbArgReplacements.push_back(rewriter.create<linalg::IndexOp>(loc, i));
+  rewriter.mergeBlocks(&generateOp.getBody().front(), body, bbArgReplacements);
+
+  // Update terminator.
+  auto yieldOp = cast<tensor::YieldOp>(body->getTerminator());
+  rewriter.replaceOpWithNewOp<linalg::YieldOp>(yieldOp, yieldOp.getValue());
+
+  // Replace tensor.generate.
+  rewriter.replaceOp(generateOp, genericOp->getResult(0));
+  return genericOp.getOperation();
+}
+
+/// Lower tensor.pad to linalg.generic + tensor.insert_slice.
+FailureOr<Operation *>
+mlir::linalg::rewriteInDestinationPassingStyle(RewriterBase &rewriter,
+                                               tensor::PadOp padOp) {
+  // Only ops with exactly one block are supported.
+  if (!padOp.getBodyRegion().hasOneBlock())
+    return failure();
+
+  // Create tensor.empty.
+  Location loc = padOp.getLoc();
+  RankedTensorType resultType = padOp.getResultType();
+  ReifiedRankedShapedTypeDims reifiedShape;
+  if (failed(cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation())
+                 .reifyResultShapes(rewriter, reifiedShape)))
+    return rewriter.notifyMatchFailure(
+        padOp, "failed to reify tensor.pad op result shape");
+  SmallVector<Value> dynamicSizes;
+  for (int64_t i = 0; i < resultType.getRank(); ++i)
+    if (resultType.isDynamicDim(i))
+      dynamicSizes.push_back(reifiedShape[0][i]);
+
+  // If the `padOp` has a nofold attribute and all paddings are known to be 0,
+  // explicitly insert a `linalg.copy`.
+  if (padOp.getNofoldAttr() &&
+      llvm::all_of(padOp.getMixedLowPad(), isZeroIndex) &&
+      llvm::all_of(padOp.getMixedHighPad(), isZeroIndex)) {
+    using bufferization::AllocTensorOp;
+    Value allocated =
+        rewriter.create<AllocTensorOp>(loc, resultType, dynamicSizes);
+    auto copyOp = rewriter.replaceOpWithNewOp<linalg::CopyOp>(
+        padOp, padOp.getSource(), allocated);
+    return copyOp.getOperation();
   }
-};
-} // namespace
+
+  Value empty = rewriter.create<EmptyOp>(loc, resultType, dynamicSizes);
+  // Create linalg.fill or linalg.generic.
+  Operation *fillOp = movePaddingToFillOrGenericOp(rewriter, loc, padOp, empty);
+  rewriter.setInsertionPointAfter(fillOp);
+
+  // Create tensor::InsertSliceOp.
+  SmallVector<OpFoldResult> sliceSizes =
+      getMixedSizes(rewriter, loc, padOp.getSource());
+  SmallVector<OpFoldResult> sliceStrides(resultType.getRank(),
+                                         rewriter.getIndexAttr(1));
+  auto insertSliceOp = rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
+      padOp, padOp.getSource(), fillOp->getResult(0),
+      /*offsets=*/padOp.getMixedLowPad(), sliceSizes, sliceStrides);
+  return insertSliceOp.getOperation();
+}
 
 Value linalg::bufferizeToAllocation(RewriterBase &rewriter, Value value,
                                     Attribute memorySpace) {
@@ -368,6 +366,45 @@ Value linalg::bufferizeToAllocation(RewriterBase &rewriter, Value value,
   return toTensorOp;
 }
 
+namespace {
+/// Lower tensor.from_elements to a sequence of chained tensor.insert.
+struct FromElementsOpConverter : public OpRewritePattern<FromElementsOp> {
+  using OpRewritePattern<FromElementsOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(FromElementsOp fromElementsOp,
+                                PatternRewriter &rewriter) const override {
+    if (failed(
+            linalg::rewriteInDestinationPassingStyle(rewriter, fromElementsOp)))
+      return failure();
+    return success();
+  }
+};
+
+/// Lower tensor.generate to linalg.generic.
+struct GenerateOpConverter : public OpRewritePattern<GenerateOp> {
+  using OpRewritePattern<GenerateOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(GenerateOp generateOp,
+                                PatternRewriter &rewriter) const override {
+    if (failed(linalg::rewriteInDestinationPassingStyle(rewriter, generateOp)))
+      return failure();
+    return success();
+  }
+};
+
+/// Lower tensor.pad to linalg.generic + tensor.insert_slice.
+struct PadOpConverter : public OpRewritePattern<PadOp> {
+  using OpRewritePattern<PadOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(PadOp padOp,
+                                PatternRewriter &rewriter) const override {
+    if (failed(linalg::rewriteInDestinationPassingStyle(rewriter, padOp)))
+      return failure();
+    return success();
+  }
+};
+} // namespace
+
 void linalg::populateConvertToDestinationStylePatterns(
     RewritePatternSet &patterns) {
   patterns.insert<FromElementsOpConverter, GenerateOpConverter, PadOpConverter>(

diff  --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 7ea36248c96f7..f3879f5dd9d12 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -44,18 +44,6 @@ using namespace presburger;
 using namespace mlir::linalg;
 using namespace mlir::scf;
 
-static bool isZero(OpFoldResult v) {
-  if (!v)
-    return false;
-  if (auto attr = v.dyn_cast<Attribute>()) {
-    IntegerAttr intAttr = attr.dyn_cast<IntegerAttr>();
-    return intAttr && intAttr.getValue().isZero();
-  }
-  if (auto cst = v.get<Value>().getDefiningOp<arith::ConstantIndexOp>())
-    return cst.value() == 0;
-  return false;
-}
-
 namespace {
 
 // Helper visitor to determine whether an AffineExpr is tiled.
@@ -70,7 +58,7 @@ struct TileCheck : public AffineExprVisitor<TileCheck> {
   TileCheck(ArrayRef<OpFoldResult> tileSizes) : tileSizes(tileSizes) {}
 
   void visitDimExpr(AffineDimExpr expr) {
-    isTiled |= !isZero(tileSizes[expr.getPosition()]);
+    isTiled |= !isZeroIndex(tileSizes[expr.getPosition()]);
   }
   void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) {
     visit(expr.getLHS());
@@ -869,7 +857,7 @@ SmallVector<OpFoldResult> computeTileOffsets(OpBuilder &b, Location loc,
   SmallVector<OpFoldResult> offsets;
   for (unsigned idx = 0, idxIvs = 0, e = tileSizes.size(); idx < e; ++idx) {
     LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for loop#" << idx << "\n");
-    bool isTiled = !isZero(tileSizes[idx]);
+    bool isTiled = !isZeroIndex(tileSizes[idx]);
     offsets.push_back(isTiled ? ivs[idxIvs++] : b.getIndexAttr(0));
     LLVM_DEBUG(llvm::dbgs()
                << "computeTileOffsets: " << offsets.back() << "\n");
@@ -882,7 +870,7 @@ SmallVector<OpFoldResult> computeTileSizes(OpBuilder &b, Location loc,
                                            ArrayRef<OpFoldResult> sizeBounds) {
   SmallVector<OpFoldResult> sizes;
   for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx) {
-    bool isTiled = !isZero(tileSizes[idx]);
+    bool isTiled = !isZeroIndex(tileSizes[idx]);
     // Before composing, we need to make range a closed interval.
     OpFoldResult size = isTiled ? tileSizes[idx] : sizeBounds[idx];
     AffineExpr d0 = getAffineDimExpr(0, b.getContext());
@@ -938,7 +926,7 @@ computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp,
                           bool omitPartialTileCheck) {
   assert(ivs.size() == static_cast<size_t>(llvm::count_if(
                            llvm::make_range(tileSizes.begin(), tileSizes.end()),
-                           [](OpFoldResult v) { return !isZero(v); })) &&
+                           [](OpFoldResult v) { return !isZeroIndex(v); })) &&
          "expected as many ivs as non-zero sizes");
 
   // Construct (potentially temporary) mins and maxes on which to apply maps

diff  --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index 436e6e901a4a6..294dc810507b4 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -14,6 +14,18 @@
 
 namespace mlir {
 
+bool isZeroIndex(OpFoldResult v) {
+  if (!v)
+    return false;
+  if (auto attr = v.dyn_cast<Attribute>()) {
+    IntegerAttr intAttr = attr.dyn_cast<IntegerAttr>();
+    return intAttr && intAttr.getValue().isZero();
+  }
+  if (auto cst = v.get<Value>().getDefiningOp<arith::ConstantIndexOp>())
+    return cst.value() == 0;
+  return false;
+}
+
 std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
            SmallVector<OpFoldResult>>
 getOffsetsSizesAndStrides(ArrayRef<Range> ranges) {

diff  --git a/mlir/test/Dialect/Linalg/convert-to-destination-style.mlir b/mlir/test/Dialect/Linalg/transform-op-rewrite-in-destination-passing-style.mlir
similarity index 71%
rename from mlir/test/Dialect/Linalg/convert-to-destination-style.mlir
rename to mlir/test/Dialect/Linalg/transform-op-rewrite-in-destination-passing-style.mlir
index 3a7472fcec08f..9b89d83f6a95d 100644
--- a/mlir/test/Dialect/Linalg/convert-to-destination-style.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-rewrite-in-destination-passing-style.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -split-input-file -test-linalg-transform-patterns=test-convert-to-destination-style-patterns -canonicalize %s | FileCheck %s
+// RUN: mlir-opt  -test-transform-dialect-interpreter --split-input-file -canonicalize %s | FileCheck %s
 
 // CHECK-LABEL: func @tensor_from_elements_0d(
 //  CHECK-SAME:     %[[arg0:.*]]: index
@@ -10,6 +10,14 @@ func.func @tensor_from_elements_0d(%arg0: index) -> tensor<index> {
   return %0 : tensor<index>
 }
 
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["tensor.from_elements"]} in %arg1
+    : (!pdl.operation) -> !pdl.operation
+  transform.structured.rewrite_in_destination_passing_style %0
+    : (!pdl.operation) -> !pdl.operation
+}
+
 // -----
 
 // CHECK-LABEL: func @tensor_from_elements_1d(
@@ -25,6 +33,14 @@ func.func @tensor_from_elements_1d(%arg0: index, %arg1: index) -> tensor<2xindex
   return %0 : tensor<2xindex>
 }
 
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["tensor.from_elements"]} in %arg1
+    : (!pdl.operation) -> !pdl.operation
+  transform.structured.rewrite_in_destination_passing_style %0
+    : (!pdl.operation) -> !pdl.operation
+}
+
 // -----
 
 // CHECK-LABEL: func @tensor_from_elements_2d(
@@ -46,6 +62,14 @@ func.func @tensor_from_elements_2d(%arg0: index, %arg1: index) -> tensor<3x2xind
   return %0 : tensor<3x2xindex>
 }
 
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["tensor.from_elements"]} in %arg1
+    : (!pdl.operation) -> !pdl.operation
+  transform.structured.rewrite_in_destination_passing_style %0
+    : (!pdl.operation) -> !pdl.operation
+}
+
 // -----
 
 // CHECK: #[[$map:.*]] = affine_map<(d0, d1) -> (d0, d1)>
@@ -70,6 +94,14 @@ func.func @tensor_generate(%s1: index, %s2: index) -> tensor<?x?xindex> {
   return %0 : tensor<?x?xindex>
 }
 
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["tensor.generate"]} in %arg1
+    : (!pdl.operation) -> !pdl.operation
+  transform.structured.rewrite_in_destination_passing_style %0
+    : (!pdl.operation) -> !pdl.operation
+}
+
 // -----
 
 // CHECK:       #[[$map:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 5)>
@@ -103,6 +135,14 @@ func.func @tensor_pad(%t1: tensor<?x10xindex>, %l2: index, %h1: index,
   return %0 : tensor<?x?xindex>
 }
 
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["tensor.pad"]} in %arg1
+    : (!pdl.operation) -> !pdl.operation
+  transform.structured.rewrite_in_destination_passing_style %0
+    : (!pdl.operation) -> !pdl.operation
+}
+
 // -----
 
 // CHECK:       #[[$map:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 5)>
@@ -129,6 +169,14 @@ func.func @tensor_pad_constant(%t1: tensor<?x10xindex>, %l2: index, %h1: index,
   return %0 : tensor<?x?xindex>
 }
 
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["tensor.pad"]} in %arg1
+    : (!pdl.operation) -> !pdl.operation
+  transform.structured.rewrite_in_destination_passing_style %0
+    : (!pdl.operation) -> !pdl.operation
+}
+
 // -----
 
 // CHECK:       #[[$map:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 5)>
@@ -152,3 +200,39 @@ func.func @tensor_pad_invariant(%t1: tensor<?x10xindex>, %l2: index, %h1: index,
   } : tensor<?x10xindex> to tensor<?x?xindex>
   return %0 : tensor<?x?xindex>
 }
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["tensor.pad"]} in %arg1
+    : (!pdl.operation) -> !pdl.operation
+  transform.structured.rewrite_in_destination_passing_style %0
+    : (!pdl.operation) -> !pdl.operation
+}
+
+// -----
+
+// CHECK-LABEL: func @tensor_pad_nofold(
+//  CHECK-SAME:   %[[t1:.*]]: tensor<?x?xindex>, %[[padding:.*]]: index
+//   CHECK-NOT:   linalg.fill
+//   CHECK-NOT:   generic
+//   CHECK-NOT:   insert_slice
+//       CHECK:   %[[alloc_tensor:.*]] = bufferization.alloc_tensor(%{{.*}}) : tensor<?x?xindex>
+//       CHECK:   %[[copied:.*]] = linalg.copy ins(%[[t1]] : tensor<?x?xindex>) outs(%[[alloc_tensor]] : tensor<?x?xindex>) -> tensor<?x?xindex>
+//       CHECK:   return %[[copied]]
+func.func @tensor_pad_nofold(%t1: tensor<?x?xindex>, %padding: index)
+    -> tensor<?x?xindex> {
+  %c0 = arith.constant 0 : index
+  %0 = tensor.pad %t1 nofold low[0, %c0] high[%c0, 0] {
+  ^bb0(%arg0: index, %arg1: index):
+    tensor.yield %padding : index
+  } : tensor<?x?xindex> to tensor<?x?xindex>
+  return %0: tensor<?x?xindex>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["tensor.pad"]} in %arg1
+    : (!pdl.operation) -> !pdl.operation
+  transform.structured.rewrite_in_destination_passing_style %0
+    : (!pdl.operation) -> !pdl.operation
+}

diff  --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
index 7842e860c6a67..5ce43ff99232b 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
@@ -128,10 +128,6 @@ struct TestLinalgTransforms
       *this, "test-erase-unnecessary-inputs",
       llvm::cl::desc("Test patterns to erase unnecessary inputs"),
       llvm::cl::init(false)};
-  Option<bool> testConvertToDestinationStylePatterns{
-      *this, "test-convert-to-destination-style-patterns",
-      llvm::cl::desc("Test patterns that convert ops to destination style"),
-      llvm::cl::init(false)};
 };
 } // namespace
 
@@ -222,12 +218,6 @@ static void applyEraseUnnecessaryInputs(func::FuncOp funcOp) {
   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
 }
 
-static void applyConvertToDestinationStylePatterns(Operation *rootOp) {
-  RewritePatternSet patterns(rootOp->getContext());
-  populateConvertToDestinationStylePatterns(patterns);
-  (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
-}
-
 /// Apply transformations specified as patterns.
 void TestLinalgTransforms::runOnOperation() {
   if (testPatterns)
@@ -254,8 +244,6 @@ void TestLinalgTransforms::runOnOperation() {
     return applyEraseUnusedOperandsAndResultsPatterns(getOperation());
   if (testEraseUnnecessaryInputs)
     return applyEraseUnnecessaryInputs(getOperation());
-  if (testConvertToDestinationStylePatterns)
-    applyConvertToDestinationStylePatterns(getOperation());
 }
 
 namespace mlir {


        


More information about the Mlir-commits mailing list