[Mlir-commits] [mlir] [mlir][SCF] Bufferize scf.index_switch (PR #67666)

Matthias Springer llvmlistbot at llvm.org
Thu Sep 28 08:21:08 PDT 2023


================
@@ -295,6 +295,117 @@ struct IfOpInterface
   }
 };
 
+/// Bufferization of scf.index_switch. Replace with a new scf.index_switch that
+/// yields memrefs.
+struct IndexSwitchOpInterface
+    : public BufferizableOpInterface::ExternalModel<IndexSwitchOpInterface,
+                                                    scf::IndexSwitchOp> {
+  AliasingOpOperandList
+  getAliasingOpOperands(Operation *op, Value value,
+                        const AnalysisState &state) const {
+    // IndexSwitchOps do not have tensor OpOperands. The yielded value can be
+    // any SSA. This is similar to IfOps.
+    auto switchOp = cast<scf::IndexSwitchOp>(op);
+    int64_t resultNum = cast<OpResult>(value).getResultNumber();
+    AliasingOpOperandList result;
+    for (int64_t i = 0, numCases = switchOp.getNumCases(); i < numCases; ++i) {
+      auto yieldOp =
+          cast<scf::YieldOp>(switchOp.getCaseBlock(i).getTerminator());
+      result.addAlias(AliasingOpOperand(&yieldOp->getOpOperand(resultNum),
+                                        BufferRelation::Equivalent,
+                                        /*isDefinite=*/false));
+    }
+    return result;
+  }
+
+  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+                          const BufferizationOptions &options) const {
+    OpBuilder::InsertionGuard g(rewriter);
+    auto switchOp = cast<scf::IndexSwitchOp>(op);
+
+    // Compute bufferized result types.
+    SmallVector<Type> newTypes;
+    for (Value result : switchOp.getResults()) {
+      if (!isa<TensorType>(result.getType())) {
+        newTypes.push_back(result.getType());
+        continue;
+      }
+      auto bufferType = bufferization::getBufferType(result, options);
+      if (failed(bufferType))
+        return failure();
+      newTypes.push_back(*bufferType);
+    }
+
+    // Create new op.
+    rewriter.setInsertionPoint(switchOp);
+    auto newSwitchOp = rewriter.create<scf::IndexSwitchOp>(
+        switchOp.getLoc(), newTypes, switchOp.getArg(), switchOp.getCases(),
+        switchOp.getCases().size());
----------------
matthias-springer wrote:

We generally do not preserve custom attributes. Most transformations just drop them.

https://github.com/llvm/llvm-project/pull/67666


More information about the Mlir-commits mailing list