[Mlir-commits] [mlir] 01581e2 - [mlir][linalg] Add bufferize_to_allocation transform op

Matthias Springer llvmlistbot at llvm.org
Wed Feb 15 06:22:55 PST 2023


Author: Matthias Springer
Date: 2023-02-15T15:22:46+01:00
New Revision: 01581e28ad929c6f2b8b1c31828006f32b2180d1

URL: https://github.com/llvm/llvm-project/commit/01581e28ad929c6f2b8b1c31828006f32b2180d1
DIFF: https://github.com/llvm/llvm-project/commit/01581e28ad929c6f2b8b1c31828006f32b2180d1.diff

LOG: [mlir][linalg] Add bufferize_to_allocation transform op

This transform materializes a buffer allocation for a given tensor value. All uses of the original value are replaced with the allocation.

Certain non-DPS ops may have an optimized lowering path that bufferizes the entire defining op. Such optimization is added for `tensor.pad` as part of this change.

The resulting IR can be further bufferized with One-Shot Bufferize.

Differential Revision: https://reviews.llvm.org/D144022

Added: 
    mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir

Modified: 
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index a8acb052b60b4..58ef106563a5b 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -26,6 +26,58 @@ def TransformParamTypeOrAnyHandle : Type<
         Transform_ParamType.predicate]>,
     "transform 'param' type or any handle type">;
 
+//===----------------------------------------------------------------------===//
+// BufferizeToAllocationOp
+//===----------------------------------------------------------------------===//
+
+def BufferizeToAllocationOp : Op<Transform_Dialect,
+    "structured.bufferize_to_allocation",
+    [DeclareOpInterfaceMethods<TransformOpInterface>,
+     DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
+  let description = [{
+    This transform materializes an allocation for the targeted tensor value. It
+    replaces all original uses of the target with the newly allocated buffer,
+    wrapped in a `bufferization.to_tensor` op. It returns a handle to the result
+    of the `to_tensor` op.
+
+    Example:
+    ```
+    %0 = "some_op"() : () -> (tensor<10xf32>)
+    "some_use"(%0) : (tensor<10xf32>) -> ()
+    ```
+
+    Is rewritten to:
+    ```
+    %0 = "some_op"() : () -> (tensor<10xf32>)
+    %1 = memref.alloc() : memref<10xf32>
+    memref.tensor_store %0, %1 : memref<10xf32>
+    %2 = bufferization.to_tensor %1 restrict writable : memref<10xf32>
+    "some_use"(%2) : (tensor<10xf32>) -> ()
+    ```
+
+    This transform has optimized lowerings for certain targets that are results
+    of non-DPS ops. For such targets, not only a buffer allocation is emitted
+    but also the defining op is bufferized. This is to avoid a second
+    allocation for the missing destination of the non-DPS op (when subsequently
+    running a bufferization pass/transform). Currently supported ops with
+    optimized lowerings:
+    - tensor.pad
+
+    An optional memory space attribute can be specified for the materialized
+    buffer allocation.
+
+    #### Return modes
+
+    This operation consumes the `target` handle and produces the `transformed`
+    handle. It always succeeds.
+  }];
+
+  let arguments = (ins Transform_AnyValue:$target,
+                       OptionalAttr<AnyAttr>:$memory_space);
+  let results = (outs Transform_AnyValue:$transformed);
+  let assemblyFormat = "$target attr-dict";
+}
+
 //===----------------------------------------------------------------------===//
 // DecomposeOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 953eb59b95134..dd01a2e3325bf 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -44,6 +44,35 @@ struct LinalgTilingOptions;
 //===----------------------------------------------------------------------===//
 using LinalgLoops = SmallVector<Operation *, 4>;
 
+/// Materialize a buffer allocation for the given tensor.pad op and lower the
+/// op to linalg.fill/linalg.generic + memref.tensor_store. E.g.:
+///
+/// %0 = tensor.pad low[%l] high[%h] %t ...
+///
+/// is lowered to:
+///
+/// %alloc = memref.alloc
+/// linalg.fill ... outs(%alloc)
+/// %subview = memref.subview %alloc [%l] [...] [1]
+/// memref.tensor_store %t, %subview
+/// %0 = bufferization.to_tensor %alloc restrict writable
+///
+/// In addition to rewriting the IR as shown above, the result of the
+/// bufferization.to_tensor op is returned.
+Value bufferizeToAllocation(RewriterBase &rewriter, tensor::PadOp padOp,
+                            Attribute memorySpace = {});
+
+/// Materialize a buffer allocation for the given tensor value. E.g.:
+///
+/// %alloc = memref.alloc
+/// memref.tensor_store %value, %alloc
+/// %0 = bufferization.to_tensor %alloc restrict writable
+///
+/// In case `value` is a tensor.pad result, the corresponding overload is used
+/// internally to produce a better bufferization.
+Value bufferizeToAllocation(RewriterBase &rewriter, Value value,
+                            Attribute memorySpace = {});
+
 void populatePadTensorTilingPatterns(RewritePatternSet &patterns,
                                      const LinalgTilingOptions &options);
 

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 0ac80c1e637fc..dab98d2406f45 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -209,6 +209,30 @@ static PackingMetadata computePackingMetadata(int64_t packedRank,
   return res;
 }
 
+//===----------------------------------------------------------------------===//
+// BufferizeToAllocationOp
+//===----------------------------------------------------------------------===//
+DiagnosedSilenceableFailure
+transform::BufferizeToAllocationOp::apply(transform::TransformResults &results,
+                                          transform::TransformState &state) {
+  Attribute memorySpace =
+      getMemorySpace().has_value() ? getMemorySpace().value() : Attribute();
+  IRRewriter rewriter(getContext());
+  auto transformed = llvm::to_vector(
+      llvm::map_range(state.getPayloadValues(getTarget()), [&](Value v) {
+        return linalg::bufferizeToAllocation(rewriter, v, memorySpace);
+      }));
+  results.setValues(getTransformed().cast<OpResult>(), transformed);
+  return DiagnosedSilenceableFailure::success();
+}
+
+void transform::BufferizeToAllocationOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  consumesHandle(getTarget(), effects);
+  producesHandle(getTransformed(), effects);
+  modifiesPayload(effects);
+}
+
 //===----------------------------------------------------------------------===//
 // DecomposeOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
index f0f7187804bbb..261ad97ac46be 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
@@ -14,6 +14,8 @@
 //===----------------------------------------------------------------------===//
 //
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -134,7 +136,157 @@ struct GenerateOpConverter : public OpRewritePattern<GenerateOp> {
     return success();
   }
 };
+} // namespace
+
+static Operation *movePaddingToFillOrGenericOp(RewriterBase &rewriter,
+                                               Location loc, PadOp padOp,
+                                               Value dest) {
+  OpBuilder::InsertionGuard g(rewriter);
+  RankedTensorType resultType = padOp.getResultType();
+
+  // Examine the yielded value to decide if a linalg.generic is neede or a
+  // linalg.fill is sufficient.
+  Value yieldedValue =
+      cast<tensor::YieldOp>(padOp.getBody()->getTerminator()).getValue();
+  Attribute constYieldedValue;
+  // Is the yielded value a bbArg defined outside of the PadOp?
+  bool outsideBbArg =
+      yieldedValue.isa<BlockArgument>() &&
+      yieldedValue.cast<BlockArgument>().getOwner()->getParentOp() !=
+          padOp.getOperation();
+  // Is the yielded value an OpResult defined outside of the PadOp?
+  bool outsideOpResult =
+      yieldedValue.isa<OpResult>() &&
+      yieldedValue.getDefiningOp()->getParentOp() != padOp.getOperation();
+  bool invariantYieldedValue = outsideBbArg || outsideOpResult;
+  if (matchPattern(yieldedValue, m_Constant(&constYieldedValue))) {
+    // Padding with a constant: Create linalg.fill.
+    Dialect *arithDialect =
+        rewriter.getContext()->getLoadedDialect<arith::ArithDialect>();
+    Value fillValue =
+        arithDialect
+            ->materializeConstant(rewriter, constYieldedValue,
+                                  yieldedValue.getType(), yieldedValue.getLoc())
+            ->getResult(0);
+    auto fillOp = rewriter.create<linalg::FillOp>(loc, ValueRange(fillValue),
+                                                  ValueRange(dest));
+    return fillOp;
+  }
+
+  if (invariantYieldedValue) {
+    // Padding with an invariant value.
+    auto fillOp = rewriter.create<linalg::FillOp>(loc, ValueRange(yieldedValue),
+                                                  ValueRange(dest));
+    return fillOp;
+  }
 
+  // Create linalg.generic.
+  SmallVector<utils::IteratorType> iteratorTypes(resultType.getRank(),
+                                                 utils::IteratorType::parallel);
+  SmallVector<AffineMap> indexingMaps(
+      1, rewriter.getMultiDimIdentityMap(resultType.getRank()));
+  auto genericOp = rewriter.create<linalg::GenericOp>(
+      loc, resultType, /*inputs=*/ValueRange(),
+      /*outputs=*/ValueRange{dest}, /*indexingMaps=*/
+      indexingMaps, iteratorTypes);
+  Block *body = rewriter.createBlock(&genericOp->getRegion(0), {},
+                                     resultType.getElementType(), loc);
+  rewriter.setInsertionPointToStart(body);
+  SmallVector<Value> bbArgReplacements;
+  for (int64_t i = 0; i < resultType.getRank(); ++i)
+    bbArgReplacements.push_back(rewriter.create<linalg::IndexOp>(loc, i));
+  rewriter.mergeBlocks(padOp.getBody(), body, bbArgReplacements);
+
+  // Update terminator.
+  auto yieldOp = cast<tensor::YieldOp>(body->getTerminator());
+  rewriter.replaceOpWithNewOp<linalg::YieldOp>(yieldOp, yieldOp.getValue());
+  return genericOp;
+}
+
+static SmallVector<Value> reifyOrComputeDynamicSizes(OpBuilder &b,
+                                                     Value value) {
+  auto tensorType = value.getType().cast<RankedTensorType>();
+  if (tensorType.hasStaticShape())
+    return {};
+
+  // Try to reify dynamic sizes.
+  if (auto reifiableOp =
+          value.getDefiningOp<ReifyRankedShapedTypeOpInterface>()) {
+    ReifiedRankedShapedTypeDims reifiedShape;
+    if (succeeded(reifiableOp.reifyResultShapes(b, reifiedShape))) {
+      SmallVector<Value> dynSizes;
+      for (int64_t i = 0; i < tensorType.getRank(); ++i) {
+        if (tensorType.isDynamicDim(i))
+          dynSizes.push_back(
+              reifiedShape[value.cast<OpResult>().getResultNumber()][i]);
+      }
+      return dynSizes;
+    }
+  }
+
+  // Create tensor.dim ops.
+  SmallVector<Value> dynSizes;
+  for (int64_t i = 0; i < tensorType.getRank(); ++i) {
+    if (tensorType.isDynamicDim(i))
+      dynSizes.push_back(
+          b.create<DimOp>(value.getLoc(), value,
+                          b.create<arith::ConstantIndexOp>(value.getLoc(), i)));
+  }
+  return dynSizes;
+}
+
+static Value createAllocationForTensor(RewriterBase &rewriter, Location loc,
+                                       Value value,
+                                       Attribute memorySpace = {}) {
+  OpBuilder::InsertionGuard g(rewriter);
+  auto tensorType = value.getType().cast<RankedTensorType>();
+
+  // Create buffer allocation.
+  auto memrefType = bufferization::getMemRefTypeWithStaticIdentityLayout(
+                        tensorType, memorySpace)
+                        .cast<MemRefType>();
+  SmallVector<Value> dynamicSizes = reifyOrComputeDynamicSizes(rewriter, value);
+  Value alloc = rewriter.create<memref::AllocOp>(loc, memrefType, dynamicSizes);
+
+  // Place deallocation at the end of the block.
+  rewriter.setInsertionPoint(rewriter.getInsertionBlock()->getTerminator());
+  rewriter.create<memref::DeallocOp>(loc, alloc);
+
+  return alloc;
+}
+
+Value linalg::bufferizeToAllocation(RewriterBase &rewriter, PadOp padOp,
+                                    Attribute memorySpace) {
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(padOp);
+  Location loc = padOp.getLoc();
+
+  // Create buffer allocation.
+  Value alloc =
+      createAllocationForTensor(rewriter, loc, padOp.getResult(), memorySpace);
+
+  // Create linalg.fill or linalg.generic.
+  Operation *fillOp = movePaddingToFillOrGenericOp(rewriter, loc, padOp, alloc);
+  rewriter.setInsertionPointAfter(fillOp);
+
+  // Create memref.tensor_store.
+  SmallVector<OpFoldResult> sizes =
+      getMixedSizes(rewriter, loc, padOp.getSource());
+  SmallVector<OpFoldResult> strides(padOp.getResultType().getRank(),
+                                    rewriter.getIndexAttr(1));
+  Value subview = rewriter.create<memref::SubViewOp>(
+      loc, alloc, /*offsets=*/padOp.getMixedLowPad(), sizes, strides);
+  rewriter.create<memref::TensorStoreOp>(loc, padOp.getSource(), subview);
+
+  // Create bufferization.to_tensor with "restrict" and "writable". The returned
+  // tensor is a new buffer allocation, so it does not alias with any buffer.
+  Value toTensorOp = rewriter.create<bufferization::ToTensorOp>(
+      loc, alloc, /*restrict=*/true, /*writable=*/true);
+  rewriter.replaceOp(padOp, toTensorOp);
+  return toTensorOp;
+}
+
+namespace {
 /// Lower tensor.pad to linalg.generic + tensor.insert_slice.
 struct PadOpConverter : public OpRewritePattern<PadOp> {
   using OpRewritePattern<PadOp>::OpRewritePattern;
@@ -159,65 +311,10 @@ struct PadOpConverter : public OpRewritePattern<PadOp> {
         dynamicSizes.push_back(reifiedShape[0][i]);
     auto emptyOp = rewriter.create<EmptyOp>(loc, resultType, dynamicSizes);
 
-    // Examine the yielded value to decide if a linalg.generic is neede or a
-    // linalg.fill is sufficient.
-    Value filled;
-    Value yieldedValue =
-        cast<tensor::YieldOp>(padOp.getBody()->getTerminator()).getValue();
-    Attribute constYieldedValue;
-    // Is the yielded value a bbArg defined outside of the PadOp?
-    bool outsideBbArg =
-        yieldedValue.isa<BlockArgument>() &&
-        yieldedValue.cast<BlockArgument>().getOwner()->getParentOp() !=
-            padOp.getOperation();
-    // Is the yielded value an OpResult defined outside of the PadOp?
-    bool outsideOpResult =
-        yieldedValue.isa<OpResult>() &&
-        yieldedValue.getDefiningOp()->getParentOp() != padOp.getOperation();
-    bool invariantYieldedValue = outsideBbArg || outsideOpResult;
-    if (matchPattern(yieldedValue, m_Constant(&constYieldedValue))) {
-      // Padding with a constant: Create linalg.fill.
-      Dialect *arithDialect =
-          rewriter.getContext()->getLoadedDialect<arith::ArithDialect>();
-      Value fillValue = arithDialect
-                            ->materializeConstant(rewriter, constYieldedValue,
-                                                  yieldedValue.getType(),
-                                                  yieldedValue.getLoc())
-                            ->getResult(0);
-      auto fillOp = rewriter.create<linalg::FillOp>(
-          loc, ValueRange(fillValue), ValueRange(emptyOp.getResult()));
-      rewriter.setInsertionPointAfter(fillOp);
-      filled = fillOp.getResult(0);
-    } else if (invariantYieldedValue) {
-      // Padding with an invariant value.
-      auto fillOp = rewriter.create<linalg::FillOp>(
-          loc, ValueRange(yieldedValue), ValueRange(emptyOp.getResult()));
-      rewriter.setInsertionPointAfter(fillOp);
-      filled = fillOp.getResult(0);
-    } else {
-      // Create linalg.generic.
-      SmallVector<utils::IteratorType> iteratorTypes(
-          resultType.getRank(), utils::IteratorType::parallel);
-      SmallVector<AffineMap> indexingMaps(
-          1, rewriter.getMultiDimIdentityMap(resultType.getRank()));
-      auto genericOp = rewriter.create<linalg::GenericOp>(
-          loc, resultType, /*inputs=*/ValueRange(),
-          /*outputs=*/ValueRange{emptyOp.getResult()}, /*indexingMaps=*/
-          indexingMaps, iteratorTypes);
-      Block *body = rewriter.createBlock(&genericOp->getRegion(0), {},
-                                         resultType.getElementType(), loc);
-      rewriter.setInsertionPointToStart(body);
-      SmallVector<Value> bbArgReplacements;
-      for (int64_t i = 0; i < resultType.getRank(); ++i)
-        bbArgReplacements.push_back(rewriter.create<linalg::IndexOp>(loc, i));
-      rewriter.mergeBlocks(padOp.getBody(), body, bbArgReplacements);
-
-      // Update terminator.
-      auto yieldOp = cast<tensor::YieldOp>(body->getTerminator());
-      rewriter.replaceOpWithNewOp<linalg::YieldOp>(yieldOp, yieldOp.getValue());
-      rewriter.setInsertionPointAfter(genericOp);
-      filled = genericOp->getResult(0);
-    }
+    // Create linalg.fill or linalg.generic.
+    Operation *fillOp =
+        movePaddingToFillOrGenericOp(rewriter, loc, padOp, emptyOp.getResult());
+    rewriter.setInsertionPointAfter(fillOp);
 
     // Create tensor::InsertSliceOp.
     SmallVector<OpFoldResult> sliceSizes =
@@ -225,15 +322,50 @@ struct PadOpConverter : public OpRewritePattern<PadOp> {
     SmallVector<OpFoldResult> sliceStrides(resultType.getRank(),
                                            rewriter.getIndexAttr(1));
     rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
-        padOp, padOp.getSource(), filled,
+        padOp, padOp.getSource(), fillOp->getResult(0),
         /*offsets=*/padOp.getMixedLowPad(), sliceSizes, sliceStrides);
 
     return success();
   }
 };
-
 } // namespace
 
+Value linalg::bufferizeToAllocation(RewriterBase &rewriter, Value value,
+                                    Attribute memorySpace) {
+  // Call specialized overload for certain ops.
+  if (auto padOp = value.getDefiningOp<PadOp>())
+    return bufferizeToAllocation(rewriter, padOp, memorySpace);
+
+  // Collect all uses.
+  SmallVector<OpOperand *> uses = llvm::to_vector(
+      llvm::map_range(value.getUses(), [](OpOperand &use) { return &use; }));
+
+  OpBuilder::InsertionGuard g(rewriter);
+  if (auto bbArg = value.dyn_cast<BlockArgument>()) {
+    rewriter.setInsertionPointToStart(bbArg.getOwner());
+  } else {
+    rewriter.setInsertionPoint(value.getDefiningOp());
+  }
+  Location loc = value.getLoc();
+
+  // Create buffer allocation.
+  Value alloc = createAllocationForTensor(rewriter, loc, value, memorySpace);
+
+  // Create memref.tensor_store.
+  rewriter.create<memref::TensorStoreOp>(loc, value, alloc);
+
+  // Create bufferization.to_tensor with "restrict" and "writable". The returned
+  // tensor is a new buffer allocation, so it does not alias with any buffer.
+  Value toTensorOp = rewriter.create<bufferization::ToTensorOp>(
+      loc, alloc, /*restrict=*/true, /*writable=*/true);
+  for (OpOperand *use : uses) {
+    rewriter.updateRootInPlace(use->getOwner(),
+                               [&]() { use->set(toTensorOp); });
+  }
+
+  return toTensorOp;
+}
+
 void linalg::populateConvertToDestinationStylePatterns(
     RewritePatternSet &patterns) {
   patterns.insert<FromElementsOpConverter, GenerateOpConverter, PadOpConverter>(

diff  --git a/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir b/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir
new file mode 100644
index 0000000000000..d6a282d2e175d
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir
@@ -0,0 +1,61 @@
+// RUN: mlir-opt -split-input-file \
+// RUN:   -test-transform-dialect-interpreter -canonicalize \
+// RUN:   -allow-unregistered-dialect -split-input-file %s | FileCheck %s
+
+// CHECK:       #[[$map:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 5)>
+// CHECK:       #[[$map1:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 10)>
+// CHECK-LABEL: func @tensor_pad_constant(
+//  CHECK-SAME:   %[[t:.*]]: tensor<?x10xindex>, %[[l2:.*]]: index, %[[h1:.*]]: index, %[[h2:.*]]: index
+//   CHECK-DAG:   %[[c0:.*]] = arith.constant 0 : index
+//   CHECK-DAG:   %[[c50:.*]] = arith.constant 50 : index
+//   CHECK-DAG:   %[[dim0:.*]] = tensor.dim %[[t]], %[[c0]]
+//   CHECK-DAG:   %[[size0:.*]] = affine.apply #[[$map]]()[%[[h1]], %[[dim0]]]
+//   CHECK-DAG:   %[[size1:.*]] = affine.apply #[[$map1]]()[%[[l2]], %[[h2]]]
+//       CHECK:   %[[alloc:.*]] = memref.alloc(%[[size0]], %[[size1]]) : memref<?x?xindex>
+//       CHECK:   linalg.fill ins(%[[c50]] : index) outs(%[[alloc]] : memref<?x?xindex>)
+//       CHECK:   %[[dim0:.*]] = tensor.dim %[[t]], %[[c0]]
+//       CHECK:   %[[subview:.*]] = memref.subview %[[alloc]][5, %[[l2]]] [%[[dim0]], 10] [1, 1]
+//       CHECK:   memref.tensor_store %[[t]], %[[subview]]
+//       CHECK:   %[[r:.*]] = bufferization.to_tensor %[[alloc]] restrict writable : memref<?x?xindex>
+//       CHECK:   memref.dealloc %[[alloc]]
+//       CHECK:   return %[[r]]
+func.func @tensor_pad_constant(%t: tensor<?x10xindex>, %l2: index, %h1: index,
+                               %h2: index) -> tensor<?x?xindex> {
+  %0 = tensor.pad %t low[5, %l2] high[%h1, %h2] {
+  ^bb0(%arg0: index, %arg1: index):
+    %c = arith.constant 50 : index
+    tensor.yield %c : index
+  } : tensor<?x10xindex> to tensor<?x?xindex>
+  return %0 : tensor<?x?xindex>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["tensor.pad"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+  %1 = transform.get_result %0[0] : (!pdl.operation) -> !transform.any_value
+  %2 = transform.structured.bufferize_to_allocation %1
+}
+
+// -----
+
+// CHECK-LABEL: func @materialization_of_bbarg(
+//  CHECK-SAME:     %[[t:.*]]: tensor<?x10xindex>
+//       CHECK:   %[[c0:.*]] = arith.constant 0 : index
+//       CHECK:   %[[dim:.*]] = tensor.dim %[[t]], %[[c0]]
+//       CHECK:   %[[alloc:.*]] = memref.alloc(%[[dim]]) : memref<?x10xindex, 4>
+//       CHECK:   memref.tensor_store %[[t]], %[[alloc]]
+//       CHECK:   %[[alloc_t:.*]] = bufferization.to_tensor %[[alloc]] restrict writable
+//       CHECK:   %[[r:.*]] = tensor.extract %[[alloc_t]]
+//       CHECK:   memref.dealloc %[[alloc]]
+//       CHECK:   return %[[r]]
+func.func @materialization_of_bbarg(%t: tensor<?x10xindex>, %idx: index) -> index {
+  %r = tensor.extract %t[%idx, %idx] : tensor<?x10xindex>
+  return %r : index
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["tensor.extract"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+  %1 = test_produce_value_handle_to_argument_of_parent_block %0, 0 : (!pdl.operation) -> !transform.any_value
+  %2 = transform.structured.bufferize_to_allocation %1 {memory_space = 4}
+}


        


More information about the Mlir-commits mailing list