[Mlir-commits] [mlir] 6fd3c20 - [MLIR] Add a utility pass to linearize `memref` (#136797)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu May 22 10:05:41 PDT 2025


Author: Alan Li
Date: 2025-05-22T13:05:37-04:00
New Revision: 6fd3c20d25a88ccc3f2b5275e67de8b88ad5f873

URL: https://github.com/llvm/llvm-project/commit/6fd3c20d25a88ccc3f2b5275e67de8b88ad5f873
DIFF: https://github.com/llvm/llvm-project/commit/6fd3c20d25a88ccc3f2b5275e67de8b88ad5f873.diff

LOG: [MLIR] Add a utility pass to linearize `memref` (#136797)

To add a transformation that simplifies memory access patterns, this PR
adds a memref linearizer which is based on the GPU/DecomposeMemRefs
pass, with the following changes:
* support vector dialect ops
* instead of decompose memrefs to rank-0 memrefs, flatten higher-ranked
memrefs to rank-1.

Notes:
* After the linearization, a MemRef's offset is kept, so a
`memref<4x8xf32, strided<[8, 1], offset: 100>>` becomes `memref<32xf32,
strided<[1], offset: 100>>`.
* It also works with dynamic shapes and strides and offsets (see test
cases for details).
* The shape of the casted memref is computed as 1d, flattened.

Added: 
    mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
    mlir/test/Dialect/MemRef/flatten_memref.mlir

Modified: 
    mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
    mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
    mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
index a46f73350bb3c..a8d135caa74f0 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
@@ -245,5 +245,15 @@ def ExpandReallocPass : Pass<"expand-realloc"> {
   ];
 }
 
+def FlattenMemrefsPass : Pass<"flatten-memref"> {
+  let summary = "Flatten a multiple dimensional memref to 1-dimensional";
+  let description = [{
+
+  }];
+  let dependentDialects = [
+      "affine::AffineDialect", "memref::MemRefDialect", "vector::VectorDialect"
+  ];
+}
+
 #endif // MLIR_DIALECT_MEMREF_TRANSFORMS_PASSES
 

diff  --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
index 62a2297c80e78..c2b8cb05be922 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
@@ -144,6 +144,8 @@ FailureOr<memref::AllocOp> multiBuffer(memref::AllocOp allocOp,
 /// ```
 void populateExtractAddressComputationsPatterns(RewritePatternSet &patterns);
 
+void populateFlattenMemrefsPatterns(RewritePatternSet &patterns);
+
 /// Build a new memref::AllocaOp whose dynamic sizes are independent of all
 /// given independencies. If the op is already independent of all
 /// independencies, the same AllocaOp result is returned.

diff  --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
index ecab97bc2b8e7..637f5ec1c9f9b 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
@@ -8,6 +8,7 @@ add_mlir_dialect_library(MLIRMemRefTransforms
   EmulateWideInt.cpp
   EmulateNarrowType.cpp
   ExtractAddressComputations.cpp
+  FlattenMemRefs.cpp
   FoldMemRefAliasOps.cpp
   IndependenceTransforms.cpp
   MultiBuffer.cpp

diff  --git a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
new file mode 100644
index 0000000000000..e9729a4766a0a
--- /dev/null
+++ b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
@@ -0,0 +1,286 @@
+//===----- FlattenMemRefs.cpp - MemRef ops flattener pass  ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains patterns for flattening an multi-rank memref-related
+// ops into 1-d memref ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Transforms/Passes.h"
+#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
+#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+#include <numeric>
+
+namespace mlir {
+namespace memref {
+#define GEN_PASS_DEF_FLATTENMEMREFSPASS
+#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
+} // namespace memref
+} // namespace mlir
+
+using namespace mlir;
+
+static Value getValueFromOpFoldResult(OpBuilder &rewriter, Location loc,
+                                      OpFoldResult in) {
+  if (Attribute offsetAttr = dyn_cast<Attribute>(in)) {
+    return rewriter.create<arith::ConstantIndexOp>(
+        loc, cast<IntegerAttr>(offsetAttr).getInt());
+  }
+  return cast<Value>(in);
+}
+
+/// Returns a collapsed memref and the linearized index to access the element
+/// at the specified indices.
+static std::pair<Value, Value> getFlattenMemrefAndOffset(OpBuilder &rewriter,
+                                                         Location loc,
+                                                         Value source,
+                                                         ValueRange indices) {
+  int64_t sourceOffset;
+  SmallVector<int64_t, 4> sourceStrides;
+  auto sourceType = cast<MemRefType>(source.getType());
+  if (failed(sourceType.getStridesAndOffset(sourceStrides, sourceOffset))) {
+    assert(false);
+  }
+
+  memref::ExtractStridedMetadataOp stridedMetadata =
+      rewriter.create<memref::ExtractStridedMetadataOp>(loc, source);
+
+  auto typeBit = sourceType.getElementType().getIntOrFloatBitWidth();
+  OpFoldResult linearizedIndices;
+  memref::LinearizedMemRefInfo linearizedInfo;
+  std::tie(linearizedInfo, linearizedIndices) =
+      memref::getLinearizedMemRefOffsetAndSize(
+          rewriter, loc, typeBit, typeBit,
+          stridedMetadata.getConstifiedMixedOffset(),
+          stridedMetadata.getConstifiedMixedSizes(),
+          stridedMetadata.getConstifiedMixedStrides(),
+          getAsOpFoldResult(indices));
+
+  return std::make_pair(
+      rewriter.create<memref::ReinterpretCastOp>(
+          loc, source,
+          /* offset = */ linearizedInfo.linearizedOffset,
+          /* shapes = */
+          ArrayRef<OpFoldResult>{linearizedInfo.linearizedSize},
+          /* strides = */
+          ArrayRef<OpFoldResult>{rewriter.getIndexAttr(1)}),
+      getValueFromOpFoldResult(rewriter, loc, linearizedIndices));
+}
+
+static bool needFlattening(Value val) {
+  auto type = cast<MemRefType>(val.getType());
+  return type.getRank() > 1;
+}
+
+static bool checkLayout(Value val) {
+  auto type = cast<MemRefType>(val.getType());
+  return type.getLayout().isIdentity() ||
+         isa<StridedLayoutAttr>(type.getLayout());
+}
+
+namespace {
+static Value getTargetMemref(Operation *op) {
+  return llvm::TypeSwitch<Operation *, Value>(op)
+      .template Case<memref::LoadOp, memref::StoreOp, memref::AllocaOp,
+                     memref::AllocOp>([](auto op) { return op.getMemref(); })
+      .template Case<vector::LoadOp, vector::StoreOp, vector::MaskedLoadOp,
+                     vector::MaskedStoreOp, vector::TransferReadOp,
+                     vector::TransferWriteOp>(
+          [](auto op) { return op.getBase(); })
+      .Default([](auto) { return Value{}; });
+}
+
+template <typename T>
+static void castAllocResult(T oper, T newOper, Location loc,
+                            PatternRewriter &rewriter) {
+  memref::ExtractStridedMetadataOp stridedMetadata =
+      rewriter.create<memref::ExtractStridedMetadataOp>(loc, oper);
+  rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
+      oper, cast<MemRefType>(oper.getType()), newOper,
+      /*offset=*/rewriter.getIndexAttr(0),
+      stridedMetadata.getConstifiedMixedSizes(),
+      stridedMetadata.getConstifiedMixedStrides());
+}
+
+template <typename T>
+static void replaceOp(T op, PatternRewriter &rewriter, Value flatMemref,
+                      Value offset) {
+  Location loc = op->getLoc();
+  llvm::TypeSwitch<Operation *>(op.getOperation())
+      .template Case<memref::AllocOp>([&](auto oper) {
+        auto newAlloc = rewriter.create<memref::AllocOp>(
+            loc, cast<MemRefType>(flatMemref.getType()),
+            oper.getAlignmentAttr());
+        castAllocResult(oper, newAlloc, loc, rewriter);
+      })
+      .template Case<memref::AllocaOp>([&](auto oper) {
+        auto newAlloca = rewriter.create<memref::AllocaOp>(
+            loc, cast<MemRefType>(flatMemref.getType()),
+            oper.getAlignmentAttr());
+        castAllocResult(oper, newAlloca, loc, rewriter);
+      })
+      .template Case<memref::LoadOp>([&](auto op) {
+        auto newLoad = rewriter.create<memref::LoadOp>(
+            loc, op->getResultTypes(), flatMemref, ValueRange{offset});
+        newLoad->setAttrs(op->getAttrs());
+        rewriter.replaceOp(op, newLoad.getResult());
+      })
+      .template Case<memref::StoreOp>([&](auto op) {
+        auto newStore = rewriter.create<memref::StoreOp>(
+            loc, op->getOperands().front(), flatMemref, ValueRange{offset});
+        newStore->setAttrs(op->getAttrs());
+        rewriter.replaceOp(op, newStore);
+      })
+      .template Case<vector::LoadOp>([&](auto op) {
+        auto newLoad = rewriter.create<vector::LoadOp>(
+            loc, op->getResultTypes(), flatMemref, ValueRange{offset});
+        newLoad->setAttrs(op->getAttrs());
+        rewriter.replaceOp(op, newLoad.getResult());
+      })
+      .template Case<vector::StoreOp>([&](auto op) {
+        auto newStore = rewriter.create<vector::StoreOp>(
+            loc, op->getOperands().front(), flatMemref, ValueRange{offset});
+        newStore->setAttrs(op->getAttrs());
+        rewriter.replaceOp(op, newStore);
+      })
+      .template Case<vector::MaskedLoadOp>([&](auto op) {
+        auto newMaskedLoad = rewriter.create<vector::MaskedLoadOp>(
+            loc, op.getType(), flatMemref, ValueRange{offset}, op.getMask(),
+            op.getPassThru());
+        newMaskedLoad->setAttrs(op->getAttrs());
+        rewriter.replaceOp(op, newMaskedLoad.getResult());
+      })
+      .template Case<vector::MaskedStoreOp>([&](auto op) {
+        auto newMaskedStore = rewriter.create<vector::MaskedStoreOp>(
+            loc, flatMemref, ValueRange{offset}, op.getMask(),
+            op.getValueToStore());
+        newMaskedStore->setAttrs(op->getAttrs());
+        rewriter.replaceOp(op, newMaskedStore);
+      })
+      .template Case<vector::TransferReadOp>([&](auto op) {
+        auto newTransferRead = rewriter.create<vector::TransferReadOp>(
+            loc, op.getType(), flatMemref, ValueRange{offset}, op.getPadding());
+        rewriter.replaceOp(op, newTransferRead.getResult());
+      })
+      .template Case<vector::TransferWriteOp>([&](auto op) {
+        auto newTransferWrite = rewriter.create<vector::TransferWriteOp>(
+            loc, op.getVector(), flatMemref, ValueRange{offset});
+        rewriter.replaceOp(op, newTransferWrite);
+      })
+      .Default([&](auto op) {
+        op->emitOpError("unimplemented: do not know how to replace op.");
+      });
+}
+
+template <typename T>
+static ValueRange getIndices(T op) {
+  if constexpr (std::is_same_v<T, memref::AllocaOp> ||
+                std::is_same_v<T, memref::AllocOp>) {
+    return ValueRange{};
+  } else {
+    return op.getIndices();
+  }
+}
+
+template <typename T>
+static LogicalResult canBeFlattened(T op, PatternRewriter &rewriter) {
+  return llvm::TypeSwitch<Operation *, LogicalResult>(op.getOperation())
+      .template Case<vector::TransferReadOp, vector::TransferWriteOp>(
+          [&](auto oper) {
+            // For vector.transfer_read/write, must make sure:
+            // 1. all accesses are inbound, and
+            // 2. has an identity or minor identity permutation map.
+            auto permutationMap = oper.getPermutationMap();
+            if (!permutationMap.isIdentity() &&
+                !permutationMap.isMinorIdentity()) {
+              return rewriter.notifyMatchFailure(
+                  oper, "only identity permutation map is supported");
+            }
+            mlir::ArrayAttr inbounds = oper.getInBounds();
+            if (llvm::any_of(inbounds, [](Attribute attr) {
+                  return !cast<BoolAttr>(attr).getValue();
+                })) {
+              return rewriter.notifyMatchFailure(oper,
+                                                 "only inbounds are supported");
+            }
+            return success();
+          })
+      .Default([&](auto op) { return success(); });
+}
+
+template <typename T>
+struct MemRefRewritePattern : public OpRewritePattern<T> {
+  using OpRewritePattern<T>::OpRewritePattern;
+  LogicalResult matchAndRewrite(T op,
+                                PatternRewriter &rewriter) const override {
+    LogicalResult canFlatten = canBeFlattened(op, rewriter);
+    if (failed(canFlatten)) {
+      return canFlatten;
+    }
+
+    Value memref = getTargetMemref(op);
+    if (!needFlattening(memref) || !checkLayout(memref))
+      return failure();
+    auto &&[flatMemref, offset] = getFlattenMemrefAndOffset(
+        rewriter, op->getLoc(), memref, getIndices<T>(op));
+    replaceOp<T>(op, rewriter, flatMemref, offset);
+    return success();
+  }
+};
+
+struct FlattenMemrefsPass
+    : public mlir::memref::impl::FlattenMemrefsPassBase<FlattenMemrefsPass> {
+  using Base::Base;
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<affine::AffineDialect, arith::ArithDialect,
+                    memref::MemRefDialect, vector::VectorDialect>();
+  }
+
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+
+    memref::populateFlattenMemrefsPatterns(patterns);
+
+    if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+      return signalPassFailure();
+  }
+};
+
+} // namespace
+
+void memref::populateFlattenMemrefsPatterns(RewritePatternSet &patterns) {
+  patterns.insert<MemRefRewritePattern<memref::LoadOp>,
+                  MemRefRewritePattern<memref::StoreOp>,
+                  MemRefRewritePattern<memref::AllocOp>,
+                  MemRefRewritePattern<memref::AllocaOp>,
+                  MemRefRewritePattern<vector::LoadOp>,
+                  MemRefRewritePattern<vector::StoreOp>,
+                  MemRefRewritePattern<vector::TransferReadOp>,
+                  MemRefRewritePattern<vector::TransferWriteOp>,
+                  MemRefRewritePattern<vector::MaskedLoadOp>,
+                  MemRefRewritePattern<vector::MaskedStoreOp>>(
+      patterns.getContext());
+}

diff  --git a/mlir/test/Dialect/MemRef/flatten_memref.mlir b/mlir/test/Dialect/MemRef/flatten_memref.mlir
new file mode 100644
index 0000000000000..e45a10ca0d431
--- /dev/null
+++ b/mlir/test/Dialect/MemRef/flatten_memref.mlir
@@ -0,0 +1,300 @@
+// RUN: mlir-opt --flatten-memref %s --split-input-file --verify-diagnostics | FileCheck %s
+
+func.func @load_scalar_from_memref(%input: memref<4x8xf32, strided<[8, 1], offset: 100>>) -> f32 {
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %value = memref.load %input[%c1, %c2] : memref<4x8xf32, strided<[8, 1], offset: 100>>
+  return %value : f32
+}
+// CHECK-LABEL: func @load_scalar_from_memref
+// CHECK-NEXT: %[[C10:.*]] = arith.constant 10 : index
+// CHECK-NEXT: %[[REINT:.*]] = memref.reinterpret_cast %arg0 to offset: [100], sizes: [32], strides: [1]
+// CHECK-SAME: memref<4x8xf32, strided<[8, 1], offset: 100>> to memref<32xf32, strided<[1], offset: 100>>
+// CHECK-NEXT: memref.load %[[REINT]][%[[C10]]] : memref<32xf32, strided<[1], offset: 100>>
+
+
+// -----
+
+func.func @load_scalar_from_memref_dynamic_dim(%input: memref<?x?xf32, strided<[?, ?], offset: ?>>, %row: index, %col: index) -> f32 {
+  %value = memref.load %input[%col, %row] : memref<?x?xf32, strided<[?, ?], offset: ?>>
+  return %value : f32
+}
+
+// CHECK: #[[MAP:.*]] = affine_map<()[s0, s1, s2, s3] -> (s0 * s1 + s2 * s3)>
+// CHECK: #[[MAP1:.*]] = affine_map<()[s0, s1, s2, s3] -> (s0 * s1, s2 * s3)>
+// CHECK: func @load_scalar_from_memref_dynamic_dim
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32, strided<[?, ?], offset: ?>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG0]]
+// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[STRIDES]]#0, %[[ARG1]], %[[STRIDES]]#1]
+// CHECK: %[[SIZE:.*]] = affine.max #[[MAP1]]()[%[[STRIDES]]#0, %[[SIZES]]#0, %[[STRIDES]]#1, %[[SIZES]]#1]
+// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %arg0 to offset: [%[[OFFSET]]], sizes: [%[[SIZE]]], strides: [1] : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?xf32, strided<[1], offset: ?>> 
+// CHECK: memref.load %[[REINT]][%[[IDX]]]
+
+// -----
+
+func.func @load_scalar_from_memref_static_dim(%input: memref<8x12xf32, strided<[24, 2], offset: 100>>) -> f32 {
+   %c7 = arith.constant 7 : index
+   %c10 = arith.constant 10 : index
+  %value = memref.load %input[%c7, %c10] : memref<8x12xf32, strided<[24, 2], offset: 100>>
+  return %value : f32
+}
+
+// CHECK-LABEL: func @load_scalar_from_memref_static_dim
+// CHECK-SAME: (%[[ARG0:.*]]: memref<8x12xf32, strided<[24, 2], offset: 100>>)
+// CHECK: %[[C188:.*]] = arith.constant 188 : index
+// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [100], sizes: [192], strides: [1] : memref<8x12xf32, strided<[24, 2], offset: 100>> to memref<192xf32, strided<[1], offset: 100>>
+// CHECK: memref.load %[[REINT]][%[[C188]]] : memref<192xf32, strided<[1], offset: 100>>
+
+// -----
+
+func.func @store_scalar_from_memref_padded(%input: memref<4x8xf32, strided<[18, 2], offset: 100>>, %row: index, %col: index, %value: f32) {
+  memref.store %value, %input[%col, %row] : memref<4x8xf32, strided<[18, 2], offset: 100>>
+  return
+}
+// CHECK: #[[MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 18 + s1 * 2)>
+// CHECK: func @store_scalar_from_memref_padded
+// CHECK-SAME: (%[[ARG0:.*]]: memref<4x8xf32, strided<[18, 2], offset: 100>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: f32)
+// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[ARG1]]]
+// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]]
+// CHECK: memref.store %[[ARG3]], %[[REINT]][%[[IDX]]] : memref<72xf32, strided<[1], offset: 100>>
+
+// -----
+
+func.func @store_scalar_from_memref_dynamic_dim(%input: memref<?x?xf32, strided<[?, ?], offset: ?>>, %row: index, %col: index, %value: f32) {
+  memref.store %value, %input[%col, %row] : memref<?x?xf32, strided<[?, ?], offset: ?>>
+  return
+}
+// CHECK: #[[MAP:.*]] = affine_map<()[s0, s1, s2, s3] -> (s0 * s1 + s2 * s3)>
+// CHECK: #[[MAP1:.*]] = affine_map<()[s0, s1, s2, s3] -> (s0 * s1, s2 * s3)>
+// CHECK: func @store_scalar_from_memref_dynamic_dim
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32, strided<[?, ?], offset: ?>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: f32)
+// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG0]]
+// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[STRIDES]]#0, %[[ARG1]], %[[STRIDES]]#1]
+// CHECK: %[[SIZE:.*]] = affine.max #[[MAP1]]()[%[[STRIDES]]#0, %[[SIZES]]#0, %[[STRIDES]]#1, %[[SIZES]]#1]
+// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [%[[OFFSET]]], sizes: [%[[SIZE]]], strides: [1]
+// CHECK: memref.store %[[ARG3]], %[[REINT]][%[[IDX]]]
+
+// -----
+
+func.func @load_vector_from_memref(%input: memref<4x8xf32>) -> vector<8xf32> {
+  %c3 = arith.constant 3 : index
+  %c6 = arith.constant 6 : index
+  %value = vector.load %input[%c3, %c6] : memref<4x8xf32>, vector<8xf32>
+  return %value : vector<8xf32>
+}
+// CHECK-LABEL: func @load_vector_from_memref
+// CHECK: %[[C30:.*]] = arith.constant 30
+// CHECK-NEXT: %[[REINT:.*]] = memref.reinterpret_cast %arg0 to offset: [0], sizes: [32], strides: [1]
+// CHECK-NEXT: vector.load %[[REINT]][%[[C30]]]
+
+// -----
+
+func.func @load_vector_from_memref_odd(%input: memref<3x7xi2>) -> vector<3xi2> {
+  %c1 = arith.constant 1 : index
+  %c3 = arith.constant 3 : index
+  %value = vector.load %input[%c1, %c3] : memref<3x7xi2>, vector<3xi2>
+  return %value : vector<3xi2>
+}
+// CHECK-LABEL: func @load_vector_from_memref_odd
+// CHECK: %[[C10:.*]] = arith.constant 10 : index
+// CHECK-NEXT: %[[REINT:.*]] = memref.reinterpret_cast
+// CHECK-NEXT: vector.load %[[REINT]][%[[C10]]]
+
+// -----
+
+func.func @load_vector_from_memref_dynamic(%input: memref<3x7xi2>, %row: index, %col: index) -> vector<3xi2> {
+  %value = vector.load %input[%col, %row] : memref<3x7xi2>, vector<3xi2>
+  return %value : vector<3xi2>
+}
+// CHECK: #[[MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 7 + s1)>
+// CHECK: func @load_vector_from_memref_dynamic
+// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()
+// CHECK: %[[REINT:.*]] = memref.reinterpret_cast
+// CHECK: vector.load %[[REINT]][%[[IDX]]] : memref<21xi2, strided<[1]>>, vector<3xi2>
+
+// -----
+
+func.func @store_vector_to_memref_odd(%input: memref<3x7xi2>, %value: vector<3xi2>) {
+  %c1 = arith.constant 1 : index
+  %c3 = arith.constant 3 : index
+  vector.store %value, %input[%c1, %c3] : memref<3x7xi2>, vector<3xi2>
+  return
+}
+// CHECK-LABEL: func @store_vector_to_memref_odd
+// CHECK-SAME: (%[[ARG0:.*]]: memref<3x7xi2>, %[[ARG1:.*]]: vector<3xi2>)
+// CHECK: %[[C10:.*]] = arith.constant 10 : index
+// CHECK-NEXT: %[[REINT:.*]] = memref.reinterpret_cast
+// CHECK-NEXT: vector.store %[[ARG1]], %[[REINT]][%[[C10]]] : memref<21xi2, strided<[1]>
+
+// -----
+
+func.func @store_vector_to_memref_dynamic(%input: memref<3x7xi2>, %value: vector<3xi2>, %row: index, %col: index) {
+  vector.store %value, %input[%col, %row] : memref<3x7xi2>, vector<3xi2>
+  return
+}
+// CHECK: #[[MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 7 + s1)>
+// CHECK: func @store_vector_to_memref_dynamic
+// CHECK-SAME: (%[[ARG0:.*]]: memref<3x7xi2>, %[[ARG1:.*]]: vector<3xi2>, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index)
+// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[ARG3]], %[[ARG2]]]
+// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [0], sizes: [21], strides: [1]
+// CHECK: vector.store %[[ARG1]], %[[REINT]][%[[IDX]]]
+
+// -----
+
+func.func @mask_store_vector_to_memref_odd(%input: memref<3x7xi2>, %value: vector<3xi2>, %mask: vector<3xi1>) {
+  %c1 = arith.constant 1 : index
+  %c3 = arith.constant 3 : index
+  vector.maskedstore %input[%c1, %c3], %mask, %value  : memref<3x7xi2>, vector<3xi1>, vector<3xi2>
+  return
+}
+// CHECK-LABEL: func @mask_store_vector_to_memref_odd
+// CHECK-SAME: (%[[ARG0:.*]]: memref<3x7xi2>, %[[ARG1:.*]]: vector<3xi2>, %[[ARG2:.*]]: vector<3xi1>)
+// CHECK: %[[C10:.*]] = arith.constant 10 : index
+// CHECK-NEXT: %[[REINT:.*]] = memref.reinterpret_cast
+// CHECK: vector.maskedstore %[[REINT]][%[[C10]]], %[[ARG2]], %[[ARG1]]
+
+// -----
+
+func.func @mask_store_vector_to_memref_dynamic(%input: memref<3x7xi2>, %value: vector<3xi2>, %row: index, %col: index, %mask: vector<3xi1>) {
+  vector.maskedstore %input[%col, %row], %mask, %value : memref<3x7xi2>, vector<3xi1>, vector<3xi2>
+  return
+}
+// CHECK: #map = affine_map<()[s0, s1] -> (s0 * 7 + s1)>
+// CHECK: func @mask_store_vector_to_memref_dynamic
+// CHECK-SAME: (%[[ARG0:.*]]: memref<3x7xi2>, %[[ARG1:.*]]: vector<3xi2>, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: vector<3xi1>)
+// CHECK: %[[IDX:.*]] = affine.apply #map()[%[[ARG3]], %[[ARG2]]]
+// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]]
+// CHECK: vector.maskedstore %[[REINT]][%[[IDX]]], %[[ARG4]], %[[ARG1]]
+
+// -----
+func.func @mask_load_vector_from_memref_odd(%input: memref<3x7xi2>, %mask: vector<3xi1>, %passthru: vector<3xi2>) -> vector<3xi2> {
+  %c1 = arith.constant 1 : index
+  %c3 = arith.constant 3 : index
+  %result = vector.maskedload %input[%c1, %c3], %mask, %passthru : memref<3x7xi2>, vector<3xi1>, vector<3xi2> into vector<3xi2>
+  return %result : vector<3xi2>
+}
+// CHECK-LABEL: func @mask_load_vector_from_memref_odd
+// CHECK-SAME: (%[[ARG0:.*]]: memref<3x7xi2>, %[[MASK:.*]]: vector<3xi1>, %[[PASSTHRU:.*]]: vector<3xi2>)
+// CHECK: %[[C10:.*]] = arith.constant 10 : index
+// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [0], sizes: [21], strides: [1]
+// CHECK: vector.maskedload %[[REINT]][%[[C10]]], %[[MASK]], %[[PASSTHRU]]
+
+// -----
+
+func.func @mask_load_vector_from_memref_dynamic(%input: memref<3x7xi2>, %row: index, %col: index, %mask: vector<3xi1>, %passthru: vector<3xi2>) -> vector<3xi2> {
+  %result = vector.maskedload %input[%col, %row], %mask, %passthru : memref<3x7xi2>, vector<3xi1>, vector<3xi2> into vector<3xi2>
+  return %result : vector<3xi2>
+}
+// CHECK: #[[MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 7 + s1)>
+// CHECK: func @mask_load_vector_from_memref_dynamic
+// CHECK-SAME: (%[[ARG0:.*]]: memref<3x7xi2>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: vector<3xi1>, %[[ARG4:.*]]: vector<3xi2>)
+// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[ARG1]]]
+// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]]
+// CHECK: vector.maskedload %[[REINT]][%[[IDX]]], %[[ARG3]]
+
+// -----
+
+func.func @transfer_read_memref(%input: memref<4x8xi2>, %value: vector<8xi2>, %row: index, %col: index) -> vector<8xi2> {
+   %c0 = arith.constant 0 : i2
+   %0 = vector.transfer_read %input[%col, %row], %c0 {in_bounds = [true]} : memref<4x8xi2>, vector<8xi2>
+   return %0 : vector<8xi2>
+}
+
+// CHECK: #[[MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 8 + s1)>
+// CHECK: func @transfer_read_memref
+// CHECK-SAME: (%[[ARG0:.*]]: memref<4x8xi2>, %[[ARG1:.*]]: vector<8xi2>, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index)
+// CHECK: %[[C0:.*]] = arith.constant 0 : i2
+// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[ARG3]], %[[ARG2]]]
+// CHECK-NEXT: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]]
+// CHECK-NEXT: vector.transfer_read %[[REINT]][%[[IDX]]], %[[C0]]
+
+// -----
+
+func.func @transfer_read_memref_not_inbound(%input: memref<4x8xi2>, %value: vector<8xi2>, %row: index, %col: index) -> vector<8xi2> {
+   %c0 = arith.constant 0 : i2
+   %0 = vector.transfer_read %input[%col, %row], %c0 {in_bounds = [false]} : memref<4x8xi2>, vector<8xi2>
+   return %0 : vector<8xi2>
+}
+
+// CHECK-LABEL: func @transfer_read_memref_not_inbound
+// CHECK-SAME: (%[[ARG0:.*]]: memref<4x8xi2>, %[[ARG1:.*]]: vector<8xi2>, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index)
+// CHECK: vector.transfer_read %[[ARG0]][%[[ARG3]], %[[ARG2]]]
+
+// -----
+
+func.func @transfer_read_memref_non_id(%input: memref<4x8xi2>, %value: vector<8xi2>, %row: index, %col: index) -> vector<8xi2> {
+   %c0 = arith.constant 0 : i2
+   %0 = vector.transfer_read %input[%col, %row], %c0 {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]} : memref<4x8xi2>, vector<8xi2>
+   return %0 : vector<8xi2>
+}
+
+// CHECK-LABEL: func @transfer_read_memref_non_id
+// CHECK-SAME: (%[[ARG0:.*]]: memref<4x8xi2>, %[[ARG1:.*]]: vector<8xi2>, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index)
+// CHECK: vector.transfer_read %[[ARG0]][%[[ARG3]], %[[ARG2]]]
+
+// -----
+
+func.func @transfer_write_memref(%input: memref<4x8xi2>, %value: vector<8xi2>, %row: index, %col: index) {
+   vector.transfer_write %value, %input[%col, %row] {in_bounds = [true]} : vector<8xi2>, memref<4x8xi2>
+   return
+}
+
+// CHECK: #[[MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 8 + s1)>
+// CHECK: func @transfer_write_memref
+// CHECK-SAME: (%[[ARG0:.*]]: memref<4x8xi2>, %[[ARG1:.*]]: vector<8xi2>, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index)
+// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[ARG3]], %[[ARG2]]]
+// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]]
+// CHECK: vector.transfer_write %[[ARG1]], %[[REINT]][%[[IDX]]]
+
+// -----
+
+func.func @alloc() -> memref<4x8xf32> {
+  %0 = memref.alloc() : memref<4x8xf32>
+  return %0 : memref<4x8xf32>
+}
+
+// CHECK-LABEL: func @alloc
+// CHECK-SAME: () -> memref<4x8xf32>
+// CHECK-NEXT: %[[ALLOC:.*]] = memref.alloc() : memref<32xf32, strided<[1]>>
+// CHECK-NEXT: %[[REINT:.*]] = memref.reinterpret_cast %[[ALLOC]] to offset: [0], sizes: [4, 8], strides: [8, 1] : memref<32xf32, strided<[1]>> to memref<4x8xf32>
+
+// -----
+
+func.func @alloca() -> memref<4x8xf32> {
+  %0 = memref.alloca() : memref<4x8xf32>
+  return %0 : memref<4x8xf32>
+}
+
+// CHECK-LABEL: func.func @alloca() -> memref<4x8xf32>
+// CHECK: %[[ALLOC:.*]] = memref.alloca() : memref<32xf32, strided<[1]>>
+// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ALLOC]] to offset: [0], sizes: [4, 8], strides: [8, 1] : memref<32xf32, strided<[1]>> to memref<4x8xf32>
+
+// -----
+
+func.func @chained_alloc_load() -> vector<8xf32> {
+  %c3 = arith.constant 3 : index
+  %c6 = arith.constant 6 : index
+  %0 = memref.alloc() : memref<4x8xf32>
+  %value = vector.load %0[%c3, %c6] : memref<4x8xf32>, vector<8xf32>
+  return %value : vector<8xf32>
+}
+
+// CHECK-LABEL: func @chained_alloc_load
+// CHECK-SAME: () -> vector<8xf32>
+// CHECK-NEXT: %[[C30:.*]] = arith.constant 30 : index
+// CHECK-NEXT: %[[ALLOC:.*]] = memref.alloc() : memref<32xf32, strided<[1]>>
+// CHECK-NEXT: vector.load %[[ALLOC]][%[[C30]]] : memref<32xf32, strided<[1]>>, vector<8xf32>
+
+// -----
+
+func.func @load_scalar_from_memref_static_dim_col_major(%input: memref<4x8xf32, strided<[1, 4], offset: 100>>, %row: index, %col: index) -> f32 {
+  %value = memref.load %input[%col, %row] : memref<4x8xf32, strided<[1, 4], offset: 100>>
+  return %value : f32
+}
+
+// CHECK: #[[MAP:.*]] = affine_map<()[s0, s1] -> (s0 + s1 * 4)>
+// CHECK: func @load_scalar_from_memref_static_dim_col_major
+// CHECK-SAME: (%[[ARG0:.*]]: memref<4x8xf32, strided<[1, 4], offset: 100>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[ARG1]]]
+// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [100], sizes: [32], strides: [1] : memref<4x8xf32, strided<[1, 4], offset: 100>> to memref<32xf32, strided<[1], offset: 100>>
+// CHECK: memref.load %[[REINT]][%[[IDX]]] : memref<32xf32, strided<[1], offset: 100>>


        


More information about the Mlir-commits mailing list