[Mlir-commits] [mlir] bf9d8d9 - [mlir][linalg][bufferize][NFC] Rename functions in BufferizationState

Matthias Springer llvmlistbot at llvm.org
Thu Jan 6 12:29:13 PST 2022


Author: Matthias Springer
Date: 2022-01-07T05:28:58+09:00
New Revision: bf9d8d9dfb8f4a0d326a08f52917c008574d60f8

URL: https://github.com/llvm/llvm-project/commit/bf9d8d9dfb8f4a0d326a08f52917c008574d60f8
DIFF: https://github.com/llvm/llvm-project/commit/bf9d8d9dfb8f4a0d326a08f52917c008574d60f8.diff

LOG: [mlir][linalg][bufferize][NFC] Rename functions in BufferizationState

The old function names (e.g., `replaceOp`) could have been confusing to users because they sound similar to rewriter functions, but have slightly different semantics.

Differential Revision: https://reviews.llvm.org/D116449

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
index 80278fc220116..ccc939ffe030d 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
@@ -380,22 +380,6 @@ class BufferizationState {
   /// Creates a memcpy between two given buffers.
   void createMemCpy(OpBuilder &b, Location loc, Value from, Value to) const;
 
-  /// Replace an op with replacement values. The op is deleted. Tensor OpResults
-  /// must be replaced with memref values.
-  void replaceOp(RewriterBase &rewriter, Operation *op,
-                 ValueRange values) const;
-
-  /// Replace an op with a new op. Tensor OpResults must be replaced with memref
-  /// values.
-  template <typename OpTy, typename... Args>
-  OpTy replaceOpWithNewOp(RewriterBase &rewriter, Operation *op,
-                          Args &&...args) const {
-    Operation *newOp =
-        rewriter.create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
-    replaceOp(rewriter, op, newOp->getResults());
-    return cast<OpTy>(newOp);
-  }
-
   /// Lookup the memref buffer that is associated to the given tensor value.
   /// Asserts if no buffer is associated.
   Value lookupBuffer(RewriterBase &rewriter, Value tensor) const;
@@ -443,6 +427,21 @@ class BufferizationState {
   const BufferizationOptions &options;
 };
 
+/// Replace an op with replacement values. The op is deleted. Tensor OpResults
+/// must be replaced with memref values.
+void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op,
+                                   ValueRange values);
+
+/// Replace an op with a new op. Tensor OpResults must be replaced with memref
+/// values.
+template <typename OpTy, typename... Args>
+OpTy replaceOpWithNewBufferizedOp(RewriterBase &rewriter, Operation *op,
+                                  Args &&...args) {
+  auto newOp = rewriter.create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
+  replaceOpWithBufferizedValues(rewriter, op, newOp->getResults());
+  return newOp;
+}
+
 /// Return a contiguous MemRefType (i.e. with canonical/empty layout map)
 /// with the same shape as `shapedType` and specified `layout` and
 /// `addressSpace`.

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp
index 8474c127b1206..40d54445c5e82 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp
@@ -35,7 +35,7 @@ struct ConstantOpInterface
 
     GlobalCreator globalCreator(moduleOp);
     auto globalMemref = globalCreator.getGlobalFor(constantOp);
-    state.replaceOpWithNewOp<memref::GetGlobalOp>(
+    replaceOpWithNewBufferizedOp<memref::GetGlobalOp>(
         rewriter, op, globalMemref.type(), globalMemref.getName());
     return success();
   }

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
index 404bb457b20b6..785c49e5f985e 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
@@ -422,8 +422,8 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::
   return operandBuffer;
 }
 
-void mlir::linalg::comprehensive_bufferize::BufferizationState::replaceOp(
-    RewriterBase &rewriter, Operation *op, ValueRange values) const {
+void mlir::linalg::comprehensive_bufferize::replaceOpWithBufferizedValues(
+    RewriterBase &rewriter, Operation *op, ValueRange values) {
   OpBuilder::InsertionGuard g(rewriter);
 
   // Replace all OpResults with the given values.

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
index eb3a52ac3bb57..9922130920dec 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
@@ -67,7 +67,7 @@ static LogicalResult bufferizeLinalgOp(RewriterBase &rewriter, LinalgOp op,
   op.clone(rewriter, op.getLoc(), /*resultTypes=*/TypeRange{}, newOperands);
 
   // Replace the results of the old op with the new output buffers.
-  state.replaceOp(rewriter, op, newOutputBuffers);
+  replaceOpWithBufferizedValues(rewriter, op, newOutputBuffers);
 
   return success();
 }
@@ -201,7 +201,7 @@ struct InitTensorOpInterface
 
     Value alloc = state.createAllocDeallocPair(rewriter, initTensorOp->getLoc(),
                                                initTensorOp.result());
-    state.replaceOp(rewriter, op, alloc);
+    replaceOpWithBufferizedValues(rewriter, op, alloc);
     return success();
   }
 };
@@ -342,7 +342,7 @@ struct TiledLoopOpInterface
     rewriter.eraseOp(oldTerminator);
 
     // Replace results and delete old op.
-    state.replaceOp(rewriter, op, newResults);
+    replaceOpWithBufferizedValues(rewriter, op, newResults);
 
     return success();
   }

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
index 09cb011a13ab5..7d9a5648b1284 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
@@ -634,7 +634,7 @@ struct CallOpInterface
     }
 
     // 5. Replace the old op with the new op.
-    state.replaceOp(rewriter, callOp, replacementValues);
+    replaceOpWithBufferizedValues(rewriter, callOp, replacementValues);
 
     return success();
   }

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
index 9a81259466c50..1c4185c8ffef9 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
@@ -192,7 +192,7 @@ struct IfOpInterface
     }
 
     // Replace op results.
-    state.replaceOp(rewriter, op, newIfOp->getResults());
+    replaceOpWithBufferizedValues(rewriter, op, newIfOp->getResults());
 
     return success();
   }
@@ -326,7 +326,7 @@ struct ForOpInterface
     yieldOp.getResultsMutable().assign(yieldValues);
 
     // Replace loop results.
-    state.replaceOp(rewriter, op, newForOp->getResults());
+    replaceOpWithBufferizedValues(rewriter, op, newForOp->getResults());
 
     return success();
   }

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
index 550e585e4736a..894b5c60c4e16 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
@@ -68,8 +68,8 @@ struct CastOpInterface
             : MemRefLayoutAttrInterface();
     Type memRefType = getContiguousOrUnrankedMemRefType(
         castOp.getResult().getType(), layout, memorySpace);
-    state.replaceOpWithNewOp<memref::CastOp>(rewriter, op, memRefType,
-                                             resultBuffer);
+    replaceOpWithNewBufferizedOp<memref::CastOp>(rewriter, op, memRefType,
+                                                 resultBuffer);
     return success();
   }
 };
@@ -98,7 +98,7 @@ struct DimOpInterface
     if (!dimOp.source().getType().isa<RankedTensorType>())
       return dimOp.emitError("unranked tensor not supported");
     Value v = state.lookupBuffer(rewriter, dimOp.source());
-    state.replaceOpWithNewOp<memref::DimOp>(rewriter, op, v, dimOp.index());
+    replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, v, dimOp.index());
     return success();
   }
 };
@@ -164,7 +164,7 @@ struct ExtractSliceOpInterface
       subView = alloc;
     }
 
-    state.replaceOp(rewriter, op, subView);
+    replaceOpWithBufferizedValues(rewriter, op, subView);
     return success();
   }
 };
@@ -191,8 +191,8 @@ struct ExtractOpInterface
                           const BufferizationState &state) const {
     auto extractOp = cast<tensor::ExtractOp>(op);
     Value srcMemref = state.lookupBuffer(rewriter, extractOp.tensor());
-    state.replaceOpWithNewOp<memref::LoadOp>(rewriter, op, srcMemref,
-                                             extractOp.indices());
+    replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, srcMemref,
+                                                 extractOp.indices());
     return success();
   }
 };
@@ -231,7 +231,7 @@ struct InsertOpInterface
         state.getResultBuffer(rewriter, insertOp->getOpResult(0));
     rewriter.create<memref::StoreOp>(loc, insertOp.scalar(), destMemref,
                                      insertOp.indices());
-    state.replaceOp(rewriter, op, destMemref);
+    replaceOpWithBufferizedValues(rewriter, op, destMemref);
     return success();
   }
 
@@ -413,7 +413,7 @@ struct InsertSliceOpInterface
     Value srcMemref = state.lookupBuffer(rewriter, insertSliceOp.source());
     state.createMemCpy(rewriter, insertSliceOp.getLoc(), srcMemref, subView);
 
-    state.replaceOp(rewriter, op, dstMemref);
+    replaceOpWithBufferizedValues(rewriter, op, dstMemref);
     return success();
   }
 };

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
index 0d66e8879563c..d4c57617b004b 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
@@ -47,11 +47,10 @@ struct TransferReadOpInterface
 
     // TransferReadOp always reads from the bufferized op.source().
     Value buffer = state.lookupBuffer(rewriter, readOp.source());
-    Value read = rewriter.create<vector::TransferReadOp>(
-        readOp.getLoc(), readOp.getVectorType(), buffer, readOp.indices(),
+    replaceOpWithNewBufferizedOp<vector::TransferReadOp>(
+        rewriter, readOp, readOp.getVectorType(), buffer, readOp.indices(),
         readOp.permutation_map(), readOp.padding(), readOp.mask(),
         readOp.in_boundsAttr());
-    state.replaceOp(rewriter, op, read);
     return success();
   }
 };
@@ -101,7 +100,7 @@ struct TransferWriteOpInterface
     rewriter.create<vector::TransferWriteOp>(
         writeOp.getLoc(), writeOp.vector(), resultBuffer, writeOp.indices(),
         writeOp.permutation_mapAttr(), writeOp.in_boundsAttr());
-    state.replaceOp(rewriter, op, resultBuffer);
+    replaceOpWithBufferizedValues(rewriter, op, resultBuffer);
 
     return success();
   }


        


More information about the Mlir-commits mailing list