[Mlir-commits] [mlir] [mlir][MemRef] Add more ops to narrow type support, strided metadata expansion (PR #102228)

Krzysztof Drewniak llvmlistbot at llvm.org
Tue Aug 6 14:20:39 PDT 2024


https://github.com/krzysz00 created https://github.com/llvm/llvm-project/pull/102228

- Add support fef memory_space_cast to strided metadata expansion and narrow type emulation
- Add support for expand_shape to narrow type emulation (like collapse_shape, it's a noop after linearization) and to expand-strided-metadata (mirroring the collapse_shape pattern)
- Add support for memref.dealloc to narrow type emulation (it is a trivial rewrite) and for memref.copy (which is unsupported when it is used for a layout change but a trivial rewrite otherwise)

>From f65fc29b30df1ea0494847c70256459222fc1c79 Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
Date: Tue, 16 Jul 2024 22:47:28 +0000
Subject: [PATCH] [mlir][MemRef] Add more ops to narrow type support, strided
 metadata

- Add support fef memory_space_cast to strided metadata expansion and
narrow type emulation
- Add support for expand_shape to narrow type emulation (like
collapse_shape, it's a noop after linearization) and to
expand-strided-metadata (mirroring the collapse_shape pattern)
- Add support for memref.dealloc to narrow type emulation (it is a
trivial rewrite) and for memref.copy (which is unsupported when it is
used for a layout change but a trivial rewrite otherwise)
---
 .../MemRef/Transforms/EmulateNarrowType.cpp   | 93 ++++++++++++++++++-
 .../Transforms/ExpandStridedMetadata.cpp      | 82 ++++++++++++++++
 .../Dialect/MemRef/emulate-narrow-type.mlir   | 68 ++++++++++++++
 .../MemRef/expand-strided-metadata.mlir       | 38 ++++++++
 4 files changed, 278 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
index 20ce1b1da4c9b..64e2c0e0e9e08 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
@@ -235,6 +235,46 @@ struct ConvertMemRefAssumeAlignment final
   }
 };
 
+//===----------------------------------------------------------------------===//
+// ConvertMemRefCopy
+//===----------------------------------------------------------------------===//
+
+struct ConvertMemRefCopy final : OpConversionPattern<memref::CopyOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto maybeRankedSource = dyn_cast<MemRefType>(op.getSource().getType());
+    auto maybeRankedDest = dyn_cast<MemRefType>(op.getTarget().getType());
+    if (maybeRankedSource && maybeRankedDest &&
+        maybeRankedSource.getLayout() != maybeRankedDest.getLayout())
+      return rewriter.notifyMatchFailure(
+          op, llvm::formatv("memref.copy emulation with distinct layouts ({0} "
+                            "and {1}) is currently unimplemented",
+                            maybeRankedSource.getLayout(),
+                            maybeRankedDest.getLayout()));
+    rewriter.replaceOpWithNewOp<memref::CopyOp>(op, adaptor.getSource(),
+                                                adaptor.getTarget());
+    return success();
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// ConvertMemRefDealloc
+//===----------------------------------------------------------------------===//
+
+struct ConvertMemRefDealloc final : OpConversionPattern<memref::DeallocOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    rewriter.replaceOpWithNewOp<memref::DeallocOp>(op, adaptor.getMemref());
+    return success();
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // ConvertMemRefLoad
 //===----------------------------------------------------------------------===//
@@ -301,6 +341,30 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
   }
 };
 
+//===----------------------------------------------------------------------===//
+// ConvertMemRefMemorySpaceCast
+//===----------------------------------------------------------------------===//
+
+struct ConvertMemRefMemorySpaceCast final
+    : OpConversionPattern<memref::MemorySpaceCastOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(memref::MemorySpaceCastOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Type newTy = getTypeConverter()->convertType(op.getDest().getType());
+    if (!newTy) {
+      return rewriter.notifyMatchFailure(
+          op->getLoc(), llvm::formatv("failed to convert memref type: {0}",
+                                      op.getDest().getType()));
+    }
+
+    rewriter.replaceOpWithNewOp<memref::MemorySpaceCastOp>(op, newTy,
+                                                           adaptor.getSource());
+    return success();
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // ConvertMemRefReinterpretCast
 //===----------------------------------------------------------------------===//
@@ -492,6 +556,28 @@ struct ConvertMemRefCollapseShape final
   }
 };
 
+/// Emulating a `memref.expand_shape` becomes a no-op after emulation given
+/// that we flatten memrefs to a single dimension as part of the emulation and
+/// the expansion would just have been undone.
+struct ConvertMemRefExpandShape final
+    : OpConversionPattern<memref::ExpandShapeOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(memref::ExpandShapeOp expandShapeOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Value srcVal = adaptor.getSrc();
+    auto newTy = dyn_cast<MemRefType>(srcVal.getType());
+    if (!newTy)
+      return failure();
+
+    if (newTy.getRank() != 1)
+      return failure();
+
+    rewriter.replaceOp(expandShapeOp, srcVal);
+    return success();
+  }
+};
 } // end anonymous namespace
 
 //===----------------------------------------------------------------------===//
@@ -504,9 +590,10 @@ void memref::populateMemRefNarrowTypeEmulationPatterns(
 
   // Populate `memref.*` conversion patterns.
   patterns.add<ConvertMemRefAllocation<memref::AllocOp>,
-               ConvertMemRefAllocation<memref::AllocaOp>,
-               ConvertMemRefCollapseShape, ConvertMemRefLoad,
-               ConvertMemrefStore, ConvertMemRefAssumeAlignment,
+               ConvertMemRefAllocation<memref::AllocaOp>, ConvertMemRefCopy,
+               ConvertMemRefDealloc, ConvertMemRefCollapseShape,
+               ConvertMemRefExpandShape, ConvertMemRefLoad, ConvertMemrefStore,
+               ConvertMemRefAssumeAlignment, ConvertMemRefMemorySpaceCast,
                ConvertMemRefSubview, ConvertMemRefReinterpretCast>(
       typeConverter, patterns.getContext());
   memref::populateResolveExtractStridedMetadataPatterns(patterns);
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
index 585c5b7381421..78c03da7d453c 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
@@ -726,6 +726,41 @@ struct ExtractStridedMetadataOpCollapseShapeFolder
   }
 };
 
+/// Pattern to replace `extract_strided_metadata(expand_shape)`
+/// with the results of computing the sizes and strides on the expanded shape
+/// and dividing up dimensions into static and dynamic parts as needed.
+struct ExtractStridedMetadataOpExpandShapeFolder
+    : OpRewritePattern<memref::ExtractStridedMetadataOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
+                                PatternRewriter &rewriter) const override {
+    auto expandShapeOp = op.getSource().getDefiningOp<memref::ExpandShapeOp>();
+    if (!expandShapeOp)
+      return failure();
+
+    FailureOr<StridedMetadata> stridedMetadata =
+        resolveReshapeStridedMetadata<memref::ExpandShapeOp>(
+            rewriter, expandShapeOp, getExpandedSizes, getExpandedStrides);
+    if (failed(stridedMetadata)) {
+      return rewriter.notifyMatchFailure(
+          op, "failed to resolve metadata in terms of source expand_shape op");
+    }
+
+    Location loc = expandShapeOp.getLoc();
+    SmallVector<Value> results;
+    results.push_back(stridedMetadata->basePtr);
+    results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc,
+                                                      stridedMetadata->offset));
+    results.append(
+        getValueOrCreateConstantIndexOp(rewriter, loc, stridedMetadata->sizes));
+    results.append(getValueOrCreateConstantIndexOp(rewriter, loc,
+                                                   stridedMetadata->strides));
+    rewriter.replaceOp(op, results);
+    return success();
+  }
+};
+
 /// Replace `base, offset, sizes, strides =
 ///              extract_strided_metadata(allocLikeOp)`
 ///
@@ -1060,6 +1095,49 @@ class ExtractStridedMetadataOpCastFolder
   }
 };
 
+/// Replace `base, offset, sizes, strides = extract_strided_metadata(
+///      memory_space_cast(src) to dstTy)`
+/// with
+/// ```
+///    oldBase, offset, sizes, strides = extract_strided_metadata(src)
+///    destBaseTy = type(oldBase) with memory space from destTy
+///    base = memory_space_cast(oldBase) to destBaseTy
+/// ```
+///
+/// In other words, propagate metadata extraction accross memory space casts.
+class ExtractStridedMetadataOpMemorySpaceCastFolder
+    : public OpRewritePattern<memref::ExtractStridedMetadataOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult
+  matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
+                  PatternRewriter &rewriter) const override {
+    Location loc = extractStridedMetadataOp.getLoc();
+    Value source = extractStridedMetadataOp.getSource();
+    auto memSpaceCastOp = source.getDefiningOp<memref::MemorySpaceCastOp>();
+    if (!memSpaceCastOp)
+      return failure();
+    auto newExtractStridedMetadata =
+        rewriter.create<memref::ExtractStridedMetadataOp>(
+            loc, memSpaceCastOp.getSource());
+    SmallVector<Value> results(newExtractStridedMetadata.getResults());
+    if (!extractStridedMetadataOp.getBaseBuffer().use_empty()) {
+      auto baseBuffer = results[0];
+      auto baseBufferType = cast<MemRefType>(baseBuffer.getType());
+      MemRefType::Builder newTypeBuilder(baseBufferType);
+      newTypeBuilder.setMemorySpace(
+          memSpaceCastOp.getResult().getType().getMemorySpace());
+      results[0] = rewriter.create<memref::MemorySpaceCastOp>(
+          loc, Type{newTypeBuilder}, baseBuffer);
+    } else {
+      // Don't create spurious casts for values that are going away.
+      results[0] = nullptr;
+    }
+    rewriter.replaceOp(extractStridedMetadataOp, results);
+    return success();
+  }
+};
+
 /// Replace `base, offset =
 ///            extract_strided_metadata(extract_strided_metadata(src)#0)`
 /// With
@@ -1099,11 +1177,13 @@ void memref::populateExpandStridedMetadataPatterns(
                ExtractStridedMetadataOpAllocFolder<memref::AllocOp>,
                ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>,
                ExtractStridedMetadataOpCollapseShapeFolder,
+               ExtractStridedMetadataOpExpandShapeFolder,
                ExtractStridedMetadataOpGetGlobalFolder,
                RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
                ExtractStridedMetadataOpReinterpretCastFolder,
                ExtractStridedMetadataOpSubviewFolder,
                ExtractStridedMetadataOpCastFolder,
+               ExtractStridedMetadataOpMemorySpaceCastFolder,
                ExtractStridedMetadataOpExtractStridedMetadataFolder>(
       patterns.getContext());
 }
@@ -1113,11 +1193,13 @@ void memref::populateResolveExtractStridedMetadataPatterns(
   patterns.add<ExtractStridedMetadataOpAllocFolder<memref::AllocOp>,
                ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>,
                ExtractStridedMetadataOpCollapseShapeFolder,
+               ExtractStridedMetadataOpExpandShapeFolder,
                ExtractStridedMetadataOpGetGlobalFolder,
                ExtractStridedMetadataOpSubviewFolder,
                RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
                ExtractStridedMetadataOpReinterpretCastFolder,
                ExtractStridedMetadataOpCastFolder,
+               ExtractStridedMetadataOpMemorySpaceCastFolder,
                ExtractStridedMetadataOpExtractStridedMetadataFolder>(
       patterns.getContext());
 }
diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
index a67237b5e4dd1..540da239fced0 100644
--- a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
@@ -6,11 +6,13 @@ func.func @memref_i8() -> i8 {
     %c3 = arith.constant 3 : index
     %m = memref.alloc() : memref<4xi8, 1>
     %v = memref.load %m[%c3] : memref<4xi8, 1>
+    memref.dealloc %m : memref<4xi8, 1>
     return %v : i8
 }
 // CHECK-LABEL: func @memref_i8()
 //       CHECK:   %[[M:.+]] = memref.alloc() : memref<4xi8, 1>
 //  CHECK-NEXT:   %[[V:.+]] = memref.load %[[M]][%{{.+}}] : memref<4xi8, 1>
+//  CHECK-NEXT:   memref.dealloc %[[M]]
 //  CHECK-NEXT:   return %[[V]]
 
 // CHECK32-LABEL: func @memref_i8()
@@ -21,6 +23,7 @@ func.func @memref_i8() -> i8 {
 //       CHECK32:   %[[CAST:.+]] = arith.index_cast %[[C24]] : index to i32
 //       CHECK32:   %[[SHIFTRT:.+]] = arith.shrsi %[[V]], %[[CAST]]
 //       CHECK32:   %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i32 to i8
+//  CHECK32-NEXT:   memref.dealloc %[[M]]
 //  CHECK32-NEXT:   return %[[TRUNC]]
 
 // -----
@@ -485,3 +488,68 @@ func.func @memref_collapse_shape_i4(%idx0 : index, %idx1 : index) -> i4 {
 //   CHECK32-NOT:     memref.collapse_shape
 //       CHECK32:     memref.load %[[ALLOC]][%{{.*}}] : memref<4096xi32>
 
+// -----
+
+func.func @memref_expand_shape_i4(%idx0 : index, %idx1 : index, %idx2 : index) -> i4 {
+  %arr = memref.alloc() : memref<256x128xi4>
+  %expand = memref.expand_shape %arr[[0, 1], [2]] output_shape [32, 8, 128] : memref<256x128xi4> into memref<32x8x128xi4>
+  %1 = memref.load %expand[%idx0, %idx1, %idx2] : memref<32x8x128xi4>
+  return %1 : i4
+}
+
+// CHECK-LABEL:   func.func @memref_expand_shape_i4(
+//       CHECK:     %[[ALLOC:.*]] = memref.alloc() : memref<16384xi8>
+//   CHECK-NOT:     memref.expand_shape
+//       CHECK:     memref.load %[[ALLOC]][%{{.*}}] : memref<16384xi8>
+
+// CHECK32-LABEL:   func.func @memref_expand_shape_i4(
+//       CHECK32:     %[[ALLOC:.*]] = memref.alloc() : memref<4096xi32>
+//   CHECK32-NOT:     memref.expand_shape
+//       CHECK32:     memref.load %[[ALLOC]][%{{.*}}] : memref<4096xi32>
+
+// -----
+
+func.func @memref_memory_space_cast_i4(%arg0: memref<32x128xi4, 1>) -> memref<32x128xi4> {
+  %cast = memref.memory_space_cast %arg0 : memref<32x128xi4, 1> to memref<32x128xi4>
+  return %cast : memref<32x128xi4>
+}
+
+// CHECK-LABEL:   func.func @memref_memory_space_cast_i4(
+//  CHECK-SAME:   %[[ARG0:.*]]: memref<2048xi8, 1>
+//       CHECK:     %[[CAST:.*]] = memref.memory_space_cast %[[ARG0]] : memref<2048xi8, 1> to memref<2048xi8>
+//       CHECK:     return %[[CAST]]
+
+// CHECK32-LABEL:   func.func @memref_memory_space_cast_i4(
+//  CHECK32-SAME:   %[[ARG0:.*]]: memref<512xi32, 1>
+//       CHECK32:     %[[CAST:.*]] = memref.memory_space_cast %[[ARG0]] : memref<512xi32, 1> to memref<512xi32>
+//       CHECK32:     return %[[CAST]]
+
+// -----
+
+func.func @memref_copy_i4(%arg0: memref<32x128xi4, 1>, %arg1: memref<32x128xi4>) {
+  memref.copy %arg0, %arg1 : memref<32x128xi4, 1> to memref<32x128xi4>
+  return
+}
+
+// CHECK-LABEL:   func.func @memref_copy_i4(
+//  CHECK-SAME:   %[[ARG0:.*]]: memref<2048xi8, 1>, %[[ARG1:.*]]: memref<2048xi8>
+//       CHECK:     memref.copy %[[ARG0]], %[[ARG1]]
+//       CHECK:     return
+
+// CHECK32-LABEL:   func.func @memref_copy_i4(
+//  CHECK32-SAME:   %[[ARG0:.*]]: memref<512xi32, 1>, %[[ARG1:.*]]: memref<512xi32>
+//       CHECK32:     memref.copy %[[ARG0]], %[[ARG1]]
+//       CHECK32:     return
+
+// -----
+
+!colMajor = memref<8x8xi4, strided<[1, 8]>>
+func.func @copy_distinct_layouts(%idx : index) -> i4 {
+  %c0 = arith.constant 0 : index
+  %arr = memref.alloc() : memref<8x8xi4>
+  %arr2 = memref.alloc() : !colMajor
+  // expected-error @+1 {{failed to legalize operation 'memref.copy' that was explicitly marked illegal}}
+  memref.copy %arr, %arr2 : memref<8x8xi4> to !colMajor
+  %ld = memref.load %arr2[%c0, %c0] : !colMajor
+  return %ld : i4
+}
diff --git a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
index d884ade319532..8aac802ba10ae 100644
--- a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
+++ b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
@@ -1553,3 +1553,41 @@ func.func @extract_strided_metadata_of_collapse_shape(%base: memref<5x4xf32>)
 //   CHECK-DAG:    %[[STEP:.*]] = arith.constant 1 : index
 //       CHECK:    %[[BASE:.*]], %{{.*}}, %{{.*}}, %{{.*}} = memref.extract_strided_metadata
 //       CHECK:    return %[[BASE]], %[[OFFSET]], %[[SIZE]], %[[STEP]] : memref<f32>, index, index, index
+
+// -----
+
+func.func @extract_strided_metadata_of_memory_space_cast(%base: memref<20xf32>)
+    -> (memref<f32, 1>, index, index, index) {
+
+  %memory_space_cast = memref.memory_space_cast %base : memref<20xf32> to memref<20xf32, 1>
+
+  %base_buffer, %offset, %size, %stride = memref.extract_strided_metadata %memory_space_cast :
+    memref<20xf32, 1> -> memref<f32, 1>, index, index, index
+
+  return %base_buffer, %offset, %size, %stride :
+    memref<f32, 1>, index, index, index
+}
+
+// CHECK-LABEL:  func @extract_strided_metadata_of_memory_space_cast
+//   CHECK-DAG:    %[[OFFSET:.*]] = arith.constant 0 : index
+//   CHECK-DAG:    %[[SIZE:.*]] = arith.constant 20 : index
+//   CHECK-DAG:    %[[STEP:.*]] = arith.constant 1 : index
+//       CHECK:    %[[BASE:.*]], %{{.*}}, %{{.*}}, %{{.*}} = memref.extract_strided_metadata
+//       CHECK:    %[[CAST:.*]] = memref.memory_space_cast %[[BASE]]
+//       CHECK:    return %[[CAST]], %[[OFFSET]], %[[SIZE]], %[[STEP]] : memref<f32, 1>, index, index, index
+
+// -----
+
+func.func @extract_strided_metadata_of_memory_space_cast_no_base(%base: memref<20xf32>)
+    -> (index, index, index) {
+
+  %memory_space_cast = memref.memory_space_cast %base : memref<20xf32> to memref<20xf32, 1>
+
+  %base_buffer, %offset, %size, %stride = memref.extract_strided_metadata %memory_space_cast :
+    memref<20xf32, 1> -> memref<f32, 1>, index, index, index
+
+  return %offset, %size, %stride : index, index, index
+}
+
+// CHECK-LABEL:  func @extract_strided_metadata_of_memory_space_cast_no_base
+//   CHECK-NOT:  memref.memory_space_cast



More information about the Mlir-commits mailing list