[Mlir-commits] [mlir] 2b5b2bf - [mlir][gpu] Add DecomposeMemrefsPass
Ivan Butygin
llvmlistbot at llvm.org
Wed Aug 9 17:28:26 PDT 2023
Author: Ivan Butygin
Date: 2023-08-10T02:28:03+02:00
New Revision: 2b5b2bfef102b1021d91f2b9485e2443bdea9df5
URL: https://github.com/llvm/llvm-project/commit/2b5b2bfef102b1021d91f2b9485e2443bdea9df5
DIFF: https://github.com/llvm/llvm-project/commit/2b5b2bfef102b1021d91f2b9485e2443bdea9df5.diff
LOG: [mlir][gpu] Add DecomposeMemrefsPass
Some GPU backends (SPIR-V) lower memrefs to bare pointers, so for dynamically sized/strided memrefs it will fail.
This pass extracts sizes and strides via `memref.extract_strrided_metadata` outside `gpu.launch` body and do index/offset calculation explicitly and then reconstructs memrefs via `memref.reinterpret_cast`.
`memref.reinterpret_cast` then lowered via https://reviews.llvm.org/D155011
Differential Revision: https://reviews.llvm.org/D155247
Added:
mlir/lib/Dialect/GPU/Transforms/DecomposeMemrefs.cpp
mlir/test/Dialect/GPU/decompose-memrefs.mlir
Modified:
mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
mlir/include/mlir/Dialect/GPU/Transforms/Passes.td
mlir/include/mlir/Dialect/Utils/IndexingUtils.h
mlir/lib/Dialect/GPU/CMakeLists.txt
mlir/lib/Dialect/Utils/IndexingUtils.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
index 1afbcb2128d490..970dfea4677d83 100644
--- a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
@@ -150,6 +150,12 @@ std::unique_ptr<Pass> createGpuSerializeToHsacoPass(StringRef triple,
StringRef features,
int optLevel);
+/// Collect a set of patterns to decompose memrefs ops.
+void populateGpuDecomposeMemrefsPatterns(RewritePatternSet &patterns);
+
+/// Pass decomposes memref ops inside `gpu.launch` body.
+std::unique_ptr<Pass> createGpuDecomposeMemrefsPass();
+
/// Generate the code for registering passes.
#define GEN_PASS_REGISTRATION
#include "mlir/Dialect/GPU/Transforms/Passes.h.inc"
diff --git a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td
index 7ee90b5d0f8437..7602f8bcc6a482 100644
--- a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td
@@ -37,4 +37,22 @@ def GpuMapParallelLoopsPass
let dependentDialects = ["mlir::gpu::GPUDialect"];
}
+def GpuDecomposeMemrefsPass : Pass<"gpu-decompose-memrefs"> {
+ let summary = "Decomposes memref index computation into explicit ops.";
+ let description = [{
+ This pass decomposes memref index computation into explicit computations on
+ sizes/strides, obtained from `memref.extract_memref_metadata` which it tries
+ to place outside of `gpu.launch` body. Memrefs are then reconstructed using
+ `memref.reinterpret_cast`.
+ This is needed for as some targets (SPIR-V) lower memrefs to bare pointers
+ and sizes/strides for dynamically-sized memrefs are not available inside
+ `gpu.launch`.
+ }];
+ let constructor = "mlir::createGpuDecomposeMemrefsPass()";
+ let dependentDialects = [
+ "mlir::gpu::GPUDialect", "mlir::memref::MemRefDialect",
+ "mlir::affine::AffineDialect"
+ ];
+}
+
#endif // MLIR_DIALECT_GPU_PASSES
diff --git a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
index 72becd8cc01c43..56d028a2576b52 100644
--- a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
@@ -229,6 +229,12 @@ computePermutationVector(int64_t permSize, ArrayRef<int64_t> positions,
SmallVector<int64_t> getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront = 0,
unsigned dropBack = 0);
+/// Compute linear index from provided strides and indices, assuming strided
+/// layout.
+OpFoldResult computeLinearIndex(OpBuilder &builder, Location loc,
+ OpFoldResult sourceOffset,
+ ArrayRef<OpFoldResult> strides,
+ ArrayRef<OpFoldResult> indices);
} // namespace mlir
#endif // MLIR_DIALECT_UTILS_INDEXINGUTILS_H
diff --git a/mlir/lib/Dialect/GPU/CMakeLists.txt b/mlir/lib/Dialect/GPU/CMakeLists.txt
index f3c518fd93066e..81d7bf96bbf4c9 100644
--- a/mlir/lib/Dialect/GPU/CMakeLists.txt
+++ b/mlir/lib/Dialect/GPU/CMakeLists.txt
@@ -47,14 +47,15 @@ add_mlir_dialect_library(MLIRGPUDialect
add_mlir_dialect_library(MLIRGPUTransforms
Transforms/AllReduceLowering.cpp
Transforms/AsyncRegionRewriter.cpp
+ Transforms/DecomposeMemrefs.cpp
Transforms/GlobalIdRewriter.cpp
Transforms/KernelOutlining.cpp
Transforms/MemoryPromotion.cpp
Transforms/ParallelLoopMapper.cpp
- Transforms/ShuffleRewriter.cpp
Transforms/SerializeToBlob.cpp
Transforms/SerializeToCubin.cpp
Transforms/SerializeToHsaco.cpp
+ Transforms/ShuffleRewriter.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/GPU
diff --git a/mlir/lib/Dialect/GPU/Transforms/DecomposeMemrefs.cpp b/mlir/lib/Dialect/GPU/Transforms/DecomposeMemrefs.cpp
new file mode 100644
index 00000000000000..1e255635edb29d
--- /dev/null
+++ b/mlir/lib/Dialect/GPU/Transforms/DecomposeMemrefs.cpp
@@ -0,0 +1,232 @@
+//===- DecomposeMemrefs.cpp - Decompose memrefs pass implementation -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements decompose memrefs pass.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/GPU/Transforms/Passes.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_GPUDECOMPOSEMEMREFSPASS
+#include "mlir/Dialect/GPU/Transforms/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+static void setInsertionPointToStart(OpBuilder &builder, Value val) {
+ if (auto parentOp = val.getDefiningOp()) {
+ builder.setInsertionPointAfter(parentOp);
+ } else {
+ builder.setInsertionPointToStart(val.getParentBlock());
+ }
+}
+
+static bool isInsideLaunch(Operation *op) {
+ return op->getParentOfType<gpu::LaunchOp>();
+}
+
+static std::tuple<Value, OpFoldResult, SmallVector<OpFoldResult>>
+getFlatOffsetAndStrides(OpBuilder &rewriter, Location loc, Value source,
+ ArrayRef<OpFoldResult> subOffsets,
+ ArrayRef<OpFoldResult> subStrides = std::nullopt) {
+ auto sourceType = cast<MemRefType>(source.getType());
+ auto sourceRank = static_cast<unsigned>(sourceType.getRank());
+
+ memref::ExtractStridedMetadataOp newExtractStridedMetadata;
+ {
+ OpBuilder::InsertionGuard g(rewriter);
+ setInsertionPointToStart(rewriter, source);
+ newExtractStridedMetadata =
+ rewriter.create<memref::ExtractStridedMetadataOp>(loc, source);
+ }
+
+ auto &&[sourceStrides, sourceOffset] = getStridesAndOffset(sourceType);
+
+ auto getDim = [&](int64_t dim, Value dimVal) -> OpFoldResult {
+ return ShapedType::isDynamic(dim) ? getAsOpFoldResult(dimVal)
+ : rewriter.getIndexAttr(dim);
+ };
+
+ OpFoldResult origOffset =
+ getDim(sourceOffset, newExtractStridedMetadata.getOffset());
+ ValueRange sourceStridesVals = newExtractStridedMetadata.getStrides();
+
+ SmallVector<OpFoldResult> origStrides;
+ origStrides.reserve(sourceRank);
+
+ SmallVector<OpFoldResult> strides;
+ strides.reserve(sourceRank);
+
+ AffineExpr s0 = rewriter.getAffineSymbolExpr(0);
+ AffineExpr s1 = rewriter.getAffineSymbolExpr(1);
+ for (auto i : llvm::seq(0u, sourceRank)) {
+ OpFoldResult origStride = getDim(sourceStrides[i], sourceStridesVals[i]);
+
+ if (!subStrides.empty()) {
+ strides.push_back(affine::makeComposedFoldedAffineApply(
+ rewriter, loc, s0 * s1, {subStrides[i], origStride}));
+ }
+
+ origStrides.emplace_back(origStride);
+ }
+
+ OpFoldResult finalOffset =
+ computeLinearIndex(rewriter, loc, origOffset, origStrides, subOffsets);
+ return {newExtractStridedMetadata.getBaseBuffer(), finalOffset, strides};
+}
+
+static Value getFlatMemref(OpBuilder &rewriter, Location loc, Value source,
+ ValueRange offsets) {
+ SmallVector<OpFoldResult> offsetsTemp = getAsOpFoldResult(offsets);
+ auto &&[base, offset, ignore] =
+ getFlatOffsetAndStrides(rewriter, loc, source, offsetsTemp);
+ auto retType = cast<MemRefType>(base.getType());
+ return rewriter.create<memref::ReinterpretCastOp>(loc, retType, base, offset,
+ std::nullopt, std::nullopt);
+}
+
+static bool needFlatten(Value val) {
+ auto type = cast<MemRefType>(val.getType());
+ return type.getRank() != 0;
+}
+
+static bool checkLayout(Value val) {
+ auto type = cast<MemRefType>(val.getType());
+ return type.getLayout().isIdentity() ||
+ isa<StridedLayoutAttr>(type.getLayout());
+}
+
+namespace {
+struct FlattenLoad : public OpRewritePattern<memref::LoadOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(memref::LoadOp op,
+ PatternRewriter &rewriter) const override {
+ if (!isInsideLaunch(op))
+ return rewriter.notifyMatchFailure(op, "not inside gpu.launch");
+
+ Value memref = op.getMemref();
+ if (!needFlatten(memref))
+ return rewriter.notifyMatchFailure(op, "nothing to do");
+
+ if (!checkLayout(memref))
+ return rewriter.notifyMatchFailure(op, "unsupported layout");
+
+ Location loc = op.getLoc();
+ Value flatMemref = getFlatMemref(rewriter, loc, memref, op.getIndices());
+ rewriter.replaceOpWithNewOp<memref::LoadOp>(op, flatMemref);
+ return success();
+ }
+};
+
+struct FlattenStore : public OpRewritePattern<memref::StoreOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(memref::StoreOp op,
+ PatternRewriter &rewriter) const override {
+ if (!isInsideLaunch(op))
+ return rewriter.notifyMatchFailure(op, "not inside gpu.launch");
+
+ Value memref = op.getMemref();
+ if (!needFlatten(memref))
+ return rewriter.notifyMatchFailure(op, "nothing to do");
+
+ if (!checkLayout(memref))
+ return rewriter.notifyMatchFailure(op, "unsupported layout");
+
+ Location loc = op.getLoc();
+ Value flatMemref = getFlatMemref(rewriter, loc, memref, op.getIndices());
+ Value value = op.getValue();
+ rewriter.replaceOpWithNewOp<memref::StoreOp>(op, value, flatMemref);
+ return success();
+ }
+};
+
+struct FlattenSubview : public OpRewritePattern<memref::SubViewOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(memref::SubViewOp op,
+ PatternRewriter &rewriter) const override {
+ if (!isInsideLaunch(op))
+ return rewriter.notifyMatchFailure(op, "not inside gpu.launch");
+
+ Value memref = op.getSource();
+ if (!needFlatten(memref))
+ return rewriter.notifyMatchFailure(op, "nothing to do");
+
+ if (!checkLayout(memref))
+ return rewriter.notifyMatchFailure(op, "unsupported layout");
+
+ Location loc = op.getLoc();
+ SmallVector<OpFoldResult> subOffsets = op.getMixedOffsets();
+ SmallVector<OpFoldResult> subSizes = op.getMixedSizes();
+ SmallVector<OpFoldResult> subStrides = op.getMixedStrides();
+ auto &&[base, finalOffset, strides] =
+ getFlatOffsetAndStrides(rewriter, loc, memref, subOffsets, subStrides);
+
+ auto srcType = cast<MemRefType>(memref.getType());
+ auto resultType = cast<MemRefType>(op.getType());
+ unsigned subRank = static_cast<unsigned>(resultType.getRank());
+
+ llvm::SmallBitVector droppedDims = op.getDroppedDims();
+
+ SmallVector<OpFoldResult> finalSizes;
+ finalSizes.reserve(subRank);
+
+ SmallVector<OpFoldResult> finalStrides;
+ finalStrides.reserve(subRank);
+
+ for (auto i : llvm::seq(0u, static_cast<unsigned>(srcType.getRank()))) {
+ if (droppedDims.test(i))
+ continue;
+
+ finalSizes.push_back(subSizes[i]);
+ finalStrides.push_back(strides[i]);
+ }
+
+ rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
+ op, resultType, base, finalOffset, finalSizes, finalStrides);
+ return success();
+ }
+};
+
+struct GpuDecomposeMemrefsPass
+ : public impl::GpuDecomposeMemrefsPassBase<GpuDecomposeMemrefsPass> {
+
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+
+ populateGpuDecomposeMemrefsPatterns(patterns);
+
+ if (failed(
+ applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
+ return signalPassFailure();
+ }
+};
+
+} // namespace
+
+void mlir::populateGpuDecomposeMemrefsPatterns(RewritePatternSet &patterns) {
+ patterns.insert<FlattenLoad, FlattenStore, FlattenSubview>(
+ patterns.getContext());
+}
+
+std::unique_ptr<Pass> mlir::createGpuDecomposeMemrefsPass() {
+ return std::make_unique<GpuDecomposeMemrefsPass>();
+}
diff --git a/mlir/lib/Dialect/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
index 2a774b599a8b68..a344b01a958946 100644
--- a/mlir/lib/Dialect/Utils/IndexingUtils.cpp
+++ b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
@@ -8,6 +8,7 @@
#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
@@ -261,3 +262,34 @@ SmallVector<int64_t> mlir::getI64SubArray(ArrayAttr arrayAttr,
res.push_back((*it).getValue().getSExtValue());
return res;
}
+
+OpFoldResult mlir::computeLinearIndex(OpBuilder &builder, Location loc,
+ OpFoldResult sourceOffset,
+ ArrayRef<OpFoldResult> strides,
+ ArrayRef<OpFoldResult> indices) {
+ assert(strides.size() == indices.size());
+ auto sourceRank = static_cast<unsigned>(strides.size());
+
+ // Hold the affine symbols and values for the computation of the offset.
+ SmallVector<OpFoldResult> values(2 * sourceRank + 1);
+ SmallVector<AffineExpr> symbols(2 * sourceRank + 1);
+
+ bindSymbolsList(builder.getContext(), MutableArrayRef{symbols});
+ AffineExpr expr = symbols.front();
+ values[0] = sourceOffset;
+
+ for (unsigned i = 0; i < sourceRank; ++i) {
+ // Compute the stride.
+ OpFoldResult origStride = strides[i];
+
+ // Build up the computation of the offset.
+ unsigned baseIdxForDim = 1 + 2 * i;
+ unsigned subOffsetForDim = baseIdxForDim;
+ unsigned origStrideForDim = baseIdxForDim + 1;
+ expr = expr + symbols[subOffsetForDim] * symbols[origStrideForDim];
+ values[subOffsetForDim] = indices[i];
+ values[origStrideForDim] = origStride;
+ }
+
+ return affine::makeComposedFoldedAffineApply(builder, loc, expr, values);
+}
diff --git a/mlir/test/Dialect/GPU/decompose-memrefs.mlir b/mlir/test/Dialect/GPU/decompose-memrefs.mlir
new file mode 100644
index 00000000000000..d714010d0f254b
--- /dev/null
+++ b/mlir/test/Dialect/GPU/decompose-memrefs.mlir
@@ -0,0 +1,137 @@
+// RUN: mlir-opt -gpu-decompose-memrefs -allow-unregistered-dialect -split-input-file %s | FileCheck %s
+
+// CHECK: #[[MAP:.*]] = affine_map<()[s0, s1, s2, s3, s4] -> (s0 * s1 + s2 * s3 + s4)>
+// CHECK: @decompose_store
+// CHECK-SAME: (%[[VAL:.*]]: f32, %[[MEM:.*]]: memref<?x?x?xf32>)
+// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:3, %[[STRIDES:.*]]:3 = memref.extract_strided_metadata %[[MEM]]
+// CHECK: gpu.launch
+// CHECK-SAME: threads(%[[TX:.*]], %[[TY:.*]], %[[TZ:.*]]) in
+// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[TX]], %[[STRIDES]]#0, %[[TY]], %[[STRIDES]]#1, %[[TZ]]]
+// CHECK: %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX]]], sizes: [], strides: [] : memref<f32> to memref<f32>
+// CHECK: memref.store %[[VAL]], %[[PTR]][] : memref<f32>
+func.func @decompose_store(%arg0 : f32, %arg1 : memref<?x?x?xf32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %block_dim0 = memref.dim %arg1, %c0 : memref<?x?x?xf32>
+ %block_dim1 = memref.dim %arg1, %c1 : memref<?x?x?xf32>
+ %block_dim2 = memref.dim %arg1, %c2 : memref<?x?x?xf32>
+ gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
+ threads(%tx, %ty, %tz) in (%block_x = %block_dim0, %block_y = %block_dim1, %block_z = %block_dim2) {
+ memref.store %arg0, %arg1[%tx, %ty, %tz] : memref<?x?x?xf32>
+ gpu.terminator
+ }
+ return
+}
+
+// -----
+
+// CHECK: #[[MAP:.*]] = affine_map<()[s0, s1, s2, s3, s4, s5, s6] -> (s0 + s1 * s2 + s3 * s4 + s5 * s6)>
+// CHECK: @decompose_store_strided
+// CHECK-SAME: (%[[VAL:.*]]: f32, %[[MEM:.*]]: memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>)
+// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:3, %[[STRIDES:.*]]:3 = memref.extract_strided_metadata %[[MEM]]
+// CHECK: gpu.launch
+// CHECK-SAME: threads(%[[TX:.*]], %[[TY:.*]], %[[TZ:.*]]) in
+// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[OFFSET]], %[[TX]], %[[STRIDES]]#0, %[[TY]], %[[STRIDES]]#1, %[[TZ]], %[[STRIDES]]#2]
+// CHECK: %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX]]], sizes: [], strides: [] : memref<f32> to memref<f32>
+// CHECK: memref.store %[[VAL]], %[[PTR]][] : memref<f32>
+func.func @decompose_store_strided(%arg0 : f32, %arg1 : memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %block_dim0 = memref.dim %arg1, %c0 : memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
+ %block_dim1 = memref.dim %arg1, %c1 : memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
+ %block_dim2 = memref.dim %arg1, %c2 : memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
+ gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
+ threads(%tx, %ty, %tz) in (%block_x = %block_dim0, %block_y = %block_dim1, %block_z = %block_dim2) {
+ memref.store %arg0, %arg1[%tx, %ty, %tz] : memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
+ gpu.terminator
+ }
+ return
+}
+
+// -----
+
+// CHECK: #[[MAP:.*]] = affine_map<()[s0, s1, s2, s3, s4] -> (s0 * s1 + s2 * s3 + s4)>
+// CHECK: @decompose_load
+// CHECK-SAME: (%[[MEM:.*]]: memref<?x?x?xf32>)
+// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:3, %[[STRIDES:.*]]:3 = memref.extract_strided_metadata %[[MEM]]
+// CHECK: gpu.launch
+// CHECK-SAME: threads(%[[TX:.*]], %[[TY:.*]], %[[TZ:.*]]) in
+// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[TX]], %[[STRIDES]]#0, %[[TY]], %[[STRIDES]]#1, %[[TZ]]]
+// CHECK: %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX]]], sizes: [], strides: [] : memref<f32> to memref<f32>
+// CHECK: %[[RES:.*]] = memref.load %[[PTR]][] : memref<f32>
+// CHECK: "test.test"(%[[RES]]) : (f32) -> ()
+func.func @decompose_load(%arg0 : memref<?x?x?xf32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %block_dim0 = memref.dim %arg0, %c0 : memref<?x?x?xf32>
+ %block_dim1 = memref.dim %arg0, %c1 : memref<?x?x?xf32>
+ %block_dim2 = memref.dim %arg0, %c2 : memref<?x?x?xf32>
+ gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
+ threads(%tx, %ty, %tz) in (%block_x = %block_dim0, %block_y = %block_dim1, %block_z = %block_dim2) {
+ %res = memref.load %arg0[%tx, %ty, %tz] : memref<?x?x?xf32>
+ "test.test"(%res) : (f32) -> ()
+ gpu.terminator
+ }
+ return
+}
+
+// -----
+
+// CHECK: #[[MAP:.*]] = affine_map<()[s0, s1, s2, s3, s4] -> (s0 * s1 + s2 * s3 + s4)>
+// CHECK: @decompose_subview
+// CHECK-SAME: (%[[MEM:.*]]: memref<?x?x?xf32>)
+// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:3, %[[STRIDES:.*]]:3 = memref.extract_strided_metadata %[[MEM]]
+// CHECK: gpu.launch
+// CHECK-SAME: threads(%[[TX:.*]], %[[TY:.*]], %[[TZ:.*]]) in
+// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[TX]], %[[STRIDES]]#0, %[[TY]], %[[STRIDES]]#1, %[[TZ]]]
+// CHECK: %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX]]], sizes: [%{{.*}}, %{{.*}}, %{{.*}}], strides: [%[[STRIDES]]#0, %[[STRIDES]]#1, 1]
+// CHECK: "test.test"(%[[PTR]]) : (memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>) -> ()
+func.func @decompose_subview(%arg0 : memref<?x?x?xf32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %block_dim0 = memref.dim %arg0, %c0 : memref<?x?x?xf32>
+ %block_dim1 = memref.dim %arg0, %c1 : memref<?x?x?xf32>
+ %block_dim2 = memref.dim %arg0, %c2 : memref<?x?x?xf32>
+ gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
+ threads(%tx, %ty, %tz) in (%block_x = %block_dim0, %block_y = %block_dim1, %block_z = %block_dim2) {
+ %res = memref.subview %arg0[%tx, %ty, %tz] [%c2, %c2, %c2] [%c1, %c1, %c1] : memref<?x?x?xf32> to memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
+ "test.test"(%res) : (memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>) -> ()
+ gpu.terminator
+ }
+ return
+}
+
+// -----
+
+// CHECK: #[[MAP:.*]] = affine_map<()[s0] -> (s0 * 2)>
+// CHECK: #[[MAP1:.*]] = affine_map<()[s0] -> (s0 * 3)>
+// CHECK: #[[MAP2:.*]] = affine_map<()[s0, s1, s2, s3, s4] -> (s0 * s1 + s2 * s3 + s4)>
+// CHECK: @decompose_subview_strided
+// CHECK-SAME: (%[[MEM:.*]]: memref<?x?x?xf32>)
+// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:3, %[[STRIDES:.*]]:3 = memref.extract_strided_metadata %[[MEM]]
+// CHECK: gpu.launch
+// CHECK-SAME: threads(%[[TX:.*]], %[[TY:.*]], %[[TZ:.*]]) in
+// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[STRIDES]]#0]
+// CHECK: %[[IDX1:.*]] = affine.apply #[[MAP1]]()[%[[STRIDES]]#1]
+// CHECK: %[[IDX2:.*]] = affine.apply #[[MAP2]]()[%[[TX]], %[[STRIDES]]#0, %[[TY]], %[[STRIDES]]#1, %[[TZ]]]
+// CHECK: %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX2]]], sizes: [%{{.*}}, %{{.*}}, %{{.*}}], strides: [%[[IDX]], %[[IDX1]], 4]
+// CHECK: "test.test"(%[[PTR]]) : (memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>) -> ()
+func.func @decompose_subview_strided(%arg0 : memref<?x?x?xf32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %block_dim0 = memref.dim %arg0, %c0 : memref<?x?x?xf32>
+ %block_dim1 = memref.dim %arg0, %c1 : memref<?x?x?xf32>
+ %block_dim2 = memref.dim %arg0, %c2 : memref<?x?x?xf32>
+ gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
+ threads(%tx, %ty, %tz) in (%block_x = %block_dim0, %block_y = %block_dim1, %block_z = %block_dim2) {
+ %res = memref.subview %arg0[%tx, %ty, %tz] [%c2, %c2, %c2] [2, 3, 4] : memref<?x?x?xf32> to memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
+ "test.test"(%res) : (memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>) -> ()
+ gpu.terminator
+ }
+ return
+}
More information about the Mlir-commits
mailing list