[Mlir-commits] [mlir] [mlir][memref] Fold subview into nd-tensor descriptor (PR #88698)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Apr 15 01:30:55 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Adam Siemieniuk (adam-smnk)
<details>
<summary>Changes</summary>
---
Full diff: https://github.com/llvm/llvm-project/pull/88698.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp (+35-2)
- (modified) mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir (+21)
``````````diff
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index aa44455ada7f9a..3dda27c91958fa 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -22,6 +22,7 @@
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
@@ -323,6 +324,16 @@ class NvgpuAsyncCopyOpSubViewOpFolder final
LogicalResult matchAndRewrite(nvgpu::DeviceAsyncCopyOp copyOp,
PatternRewriter &rewriter) const override;
};
+
+/// Merges subview operation with xegpu.create_nd_tdesc operation.
+class XegpuCreateNdDescOpSubViewOpFolder final
+ : public OpRewritePattern<xegpu::CreateNdDescOp> {
+public:
+ using OpRewritePattern<xegpu::CreateNdDescOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(xegpu::CreateNdDescOp descOp,
+ PatternRewriter &rewriter) const override;
+};
} // namespace
static SmallVector<Value>
@@ -700,6 +711,28 @@ LogicalResult NvgpuAsyncCopyOpSubViewOpFolder::matchAndRewrite(
return success();
}
+LogicalResult XegpuCreateNdDescOpSubViewOpFolder::matchAndRewrite(
+ xegpu::CreateNdDescOp descOp, PatternRewriter &rewriter) const {
+ auto subViewOp = descOp.getSource().getDefiningOp<memref::SubViewOp>();
+
+ if (!subViewOp)
+ return rewriter.notifyMatchFailure(descOp, "not a subview producer");
+ if (!subViewOp.hasUnitStride())
+ return rewriter.notifyMatchFailure(descOp, "requires unit strides");
+
+ SmallVector<Value> resolvedOffsets;
+ affine::resolveIndicesIntoOpWithOffsetsAndStrides(
+ rewriter, descOp.getLoc(), subViewOp.getMixedOffsets(),
+ subViewOp.getMixedStrides(), subViewOp.getDroppedDims(),
+ descOp.getMixedOffsets(), resolvedOffsets);
+
+ rewriter.replaceOpWithNewOp<xegpu::CreateNdDescOp>(
+ descOp, descOp.getTensorDesc().getType(), subViewOp.getSource(),
+ getAsOpFoldResult(resolvedOffsets));
+
+ return success();
+}
+
void memref::populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) {
patterns.add<LoadOpOfSubViewOpFolder<affine::AffineLoadOp>,
LoadOpOfSubViewOpFolder<memref::LoadOp>,
@@ -722,8 +755,8 @@ void memref::populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) {
LoadOpOfCollapseShapeOpFolder<memref::LoadOp>,
StoreOpOfCollapseShapeOpFolder<affine::AffineStoreOp>,
StoreOpOfCollapseShapeOpFolder<memref::StoreOp>,
- SubViewOfSubViewFolder, NvgpuAsyncCopyOpSubViewOpFolder>(
- patterns.getContext());
+ SubViewOfSubViewFolder, NvgpuAsyncCopyOpSubViewOpFolder,
+ XegpuCreateNdDescOpSubViewOpFolder>(patterns.getContext());
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index 5b853a6cc5a37a..58e32653f24c79 100644
--- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -838,3 +838,24 @@ func.func @fold_vector_maskedstore(
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: vector<32xf32>
// CHECK: vector.maskedstore %[[ARG0]][%[[ARG1]], %[[ARG2]]], %[[ARG3]], %[[ARG4]] : memref<12x32xf32>, vector<32xi1>, vector<32xf32>
// CHECK: return
+
+// -----
+
+func.func @fold_subview_with_xegpu_create_nd_tdesc(%arg0 : memref<256x256xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index) ->(!xegpu.tensor_desc<8x16xf32>) {
+ %subview = memref.subview %arg0[%arg1, %arg2] [32, 32] [1, 1] :
+ memref<256x256xf32> to memref<32x32xf32, strided<[256, 1], offset: ?>>
+ %0 = xegpu.create_nd_tdesc %subview[%arg3, %arg4] :
+ memref<32x32xf32, strided<[256, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32>
+ return %0 : !xegpu.tensor_desc<8x16xf32>
+}
+
+// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
+// CHECK: func @fold_subview_with_xegpu_create_nd_tdesc
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<256x256xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index
+// CHECK-DAG: %[[IDX0:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG3]]]
+// CHECK-DAG: %[[IDX1:.+]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[ARG4]]]
+// CHECK: xegpu.create_nd_tdesc %[[ARG0]][%[[IDX0]], %[[IDX1]]] : memref<256x256xf32> -> !xegpu.tensor_desc<8x16xf32>
``````````
</details>
https://github.com/llvm/llvm-project/pull/88698
More information about the Mlir-commits
mailing list