[Mlir-commits] [mlir] 63a2536 - [mlir][MemRef] Simplify extract_strided_metadata(subview)

Quentin Colombet llvmlistbot at llvm.org
Thu Sep 8 10:16:45 PDT 2022


Author: Quentin Colombet
Date: 2022-09-08T17:10:02Z
New Revision: 63a2536f77a4037902be517399cb16b39fb732e7

URL: https://github.com/llvm/llvm-project/commit/63a2536f77a4037902be517399cb16b39fb732e7
DIFF: https://github.com/llvm/llvm-project/commit/63a2536f77a4037902be517399cb16b39fb732e7.diff

LOG: [mlir][MemRef] Simplify extract_strided_metadata(subview)

Add a dedicated pass to simplify
extract_strided_metadata(other_op(memref)).

Currently the pass features only one pattern:
extract_strided_metadata(subview).
The goal is to get rid of the subview while materializing its effects on
the offset, sizes, and strides with respect to the base object.

In other words, this simplification replaces:
```
baseBuffer, offset, sizes, strides =
    extract_strided_metadata(
        subview(memref, subOffset, subSizes, subStrides))
```

With

```
baseBuffer, baseOffset, baseSizes, baseStrides =
    extract_strided_metadata(memref)
strides#i = baseStrides#i * subSizes#i
offset = baseOffset + sum(subOffset#i * strides#i)
sizes = subSizes
```

Differential Revision: https://reviews.llvm.org/D133166

Added: 
    mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
    mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir

Modified: 
    mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
    mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
    mlir/include/mlir/IR/AffineExpr.h
    mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
index b33ce0e5f2bb3..a5309ddcc42c6 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
@@ -55,6 +55,11 @@ void populateResolveRankedShapeTypeResultDimsPatterns(
 /// terms of shapes of its input operands.
 void populateResolveShapedTypeResultDimsPatterns(RewritePatternSet &patterns);
 
+/// Appends patterns for simplifying extract_strided_metadata(other_op) into
+/// easier to analyze constructs.
+void populateSimplifyExtractStridedMetadataOpPatterns(
+    RewritePatternSet &patterns);
+
 /// Transformation to do multi-buffering/array expansion to remove dependencies
 /// on the temporary allocation between consecutive loop iterations.
 /// It return success if the allocation was multi-buffered and returns failure()
@@ -118,6 +123,11 @@ std::unique_ptr<Pass> createResolveRankedShapeTypeResultDimsPass();
 /// in terms of shapes of its input operands.
 std::unique_ptr<Pass> createResolveShapedTypeResultDimsPass();
 
+/// Creates an operation pass to simplify
+/// `extract_strided_metadata(other_op(memref))` into
+/// `extract_strided_metadata(memref)`.
+std::unique_ptr<Pass> createSimplifyExtractStridedMetadataPass();
+
 //===----------------------------------------------------------------------===//
 // Registration
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
index 5ac124a0fcaa8..64045033cabef 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
@@ -173,5 +173,18 @@ def ResolveShapedTypeResultDims : Pass<"resolve-shaped-type-result-dims"> {
   ];
 }
 
+def SimplifyExtractStridedMetadata : Pass<"simplify-extract-strided-metadata"> {
+  let summary = "Simplify extract_strided_metadata ops";
+  let description = [{
+    The pass simplifies extract_strided_metadata(other_op(memref)) to
+    extract_strided_metadata(memref) when it is possible to model the effect
+    of other_op directly with affine maps applied to the result of
+    extract_strided_metadata.
+  }];
+  let constructor = "mlir::memref::createSimplifyExtractStridedMetadataPass()";
+  let dependentDialects = [
+      "AffineDialect", "memref::MemRefDialect"
+  ];
+}
 #endif // MLIR_DIALECT_MEMREF_TRANSFORMS_PASSES
 

diff  --git a/mlir/include/mlir/IR/AffineExpr.h b/mlir/include/mlir/IR/AffineExpr.h
index 962797824d95b..e72e2990d3559 100644
--- a/mlir/include/mlir/IR/AffineExpr.h
+++ b/mlir/include/mlir/IR/AffineExpr.h
@@ -320,6 +320,14 @@ void bindSymbols(MLIRContext *ctx, AffineExprTy &e, AffineExprTy2 &...exprs) {
   e = getAffineSymbolExpr(N, ctx);
   bindSymbols<N + 1, AffineExprTy2 &...>(ctx, exprs...);
 }
+
+template <typename AffineExprTy>
+void bindSymbolsList(MLIRContext *ctx, SmallVectorImpl<AffineExprTy> &exprs) {
+  int idx = 0;
+  for (AffineExprTy &e : exprs)
+    e = getAffineSymbolExpr(idx++, ctx);
+}
+
 } // namespace detail
 
 /// Bind a list of AffineExpr references to DimExpr at positions:

diff  --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
index f85b6e50e91e4..d64bbef936aeb 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
@@ -5,6 +5,7 @@ add_mlir_dialect_library(MLIRMemRefTransforms
   MultiBuffer.cpp
   NormalizeMemRefs.cpp
   ResolveShapedTypeResultDims.cpp
+  SimplifyExtractStridedMetadata.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/MemRef

diff  --git a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
new file mode 100644
index 0000000000000..3cad0af9b9be2
--- /dev/null
+++ b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
@@ -0,0 +1,199 @@
+//===- SimplifyExtractStridedMetadata.cpp - Simplify this operation -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+/// This pass simplifies extract_strided_metadata(other_op(memref) to
+/// extract_strided_metadata(memref) when it is possible to express the effect
+// of other_op using affine apply on the results of
+// extract_strided_metadata(memref).
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Transforms/Passes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/SmallBitVector.h"
+
+namespace mlir {
+namespace memref {
+#define GEN_PASS_DEF_SIMPLIFYEXTRACTSTRIDEDMETADATA
+#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
+} // namespace memref
+} // namespace mlir
+using namespace mlir;
+
+namespace {
+/// Replace `baseBuffer, offset, sizes, strides =
+///              extract_strided_metadata(subview(memref, subOffset,
+///                                               subSizes, subStrides))`
+/// With
+///
+/// \verbatim
+/// baseBuffer, baseOffset, baseSizes, baseStrides =
+///     extract_strided_metadata(memref)
+/// strides#i = baseStrides#i * subSizes#i
+/// offset = baseOffset + sum(subOffset#i * strides#i)
+/// sizes = subSizes
+/// \endverbatim
+///
+/// In other words, get rid of the subview in that expression and canonicalize
+/// on its effects on the offset, the sizes, and the strides using affine apply.
+struct ExtractStridedMetadataOpSubviewFolder
+    : public OpRewritePattern<memref::ExtractStridedMetadataOp> {
+public:
+  using OpRewritePattern<memref::ExtractStridedMetadataOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
+                                PatternRewriter &rewriter) const override {
+    auto subview = op.getSource().getDefiningOp<memref::SubViewOp>();
+    if (!subview)
+      return failure();
+
+    // Build a plain extract_strided_metadata(memref) from
+    // extract_strided_metadata(subview(memref)).
+    Location origLoc = op.getLoc();
+    IndexType indexType = rewriter.getIndexType();
+    Value source = subview.getSource();
+    auto sourceType = source.getType().cast<MemRefType>();
+    unsigned sourceRank = sourceType.getRank();
+    SmallVector<Type> sizeStrideTypes(sourceRank, indexType);
+
+    auto newExtractStridedMetadata =
+        rewriter.create<memref::ExtractStridedMetadataOp>(
+            origLoc, op.getBaseBuffer().getType(), indexType, sizeStrideTypes,
+            sizeStrideTypes, source);
+
+    SmallVector<int64_t> sourceStrides;
+    int64_t sourceOffset;
+
+    bool hasKnownStridesAndOffset =
+        succeeded(getStridesAndOffset(sourceType, sourceStrides, sourceOffset));
+    (void)hasKnownStridesAndOffset;
+    assert(hasKnownStridesAndOffset &&
+           "getStridesAndOffset must work on valid subviews");
+
+    // Compute the new strides and offset from the base strides and offset:
+    // newStride#i = baseStride#i * subStride#i
+    // offset = baseOffset + sum(subOffsets#i * newStrides#i)
+    SmallVector<OpFoldResult> strides;
+    SmallVector<OpFoldResult> subStrides = subview.getMixedStrides();
+    auto origStrides = newExtractStridedMetadata.getStrides();
+
+    // Hold the affine symbols and values for the computation of the offset.
+    SmallVector<OpFoldResult> values(3 * sourceRank + 1);
+    SmallVector<AffineExpr> symbols(3 * sourceRank + 1);
+
+    detail::bindSymbolsList(rewriter.getContext(), symbols);
+    AffineExpr expr = symbols.front();
+    values[0] = ShapedType::isDynamicStrideOrOffset(sourceOffset)
+                    ? getAsOpFoldResult(newExtractStridedMetadata.getOffset())
+                    : rewriter.getIndexAttr(sourceOffset);
+    SmallVector<OpFoldResult> subOffsets = subview.getMixedOffsets();
+
+    AffineExpr s0 = rewriter.getAffineSymbolExpr(0);
+    AffineExpr s1 = rewriter.getAffineSymbolExpr(1);
+    for (unsigned i = 0; i < sourceRank; ++i) {
+      // Compute the stride.
+      OpFoldResult origStride =
+          ShapedType::isDynamicStrideOrOffset(sourceStrides[i])
+              ? origStrides[i]
+              : OpFoldResult(rewriter.getIndexAttr(sourceStrides[i]));
+      strides.push_back(makeComposedFoldedAffineApply(
+          rewriter, origLoc, s0 * s1, {subStrides[i], origStride}));
+
+      // Build up the computation of the offset.
+      unsigned baseIdxForDim = 1 + 3 * i;
+      unsigned subOffsetForDim = baseIdxForDim;
+      unsigned subStrideForDim = baseIdxForDim + 1;
+      unsigned origStrideForDim = baseIdxForDim + 2;
+      expr = expr + symbols[subOffsetForDim] * symbols[subStrideForDim] *
+                        symbols[origStrideForDim];
+      values[subOffsetForDim] = subOffsets[i];
+      values[subStrideForDim] = subStrides[i];
+      values[origStrideForDim] = origStride;
+    }
+
+    // Compute the offset.
+    OpFoldResult finalOffset =
+        makeComposedFoldedAffineApply(rewriter, origLoc, expr, values);
+
+    SmallVector<Value> results;
+    // The final result is  <baseBuffer, offset, sizes, strides>.
+    // Thus we need 1 + 1 + subview.getRank() + subview.getRank(), to hold all
+    // the values.
+    auto subType = subview.getType().cast<MemRefType>();
+    unsigned subRank = subType.getRank();
+    // Properly size the array so that we can do random insertions
+    // at the right indices.
+    // We do that to populate the non-dropped sizes and strides in one go.
+    results.resize_for_overwrite(subRank * 2 + 2);
+
+    results[0] = newExtractStridedMetadata.getBaseBuffer();
+    results[1] =
+        getValueOrCreateConstantIndexOp(rewriter, origLoc, finalOffset);
+
+    // The sizes of the final type are defined directly by the input sizes of
+    // the subview.
+    // Moreover subviews can drop some dimensions, some strides and sizes may
+    // not end up in the final <base, offset, sizes, strides> value that we are
+    // replacing.
+    // Do the filtering here.
+    SmallVector<OpFoldResult> subSizes = subview.getMixedSizes();
+    const unsigned sizeStartIdx = 2;
+    const unsigned strideStartIdx = sizeStartIdx + subRank;
+    unsigned insertedDims = 0;
+    llvm::SmallBitVector droppedDims = subview.getDroppedDims();
+    for (unsigned i = 0; i < sourceRank; ++i) {
+      if (droppedDims.test(i))
+        continue;
+
+      results[sizeStartIdx + insertedDims] =
+          getValueOrCreateConstantIndexOp(rewriter, origLoc, subSizes[i]);
+      results[strideStartIdx + insertedDims] =
+          getValueOrCreateConstantIndexOp(rewriter, origLoc, strides[i]);
+      ++insertedDims;
+    }
+    assert(insertedDims == subRank &&
+           "Should have populated all the values at this point");
+
+    rewriter.replaceOp(op, results);
+    return success();
+  }
+};
+} // namespace
+
+void memref::populateSimplifyExtractStridedMetadataOpPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<ExtractStridedMetadataOpSubviewFolder>(patterns.getContext());
+}
+
+//===----------------------------------------------------------------------===//
+// Pass registration
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+struct SimplifyExtractStridedMetadataPass final
+    : public memref::impl::SimplifyExtractStridedMetadataBase<
+          SimplifyExtractStridedMetadataPass> {
+  void runOnOperation() override;
+};
+
+} // namespace
+
+void SimplifyExtractStridedMetadataPass::runOnOperation() {
+  RewritePatternSet patterns(&getContext());
+  memref::populateSimplifyExtractStridedMetadataOpPatterns(patterns);
+  (void)applyPatternsAndFoldGreedily(getOperation()->getRegions(),
+                                     std::move(patterns));
+}
+
+std::unique_ptr<Pass> memref::createSimplifyExtractStridedMetadataPass() {
+  return std::make_unique<SimplifyExtractStridedMetadataPass>();
+}

diff  --git a/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir b/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir
new file mode 100644
index 0000000000000..8ef1729c0ce07
--- /dev/null
+++ b/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir
@@ -0,0 +1,283 @@
+// RUN: mlir-opt --simplify-extract-strided-metadata -split-input-file %s -o - | FileCheck %s
+
+// -----
+
+// Check that we simplify extract_strided_metadata of subview to
+// base_buf, base_offset, base_sizes, base_strides = extract_strided_metadata
+// strides = base_stride_i * subview_stride_i
+// offset = base_offset + sum(subview_offsets_i * strides_i).
+//
+// This test also checks that we don't create useless arith operations
+// when subview_offsets_i is 0.
+//
+// CHECK-LABEL: func @extract_strided_metadata_of_subview
+//  CHECK-SAME: (%[[ARG:.*]]: memref<5x4xf32>)
+//
+// Materialize the offset for dimension 1.
+//   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+//   CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+//   CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+//
+// Plain extract_strided_metadata.
+//   CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
+//
+// Final offset is:
+//   origOffset + (== 0)
+//   base_stride0 * subview_stride0 * subview_offset0 + (== 4 * 1 * 0 == 0)
+//   base_stride1 * subview_stride1 * subview_offset1 (== 1 * 1 * 2)
+//  == 2
+//
+// Return the new tuple.
+//       CHECK: return %[[BASE]], %[[C2]], %[[C2]], %[[C2]], %[[C4]], %[[C1]]
+func.func @extract_strided_metadata_of_subview(%base: memref<5x4xf32>)
+    -> (memref<f32>, index, index, index, index, index) {
+
+  %subview = memref.subview %base[0, 2][2, 2][1, 1] :
+    memref<5x4xf32> to memref<2x2xf32, strided<[4, 1], offset: 2>>
+
+  %base_buffer, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %subview :
+    memref<2x2xf32, strided<[4,1], offset:2>>
+    -> memref<f32>, index, index, index, index, index
+
+  return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 :
+    memref<f32>, index, index, index, index, index
+}
+
+// -----
+
+// Check that we simplify extract_strided_metadata of subview properly
+// when dynamic sizes are involved.
+// See extract_strided_metadata_of_subview for an explanation of the actual
+// expansion.
+// Orig strides: [64, 4, 1]
+// Sub strides: [1, 1, 1]
+// => New strides: [64, 4, 1]
+//
+// Orig offset: 0
+// Sub offsets: [3, 4, 2]
+// => Final offset: 3 * 64 + 4 * 4 + 2 * 1 + 0 == 210
+//
+// Final sizes == subview sizes == [%size, 6, 3]
+//
+// CHECK-LABEL: func @extract_strided_metadata_of_subview_with_dynamic_size
+//  CHECK-SAME: (%[[ARG:.*]]: memref<8x16x4xf32>,
+//  CHECK-SAME: %[[DYN_SIZE:.*]]: index)
+//
+//   CHECK-DAG: %[[C210:.*]] = arith.constant 210 : index
+//   CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index
+//   CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index
+//   CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+//   CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
+//   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+//
+//   CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:3, %[[STRIDES:.*]]:3 = memref.extract_strided_metadata %[[ARG]]
+//
+//       CHECK: return %[[BASE]], %[[C210]], %[[DYN_SIZE]], %[[C6]], %[[C3]], %[[C64]], %[[C4]], %[[C1]]
+func.func @extract_strided_metadata_of_subview_with_dynamic_size(
+    %base: memref<8x16x4xf32>, %size: index)
+    -> (memref<f32>, index, index, index, index, index, index, index) {
+
+  %subview = memref.subview %base[3, 4, 2][%size, 6, 3][1, 1, 1] :
+    memref<8x16x4xf32> to memref<?x6x3xf32, strided<[64, 4, 1], offset: 210>>
+
+  %base_buffer, %offset, %sizes:3, %strides:3 = memref.extract_strided_metadata %subview :
+    memref<?x6x3xf32, strided<[64,4,1], offset: 210>>
+    -> memref<f32>, index, index, index, index, index, index, index
+
+  return %base_buffer, %offset, %sizes#0, %sizes#1, %sizes#2, %strides#0, %strides#1, %strides#2 :
+    memref<f32>, index, index, index, index, index, index, index
+}
+
+// -----
+
+// Check that we simplify extract_strided_metadata of subview properly
+// when the subview reduces the ranks.
+// In particular the returned strides must come from #1 and #2 of the %strides
+// value of the new extract_strided_metadata_of_subview, not #0 and #1.
+// See extract_strided_metadata_of_subview for an explanation of the actual
+// expansion.
+//
+// Orig strides: [64, 4, 1]
+// Sub strides: [1, 1, 1]
+// => New strides: [64, 4, 1]
+// Final strides == filterOutReducedDim(new strides, 0) == [4 , 1]
+//
+// Orig offset: 0
+// Sub offsets: [3, 4, 2]
+// => Final offset: 3 * 64 + 4 * 4 + 2 * 1 + 0 == 210
+//
+// Final sizes == filterOutReducedDim(subview sizes, 0) == [6, 3]
+//
+// CHECK-LABEL: func @extract_strided_metadata_of_rank_reduced_subview
+//  CHECK-SAME: (%[[ARG:.*]]: memref<8x16x4xf32>)
+//
+//   CHECK-DAG: %[[C210:.*]] = arith.constant 210 : index
+//   CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index
+//   CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+//   CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
+//   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+//
+//   CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:3, %[[STRIDES:.*]]:3 = memref.extract_strided_metadata %[[ARG]]
+//
+//       CHECK: return %[[BASE]], %[[C210]], %[[C6]], %[[C3]], %[[C4]], %[[C1]]
+func.func @extract_strided_metadata_of_rank_reduced_subview(%base: memref<8x16x4xf32>)
+    -> (memref<f32>, index, index, index, index, index) {
+
+  %subview = memref.subview %base[3, 4, 2][1, 6, 3][1, 1, 1] :
+    memref<8x16x4xf32> to memref<6x3xf32, strided<[4, 1], offset: 210>>
+
+  %base_buffer, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %subview :
+    memref<6x3xf32, strided<[4,1], offset: 210>>
+    -> memref<f32>, index, index, index, index, index
+
+  return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 :
+    memref<f32>, index, index, index, index, index
+}
+
+// -----
+
+// Check that we simplify extract_strided_metadata of subview properly
+// when the subview reduces the rank and some of the strides are variable.
+// In particular, we check that:
+// A. The dynamic stride is multiplied with the base stride to create the new
+//    stride for dimension 1.
+// B. The first returned stride is the value computed in #A.
+// See extract_strided_metadata_of_subview for an explanation of the actual
+// expansion.
+//
+// Orig strides: [64, 4, 1]
+// Sub strides: [1, %stride, 1]
+// => New strides: [64, 4 * %stride, 1]
+// Final strides == filterOutReducedDim(new strides, 0) == [4 * %stride , 1]
+//
+// Orig offset: 0
+// Sub offsets: [3, 4, 2]
+// => Final offset: 3 * 64 + 4 * 4 * %stride + 2 * 1 + 0 == 16 * %stride + 194
+//
+//   CHECK-DAG: #[[$STRIDE1_MAP:.*]] = affine_map<()[s0] -> (s0 * 4)>
+//   CHECK-DAG: #[[$OFFSET_MAP:.*]] = affine_map<()[s0] -> (s0 * 16 + 194)>
+// CHECK-LABEL: func @extract_strided_metadata_of_rank_reduced_subview_w_variable_strides
+//  CHECK-SAME: (%[[ARG:.*]]: memref<8x16x4xf32>,
+//  CHECK-SAME: %[[DYN_STRIDE:.*]]: index)
+//
+//   CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index
+//   CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
+//   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+//
+//   CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:3, %[[STRIDES:.*]]:3 = memref.extract_strided_metadata %[[ARG]]
+//
+//   CHECK-DAG: %[[DIM1_STRIDE:.*]] = affine.apply #[[$STRIDE1_MAP]]()[%[[DYN_STRIDE]]]
+//   CHECK-DAG: %[[FINAL_OFFSET:.*]] = affine.apply #[[$OFFSET_MAP]]()[%[[DYN_STRIDE]]]
+//
+//       CHECK: return %[[BASE]], %[[FINAL_OFFSET]], %[[C6]], %[[C3]], %[[DIM1_STRIDE]], %[[C1]]
+func.func @extract_strided_metadata_of_rank_reduced_subview_w_variable_strides(
+    %base: memref<8x16x4xf32>, %stride: index)
+    -> (memref<f32>, index, index, index, index, index) {
+
+  %subview = memref.subview %base[3, 4, 2][1, 6, 3][1, %stride, 1] :
+    memref<8x16x4xf32> to memref<6x3xf32, strided<[4, 1], offset: 210>>
+
+  %base_buffer, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %subview :
+    memref<6x3xf32, strided<[4, 1], offset: 210>>
+    -> memref<f32>, index, index, index, index, index
+
+  return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 :
+    memref<f32>, index, index, index, index, index
+}
+
+// -----
+
+// Check that we simplify extract_strided_metadata of subview properly
+// when the subview uses variable offsets.
+// See extract_strided_metadata_of_subview for an explanation of the actual
+// expansion.
+//
+// Orig strides: [128, 1]
+// Sub strides: [1, 1]
+// => New strides: [128, 1]
+//
+// Orig offset: 0
+// Sub offsets: [%arg1, %arg2]
+// => Final offset: 128 * arg1 + 1 * %arg2 + 0
+//
+//   CHECK-DAG: #[[$OFFSETS_MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 128 + s1)>
+// CHECK-LABEL: func @extract_strided_metadata_of_subview_w_variable_offset
+//  CHECK-SAME: (%[[ARG:.*]]: memref<384x128xf32>,
+//  CHECK-SAME: %[[DYN_OFFSET0:.*]]: index,
+//  CHECK-SAME: %[[DYN_OFFSET1:.*]]: index)
+//
+//   CHECK-DAG: %[[C128:.*]] = arith.constant 128 : index
+//   CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index
+//   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+//   CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
+//
+//   CHECK-DAG: %[[FINAL_OFFSET:.*]] = affine.apply #[[$OFFSETS_MAP]]()[%[[DYN_OFFSET0]], %[[DYN_OFFSET1]]]
+//
+//       CHECK: return %[[BASE]], %[[FINAL_OFFSET]], %[[C64]], %[[C64]], %[[C128]], %[[C1]]
+#map0 = affine_map<(d0, d1)[s0] -> (d0 * 128 + s0 + d1)>
+func.func @extract_strided_metadata_of_subview_w_variable_offset(
+    %arg0: memref<384x128xf32>, %arg1 : index, %arg2 : index)
+    -> (memref<f32>, index, index, index, index, index) {
+
+  %subview = memref.subview %arg0[%arg1, %arg2] [64, 64] [1, 1] :
+    memref<384x128xf32> to memref<64x64xf32, #map0>
+
+  %base_buffer, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %subview :
+    memref<64x64xf32, #map0> -> memref<f32>, index, index, index, index, index
+
+  return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 :
+    memref<f32>, index, index, index, index, index
+}
+
+// -----
+
+// Check that all the math is correct for all types of computations.
+// We achieve that by using dynamic values for all the 
diff erent types:
+// - Offsets
+// - Sizes
+// - Strides
+//
+// Orig strides: [s0, s1, s2]
+// Sub strides: [subS0, subS1, subS2]
+// => New strides: [s0 * subS0, s1 * subS1, s2 * subS2]
+// ==> 1 affine map (used for each stride) with two values.
+//
+// Orig offset: origOff
+// Sub offsets: [subO0, subO1, subO2]
+// => Final offset: s0 * subS0 * subO0 + ... + s2 * subS2 * subO2 + origOff
+// ==> 1 affine map with (rank * 3 + 1) symbols
+//
+// CHECK-DAG: #[[$STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> (s0 * s1)>
+// CHECK-DAG: #[[$OFFSET_MAP:.*]] = affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0 + (s1 * s2) * s3 + (s4 * s5) * s6 + (s7 * s8) * s9)>
+// CHECK-LABEL: func @extract_strided_metadata_of_subview_all_dynamic
+//  CHECK-SAME: (%[[ARG:.*]]: memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>, %[[DYN_OFFSET0:.*]]: index, %[[DYN_OFFSET1:.*]]: index, %[[DYN_OFFSET2:.*]]: index, %[[DYN_SIZE0:.*]]: index, %[[DYN_SIZE1:.*]]: index, %[[DYN_SIZE2:.*]]: index, %[[DYN_STRIDE0:.*]]: index, %[[DYN_STRIDE1:.*]]: index, %[[DYN_STRIDE2:.*]]: index)
+//
+//   CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:3, %[[STRIDES:.*]]:3 = memref.extract_strided_metadata %[[ARG]]
+//
+//  CHECK-DAG: %[[FINAL_STRIDE0:.*]] = affine.apply #[[$STRIDE_MAP]]()[%[[DYN_STRIDE0]], %[[STRIDES]]#0]
+//  CHECK-DAG: %[[FINAL_STRIDE1:.*]] = affine.apply #[[$STRIDE_MAP]]()[%[[DYN_STRIDE1]], %[[STRIDES]]#1]
+//  CHECK-DAG: %[[FINAL_STRIDE2:.*]] = affine.apply #[[$STRIDE_MAP]]()[%[[DYN_STRIDE2]], %[[STRIDES]]#2]
+//
+//  CHECK-DAG: %[[FINAL_OFFSET:.*]] = affine.apply #[[$OFFSET_MAP]]()[%[[OFFSET]], %[[DYN_OFFSET0]], %[[DYN_STRIDE0]], %[[STRIDES]]#0, %[[DYN_OFFSET1]], %[[DYN_STRIDE1]], %[[STRIDES]]#1, %[[DYN_OFFSET2]], %[[DYN_STRIDE2]], %[[STRIDES]]#2]
+//
+//       CHECK: return %[[BASE]], %[[FINAL_OFFSET]], %[[DYN_SIZE0]], %[[DYN_SIZE1]], %[[DYN_SIZE2]], %[[FINAL_STRIDE0]], %[[FINAL_STRIDE1]], %[[FINAL_STRIDE2]]
+func.func @extract_strided_metadata_of_subview_all_dynamic(
+    %base: memref<?x?x?xf32, strided<[?,?,?], offset:?>>,
+    %offset0: index, %offset1: index, %offset2: index,
+    %size0: index, %size1: index, %size2: index,
+    %stride0: index, %stride1: index, %stride2: index)
+    -> (memref<f32>, index, index, index, index, index, index, index) {
+
+  %subview = memref.subview %base[%offset0, %offset1, %offset2]
+                                 [%size0, %size1, %size2]
+                                 [%stride0, %stride1, %stride2] :
+    memref<?x?x?xf32, strided<[?,?,?], offset: ?>> to
+      memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
+
+  %base_buffer, %offset, %sizes:3, %strides:3 = memref.extract_strided_metadata %subview :
+    memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
+    -> memref<f32>, index, index, index, index, index, index, index
+
+  return %base_buffer, %offset, %sizes#0, %sizes#1, %sizes#2, %strides#0, %strides#1, %strides#2 :
+    memref<f32>, index, index, index, index, index, index, index
+}


        


More information about the Mlir-commits mailing list