[Mlir-commits] [mlir] 325426d - [mlir][MemRef] Introduce a memref.extract_metadata op.

Nicolas Vasilache llvmlistbot at llvm.org
Fri Aug 26 09:09:25 PDT 2022


Author: Nicolas Vasilache
Date: 2022-08-26T09:09:15-07:00
New Revision: 325426d72ce50c35e52ce801dcbfabc4a5a2afb3

URL: https://github.com/llvm/llvm-project/commit/325426d72ce50c35e52ce801dcbfabc4a5a2afb3
DIFF: https://github.com/llvm/llvm-project/commit/325426d72ce50c35e52ce801dcbfabc4a5a2afb3.diff

LOG: [mlir][MemRef] Introduce a memref.extract_metadata op.

This is the counterpart of `memref.reinterpret_cast` and is useful to lift
strided memref manipulation out of the LLVM dialect.

Discussion: https://discourse.llvm.org/t/extracting-dynamic-offsets-strides-from-memref/64170

Differential Revision: https://reviews.llvm.org/D132243

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
    mlir/test/Dialect/MemRef/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 082f4bfa0a764..5ef8d8fcefb5a 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -693,6 +693,71 @@ def MemRef_DmaWaitOp : MemRef_Op<"dma_wait"> {
   let hasVerifier = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// ExtractMetadataOp
+//===----------------------------------------------------------------------===//
+
+def MemRef_ExtractStridedMetadataOp : MemRef_Op<"extract_strided_metadata",
+    [SameVariadicResultSize]> {
+  let summary = "Extracts a buffer base with offset and strides";
+  let description = [{
+    Extracts a base buffer, offset and strides. This op allows additional layers
+    of transformations and foldings to be added as lowering progresses from
+    higher-level dialect to lower-level dialects such as the LLVM dialect.
+
+    The op requires a strided memref source operand. If the source operand is not
+    a strided memref, then verification fails.
+
+    This operation is also useful for completeness to the existing memref.dim op.
+    While accessing strides, offsets and the base pointer independently is not
+    available, this is useful for composing with its natural complement op: 
+    `memref.reinterpret_cast`.
+
+    Intended Use Cases:
+
+    The main use case is to expose the logic for manipulate memref metadata at a
+    higher level than the LLVM dialect. 
+    This makes lowering more progressive and brings the following benefits:
+      - not all users of MLIR want to lower to LLVM and the information to e.g.
+        lower to library calls---like libxsmm---or to SPIR-V was not available.
+      - foldings and canonicalizations can happen at a higher level in MLIR: 
+        before this op existed, lowering to LLVM would create large amounts of 
+        LLVMIR. Even when LLVM does a good job at folding the low-level IR from
+        a performance perspective, it is unnecessarily opaque and inefficient to
+        send unkempt IR to LLVM.
+
+    Example:
+
+    ```mlir
+      %base, %offset, %sizes:2, %strides:2 = 
+        memref.extract_strided_metadata %memref : 
+          memref<10x?xf32>, index, index, index, index, index
+
+      // After folding, the type of %m2 can be memref<10x?xf32> and further 
+      // folded to %memref.
+      %m2 = memref.reinterpret_cast %base to
+          offset: [%offset],
+          sizes: [%sizes#0, %sizes#1],
+          strides: [%strides#0, %strides#1]
+        : memref<f32> to memref<?x?xf32, offset: ?, strides: [?, ?]>
+    ```
+  }];
+
+  let arguments = (ins
+    AnyStridedMemRef:$source
+  );
+  let results = (outs
+    AnyStridedMemRefOfRank<0>:$base_buffer,
+    Index:$offset,
+    Variadic<Index>:$sizes,
+    Variadic<Index>:$strides
+  );
+
+  let assemblyFormat = [{
+    $source `:` type($source) `->` type(results) attr-dict
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // GenericAtomicRMWOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir
index 39a15b121034c..eda7c4cb3bf5e 100644
--- a/mlir/test/Dialect/MemRef/ops.mlir
+++ b/mlir/test/Dialect/MemRef/ops.mlir
@@ -336,3 +336,20 @@ func.func @generic_atomic_rmw(%I: memref<1x2xf32>, %i : index, %j : index) {
   } { index_attr = 8 : index }
   return
 }
+
+// -----
+
+func.func @extract_strided_metadata(%memref : memref<10x?xf32>) 
+    -> memref<?x?xf32, offset: ?, strides: [?, ?]> {
+
+  %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %memref
+    : memref<10x?xf32> -> memref<f32>, index, index, index, index, index
+
+  %m2 = memref.reinterpret_cast %base to
+      offset: [%offset],
+      sizes: [%sizes#0, %sizes#1],
+      strides: [%strides#0, %strides#1]
+    : memref<f32> to memref<?x?xf32, offset: ?, strides: [?, ?]>
+
+  return %m2: memref<?x?xf32, offset: ?, strides: [?, ?]>
+}


        


More information about the Mlir-commits mailing list