[Mlir-commits] [mlir] Support folding of higher dimensional memeref subviews in XeGPUFoldAliasOps (PR #99593)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jul 18 19:48:31 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Charitha Saumya (charithaintc)
<details>
<summary>Changes</summary>
`XeGPUFoldAliasOps` pass folds `memref.subview` operations that are sources of `xegpu.create_nd_tdesc`. But this does not support subviews created from n-D memrefs (n >2). This PR adds support for higher dimensional memrefs.
Example usage:
```
func.func @<!-- -->fold_subview_with_xegpu_create_nd_tdesc(%arg0 : memref<32x256x256xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index) ->(!xegpu.tensor_desc<8x16xf32>) {
%subview = memref.subview %arg0[%arg1, %arg2, %arg3] [1, 32, 32] [1, 1, 1] :
memref<32x256x256xf32> to memref<32x32xf32, strided<[256, 1], offset: ?>>
%0 = xegpu.create_nd_tdesc %subview[%arg4, %arg5] :
memref<32x32xf32, strided<[256, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32>
return %0 : !xegpu.tensor_desc<8x16xf32>
}
```
Gets folded to:
```
#map = affine_map<()[s0, s1] -> (s0 + s1)>
module {
func.func @<!-- -->fold_subview_with_xegpu_create_nd_tdesc(%arg0: memref<32x256x256xf32>, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index) -> !xegpu.tensor_desc<8x16xf32> {
%c65536 = arith.constant 65536 : index
%0 = affine.apply #map()[%arg2, %arg4]
%1 = affine.apply #map()[%arg3, %arg5]
%reinterpret_cast = memref.reinterpret_cast %arg0 to offset: [0], sizes: [8192, 256], strides: [256, 1] : memref<32x256x256xf32> to memref<8192x256xf32>
%2 = arith.muli %arg1, %c65536 : index
%3 = arith.addi %2, %0 : index
%4 = xegpu.create_nd_tdesc %reinterpret_cast[%3, %1] : memref<8192x256xf32> -> !xegpu.tensor_desc<8x16xf32>
return %4 : !xegpu.tensor_desc<8x16xf32>
}
}
```
Please review these guidelines to help with the review process:
- [ ] Have you provided a meaningful PR description?
- [ ] Have you added a test, a reproducer, or a reference to an issue with a reproducer?
- [ ] Have you tested your changes locally for CPU and GPU devices?
- [ ] Have you made sure that new changes do not introduce compiler warnings?
- [ ] If this PR is a work in progress, are you filing the PR as a draft?
- [ ] Have you organized your commits logically and ensured each can be built by itself?
---
Full diff: https://github.com/llvm/llvm-project/pull/99593.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUFoldAliasOps.cpp (+39-1)
- (modified) mlir/test/Dialect/XeGPU/xegpu-fold-alias-ops.mlir (+54)
``````````diff
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUFoldAliasOps.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUFoldAliasOps.cpp
index 9307e8eb784b5..27e10dfc785e4 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUFoldAliasOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUFoldAliasOps.cpp
@@ -6,12 +6,14 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
+#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/Debug.h"
@@ -48,6 +50,8 @@ LogicalResult XegpuCreateNdDescOpSubViewOpFolder::matchAndRewrite(
return rewriter.notifyMatchFailure(descOp, "not a subview producer");
if (!subViewOp.hasUnitStride())
return rewriter.notifyMatchFailure(descOp, "requires unit strides");
+ if (!subViewOp.getSource().getType().hasStaticShape())
+ return rewriter.notifyMatchFailure(descOp, "requires static shape");
SmallVector<Value> resolvedOffsets;
affine::resolveIndicesIntoOpWithOffsetsAndStrides(
@@ -55,8 +59,42 @@ LogicalResult XegpuCreateNdDescOpSubViewOpFolder::matchAndRewrite(
subViewOp.getMixedStrides(), subViewOp.getDroppedDims(),
descOp.getMixedOffsets(), resolvedOffsets);
+ auto updatedSource = subViewOp.getSource();
+ // If the source memref rank is greater than 2, we need to cast the source to
+ // 2D and compute the height, width offsets relative to that.
+ if (resolvedOffsets.size() > 2) {
+ // Cast the source to 2D. This will become the new source.
+ auto sourceTy = subViewOp.getSource().getType();
+ int64_t newWidth = sourceTy.getShape().back();
+ int64_t newHeight = 1;
+ for (int64_t dim : sourceTy.getShape().drop_back())
+ newHeight *= dim;
+ auto newSourceTy =
+ MemRefType::get({newHeight, newWidth}, sourceTy.getElementType());
+ int64_t offset = 0;
+ updatedSource = rewriter.create<memref::ReinterpretCastOp>(
+ descOp.getLoc(), newSourceTy, subViewOp.getSource(), offset,
+ llvm::SmallVector<int64_t>({newHeight, newWidth}),
+ llvm::SmallVector<int64_t>({newWidth, 1}));
+ // Get source strides.
+ llvm::SmallVector<int64_t> sourceStrides;
+ int64_t sourceOffset;
+ std::tie(sourceStrides, sourceOffset) = mlir::getStridesAndOffset(sourceTy);
+ // Compute height offset.
+ mlir::Value heightOffset = resolvedOffsets[resolvedOffsets.size() - 2];
+ for (int64_t i = resolvedOffsets.size() - 3; i >= 0; --i) {
+ auto constStrideOp = rewriter.create<arith::ConstantIndexOp>(
+ descOp.getLoc(), sourceStrides[i]);
+ auto mulOp = rewriter.create<arith::MulIOp>(
+ descOp.getLoc(), resolvedOffsets[i], constStrideOp);
+ heightOffset =
+ rewriter.create<arith::AddIOp>(descOp.getLoc(), mulOp, heightOffset);
+ }
+ resolvedOffsets = {heightOffset, resolvedOffsets.back()};
+ }
+
rewriter.replaceOpWithNewOp<xegpu::CreateNdDescOp>(
- descOp, descOp.getTensorDesc().getType(), subViewOp.getSource(),
+ descOp, descOp.getTensorDesc().getType(), updatedSource,
getAsOpFoldResult(resolvedOffsets));
return success();
diff --git a/mlir/test/Dialect/XeGPU/xegpu-fold-alias-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-fold-alias-ops.mlir
index d32954127fce6..69f195d0d328b 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-fold-alias-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-fold-alias-ops.mlir
@@ -18,3 +18,57 @@ func.func @fold_subview_with_xegpu_create_nd_tdesc(%arg0 : memref<256x256xf32>,
// CHECK-DAG: %[[IDX0:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG3]]]
// CHECK-DAG: %[[IDX1:.+]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[ARG4]]]
// CHECK: xegpu.create_nd_tdesc %[[ARG0]][%[[IDX0]], %[[IDX1]]] : memref<256x256xf32> -> !xegpu.tensor_desc<8x16xf32>
+
+// -----
+func.func @fold_subview_with_xegpu_create_nd_tdesc(%arg0 : memref<32x256x256xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index) ->(!xegpu.tensor_desc<8x16xf32>) {
+ %subview = memref.subview %arg0[%arg1, %arg2, %arg3] [1, 32, 32] [1, 1, 1] :
+ memref<32x256x256xf32> to memref<32x32xf32, strided<[256, 1], offset: ?>>
+ %0 = xegpu.create_nd_tdesc %subview[%arg4, %arg5] :
+ memref<32x32xf32, strided<[256, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32>
+ return %0 : !xegpu.tensor_desc<8x16xf32>
+}
+
+// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
+// CHECK: func @fold_subview_with_xegpu_create_nd_tdesc
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<32x256x256xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: index
+// CHECK: %[[C65536:[a-zA-Z0-9]+]] = arith.constant 65536 : index
+// CHECK-DAG: %[[IDX0:.+]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[ARG4]]]
+// CHECK-DAG: %[[IDX1:.+]] = affine.apply #[[MAP]]()[%[[ARG3]], %[[ARG5]]]
+// CHECK: %[[CAST:.+]] = memref.reinterpret_cast %[[ARG0]] to offset: [0], sizes: [8192, 256], strides: [256, 1] : memref<32x256x256xf32> to memref<8192x256xf32>
+// CHECK: %[[T1:.+]] = arith.muli %[[ARG1]], %[[C65536]] : index
+// CHECK: %[[T2:.+]] = arith.addi %[[T1]], %[[IDX0]] : index
+// CHECK: xegpu.create_nd_tdesc %[[CAST]][%[[T2]], %[[IDX1]]] : memref<8192x256xf32> -> !xegpu.tensor_desc<8x16xf32>
+
+// -----
+func.func @fold_subview_with_xegpu_create_nd_tdesc(%arg0 : memref<32x32x256x256xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index, %arg6: index) ->(!xegpu.tensor_desc<8x16xf32>) {
+ %subview = memref.subview %arg0[%arg1, %arg2, %arg3, %arg4] [1, 1, 32, 32] [1, 1, 1, 1] :
+ memref<32x32x256x256xf32> to memref<32x32xf32, strided<[256, 1], offset: ?>>
+ %0 = xegpu.create_nd_tdesc %subview[%arg5, %arg6] :
+ memref<32x32xf32, strided<[256, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32>
+ return %0 : !xegpu.tensor_desc<8x16xf32>
+}
+
+// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
+// CHECK: func @fold_subview_with_xegpu_create_nd_tdesc
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<32x32x256x256xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: index
+// CHECK: %[[C2097152:[a-zA-Z0-9]+]] = arith.constant 2097152 : index
+// CHECK: %[[C65536:[a-zA-Z0-9]+]] = arith.constant 65536 : index
+// CHECK-DAG: %[[IDX0:.+]] = affine.apply #[[MAP]]()[%[[ARG3]], %[[ARG5]]]
+// CHECK-DAG: %[[IDX1:.+]] = affine.apply #[[MAP]]()[%[[ARG4]], %[[ARG6]]]
+// CHECK: %[[CAST:.+]] = memref.reinterpret_cast %[[ARG0]] to offset: [0], sizes: [262144, 256], strides: [256, 1] : memref<32x32x256x256xf32> to memref<262144x256xf32>
+// CHECK: %[[T1:.+]] = arith.muli %[[ARG2]], %[[C65536]] : index
+// CHECK: %[[T2:.+]] = arith.addi %[[T1]], %[[IDX0]] : index
+// CHECK: %[[T3:.+]] = arith.muli %[[ARG1]], %[[C2097152]] : index
+// CHECK: %[[T4:.+]] = arith.addi %[[T3]], %[[T2]] : index
+// CHECK: xegpu.create_nd_tdesc %[[CAST]][%[[T4]], %[[IDX1]]] : memref<262144x256xf32> -> !xegpu.tensor_desc<8x16xf32>
``````````
</details>
https://github.com/llvm/llvm-project/pull/99593
More information about the Mlir-commits
mailing list