[Mlir-commits] [mlir] aba437c - [mlir][Vector] Patterns flattening vector transfers to 1D

Nicolas Vasilache llvmlistbot at llvm.org
Mon Dec 13 14:42:41 PST 2021


Author: Benoit Jacob
Date: 2021-12-13T22:39:41Z
New Revision: aba437ceb2379f219935b98a10ca3c5081f0c8b7

URL: https://github.com/llvm/llvm-project/commit/aba437ceb2379f219935b98a10ca3c5081f0c8b7
DIFF: https://github.com/llvm/llvm-project/commit/aba437ceb2379f219935b98a10ca3c5081f0c8b7.diff

LOG: [mlir][Vector] Patterns flattening vector transfers to 1D

This is the second part of https://reviews.llvm.org/D114993 after slicing
into 2 independent commits.

This is needed at the moment to get good codegen from 2d vector.transfer
ops that aim to compile to SIMD load/store instructions but that can
only do so if the whole 2d transfer shape is handled in one piece, in
particular taking advantage of the memref being contiguous rowmajor.

For instance, if the target architecture has 128bit SIMD then we would
expect that contiguous row-major transfers of <4x4xi8> map to one SIMD
load/store instruction each.

The current generic lowering of multi-dimensional vector.transfer ops
can't achieve that because it peels dimensions one by one, so a transfer
of <4x4xi8> becomes 4 transfers of <4xi8>.

The new patterns here are only enabled for now by
 -test-vector-transfer-flatten-patterns.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D114993

Added: 
    mlir/test/Dialect/Vector/vector-transfer-flatten.mlir

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorOps.h
    mlir/include/mlir/IR/BuiltinTypes.h
    mlir/lib/Dialect/Vector/VectorTransferOpTransforms.cpp
    mlir/lib/IR/BuiltinTypes.cpp
    mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h
index c6b63a949f642..14bd03968fcf6 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h
@@ -67,13 +67,21 @@ void populateShapeCastFoldingPatterns(RewritePatternSet &patterns);
 /// pairs or forward write-read pairs.
 void populateCastAwayVectorLeadingOneDimPatterns(RewritePatternSet &patterns);
 
-/// Collect a set of leading one dimension removal patterns.
+/// Collect a set of one dimension removal patterns.
 ///
 /// These patterns insert rank-reducing memref.subview ops to remove one
 /// dimensions. With them, there are more chances that we can avoid
 /// potentially exensive vector.shape_cast operations.
 void populateVectorTransferDropUnitDimsPatterns(RewritePatternSet &patterns);
 
+/// Collect a set of patterns to flatten n-D vector transfers on contiguous
+/// memref.
+///
+/// These patterns insert memref.collapse_shape + vector.shape_cast patterns
+/// to transform multiple small n-D transfers into a larger 1-D transfer where
+/// the memref contiguity properties allow it.
+void populateFlattenVectorTransferPatterns(RewritePatternSet &patterns);
+
 /// Collect a set of patterns that bubble up/down bitcast ops.
 ///
 /// These patterns move vector.bitcast ops to be before insert ops or after

diff  --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 296b12c427927..be98b577ffab2 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -531,6 +531,11 @@ bool isStrided(MemRefType t);
 /// Return null if the layout is not compatible with a strided layout.
 AffineMap getStridedLinearLayoutMap(MemRefType t);
 
+/// Helper determining if a memref is static-shape and contiguous-row-major
+/// layout, while still allowing for an arbitrary offset (any static or
+/// dynamic value).
+bool isStaticShapeAndContiguousRowMajor(MemRefType memrefType);
+
 } // namespace mlir
 
 #endif // MLIR_IR_BUILTINTYPES_H

diff  --git a/mlir/lib/Dialect/Vector/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransferOpTransforms.cpp
index c9438c4a28f4e..9b1ae7a402261 100644
--- a/mlir/lib/Dialect/Vector/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransferOpTransforms.cpp
@@ -227,7 +227,8 @@ static Value rankReducingSubviewDroppingUnitDims(PatternRewriter &rewriter,
   MemRefType inputType = input.getType().cast<MemRefType>();
   assert(inputType.hasStaticShape());
   MemRefType resultType = dropUnitDims(inputType);
-  if (resultType == inputType)
+  if (canonicalizeStridedLayout(resultType) ==
+      canonicalizeStridedLayout(inputType))
     return input;
   SmallVector<int64_t> subviewOffsets(inputType.getRank(), 0);
   SmallVector<int64_t> subviewStrides(inputType.getRank(), 1);
@@ -333,6 +334,130 @@ class TransferWriteDropUnitDimsPattern
   }
 };
 
+/// Creates a memref.collapse_shape collapsing all of the dimensions of the
+/// input into a 1D shape.
+// TODO: move helper function
+static Value collapseContiguousRowMajorMemRefTo1D(PatternRewriter &rewriter,
+                                                  mlir::Location loc,
+                                                  Value input) {
+  Value rankReducedInput =
+      rankReducingSubviewDroppingUnitDims(rewriter, loc, input);
+  ShapedType rankReducedInputType =
+      rankReducedInput.getType().cast<ShapedType>();
+  if (rankReducedInputType.getRank() == 1)
+    return rankReducedInput;
+  ReassociationIndices indices;
+  for (int i = 0; i < rankReducedInputType.getRank(); ++i)
+    indices.push_back(i);
+  return rewriter.create<memref::CollapseShapeOp>(
+      loc, rankReducedInput, std::array<ReassociationIndices, 1>{indices});
+}
+
+/// Rewrites contiguous row-major vector.transfer_read ops by inserting
+/// memref.collapse_shape on the source so that the resulting
+/// vector.transfer_read has a 1D source. Requires the source shape to be
+/// already reduced i.e. without unit dims.
+class FlattenContiguousRowMajorTransferReadPattern
+    : public OpRewritePattern<vector::TransferReadOp> {
+  using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
+                                PatternRewriter &rewriter) const override {
+    auto loc = transferReadOp.getLoc();
+    Value vector = transferReadOp.vector();
+    VectorType vectorType = vector.getType().cast<VectorType>();
+    Value source = transferReadOp.source();
+    MemRefType sourceType = source.getType().dyn_cast<MemRefType>();
+    // Contiguity check is valid on tensors only.
+    if (!sourceType)
+      return failure();
+    if (vectorType.getRank() == 1 && sourceType.getRank() == 1)
+      // Already 1D, nothing to do.
+      return failure();
+    if (!isStaticShapeAndContiguousRowMajor(sourceType))
+      return failure();
+    if (getReducedRank(sourceType.getShape()) != sourceType.getRank())
+      // This pattern requires the source to already be rank-reduced.
+      return failure();
+    if (sourceType.getNumElements() != vectorType.getNumElements())
+      return failure();
+    // TODO: generalize this pattern, relax the requirements here.
+    if (transferReadOp.hasOutOfBoundsDim())
+      return failure();
+    if (!transferReadOp.permutation_map().isMinorIdentity())
+      return failure();
+    if (transferReadOp.mask())
+      return failure();
+    if (llvm::any_of(transferReadOp.indices(),
+                     [](Value v) { return !isZero(v); }))
+      return failure();
+    Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    auto identityMap1D = rewriter.getMultiDimIdentityMap(1);
+    VectorType vectorType1d = VectorType::get({sourceType.getNumElements()},
+                                              sourceType.getElementType());
+    Value source1d =
+        collapseContiguousRowMajorMemRefTo1D(rewriter, loc, source);
+    Value read1d = rewriter.create<vector::TransferReadOp>(
+        loc, vectorType1d, source1d, ValueRange{c0}, identityMap1D);
+    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
+        transferReadOp, vector.getType().cast<VectorType>(), read1d);
+    return success();
+  }
+};
+
+/// Rewrites contiguous row-major vector.transfer_write ops by inserting
+/// memref.collapse_shape on the source so that the resulting
+/// vector.transfer_write has a 1D source. Requires the source shape to be
+/// already reduced i.e. without unit dims.
+class FlattenContiguousRowMajorTransferWritePattern
+    : public OpRewritePattern<vector::TransferWriteOp> {
+  using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
+                                PatternRewriter &rewriter) const override {
+    auto loc = transferWriteOp.getLoc();
+    Value vector = transferWriteOp.vector();
+    VectorType vectorType = vector.getType().cast<VectorType>();
+    Value source = transferWriteOp.source();
+    MemRefType sourceType = source.getType().dyn_cast<MemRefType>();
+    // Contiguity check is valid on tensors only.
+    if (!sourceType)
+      return failure();
+    if (vectorType.getRank() == 1 && sourceType.getRank() == 1)
+      // Already 1D, nothing to do.
+      return failure();
+    if (!isStaticShapeAndContiguousRowMajor(sourceType))
+      return failure();
+    if (getReducedRank(sourceType.getShape()) != sourceType.getRank())
+      // This pattern requires the source to already be rank-reduced.
+      return failure();
+    if (sourceType.getNumElements() != vectorType.getNumElements())
+      return failure();
+    // TODO: generalize this pattern, relax the requirements here.
+    if (transferWriteOp.hasOutOfBoundsDim())
+      return failure();
+    if (!transferWriteOp.permutation_map().isMinorIdentity())
+      return failure();
+    if (transferWriteOp.mask())
+      return failure();
+    if (llvm::any_of(transferWriteOp.indices(),
+                     [](Value v) { return !isZero(v); }))
+      return failure();
+    Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    auto identityMap1D = rewriter.getMultiDimIdentityMap(1);
+    VectorType vectorType1d = VectorType::get({sourceType.getNumElements()},
+                                              sourceType.getElementType());
+    Value source1d =
+        collapseContiguousRowMajorMemRefTo1D(rewriter, loc, source);
+    Value vector1d =
+        rewriter.create<vector::ShapeCastOp>(loc, vectorType1d, vector);
+    rewriter.create<vector::TransferWriteOp>(loc, vector1d, source1d,
+                                             ValueRange{c0}, identityMap1D);
+    rewriter.eraseOp(transferWriteOp);
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::vector::transferOpflowOpt(FuncOp func) {
@@ -358,3 +483,11 @@ void mlir::vector::populateVectorTransferDropUnitDimsPatterns(
           patterns.getContext());
   populateShapeCastFoldingPatterns(patterns);
 }
+
+void mlir::vector::populateFlattenVectorTransferPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<FlattenContiguousRowMajorTransferReadPattern,
+               FlattenContiguousRowMajorTransferWritePattern>(
+      patterns.getContext());
+  populateShapeCastFoldingPatterns(patterns);
+}

diff  --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 10c38a86314fa..8e408e440dc3b 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -1168,3 +1168,40 @@ AffineMap mlir::getStridedLinearLayoutMap(MemRefType t) {
     return AffineMap();
   return makeStridedLinearLayoutMap(strides, offset, t.getContext());
 }
+
+/// Return the AffineExpr representation of the offset, assuming `memRefType`
+/// is a strided memref.
+static AffineExpr getOffsetExpr(MemRefType memrefType) {
+  SmallVector<AffineExpr> strides;
+  AffineExpr offset;
+  if (failed(getStridesAndOffset(memrefType, strides, offset)))
+    assert(false && "expected strided memref");
+  return offset;
+}
+
+/// Helper to construct a contiguous MemRefType of `shape`, `elementType` and
+/// `offset` AffineExpr.
+static MemRefType makeContiguousRowMajorMemRefType(MLIRContext *context,
+                                                   ArrayRef<int64_t> shape,
+                                                   Type elementType,
+                                                   AffineExpr offset) {
+  AffineExpr canonical = makeCanonicalStridedLayoutExpr(shape, context);
+  AffineExpr contiguousRowMajor = canonical + offset;
+  AffineMap contiguousRowMajorMap =
+      AffineMap::inferFromExprList({contiguousRowMajor})[0];
+  return MemRefType::get(shape, elementType, contiguousRowMajorMap);
+}
+
+/// Helper determining if a memref is static-shape and contiguous-row-major
+/// layout, while still allowing for an arbitrary offset (any static or
+/// dynamic value).
+bool mlir::isStaticShapeAndContiguousRowMajor(MemRefType memrefType) {
+  if (!memrefType.hasStaticShape())
+    return false;
+  AffineExpr offset = getOffsetExpr(memrefType);
+  MemRefType contiguousRowMajorMemRefType = makeContiguousRowMajorMemRefType(
+      memrefType.getContext(), memrefType.getShape(),
+      memrefType.getElementType(), offset);
+  return canonicalizeStridedLayout(memrefType) ==
+         canonicalizeStridedLayout(contiguousRowMajorMemRefType);
+}

diff  --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
new file mode 100644
index 0000000000000..68a6779461d62
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -0,0 +1,35 @@
+// RUN: mlir-opt %s -test-vector-transfer-flatten-patterns -split-input-file | FileCheck %s
+
+func @transfer_read_flattenable_with_offset(
+      %arg : memref<5x4x3x2xi8, offset:?, strides:[24, 6, 2, 1]>) -> vector<5x4x3x2xi8> {
+    %c0 = arith.constant 0 : index
+    %cst = arith.constant 0 : i8
+    %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst : 
+      memref<5x4x3x2xi8, offset:?, strides:[24, 6, 2, 1]>, vector<5x4x3x2xi8>
+    return %v : vector<5x4x3x2xi8>
+}
+
+// CHECK-LABEL: func @transfer_read_flattenable_with_offset
+// CHECK-SAME:      %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
+// CHECK:         %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]
+// C-HECK:         %[[READ1D:.+]] = vector.transfer_read %[[COLLAPSED]]
+// C-HECK:         %[[VEC2D:.+]] = vector.shape_cast %[[READ1D]] : vector<120xi8> to vector<5x4x3x2xi8>
+// C-HECK:         return %[[VEC2D]]
+
+// -----
+
+func @transfer_write_flattenable_with_offset(
+      %arg : memref<5x4x3x2xi8, offset:?, strides:[24, 6, 2, 1]>, %vec : vector<5x4x3x2xi8>) {
+    %c0 = arith.constant 0 : index
+    vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] : 
+      vector<5x4x3x2xi8>, memref<5x4x3x2xi8, offset:?, strides:[24, 6, 2, 1]>
+    return
+}
+
+// C-HECK-LABEL: func @transfer_write_flattenable_with_offset
+// C-HECK-SAME:      %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
+// C-HECK-SAME:      %[[VEC:[0-9a-zA-Z]+]]: vector<5x4x3x2xi8>
+// C-HECK-DAG:     %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]{{.}} : memref<5x4x3x2xi8, {{.+}}> into memref<120xi8, {{.+}}>
+// C-HECK-DAG:     %[[VEC1D:.+]] = vector.shape_cast %[[VEC]] : vector<5x4x3x2xi8> to vector<120xi8>
+// C-HECK:         vector.transfer_write %[[VEC1D]], %[[COLLAPSED]]
+

diff  --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index cf33b0d7117d5..a0d5a1b915ff8 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -598,6 +598,25 @@ struct TestVectorTransferDropUnitDimsPatterns
   }
 };
 
+struct TestFlattenVectorTransferPatterns
+    : public PassWrapper<TestFlattenVectorTransferPatterns, FunctionPass> {
+  StringRef getArgument() const final {
+    return "test-vector-transfer-flatten-patterns";
+  }
+  StringRef getDescription() const final {
+    return "Test patterns to rewrite contiguous row-major N-dimensional "
+           "vector.transfer_{read,write} ops into 1D transfers";
+  }
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<memref::MemRefDialect>();
+  }
+  void runOnFunction() override {
+    RewritePatternSet patterns(&getContext());
+    populateFlattenVectorTransferPatterns(patterns);
+    (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
+  }
+};
+
 } // namespace
 
 namespace mlir {
@@ -630,6 +649,8 @@ void registerTestVectorLowerings() {
   PassRegistration<TestVectorReduceToContractPatternsPatterns>();
 
   PassRegistration<TestVectorTransferDropUnitDimsPatterns>();
+
+  PassRegistration<TestFlattenVectorTransferPatterns>();
 }
 } // namespace test
 } // namespace mlir


        


More information about the Mlir-commits mailing list