[Mlir-commits] [mlir] edbdea7 - [mlir][vector] Add unrolling patterns for Transfer read/write

Thomas Raoux llvmlistbot at llvm.org
Thu Oct 15 15:18:37 PDT 2020


Author: Thomas Raoux
Date: 2020-10-15T15:17:36-07:00
New Revision: edbdea7466d25c5e4d9f73e3043ac87efe433193

URL: https://github.com/llvm/llvm-project/commit/edbdea7466d25c5e4d9f73e3043ac87efe433193
DIFF: https://github.com/llvm/llvm-project/commit/edbdea7466d25c5e4d9f73e3043ac87efe433193.diff

LOG: [mlir][vector] Add unrolling patterns for Transfer read/write

Adding unroll support for transfer read and transfer write operation. This
allows to pick the ideal size for the memory access for a given target.

Differential Revision: https://reviews.llvm.org/D89289

Added: 
    mlir/test/Dialect/Vector/vector-transfer-unroll.mlir

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorTransforms.h
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/test/lib/Transforms/TestVectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
index da9650c67efb..157084a2bff1 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
@@ -69,6 +69,22 @@ SmallVector<Value, 1> unrollSingleResultVectorOp(OpBuilder &builder,
                                                  Operation *op,
                                                  ArrayRef<int64_t> targetShape);
 
+/// Unroll a transfer_write op. Break up the vector source into a tuple of
+/// vectors matching the given shape. Then store each element with its own
+/// transfer_write.
+///
+/// Example:
+/// vector.transfer_write %A, %M[%c0, %c0] : vector<4x4xf32>, memref<4x4xf32>
+/// ->
+/// %0 = vector.extract_slices %A, [2, 4], [1, 1] :
+///                vector<4x4xf32> into tuple<vector<2x4xf32>, vector<2x4xf32>>
+/// %1 = vector.tuple_get %0, 0 : tuple<vector<2x4xf32>, vector<2x4xf32>>
+/// vector.transfer_write %1, %M[%c0, %c0] : vector<2x4xf32>, memref<4x4xf32>
+/// %2 = vector.tuple_get %0, 1 : tuple<vector<2x4xf32>, vector<2x4xf32>>
+/// vector.transfer_write %2, %M[%c2, %c0] : vector<2x4xf32>, memref<4x4xf32>
+LogicalResult unrollTransferWriteOp(OpBuilder &builder, Operation *op,
+                                    ArrayRef<int64_t> targetShape);
+
 /// Pattern to apply `unrollSingleResultVectorOp` to a `targetShape`
 /// declaratively.
 template <typename OpTy>
@@ -95,6 +111,12 @@ struct UnrollVectorPattern : public OpRewritePattern<OpTy> {
     if (!maybeShapeRatio ||
         llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; }))
       return failure();
+    if (std::is_same<OpTy, TransferWriteOp>::value) {
+      if (failed(unrollTransferWriteOp(rewriter, op, targetShape)))
+        return failure();
+      rewriter.eraseOp(op);
+      return success();
+    }
     if (op.getOperation()->getNumResults() != 1)
       return failure();
     auto resultVector = unrollSingleResultVectorOp(rewriter, op, targetShape);

diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 08ee5c64af09..5f34b6caf65a 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -511,35 +511,6 @@ static void getVectorElementwiseOpUnrollState(Operation *op,
   resultIndex = numVectors - 1;
 }
 
-// Entry point for unrolling declarative pattern rewrites.
-SmallVector<Value, 1>
-mlir::vector::unrollSingleResultVectorOp(OpBuilder &builder, Operation *op,
-                                         ArrayRef<int64_t> targetShape) {
-  assert(op->getNumResults() == 1 && "Expected single result operation");
-
-  // Populate 'iterationBounds', 'vectors' and 'resultIndex' to unroll 'op'.
-  SmallVector<int64_t, 6> iterationBounds;
-  auto unrollableVectorOp = cast<VectorUnrollOpInterface>(op);
-  auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
-  assert(maybeUnrollShape && "Trying to unroll an incorrect vector op");
-
-  std::vector<VectorState> vectors;
-  unsigned resultIndex;
-
-  if (auto contractionOp = dyn_cast<vector::ContractionOp>(op)) {
-    // Populate state for vector ContractionOp.
-    getVectorContractionOpUnrollState(contractionOp, targetShape, vectors,
-                                      resultIndex);
-  } else {
-    // Populate state for vector elementwise op.
-    getVectorElementwiseOpUnrollState(op, targetShape, vectors, resultIndex);
-  }
-
-  // Unroll 'op' with 'iterationBounds' to 'targetShape'.
-  return SmallVector<Value, 1>{unrollSingleResultStructuredOp(
-      op, *maybeUnrollShape, vectors, resultIndex, targetShape, builder)};
-}
-
 /// Generates slices of 'vectorType' according to 'sizes' and 'strides, and
 /// calls 'fn' with linear index and indices for each slice.
 static void generateTransferOpSlices(
@@ -615,6 +586,114 @@ static bool isIdentitySuffix(AffineMap map) {
   return true;
 }
 
+/// Unroll transfer_read ops to the given shape and create an aggregate with all
+/// the chunks.
+static Value unrollTransferReadOp(vector::TransferReadOp readOp,
+                                  ArrayRef<int64_t> targetShape,
+                                  OpBuilder &builder) {
+  if (!isIdentitySuffix(readOp.permutation_map()))
+    return nullptr;
+  auto sourceVectorType = readOp.getVectorType();
+  SmallVector<int64_t, 4> strides(targetShape.size(), 1);
+
+  Location loc = readOp.getLoc();
+  auto memrefElementType =
+      readOp.memref().getType().cast<MemRefType>().getElementType();
+  auto tupleType = generateExtractSlicesOpResultType(
+      sourceVectorType, targetShape, strides, builder);
+  int64_t numSlices = tupleType.size();
+
+  SmallVector<Value, 4> vectorTupleValues(numSlices);
+  SmallVector<Value, 4> indices(readOp.indices().begin(),
+                                readOp.indices().end());
+  auto createSlice = [&](unsigned index, ArrayRef<Value> sliceIndices) {
+    // Get VectorType for slice 'i'.
+    auto sliceVectorType = tupleType.getType(index);
+    // Create split TransferReadOp for 'sliceUser'.
+    // `masked` attribute propagates conservatively: if the coarse op didn't
+    // need masking, the fine op doesn't either.
+    vectorTupleValues[index] = builder.create<vector::TransferReadOp>(
+        loc, sliceVectorType, readOp.memref(), sliceIndices,
+        readOp.permutation_map(), readOp.padding(),
+        readOp.masked() ? *readOp.masked() : ArrayAttr());
+  };
+  generateTransferOpSlices(memrefElementType, sourceVectorType, tupleType,
+                           targetShape, strides, indices, builder, createSlice);
+
+  // Create tuple of splice transfer read operations.
+  Value tupleOp =
+      builder.create<vector::TupleOp>(loc, tupleType, vectorTupleValues);
+  // Replace 'readOp' with result 'insertSlicesResult'.
+  Value newVec = builder.create<vector::InsertSlicesOp>(
+      loc, sourceVectorType, tupleOp, builder.getI64ArrayAttr(targetShape),
+      builder.getI64ArrayAttr(strides));
+  return newVec;
+}
+
+// Entry point for unrolling declarative pattern rewrite for transfer_write op.
+LogicalResult
+mlir::vector::unrollTransferWriteOp(OpBuilder &builder, Operation *op,
+                                    ArrayRef<int64_t> targetShape) {
+  auto writeOp = cast<vector::TransferWriteOp>(op);
+  if (!isIdentitySuffix(writeOp.permutation_map()))
+    return failure();
+  VectorType sourceVectorType = writeOp.getVectorType();
+  SmallVector<int64_t, 4> strides(targetShape.size(), 1);
+  TupleType tupleType = generateExtractSlicesOpResultType(
+      sourceVectorType, targetShape, strides, builder);
+  Location loc = writeOp.getLoc();
+  Value tuple = builder.create<vector::ExtractSlicesOp>(
+      loc, tupleType, writeOp.vector(), targetShape, strides);
+  auto memrefElementType =
+      writeOp.memref().getType().cast<MemRefType>().getElementType();
+  SmallVector<Value, 4> indices(writeOp.indices().begin(),
+                                writeOp.indices().end());
+  auto createSlice = [&](unsigned index, ArrayRef<Value> sliceIndices) {
+    auto element = builder.create<vector::TupleGetOp>(
+        loc, tupleType.getType(index), tuple, builder.getI64IntegerAttr(index));
+    builder.create<vector::TransferWriteOp>(
+        loc, element.getResult(), writeOp.memref(), sliceIndices,
+        writeOp.permutation_map(),
+        writeOp.masked() ? *writeOp.masked() : ArrayAttr());
+  };
+  generateTransferOpSlices(memrefElementType, sourceVectorType, tupleType,
+                           targetShape, strides, indices, builder, createSlice);
+  return success();
+}
+
+// Entry point for unrolling declarative pattern rewrites.
+SmallVector<Value, 1>
+mlir::vector::unrollSingleResultVectorOp(OpBuilder &builder, Operation *op,
+                                         ArrayRef<int64_t> targetShape) {
+  assert(op->getNumResults() == 1 && "Expected single result operation");
+
+  // Populate 'iterationBounds', 'vectors' and 'resultIndex' to unroll 'op'.
+  SmallVector<int64_t, 6> iterationBounds;
+  auto unrollableVectorOp = cast<VectorUnrollOpInterface>(op);
+  auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
+  assert(maybeUnrollShape && "Trying to unroll an incorrect vector op");
+
+  std::vector<VectorState> vectors;
+  unsigned resultIndex;
+
+  if (auto readOp = dyn_cast<vector::TransferReadOp>(op))
+    return SmallVector<Value, 1>{
+        unrollTransferReadOp(readOp, targetShape, builder)};
+
+  if (auto contractionOp = dyn_cast<vector::ContractionOp>(op)) {
+    // Populate state for vector ContractionOp.
+    getVectorContractionOpUnrollState(contractionOp, targetShape, vectors,
+                                      resultIndex);
+  } else {
+    // Populate state for vector elementwise op.
+    getVectorElementwiseOpUnrollState(op, targetShape, vectors, resultIndex);
+  }
+
+  // Unroll 'op' with 'iterationBounds' to 'targetShape'.
+  return SmallVector<Value, 1>{unrollSingleResultStructuredOp(
+      op, *maybeUnrollShape, vectors, resultIndex, targetShape, builder)};
+}
+
 namespace {
 
 // Splits vector TransferReadOp into smaller TransferReadOps based on slicing
@@ -636,43 +715,16 @@ struct SplitTransferReadOp : public OpRewritePattern<vector::TransferReadOp> {
       return failure();
 
     // Get 'sizes' and 'strides' parameters from ExtractSlicesOp user.
-    auto sourceVectorType = extractSlicesOp.getSourceVectorType();
-    auto resultTupleType = extractSlicesOp.getResultTupleType();
     SmallVector<int64_t, 4> sizes;
     extractSlicesOp.getSizes(sizes);
     SmallVector<int64_t, 4> strides;
     extractSlicesOp.getStrides(strides);
     assert(llvm::all_of(strides, [](int64_t s) { return s == 1; }));
 
-    Location loc = xferReadOp.getLoc();
-    auto memrefElementType =
-        xferReadOp.memref().getType().cast<MemRefType>().getElementType();
-    int64_t numSlices = resultTupleType.size();
-    SmallVector<Value, 4> vectorTupleValues(numSlices);
-    SmallVector<Value, 4> indices(xferReadOp.indices().begin(),
-                                  xferReadOp.indices().end());
-    auto createSlice = [&](unsigned index, ArrayRef<Value> sliceIndices) {
-      // Get VectorType for slice 'i'.
-      auto sliceVectorType = resultTupleType.getType(index);
-      // Create split TransferReadOp for 'sliceUser'.
-      // `masked` attribute propagates conservatively: if the coarse op didn't
-      // need masking, the fine op doesn't either.
-      vectorTupleValues[index] = rewriter.create<vector::TransferReadOp>(
-          loc, sliceVectorType, xferReadOp.memref(), sliceIndices,
-          xferReadOp.permutation_map(), xferReadOp.padding(),
-          xferReadOp.masked() ? *xferReadOp.masked() : ArrayAttr());
-    };
-    generateTransferOpSlices(memrefElementType, sourceVectorType,
-                             resultTupleType, sizes, strides, indices, rewriter,
-                             createSlice);
-
-    // Create tuple of splice xfer read operations.
-    Value tupleOp = rewriter.create<vector::TupleOp>(loc, resultTupleType,
-                                                     vectorTupleValues);
-    // Replace 'xferReadOp' with result 'insertSlicesResult'.
-    rewriter.replaceOpWithNewOp<vector::InsertSlicesOp>(
-        xferReadOp, sourceVectorType, tupleOp, extractSlicesOp.sizes(),
-        extractSlicesOp.strides());
+    Value newVec = unrollTransferReadOp(xferReadOp, sizes, rewriter);
+    if (!newVec)
+      return failure();
+    rewriter.replaceOp(xferReadOp, newVec);
     return success();
   }
 };

diff  --git a/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir b/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir
new file mode 100644
index 000000000000..b676700dae06
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir
@@ -0,0 +1,60 @@
+// RUN: mlir-opt %s -test-vector-transfer-unrolling-patterns | FileCheck %s
+
+// CHECK-LABEL: func @transfer_read_unroll
+//       CHECK:   %[[C0:.*]] = constant 0 : index
+//       CHECK:   %[[C2:.*]] = constant 2 : index
+//       CHECK:   %[[VTR0:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
+//  CHECK-NEXT:   %[[VTR1:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C2]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
+//  CHECK-NEXT:   %[[VTR2:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
+//  CHECK-NEXT:   %[[VTR3:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C2]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
+//  CHECK-NEXT:   %[[TUPL:.*]] = vector.tuple %[[VTR0]], %[[VTR1]], %[[VTR2]], %[[VTR3]] : vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>
+//  CHECK-NEXT:   %[[VEC:.*]] = vector.insert_slices %[[TUPL]], [2, 2], [1, 1] : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> into vector<4x4xf32>
+//  CHECK-NEXT:   return %[[VEC]] : vector<4x4xf32>
+
+func @transfer_read_unroll(%arg0 : memref<4x4xf32>) -> vector<4x4xf32> {
+  %c0 = constant 0 : index
+  %cf0 = constant 0.0 : f32
+  %0 = vector.transfer_read %arg0[%c0, %c0], %cf0 : memref<4x4xf32>, vector<4x4xf32>
+  return %0 : vector<4x4xf32>
+}
+
+// CHECK-LABEL: func @transfer_write_unroll
+//       CHECK:   %[[C0:.*]] = constant 0 : index
+//       CHECK:   %[[C2:.*]] = constant 2 : index
+//       CHECK:   %[[TUPL:.*]] = vector.extract_slices {{.*}}, [2, 2], [1, 1] : vector<4x4xf32> into tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
+//  CHECK-NEXT:   %[[T0:.*]] = vector.tuple_get %[[TUPL]], 0 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
+//  CHECK-NEXT:   vector.transfer_write %[[T0]], {{.*}}[%[[C0]], %[[C0]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
+//  CHECK-NEXT:   %[[T1:.*]] = vector.tuple_get %[[TUPL]], 1 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
+//  CHECK-NEXT:   vector.transfer_write %[[T1]], {{.*}}[%[[C0]], %[[C2]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
+//  CHECK-NEXT:   %[[T2:.*]] = vector.tuple_get %[[TUPL]], 2 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
+//  CHECK-NEXT:   vector.transfer_write %[[T2]], {{.*}}[%[[C2]], %[[C0]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
+//  CHECK-NEXT:   %[[T3:.*]] = vector.tuple_get %[[TUPL]], 3 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
+//  CHECK-NEXT:   vector.transfer_write %[[T3]], {{.*}}[%[[C2]], %[[C2]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
+//  CHECK-NEXT:   return
+
+func @transfer_write_unroll(%arg0 : memref<4x4xf32>, %arg1 : vector<4x4xf32>) {
+  %c0 = constant 0 : index
+  vector.transfer_write %arg1, %arg0[%c0, %c0] : vector<4x4xf32>, memref<4x4xf32>
+  return
+}
+
+// CHECK-LABEL: func @transfer_readwrite_unroll
+//       CHECK:   %[[C0:.*]] = constant 0 : index
+//       CHECK:   %[[C2:.*]] = constant 2 : index
+//       CHECK:   %[[VTR0:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
+//  CHECK-NEXT:   %[[VTR1:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C2]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
+//  CHECK-NEXT:   %[[VTR2:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
+//  CHECK-NEXT:   %[[VTR3:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C2]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
+//  CHECK-NEXT:   vector.transfer_write %[[VTR0]], {{.*}}[%[[C0]], %[[C0]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
+//  CHECK-NEXT:   vector.transfer_write %[[VTR1]], {{.*}}[%[[C0]], %[[C2]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
+//  CHECK-NEXT:   vector.transfer_write %[[VTR2]], {{.*}}[%[[C2]], %[[C0]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
+//  CHECK-NEXT:   vector.transfer_write %[[VTR3]], {{.*}}[%[[C2]], %[[C2]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
+//  CHECK-NEXT:   return
+
+func @transfer_readwrite_unroll(%arg0 : memref<4x4xf32>) {
+  %c0 = constant 0 : index
+  %cf0 = constant 0.0 : f32
+  %0 = vector.transfer_read %arg0[%c0, %c0], %cf0 : memref<4x4xf32>, vector<4x4xf32>
+  vector.transfer_write %0, %arg0[%c0, %c0] : vector<4x4xf32>, memref<4x4xf32>
+  return
+}

diff  --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
index fe0947d0ac30..52d0f7b2bb5e 100644
--- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
@@ -156,6 +156,24 @@ struct TestVectorDistributePatterns
   }
 };
 
+struct TestVectorTransferUnrollingPatterns
+    : public PassWrapper<TestVectorTransferUnrollingPatterns, FunctionPass> {
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<AffineDialect>();
+  }
+  void runOnFunction() override {
+    MLIRContext *ctx = &getContext();
+    OwningRewritePatternList patterns;
+    patterns.insert<UnrollVectorPattern<vector::TransferReadOp>>(
+        ArrayRef<int64_t>{2, 2}, ctx);
+    patterns.insert<UnrollVectorPattern<vector::TransferWriteOp>>(
+        ArrayRef<int64_t>{2, 2}, ctx);
+    populateVectorToVectorCanonicalizationPatterns(patterns, ctx);
+    populateVectorToVectorTransformationPatterns(patterns, ctx);
+    applyPatternsAndFoldGreedily(getFunction(), patterns);
+  }
+};
+
 struct TestVectorTransferFullPartialSplitPatterns
     : public PassWrapper<TestVectorTransferFullPartialSplitPatterns,
                          FunctionPass> {
@@ -205,6 +223,10 @@ void registerTestVectorConversions() {
       "test-vector-unrolling-patterns",
       "Test conversion patterns to unroll contract ops in the vector dialect");
 
+  PassRegistration<TestVectorTransferUnrollingPatterns> transferOpUnrollingPass(
+      "test-vector-transfer-unrolling-patterns",
+      "Test conversion patterns to unroll transfer ops in the vector dialect");
+
   PassRegistration<TestVectorTransferFullPartialSplitPatterns>
       vectorTransformFullPartialPass("test-vector-transfer-full-partial-split",
                                      "Test conversion patterns to split "


        


More information about the Mlir-commits mailing list