[Mlir-commits] [mlir] 558e740 - [mlir] Support tensor types in non-unrolled VectorToSCF

Matthias Springer llvmlistbot at llvm.org
Tue Jun 1 18:38:16 PDT 2021


Author: Matthias Springer
Date: 2021-06-02T10:37:58+09:00
New Revision: 558e740170681c723ecb04156f7177d6dfebff13

URL: https://github.com/llvm/llvm-project/commit/558e740170681c723ecb04156f7177d6dfebff13
DIFF: https://github.com/llvm/llvm-project/commit/558e740170681c723ecb04156f7177d6dfebff13.diff

LOG: [mlir] Support tensor types in non-unrolled VectorToSCF

Support for tensor types in the unrolled version will follow in a separate commit.

Add a new pass option to activate lowering of transfer ops with tensor types (default: deactivated).

Differential Revision: https://reviews.llvm.org/D102666

Added: 
    mlir/test/Conversion/VectorToSCF/tensor-transfer-ops.mlir

Modified: 
    mlir/include/mlir/Conversion/Passes.td
    mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h
    mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index a1c40a7416770..a659c20e6530e 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -523,7 +523,9 @@ def ConvertVectorToSCF : FunctionPass<"convert-vector-to-scf"> {
            "Target vector rank to which transfer ops should be lowered">,
     Option<"lowerPermutationMaps", "lower-permutation-maps", "bool",
            /*default=*/"false", "Replace permutation maps with vector "
-           "transposes/broadcasts before lowering transfer ops">
+           "transposes/broadcasts before lowering transfer ops">,
+    Option<"lowerTensors", "lower-tensors", "bool", /*default=*/"false",
+           "Lower transfer ops that operate on tensors">
   ];
 }
 

diff  --git a/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h b/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h
index a999c4a1fcfc6..123ce4b7ed206 100644
--- a/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h
+++ b/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h
@@ -48,12 +48,18 @@ class RewritePatternSet;
 /// is reused and only a second vector.type_cast is added.
 
 struct VectorTransferToSCFOptions {
-  bool unroll = false;
   unsigned targetRank = 1;
   bool lowerPermutationMaps = false;
+  bool lowerTensors = false;
+  bool unroll = false;
 
-  VectorTransferToSCFOptions &setUnroll(bool u) {
-    unroll = u;
+  VectorTransferToSCFOptions &setLowerPermutationMaps(bool l) {
+    lowerPermutationMaps = l;
+    return *this;
+  }
+
+  VectorTransferToSCFOptions &setLowerTensors(bool l) {
+    lowerTensors = l;
     return *this;
   }
 
@@ -62,8 +68,8 @@ struct VectorTransferToSCFOptions {
     return *this;
   }
 
-  VectorTransferToSCFOptions &setLowerPermutationMaps(bool l) {
-    lowerPermutationMaps = l;
+  VectorTransferToSCFOptions &setUnroll(bool u) {
+    unroll = u;
     return *this;
   }
 };

diff  --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 48a1b6b320a8d..7637c22d17bd5 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -99,6 +99,7 @@ static void getXferIndices(OpBuilder &b, OpTy xferOp, Value iv,
 static void maybeYieldValue(OpBuilder &b, Location loc, bool hasRetVal,
                             Value value) {
   if (hasRetVal) {
+    assert(value && "Expected non-empty value");
     b.create<scf::YieldOp>(loc, value);
   } else {
     b.create<scf::YieldOp>(loc);
@@ -242,6 +243,19 @@ static void maybeApplyPassLabel(OpBuilder &b, OpTy newXferOp,
     newXferOp->setAttr(kPassLabel, b.getUnitAttr());
 }
 
+/// Return true if this transfer op operates on a source tensor.
+template <typename OpTy>
+static bool isTensorOp(OpTy xferOp) {
+  if (xferOp.getShapedType().template isa<RankedTensorType>()) {
+    if (xferOp.getOperationName().equals(TransferWriteOp::getOperationName())) {
+      // TransferWriteOps on tensors have a result.
+      assert(xferOp->getNumResults() > 0);
+    }
+    return true;
+  }
+  return false;
+}
+
 namespace lowering_n_d {
 
 /// Helper data structure for data and mask buffers.
@@ -365,8 +379,8 @@ struct Strategy<TransferReadOp> {
   /// Note: The `mask` operand is set in TransferOpConversion.
   static TransferReadOp rewriteOp(OpBuilder &b,
                                   VectorTransferToSCFOptions options,
-                                  TransferReadOp xferOp, Value buffer,
-                                  Value iv) {
+                                  TransferReadOp xferOp, Value buffer, Value iv,
+                                  ValueRange /*loopState*/) {
     SmallVector<Value, 8> storeIndices;
     getBufferIndices(xferOp, storeIndices);
     storeIndices.push_back(iv);
@@ -391,8 +405,9 @@ struct Strategy<TransferReadOp> {
 
   /// Handle out-of-bounds accesses on the to-be-unpacked dimension: Write
   /// padding value to the temporary buffer.
-  static void handleOutOfBoundsDim(OpBuilder &b, TransferReadOp xferOp,
-                                   Value buffer, Value iv) {
+  static Value handleOutOfBoundsDim(OpBuilder &b, TransferReadOp xferOp,
+                                    Value buffer, Value iv,
+                                    ValueRange /*loopState*/) {
     SmallVector<Value, 8> storeIndices;
     getBufferIndices(xferOp, storeIndices);
     storeIndices.push_back(iv);
@@ -402,13 +417,19 @@ struct Strategy<TransferReadOp> {
     auto vecType = bufferType.getElementType().dyn_cast<VectorType>();
     auto vec = b.create<SplatOp>(loc, vecType, xferOp.padding());
     b.create<memref::StoreOp>(loc, vec, buffer, storeIndices);
+
+    return Value();
   }
 
   /// Cleanup after rewriting the op.
-  static void cleanup(PatternRewriter &rewriter, TransferReadOp xferOp) {
+  static void cleanup(PatternRewriter &rewriter, TransferReadOp xferOp,
+                      scf::ForOp /*forOp*/) {
     rewriter.eraseOp(getStoreOp(xferOp));
     rewriter.eraseOp(xferOp);
   }
+
+  /// Return the initial loop state for the generated scf.for loop.
+  static Value initialLoopState(TransferReadOp xferOp) { return Value(); }
 };
 
 /// Codegen strategy for vector TransferWriteOp.
@@ -447,7 +468,7 @@ struct Strategy<TransferWriteOp> {
   static TransferWriteOp rewriteOp(OpBuilder &b,
                                    VectorTransferToSCFOptions options,
                                    TransferWriteOp xferOp, Value buffer,
-                                   Value iv) {
+                                   Value iv, ValueRange loopState) {
     SmallVector<Value, 8> loadIndices;
     getBufferIndices(xferOp, loadIndices);
     loadIndices.push_back(iv);
@@ -458,8 +479,10 @@ struct Strategy<TransferWriteOp> {
     Location loc = xferOp.getLoc();
     auto vec = b.create<memref::LoadOp>(loc, buffer, loadIndices);
     auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr());
+    auto source = loopState.empty() ? xferOp.source() : loopState[0];
+    Type type = isTensorOp(xferOp) ? xferOp.getShapedType() : Type();
     auto newXferOp = b.create<vector::TransferWriteOp>(
-        loc, Type(), vec, xferOp.source(), xferIndices,
+        loc, type, vec, source, xferIndices,
         AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(),
         inBoundsAttr);
 
@@ -469,12 +492,26 @@ struct Strategy<TransferWriteOp> {
   }
 
   /// Handle out-of-bounds accesses on the to-be-unpacked dimension.
-  static void handleOutOfBoundsDim(OpBuilder &b, TransferWriteOp xferOp,
-                                   Value buffer, Value iv) {}
+  static Value handleOutOfBoundsDim(OpBuilder &b, TransferWriteOp xferOp,
+                                    Value buffer, Value iv,
+                                    ValueRange loopState) {
+    return isTensorOp(xferOp) ? loopState[0] : Value();
+  }
 
   /// Cleanup after rewriting the op.
-  static void cleanup(PatternRewriter &rewriter, TransferWriteOp xferOp) {
-    rewriter.eraseOp(xferOp);
+  static void cleanup(PatternRewriter &rewriter, TransferWriteOp xferOp,
+                      scf::ForOp forOp) {
+    if (isTensorOp(xferOp)) {
+      assert(forOp->getNumResults() == 1 && "Expected one for loop result");
+      rewriter.replaceOp(xferOp, forOp->getResult(0));
+    } else {
+      rewriter.eraseOp(xferOp);
+    }
+  }
+
+  /// Return the initial loop state for the generated scf.for loop.
+  static Value initialLoopState(TransferWriteOp xferOp) {
+    return isTensorOp(xferOp) ? xferOp.source() : Value();
   }
 };
 
@@ -485,7 +522,7 @@ LogicalResult checkPrepareXferOp(OpTy xferOp,
     return failure();
   if (xferOp.getVectorType().getRank() <= options.targetRank)
     return failure();
-  if (xferOp.getShapedType().template isa<RankedTensorType>())
+  if (isTensorOp(xferOp) && !options.lowerTensors)
     return failure();
   // Transfer ops that modify the element type are not supported atm.
   if (xferOp.getVectorType().getElementType() !=
@@ -610,6 +647,18 @@ struct PrepareTransferWriteConversion
 ///    corresponding Strategy<OpTy>. If the to-be-unpacked dimension can be
 ///    out-of-bounds, generate an if-check and handle both cases separately.
 /// 3. Clean up according to the corresponding Strategy<OpTy>.
+///
+/// Note: If the transfer op is a TransferWriteOp and operates on a tensor
+/// source (as opposed to a memref source), then each iteration of the generated
+/// scf.for loop yields the new tensor value. E.g.:
+/// ```
+/// %result = scf.for i = 0 to 5 {
+///   %0 = memref.load %buffer[i] : memref<5xvector<4x3xf32>>
+///   %1 = vector.transfer_write %0, %source[...]
+///       : vector<4x3xf32>, tensor<5x4x3xf32>
+///   scf.yield %1 : tensor<5x4x3xf32>
+/// }
+/// ```
 template <typename OpTy>
 struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
   using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
@@ -652,18 +701,24 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
     auto ub = locB.create<ConstantIndexOp>(
         castedDataType.getDimSize(castedDataType.getRank() - 1));
     auto step = locB.create<ConstantIndexOp>(1);
+    // TransferWriteOps that operate on tensors return the modified tensor and
+    // require a loop state.
+    auto loopState = Strategy<OpTy>::initialLoopState(xferOp);
 
     // Generate for loop.
-    locB.create<scf::ForOp>(
-        lb, ub, step, ValueRange(),
-        [&](OpBuilder &b, Location loc, Value iv, ValueRange /*loopState*/) {
-          generateInBoundsCheck(
+    auto result = locB.create<scf::ForOp>(
+        lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(),
+        [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) {
+          Type stateType = loopState.empty() ? Type() : loopState[0].getType();
+
+          auto result = generateInBoundsCheck(
               b, xferOp, iv, unpackedDim(xferOp),
+              stateType ? TypeRange(stateType) : TypeRange(),
               /*inBoundsCase=*/
               [&](OpBuilder &b, Location loc) {
                 // Create new transfer op.
                 OpTy newXfer = Strategy<OpTy>::rewriteOp(
-                    b, this->options, xferOp, castedDataBuffer, iv);
+                    b, this->options, xferOp, castedDataBuffer, iv, loopState);
 
                 // If old transfer op has a mask: Set mask on new transfer op.
                 // Special case: If the mask of the old transfer op is 1D and
@@ -687,16 +742,19 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
                   rewriter.updateRootInPlace(
                       newXfer, [&]() { newXfer.maskMutable().assign(mask); });
                 }
+
+                return loopState.empty() ? Value() : newXfer->getResult(0);
               },
               /*outOfBoundsCase=*/
               [&](OpBuilder &b, Location /*loc*/) {
-                Strategy<OpTy>::handleOutOfBoundsDim(b, xferOp,
-                                                     castedDataBuffer, iv);
+                return Strategy<OpTy>::handleOutOfBoundsDim(
+                    b, xferOp, castedDataBuffer, iv, loopState);
               });
-          b.create<scf::YieldOp>(loc);
+
+          maybeYieldValue(b, loc, !loopState.empty(), result);
         });
 
-    Strategy<OpTy>::cleanup(rewriter, xferOp);
+    Strategy<OpTy>::cleanup(rewriter, xferOp, result);
     return success();
   }
 };
@@ -1184,6 +1242,7 @@ struct ConvertVectorToSCFPass
     this->fullUnroll = options.unroll;
     this->targetRank = options.targetRank;
     this->lowerPermutationMaps = options.lowerPermutationMaps;
+    this->lowerTensors = options.lowerTensors;
   }
 
   void runOnFunction() override {
@@ -1191,6 +1250,7 @@ struct ConvertVectorToSCFPass
     options.unroll = fullUnroll;
     options.targetRank = targetRank;
     options.lowerPermutationMaps = lowerPermutationMaps;
+    options.lowerTensors = lowerTensors;
 
     // Lower permutation maps first.
     if (lowerPermutationMaps) {

diff  --git a/mlir/test/Conversion/VectorToSCF/tensor-transfer-ops.mlir b/mlir/test/Conversion/VectorToSCF/tensor-transfer-ops.mlir
new file mode 100644
index 0000000000000..0cfc6ab814a70
--- /dev/null
+++ b/mlir/test/Conversion/VectorToSCF/tensor-transfer-ops.mlir
@@ -0,0 +1,38 @@
+// RUN: mlir-opt %s -convert-vector-to-scf='lower-tensors=true' -split-input-file -allow-unregistered-dialect | FileCheck %s
+
+// CHECK-LABEL: func @transfer_read_2d(
+//       CHECK: %[[ALLOC:.*]] = memref.alloca() : memref<vector<4x9xf32>>
+//       CHECK: %[[CASTED:.*]] = vector.type_cast %[[ALLOC]] : memref<vector<4x9xf32>> to memref<4xvector<9xf32>>
+//       CHECK: scf.for {{.*}} {
+//       CHECK:   %[[READ:.*]] = vector.transfer_read %{{.*}}[{{.*}}], %cst {in_bounds = [true]} : tensor<?x?xf32>, vector<9xf32>
+//       CHECK:   memref.store %[[READ]], %[[CASTED]][%{{.*}}] : memref<4xvector<9xf32>>
+//       CHECK: }
+//       CHECK: %[[LOADED:.*]] = memref.load %[[ALLOC]][] : memref<vector<4x9xf32>>
+//       CHECK: return %[[LOADED]] : vector<4x9xf32>
+func @transfer_read_2d(%A : tensor<?x?xf32>, %base1 : index, %base2 : index)
+    -> (vector<4x9xf32>){
+  %p = constant -42.0: f32
+  %f = vector.transfer_read %A[%base1, %base2], %p {in_bounds = [true, true]}
+      : tensor<?x?xf32>, vector<4x9xf32>
+  return %f : vector<4x9xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @transfer_write_2d(
+//       CHECK: %[[ALLOC:.*]] = memref.alloca() : memref<vector<2x3xf32>>
+//       CHECK: memref.store {{.*}}, %[[ALLOC]][] : memref<vector<2x3xf32>>
+//       CHECK: %[[CASTED:.*]] = vector.type_cast %[[ALLOC]] : memref<vector<2x3xf32>> to memref<2xvector<3xf32>>
+//       CHECK: %[[RESULT:.*]] = scf.for {{.*}} iter_args(%[[STATE:.*]] = %{{.*}}) -> (tensor<?x?xf32>) {
+//       CHECK:   %[[LOADED:.*]] = memref.load %[[CASTED]][%{{.*}}] : memref<2xvector<3xf32>>
+//       CHECK:   %[[WRITE:.*]] = vector.transfer_write %[[LOADED]], %[[STATE]][{{.*}}] {in_bounds = [true]} : vector<3xf32>, tensor<?x?xf32>
+//       CHECK:   scf.yield %[[WRITE]] : tensor<?x?xf32>
+//       CHECK: }
+//       CHECK: return %[[RESULT]] : tensor<?x?xf32>
+func @transfer_write_2d(%A : tensor<?x?xf32>, %vec : vector<2x3xf32>,
+                        %base1 : index, %base2 : index) -> (tensor<?x?xf32>) {
+  %t = vector.transfer_write %vec, %A[%base1, %base2] {in_bounds = [true, true]}
+      : vector<2x3xf32>, tensor<?x?xf32>
+  return %t : tensor<?x?xf32>
+}
+


        


More information about the Mlir-commits mailing list