[Mlir-commits] [mlir] 793ee2b - [mlir][gpu] Add DecomposeMemrefsPass

Ivan Butygin llvmlistbot at llvm.org
Thu Aug 10 13:33:20 PDT 2023


Author: Ivan Butygin
Date: 2023-08-10T22:28:05+02:00
New Revision: 793ee2bf08680e221018f5707aca6eab121d1a41

URL: https://github.com/llvm/llvm-project/commit/793ee2bf08680e221018f5707aca6eab121d1a41
DIFF: https://github.com/llvm/llvm-project/commit/793ee2bf08680e221018f5707aca6eab121d1a41.diff

LOG: [mlir][gpu] Add DecomposeMemrefsPass

Some GPU backends (SPIR-V) lower memrefs to bare pointers, so for dynamically sized/strided memrefs it will fail.
This pass extracts sizes and strides via `memref.extract_strrided_metadata` outside `gpu.launch` body and do index/offset calculation explicitly and then reconstructs memrefs via `memref.reinterpret_cast`.

`memref.reinterpret_cast` then lowered via https://reviews.llvm.org/D155011

Differential Revision: https://reviews.llvm.org/D155247

Added: 
    mlir/lib/Dialect/GPU/Transforms/DecomposeMemrefs.cpp
    mlir/test/Dialect/GPU/decompose-memrefs.mlir

Modified: 
    mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
    mlir/include/mlir/Dialect/GPU/Transforms/Passes.td
    mlir/include/mlir/Dialect/Utils/IndexingUtils.h
    mlir/lib/Dialect/GPU/CMakeLists.txt
    mlir/lib/Dialect/Utils/IndexingUtils.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
index 1afbcb2128d490..970dfea4677d83 100644
--- a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
@@ -150,6 +150,12 @@ std::unique_ptr<Pass> createGpuSerializeToHsacoPass(StringRef triple,
                                                     StringRef features,
                                                     int optLevel);
 
+/// Collect a set of patterns to decompose memrefs ops.
+void populateGpuDecomposeMemrefsPatterns(RewritePatternSet &patterns);
+
+/// Pass decomposes memref ops inside `gpu.launch` body.
+std::unique_ptr<Pass> createGpuDecomposeMemrefsPass();
+
 /// Generate the code for registering passes.
 #define GEN_PASS_REGISTRATION
 #include "mlir/Dialect/GPU/Transforms/Passes.h.inc"

diff  --git a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td
index 7ee90b5d0f8437..7602f8bcc6a482 100644
--- a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td
@@ -37,4 +37,22 @@ def GpuMapParallelLoopsPass
   let dependentDialects = ["mlir::gpu::GPUDialect"];
 }
 
+def GpuDecomposeMemrefsPass : Pass<"gpu-decompose-memrefs"> {
+  let summary = "Decomposes memref index computation into explicit ops.";
+  let description = [{
+    This pass decomposes memref index computation into explicit computations on
+    sizes/strides, obtained from `memref.extract_memref_metadata` which it tries
+    to place outside of `gpu.launch` body. Memrefs are then reconstructed using
+    `memref.reinterpret_cast`.
+    This is needed for as some targets (SPIR-V) lower memrefs to bare pointers
+    and sizes/strides for dynamically-sized memrefs are not available inside
+    `gpu.launch`.
+  }];
+  let constructor = "mlir::createGpuDecomposeMemrefsPass()";
+  let dependentDialects = [
+    "mlir::gpu::GPUDialect", "mlir::memref::MemRefDialect",
+    "mlir::affine::AffineDialect"
+  ];
+}
+
 #endif // MLIR_DIALECT_GPU_PASSES

diff  --git a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
index 72becd8cc01c43..cb8419374c43e3 100644
--- a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
@@ -229,6 +229,16 @@ computePermutationVector(int64_t permSize, ArrayRef<int64_t> positions,
 SmallVector<int64_t> getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront = 0,
                                     unsigned dropBack = 0);
 
+/// Compute linear index from provided strides and indices, assuming strided
+/// layout.
+/// Returns AffineExpr and list of values to apply to it, e.g.:
+///
+/// auto &&[expr, values] = computeLinearIndex(...);
+/// offset = affine::makeComposedFoldedAffineApply(builder, loc, expr, values);
+std::pair<AffineExpr, SmallVector<OpFoldResult>>
+computeLinearIndex(OpFoldResult sourceOffset, ArrayRef<OpFoldResult> strides,
+                   ArrayRef<OpFoldResult> indices);
+
 } // namespace mlir
 
 #endif // MLIR_DIALECT_UTILS_INDEXINGUTILS_H

diff  --git a/mlir/lib/Dialect/GPU/CMakeLists.txt b/mlir/lib/Dialect/GPU/CMakeLists.txt
index f3c518fd93066e..81d7bf96bbf4c9 100644
--- a/mlir/lib/Dialect/GPU/CMakeLists.txt
+++ b/mlir/lib/Dialect/GPU/CMakeLists.txt
@@ -47,14 +47,15 @@ add_mlir_dialect_library(MLIRGPUDialect
 add_mlir_dialect_library(MLIRGPUTransforms
   Transforms/AllReduceLowering.cpp
   Transforms/AsyncRegionRewriter.cpp
+  Transforms/DecomposeMemrefs.cpp
   Transforms/GlobalIdRewriter.cpp
   Transforms/KernelOutlining.cpp
   Transforms/MemoryPromotion.cpp
   Transforms/ParallelLoopMapper.cpp
-  Transforms/ShuffleRewriter.cpp
   Transforms/SerializeToBlob.cpp
   Transforms/SerializeToCubin.cpp
   Transforms/SerializeToHsaco.cpp
+  Transforms/ShuffleRewriter.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/GPU

diff  --git a/mlir/lib/Dialect/GPU/Transforms/DecomposeMemrefs.cpp b/mlir/lib/Dialect/GPU/Transforms/DecomposeMemrefs.cpp
new file mode 100644
index 00000000000000..4f06c81ccaca29
--- /dev/null
+++ b/mlir/lib/Dialect/GPU/Transforms/DecomposeMemrefs.cpp
@@ -0,0 +1,234 @@
+//===- DecomposeMemrefs.cpp - Decompose memrefs pass implementation -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements decompose memrefs pass.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/GPU/Transforms/Passes.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_GPUDECOMPOSEMEMREFSPASS
+#include "mlir/Dialect/GPU/Transforms/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+static void setInsertionPointToStart(OpBuilder &builder, Value val) {
+  if (auto parentOp = val.getDefiningOp()) {
+    builder.setInsertionPointAfter(parentOp);
+  } else {
+    builder.setInsertionPointToStart(val.getParentBlock());
+  }
+}
+
+static bool isInsideLaunch(Operation *op) {
+  return op->getParentOfType<gpu::LaunchOp>();
+}
+
+static std::tuple<Value, OpFoldResult, SmallVector<OpFoldResult>>
+getFlatOffsetAndStrides(OpBuilder &rewriter, Location loc, Value source,
+                        ArrayRef<OpFoldResult> subOffsets,
+                        ArrayRef<OpFoldResult> subStrides = std::nullopt) {
+  auto sourceType = cast<MemRefType>(source.getType());
+  auto sourceRank = static_cast<unsigned>(sourceType.getRank());
+
+  memref::ExtractStridedMetadataOp newExtractStridedMetadata;
+  {
+    OpBuilder::InsertionGuard g(rewriter);
+    setInsertionPointToStart(rewriter, source);
+    newExtractStridedMetadata =
+        rewriter.create<memref::ExtractStridedMetadataOp>(loc, source);
+  }
+
+  auto &&[sourceStrides, sourceOffset] = getStridesAndOffset(sourceType);
+
+  auto getDim = [&](int64_t dim, Value dimVal) -> OpFoldResult {
+    return ShapedType::isDynamic(dim) ? getAsOpFoldResult(dimVal)
+                                      : rewriter.getIndexAttr(dim);
+  };
+
+  OpFoldResult origOffset =
+      getDim(sourceOffset, newExtractStridedMetadata.getOffset());
+  ValueRange sourceStridesVals = newExtractStridedMetadata.getStrides();
+
+  SmallVector<OpFoldResult> origStrides;
+  origStrides.reserve(sourceRank);
+
+  SmallVector<OpFoldResult> strides;
+  strides.reserve(sourceRank);
+
+  AffineExpr s0 = rewriter.getAffineSymbolExpr(0);
+  AffineExpr s1 = rewriter.getAffineSymbolExpr(1);
+  for (auto i : llvm::seq(0u, sourceRank)) {
+    OpFoldResult origStride = getDim(sourceStrides[i], sourceStridesVals[i]);
+
+    if (!subStrides.empty()) {
+      strides.push_back(affine::makeComposedFoldedAffineApply(
+          rewriter, loc, s0 * s1, {subStrides[i], origStride}));
+    }
+
+    origStrides.emplace_back(origStride);
+  }
+
+  auto &&[expr, values] =
+      computeLinearIndex(origOffset, origStrides, subOffsets);
+  OpFoldResult finalOffset =
+      affine::makeComposedFoldedAffineApply(rewriter, loc, expr, values);
+  return {newExtractStridedMetadata.getBaseBuffer(), finalOffset, strides};
+}
+
+static Value getFlatMemref(OpBuilder &rewriter, Location loc, Value source,
+                           ValueRange offsets) {
+  SmallVector<OpFoldResult> offsetsTemp = getAsOpFoldResult(offsets);
+  auto &&[base, offset, ignore] =
+      getFlatOffsetAndStrides(rewriter, loc, source, offsetsTemp);
+  auto retType = cast<MemRefType>(base.getType());
+  return rewriter.create<memref::ReinterpretCastOp>(loc, retType, base, offset,
+                                                    std::nullopt, std::nullopt);
+}
+
+static bool needFlatten(Value val) {
+  auto type = cast<MemRefType>(val.getType());
+  return type.getRank() != 0;
+}
+
+static bool checkLayout(Value val) {
+  auto type = cast<MemRefType>(val.getType());
+  return type.getLayout().isIdentity() ||
+         isa<StridedLayoutAttr>(type.getLayout());
+}
+
+namespace {
+struct FlattenLoad : public OpRewritePattern<memref::LoadOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(memref::LoadOp op,
+                                PatternRewriter &rewriter) const override {
+    if (!isInsideLaunch(op))
+      return rewriter.notifyMatchFailure(op, "not inside gpu.launch");
+
+    Value memref = op.getMemref();
+    if (!needFlatten(memref))
+      return rewriter.notifyMatchFailure(op, "nothing to do");
+
+    if (!checkLayout(memref))
+      return rewriter.notifyMatchFailure(op, "unsupported layout");
+
+    Location loc = op.getLoc();
+    Value flatMemref = getFlatMemref(rewriter, loc, memref, op.getIndices());
+    rewriter.replaceOpWithNewOp<memref::LoadOp>(op, flatMemref);
+    return success();
+  }
+};
+
+struct FlattenStore : public OpRewritePattern<memref::StoreOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(memref::StoreOp op,
+                                PatternRewriter &rewriter) const override {
+    if (!isInsideLaunch(op))
+      return rewriter.notifyMatchFailure(op, "not inside gpu.launch");
+
+    Value memref = op.getMemref();
+    if (!needFlatten(memref))
+      return rewriter.notifyMatchFailure(op, "nothing to do");
+
+    if (!checkLayout(memref))
+      return rewriter.notifyMatchFailure(op, "unsupported layout");
+
+    Location loc = op.getLoc();
+    Value flatMemref = getFlatMemref(rewriter, loc, memref, op.getIndices());
+    Value value = op.getValue();
+    rewriter.replaceOpWithNewOp<memref::StoreOp>(op, value, flatMemref);
+    return success();
+  }
+};
+
+struct FlattenSubview : public OpRewritePattern<memref::SubViewOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(memref::SubViewOp op,
+                                PatternRewriter &rewriter) const override {
+    if (!isInsideLaunch(op))
+      return rewriter.notifyMatchFailure(op, "not inside gpu.launch");
+
+    Value memref = op.getSource();
+    if (!needFlatten(memref))
+      return rewriter.notifyMatchFailure(op, "nothing to do");
+
+    if (!checkLayout(memref))
+      return rewriter.notifyMatchFailure(op, "unsupported layout");
+
+    Location loc = op.getLoc();
+    SmallVector<OpFoldResult> subOffsets = op.getMixedOffsets();
+    SmallVector<OpFoldResult> subSizes = op.getMixedSizes();
+    SmallVector<OpFoldResult> subStrides = op.getMixedStrides();
+    auto &&[base, finalOffset, strides] =
+        getFlatOffsetAndStrides(rewriter, loc, memref, subOffsets, subStrides);
+
+    auto srcType = cast<MemRefType>(memref.getType());
+    auto resultType = cast<MemRefType>(op.getType());
+    unsigned subRank = static_cast<unsigned>(resultType.getRank());
+
+    llvm::SmallBitVector droppedDims = op.getDroppedDims();
+
+    SmallVector<OpFoldResult> finalSizes;
+    finalSizes.reserve(subRank);
+
+    SmallVector<OpFoldResult> finalStrides;
+    finalStrides.reserve(subRank);
+
+    for (auto i : llvm::seq(0u, static_cast<unsigned>(srcType.getRank()))) {
+      if (droppedDims.test(i))
+        continue;
+
+      finalSizes.push_back(subSizes[i]);
+      finalStrides.push_back(strides[i]);
+    }
+
+    rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
+        op, resultType, base, finalOffset, finalSizes, finalStrides);
+    return success();
+  }
+};
+
+struct GpuDecomposeMemrefsPass
+    : public impl::GpuDecomposeMemrefsPassBase<GpuDecomposeMemrefsPass> {
+
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+
+    populateGpuDecomposeMemrefsPatterns(patterns);
+
+    if (failed(
+            applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
+      return signalPassFailure();
+  }
+};
+
+} // namespace
+
+void mlir::populateGpuDecomposeMemrefsPatterns(RewritePatternSet &patterns) {
+  patterns.insert<FlattenLoad, FlattenStore, FlattenSubview>(
+      patterns.getContext());
+}
+
+std::unique_ptr<Pass> mlir::createGpuDecomposeMemrefsPass() {
+  return std::make_unique<GpuDecomposeMemrefsPass>();
+}

diff  --git a/mlir/lib/Dialect/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
index 2a774b599a8b68..8afa51a2034557 100644
--- a/mlir/lib/Dialect/Utils/IndexingUtils.cpp
+++ b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
@@ -261,3 +261,44 @@ SmallVector<int64_t> mlir::getI64SubArray(ArrayAttr arrayAttr,
     res.push_back((*it).getValue().getSExtValue());
   return res;
 }
+
+// TODO: do we have any common utily for this?
+static MLIRContext *getContext(OpFoldResult val) {
+  assert(val && "Invalid value");
+  if (auto attr = dyn_cast<Attribute>(val)) {
+    return attr.getContext();
+  } else {
+    return cast<Value>(val).getContext();
+  }
+}
+
+std::pair<AffineExpr, SmallVector<OpFoldResult>>
+mlir::computeLinearIndex(OpFoldResult sourceOffset,
+                         ArrayRef<OpFoldResult> strides,
+                         ArrayRef<OpFoldResult> indices) {
+  assert(strides.size() == indices.size());
+  auto sourceRank = static_cast<unsigned>(strides.size());
+
+  // Hold the affine symbols and values for the computation of the offset.
+  SmallVector<OpFoldResult> values(2 * sourceRank + 1);
+  SmallVector<AffineExpr> symbols(2 * sourceRank + 1);
+
+  bindSymbolsList(getContext(sourceOffset), MutableArrayRef{symbols});
+  AffineExpr expr = symbols.front();
+  values[0] = sourceOffset;
+
+  for (unsigned i = 0; i < sourceRank; ++i) {
+    // Compute the stride.
+    OpFoldResult origStride = strides[i];
+
+    // Build up the computation of the offset.
+    unsigned baseIdxForDim = 1 + 2 * i;
+    unsigned subOffsetForDim = baseIdxForDim;
+    unsigned origStrideForDim = baseIdxForDim + 1;
+    expr = expr + symbols[subOffsetForDim] * symbols[origStrideForDim];
+    values[subOffsetForDim] = indices[i];
+    values[origStrideForDim] = origStride;
+  }
+
+  return {expr, values};
+}

diff  --git a/mlir/test/Dialect/GPU/decompose-memrefs.mlir b/mlir/test/Dialect/GPU/decompose-memrefs.mlir
new file mode 100644
index 00000000000000..d714010d0f254b
--- /dev/null
+++ b/mlir/test/Dialect/GPU/decompose-memrefs.mlir
@@ -0,0 +1,137 @@
+// RUN: mlir-opt -gpu-decompose-memrefs -allow-unregistered-dialect -split-input-file %s | FileCheck %s
+
+//       CHECK: #[[MAP:.*]] = affine_map<()[s0, s1, s2, s3, s4] -> (s0 * s1 + s2 * s3 + s4)>
+//       CHECK: @decompose_store
+//  CHECK-SAME: (%[[VAL:.*]]: f32, %[[MEM:.*]]: memref<?x?x?xf32>)
+//       CHECK:  %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:3, %[[STRIDES:.*]]:3 = memref.extract_strided_metadata %[[MEM]]
+//       CHECK:  gpu.launch
+//  CHECK-SAME:  threads(%[[TX:.*]], %[[TY:.*]], %[[TZ:.*]]) in
+//       CHECK:  %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[TX]], %[[STRIDES]]#0, %[[TY]], %[[STRIDES]]#1, %[[TZ]]]
+//       CHECK:  %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX]]], sizes: [], strides: [] : memref<f32> to memref<f32>
+//       CHECK:  memref.store %[[VAL]], %[[PTR]][] : memref<f32>
+func.func @decompose_store(%arg0 : f32, %arg1 : memref<?x?x?xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %block_dim0 = memref.dim %arg1, %c0 : memref<?x?x?xf32>
+  %block_dim1 = memref.dim %arg1, %c1 : memref<?x?x?xf32>
+  %block_dim2 = memref.dim %arg1, %c2 : memref<?x?x?xf32>
+  gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
+             threads(%tx, %ty, %tz) in (%block_x = %block_dim0, %block_y = %block_dim1, %block_z = %block_dim2) {
+    memref.store %arg0, %arg1[%tx, %ty, %tz] : memref<?x?x?xf32>
+    gpu.terminator
+  }
+  return
+}
+
+// -----
+
+//       CHECK: #[[MAP:.*]] = affine_map<()[s0, s1, s2, s3, s4, s5, s6] -> (s0 + s1 * s2 + s3 * s4 + s5 * s6)>
+//       CHECK: @decompose_store_strided
+//  CHECK-SAME: (%[[VAL:.*]]: f32, %[[MEM:.*]]: memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>)
+//       CHECK:  %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:3, %[[STRIDES:.*]]:3 = memref.extract_strided_metadata %[[MEM]]
+//       CHECK:  gpu.launch
+//  CHECK-SAME:  threads(%[[TX:.*]], %[[TY:.*]], %[[TZ:.*]]) in
+//       CHECK:  %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[OFFSET]], %[[TX]], %[[STRIDES]]#0, %[[TY]], %[[STRIDES]]#1, %[[TZ]], %[[STRIDES]]#2]
+//       CHECK:  %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX]]], sizes: [], strides: [] : memref<f32> to memref<f32>
+//       CHECK:  memref.store %[[VAL]], %[[PTR]][] : memref<f32>
+func.func @decompose_store_strided(%arg0 : f32, %arg1 : memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %block_dim0 = memref.dim %arg1, %c0 : memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
+  %block_dim1 = memref.dim %arg1, %c1 : memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
+  %block_dim2 = memref.dim %arg1, %c2 : memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
+  gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
+             threads(%tx, %ty, %tz) in (%block_x = %block_dim0, %block_y = %block_dim1, %block_z = %block_dim2) {
+    memref.store %arg0, %arg1[%tx, %ty, %tz] : memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
+    gpu.terminator
+  }
+  return
+}
+
+// -----
+
+//       CHECK: #[[MAP:.*]] = affine_map<()[s0, s1, s2, s3, s4] -> (s0 * s1 + s2 * s3 + s4)>
+//       CHECK: @decompose_load
+//  CHECK-SAME: (%[[MEM:.*]]: memref<?x?x?xf32>)
+//       CHECK:  %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:3, %[[STRIDES:.*]]:3 = memref.extract_strided_metadata %[[MEM]]
+//       CHECK:  gpu.launch
+//  CHECK-SAME:  threads(%[[TX:.*]], %[[TY:.*]], %[[TZ:.*]]) in
+//       CHECK:  %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[TX]], %[[STRIDES]]#0, %[[TY]], %[[STRIDES]]#1, %[[TZ]]]
+//       CHECK:  %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX]]], sizes: [], strides: [] : memref<f32> to memref<f32>
+//       CHECK:  %[[RES:.*]] = memref.load %[[PTR]][] : memref<f32>
+//       CHECK:  "test.test"(%[[RES]]) : (f32) -> ()
+func.func @decompose_load(%arg0 : memref<?x?x?xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %block_dim0 = memref.dim %arg0, %c0 : memref<?x?x?xf32>
+  %block_dim1 = memref.dim %arg0, %c1 : memref<?x?x?xf32>
+  %block_dim2 = memref.dim %arg0, %c2 : memref<?x?x?xf32>
+  gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
+             threads(%tx, %ty, %tz) in (%block_x = %block_dim0, %block_y = %block_dim1, %block_z = %block_dim2) {
+    %res = memref.load %arg0[%tx, %ty, %tz] : memref<?x?x?xf32>
+    "test.test"(%res) : (f32) -> ()
+    gpu.terminator
+  }
+  return
+}
+
+// -----
+
+//       CHECK: #[[MAP:.*]] = affine_map<()[s0, s1, s2, s3, s4] -> (s0 * s1 + s2 * s3 + s4)>
+//       CHECK: @decompose_subview
+//  CHECK-SAME: (%[[MEM:.*]]: memref<?x?x?xf32>)
+//       CHECK:  %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:3, %[[STRIDES:.*]]:3 = memref.extract_strided_metadata %[[MEM]]
+//       CHECK:  gpu.launch
+//  CHECK-SAME:  threads(%[[TX:.*]], %[[TY:.*]], %[[TZ:.*]]) in
+//       CHECK:  %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[TX]], %[[STRIDES]]#0, %[[TY]], %[[STRIDES]]#1, %[[TZ]]]
+//       CHECK:  %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX]]], sizes: [%{{.*}}, %{{.*}}, %{{.*}}], strides: [%[[STRIDES]]#0, %[[STRIDES]]#1, 1]
+//       CHECK:  "test.test"(%[[PTR]]) : (memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>) -> ()
+func.func @decompose_subview(%arg0 : memref<?x?x?xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %block_dim0 = memref.dim %arg0, %c0 : memref<?x?x?xf32>
+  %block_dim1 = memref.dim %arg0, %c1 : memref<?x?x?xf32>
+  %block_dim2 = memref.dim %arg0, %c2 : memref<?x?x?xf32>
+  gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
+             threads(%tx, %ty, %tz) in (%block_x = %block_dim0, %block_y = %block_dim1, %block_z = %block_dim2) {
+    %res = memref.subview %arg0[%tx, %ty, %tz] [%c2, %c2, %c2] [%c1, %c1, %c1] : memref<?x?x?xf32> to memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
+    "test.test"(%res) : (memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>) -> ()
+    gpu.terminator
+  }
+  return
+}
+
+// -----
+
+//       CHECK: #[[MAP:.*]] = affine_map<()[s0] -> (s0 * 2)>
+//       CHECK: #[[MAP1:.*]] = affine_map<()[s0] -> (s0 * 3)>
+//       CHECK: #[[MAP2:.*]] = affine_map<()[s0, s1, s2, s3, s4] -> (s0 * s1 + s2 * s3 + s4)>
+//       CHECK: @decompose_subview_strided
+//  CHECK-SAME: (%[[MEM:.*]]: memref<?x?x?xf32>)
+//       CHECK:  %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:3, %[[STRIDES:.*]]:3 = memref.extract_strided_metadata %[[MEM]]
+//       CHECK:  gpu.launch
+//  CHECK-SAME:  threads(%[[TX:.*]], %[[TY:.*]], %[[TZ:.*]]) in
+//       CHECK:  %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[STRIDES]]#0]
+//       CHECK:  %[[IDX1:.*]] = affine.apply #[[MAP1]]()[%[[STRIDES]]#1]
+//       CHECK:  %[[IDX2:.*]] = affine.apply #[[MAP2]]()[%[[TX]], %[[STRIDES]]#0, %[[TY]], %[[STRIDES]]#1, %[[TZ]]]
+//       CHECK:  %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX2]]], sizes: [%{{.*}}, %{{.*}}, %{{.*}}], strides: [%[[IDX]], %[[IDX1]], 4]
+//       CHECK:  "test.test"(%[[PTR]]) : (memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>) -> ()
+func.func @decompose_subview_strided(%arg0 : memref<?x?x?xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %block_dim0 = memref.dim %arg0, %c0 : memref<?x?x?xf32>
+  %block_dim1 = memref.dim %arg0, %c1 : memref<?x?x?xf32>
+  %block_dim2 = memref.dim %arg0, %c2 : memref<?x?x?xf32>
+  gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
+             threads(%tx, %ty, %tz) in (%block_x = %block_dim0, %block_y = %block_dim1, %block_z = %block_dim2) {
+    %res = memref.subview %arg0[%tx, %ty, %tz] [%c2, %c2, %c2] [2, 3, 4] : memref<?x?x?xf32> to memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
+    "test.test"(%res) : (memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>) -> ()
+    gpu.terminator
+  }
+  return
+}


        


More information about the Mlir-commits mailing list