[PATCH] D135734: {NFC][mlir][MemRef] Add a builder for `extract_strided_metadata(source)`
Quentin Colombet via Phabricator via llvm-commits
llvm-commits at lists.llvm.org
Tue Oct 11 18:03:24 PDT 2022
qcolombet created this revision.
qcolombet added reviewers: nicolasvasilache, chelini.
qcolombet added a project: MLIR.
Herald added subscribers: bzcheeseman, sdasgup3, wenzhicui, wrengr, cota, teijeong, rdzhabarov, tatianashp, msifontes, jurahul, Kayjukh, grosul1, Joonsoo, liufengdb, aartbik, mgester, arpith-jacob, antiagainst, shauheen, rriddle, mehdi_amini.
Herald added a project: All.
qcolombet requested review of this revision.
Herald added a subscriber: stephenneuendorffer.
The new builder infers the number of results and their types directly from the shape of `source`.
This makes the code easier to write and understand.
Note, although this patch is NFC, I wanted someone to review it, since this is the first time I create a builder.
Repository:
rG LLVM Github Monorepo
https://reviews.llvm.org/D135734
Files:
mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
Index: mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
===================================================================
--- mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
+++ mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
@@ -59,16 +59,12 @@
// Build a plain extract_strided_metadata(memref) from
// extract_strided_metadata(subview(memref)).
Location origLoc = op.getLoc();
- IndexType indexType = rewriter.getIndexType();
Value source = subview.getSource();
auto sourceType = source.getType().cast<MemRefType>();
unsigned sourceRank = sourceType.getRank();
- SmallVector<Type> sizeStrideTypes(sourceRank, indexType);
auto newExtractStridedMetadata =
- rewriter.create<memref::ExtractStridedMetadataOp>(
- origLoc, op.getBaseBuffer().getType(), indexType, sizeStrideTypes,
- sizeStrideTypes, source);
+ rewriter.create<memref::ExtractStridedMetadataOp>(origLoc, source);
SmallVector<int64_t> sourceStrides;
int64_t sourceOffset;
@@ -486,16 +482,12 @@
// Build a plain extract_strided_metadata(memref) from
// extract_strided_metadata(reassociative_reshape_like(memref)).
Location origLoc = op.getLoc();
- IndexType indexType = rewriter.getIndexType();
Value source = reshape.getSrc();
auto sourceType = source.getType().cast<MemRefType>();
unsigned sourceRank = sourceType.getRank();
- SmallVector<Type> sizeStrideTypes(sourceRank, indexType);
auto newExtractStridedMetadata =
- rewriter.create<memref::ExtractStridedMetadataOp>(
- origLoc, op.getBaseBuffer().getType(), indexType, sizeStrideTypes,
- sizeStrideTypes, source);
+ rewriter.create<memref::ExtractStridedMetadataOp>(origLoc, source);
// Collect statically known information.
SmallVector<int64_t> strides;
Index: mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
===================================================================
--- mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1240,6 +1240,20 @@
// ExtractStridedMetadataOp
//===----------------------------------------------------------------------===//
+void ExtractStridedMetadataOp::build(mlir::OpBuilder &builder,
+ mlir::OperationState &state,
+ Value source) {
+ auto sourceType = source.getType().cast<MemRefType>();
+ unsigned sourceRank = sourceType.getRank();
+ IndexType indexType = builder.getIndexType();
+ SmallVector<Type> sizeStrideTypes(sourceRank, indexType);
+ auto memrefType =
+ MemRefType::get({}, sourceType.getElementType(),
+ MemRefLayoutAttrInterface{}, sourceType.getMemorySpace());
+ ExtractStridedMetadataOp::build(builder, state, memrefType, indexType,
+ sizeStrideTypes, sizeStrideTypes, source);
+}
+
void ExtractStridedMetadataOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getBaseBuffer(), "base_buffer");
Index: mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
===================================================================
--- mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -905,6 +905,11 @@
Variadic<Index>:$strides
);
+ // Build `extract_strided_metadata(source)`.
+ // The number and type of the results are inferred from the
+ // shape of the source.
+ let builders = [OpBuilder<(ins "Value":$source)>];
+
let assemblyFormat = [{
$source `:` type($source) `->` type(results) attr-dict
}];
-------------- next part --------------
A non-text attachment was scrubbed...
Name: D135734.466986.patch
Type: text/x-patch
Size: 3709 bytes
Desc: not available
URL: <http://lists.llvm.org/pipermail/llvm-commits/attachments/20221012/7a453802/attachment.bin>
More information about the llvm-commits
mailing list