[Mlir-commits] [mlir] 01581e2 - [mlir][linalg] Add bufferize_to_allocation transform op
Matthias Springer
llvmlistbot at llvm.org
Wed Feb 15 06:22:55 PST 2023
Author: Matthias Springer
Date: 2023-02-15T15:22:46+01:00
New Revision: 01581e28ad929c6f2b8b1c31828006f32b2180d1
URL: https://github.com/llvm/llvm-project/commit/01581e28ad929c6f2b8b1c31828006f32b2180d1
DIFF: https://github.com/llvm/llvm-project/commit/01581e28ad929c6f2b8b1c31828006f32b2180d1.diff
LOG: [mlir][linalg] Add bufferize_to_allocation transform op
This transform materializes a buffer allocation for a given tensor value. All uses of the original value are replaced with the allocation.
Certain non-DPS ops may have an optimized lowering path that bufferizes the entire defining op. Such optimization is added for `tensor.pad` as part of this change.
The resulting IR can be further bufferized with One-Shot Bufferize.
Differential Revision: https://reviews.llvm.org/D144022
Added:
mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir
Modified:
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index a8acb052b60b4..58ef106563a5b 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -26,6 +26,58 @@ def TransformParamTypeOrAnyHandle : Type<
Transform_ParamType.predicate]>,
"transform 'param' type or any handle type">;
+//===----------------------------------------------------------------------===//
+// BufferizeToAllocationOp
+//===----------------------------------------------------------------------===//
+
+def BufferizeToAllocationOp : Op<Transform_Dialect,
+ "structured.bufferize_to_allocation",
+ [DeclareOpInterfaceMethods<TransformOpInterface>,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
+ let description = [{
+ This transform materializes an allocation for the targeted tensor value. It
+ replaces all original uses of the target with the newly allocated buffer,
+ wrapped in a `bufferization.to_tensor` op. It returns a handle to the result
+ of the `to_tensor` op.
+
+ Example:
+ ```
+ %0 = "some_op"() : () -> (tensor<10xf32>)
+ "some_use"(%0) : (tensor<10xf32>) -> ()
+ ```
+
+ Is rewritten to:
+ ```
+ %0 = "some_op"() : () -> (tensor<10xf32>)
+ %1 = memref.alloc() : memref<10xf32>
+ memref.tensor_store %0, %1 : memref<10xf32>
+ %2 = bufferization.to_tensor %1 restrict writable : memref<10xf32>
+ "some_use"(%2) : (tensor<10xf32>) -> ()
+ ```
+
+ This transform has optimized lowerings for certain targets that are results
+ of non-DPS ops. For such targets, not only a buffer allocation is emitted
+ but also the defining op is bufferized. This is to avoid a second
+ allocation for the missing destination of the non-DPS op (when subsequently
+ running a bufferization pass/transform). Currently supported ops with
+ optimized lowerings:
+ - tensor.pad
+
+ An optional memory space attribute can be specified for the materialized
+ buffer allocation.
+
+ #### Return modes
+
+ This operation consumes the `target` handle and produces the `transformed`
+ handle. It always succeeds.
+ }];
+
+ let arguments = (ins Transform_AnyValue:$target,
+ OptionalAttr<AnyAttr>:$memory_space);
+ let results = (outs Transform_AnyValue:$transformed);
+ let assemblyFormat = "$target attr-dict";
+}
+
//===----------------------------------------------------------------------===//
// DecomposeOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 953eb59b95134..dd01a2e3325bf 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -44,6 +44,35 @@ struct LinalgTilingOptions;
//===----------------------------------------------------------------------===//
using LinalgLoops = SmallVector<Operation *, 4>;
+/// Materialize a buffer allocation for the given tensor.pad op and lower the
+/// op to linalg.fill/linalg.generic + memref.tensor_store. E.g.:
+///
+/// %0 = tensor.pad low[%l] high[%h] %t ...
+///
+/// is lowered to:
+///
+/// %alloc = memref.alloc
+/// linalg.fill ... outs(%alloc)
+/// %subview = memref.subview %alloc [%l] [...] [1]
+/// memref.tensor_store %t, %subview
+/// %0 = bufferization.to_tensor %alloc restrict writable
+///
+/// In addition to rewriting the IR as shown above, the result of the
+/// bufferization.to_tensor op is returned.
+Value bufferizeToAllocation(RewriterBase &rewriter, tensor::PadOp padOp,
+ Attribute memorySpace = {});
+
+/// Materialize a buffer allocation for the given tensor value. E.g.:
+///
+/// %alloc = memref.alloc
+/// memref.tensor_store %value, %alloc
+/// %0 = bufferization.to_tensor %alloc restrict writable
+///
+/// In case `value` is a tensor.pad result, the corresponding overload is used
+/// internally to produce a better bufferization.
+Value bufferizeToAllocation(RewriterBase &rewriter, Value value,
+ Attribute memorySpace = {});
+
void populatePadTensorTilingPatterns(RewritePatternSet &patterns,
const LinalgTilingOptions &options);
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 0ac80c1e637fc..dab98d2406f45 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -209,6 +209,30 @@ static PackingMetadata computePackingMetadata(int64_t packedRank,
return res;
}
+//===----------------------------------------------------------------------===//
+// BufferizeToAllocationOp
+//===----------------------------------------------------------------------===//
+DiagnosedSilenceableFailure
+transform::BufferizeToAllocationOp::apply(transform::TransformResults &results,
+ transform::TransformState &state) {
+ Attribute memorySpace =
+ getMemorySpace().has_value() ? getMemorySpace().value() : Attribute();
+ IRRewriter rewriter(getContext());
+ auto transformed = llvm::to_vector(
+ llvm::map_range(state.getPayloadValues(getTarget()), [&](Value v) {
+ return linalg::bufferizeToAllocation(rewriter, v, memorySpace);
+ }));
+ results.setValues(getTransformed().cast<OpResult>(), transformed);
+ return DiagnosedSilenceableFailure::success();
+}
+
+void transform::BufferizeToAllocationOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ consumesHandle(getTarget(), effects);
+ producesHandle(getTransformed(), effects);
+ modifiesPayload(effects);
+}
+
//===----------------------------------------------------------------------===//
// DecomposeOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
index f0f7187804bbb..261ad97ac46be 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
@@ -14,6 +14,8 @@
//===----------------------------------------------------------------------===//
//
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -134,7 +136,157 @@ struct GenerateOpConverter : public OpRewritePattern<GenerateOp> {
return success();
}
};
+} // namespace
+
+static Operation *movePaddingToFillOrGenericOp(RewriterBase &rewriter,
+ Location loc, PadOp padOp,
+ Value dest) {
+ OpBuilder::InsertionGuard g(rewriter);
+ RankedTensorType resultType = padOp.getResultType();
+
+ // Examine the yielded value to decide if a linalg.generic is neede or a
+ // linalg.fill is sufficient.
+ Value yieldedValue =
+ cast<tensor::YieldOp>(padOp.getBody()->getTerminator()).getValue();
+ Attribute constYieldedValue;
+ // Is the yielded value a bbArg defined outside of the PadOp?
+ bool outsideBbArg =
+ yieldedValue.isa<BlockArgument>() &&
+ yieldedValue.cast<BlockArgument>().getOwner()->getParentOp() !=
+ padOp.getOperation();
+ // Is the yielded value an OpResult defined outside of the PadOp?
+ bool outsideOpResult =
+ yieldedValue.isa<OpResult>() &&
+ yieldedValue.getDefiningOp()->getParentOp() != padOp.getOperation();
+ bool invariantYieldedValue = outsideBbArg || outsideOpResult;
+ if (matchPattern(yieldedValue, m_Constant(&constYieldedValue))) {
+ // Padding with a constant: Create linalg.fill.
+ Dialect *arithDialect =
+ rewriter.getContext()->getLoadedDialect<arith::ArithDialect>();
+ Value fillValue =
+ arithDialect
+ ->materializeConstant(rewriter, constYieldedValue,
+ yieldedValue.getType(), yieldedValue.getLoc())
+ ->getResult(0);
+ auto fillOp = rewriter.create<linalg::FillOp>(loc, ValueRange(fillValue),
+ ValueRange(dest));
+ return fillOp;
+ }
+
+ if (invariantYieldedValue) {
+ // Padding with an invariant value.
+ auto fillOp = rewriter.create<linalg::FillOp>(loc, ValueRange(yieldedValue),
+ ValueRange(dest));
+ return fillOp;
+ }
+ // Create linalg.generic.
+ SmallVector<utils::IteratorType> iteratorTypes(resultType.getRank(),
+ utils::IteratorType::parallel);
+ SmallVector<AffineMap> indexingMaps(
+ 1, rewriter.getMultiDimIdentityMap(resultType.getRank()));
+ auto genericOp = rewriter.create<linalg::GenericOp>(
+ loc, resultType, /*inputs=*/ValueRange(),
+ /*outputs=*/ValueRange{dest}, /*indexingMaps=*/
+ indexingMaps, iteratorTypes);
+ Block *body = rewriter.createBlock(&genericOp->getRegion(0), {},
+ resultType.getElementType(), loc);
+ rewriter.setInsertionPointToStart(body);
+ SmallVector<Value> bbArgReplacements;
+ for (int64_t i = 0; i < resultType.getRank(); ++i)
+ bbArgReplacements.push_back(rewriter.create<linalg::IndexOp>(loc, i));
+ rewriter.mergeBlocks(padOp.getBody(), body, bbArgReplacements);
+
+ // Update terminator.
+ auto yieldOp = cast<tensor::YieldOp>(body->getTerminator());
+ rewriter.replaceOpWithNewOp<linalg::YieldOp>(yieldOp, yieldOp.getValue());
+ return genericOp;
+}
+
+static SmallVector<Value> reifyOrComputeDynamicSizes(OpBuilder &b,
+ Value value) {
+ auto tensorType = value.getType().cast<RankedTensorType>();
+ if (tensorType.hasStaticShape())
+ return {};
+
+ // Try to reify dynamic sizes.
+ if (auto reifiableOp =
+ value.getDefiningOp<ReifyRankedShapedTypeOpInterface>()) {
+ ReifiedRankedShapedTypeDims reifiedShape;
+ if (succeeded(reifiableOp.reifyResultShapes(b, reifiedShape))) {
+ SmallVector<Value> dynSizes;
+ for (int64_t i = 0; i < tensorType.getRank(); ++i) {
+ if (tensorType.isDynamicDim(i))
+ dynSizes.push_back(
+ reifiedShape[value.cast<OpResult>().getResultNumber()][i]);
+ }
+ return dynSizes;
+ }
+ }
+
+ // Create tensor.dim ops.
+ SmallVector<Value> dynSizes;
+ for (int64_t i = 0; i < tensorType.getRank(); ++i) {
+ if (tensorType.isDynamicDim(i))
+ dynSizes.push_back(
+ b.create<DimOp>(value.getLoc(), value,
+ b.create<arith::ConstantIndexOp>(value.getLoc(), i)));
+ }
+ return dynSizes;
+}
+
+static Value createAllocationForTensor(RewriterBase &rewriter, Location loc,
+ Value value,
+ Attribute memorySpace = {}) {
+ OpBuilder::InsertionGuard g(rewriter);
+ auto tensorType = value.getType().cast<RankedTensorType>();
+
+ // Create buffer allocation.
+ auto memrefType = bufferization::getMemRefTypeWithStaticIdentityLayout(
+ tensorType, memorySpace)
+ .cast<MemRefType>();
+ SmallVector<Value> dynamicSizes = reifyOrComputeDynamicSizes(rewriter, value);
+ Value alloc = rewriter.create<memref::AllocOp>(loc, memrefType, dynamicSizes);
+
+ // Place deallocation at the end of the block.
+ rewriter.setInsertionPoint(rewriter.getInsertionBlock()->getTerminator());
+ rewriter.create<memref::DeallocOp>(loc, alloc);
+
+ return alloc;
+}
+
+Value linalg::bufferizeToAllocation(RewriterBase &rewriter, PadOp padOp,
+ Attribute memorySpace) {
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(padOp);
+ Location loc = padOp.getLoc();
+
+ // Create buffer allocation.
+ Value alloc =
+ createAllocationForTensor(rewriter, loc, padOp.getResult(), memorySpace);
+
+ // Create linalg.fill or linalg.generic.
+ Operation *fillOp = movePaddingToFillOrGenericOp(rewriter, loc, padOp, alloc);
+ rewriter.setInsertionPointAfter(fillOp);
+
+ // Create memref.tensor_store.
+ SmallVector<OpFoldResult> sizes =
+ getMixedSizes(rewriter, loc, padOp.getSource());
+ SmallVector<OpFoldResult> strides(padOp.getResultType().getRank(),
+ rewriter.getIndexAttr(1));
+ Value subview = rewriter.create<memref::SubViewOp>(
+ loc, alloc, /*offsets=*/padOp.getMixedLowPad(), sizes, strides);
+ rewriter.create<memref::TensorStoreOp>(loc, padOp.getSource(), subview);
+
+ // Create bufferization.to_tensor with "restrict" and "writable". The returned
+ // tensor is a new buffer allocation, so it does not alias with any buffer.
+ Value toTensorOp = rewriter.create<bufferization::ToTensorOp>(
+ loc, alloc, /*restrict=*/true, /*writable=*/true);
+ rewriter.replaceOp(padOp, toTensorOp);
+ return toTensorOp;
+}
+
+namespace {
/// Lower tensor.pad to linalg.generic + tensor.insert_slice.
struct PadOpConverter : public OpRewritePattern<PadOp> {
using OpRewritePattern<PadOp>::OpRewritePattern;
@@ -159,65 +311,10 @@ struct PadOpConverter : public OpRewritePattern<PadOp> {
dynamicSizes.push_back(reifiedShape[0][i]);
auto emptyOp = rewriter.create<EmptyOp>(loc, resultType, dynamicSizes);
- // Examine the yielded value to decide if a linalg.generic is neede or a
- // linalg.fill is sufficient.
- Value filled;
- Value yieldedValue =
- cast<tensor::YieldOp>(padOp.getBody()->getTerminator()).getValue();
- Attribute constYieldedValue;
- // Is the yielded value a bbArg defined outside of the PadOp?
- bool outsideBbArg =
- yieldedValue.isa<BlockArgument>() &&
- yieldedValue.cast<BlockArgument>().getOwner()->getParentOp() !=
- padOp.getOperation();
- // Is the yielded value an OpResult defined outside of the PadOp?
- bool outsideOpResult =
- yieldedValue.isa<OpResult>() &&
- yieldedValue.getDefiningOp()->getParentOp() != padOp.getOperation();
- bool invariantYieldedValue = outsideBbArg || outsideOpResult;
- if (matchPattern(yieldedValue, m_Constant(&constYieldedValue))) {
- // Padding with a constant: Create linalg.fill.
- Dialect *arithDialect =
- rewriter.getContext()->getLoadedDialect<arith::ArithDialect>();
- Value fillValue = arithDialect
- ->materializeConstant(rewriter, constYieldedValue,
- yieldedValue.getType(),
- yieldedValue.getLoc())
- ->getResult(0);
- auto fillOp = rewriter.create<linalg::FillOp>(
- loc, ValueRange(fillValue), ValueRange(emptyOp.getResult()));
- rewriter.setInsertionPointAfter(fillOp);
- filled = fillOp.getResult(0);
- } else if (invariantYieldedValue) {
- // Padding with an invariant value.
- auto fillOp = rewriter.create<linalg::FillOp>(
- loc, ValueRange(yieldedValue), ValueRange(emptyOp.getResult()));
- rewriter.setInsertionPointAfter(fillOp);
- filled = fillOp.getResult(0);
- } else {
- // Create linalg.generic.
- SmallVector<utils::IteratorType> iteratorTypes(
- resultType.getRank(), utils::IteratorType::parallel);
- SmallVector<AffineMap> indexingMaps(
- 1, rewriter.getMultiDimIdentityMap(resultType.getRank()));
- auto genericOp = rewriter.create<linalg::GenericOp>(
- loc, resultType, /*inputs=*/ValueRange(),
- /*outputs=*/ValueRange{emptyOp.getResult()}, /*indexingMaps=*/
- indexingMaps, iteratorTypes);
- Block *body = rewriter.createBlock(&genericOp->getRegion(0), {},
- resultType.getElementType(), loc);
- rewriter.setInsertionPointToStart(body);
- SmallVector<Value> bbArgReplacements;
- for (int64_t i = 0; i < resultType.getRank(); ++i)
- bbArgReplacements.push_back(rewriter.create<linalg::IndexOp>(loc, i));
- rewriter.mergeBlocks(padOp.getBody(), body, bbArgReplacements);
-
- // Update terminator.
- auto yieldOp = cast<tensor::YieldOp>(body->getTerminator());
- rewriter.replaceOpWithNewOp<linalg::YieldOp>(yieldOp, yieldOp.getValue());
- rewriter.setInsertionPointAfter(genericOp);
- filled = genericOp->getResult(0);
- }
+ // Create linalg.fill or linalg.generic.
+ Operation *fillOp =
+ movePaddingToFillOrGenericOp(rewriter, loc, padOp, emptyOp.getResult());
+ rewriter.setInsertionPointAfter(fillOp);
// Create tensor::InsertSliceOp.
SmallVector<OpFoldResult> sliceSizes =
@@ -225,15 +322,50 @@ struct PadOpConverter : public OpRewritePattern<PadOp> {
SmallVector<OpFoldResult> sliceStrides(resultType.getRank(),
rewriter.getIndexAttr(1));
rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
- padOp, padOp.getSource(), filled,
+ padOp, padOp.getSource(), fillOp->getResult(0),
/*offsets=*/padOp.getMixedLowPad(), sliceSizes, sliceStrides);
return success();
}
};
-
} // namespace
+Value linalg::bufferizeToAllocation(RewriterBase &rewriter, Value value,
+ Attribute memorySpace) {
+ // Call specialized overload for certain ops.
+ if (auto padOp = value.getDefiningOp<PadOp>())
+ return bufferizeToAllocation(rewriter, padOp, memorySpace);
+
+ // Collect all uses.
+ SmallVector<OpOperand *> uses = llvm::to_vector(
+ llvm::map_range(value.getUses(), [](OpOperand &use) { return &use; }));
+
+ OpBuilder::InsertionGuard g(rewriter);
+ if (auto bbArg = value.dyn_cast<BlockArgument>()) {
+ rewriter.setInsertionPointToStart(bbArg.getOwner());
+ } else {
+ rewriter.setInsertionPoint(value.getDefiningOp());
+ }
+ Location loc = value.getLoc();
+
+ // Create buffer allocation.
+ Value alloc = createAllocationForTensor(rewriter, loc, value, memorySpace);
+
+ // Create memref.tensor_store.
+ rewriter.create<memref::TensorStoreOp>(loc, value, alloc);
+
+ // Create bufferization.to_tensor with "restrict" and "writable". The returned
+ // tensor is a new buffer allocation, so it does not alias with any buffer.
+ Value toTensorOp = rewriter.create<bufferization::ToTensorOp>(
+ loc, alloc, /*restrict=*/true, /*writable=*/true);
+ for (OpOperand *use : uses) {
+ rewriter.updateRootInPlace(use->getOwner(),
+ [&]() { use->set(toTensorOp); });
+ }
+
+ return toTensorOp;
+}
+
void linalg::populateConvertToDestinationStylePatterns(
RewritePatternSet &patterns) {
patterns.insert<FromElementsOpConverter, GenerateOpConverter, PadOpConverter>(
diff --git a/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir b/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir
new file mode 100644
index 0000000000000..d6a282d2e175d
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir
@@ -0,0 +1,61 @@
+// RUN: mlir-opt -split-input-file \
+// RUN: -test-transform-dialect-interpreter -canonicalize \
+// RUN: -allow-unregistered-dialect -split-input-file %s | FileCheck %s
+
+// CHECK: #[[$map:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 5)>
+// CHECK: #[[$map1:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 10)>
+// CHECK-LABEL: func @tensor_pad_constant(
+// CHECK-SAME: %[[t:.*]]: tensor<?x10xindex>, %[[l2:.*]]: index, %[[h1:.*]]: index, %[[h2:.*]]: index
+// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[c50:.*]] = arith.constant 50 : index
+// CHECK-DAG: %[[dim0:.*]] = tensor.dim %[[t]], %[[c0]]
+// CHECK-DAG: %[[size0:.*]] = affine.apply #[[$map]]()[%[[h1]], %[[dim0]]]
+// CHECK-DAG: %[[size1:.*]] = affine.apply #[[$map1]]()[%[[l2]], %[[h2]]]
+// CHECK: %[[alloc:.*]] = memref.alloc(%[[size0]], %[[size1]]) : memref<?x?xindex>
+// CHECK: linalg.fill ins(%[[c50]] : index) outs(%[[alloc]] : memref<?x?xindex>)
+// CHECK: %[[dim0:.*]] = tensor.dim %[[t]], %[[c0]]
+// CHECK: %[[subview:.*]] = memref.subview %[[alloc]][5, %[[l2]]] [%[[dim0]], 10] [1, 1]
+// CHECK: memref.tensor_store %[[t]], %[[subview]]
+// CHECK: %[[r:.*]] = bufferization.to_tensor %[[alloc]] restrict writable : memref<?x?xindex>
+// CHECK: memref.dealloc %[[alloc]]
+// CHECK: return %[[r]]
+func.func @tensor_pad_constant(%t: tensor<?x10xindex>, %l2: index, %h1: index,
+ %h2: index) -> tensor<?x?xindex> {
+ %0 = tensor.pad %t low[5, %l2] high[%h1, %h2] {
+ ^bb0(%arg0: index, %arg1: index):
+ %c = arith.constant 50 : index
+ tensor.yield %c : index
+ } : tensor<?x10xindex> to tensor<?x?xindex>
+ return %0 : tensor<?x?xindex>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["tensor.pad"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+ %1 = transform.get_result %0[0] : (!pdl.operation) -> !transform.any_value
+ %2 = transform.structured.bufferize_to_allocation %1
+}
+
+// -----
+
+// CHECK-LABEL: func @materialization_of_bbarg(
+// CHECK-SAME: %[[t:.*]]: tensor<?x10xindex>
+// CHECK: %[[c0:.*]] = arith.constant 0 : index
+// CHECK: %[[dim:.*]] = tensor.dim %[[t]], %[[c0]]
+// CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]]) : memref<?x10xindex, 4>
+// CHECK: memref.tensor_store %[[t]], %[[alloc]]
+// CHECK: %[[alloc_t:.*]] = bufferization.to_tensor %[[alloc]] restrict writable
+// CHECK: %[[r:.*]] = tensor.extract %[[alloc_t]]
+// CHECK: memref.dealloc %[[alloc]]
+// CHECK: return %[[r]]
+func.func @materialization_of_bbarg(%t: tensor<?x10xindex>, %idx: index) -> index {
+ %r = tensor.extract %t[%idx, %idx] : tensor<?x10xindex>
+ return %r : index
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["tensor.extract"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+ %1 = test_produce_value_handle_to_argument_of_parent_block %0, 0 : (!pdl.operation) -> !transform.any_value
+ %2 = transform.structured.bufferize_to_allocation %1 {memory_space = 4}
+}
More information about the Mlir-commits
mailing list