[Mlir-commits] [mlir] 19efb84 - [mlir][shape][bufferize][NFC] Bufferize block terminators separately

Matthias Springer llvmlistbot at llvm.org
Mon Jun 27 04:08:24 PDT 2022


Author: Matthias Springer
Date: 2022-06-27T13:08:13+02:00
New Revision: 19efb84c7a03a951cdc54662a21e0fd71a33601c

URL: https://github.com/llvm/llvm-project/commit/19efb84c7a03a951cdc54662a21e0fd71a33601c
DIFF: https://github.com/llvm/llvm-project/commit/19efb84c7a03a951cdc54662a21e0fd71a33601c.diff

LOG: [mlir][shape][bufferize][NFC] Bufferize block terminators separately

This allows for better type inference during bufferization and is in preparation of supporting memory spaces.

Differential Revision: https://reviews.llvm.org/D128579

Added: 
    

Modified: 
    mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
index 02a073de45699..9a4f8e187a8b7 100644
--- a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -61,42 +61,17 @@ struct AssumingOpInterface
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           const BufferizationOptions &options) const {
     auto assumingOp = cast<shape::AssumingOp>(op);
-
-    // Compute new result types.
-    SmallVector<Type> newResultTypes;
-    for (Type type : assumingOp->getResultTypes()) {
-      if (auto tensorType = type.dyn_cast<TensorType>()) {
-        // TODO: Infer the result type instead of computing it.
-        newResultTypes.push_back(getMemRefType(tensorType, options));
-      } else {
-        newResultTypes.push_back(type);
-      }
-    }
+    assert(assumingOp.getDoRegion().getBlocks().size() == 1 &&
+           "only 1 block supported");
+    auto yieldOp = cast<shape::AssumingYieldOp>(
+        assumingOp.getDoRegion().front().getTerminator());
 
     // Create new op and move over region.
+    TypeRange newResultTypes(yieldOp.operands());
     auto newOp = rewriter.create<shape::AssumingOp>(
         op->getLoc(), newResultTypes, assumingOp.getWitness());
     newOp.getDoRegion().takeBody(assumingOp.getRegion());
 
-    // Update terminator.
-    assert(newOp.getDoRegion().getBlocks().size() == 1 &&
-           "only 1 block supported");
-    Block *newBlock = &newOp.getDoRegion().front();
-    auto yieldOp = cast<shape::AssumingYieldOp>(newBlock->getTerminator());
-    rewriter.setInsertionPoint(yieldOp);
-    SmallVector<Value> newYieldValues;
-    for (const auto &it : llvm::enumerate(yieldOp.operands())) {
-      Value val = it.value();
-      if (val.getType().isa<TensorType>()) {
-        newYieldValues.push_back(rewriter.create<bufferization::ToMemrefOp>(
-            yieldOp.getLoc(), newResultTypes[it.index()], val));
-      } else {
-        newYieldValues.push_back(val);
-      }
-    }
-    rewriter.replaceOpWithNewOp<shape::AssumingYieldOp>(yieldOp,
-                                                        newYieldValues);
-
     // Update all uses of the old op.
     rewriter.setInsertionPointAfter(newOp);
     SmallVector<Value> newResults;
@@ -153,7 +128,14 @@ struct AssumingYieldOpInterface
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           const BufferizationOptions &options) const {
-    // Op is bufferized as part of AssumingOp.
+    auto yieldOp = cast<shape::AssumingYieldOp>(op);
+    SmallVector<Value> newResults;
+    for (Value value : yieldOp.operands())
+      newResults.push_back(value.getType().isa<TensorType>()
+                               ? getBuffer(rewriter, value, options)
+                               : value);
+    replaceOpWithNewBufferizedOp<shape::AssumingYieldOp>(rewriter, op,
+                                                         newResults);
     return success();
   }
 };


        


More information about the Mlir-commits mailing list