[Mlir-commits] [mlir] [mlir][SCF] Bufferize scf.index_switch (PR #67666)
Matthias Springer
llvmlistbot at llvm.org
Thu Sep 28 05:47:48 PDT 2023
https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/67666
Add the `BufferizableOpInterface` implementation of `scf.index_switch`.
>From 9e8f9364d27aa67a2c00be1a75e68f889681ddcb Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Thu, 28 Sep 2023 14:47:01 +0200
Subject: [PATCH] [mlir][SCF] Bufferize scf.index_switch
Add the `BufferizableOpInterface` implementation of `scf.index_switch`.
---
.../BufferizableOpInterfaceImpl.cpp | 119 +++++++++++++++++-
mlir/test/Dialect/SCF/one-shot-bufferize.mlir | 31 +++++
2 files changed, 147 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index bcbc693a9742ccc..dff779170bba733 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -295,6 +295,117 @@ struct IfOpInterface
}
};
+/// Bufferization of scf.index_switch. Replace with a new scf.index_switch that
+/// yields memrefs.
+struct IndexSwitchOpInterface
+ : public BufferizableOpInterface::ExternalModel<IndexSwitchOpInterface,
+ scf::IndexSwitchOp> {
+ AliasingOpOperandList
+ getAliasingOpOperands(Operation *op, Value value,
+ const AnalysisState &state) const {
+ // IndexSwitchOps do not have tensor OpOperands. The yielded value can be
+ // any SSA. This is similar to IfOps.
+ auto switchOp = cast<scf::IndexSwitchOp>(op);
+ int64_t resultNum = cast<OpResult>(value).getResultNumber();
+ AliasingOpOperandList result;
+ for (int64_t i = 0, numCases = switchOp.getNumCases(); i < numCases; ++i) {
+ auto yieldOp =
+ cast<scf::YieldOp>(switchOp.getCaseBlock(i).getTerminator());
+ result.addAlias(AliasingOpOperand(&yieldOp->getOpOperand(resultNum),
+ BufferRelation::Equivalent,
+ /*isDefinite=*/false));
+ }
+ return result;
+ }
+
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+ const BufferizationOptions &options) const {
+ OpBuilder::InsertionGuard g(rewriter);
+ auto switchOp = cast<scf::IndexSwitchOp>(op);
+
+ // Compute bufferized result types.
+ SmallVector<Type> newTypes;
+ for (Value result : switchOp.getResults()) {
+ if (!isa<TensorType>(result.getType())) {
+ newTypes.push_back(result.getType());
+ continue;
+ }
+ auto bufferType = bufferization::getBufferType(result, options);
+ if (failed(bufferType))
+ return failure();
+ newTypes.push_back(*bufferType);
+ }
+
+ // Create new op.
+ rewriter.setInsertionPoint(switchOp);
+ auto newSwitchOp = rewriter.create<scf::IndexSwitchOp>(
+ switchOp.getLoc(), newTypes, switchOp.getArg(), switchOp.getCases(),
+ switchOp.getCases().size());
+
+ // Move over blocks.
+ for (auto [src, dest] :
+ llvm::zip(switchOp.getCaseRegions(), newSwitchOp.getCaseRegions()))
+ rewriter.inlineRegionBefore(src, dest, dest.begin());
+ rewriter.inlineRegionBefore(switchOp.getDefaultRegion(),
+ newSwitchOp.getDefaultRegion(),
+ newSwitchOp.getDefaultRegion().begin());
+
+ // Replace op results.
+ replaceOpWithBufferizedValues(rewriter, op, newSwitchOp->getResults());
+
+ return success();
+ }
+
+ FailureOr<BaseMemRefType>
+ getBufferType(Operation *op, Value value, const BufferizationOptions &options,
+ SmallVector<Value> &invocationStack) const {
+ auto switchOp = cast<scf::IndexSwitchOp>(op);
+ assert(value.getDefiningOp() == op && "invalid valid");
+ int64_t resultNum = cast<OpResult>(value).getResultNumber();
+
+ // Helper function to get buffer type of a case.
+ SmallVector<BaseMemRefType> yieldedTypes;
+ auto getYieldedBufferType = [&](Block &b) -> FailureOr<BaseMemRefType> {
+ auto yieldOp = cast<scf::YieldOp>(b.getTerminator());
+ Value yieldedValue = yieldOp->getOperand(resultNum);
+ if (auto bufferType = dyn_cast<BaseMemRefType>(yieldedValue.getType()))
+ return bufferType;
+ auto maybeBufferType =
+ bufferization::getBufferType(yieldedValue, options, invocationStack);
+ if (failed(maybeBufferType))
+ return failure();
+ return maybeBufferType;
+ };
+
+ // Compute buffer type of the default case.
+ auto maybeBufferType = getYieldedBufferType(switchOp.getDefaultBlock());
+ if (failed(maybeBufferType))
+ return failure();
+ BaseMemRefType bufferType = *maybeBufferType;
+
+ // Compute buffer types of all other cases.
+ for (int64_t i = 0, numCases = switchOp.getNumCases(); i < numCases; ++i) {
+ auto yieldedBufferType = getYieldedBufferType(switchOp.getCaseBlock(i));
+ if (failed(yieldedBufferType))
+ return failure();
+
+ // Best case: Both branches have the exact same buffer type.
+ if (bufferType == *yieldedBufferType)
+ continue;
+
+ // Memory space mismatch.
+ if (bufferType.getMemorySpace() != yieldedBufferType->getMemorySpace())
+ return op->emitError("inconsistent memory space on switch cases");
+
+ // Layout maps are different: Promote to fully dynamic layout map.
+ bufferType = getMemRefTypeWithFullyDynamicLayout(
+ cast<TensorType>(value.getType()), bufferType.getMemorySpace());
+ }
+
+ return bufferType;
+ }
+};
+
/// Helper function for loop bufferization. Return the indices of all values
/// that have a tensor type.
static DenseSet<int64_t> getTensorIndices(ValueRange values) {
@@ -1005,8 +1116,8 @@ struct YieldOpInterface
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
auto yieldOp = cast<scf::YieldOp>(op);
- if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::ForOp, scf::WhileOp>(
- yieldOp->getParentOp()))
+ if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::IndexSwitchOp, scf::ForOp,
+ scf::WhileOp>(yieldOp->getParentOp()))
return yieldOp->emitError("unsupported scf::YieldOp parent");
SmallVector<Value> newResults;
@@ -1018,7 +1129,8 @@ struct YieldOpInterface
return failure();
Value buffer = *maybeBuffer;
// We may have to cast the value before yielding it.
- if (isa<scf::ForOp, scf::IfOp>(yieldOp->getParentOp())) {
+ if (isa<scf::ForOp, scf::IfOp, scf::IndexSwitchOp>(
+ yieldOp->getParentOp())) {
FailureOr<BaseMemRefType> resultType = bufferization::getBufferType(
yieldOp->getParentOp()->getResult(it.index()), options);
if (failed(resultType))
@@ -1217,6 +1329,7 @@ void mlir::scf::registerBufferizableOpInterfaceExternalModels(
ExecuteRegionOp::attachInterface<ExecuteRegionOpInterface>(*ctx);
ForOp::attachInterface<ForOpInterface>(*ctx);
IfOp::attachInterface<IfOpInterface>(*ctx);
+ IndexSwitchOp::attachInterface<IndexSwitchOpInterface>(*ctx);
ForallOp::attachInterface<ForallOpInterface>(*ctx);
InParallelOp::attachInterface<InParallelOpInterface>(*ctx);
WhileOp::attachInterface<WhileOpInterface>(*ctx);
diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
index 9b5c0cf048c56f5..1db155b2db38243 100644
--- a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
@@ -921,3 +921,34 @@ func.func @elide_copy_of_non_writing_scf_if(%c: i1, %p1: index, %p2: index, %f:
%r3 = tensor.extract %r[%p2] : tensor<10xf32>
return %r2, %r3 : tensor<10xf32>, f32
}
+
+// -----
+
+// CHECK-LABEL: func @index_switch(
+// CHECK-SAME: %[[pred:.*]]: index, %[[b:.*]]: memref<{{.*}}>, %[[c:.*]]: memref<{{.*}}>) -> memref<{{.*}}>
+func.func @index_switch(%pred: index, %b: tensor<5xf32>, %c: tensor<5xf32>) -> tensor<5xf32> {
+ // Throw in a tensor that bufferizes to a different layout map.
+ // CHECK: %[[a:.*]] = memref.alloc() {{.*}} : memref<5xf32>
+ %a = bufferization.alloc_tensor() : tensor<5xf32>
+
+ // CHECK: %[[r:.*]] = scf.index_switch %[[pred]] -> memref<5xf32, strided<[?], offset: ?>>
+ %0 = scf.index_switch %pred -> tensor<5xf32>
+ // CHECK: case 2 {
+ // CHECK: %[[cast:.*]] = memref.cast %[[a]] : memref<5xf32> to memref<5xf32, strided<[?], offset: ?>>
+ // CHECK: scf.yield %[[cast]]
+ case 2 {
+ scf.yield %a: tensor<5xf32>
+ }
+ // CHECK: case 5 {
+ // CHECK: scf.yield %[[b]] : memref<5xf32, strided<[?], offset: ?>>
+ case 5 {
+ scf.yield %b: tensor<5xf32>
+ }
+ // CHECK: default {
+ // CHECK: scf.yield %[[c]] : memref<5xf32, strided<[?], offset: ?>>
+ default {
+ scf.yield %c: tensor<5xf32>
+ }
+ // CHECK: return %[[r]]
+ return %0 : tensor<5xf32>
+}
More information about the Mlir-commits
mailing list