[Mlir-commits] [mlir] 54cda2e - [mlir][MemRef] Add patterns to extract address computations
Quentin Colombet
llvmlistbot at llvm.org
Tue Mar 28 05:04:30 PDT 2023
Author: Quentin Colombet
Date: 2023-03-28T13:52:29+02:00
New Revision: 54cda2ec976a89fcf5157d78479a576b09922df7
URL: https://github.com/llvm/llvm-project/commit/54cda2ec976a89fcf5157d78479a576b09922df7
DIFF: https://github.com/llvm/llvm-project/commit/54cda2ec976a89fcf5157d78479a576b09922df7.diff
LOG: [mlir][MemRef] Add patterns to extract address computations
This patch adds patterns to rewrite memory accesses such that the resulting
accesses are only using a base pointer.
E.g.,
```mlir
memref.load %base[%off0, ...]
```
Will be rewritten in:
```mlir
%new_base = memref.subview %base[%off0,...][1,...][1,...]
memref.load %new_base[%c0,...]
```
The idea behind these patterns is to offer a way to more gradually lower
address computations.
These patterns are the exact opposite of FoldMemRefAliasOps.
I've implemented the support of only five operations in this patch:
- memref.load
- memref.store
- nvgpu.ldmatrix
- vector.transfer_read
- vector.transfer_write
Going forward we may want to provide an interface for these rewritings (and
the ones in FoldMemRefAliasOps.)
One step at a time!
Differential Revision: https://reviews.llvm.org/D146724
Added:
mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp
mlir/test/Dialect/MemRef/extract-address-computations.mlir
Modified:
mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
mlir/lib/Dialect/MemRef/TransformOps/CMakeLists.txt
mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
index ea7784eec0d36..a0b5a6858f04e 100644
--- a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
@@ -49,4 +49,48 @@ def MemRefMultiBufferOp : Op<Transform_Dialect, "memref.multibuffer",
"$target attr-dict `:` functional-type(operands, results)";
}
+def MemRefExtractAddressComputationsOp :
+ Op<Transform_Dialect, "memref.extract_address_computations",
+ [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+ TransformEachOpTrait, TransformOpInterface]> {
+ let summary = "Extract address computations from memory accesses";
+ let description = [{
+ Transformation that extracts address computations from instructions
+ with memory accesses such that these memory accesses use only a base
+ pointer.
+
+ For instance,
+ ```mlir
+ memref.load %base[%off0, ...]
+ ```
+
+ Will be rewritten in:
+ ```mlir
+ %new_base = memref.subview %base[%off0,...][1,...][1,...]
+ memref.load %new_base[%c0,...]
+ ```
+
+ Note: The current implementation requires that the input operation
+ is "isolated from above".
+
+ #### Return modes
+
+ This operation produces `definiteFailure` if the extraction fails for any
+ reason.
+ The operation always returns the handle to the target op that is expected
+ to be isolated from above.
+ }];
+
+ let arguments = (ins PDL_Operation:$target);
+ let results = (outs PDL_Operation:$transformed);
+
+ let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)";
+
+ let extraClassDeclaration = [{
+ ::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::Operation *target,
+ ::mlir::transform::ApplyToEachResultList &transformResults,
+ ::mlir::transform::TransformState &state);
+ }];
+}
#endif // MEMREF_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
new file mode 100644
index 0000000000000..18b12d6b31dc7
--- /dev/null
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
@@ -0,0 +1,40 @@
+//===- Transforms.h - MemRef Dialect transformations ------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+/// This header declares functions that assit transformations in the MemRef
+/// dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MEMREF_TRANSFORMS_TRANSFORMS_H
+#define MLIR_DIALECT_MEMREF_TRANSFORMS_TRANSFORMS_H
+
+namespace mlir {
+class RewritePatternSet;
+
+namespace memref {
+/// Appends patterns for extracting address computations from the instructions
+/// with memory accesses such that these memory accesses use only a base
+/// pointer.
+///
+/// For instance,
+/// ```mlir
+/// memref.load %base[%off0, ...]
+/// ```
+///
+/// Will be rewritten in:
+/// ```mlir
+/// %new_base = memref.subview %base[%off0,...][1,...][1,...]
+/// memref.load %new_base[%c0,...]
+/// ```
+void populateExtractAddressComputationsPatterns(RewritePatternSet &patterns);
+
+} // namespace memref
+} // namespace mlir
+
+#endif
diff --git a/mlir/lib/Dialect/MemRef/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/MemRef/TransformOps/CMakeLists.txt
index b98db40633e4e..b32e06aaa09fe 100644
--- a/mlir/lib/Dialect/MemRef/TransformOps/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/TransformOps/CMakeLists.txt
@@ -15,5 +15,7 @@ add_mlir_dialect_library(MLIRMemRefTransformOps
MLIRLoopLikeInterface
MLIRMemRefDialect
MLIRMemRefTransforms
+ MLIRNVGPUDialect
MLIRTransformDialect
+ MLIRVectorDialect
)
diff --git a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
index ae721fe641a84..3209b1bb83411 100644
--- a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
+++ b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
@@ -11,10 +11,14 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
+#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
+#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
#include "mlir/Dialect/PDL/IR/PDL.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/Debug.h"
using namespace mlir;
@@ -68,6 +72,31 @@ DiagnosedSilenceableFailure transform::MemRefMultiBufferOp::apply(
return DiagnosedSilenceableFailure::success();
}
+//===----------------------------------------------------------------------===//
+// MemRefExtractAddressComputationsOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::MemRefExtractAddressComputationsOp::applyToOne(
+ Operation *target, transform::ApplyToEachResultList &results,
+ transform::TransformState &state) {
+ if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
+ auto diag = this->emitOpError("requires isolated-from-above targets");
+ diag.attachNote(target->getLoc()) << "non-isolated target";
+ return DiagnosedSilenceableFailure::definiteFailure();
+ }
+
+ MLIRContext *ctx = getContext();
+ RewritePatternSet patterns(ctx);
+ memref::populateExtractAddressComputationsPatterns(patterns);
+
+ if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns))))
+ return emitDefaultDefiniteFailure(target);
+
+ results.push_back(target);
+ return DiagnosedSilenceableFailure::success();
+}
+
//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//
@@ -83,6 +112,9 @@ class MemRefTransformDialectExtension
declareDependentDialect<pdl::PDLDialect>();
declareGeneratedDialect<AffineDialect>();
declareGeneratedDialect<arith::ArithDialect>();
+ declareGeneratedDialect<memref::MemRefDialect>();
+ declareGeneratedDialect<nvgpu::NVGPUDialect>();
+ declareGeneratedDialect<vector::VectorDialect>();
registerTransformOps<
#define GET_OP_LIST
diff --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
index 744f5c647b228..0b01a1c864327 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
@@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRMemRefTransforms
ExpandOps.cpp
ExpandStridedMetadata.cpp
EmulateWideInt.cpp
+ ExtractAddressComputations.cpp
FoldMemRefAliasOps.cpp
MultiBuffer.cpp
NormalizeMemRefs.cpp
@@ -27,6 +28,7 @@ add_mlir_dialect_library(MLIRMemRefTransforms
MLIRInferTypeOpInterface
MLIRLoopLikeInterface
MLIRMemRefDialect
+ MLIRNVGPUDialect
MLIRPass
MLIRTensorDialect
MLIRTransforms
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp
new file mode 100644
index 0000000000000..5ef977f9add3d
--- /dev/null
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp
@@ -0,0 +1,313 @@
+//===- ExtractAddressCmoputations.cpp - Extract address computations -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+/// This transformation pass rewrites loading/storing from/to a memref with
+/// offsets into loading/storing from/to a subview and without any offset on
+/// the instruction itself.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
+#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/PatternMatch.h"
+
+using namespace mlir;
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// Helper functions for the `load base[off0...]`
+// => `load (subview base[off0...])[0...]` pattern.
+//===----------------------------------------------------------------------===//
+
+// Matches getFailureOrSrcMemRef specs for LoadOp.
+// \see LoadStoreLikeOpRewriter.
+static FailureOr<Value> getLoadOpSrcMemRef(memref::LoadOp loadOp) {
+ return loadOp.getMemRef();
+}
+
+// Matches rebuildOpFromAddressAndIndices specs for LoadOp.
+// \see LoadStoreLikeOpRewriter.
+static memref::LoadOp rebuildLoadOp(RewriterBase &rewriter,
+ memref::LoadOp loadOp, Value srcMemRef,
+ ArrayRef<Value> indices) {
+ Location loc = loadOp.getLoc();
+ return rewriter.create<memref::LoadOp>(loc, srcMemRef, indices,
+ loadOp.getNontemporal());
+}
+
+// Matches getViewSizeForEachDim specs for LoadOp.
+// \see LoadStoreLikeOpRewriter.
+static SmallVector<OpFoldResult>
+getLoadOpViewSizeForEachDim(RewriterBase &rewriter, memref::LoadOp loadOp) {
+ MemRefType ldTy = loadOp.getMemRefType();
+ unsigned loadRank = ldTy.getRank();
+ return SmallVector<OpFoldResult>(loadRank, rewriter.getIndexAttr(1));
+}
+
+//===----------------------------------------------------------------------===//
+// Helper functions for the `store val, base[off0...]`
+// => `store val, (subview base[off0...])[0...]` pattern.
+//===----------------------------------------------------------------------===//
+
+// Matches getFailureOrSrcMemRef specs for StoreOp.
+// \see LoadStoreLikeOpRewriter.
+static FailureOr<Value> getStoreOpSrcMemRef(memref::StoreOp storeOp) {
+ return storeOp.getMemRef();
+}
+
+// Matches rebuildOpFromAddressAndIndices specs for StoreOp.
+// \see LoadStoreLikeOpRewriter.
+static memref::StoreOp rebuildStoreOp(RewriterBase &rewriter,
+ memref::StoreOp storeOp, Value srcMemRef,
+ ArrayRef<Value> indices) {
+ Location loc = storeOp.getLoc();
+ return rewriter.create<memref::StoreOp>(loc, storeOp.getValueToStore(),
+ srcMemRef, indices,
+ storeOp.getNontemporal());
+}
+
+// Matches getViewSizeForEachDim specs for StoreOp.
+// \see LoadStoreLikeOpRewriter.
+static SmallVector<OpFoldResult>
+getStoreOpViewSizeForEachDim(RewriterBase &rewriter, memref::StoreOp storeOp) {
+ MemRefType ldTy = storeOp.getMemRefType();
+ unsigned loadRank = ldTy.getRank();
+ return SmallVector<OpFoldResult>(loadRank, rewriter.getIndexAttr(1));
+}
+
+//===----------------------------------------------------------------------===//
+// Helper functions for the `ldmatrix base[off0...]`
+// => `ldmatrix (subview base[off0...])[0...]` pattern.
+//===----------------------------------------------------------------------===//
+
+// Matches getFailureOrSrcMemRef specs for LdMatrixOp.
+// \see LoadStoreLikeOpRewriter.
+static FailureOr<Value> getLdMatrixOpSrcMemRef(nvgpu::LdMatrixOp ldMatrixOp) {
+ return ldMatrixOp.getSrcMemref();
+}
+
+// Matches rebuildOpFromAddressAndIndices specs for LdMatrixOp.
+// \see LoadStoreLikeOpRewriter.
+static nvgpu::LdMatrixOp rebuildLdMatrixOp(RewriterBase &rewriter,
+ nvgpu::LdMatrixOp ldMatrixOp,
+ Value srcMemRef,
+ ArrayRef<Value> indices) {
+ Location loc = ldMatrixOp.getLoc();
+ return rewriter.create<nvgpu::LdMatrixOp>(
+ loc, ldMatrixOp.getResult().getType(), srcMemRef, indices,
+ ldMatrixOp.getTranspose(), ldMatrixOp.getNumTiles());
+}
+
+//===----------------------------------------------------------------------===//
+// Helper functions for the `transfer_read base[off0...]`
+// => `transfer_read (subview base[off0...])[0...]` pattern.
+//===----------------------------------------------------------------------===//
+
+// Matches getFailureOrSrcMemRef specs for TransferReadOp.
+// \see LoadStoreLikeOpRewriter.
+template <typename TransferLikeOp>
+static FailureOr<Value>
+getTransferLikeOpSrcMemRef(TransferLikeOp transferLikeOp) {
+ Value src = transferLikeOp.getSource();
+ if (src.getType().isa<MemRefType>())
+ return src;
+ return failure();
+}
+
+// Matches rebuildOpFromAddressAndIndices specs for TransferReadOp.
+// \see LoadStoreLikeOpRewriter.
+static vector::TransferReadOp
+rebuildTransferReadOp(RewriterBase &rewriter,
+ vector::TransferReadOp transferReadOp, Value srcMemRef,
+ ArrayRef<Value> indices) {
+ Location loc = transferReadOp.getLoc();
+ return rewriter.create<vector::TransferReadOp>(
+ loc, transferReadOp.getResult().getType(), srcMemRef, indices,
+ transferReadOp.getPermutationMap(), transferReadOp.getPadding(),
+ transferReadOp.getMask(), transferReadOp.getInBoundsAttr());
+}
+
+//===----------------------------------------------------------------------===//
+// Helper functions for the `transfer_write base[off0...]`
+// => `transfer_write (subview base[off0...])[0...]` pattern.
+//===----------------------------------------------------------------------===//
+
+// Matches rebuildOpFromAddressAndIndices specs for TransferWriteOp.
+// \see LoadStoreLikeOpRewriter.
+static vector::TransferWriteOp
+rebuildTransferWriteOp(RewriterBase &rewriter,
+ vector::TransferWriteOp transferWriteOp, Value srcMemRef,
+ ArrayRef<Value> indices) {
+ Location loc = transferWriteOp.getLoc();
+ return rewriter.create<vector::TransferWriteOp>(
+ loc, transferWriteOp.getValue(), srcMemRef, indices,
+ transferWriteOp.getPermutationMapAttr(), transferWriteOp.getMask(),
+ transferWriteOp.getInBoundsAttr());
+}
+
+//===----------------------------------------------------------------------===//
+// Generic helper functions used as default implementation in
+// LoadStoreLikeOpRewriter.
+//===----------------------------------------------------------------------===//
+
+/// Helper function to get the src memref.
+/// It uses the already defined getFailureOrSrcMemRef but asserts
+/// that the source is a memref.
+template <typename LoadStoreLikeOp,
+ FailureOr<Value> (*getFailureOrSrcMemRef)(LoadStoreLikeOp)>
+static Value getSrcMemRef(LoadStoreLikeOp loadStoreLikeOp) {
+ FailureOr<Value> failureOrSrcMemRef = getFailureOrSrcMemRef(loadStoreLikeOp);
+ assert(!failed(failureOrSrcMemRef) && "Generic getSrcMemRef cannot be used");
+ return *failureOrSrcMemRef;
+}
+
+/// Helper function to get the sizes of the resulting view.
+/// This function gets the sizes of the source memref then substracts the
+/// offsets used within \p loadStoreLikeOp. This gives the maximal (for
+/// inbound) sizes for the view.
+/// The source memref is retrieved using getSrcMemRef on \p loadStoreLikeOp.
+template <typename LoadStoreLikeOp, Value (*getSrcMemRef)(LoadStoreLikeOp)>
+static SmallVector<OpFoldResult>
+getGenericOpViewSizeForEachDim(RewriterBase &rewriter,
+ LoadStoreLikeOp loadStoreLikeOp) {
+ Location loc = loadStoreLikeOp.getLoc();
+ auto extractStridedMetadataOp =
+ rewriter.create<memref::ExtractStridedMetadataOp>(
+ loc, getSrcMemRef(loadStoreLikeOp));
+ SmallVector<OpFoldResult> srcSizes =
+ extractStridedMetadataOp.getConstifiedMixedSizes();
+ SmallVector<OpFoldResult> indices =
+ getAsOpFoldResult(loadStoreLikeOp.getIndices());
+ SmallVector<OpFoldResult> finalSizes;
+
+ AffineExpr s0 = rewriter.getAffineSymbolExpr(0);
+ AffineExpr s1 = rewriter.getAffineSymbolExpr(1);
+
+ for (auto [srcSize, indice] : llvm::zip(srcSizes, indices)) {
+ finalSizes.push_back(makeComposedFoldedAffineApply(rewriter, loc, s0 - s1,
+ {srcSize, indice}));
+ }
+ return finalSizes;
+}
+
+/// Rewrite a store/load-like op so that all its indices are zeros.
+/// E.g., %ld = memref.load %base[%off0]...[%offN]
+/// =>
+/// %new_base = subview %base[%off0,.., %offN][1,..,1][1,..,1]
+/// %ld = memref.load %new_base[0,..,0] :
+/// memref<1x..x1xTy, strided<[1,..,1], offset: ?>>
+///
+/// `getSrcMemRef` returns the source memref for the given load-like operation.
+///
+/// `getViewSizeForEachDim` returns the sizes of view that is going to feed
+/// new operation. This must return one size per dimension of the view.
+/// The sizes of the view needs to be at least as big as what is actually
+/// going to be accessed. Use the provided `loadStoreOp` to get the right
+/// sizes.
+///
+/// Using the given rewriter, `rebuildOpFromAddressAndIndices` creates a new
+/// LoadStoreLikeOp that reads from srcMemRef[indices].
+/// The returned operation will be used to replace loadStoreOp.
+template <typename LoadStoreLikeOp,
+ FailureOr<Value> (*getFailureOrSrcMemRef)(LoadStoreLikeOp),
+ LoadStoreLikeOp (*rebuildOpFromAddressAndIndices)(
+ RewriterBase & /*rewriter*/, LoadStoreLikeOp /*loadStoreOp*/,
+ Value /*srcMemRef*/, ArrayRef<Value> /*indices*/),
+ SmallVector<OpFoldResult> (*getViewSizeForEachDim)(
+ RewriterBase & /*rewriter*/, LoadStoreLikeOp /*loadStoreOp*/) =
+ getGenericOpViewSizeForEachDim<
+ LoadStoreLikeOp,
+ getSrcMemRef<LoadStoreLikeOp, getFailureOrSrcMemRef>>>
+struct LoadStoreLikeOpRewriter : public OpRewritePattern<LoadStoreLikeOp> {
+ using OpRewritePattern<LoadStoreLikeOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(LoadStoreLikeOp loadStoreLikeOp,
+ PatternRewriter &rewriter) const override {
+ FailureOr<Value> failureOrSrcMemRef =
+ getFailureOrSrcMemRef(loadStoreLikeOp);
+ if (failed(failureOrSrcMemRef))
+ return rewriter.notifyMatchFailure(loadStoreLikeOp,
+ "source is not a memref");
+ Value srcMemRef = *failureOrSrcMemRef;
+ auto ldStTy = srcMemRef.getType().cast<MemRefType>();
+ unsigned loadStoreRank = ldStTy.getRank();
+ // Don't waste compile time if there is nothing to rewrite.
+ if (loadStoreRank == 0)
+ return rewriter.notifyMatchFailure(loadStoreLikeOp,
+ "0-D accesses don't need rewriting");
+
+ // If our load already has only zeros as indices there is nothing
+ // to do.
+ SmallVector<OpFoldResult> indices =
+ getAsOpFoldResult(loadStoreLikeOp.getIndices());
+ if (std::all_of(indices.begin(), indices.end(),
+ [](const OpFoldResult &opFold) {
+ return isConstantIntValue(opFold, 0);
+ })) {
+ return rewriter.notifyMatchFailure(
+ loadStoreLikeOp, "no computation to extract: offsets are 0s");
+ }
+
+ // Create the array of ones of the right size.
+ SmallVector<OpFoldResult> ones(loadStoreRank, rewriter.getIndexAttr(1));
+ SmallVector<OpFoldResult> sizes =
+ getViewSizeForEachDim(rewriter, loadStoreLikeOp);
+ assert(sizes.size() == loadStoreRank &&
+ "Expected one size per load dimension");
+ Location loc = loadStoreLikeOp.getLoc();
+ // The subview inherits its strides from the original memref and will
+ // apply them properly to the input indices.
+ // Therefore the strides multipliers are simply ones.
+ auto subview =
+ rewriter.create<memref::SubViewOp>(loc, /*source=*/srcMemRef,
+ /*offsets=*/indices,
+ /*sizes=*/sizes, /*strides=*/ones);
+ // Rewrite the load/store with the subview as the base pointer.
+ SmallVector<Value> zeros(loadStoreRank,
+ rewriter.create<arith::ConstantIndexOp>(loc, 0));
+ LoadStoreLikeOp newLoadStore = rebuildOpFromAddressAndIndices(
+ rewriter, loadStoreLikeOp, subview.getResult(), zeros);
+ rewriter.replaceOp(loadStoreLikeOp, newLoadStore->getResults());
+ return success();
+ }
+};
+} // namespace
+
+void memref::populateExtractAddressComputationsPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<
+ LoadStoreLikeOpRewriter<
+ memref::LoadOp,
+ /*getSrcMemRef=*/getLoadOpSrcMemRef,
+ /*rebuildOpFromAddressAndIndices=*/rebuildLoadOp,
+ /*getViewSizeForEachDim=*/getLoadOpViewSizeForEachDim>,
+ LoadStoreLikeOpRewriter<
+ memref::StoreOp,
+ /*getSrcMemRef=*/getStoreOpSrcMemRef,
+ /*rebuildOpFromAddressAndIndices=*/rebuildStoreOp,
+ /*getViewSizeForEachDim=*/getStoreOpViewSizeForEachDim>,
+ LoadStoreLikeOpRewriter<
+ nvgpu::LdMatrixOp,
+ /*getSrcMemRef=*/getLdMatrixOpSrcMemRef,
+ /*rebuildOpFromAddressAndIndices=*/rebuildLdMatrixOp>,
+ LoadStoreLikeOpRewriter<
+ vector::TransferReadOp,
+ /*getSrcMemRef=*/getTransferLikeOpSrcMemRef<vector::TransferReadOp>,
+ /*rebuildOpFromAddressAndIndices=*/rebuildTransferReadOp>,
+ LoadStoreLikeOpRewriter<
+ vector::TransferWriteOp,
+ /*getSrcMemRef=*/getTransferLikeOpSrcMemRef<vector::TransferWriteOp>,
+ /*rebuildOpFromAddressAndIndices=*/rebuildTransferWriteOp>>(
+ patterns.getContext());
+}
diff --git a/mlir/test/Dialect/MemRef/extract-address-computations.mlir b/mlir/test/Dialect/MemRef/extract-address-computations.mlir
new file mode 100644
index 0000000000000..17e2ac3bc5e24
--- /dev/null
+++ b/mlir/test/Dialect/MemRef/extract-address-computations.mlir
@@ -0,0 +1,393 @@
+// RUN: mlir-opt -test-transform-dialect-interpreter %s --split-input-file --verify-diagnostics | FileCheck %s
+
+// Simple test: check that we extract the address computation of a load into
+// a dedicated subview.
+// The resulting load will be loading from the subview and have only indices
+// set to zero.
+
+// CHECK-LABEL: @test_load(
+// CHECK-SAME: %[[BASE:[^:]*]]: memref{{[^,]*}},
+// CHECK-SAME: %[[DYN_OFFSET:.*]]: index)
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[SUBVIEW:.*]] = memref.subview %[[BASE]][%[[DYN_OFFSET]], 0, 8] [1, 1, 1] [1, 1, 1] : memref<2x16x16xf32> to memref<1x1x1xf32, strided<[256, 16, 1], offset: ?>>
+// CHECK: %[[LOADED_VAL:.*]] = memref.load %[[SUBVIEW]][%[[C0]], %[[C0]], %[[C0]]] : memref<1x1x1xf32, strided<[256, 16, 1], offset: ?>>
+// CHECK: return %[[LOADED_VAL]] : f32
+
+// expected-remark @below {{transformed}}
+func.func @test_load(%base : memref<2x16x16xf32>, %offset : index) -> f32 {
+ %c0 = arith.constant 0 : index
+ %c8 = arith.constant 8 : index
+ %loaded_val = memref.load %base[%offset, %c0, %c8] : memref<2x16x16xf32>
+ return %loaded_val : f32
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+ %1 = transform.memref.extract_address_computations %0 : (!pdl.operation) -> !pdl.operation
+ // Verify that the returned handle is usable.
+ transform.test_print_remark_at_operand %1, "transformed" : !pdl.operation
+}
+
+// -----
+
+// Same as previous @test_load but with the nontemporal flag.
+
+// CHECK-LABEL: @test_load_nontemporal(
+// CHECK-SAME: %[[BASE:[^:]*]]: memref{{[^,]*}},
+// CHECK-SAME: %[[DYN_OFFSET:.*]]: index)
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[SUBVIEW:.*]] = memref.subview %[[BASE]][%[[DYN_OFFSET]], 0, 8] [1, 1, 1] [1, 1, 1] : memref<2x16x16xf32> to memref<1x1x1xf32, strided<[256, 16, 1], offset: ?>>
+// CHECK: %[[LOADED_VAL:.*]] = memref.load %[[SUBVIEW]][%[[C0]], %[[C0]], %[[C0]]] {nontemporal = true} : memref<1x1x1xf32, strided<[256, 16, 1], offset: ?>>
+// CHECK: return %[[LOADED_VAL]] : f32
+func.func @test_load_nontemporal(%base : memref<2x16x16xf32>, %offset : index) -> f32 {
+ %c0 = arith.constant 0 : index
+ %c8 = arith.constant 8 : index
+ %loaded_val = memref.load %base[%offset, %c0, %c8] {nontemporal = true } : memref<2x16x16xf32>
+ return %loaded_val : f32
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+ %1 = transform.memref.extract_address_computations %0 : (!pdl.operation) -> !pdl.operation
+}
+
+// -----
+
+// Simple test: check that we extract the address computation of a store into
+// a dedicated subview.
+// The resulting store will use the address from the subview and have only
+// indices set to zero.
+
+// CHECK-LABEL: @test_store(
+// CHECK-SAME: %[[BASE:[^:]*]]: memref{{[^,]*}},
+// CHECK-SAME: %[[DYN_OFFSET:.*]]: index)
+// CHECK-DAG: %[[CF0:.*]] = arith.constant 0.0{{0*e\+00}} : f32
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[SUBVIEW:.*]] = memref.subview %[[BASE]][%[[DYN_OFFSET]], 0, 8] [1, 1, 1] [1, 1, 1] : memref<2x16x16xf32> to memref<1x1x1xf32, strided<[256, 16, 1], offset: ?>>
+// CHECK: memref.store %[[CF0]], %[[SUBVIEW]][%[[C0]], %[[C0]], %[[C0]]] : memref<1x1x1xf32, strided<[256, 16, 1], offset: ?>>
+// CHECK: return
+func.func @test_store(%base : memref<2x16x16xf32>, %offset : index) -> () {
+ %cf0 = arith.constant 0.0 : f32
+ %c0 = arith.constant 0 : index
+ %c8 = arith.constant 8 : index
+ memref.store %cf0, %base[%offset, %c0, %c8] : memref<2x16x16xf32>
+ return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+ %1 = transform.memref.extract_address_computations %0 : (!pdl.operation) -> !pdl.operation
+}
+
+// -----
+
+// Same as @test_store but check that the nontemporal flag is preserved.
+
+// CHECK-LABEL: @test_store_nontemporal(
+// CHECK-SAME: %[[BASE:[^:]*]]: memref{{[^,]*}},
+// CHECK-SAME: %[[DYN_OFFSET:.*]]: index)
+// CHECK-DAG: %[[CF0:.*]] = arith.constant 0.0{{0*e\+00}} : f32
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[SUBVIEW:.*]] = memref.subview %[[BASE]][%[[DYN_OFFSET]], 0, 8] [1, 1, 1] [1, 1, 1] : memref<2x16x16xf32> to memref<1x1x1xf32, strided<[256, 16, 1], offset: ?>>
+// CHECK: memref.store %[[CF0]], %[[SUBVIEW]][%[[C0]], %[[C0]], %[[C0]]] {nontemporal = true} : memref<1x1x1xf32, strided<[256, 16, 1], offset: ?>>
+// CHECK: return
+func.func @test_store_nontemporal(%base : memref<2x16x16xf32>, %offset : index) -> () {
+ %cf0 = arith.constant 0.0 : f32
+ %c0 = arith.constant 0 : index
+ %c8 = arith.constant 8 : index
+ memref.store %cf0, %base[%offset, %c0, %c8] { nontemporal = true } : memref<2x16x16xf32>
+ return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+ %1 = transform.memref.extract_address_computations %0 : (!pdl.operation) -> !pdl.operation
+}
+
+// -----
+// For this test, we made the source memref fully dynamic.
+// The gist of the check remains the same as the simple test:
+// The address computation is extracted into its own subview.
+// CHECK-LABEL: @testWithLoop(
+// CHECK-SAME: %[[BASE:[^:]*]]: memref
+// CHECK: %[[SUM_ALL:.*]] = arith.constant 0.0{{0*e\+00}} : f32
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: %[[UPPER_BOUND0:.*]] = memref.dim %[[BASE]], %[[C0]] : memref<?x?x?xf32,
+// CHECK: %[[UPPER_BOUND1:.*]] = memref.dim %[[BASE]], %[[C1]] : memref<?x?x?xf32,
+// CHECK: %[[UPPER_BOUND2:.*]] = memref.dim %[[BASE]], %[[C2]] : memref<?x?x?xf32,
+// CHECK: %[[SUM_RES2:.*]] = scf.for %[[IV2:.*]] = %[[C0]] to %[[UPPER_BOUND2]] step %[[C1]] iter_args(%[[SUM_ITER2:.*]] = %[[SUM_ALL]]) -> (f32) {
+// CHECK: %[[SUM_RES1:.*]] = scf.for %[[IV1:.*]] = %[[C0]] to %[[UPPER_BOUND1]] step %[[C1]] iter_args(%[[SUM_ITER1:.*]] = %[[SUM_ITER2]]) -> (f32) {
+// CHECK: %[[SUM_RES0:.*]] = scf.for %[[IV0:.*]] = %[[C0]] to %[[UPPER_BOUND0]] step %[[C1]] iter_args(%[[SUM_ITER0:.*]] = %[[SUM_ITER1]]) -> (f32) {
+// CHECK: %[[SUBVIEW:.*]] = memref.subview %[[BASE]][%[[IV0]], %[[IV1]], %[[IV2]]] [1, 1, 1] [1, 1, 1] : memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> to memref<1x1x1xf32, strided<[?, ?, ?], offset: ?>>
+// CHECK: %[[LOADED_VAL:.*]] = memref.load %[[SUBVIEW]][%[[C0]], %[[C0]], %[[C0]]] : memref<1x1x1xf32, strided<[?, ?, ?], offset: ?>>
+// CHECK: %[[RES:.*]] = arith.addf %[[LOADED_VAL]], %[[SUM_ITER2]] : f32
+// CHECK: scf.yield %[[RES]] : f32
+// CHECK: }
+// CHECK: scf.yield %[[SUM_RES0]] : f32
+// CHECK: }
+// CHECK: scf.yield %[[SUM_RES1]] : f32
+// CHECK: }
+// CHECK: return %[[SUM_RES2]] : f32
+func.func @testWithLoop(%base : memref<?x?x?xf32, strided<[?,?,?], offset: ?>>) -> f32 {
+ %sum_all = arith.constant 0.0 : f32
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %upper_bound0 = memref.dim %base, %c0 : memref<?x?x?xf32, strided<[?,?,?], offset: ?>>
+ %upper_bound1 = memref.dim %base, %c1 : memref<?x?x?xf32, strided<[?,?,?], offset: ?>>
+ %upper_bound2 = memref.dim %base, %c2 : memref<?x?x?xf32, strided<[?,?,?], offset: ?>>
+ %sum_res2 = scf.for %iv2 = %c0 to %upper_bound2 step %c1 iter_args(%sum_iter2 = %sum_all) -> (f32) {
+ %sum_res1 = scf.for %iv1 = %c0 to %upper_bound1 step %c1 iter_args(%sum_iter1 = %sum_iter2) -> (f32) {
+ %sum_res0 = scf.for %iv0 = %c0 to %upper_bound0 step %c1 iter_args(%sum_iter0 = %sum_iter1) -> (f32) {
+ %loaded_val = memref.load %base[%iv0, %iv1, %iv2] : memref<?x?x?xf32, strided<[?,?,?], offset: ?>>
+ %res = arith.addf %loaded_val, %sum_iter2 : f32
+ scf.yield %res : f32
+ }
+ scf.yield %sum_res0 : f32
+ }
+ scf.yield %sum_res1 : f32
+ }
+ return %sum_res2 : f32
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+ %1 = transform.memref.extract_address_computations %0 : (!pdl.operation) -> !pdl.operation
+}
+
+// -----
+
+// Simple test: check that we extract the address computation of a ldmatrix into
+// a dedicated subview.
+// The resulting ldmatrix will loaded from with subview and have only indices set
+// to zero.
+// Also the sizes of the view are adjusted to `original size - offset`.
+
+// CHECK-DAG: #[[$FOUR_MINUS_OFF_MAP:.*]] = affine_map<()[s0] -> (-s0 + 4)>
+// CHECK-DAG: #[[$THIRTY_TWO_MINUS_OFF_MAP:.*]] = affine_map<()[s0] -> (-s0 + 32)>
+// CHECK-LABEL: @test_ldmatrix(
+// CHECK-SAME: %[[BASE:[^:]*]]: memref<{{[^,]*}}, 3>,
+// CHECK-SAME: %[[DYN_OFFSET0:[^:]*]]: index,
+// CHECK-SAME: %[[DYN_OFFSET1:[^:]*]]: index,
+// CHECK-SAME: %[[DYN_OFFSET2:[^:]*]]: index)
+// CHECK-DAG: %[[DYN_SIZE0:.*]] = affine.apply #[[$FOUR_MINUS_OFF_MAP]]()[%[[DYN_OFFSET0]]]
+// CHECK-DAG: %[[DYN_SIZE1:.*]] = affine.apply #[[$THIRTY_TWO_MINUS_OFF_MAP]]()[%[[DYN_OFFSET1]]]
+// CHECK-DAG: %[[DYN_SIZE2:.*]] = affine.apply #[[$THIRTY_TWO_MINUS_OFF_MAP]]()[%[[DYN_OFFSET2]]]
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[SUBVIEW:.*]] = memref.subview %[[BASE]][%[[DYN_OFFSET0]], %[[DYN_OFFSET1]], %[[DYN_OFFSET2]]] [%[[DYN_SIZE0]], %[[DYN_SIZE1]], %[[DYN_SIZE2]]] [1, 1, 1] : memref<4x32x32xf16, 3> to memref<?x?x?xf16, strided<[1024, 32, 1], offset: ?>, 3>
+// CHECK: %[[LOADED_VAL:.*]] = nvgpu.ldmatrix %[[SUBVIEW]][%[[C0]], %[[C0]], %[[C0]]] {numTiles = 4 : i32, transpose = false} : memref<?x?x?xf16, strided<[1024, 32, 1], offset: ?>, 3> -> vector<4x2xf16>
+// CHECK: return %[[LOADED_VAL]] : vector<4x2xf16>
+func.func @test_ldmatrix(%base : memref<4x32x32xf16, 3>,
+ %offset0 : index, %offset1: index, %offset2: index)
+ -> vector<4x2xf16> {
+ %loaded_val = nvgpu.ldmatrix
+ %base[%offset0, %offset1, %offset2]
+ {numTiles = 4 : i32, transpose = false}
+ : memref<4x32x32xf16, 3> -> vector<4x2xf16>
+ return %loaded_val : vector<4x2xf16>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+ %1 = transform.memref.extract_address_computations %0 : (!pdl.operation) -> !pdl.operation
+}
+
+// -----
+
+// Same as test_ldmatrix but with fully dynamic memref.
+
+// CHECK-DAG: #[[$A_MINUS_B_MAP:.*]] = affine_map<()[s0, s1] -> (s0 - s1)>
+// CHECK-LABEL: @test_ldmatrix(
+// CHECK-SAME: %[[BASE:[^:]*]]: memref<{{[^,]*}}, 3>,
+// CHECK-SAME: %[[DYN_OFFSET0:[^:]*]]: index,
+// CHECK-SAME: %[[DYN_OFFSET1:[^:]*]]: index,
+// CHECK-SAME: %[[DYN_OFFSET2:[^:]*]]: index)
+// CHECK-DAG: {{.*}}, {{.*}}, %[[DYN_SIZES:.*]]:3, {{.*}} = memref.extract_strided_metadata %[[BASE]]
+// CHECK-DAG: %[[DYN_SIZE0:.*]] = affine.apply #[[$A_MINUS_B_MAP]]()[%[[DYN_SIZES]]#0, %[[DYN_OFFSET0]]]
+// CHECK-DAG: %[[DYN_SIZE1:.*]] = affine.apply #[[$A_MINUS_B_MAP]]()[%[[DYN_SIZES]]#1, %[[DYN_OFFSET1]]]
+// CHECK-DAG: %[[DYN_SIZE2:.*]] = affine.apply #[[$A_MINUS_B_MAP]]()[%[[DYN_SIZES]]#2, %[[DYN_OFFSET2]]]
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[SUBVIEW:.*]] = memref.subview %[[BASE]][%[[DYN_OFFSET0]], %[[DYN_OFFSET1]], %[[DYN_OFFSET2]]] [%[[DYN_SIZE0]], %[[DYN_SIZE1]], %[[DYN_SIZE2]]] [1, 1, 1] : memref<?x?x?xf16, 3> to memref<?x?x?xf16, strided<[?, ?, 1], offset: ?>, 3>
+// CHECK: %[[LOADED_VAL:.*]] = nvgpu.ldmatrix %[[SUBVIEW]][%[[C0]], %[[C0]], %[[C0]]] {numTiles = 4 : i32, transpose = false} : memref<?x?x?xf16, strided<[?, ?, 1], offset: ?>, 3> -> vector<4x2xf16>
+// CHECK: return %[[LOADED_VAL]] : vector<4x2xf16>
+func.func @test_ldmatrix(%base : memref<?x?x?xf16, 3>,
+ %offset0 : index, %offset1: index, %offset2: index)
+ -> vector<4x2xf16> {
+ %loaded_val = nvgpu.ldmatrix
+ %base[%offset0, %offset1, %offset2]
+ {numTiles = 4 : i32, transpose = false}
+ : memref<?x?x?xf16, 3> -> vector<4x2xf16>
+ return %loaded_val : vector<4x2xf16>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+ %1 = transform.memref.extract_address_computations %0 : (!pdl.operation) -> !pdl.operation
+}
+
+// -----
+
+// Simple test for vector.transfer_read with fully dynamic memref.
+// We also set a permutation map to make sure it is properly preserved.
+
+// CHECK-DAG: #[[$A_MINUS_B_MAP:.*]] = affine_map<()[s0, s1] -> (s0 - s1)>
+// CHECK-DAG: #[[$PERMUTATION_MAP:.*]] = affine_map<(d0, d1, d2) -> (d2, d0)>
+// CHECK-LABEL: @test_transfer_read_op(
+// CHECK-SAME: %[[BASE:[^:]*]]: memref<{{[^,]*}}>,
+// CHECK-SAME: %[[DYN_OFFSET0:[^:]*]]: index,
+// CHECK-SAME: %[[DYN_OFFSET1:[^:]*]]: index,
+// CHECK-SAME: %[[DYN_OFFSET2:[^:]*]]: index)
+// CHECK-DAG: {{.*}}, {{.*}}, %[[DYN_SIZES:.*]]:3, {{.*}} = memref.extract_strided_metadata %[[BASE]]
+// CHECK-DAG: %[[DYN_SIZE0:.*]] = affine.apply #[[$A_MINUS_B_MAP]]()[%[[DYN_SIZES]]#0, %[[DYN_OFFSET0]]]
+// CHECK-DAG: %[[DYN_SIZE1:.*]] = affine.apply #[[$A_MINUS_B_MAP]]()[%[[DYN_SIZES]]#1, %[[DYN_OFFSET1]]]
+// CHECK-DAG: %[[DYN_SIZE2:.*]] = affine.apply #[[$A_MINUS_B_MAP]]()[%[[DYN_SIZES]]#2, %[[DYN_OFFSET2]]]
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[CF0:.*]] = arith.constant 0.0{{0*e\+00}} : f16
+// CHECK-DAG: %[[SUBVIEW:.*]] = memref.subview %[[BASE]][%[[DYN_OFFSET0]], %[[DYN_OFFSET1]], %[[DYN_OFFSET2]]] [%[[DYN_SIZE0]], %[[DYN_SIZE1]], %[[DYN_SIZE2]]] [1, 1, 1] : memref<?x?x?xf16> to memref<?x?x?xf16, strided<[?, ?, 1], offset: ?>>
+// CHECK: %[[LOADED_VAL:.*]] = vector.transfer_read %[[SUBVIEW]][%[[C0]], %[[C0]], %[[C0]]], %[[CF0]] {permutation_map = #[[$PERMUTATION_MAP]]} : memref<?x?x?xf16, strided<[?, ?, 1], offset: ?>>, vector<4x2xf16>
+// CHECK: return %[[LOADED_VAL]] : vector<4x2xf16>
+func.func @test_transfer_read_op(%base : memref<?x?x?xf16>,
+ %offset0 : index, %offset1: index, %offset2: index)
+ -> vector<4x2xf16> {
+ %cf0 = arith.constant 0.0 : f16
+ %loaded_val = vector.transfer_read %base[%offset0, %offset1, %offset2], %cf0 { permutation_map = affine_map<(d0,d1,d2) -> (d2,d0)> } : memref<?x?x?xf16>, vector<4x2xf16>
+ return %loaded_val : vector<4x2xf16>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+ %1 = transform.memref.extract_address_computations %0 : (!pdl.operation) -> !pdl.operation
+}
+
+// -----
+
+// Same as test_transfer_read_op but with tensors.
+// Right now this rewrite is not supported but we still shouldn't choke on it.
+
+// CHECK: #[[$PERMUTATION_MAP:.*]] = affine_map<(d0, d1, d2) -> (d2, d0)>
+// CHECK-LABEL: @test_transfer_read_op_with_tensor(
+// CHECK-SAME: %[[BASE:[^:]*]]: tensor<{{[^,]*}}>,
+// CHECK-SAME: %[[DYN_OFFSET0:[^:]*]]: index,
+// CHECK-SAME: %[[DYN_OFFSET1:[^:]*]]: index,
+// CHECK-SAME: %[[DYN_OFFSET2:[^:]*]]: index)
+// CHECK: %[[CF0:.*]] = arith.constant 0.0{{0*e\+00}} : f16
+// CHECK: %[[LOADED_VAL:.*]] = vector.transfer_read %[[BASE]][%[[DYN_OFFSET0]], %[[DYN_OFFSET1]], %[[DYN_OFFSET2]]], %[[CF0]] {permutation_map = #[[$PERMUTATION_MAP]]} : tensor<?x?x?xf16>, vector<4x2xf16>
+// CHECK: return %[[LOADED_VAL]] : vector<4x2xf16>
+func.func @test_transfer_read_op_with_tensor(%base : tensor<?x?x?xf16>,
+ %offset0 : index, %offset1: index, %offset2: index)
+ -> vector<4x2xf16> {
+ %cf0 = arith.constant 0.0 : f16
+ %loaded_val = vector.transfer_read %base[%offset0, %offset1, %offset2], %cf0 { permutation_map = affine_map<(d0,d1,d2) -> (d2,d0)> } : tensor<?x?x?xf16>, vector<4x2xf16>
+ return %loaded_val : vector<4x2xf16>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+ %1 = transform.memref.extract_address_computations %0 : (!pdl.operation) -> !pdl.operation
+}
+
+// -----
+
+// Simple test for vector.transfer_write with fully dynamic memref.
+// We also set a permutation map to make sure it is properly preserved.
+
+// CHECK-DAG: #[[$A_MINUS_B_MAP:.*]] = affine_map<()[s0, s1] -> (s0 - s1)>
+// CHECK-DAG: #[[$PERMUTATION_MAP:.*]] = affine_map<(d0, d1, d2) -> (d2, d0)>
+// CHECK-LABEL: @test_transfer_write_op(
+// CHECK-SAME: %[[BASE:[^:]*]]: memref<{{[^,]*}}>,
+// CHECK-SAME: %[[DYN_OFFSET0:[^:]*]]: index,
+// CHECK-SAME: %[[DYN_OFFSET1:[^:]*]]: index,
+// CHECK-SAME: %[[DYN_OFFSET2:[^:]*]]: index)
+// CHECK-DAG: {{.*}}, {{.*}}, %[[DYN_SIZES:.*]]:3, {{.*}} = memref.extract_strided_metadata %[[BASE]]
+// CHECK-DAG: %[[DYN_SIZE0:.*]] = affine.apply #[[$A_MINUS_B_MAP]]()[%[[DYN_SIZES]]#0, %[[DYN_OFFSET0]]]
+// CHECK-DAG: %[[DYN_SIZE1:.*]] = affine.apply #[[$A_MINUS_B_MAP]]()[%[[DYN_SIZES]]#1, %[[DYN_OFFSET1]]]
+// CHECK-DAG: %[[DYN_SIZE2:.*]] = affine.apply #[[$A_MINUS_B_MAP]]()[%[[DYN_SIZES]]#2, %[[DYN_OFFSET2]]]
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[VCF0:.*]] = arith.constant dense<0.0{{0*e\+00}}> : vector<4x2xf16>
+// CHECK-DAG: %[[SUBVIEW:.*]] = memref.subview %[[BASE]][%[[DYN_OFFSET0]], %[[DYN_OFFSET1]], %[[DYN_OFFSET2]]] [%[[DYN_SIZE0]], %[[DYN_SIZE1]], %[[DYN_SIZE2]]] [1, 1, 1] : memref<?x?x?xf16> to memref<?x?x?xf16, strided<[?, ?, 1], offset: ?>>
+// CHECK: vector.transfer_write %[[VCF0]], %[[SUBVIEW]][%[[C0]], %[[C0]], %[[C0]]] {permutation_map = #[[$PERMUTATION_MAP]]} : vector<4x2xf16>, memref<?x?x?xf16, strided<[?, ?, 1], offset: ?>>
+// CHECK: return
+func.func @test_transfer_write_op(%base : memref<?x?x?xf16>,
+ %offset0 : index, %offset1: index, %offset2: index) {
+ %vcf0 = arith.constant dense<0.000000e+00> : vector<4x2xf16>
+ vector.transfer_write %vcf0, %base[%offset0, %offset1, %offset2] { permutation_map = affine_map<(d0,d1,d2) -> (d2,d0)> } : vector<4x2xf16>, memref<?x?x?xf16>
+ return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+ %1 = transform.memref.extract_address_computations %0 : (!pdl.operation) -> !pdl.operation
+}
+
+// -----
+
+// Check that the strides of the original memref are kept.
+// Moreover even with non-1 strides the subview should still issue [1,...]
+// strides, since this is a multiplication factor.
+
+// CHECK-DAG: #[[$A_MINUS_B_MAP:.*]] = affine_map<()[s0, s1] -> (s0 - s1)>
+// CHECK-DAG: #[[$PERMUTATION_MAP:.*]] = affine_map<(d0, d1, d2) -> (d2, d0)>
+// CHECK-LABEL: @test_transfer_write_op_with_strides(
+// CHECK-SAME: %[[BASE:[^:]*]]: memref<{{[^>]*}}>>,
+// CHECK-SAME: %[[DYN_OFFSET0:[^:]*]]: index,
+// CHECK-SAME: %[[DYN_OFFSET1:[^:]*]]: index,
+// CHECK-SAME: %[[DYN_OFFSET2:[^:]*]]: index)
+// CHECK-DAG: {{.*}}, {{.*}}, %[[DYN_SIZES:.*]]:3, {{.*}} = memref.extract_strided_metadata %[[BASE]]
+// CHECK-DAG: %[[DYN_SIZE0:.*]] = affine.apply #[[$A_MINUS_B_MAP]]()[%[[DYN_SIZES]]#0, %[[DYN_OFFSET0]]]
+// CHECK-DAG: %[[DYN_SIZE1:.*]] = affine.apply #[[$A_MINUS_B_MAP]]()[%[[DYN_SIZES]]#1, %[[DYN_OFFSET1]]]
+// CHECK-DAG: %[[DYN_SIZE2:.*]] = affine.apply #[[$A_MINUS_B_MAP]]()[%[[DYN_SIZES]]#2, %[[DYN_OFFSET2]]]
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[VCF0:.*]] = arith.constant dense<0.0{{0*e\+00}}> : vector<4x2xf16>
+// CHECK-DAG: %[[SUBVIEW:.*]] = memref.subview %[[BASE]][%[[DYN_OFFSET0]], %[[DYN_OFFSET1]], %[[DYN_OFFSET2]]] [%[[DYN_SIZE0]], %[[DYN_SIZE1]], %[[DYN_SIZE2]]] [1, 1, 1] : memref<?x?x?xf16, strided<[329, 26, 12], offset: ?>> to memref<?x?x?xf16, strided<[329, 26, 12], offset: ?>>
+// CHECK: vector.transfer_write %[[VCF0]], %[[SUBVIEW]][%[[C0]], %[[C0]], %[[C0]]] {permutation_map = #[[$PERMUTATION_MAP]]} : vector<4x2xf16>, memref<?x?x?xf16, strided<[329, 26, 12], offset: ?>>
+// CHECK: return
+func.func @test_transfer_write_op_with_strides(%base : memref<?x?x?xf16, strided<[329, 26, 12], offset: ?>>,
+ %offset0 : index, %offset1: index, %offset2: index) {
+ %vcf0 = arith.constant dense<0.000000e+00> : vector<4x2xf16>
+ vector.transfer_write %vcf0, %base[%offset0, %offset1, %offset2] { permutation_map = affine_map<(d0,d1,d2) -> (d2,d0)> } : vector<4x2xf16>, memref<?x?x?xf16, strided<[329, 26, 12], offset: ?>>
+ return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+ %1 = transform.memref.extract_address_computations %0 : (!pdl.operation) -> !pdl.operation
+}
+// -----
+
+// Same as test_transfer_write_op but with tensors.
+// Right now this rewrite is not supported but we still shouldn't choke on it.
+
+// CHECK: #[[$PERMUTATION_MAP:.*]] = affine_map<(d0, d1, d2) -> (d2, d0)>
+// CHECK-LABEL: @test_transfer_write_op_with_tensor(
+// CHECK-SAME: %[[BASE:[^:]*]]: tensor<{{[^,]*}}>,
+// CHECK-SAME: %[[DYN_OFFSET0:[^:]*]]: index,
+// CHECK-SAME: %[[DYN_OFFSET1:[^:]*]]: index,
+// CHECK-SAME: %[[DYN_OFFSET2:[^:]*]]: index)
+// CHECK-DAG: %[[VCF0:.*]] = arith.constant dense<0.0{{0*e\+00}}> : vector<4x2xf16>
+// CHECK: %[[RES:.*]] = vector.transfer_write %[[VCF0]], %[[BASE]][%[[DYN_OFFSET0]], %[[DYN_OFFSET1]], %[[DYN_OFFSET2]]] {permutation_map = #[[$PERMUTATION_MAP]]} : vector<4x2xf16>, tensor<?x?x?xf16>
+// CHECK: return %[[RES]] : tensor<?x?x?xf16>
+func.func @test_transfer_write_op_with_tensor(%base : tensor<?x?x?xf16>,
+ %offset0 : index, %offset1: index, %offset2: index) -> tensor<?x?x?xf16> {
+ %vcf0 = arith.constant dense<0.000000e+00> : vector<4x2xf16>
+ %res = vector.transfer_write %vcf0, %base[%offset0, %offset1, %offset2] { permutation_map = affine_map<(d0,d1,d2) -> (d2,d0)> } : vector<4x2xf16>, tensor<?x?x?xf16>
+ return %res : tensor<?x?x?xf16>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+ %1 = transform.memref.extract_address_computations %0 : (!pdl.operation) -> !pdl.operation
+}
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index b80ddae92ac6b..6d75536c98026 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -10098,6 +10098,7 @@ cc_library(
":LoopLikeInterface",
":MemRefDialect",
":MemRefPassIncGen",
+ ":NVGPUDialect",
":Pass",
":RuntimeVerifiableOpInterface",
":TensorDialect",
@@ -10152,8 +10153,11 @@ cc_library(
":MemRefDialect",
":MemRefTransformOpsIncGen",
":MemRefTransforms",
+ ":NVGPUDialect",
":PDLDialect",
":TransformDialect",
+ ":TransformUtils",
+ ":VectorDialect",
"//llvm:Support",
],
)
More information about the Mlir-commits
mailing list