[Mlir-commits] [mlir] 0aea49a - [mlir][Vector] Patterns flattening vector transfers to 1D

Nicolas Vasilache llvmlistbot at llvm.org
Mon Dec 13 13:50:53 PST 2021


Author: Benoit Jacob
Date: 2021-12-13T21:49:04Z
New Revision: 0aea49a7308322e6987c7b45e4e0d7ab15609e78

URL: https://github.com/llvm/llvm-project/commit/0aea49a7308322e6987c7b45e4e0d7ab15609e78
DIFF: https://github.com/llvm/llvm-project/commit/0aea49a7308322e6987c7b45e4e0d7ab15609e78.diff

LOG: [mlir][Vector] Patterns flattening vector transfers to 1D

This is the first part of https://reviews.llvm.org/D114993 which has been
split into small independent commits.

This is needed at the moment to get good codegen from 2d vector.transfer
ops that aim to compile to SIMD load/store instructions but that can
only do so if the whole 2d transfer shape is handled in one piece, in
particular taking advantage of the memref being contiguous rowmajor.

For instance, if the target architecture has 128bit SIMD then we would
expect that contiguous row-major transfers of <4x4xi8> map to one SIMD
load/store instruction each.

The current generic lowering of multi-dimensional vector.transfer ops
can't achieve that because it peels dimensions one by one, so a transfer
of <4x4xi8> becomes 4 transfers of <4xi8>.

The new patterns here are only enabled for now by
 -test-vector-transfer-flatten-patterns.

Reviewed By: nicolasvasilache

Added: 
    mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorOps.h
    mlir/lib/Dialect/Vector/VectorTransferOpTransforms.cpp
    mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h
index 6dd4b9aaf5552..c6b63a949f642 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h
@@ -67,6 +67,13 @@ void populateShapeCastFoldingPatterns(RewritePatternSet &patterns);
 /// pairs or forward write-read pairs.
 void populateCastAwayVectorLeadingOneDimPatterns(RewritePatternSet &patterns);
 
+/// Collect a set of leading one dimension removal patterns.
+///
+/// These patterns insert rank-reducing memref.subview ops to remove one
+/// dimensions. With them, there are more chances that we can avoid
+/// potentially exensive vector.shape_cast operations.
+void populateVectorTransferDropUnitDimsPatterns(RewritePatternSet &patterns);
+
 /// Collect a set of patterns that bubble up/down bitcast ops.
 ///
 /// These patterns move vector.bitcast ops to be before insert ops or after

diff  --git a/mlir/lib/Dialect/Vector/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransferOpTransforms.cpp
index ae6f3949c3998..c9438c4a28f4e 100644
--- a/mlir/lib/Dialect/Vector/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransferOpTransforms.cpp
@@ -10,12 +10,14 @@
 // transfer_write ops.
 //
 //===----------------------------------------------------------------------===//
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/Vector/VectorOps.h"
 #include "mlir/Dialect/Vector/VectorTransforms.h"
 #include "mlir/Dialect/Vector/VectorUtils.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/Dominance.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/Support/Debug.h"
 
@@ -209,6 +211,128 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
   opToErase.push_back(read.getOperation());
 }
 
+/// Drops unit dimensions from the input MemRefType.
+static MemRefType dropUnitDims(MemRefType inputType) {
+  ArrayRef<int64_t> none{};
+  Type rankReducedType = memref::SubViewOp::inferRankReducedResultType(
+      0, inputType, none, none, none);
+  return canonicalizeStridedLayout(rankReducedType.cast<MemRefType>());
+}
+
+/// Creates a rank-reducing memref.subview op that drops unit dims from its
+/// input. Or just returns the input if it was already without unit dims.
+static Value rankReducingSubviewDroppingUnitDims(PatternRewriter &rewriter,
+                                                 mlir::Location loc,
+                                                 Value input) {
+  MemRefType inputType = input.getType().cast<MemRefType>();
+  assert(inputType.hasStaticShape());
+  MemRefType resultType = dropUnitDims(inputType);
+  if (resultType == inputType)
+    return input;
+  SmallVector<int64_t> subviewOffsets(inputType.getRank(), 0);
+  SmallVector<int64_t> subviewStrides(inputType.getRank(), 1);
+  return rewriter.create<memref::SubViewOp>(
+      loc, resultType, input, subviewOffsets, inputType.getShape(),
+      subviewStrides);
+}
+
+/// Returns the number of dims that aren't unit dims.
+static int getReducedRank(ArrayRef<int64_t> shape) {
+  return llvm::count_if(shape, [](int64_t dimSize) { return dimSize != 1; });
+}
+
+/// Returns true if all values are `arith.constant 0 : index`
+static bool isZero(Value v) {
+  auto cst = v.getDefiningOp<arith::ConstantIndexOp>();
+  return cst && cst.value() == 0;
+}
+
+/// Rewrites vector.transfer_read ops where the source has unit dims, by
+/// inserting a memref.subview dropping those unit dims.
+class TransferReadDropUnitDimsPattern
+    : public OpRewritePattern<vector::TransferReadOp> {
+  using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
+                                PatternRewriter &rewriter) const override {
+    auto loc = transferReadOp.getLoc();
+    Value vector = transferReadOp.vector();
+    VectorType vectorType = vector.getType().cast<VectorType>();
+    Value source = transferReadOp.source();
+    MemRefType sourceType = source.getType().dyn_cast<MemRefType>();
+    // TODO: support tensor types.
+    if (!sourceType || !sourceType.hasStaticShape())
+      return failure();
+    if (sourceType.getNumElements() != vectorType.getNumElements())
+      return failure();
+    // TODO: generalize this pattern, relax the requirements here.
+    if (transferReadOp.hasOutOfBoundsDim())
+      return failure();
+    if (!transferReadOp.permutation_map().isMinorIdentity())
+      return failure();
+    int reducedRank = getReducedRank(sourceType.getShape());
+    if (reducedRank == sourceType.getRank())
+      return failure(); // The source shape can't be further reduced.
+    if (reducedRank != vectorType.getRank())
+      return failure(); // This pattern requires the vector shape to match the
+                        // reduced source shape.
+    if (llvm::any_of(transferReadOp.indices(),
+                     [](Value v) { return !isZero(v); }))
+      return failure();
+    Value reducedShapeSource =
+        rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
+    Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    SmallVector<Value> zeros(reducedRank, c0);
+    auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
+    rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
+        transferReadOp, vectorType, reducedShapeSource, zeros, identityMap);
+    return success();
+  }
+};
+
+/// Rewrites vector.transfer_write ops where the "source" (i.e. destination) has
+/// unit dims, by inserting a memref.subview dropping those unit dims.
+class TransferWriteDropUnitDimsPattern
+    : public OpRewritePattern<vector::TransferWriteOp> {
+  using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
+                                PatternRewriter &rewriter) const override {
+    auto loc = transferWriteOp.getLoc();
+    Value vector = transferWriteOp.vector();
+    VectorType vectorType = vector.getType().cast<VectorType>();
+    Value source = transferWriteOp.source();
+    MemRefType sourceType = source.getType().dyn_cast<MemRefType>();
+    // TODO: support tensor type.
+    if (!sourceType || !sourceType.hasStaticShape())
+      return failure();
+    if (sourceType.getNumElements() != vectorType.getNumElements())
+      return failure();
+    // TODO: generalize this pattern, relax the requirements here.
+    if (transferWriteOp.hasOutOfBoundsDim())
+      return failure();
+    if (!transferWriteOp.permutation_map().isMinorIdentity())
+      return failure();
+    int reducedRank = getReducedRank(sourceType.getShape());
+    if (reducedRank == sourceType.getRank())
+      return failure(); // The source shape can't be further reduced.
+    if (reducedRank != vectorType.getRank())
+      return failure(); // This pattern requires the vector shape to match the
+                        // reduced source shape.
+    if (llvm::any_of(transferWriteOp.indices(),
+                     [](Value v) { return !isZero(v); }))
+      return failure();
+    Value reducedShapeSource =
+        rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
+    Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    SmallVector<Value> zeros(reducedRank, c0);
+    auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
+    rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
+        transferWriteOp, vector, reducedShapeSource, zeros, identityMap);
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::vector::transferOpflowOpt(FuncOp func) {
@@ -226,3 +350,11 @@ void mlir::vector::transferOpflowOpt(FuncOp func) {
   });
   opt.removeDeadOp();
 }
+
+void mlir::vector::populateVectorTransferDropUnitDimsPatterns(
+    RewritePatternSet &patterns) {
+  patterns
+      .add<TransferReadDropUnitDimsPattern, TransferWriteDropUnitDimsPattern>(
+          patterns.getContext());
+  populateShapeCastFoldingPatterns(patterns);
+}

diff  --git a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
new file mode 100644
index 0000000000000..a3d34a646c2fd
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
@@ -0,0 +1,33 @@
+// RUN: mlir-opt %s -test-vector-transfer-drop-unit-dims-patterns -split-input-file | FileCheck %s
+
+// -----
+
+func @transfer_read_rank_reducing(
+      %arg : memref<1x1x3x2xi8, offset:?, strides:[6, 6, 2, 1]>) -> vector<3x2xi8> {
+    %c0 = arith.constant 0 : index
+    %cst = arith.constant 0 : i8
+    %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst : 
+      memref<1x1x3x2xi8, offset:?, strides:[6, 6, 2, 1]>, vector<3x2xi8>
+    return %v : vector<3x2xi8>
+}
+
+// CHECK-LABEL: func @transfer_read_rank_reducing
+//  CHECK-SAME:     %[[ARG:.+]]: memref<1x1x3x2xi8
+//       CHECK:   %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0] [1, 1, 3, 2] [1, 1, 1, 1] 
+//  CHECK-SAME:     memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}>
+//       CHECK:   vector.transfer_read %[[SUBVIEW]]
+
+// -----
+
+func @transfer_write_rank_reducing(%arg : memref<1x1x3x2xi8, offset:?, strides:[6, 6, 2, 1]>, %vec : vector<3x2xi8>) {
+    %c0 = arith.constant 0 : index
+    vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] : 
+      vector<3x2xi8>, memref<1x1x3x2xi8, offset:?, strides:[6, 6, 2, 1]>
+    return
+}
+
+// CHECK-LABEL: func @transfer_write_rank_reducing
+//  CHECK-SAME:     %[[ARG:.+]]: memref<1x1x3x2xi8
+//       CHECK:   %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0] [1, 1, 3, 2] [1, 1, 1, 1] 
+//  CHECK-SAME:     memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}>
+//       CHECK:   vector.transfer_write %{{.*}}, %[[SUBVIEW]]
\ No newline at end of file

diff  --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index f7e13bc330d44..cf33b0d7117d5 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -583,6 +583,21 @@ struct TestVectorReduceToContractPatternsPatterns
   }
 };
 
+struct TestVectorTransferDropUnitDimsPatterns
+    : public PassWrapper<TestVectorTransferDropUnitDimsPatterns, FunctionPass> {
+  StringRef getArgument() const final {
+    return "test-vector-transfer-drop-unit-dims-patterns";
+  }
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<memref::MemRefDialect>();
+  }
+  void runOnFunction() override {
+    RewritePatternSet patterns(&getContext());
+    populateVectorTransferDropUnitDimsPatterns(patterns);
+    (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
+  }
+};
+
 } // namespace
 
 namespace mlir {
@@ -613,6 +628,8 @@ void registerTestVectorLowerings() {
   PassRegistration<TestVectorTransferCollapseInnerMostContiguousDims>();
 
   PassRegistration<TestVectorReduceToContractPatternsPatterns>();
+
+  PassRegistration<TestVectorTransferDropUnitDimsPatterns>();
 }
 } // namespace test
 } // namespace mlir


        


More information about the Mlir-commits mailing list