[PATCH] D135734: {NFC][mlir][MemRef] Add a builder for `extract_strided_metadata(source)`

Quentin Colombet via Phabricator via llvm-commits llvm-commits at lists.llvm.org
Tue Oct 11 18:03:24 PDT 2022


qcolombet created this revision.
qcolombet added reviewers: nicolasvasilache, chelini.
qcolombet added a project: MLIR.
Herald added subscribers: bzcheeseman, sdasgup3, wenzhicui, wrengr, cota, teijeong, rdzhabarov, tatianashp, msifontes, jurahul, Kayjukh, grosul1, Joonsoo, liufengdb, aartbik, mgester, arpith-jacob, antiagainst, shauheen, rriddle, mehdi_amini.
Herald added a project: All.
qcolombet requested review of this revision.
Herald added a subscriber: stephenneuendorffer.

The new builder infers the number of results and their types directly from the shape of `source`.
This makes the code easier to write and understand.

Note, although this patch is NFC, I wanted someone to review it, since this is the first time I create a builder.


Repository:
  rG LLVM Github Monorepo

https://reviews.llvm.org/D135734

Files:
  mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
  mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
  mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp


Index: mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
===================================================================
--- mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
+++ mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
@@ -59,16 +59,12 @@
     // Build a plain extract_strided_metadata(memref) from
     // extract_strided_metadata(subview(memref)).
     Location origLoc = op.getLoc();
-    IndexType indexType = rewriter.getIndexType();
     Value source = subview.getSource();
     auto sourceType = source.getType().cast<MemRefType>();
     unsigned sourceRank = sourceType.getRank();
-    SmallVector<Type> sizeStrideTypes(sourceRank, indexType);
 
     auto newExtractStridedMetadata =
-        rewriter.create<memref::ExtractStridedMetadataOp>(
-            origLoc, op.getBaseBuffer().getType(), indexType, sizeStrideTypes,
-            sizeStrideTypes, source);
+        rewriter.create<memref::ExtractStridedMetadataOp>(origLoc, source);
 
     SmallVector<int64_t> sourceStrides;
     int64_t sourceOffset;
@@ -486,16 +482,12 @@
     // Build a plain extract_strided_metadata(memref) from
     // extract_strided_metadata(reassociative_reshape_like(memref)).
     Location origLoc = op.getLoc();
-    IndexType indexType = rewriter.getIndexType();
     Value source = reshape.getSrc();
     auto sourceType = source.getType().cast<MemRefType>();
     unsigned sourceRank = sourceType.getRank();
-    SmallVector<Type> sizeStrideTypes(sourceRank, indexType);
 
     auto newExtractStridedMetadata =
-        rewriter.create<memref::ExtractStridedMetadataOp>(
-            origLoc, op.getBaseBuffer().getType(), indexType, sizeStrideTypes,
-            sizeStrideTypes, source);
+        rewriter.create<memref::ExtractStridedMetadataOp>(origLoc, source);
 
     // Collect statically known information.
     SmallVector<int64_t> strides;
Index: mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
===================================================================
--- mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1240,6 +1240,20 @@
 // ExtractStridedMetadataOp
 //===----------------------------------------------------------------------===//
 
+void ExtractStridedMetadataOp::build(mlir::OpBuilder &builder,
+                                     mlir::OperationState &state,
+                                     Value source) {
+  auto sourceType = source.getType().cast<MemRefType>();
+  unsigned sourceRank = sourceType.getRank();
+  IndexType indexType = builder.getIndexType();
+  SmallVector<Type> sizeStrideTypes(sourceRank, indexType);
+  auto memrefType =
+      MemRefType::get({}, sourceType.getElementType(),
+                      MemRefLayoutAttrInterface{}, sourceType.getMemorySpace());
+  ExtractStridedMetadataOp::build(builder, state, memrefType, indexType,
+                                  sizeStrideTypes, sizeStrideTypes, source);
+}
+
 void ExtractStridedMetadataOp::getAsmResultNames(
     function_ref<void(Value, StringRef)> setNameFn) {
   setNameFn(getBaseBuffer(), "base_buffer");
Index: mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
===================================================================
--- mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -905,6 +905,11 @@
     Variadic<Index>:$strides
   );
 
+  // Build `extract_strided_metadata(source)`.
+  // The number and type of the results are inferred from the
+  // shape of the source.
+  let builders = [OpBuilder<(ins "Value":$source)>];
+
   let assemblyFormat = [{
     $source `:` type($source) `->` type(results) attr-dict
   }];


-------------- next part --------------
A non-text attachment was scrubbed...
Name: D135734.466986.patch
Type: text/x-patch
Size: 3709 bytes
Desc: not available
URL: <http://lists.llvm.org/pipermail/llvm-commits/attachments/20221012/7a453802/attachment.bin>


More information about the llvm-commits mailing list