[Mlir-commits] [mlir] Support folding of higher dimensional memeref subviews in XeGPUFoldAliasOps (PR #99593)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jul 18 19:48:31 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Charitha Saumya (charithaintc)

<details>
<summary>Changes</summary>

`XeGPUFoldAliasOps` pass folds `memref.subview` operations that are sources of `xegpu.create_nd_tdesc`. But this does not support subviews created from n-D memrefs (n >2). This PR adds support for higher dimensional memrefs. 

Example usage:
```
func.func @<!-- -->fold_subview_with_xegpu_create_nd_tdesc(%arg0 : memref<32x256x256xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index) ->(!xegpu.tensor_desc<8x16xf32>) {
  %subview = memref.subview %arg0[%arg1, %arg2, %arg3] [1, 32, 32] [1, 1, 1] :
    memref<32x256x256xf32> to memref<32x32xf32, strided<[256, 1], offset: ?>>
  %0 = xegpu.create_nd_tdesc %subview[%arg4, %arg5] :
    memref<32x32xf32, strided<[256, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32>
  return %0 : !xegpu.tensor_desc<8x16xf32>
}
```

Gets folded to:
```
#map = affine_map<()[s0, s1] -> (s0 + s1)>
module {
  func.func @<!-- -->fold_subview_with_xegpu_create_nd_tdesc(%arg0: memref<32x256x256xf32>, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index) -> !xegpu.tensor_desc<8x16xf32> {
    %c65536 = arith.constant 65536 : index
    %0 = affine.apply #map()[%arg2, %arg4]
    %1 = affine.apply #map()[%arg3, %arg5]
    %reinterpret_cast = memref.reinterpret_cast %arg0 to offset: [0], sizes: [8192, 256], strides: [256, 1] : memref<32x256x256xf32> to memref<8192x256xf32>
    %2 = arith.muli %arg1, %c65536 : index
    %3 = arith.addi %2, %0 : index
    %4 = xegpu.create_nd_tdesc %reinterpret_cast[%3, %1] : memref<8192x256xf32> -> !xegpu.tensor_desc<8x16xf32>
    return %4 : !xegpu.tensor_desc<8x16xf32>
  }
}
```

Please review these guidelines to help with the review process:
- [ ] Have you provided a meaningful PR description?
- [ ] Have you added a test, a reproducer, or a reference to an issue with a reproducer?
- [ ] Have you tested your changes locally for CPU and GPU devices?
- [ ] Have you made sure that new changes do not introduce compiler warnings?
- [ ] If this PR is a work in progress, are you filing the PR as a draft?
- [ ] Have you organized your commits logically and ensured each can be built by itself?


---
Full diff: https://github.com/llvm/llvm-project/pull/99593.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUFoldAliasOps.cpp (+39-1) 
- (modified) mlir/test/Dialect/XeGPU/xegpu-fold-alias-ops.mlir (+54) 


``````````diff
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUFoldAliasOps.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUFoldAliasOps.cpp
index 9307e8eb784b5..27e10dfc785e4 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUFoldAliasOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUFoldAliasOps.cpp
@@ -6,12 +6,14 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/XeGPU/Transforms/Passes.h"
 
 #include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
 #include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/Support/Debug.h"
@@ -48,6 +50,8 @@ LogicalResult XegpuCreateNdDescOpSubViewOpFolder::matchAndRewrite(
     return rewriter.notifyMatchFailure(descOp, "not a subview producer");
   if (!subViewOp.hasUnitStride())
     return rewriter.notifyMatchFailure(descOp, "requires unit strides");
+  if (!subViewOp.getSource().getType().hasStaticShape())
+    return rewriter.notifyMatchFailure(descOp, "requires static shape");
 
   SmallVector<Value> resolvedOffsets;
   affine::resolveIndicesIntoOpWithOffsetsAndStrides(
@@ -55,8 +59,42 @@ LogicalResult XegpuCreateNdDescOpSubViewOpFolder::matchAndRewrite(
       subViewOp.getMixedStrides(), subViewOp.getDroppedDims(),
       descOp.getMixedOffsets(), resolvedOffsets);
 
+  auto updatedSource = subViewOp.getSource();
+  // If the source memref rank is greater than 2, we need to cast the source to
+  // 2D and compute the height, width offsets relative to that.
+  if (resolvedOffsets.size() > 2) {
+    // Cast the source to 2D. This will become the new source.
+    auto sourceTy = subViewOp.getSource().getType();
+    int64_t newWidth = sourceTy.getShape().back();
+    int64_t newHeight = 1;
+    for (int64_t dim : sourceTy.getShape().drop_back())
+      newHeight *= dim;
+    auto newSourceTy =
+        MemRefType::get({newHeight, newWidth}, sourceTy.getElementType());
+    int64_t offset = 0;
+    updatedSource = rewriter.create<memref::ReinterpretCastOp>(
+        descOp.getLoc(), newSourceTy, subViewOp.getSource(), offset,
+        llvm::SmallVector<int64_t>({newHeight, newWidth}),
+        llvm::SmallVector<int64_t>({newWidth, 1}));
+    // Get source strides.
+    llvm::SmallVector<int64_t> sourceStrides;
+    int64_t sourceOffset;
+    std::tie(sourceStrides, sourceOffset) = mlir::getStridesAndOffset(sourceTy);
+    // Compute height offset.
+    mlir::Value heightOffset = resolvedOffsets[resolvedOffsets.size() - 2];
+    for (int64_t i = resolvedOffsets.size() - 3; i >= 0; --i) {
+      auto constStrideOp = rewriter.create<arith::ConstantIndexOp>(
+          descOp.getLoc(), sourceStrides[i]);
+      auto mulOp = rewriter.create<arith::MulIOp>(
+          descOp.getLoc(), resolvedOffsets[i], constStrideOp);
+      heightOffset =
+          rewriter.create<arith::AddIOp>(descOp.getLoc(), mulOp, heightOffset);
+    }
+    resolvedOffsets = {heightOffset, resolvedOffsets.back()};
+  }
+
   rewriter.replaceOpWithNewOp<xegpu::CreateNdDescOp>(
-      descOp, descOp.getTensorDesc().getType(), subViewOp.getSource(),
+      descOp, descOp.getTensorDesc().getType(), updatedSource,
       getAsOpFoldResult(resolvedOffsets));
 
   return success();
diff --git a/mlir/test/Dialect/XeGPU/xegpu-fold-alias-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-fold-alias-ops.mlir
index d32954127fce6..69f195d0d328b 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-fold-alias-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-fold-alias-ops.mlir
@@ -18,3 +18,57 @@ func.func @fold_subview_with_xegpu_create_nd_tdesc(%arg0 : memref<256x256xf32>,
 //   CHECK-DAG:   %[[IDX0:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG3]]]
 //   CHECK-DAG:   %[[IDX1:.+]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[ARG4]]]
 //   CHECK:       xegpu.create_nd_tdesc %[[ARG0]][%[[IDX0]], %[[IDX1]]] : memref<256x256xf32> -> !xegpu.tensor_desc<8x16xf32>
+
+// -----
+func.func @fold_subview_with_xegpu_create_nd_tdesc(%arg0 : memref<32x256x256xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index) ->(!xegpu.tensor_desc<8x16xf32>) {
+  %subview = memref.subview %arg0[%arg1, %arg2, %arg3] [1, 32, 32] [1, 1, 1] :
+    memref<32x256x256xf32> to memref<32x32xf32, strided<[256, 1], offset: ?>>
+  %0 = xegpu.create_nd_tdesc %subview[%arg4, %arg5] :
+    memref<32x32xf32, strided<[256, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32>
+  return %0 : !xegpu.tensor_desc<8x16xf32>
+}
+
+//   CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
+//       CHECK: func @fold_subview_with_xegpu_create_nd_tdesc
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]: memref<32x256x256xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:   %[[ARG2:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:   %[[ARG3:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:   %[[ARG4:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:   %[[ARG5:[a-zA-Z0-9]+]]: index
+//       CHECK:   %[[C65536:[a-zA-Z0-9]+]] = arith.constant 65536 : index
+//   CHECK-DAG:   %[[IDX0:.+]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[ARG4]]]
+//   CHECK-DAG:   %[[IDX1:.+]] = affine.apply #[[MAP]]()[%[[ARG3]], %[[ARG5]]]
+//       CHECK:   %[[CAST:.+]] = memref.reinterpret_cast %[[ARG0]] to offset: [0], sizes: [8192, 256], strides: [256, 1] : memref<32x256x256xf32> to memref<8192x256xf32>
+//       CHECK:   %[[T1:.+]] = arith.muli %[[ARG1]], %[[C65536]] : index
+//       CHECK:   %[[T2:.+]] = arith.addi %[[T1]], %[[IDX0]] : index
+//       CHECK:   xegpu.create_nd_tdesc %[[CAST]][%[[T2]], %[[IDX1]]] : memref<8192x256xf32> -> !xegpu.tensor_desc<8x16xf32>
+
+// -----
+func.func @fold_subview_with_xegpu_create_nd_tdesc(%arg0 : memref<32x32x256x256xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index, %arg6: index) ->(!xegpu.tensor_desc<8x16xf32>) {
+  %subview = memref.subview %arg0[%arg1, %arg2, %arg3, %arg4] [1, 1, 32, 32] [1, 1, 1, 1] :
+    memref<32x32x256x256xf32> to memref<32x32xf32, strided<[256, 1], offset: ?>>
+  %0 = xegpu.create_nd_tdesc %subview[%arg5, %arg6] :
+    memref<32x32xf32, strided<[256, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32>
+  return %0 : !xegpu.tensor_desc<8x16xf32>
+}
+
+//   CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
+//       CHECK: func @fold_subview_with_xegpu_create_nd_tdesc
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]: memref<32x32x256x256xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:   %[[ARG2:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:   %[[ARG3:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:   %[[ARG4:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:   %[[ARG5:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:   %[[ARG6:[a-zA-Z0-9]+]]: index
+//       CHECK:   %[[C2097152:[a-zA-Z0-9]+]] = arith.constant 2097152 : index
+//       CHECK:   %[[C65536:[a-zA-Z0-9]+]] = arith.constant 65536 : index
+//   CHECK-DAG:   %[[IDX0:.+]] = affine.apply #[[MAP]]()[%[[ARG3]], %[[ARG5]]]
+//   CHECK-DAG:   %[[IDX1:.+]] = affine.apply #[[MAP]]()[%[[ARG4]], %[[ARG6]]]
+//       CHECK:   %[[CAST:.+]] = memref.reinterpret_cast %[[ARG0]] to offset: [0], sizes: [262144, 256], strides: [256, 1] : memref<32x32x256x256xf32> to memref<262144x256xf32>
+//       CHECK:   %[[T1:.+]] = arith.muli %[[ARG2]], %[[C65536]] : index
+//       CHECK:   %[[T2:.+]] = arith.addi %[[T1]], %[[IDX0]] : index
+//       CHECK:   %[[T3:.+]] = arith.muli %[[ARG1]], %[[C2097152]] : index
+//       CHECK:   %[[T4:.+]] = arith.addi %[[T3]], %[[T2]] : index
+//       CHECK:   xegpu.create_nd_tdesc %[[CAST]][%[[T4]], %[[IDX1]]] : memref<262144x256xf32> -> !xegpu.tensor_desc<8x16xf32>

``````````

</details>


https://github.com/llvm/llvm-project/pull/99593


More information about the Mlir-commits mailing list