[Mlir-commits] [mlir] a819e73 - [mlir] Support broadcast dimensions in ProgressiveVectorToSCF

Matthias Springer llvmlistbot at llvm.org
Fri Apr 23 02:01:59 PDT 2021


Author: Matthias Springer
Date: 2021-04-23T18:01:32+09:00
New Revision: a819e7339315687f06f686971a649f614afbd987

URL: https://github.com/llvm/llvm-project/commit/a819e7339315687f06f686971a649f614afbd987
DIFF: https://github.com/llvm/llvm-project/commit/a819e7339315687f06f686971a649f614afbd987.diff

LOG: [mlir] Support broadcast dimensions in ProgressiveVectorToSCF

This commit adds support for broadcast dimensions in permutation maps of vector transfer ops.

Also fixes a bug in VectorToSCF that generated incorrect in-bounds checks for broadcast dimensions.

Differential Revision: https://reviews.llvm.org/D101019

Added: 
    

Modified: 
    mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp
    mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
    mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir
    mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir
    mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp
index 05f8dc4b4856..3eb6072e7979 100644
--- a/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp
@@ -70,13 +70,16 @@ static Value setAllocAtFunctionEntry(MemRefType type, Operation *op) {
 
 /// Given a vector transfer op, calculate which dimension of the `source`
 /// memref should be unpacked in the next application of TransferOpConversion.
+/// A return value of None indicates a broadcast.
 template <typename OpTy>
-static unsigned unpackedDim(OpTy xferOp) {
+static Optional<int64_t> unpackedDim(OpTy xferOp) {
   auto map = xferOp.permutation_map();
-  // TODO: Handle broadcast
-  auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>();
-  assert(expr && "Expected AffineDimExpr in permutation map result");
-  return expr.getPosition();
+  if (auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>())
+    return expr.getPosition();
+
+  assert(map.getResult(0).template isa<AffineConstantExpr>() &&
+         "Expected AffineDimExpr or AffineConstantExpr");
+  return None;
 }
 
 /// Compute the permutation map for the new (N-1)-D vector transfer op. This
@@ -103,8 +106,12 @@ static void getXferIndices(OpTy xferOp, Value iv,
   auto dim = unpackedDim(xferOp);
   auto prevIndices = adaptor.indices();
   indices.append(prevIndices.begin(), prevIndices.end());
-  using edsc::op::operator+;
-  indices[dim] = adaptor.indices()[dim] + iv;
+
+  bool isBroadcast = !dim.hasValue();
+  if (!isBroadcast) {
+    using edsc::op::operator+;
+    indices[dim.getValue()] = adaptor.indices()[dim.getValue()] + iv;
+  }
 }
 
 static void maybeYieldValue(bool hasRetVal, OpBuilder builder, Location loc,
@@ -116,7 +123,7 @@ static void maybeYieldValue(bool hasRetVal, OpBuilder builder, Location loc,
   }
 }
 
-/// Helper function TransferOpConversion and Strided1dTransferOpConversion.
+/// Helper function TransferOpConversion and TransferOp1dConversion.
 /// Generate an in-bounds check if the transfer op may go out-of-bounds on the
 /// specified dimension `dim` with the loop iteration variable `iv`.
 /// E.g., when unpacking dimension 0 from:
@@ -138,15 +145,17 @@ static void maybeYieldValue(bool hasRetVal, OpBuilder builder, Location loc,
 /// `resultTypes`.
 template <typename OpTy>
 static Value generateInBoundsCheck(
-    OpTy xferOp, Value iv, OpBuilder &builder, unsigned dim,
+    OpTy xferOp, Value iv, OpBuilder &builder, Optional<int64_t> dim,
     TypeRange resultTypes,
     function_ref<Value(OpBuilder &, Location)> inBoundsCase,
     function_ref<Value(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
   bool hasRetVal = !resultTypes.empty();
-  if (!xferOp.isDimInBounds(0)) {
-    auto memrefDim = memref_dim(xferOp.source(), std_constant_index(dim));
+  bool isBroadcast = !dim.hasValue(); // No in-bounds check for broadcasts.
+  if (!xferOp.isDimInBounds(0) && !isBroadcast) {
+    auto memrefDim =
+        memref_dim(xferOp.source(), std_constant_index(dim.getValue()));
     using edsc::op::operator+;
-    auto memrefIdx = xferOp.indices()[dim] + iv;
+    auto memrefIdx = xferOp.indices()[dim.getValue()] + iv;
     auto cond = std_cmpi_sgt(memrefDim.value, memrefIdx);
     auto check = builder.create<scf::IfOp>(
         xferOp.getLoc(), resultTypes, cond,
@@ -175,7 +184,7 @@ static Value generateInBoundsCheck(
 /// a return value. Consequently, this function does not have a return value.
 template <typename OpTy>
 static void generateInBoundsCheck(
-    OpTy xferOp, Value iv, OpBuilder &builder, int64_t dim,
+    OpTy xferOp, Value iv, OpBuilder &builder, Optional<int64_t> dim,
     function_ref<void(OpBuilder &, Location)> inBoundsCase,
     function_ref<void(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
   generateInBoundsCheck(
@@ -534,27 +543,31 @@ struct TransferOpConversion : public OpRewritePattern<OpTy> {
 };
 
 /// Compute the indices into the memref for the LoadOp/StoreOp generated as
-/// part of Strided1dTransferOpConversion. Return the memref dimension on which
-/// the transfer is operating.
+/// part of TransferOp1dConversion. Return the memref dimension on which
+/// the transfer is operating. A return value of None indicates a broadcast.
 template <typename OpTy>
-static unsigned get1dMemrefIndices(OpTy xferOp, Value iv,
-                                   SmallVector<Value, 8> &memrefIndices) {
+static Optional<int64_t>
+get1dMemrefIndices(OpTy xferOp, Value iv,
+                   SmallVector<Value, 8> &memrefIndices) {
   auto indices = xferOp.indices();
   auto map = xferOp.permutation_map();
 
   memrefIndices.append(indices.begin(), indices.end());
   assert(map.getNumResults() == 1 &&
          "Expected 1 permutation map result for 1D transfer");
-  // TODO: Handle broadcast
-  auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>();
-  assert(expr && "Expected AffineDimExpr in permutation map result");
-  auto dim = expr.getPosition();
-  using edsc::op::operator+;
-  memrefIndices[dim] = memrefIndices[dim] + iv;
-  return dim;
+  if (auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>()) {
+    auto dim = expr.getPosition();
+    using edsc::op::operator+;
+    memrefIndices[dim] = memrefIndices[dim] + iv;
+    return dim;
+  }
+
+  assert(map.getResult(0).template isa<AffineConstantExpr>() &&
+         "Expected AffineDimExpr or AffineConstantExpr");
+  return None;
 }
 
-/// Codegen strategy for Strided1dTransferOpConversion, depending on the
+/// Codegen strategy for TransferOp1dConversion, depending on the
 /// operation.
 template <typename OpTy>
 struct Strategy1d;
@@ -613,14 +626,24 @@ struct Strategy1d<TransferWriteOp> {
   static Value initialLoopState(TransferWriteOp xferOp) { return Value(); }
 };
 
-/// Lower a 1D vector transfer op that operates on a dimension 
diff erent from
-/// the last one. Instead of accessing contiguous chunks (vectors) of memory,
-/// such ops access memory in a strided fashion.
+/// Lower a 1D vector transfer op to SCF using scalar loads/stores. This is
+/// necessary in cases where a 1D vector transfer op cannot be lowered into
+/// vector load/stores due to non-unit strides or broadcasts:
+///
+/// * Transfer dimension is not the last memref dimension
+/// * Transfer dimension is a broadcast (i.e., scalar load + broadcast)
+/// * Memref has a layout map with non-unit stride on the last dimension
+///
+/// This pattern generates IR as follows:
 ///
 /// 1. Generate a for loop iterating over each vector element.
 /// 2. Inside the loop, generate a InsertElementOp or ExtractElementOp,
 ///    depending on OpTy.
 ///
+/// TODO: In some cases (no masking, etc.), LLVM::MatrixColumnMajorLoadOp
+///       can be generated instead of TransferOp1dConversion. Add such a pattern
+///       to ConvertVectorToLLVM.
+///
 /// E.g.:
 /// ```
 /// vector.transfer_write %vec, %A[%a, %b]
@@ -635,7 +658,7 @@ struct Strategy1d<TransferWriteOp> {
 /// }
 /// ```
 template <typename OpTy>
-struct Strided1dTransferOpConversion : public OpRewritePattern<OpTy> {
+struct TransferOp1dConversion : public OpRewritePattern<OpTy> {
   using OpRewritePattern<OpTy>::OpRewritePattern;
 
   LogicalResult matchAndRewrite(OpTy xferOp,
@@ -681,8 +704,8 @@ void populateProgressiveVectorToSCFConversionPatterns(
                TransferOpConversion<TransferWriteOp>>(patterns.getContext());
 
   if (kTargetRank == 1) {
-    patterns.add<Strided1dTransferOpConversion<TransferReadOp>,
-                 Strided1dTransferOpConversion<TransferWriteOp>>(
+    patterns.add<TransferOp1dConversion<TransferReadOp>,
+                 TransferOp1dConversion<TransferWriteOp>>(
         patterns.getContext());
   }
 }

diff  --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 72d32d071e49..4f13e7d8e5af 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -230,7 +230,10 @@ emitInBoundsCondition(PatternRewriter &rewriter,
     Value iv = std::get<0>(it), off = std::get<1>(it), ub = std::get<2>(it);
     using namespace mlir::edsc::op;
     majorIvsPlusOffsets.push_back(iv + off);
-    if (!xferOp.isDimInBounds(leadingRank + idx)) {
+    auto affineConstExpr =
+        xferOp.permutation_map().getResult(idx).dyn_cast<AffineConstantExpr>();
+    bool isBroadcast = affineConstExpr && affineConstExpr.getValue() == 0;
+    if (!xferOp.isDimInBounds(leadingRank + idx) && !isBroadcast) {
       Value inBoundsCond = onTheFlyFoldSLT(majorIvsPlusOffsets.back(), ub);
       if (inBoundsCond)
         inBoundsCondition = (inBoundsCondition)

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir
index 17f635f7b78a..b6bd8c404158 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir
@@ -10,6 +10,15 @@
 
 // Test for special cases of 1D vector transfer ops.
 
+func @transfer_read_2d(%A : memref<?x?xf32>, %base1 : index, %base2 : index) {
+  %fm42 = constant -42.0: f32
+  %f = vector.transfer_read %A[%base1, %base2], %fm42
+      {permutation_map = affine_map<(d0, d1) -> (d0, d1)>}
+      : memref<?x?xf32>, vector<5x6xf32>
+  vector.print %f: vector<5x6xf32>
+  return
+}
+
 func @transfer_read_1d(%A : memref<?x?xf32>, %base1 : index, %base2 : index) {
   %fm42 = constant -42.0: f32
   %f = vector.transfer_read %A[%base1, %base2], %fm42
@@ -19,6 +28,16 @@ func @transfer_read_1d(%A : memref<?x?xf32>, %base1 : index, %base2 : index) {
   return
 }
 
+func @transfer_read_1d_broadcast(
+    %A : memref<?x?xf32>, %base1 : index, %base2 : index) {
+  %fm42 = constant -42.0: f32
+  %f = vector.transfer_read %A[%base1, %base2], %fm42
+      {permutation_map = affine_map<(d0, d1) -> (0)>}
+      : memref<?x?xf32>, vector<9xf32>
+  vector.print %f: vector<9xf32>
+  return
+}
+
 func @transfer_write_1d(%A : memref<?x?xf32>, %base1 : index, %base2 : index) {
   %fn1 = constant -1.0 : f32
   %vf0 = splat %fn1 : vector<7xf32>
@@ -53,8 +72,11 @@ func @entry() {
   call @transfer_read_1d(%A, %c1, %c2) : (memref<?x?xf32>, index, index) -> ()
   call @transfer_write_1d(%A, %c3, %c2) : (memref<?x?xf32>, index, index) -> ()
   call @transfer_read_1d(%A, %c0, %c2) : (memref<?x?xf32>, index, index) -> ()
+  call @transfer_read_1d_broadcast(%A, %c1, %c2)
+      : (memref<?x?xf32>, index, index) -> ()
   return
 }
 
 // CHECK: ( 12, 22, 32, 42, -42, -42, -42, -42, -42 )
 // CHECK: ( 2, 12, 22, -1, -1, -42, -42, -42, -42 )
+// CHECK: ( 12, 12, 12, 12, 12, 12, 12, 12, 12 )

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir
index fbcf94a6233c..cbe0aa52a437 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir
@@ -27,6 +27,16 @@ func @transfer_read_2d_transposed(
   return
 }
 
+func @transfer_read_2d_broadcast(
+    %A : memref<?x?xf32>, %base1: index, %base2: index) {
+  %fm42 = constant -42.0: f32
+  %f = vector.transfer_read %A[%base1, %base2], %fm42
+      {permutation_map = affine_map<(d0, d1) -> (d1, 0)>} :
+    memref<?x?xf32>, vector<4x9xf32>
+  vector.print %f: vector<4x9xf32>
+  return
+}
+
 func @transfer_write_2d(%A : memref<?x?xf32>, %base1: index, %base2: index) {
   %fn1 = constant -1.0 : f32
   %vf0 = splat %fn1 : vector<1x4xf32>
@@ -73,6 +83,9 @@ func @entry() {
   // Same as above, but transposed
   call @transfer_read_2d_transposed(%A, %c0, %c0)
       : (memref<?x?xf32>, index, index) -> ()
+  // Second vector dimension is a broadcast
+  call @transfer_read_2d_broadcast(%A, %c1, %c2)
+      : (memref<?x?xf32>, index, index) -> ()
   return
 }
 
@@ -80,3 +93,4 @@ func @entry() {
 // CHECK: ( ( 12, 22, -42, -42, -42, -42, -42, -42, -42 ), ( 13, 23, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) )
 // CHECK: ( ( 0, 1, 2, 3, -42, -42, -42, -42, -42 ), ( 10, 11, 12, 13, -42, -42, -42, -42, -42 ), ( 20, 21, 22, 23, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) )
 // CHECK: ( ( 0, 10, 20, -42, -42, -42, -42, -42, -42 ), ( 1, 11, 21, -42, -42, -42, -42, -42, -42 ), ( 2, 12, 22, -42, -42, -42, -42, -42, -42 ), ( 3, 13, 23, -42, -42, -42, -42, -42, -42 ) )
+// CHECK: ( ( 12, 12, 12, 12, 12, 12, 12, 12, 12 ), ( 13, 13, 13, 13, 13, 13, 13, 13, 13 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) )

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir
index 7ecac4a38938..ae7fee3c9110 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir
@@ -19,6 +19,16 @@ func @transfer_read_3d(%A : memref<?x?x?x?xf32>,
   return
 }
 
+func @transfer_read_3d_broadcast(%A : memref<?x?x?x?xf32>,
+                                 %o: index, %a: index, %b: index, %c: index) {
+  %fm42 = constant -42.0: f32
+  %f = vector.transfer_read %A[%o, %a, %b, %c], %fm42
+      {permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, 0, d3)>}
+      : memref<?x?x?x?xf32>, vector<2x5x3xf32>
+  vector.print %f: vector<2x5x3xf32>
+  return
+}
+
 func @transfer_read_3d_transposed(%A : memref<?x?x?x?xf32>,
                                   %o: index, %a: index, %b: index, %c: index) {
   %fm42 = constant -42.0: f32
@@ -78,9 +88,12 @@ func @entry() {
       : (memref<?x?x?x?xf32>, index, index, index, index) -> ()
   call @transfer_read_3d_transposed(%A, %c0, %c0, %c0, %c0)
       : (memref<?x?x?x?xf32>, index, index, index, index) -> ()
+  call @transfer_read_3d_broadcast(%A, %c0, %c0, %c0, %c0)
+      : (memref<?x?x?x?xf32>, index, index, index, index) -> ()
   return
 }
 
 // CHECK: ( ( ( 0, 0, -42 ), ( 2, 3, -42 ), ( 4, 6, -42 ), ( 6, 9, -42 ), ( -42, -42, -42 ) ), ( ( 20, 30, -42 ), ( 22, 33, -42 ), ( 24, 36, -42 ), ( 26, 39, -42 ), ( -42, -42, -42 ) ) )
 // CHECK: ( ( ( 0, 0, -42 ), ( 2, -1, -42 ), ( 4, -1, -42 ), ( 6, -1, -42 ), ( -42, -42, -42 ) ), ( ( 20, 30, -42 ), ( 22, -1, -42 ), ( 24, -1, -42 ), ( 26, -1, -42 ), ( -42, -42, -42 ) ) )
 // CHECK: ( ( ( 0, 20, 40 ), ( 0, 20, 40 ), ( 0, 20, 40 ), ( 0, 20, 40 ), ( 0, 20, 40 ) ), ( ( 0, 30, 60 ), ( 0, 30, 60 ), ( 0, 30, 60 ), ( 0, 30, 60 ), ( 0, 30, 60 ) ), ( ( -42, -42, -42 ), ( -42, -42, -42 ), ( -42, -42, -42 ), ( -42, -42, -42 ), ( -42, -42, -42 ) ) )
+// CHECK: ( ( ( 0, 0, -42 ), ( 0, 0, -42 ), ( 0, 0, -42 ), ( 0, 0, -42 ), ( 0, 0, -42 ) ), ( ( 20, 30, -42 ), ( 20, 30, -42 ), ( 20, 30, -42 ), ( 20, 30, -42 ), ( 20, 30, -42 ) ) )


        


More information about the Mlir-commits mailing list