[Mlir-commits] [mlir] [mlir][SCF] Bufferize scf.index_switch (PR	#67666)
    llvmlistbot at llvm.org 
    llvmlistbot at llvm.org
       
    Thu Sep 28 05:49:04 PDT 2023
    
    
  
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-scf
<details>
<summary>Changes</summary>
Add the `BufferizableOpInterface` implementation of `scf.index_switch`.
---
Full diff: https://github.com/llvm/llvm-project/pull/67666.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp (+116-3) 
- (modified) mlir/test/Dialect/SCF/one-shot-bufferize.mlir (+31) 
``````````diff
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index bcbc693a9742ccc..dff779170bba733 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -295,6 +295,117 @@ struct IfOpInterface
   }
 };
 
+/// Bufferization of scf.index_switch. Replace with a new scf.index_switch that
+/// yields memrefs.
+struct IndexSwitchOpInterface
+    : public BufferizableOpInterface::ExternalModel<IndexSwitchOpInterface,
+                                                    scf::IndexSwitchOp> {
+  AliasingOpOperandList
+  getAliasingOpOperands(Operation *op, Value value,
+                        const AnalysisState &state) const {
+    // IndexSwitchOps do not have tensor OpOperands. The yielded value can be
+    // any SSA. This is similar to IfOps.
+    auto switchOp = cast<scf::IndexSwitchOp>(op);
+    int64_t resultNum = cast<OpResult>(value).getResultNumber();
+    AliasingOpOperandList result;
+    for (int64_t i = 0, numCases = switchOp.getNumCases(); i < numCases; ++i) {
+      auto yieldOp =
+          cast<scf::YieldOp>(switchOp.getCaseBlock(i).getTerminator());
+      result.addAlias(AliasingOpOperand(&yieldOp->getOpOperand(resultNum),
+                                        BufferRelation::Equivalent,
+                                        /*isDefinite=*/false));
+    }
+    return result;
+  }
+
+  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+                          const BufferizationOptions &options) const {
+    OpBuilder::InsertionGuard g(rewriter);
+    auto switchOp = cast<scf::IndexSwitchOp>(op);
+
+    // Compute bufferized result types.
+    SmallVector<Type> newTypes;
+    for (Value result : switchOp.getResults()) {
+      if (!isa<TensorType>(result.getType())) {
+        newTypes.push_back(result.getType());
+        continue;
+      }
+      auto bufferType = bufferization::getBufferType(result, options);
+      if (failed(bufferType))
+        return failure();
+      newTypes.push_back(*bufferType);
+    }
+
+    // Create new op.
+    rewriter.setInsertionPoint(switchOp);
+    auto newSwitchOp = rewriter.create<scf::IndexSwitchOp>(
+        switchOp.getLoc(), newTypes, switchOp.getArg(), switchOp.getCases(),
+        switchOp.getCases().size());
+
+    // Move over blocks.
+    for (auto [src, dest] :
+         llvm::zip(switchOp.getCaseRegions(), newSwitchOp.getCaseRegions()))
+      rewriter.inlineRegionBefore(src, dest, dest.begin());
+    rewriter.inlineRegionBefore(switchOp.getDefaultRegion(),
+                                newSwitchOp.getDefaultRegion(),
+                                newSwitchOp.getDefaultRegion().begin());
+
+    // Replace op results.
+    replaceOpWithBufferizedValues(rewriter, op, newSwitchOp->getResults());
+
+    return success();
+  }
+
+  FailureOr<BaseMemRefType>
+  getBufferType(Operation *op, Value value, const BufferizationOptions &options,
+                SmallVector<Value> &invocationStack) const {
+    auto switchOp = cast<scf::IndexSwitchOp>(op);
+    assert(value.getDefiningOp() == op && "invalid valid");
+    int64_t resultNum = cast<OpResult>(value).getResultNumber();
+
+    // Helper function to get buffer type of a case.
+    SmallVector<BaseMemRefType> yieldedTypes;
+    auto getYieldedBufferType = [&](Block &b) -> FailureOr<BaseMemRefType> {
+      auto yieldOp = cast<scf::YieldOp>(b.getTerminator());
+      Value yieldedValue = yieldOp->getOperand(resultNum);
+      if (auto bufferType = dyn_cast<BaseMemRefType>(yieldedValue.getType()))
+        return bufferType;
+      auto maybeBufferType =
+          bufferization::getBufferType(yieldedValue, options, invocationStack);
+      if (failed(maybeBufferType))
+        return failure();
+      return maybeBufferType;
+    };
+
+    // Compute buffer type of the default case.
+    auto maybeBufferType = getYieldedBufferType(switchOp.getDefaultBlock());
+    if (failed(maybeBufferType))
+      return failure();
+    BaseMemRefType bufferType = *maybeBufferType;
+
+    // Compute buffer types of all other cases.
+    for (int64_t i = 0, numCases = switchOp.getNumCases(); i < numCases; ++i) {
+      auto yieldedBufferType = getYieldedBufferType(switchOp.getCaseBlock(i));
+      if (failed(yieldedBufferType))
+        return failure();
+
+      // Best case: Both branches have the exact same buffer type.
+      if (bufferType == *yieldedBufferType)
+        continue;
+
+      // Memory space mismatch.
+      if (bufferType.getMemorySpace() != yieldedBufferType->getMemorySpace())
+        return op->emitError("inconsistent memory space on switch cases");
+
+      // Layout maps are different: Promote to fully dynamic layout map.
+      bufferType = getMemRefTypeWithFullyDynamicLayout(
+          cast<TensorType>(value.getType()), bufferType.getMemorySpace());
+    }
+
+    return bufferType;
+  }
+};
+
 /// Helper function for loop bufferization. Return the indices of all values
 /// that have a tensor type.
 static DenseSet<int64_t> getTensorIndices(ValueRange values) {
@@ -1005,8 +1116,8 @@ struct YieldOpInterface
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           const BufferizationOptions &options) const {
     auto yieldOp = cast<scf::YieldOp>(op);
-    if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::ForOp, scf::WhileOp>(
-            yieldOp->getParentOp()))
+    if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::IndexSwitchOp, scf::ForOp,
+             scf::WhileOp>(yieldOp->getParentOp()))
       return yieldOp->emitError("unsupported scf::YieldOp parent");
 
     SmallVector<Value> newResults;
@@ -1018,7 +1129,8 @@ struct YieldOpInterface
           return failure();
         Value buffer = *maybeBuffer;
         // We may have to cast the value before yielding it.
-        if (isa<scf::ForOp, scf::IfOp>(yieldOp->getParentOp())) {
+        if (isa<scf::ForOp, scf::IfOp, scf::IndexSwitchOp>(
+                yieldOp->getParentOp())) {
           FailureOr<BaseMemRefType> resultType = bufferization::getBufferType(
               yieldOp->getParentOp()->getResult(it.index()), options);
           if (failed(resultType))
@@ -1217,6 +1329,7 @@ void mlir::scf::registerBufferizableOpInterfaceExternalModels(
     ExecuteRegionOp::attachInterface<ExecuteRegionOpInterface>(*ctx);
     ForOp::attachInterface<ForOpInterface>(*ctx);
     IfOp::attachInterface<IfOpInterface>(*ctx);
+    IndexSwitchOp::attachInterface<IndexSwitchOpInterface>(*ctx);
     ForallOp::attachInterface<ForallOpInterface>(*ctx);
     InParallelOp::attachInterface<InParallelOpInterface>(*ctx);
     WhileOp::attachInterface<WhileOpInterface>(*ctx);
diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
index 9b5c0cf048c56f5..1db155b2db38243 100644
--- a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
@@ -921,3 +921,34 @@ func.func @elide_copy_of_non_writing_scf_if(%c: i1, %p1: index, %p2: index, %f:
   %r3 = tensor.extract %r[%p2] : tensor<10xf32>
   return %r2, %r3 : tensor<10xf32>, f32
 }
+
+// -----
+
+// CHECK-LABEL: func @index_switch(
+//  CHECK-SAME:     %[[pred:.*]]: index, %[[b:.*]]: memref<{{.*}}>, %[[c:.*]]: memref<{{.*}}>) -> memref<{{.*}}>
+func.func @index_switch(%pred: index, %b: tensor<5xf32>, %c: tensor<5xf32>) -> tensor<5xf32> {
+  // Throw in a tensor that bufferizes to a different layout map.
+  // CHECK: %[[a:.*]] = memref.alloc() {{.*}} : memref<5xf32>
+  %a = bufferization.alloc_tensor() : tensor<5xf32>
+
+  // CHECK: %[[r:.*]] = scf.index_switch %[[pred]] -> memref<5xf32, strided<[?], offset: ?>>
+  %0 = scf.index_switch %pred -> tensor<5xf32>
+  // CHECK: case 2 {
+  // CHECK:   %[[cast:.*]] = memref.cast %[[a]] : memref<5xf32> to memref<5xf32, strided<[?], offset: ?>>
+  // CHECK:   scf.yield %[[cast]]
+  case 2 {
+    scf.yield %a: tensor<5xf32>
+  }
+  // CHECK: case 5 {
+  // CHECK:   scf.yield %[[b]] : memref<5xf32, strided<[?], offset: ?>>
+  case 5 {
+    scf.yield %b: tensor<5xf32>
+  }
+  // CHECK: default {
+  // CHECK:   scf.yield %[[c]] : memref<5xf32, strided<[?], offset: ?>>
+  default {
+    scf.yield %c: tensor<5xf32>
+  }
+  // CHECK: return %[[r]]
+  return %0 : tensor<5xf32>
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/67666
    
    
More information about the Mlir-commits
mailing list