[Mlir-commits] [mlir] 9d25991 - [mlir][MemRef] Simplify extract_strided_metadata(allocLikeOp)

Quentin Colombet llvmlistbot at llvm.org
Mon Sep 26 09:26:18 PDT 2022


Author: Quentin Colombet
Date: 2022-09-26T16:14:29Z
New Revision: 9d259916e1ae69f83762b1801bfb0db1faf20262

URL: https://github.com/llvm/llvm-project/commit/9d259916e1ae69f83762b1801bfb0db1faf20262
DIFF: https://github.com/llvm/llvm-project/commit/9d259916e1ae69f83762b1801bfb0db1faf20262.diff

LOG: [mlir][MemRef] Simplify extract_strided_metadata(allocLikeOp)

Teach the pass that simplifies extract_strided_metadata(other_op(memref))
how to get rid of extract_strided_metadata when they are fed by
allocLikeOp.

For the simplification to happen the allocLikeOp needs to have been
normalized. I.e., no weird offset and strides.

When this is the case, we replace:
```
base, offset, sizes, strides =
    extract_strided_metadata(allocLikeOp(allocSizes))
```

With
```
base = reinterpret_cast allocLikeOp(allocSizes) to a flat memref<eltTy>
offset = 0
sizes = allocSizes
strides#i = prod(allocSizes#j, for j in {i+1..rank-1})
```

The computation involving dynamic sizes are expanded in affine.apply.

Differential Revision: https://reviews.llvm.org/D134577

Added: 
    

Modified: 
    mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
    mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
index dcadd5b33d07..1d42c9109447 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
@@ -487,14 +487,103 @@ struct ForwardStaticMetadata
     return success(atLeastOneReplacement);
   }
 };
+
+/// Replace `base, offset, sizes, strides =
+///              extract_strided_metadata(allocLikeOp)`
+///
+/// With
+///
+/// ```
+/// base = reinterpret_cast allocLikeOp(allocSizes) to a flat memref<eltTy>
+/// offset = 0
+/// sizes = allocSizes
+/// strides#i = prod(allocSizes#j, for j in {i+1..rank-1})
+/// ```
+///
+/// The transformation only applies if the allocLikeOp has been normalized.
+/// In other words, the affine_map must be an identity.
+template <typename AllocLikeOp>
+struct ExtractStridedMetadataOpAllocFolder
+    : public OpRewritePattern<memref::ExtractStridedMetadataOp> {
+public:
+  using OpRewritePattern<memref::ExtractStridedMetadataOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
+                                PatternRewriter &rewriter) const override {
+    auto allocLikeOp = op.getSource().getDefiningOp<AllocLikeOp>();
+    if (!allocLikeOp)
+      return failure();
+
+    auto memRefType =
+        allocLikeOp.getResult().getType().template cast<MemRefType>();
+    if (!memRefType.getLayout().isIdentity())
+      return rewriter.notifyMatchFailure(
+          allocLikeOp, "alloc-like operations should have been normalized");
+
+    Location loc = op.getLoc();
+    int rank = memRefType.getRank();
+
+    // Collect the sizes.
+    ValueRange dynamic = allocLikeOp.getDynamicSizes();
+    SmallVector<OpFoldResult> sizes;
+    sizes.reserve(rank);
+    unsigned dynamicPos = 0;
+    for (int64_t size : memRefType.getShape()) {
+      if (ShapedType::isDynamic(size))
+        sizes.push_back(dynamic[dynamicPos++]);
+      else
+        sizes.push_back(rewriter.getIndexAttr(size));
+    }
+
+    // Strides (just creates identity strides).
+    SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
+    AffineExpr expr = rewriter.getAffineConstantExpr(1);
+    unsigned symbolNumber = 0;
+    for (int i = rank - 2; i >= 0; --i) {
+      expr = expr * rewriter.getAffineSymbolExpr(symbolNumber++);
+      assert(i + 1 + symbolNumber == sizes.size() &&
+             "The ArrayRef should encompass the last #symbolNumber sizes");
+      ArrayRef<OpFoldResult> sizesInvolvedInStride(&sizes[i + 1], symbolNumber);
+      strides[i] = makeComposedFoldedAffineApply(rewriter, loc, expr,
+                                                 sizesInvolvedInStride);
+    }
+
+    // Put all the values together to replace the results.
+    SmallVector<Value> results;
+    results.reserve(rank * 2 + 2);
+
+    auto baseBufferType = op.getBaseBuffer().getType().cast<MemRefType>();
+    int64_t offset = 0;
+    if (allocLikeOp.getType() == baseBufferType)
+      results.push_back(allocLikeOp);
+    else
+      results.push_back(rewriter.create<memref::ReinterpretCastOp>(
+          loc, baseBufferType, allocLikeOp, offset,
+          /*sizes=*/ArrayRef<int64_t>(),
+          /*strides=*/ArrayRef<int64_t>()));
+
+    // Offset.
+    results.push_back(rewriter.create<arith::ConstantIndexOp>(loc, offset));
+
+    for (OpFoldResult size : sizes)
+      results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc, size));
+
+    for (OpFoldResult stride : strides)
+      results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc, stride));
+
+    rewriter.replaceOp(op, results);
+    return success();
+  }
+};
 } // namespace
 
 void memref::populateSimplifyExtractStridedMetadataOpPatterns(
     RewritePatternSet &patterns) {
-  patterns
-      .add<ExtractStridedMetadataOpSubviewFolder,
-           ExtractStridedMetadataOpExpandShapeFolder, ForwardStaticMetadata>(
-          patterns.getContext());
+  patterns.add<ExtractStridedMetadataOpSubviewFolder,
+               ExtractStridedMetadataOpExpandShapeFolder, ForwardStaticMetadata,
+               ExtractStridedMetadataOpAllocFolder<memref::AllocOp>,
+               ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>>(
+      patterns.getContext());
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir b/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir
index dd5889fe47fd..449b35b9ed66 100644
--- a/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir
+++ b/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir
@@ -527,3 +527,226 @@ func.func @extract_strided_metadata_of_expand_shape_all_static_0_rank(
       index, index, index, index, index,
       index, index, index, index, index
 }
+
+// -----
+
+// Check that we simplify extract_strided_metadata(alloc)
+// into simply the alloc with the information extracted from
+// the memref type and arguments of the alloc.
+//
+// baseBuffer = reinterpret_cast alloc
+// offset = 0
+// sizes = shape(memref)
+// strides = strides(memref)
+//
+// For dynamic shapes, we simply use the values that feed the alloc.
+//
+// Simple rank 0 test: we don't need a reinterpret_cast here.
+// CHECK-LABEL: func @extract_strided_metadata_of_alloc_all_static_0_rank
+//
+//   CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+//   CHECK-DAG: %[[ALLOC:.*]] = memref.alloc()
+//       CHECK: return %[[ALLOC]], %[[C0]] : memref<i16>, index
+func.func @extract_strided_metadata_of_alloc_all_static_0_rank()
+    -> (memref<i16>, index) {
+
+  %A = memref.alloc() : memref<i16>
+  %base, %offset = memref.extract_strided_metadata %A :
+    memref<i16>
+    -> memref<i16>, index
+
+  return %base, %offset :
+      memref<i16>, index
+}
+
+// -----
+
+// Simplification of extract_strided_metadata(alloc).
+// Check that we properly use the dynamic sizes to
+// create the new sizes and strides.
+// size 0 = dyn_size0
+// size 1 = 4
+// size 2 = dyn_size2
+// size 3 = dyn_size3
+//
+// stride 0 = size 1 * size 2 * size 3
+//          = 4 * dyn_size2 * dyn_size3
+// stride 1 = size 2 * size 3
+//          = dyn_size2 * dyn_size3
+// stride 2 = size 3
+//          = dyn_size3
+// stride 3 = 1
+//
+//   CHECK-DAG: #[[$STRIDE0_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) * 4)>
+//   CHECK-DAG: #[[$STRIDE1_MAP:.*]] = affine_map<()[s0, s1] -> (s0 * s1)>
+// CHECK-LABEL: extract_strided_metadata_of_alloc_dyn_size
+//  CHECK-SAME: (%[[DYN_SIZE0:.*]]: index, %[[DYN_SIZE2:.*]]: index, %[[DYN_SIZE3:.*]]: index)
+//
+//   CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+//   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+//   CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+//   CHECK-DAG: %[[ALLOC:.*]] = memref.alloc(%[[DYN_SIZE0]], %[[DYN_SIZE2]], %[[DYN_SIZE3]])
+//
+//   CHECK-DAG: %[[STRIDE0:.*]] = affine.apply #[[$STRIDE0_MAP]]()[%[[DYN_SIZE2]], %[[DYN_SIZE3]]]
+//   CHECK-DAG: %[[STRIDE1:.*]] = affine.apply #[[$STRIDE1_MAP]]()[%[[DYN_SIZE2]], %[[DYN_SIZE3]]]
+//
+//   CHECK-DAG:  %[[CASTED_ALLOC:.*]] = memref.reinterpret_cast %[[ALLOC]] to offset: [0], sizes: [], strides: [] : memref<?x4x?x?xi16> to memref<i16>
+//
+//       CHECK: return %[[CASTED_ALLOC]], %[[C0]], %[[DYN_SIZE0]], %[[C4]], %[[DYN_SIZE2]], %[[DYN_SIZE3]], %[[STRIDE0]], %[[STRIDE1]], %[[DYN_SIZE3]], %[[C1]]
+func.func @extract_strided_metadata_of_alloc_dyn_size(
+  %dyn_size0 : index, %dyn_size2 : index, %dyn_size3 : index)
+    -> (memref<i16>, index,
+        index, index, index, index,
+        index, index, index, index) {
+
+  %A = memref.alloc(%dyn_size0, %dyn_size2, %dyn_size3) : memref<?x4x?x?xi16>
+
+  %base, %offset, %sizes:4, %strides:4 = memref.extract_strided_metadata %A :
+    memref<?x4x?x?xi16>
+    -> memref<i16>, index,
+       index, index, index, index,
+       index, index, index, index
+
+  return %base, %offset,
+    %sizes#0, %sizes#1, %sizes#2, %sizes#3,
+    %strides#0, %strides#1, %strides#2, %strides#3 :
+      memref<i16>, index,
+      index, index, index, index,
+      index, index, index, index
+}
+
+// -----
+
+// Same check as extract_strided_metadata_of_alloc_dyn_size but alloca
+// instead of alloc. Just to make sure we handle allocas the same way
+// we do with alloc.
+// While at it, test a slightly 
diff erent shape than
+// extract_strided_metadata_of_alloc_dyn_size.
+//
+// size 0 = dyn_size0
+// size 1 = dyn_size1
+// size 2 = 4
+// size 3 = dyn_size3
+//
+// stride 0 = size 1 * size 2 * size 3
+//          = dyn_size1 * 4 * dyn_size3
+// stride 1 = size 2 * size 3
+//          = 4 * dyn_size3
+// stride 2 = size 3
+//          = dyn_size3
+// stride 3 = 1
+//
+//   CHECK-DAG: #[[$STRIDE0_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) * 4)>
+//   CHECK-DAG: #[[$STRIDE1_MAP:.*]] = affine_map<()[s0] -> (s0 * 4)>
+// CHECK-LABEL: extract_strided_metadata_of_alloca_dyn_size
+//  CHECK-SAME: (%[[DYN_SIZE0:.*]]: index, %[[DYN_SIZE1:.*]]: index, %[[DYN_SIZE3:.*]]: index)
+//
+//   CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+//   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+//   CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+//   CHECK-DAG: %[[ALLOCA:.*]] = memref.alloca(%[[DYN_SIZE0]], %[[DYN_SIZE1]], %[[DYN_SIZE3]])
+//
+//   CHECK-DAG: %[[STRIDE0:.*]] = affine.apply #[[$STRIDE0_MAP]]()[%[[DYN_SIZE1]], %[[DYN_SIZE3]]]
+//   CHECK-DAG: %[[STRIDE1:.*]] = affine.apply #[[$STRIDE1_MAP]]()[%[[DYN_SIZE3]]]
+//
+//   CHECK-DAG:  %[[CASTED_ALLOCA:.*]] = memref.reinterpret_cast %[[ALLOCA]] to offset: [0], sizes: [], strides: [] : memref<?x?x4x?xi16> to memref<i16>
+//
+//       CHECK: return %[[CASTED_ALLOCA]], %[[C0]], %[[DYN_SIZE0]], %[[DYN_SIZE1]], %[[C4]], %[[DYN_SIZE3]], %[[STRIDE0]], %[[STRIDE1]], %[[DYN_SIZE3]], %[[C1]]
+func.func @extract_strided_metadata_of_alloca_dyn_size(
+  %dyn_size0 : index, %dyn_size1 : index, %dyn_size3 : index)
+    -> (memref<i16>, index,
+        index, index, index, index,
+        index, index, index, index) {
+
+  %A = memref.alloca(%dyn_size0, %dyn_size1, %dyn_size3) : memref<?x?x4x?xi16>
+
+  %base, %offset, %sizes:4, %strides:4 = memref.extract_strided_metadata %A :
+    memref<?x?x4x?xi16>
+    -> memref<i16>, index,
+       index, index, index, index,
+       index, index, index, index
+
+  return %base, %offset,
+    %sizes#0, %sizes#1, %sizes#2, %sizes#3,
+    %strides#0, %strides#1, %strides#2, %strides#3 :
+      memref<i16>, index,
+      index, index, index, index,
+      index, index, index, index
+}
+
+// -----
+
+// The following few alloc tests are negative tests (the simplification
+// doesn't happen) to make sure non trivial memref types are treated
+// as "not been normalized".
+// CHECK-LABEL: extract_strided_metadata_of_alloc_with_variable_offset
+//       CHECK: %[[ALLOC:.*]] = memref.alloc
+//       CHECK: %[[BASE:[^,]*]], {{.*}} = memref.extract_strided_metadata %[[ALLOC]]
+//       CHECK: return %[[BASE]]
+#map0 = affine_map<(d0)[s0] -> (d0 + s0)>
+func.func @extract_strided_metadata_of_alloc_with_variable_offset(%arg : index)
+    -> (memref<i16>, index, index, index) {
+
+  %A = memref.alloc()[%arg] : memref<4xi16, #map0>
+  %base, %offset, %size, %stride = memref.extract_strided_metadata %A :
+    memref<4xi16, #map0>
+    -> memref<i16>, index, index, index
+
+  return %base, %offset, %size, %stride :
+      memref<i16>, index, index, index
+}
+
+// -----
+
+// CHECK-LABEL: extract_strided_metadata_of_alloc_with_cst_offset
+//       CHECK: %[[ALLOC:.*]] = memref.alloc
+//       CHECK: %[[BASE:[^,]*]], {{.*}} = memref.extract_strided_metadata %[[ALLOC]]
+//       CHECK: return %[[BASE]]
+#map0 = affine_map<(d0) -> (d0 + 12)>
+func.func @extract_strided_metadata_of_alloc_with_cst_offset(%arg : index)
+    -> (memref<i16>, index, index, index) {
+
+  %A = memref.alloc() : memref<4xi16, #map0>
+  %base, %offset, %size, %stride = memref.extract_strided_metadata %A :
+    memref<4xi16, #map0>
+    -> memref<i16>, index, index, index
+
+  return %base, %offset, %size, %stride :
+      memref<i16>, index, index, index
+}
+
+// -----
+
+// CHECK-LABEL: extract_strided_metadata_of_alloc_with_cst_offset_in_type
+//       CHECK: %[[ALLOC:.*]] = memref.alloc
+//       CHECK: %[[BASE:[^,]*]], {{.*}} = memref.extract_strided_metadata %[[ALLOC]]
+//       CHECK: return %[[BASE]]
+func.func @extract_strided_metadata_of_alloc_with_cst_offset_in_type(%arg : index)
+    -> (memref<i16>, index, index, index) {
+
+  %A = memref.alloc() : memref<4xi16, strided<[1], offset : 10>>
+  %base, %offset, %size, %stride = memref.extract_strided_metadata %A :
+    memref<4xi16, strided<[1], offset : 10>>
+    -> memref<i16>, index, index, index
+
+  return %base, %offset, %size, %stride :
+      memref<i16>, index, index, index
+}
+
+// -----
+
+// CHECK-LABEL: extract_strided_metadata_of_alloc_with_strided
+//       CHECK: %[[ALLOC:.*]] = memref.alloc
+//       CHECK: %[[BASE:[^,]*]], {{.*}} = memref.extract_strided_metadata %[[ALLOC]]
+//       CHECK: return %[[BASE]]
+func.func @extract_strided_metadata_of_alloc_with_strided(%arg : index)
+    -> (memref<i16>, index, index, index) {
+
+  %A = memref.alloc() : memref<4xi16, strided<[12]>>
+  %base, %offset, %size, %stride = memref.extract_strided_metadata %A :
+    memref<4xi16, strided<[12]>>
+    -> memref<i16>, index, index, index
+
+  return %base, %offset, %size, %stride :
+      memref<i16>, index, index, index
+}


        


More information about the Mlir-commits mailing list