[Mlir-commits] [mlir] bf9d8d9 - [mlir][linalg][bufferize][NFC] Rename functions in BufferizationState
Matthias Springer
llvmlistbot at llvm.org
Thu Jan 6 12:29:13 PST 2022
Author: Matthias Springer
Date: 2022-01-07T05:28:58+09:00
New Revision: bf9d8d9dfb8f4a0d326a08f52917c008574d60f8
URL: https://github.com/llvm/llvm-project/commit/bf9d8d9dfb8f4a0d326a08f52917c008574d60f8
DIFF: https://github.com/llvm/llvm-project/commit/bf9d8d9dfb8f4a0d326a08f52917c008574d60f8.diff
LOG: [mlir][linalg][bufferize][NFC] Rename functions in BufferizationState
The old function names (e.g., `replaceOp`) could have been confusing to users because they sound similar to rewriter functions, but have slightly different semantics.
Differential Revision: https://reviews.llvm.org/D116449
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
index 80278fc220116..ccc939ffe030d 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
@@ -380,22 +380,6 @@ class BufferizationState {
/// Creates a memcpy between two given buffers.
void createMemCpy(OpBuilder &b, Location loc, Value from, Value to) const;
- /// Replace an op with replacement values. The op is deleted. Tensor OpResults
- /// must be replaced with memref values.
- void replaceOp(RewriterBase &rewriter, Operation *op,
- ValueRange values) const;
-
- /// Replace an op with a new op. Tensor OpResults must be replaced with memref
- /// values.
- template <typename OpTy, typename... Args>
- OpTy replaceOpWithNewOp(RewriterBase &rewriter, Operation *op,
- Args &&...args) const {
- Operation *newOp =
- rewriter.create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
- replaceOp(rewriter, op, newOp->getResults());
- return cast<OpTy>(newOp);
- }
-
/// Lookup the memref buffer that is associated to the given tensor value.
/// Asserts if no buffer is associated.
Value lookupBuffer(RewriterBase &rewriter, Value tensor) const;
@@ -443,6 +427,21 @@ class BufferizationState {
const BufferizationOptions &options;
};
+/// Replace an op with replacement values. The op is deleted. Tensor OpResults
+/// must be replaced with memref values.
+void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op,
+ ValueRange values);
+
+/// Replace an op with a new op. Tensor OpResults must be replaced with memref
+/// values.
+template <typename OpTy, typename... Args>
+OpTy replaceOpWithNewBufferizedOp(RewriterBase &rewriter, Operation *op,
+ Args &&...args) {
+ auto newOp = rewriter.create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
+ replaceOpWithBufferizedValues(rewriter, op, newOp->getResults());
+ return newOp;
+}
+
/// Return a contiguous MemRefType (i.e. with canonical/empty layout map)
/// with the same shape as `shapedType` and specified `layout` and
/// `addressSpace`.
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp
index 8474c127b1206..40d54445c5e82 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp
@@ -35,7 +35,7 @@ struct ConstantOpInterface
GlobalCreator globalCreator(moduleOp);
auto globalMemref = globalCreator.getGlobalFor(constantOp);
- state.replaceOpWithNewOp<memref::GetGlobalOp>(
+ replaceOpWithNewBufferizedOp<memref::GetGlobalOp>(
rewriter, op, globalMemref.type(), globalMemref.getName());
return success();
}
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
index 404bb457b20b6..785c49e5f985e 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
@@ -422,8 +422,8 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::
return operandBuffer;
}
-void mlir::linalg::comprehensive_bufferize::BufferizationState::replaceOp(
- RewriterBase &rewriter, Operation *op, ValueRange values) const {
+void mlir::linalg::comprehensive_bufferize::replaceOpWithBufferizedValues(
+ RewriterBase &rewriter, Operation *op, ValueRange values) {
OpBuilder::InsertionGuard g(rewriter);
// Replace all OpResults with the given values.
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
index eb3a52ac3bb57..9922130920dec 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
@@ -67,7 +67,7 @@ static LogicalResult bufferizeLinalgOp(RewriterBase &rewriter, LinalgOp op,
op.clone(rewriter, op.getLoc(), /*resultTypes=*/TypeRange{}, newOperands);
// Replace the results of the old op with the new output buffers.
- state.replaceOp(rewriter, op, newOutputBuffers);
+ replaceOpWithBufferizedValues(rewriter, op, newOutputBuffers);
return success();
}
@@ -201,7 +201,7 @@ struct InitTensorOpInterface
Value alloc = state.createAllocDeallocPair(rewriter, initTensorOp->getLoc(),
initTensorOp.result());
- state.replaceOp(rewriter, op, alloc);
+ replaceOpWithBufferizedValues(rewriter, op, alloc);
return success();
}
};
@@ -342,7 +342,7 @@ struct TiledLoopOpInterface
rewriter.eraseOp(oldTerminator);
// Replace results and delete old op.
- state.replaceOp(rewriter, op, newResults);
+ replaceOpWithBufferizedValues(rewriter, op, newResults);
return success();
}
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
index 09cb011a13ab5..7d9a5648b1284 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
@@ -634,7 +634,7 @@ struct CallOpInterface
}
// 5. Replace the old op with the new op.
- state.replaceOp(rewriter, callOp, replacementValues);
+ replaceOpWithBufferizedValues(rewriter, callOp, replacementValues);
return success();
}
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
index 9a81259466c50..1c4185c8ffef9 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
@@ -192,7 +192,7 @@ struct IfOpInterface
}
// Replace op results.
- state.replaceOp(rewriter, op, newIfOp->getResults());
+ replaceOpWithBufferizedValues(rewriter, op, newIfOp->getResults());
return success();
}
@@ -326,7 +326,7 @@ struct ForOpInterface
yieldOp.getResultsMutable().assign(yieldValues);
// Replace loop results.
- state.replaceOp(rewriter, op, newForOp->getResults());
+ replaceOpWithBufferizedValues(rewriter, op, newForOp->getResults());
return success();
}
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
index 550e585e4736a..894b5c60c4e16 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
@@ -68,8 +68,8 @@ struct CastOpInterface
: MemRefLayoutAttrInterface();
Type memRefType = getContiguousOrUnrankedMemRefType(
castOp.getResult().getType(), layout, memorySpace);
- state.replaceOpWithNewOp<memref::CastOp>(rewriter, op, memRefType,
- resultBuffer);
+ replaceOpWithNewBufferizedOp<memref::CastOp>(rewriter, op, memRefType,
+ resultBuffer);
return success();
}
};
@@ -98,7 +98,7 @@ struct DimOpInterface
if (!dimOp.source().getType().isa<RankedTensorType>())
return dimOp.emitError("unranked tensor not supported");
Value v = state.lookupBuffer(rewriter, dimOp.source());
- state.replaceOpWithNewOp<memref::DimOp>(rewriter, op, v, dimOp.index());
+ replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, v, dimOp.index());
return success();
}
};
@@ -164,7 +164,7 @@ struct ExtractSliceOpInterface
subView = alloc;
}
- state.replaceOp(rewriter, op, subView);
+ replaceOpWithBufferizedValues(rewriter, op, subView);
return success();
}
};
@@ -191,8 +191,8 @@ struct ExtractOpInterface
const BufferizationState &state) const {
auto extractOp = cast<tensor::ExtractOp>(op);
Value srcMemref = state.lookupBuffer(rewriter, extractOp.tensor());
- state.replaceOpWithNewOp<memref::LoadOp>(rewriter, op, srcMemref,
- extractOp.indices());
+ replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, srcMemref,
+ extractOp.indices());
return success();
}
};
@@ -231,7 +231,7 @@ struct InsertOpInterface
state.getResultBuffer(rewriter, insertOp->getOpResult(0));
rewriter.create<memref::StoreOp>(loc, insertOp.scalar(), destMemref,
insertOp.indices());
- state.replaceOp(rewriter, op, destMemref);
+ replaceOpWithBufferizedValues(rewriter, op, destMemref);
return success();
}
@@ -413,7 +413,7 @@ struct InsertSliceOpInterface
Value srcMemref = state.lookupBuffer(rewriter, insertSliceOp.source());
state.createMemCpy(rewriter, insertSliceOp.getLoc(), srcMemref, subView);
- state.replaceOp(rewriter, op, dstMemref);
+ replaceOpWithBufferizedValues(rewriter, op, dstMemref);
return success();
}
};
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
index 0d66e8879563c..d4c57617b004b 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
@@ -47,11 +47,10 @@ struct TransferReadOpInterface
// TransferReadOp always reads from the bufferized op.source().
Value buffer = state.lookupBuffer(rewriter, readOp.source());
- Value read = rewriter.create<vector::TransferReadOp>(
- readOp.getLoc(), readOp.getVectorType(), buffer, readOp.indices(),
+ replaceOpWithNewBufferizedOp<vector::TransferReadOp>(
+ rewriter, readOp, readOp.getVectorType(), buffer, readOp.indices(),
readOp.permutation_map(), readOp.padding(), readOp.mask(),
readOp.in_boundsAttr());
- state.replaceOp(rewriter, op, read);
return success();
}
};
@@ -101,7 +100,7 @@ struct TransferWriteOpInterface
rewriter.create<vector::TransferWriteOp>(
writeOp.getLoc(), writeOp.vector(), resultBuffer, writeOp.indices(),
writeOp.permutation_mapAttr(), writeOp.in_boundsAttr());
- state.replaceOp(rewriter, op, resultBuffer);
+ replaceOpWithBufferizedValues(rewriter, op, resultBuffer);
return success();
}
More information about the Mlir-commits
mailing list