[Mlir-commits] [mlir] 5017b0f - [mlir] Check only last dim stride in transfer op lowering

Matthias Springer llvmlistbot at llvm.org
Tue May 25 01:53:59 PDT 2021


Author: Matthias Springer
Date: 2021-05-25T17:53:24+09:00
New Revision: 5017b0f88b81083d3f723e7a8e5cc19b1c4eb366

URL: https://github.com/llvm/llvm-project/commit/5017b0f88b81083d3f723e7a8e5cc19b1c4eb366
DIFF: https://github.com/llvm/llvm-project/commit/5017b0f88b81083d3f723e7a8e5cc19b1c4eb366.diff

LOG: [mlir] Check only last dim stride in transfer op lowering

Lower a 1D vector transfer op to LLVM if the last dim stride is 1. Also fixes a bug in the original unit stride computation.

Differential Revision: https://reviews.llvm.org/D102897

Added: 
    

Modified: 
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
    mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index f3909b3e85c6b..dd6c3b6dd103d 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1027,6 +1027,15 @@ class VectorInsertStridedSliceOpSameRankRewritePattern
   }
 };
 
+/// Return true if the last dimension of the MemRefType has unit stride. Also
+/// return true for memrefs with no strides.
+static bool isLastMemrefDimUnitStride(MemRefType type) {
+  int64_t offset;
+  SmallVector<int64_t> strides;
+  auto successStrides = getStridesAndOffset(type, strides, offset);
+  return succeeded(successStrides) && (strides.empty() || strides.back() == 1);
+}
+
 /// Returns the strides if the memory underlying `memRefType` has a contiguous
 /// static layout.
 static llvm::Optional<SmallVector<int64_t, 4>>
@@ -1047,7 +1056,7 @@ computeContiguousStrides(MemRefType memRefType) {
   // contiguous dynamic shapes in other ways than with just empty/identity
   // layout.
   auto sizes = memRefType.getShape();
-  for (int index = 0, e = strides.size() - 2; index < e; ++index) {
+  for (int index = 0, e = strides.size() - 1; index < e; ++index) {
     if (ShapedType::isDynamic(sizes[index + 1]) ||
         ShapedType::isDynamicStrideOrOffset(strides[index]) ||
         ShapedType::isDynamicStrideOrOffset(strides[index + 1]))
@@ -1149,8 +1158,7 @@ class VectorTransferConversion : public ConvertOpToLLVMPattern<ConcreteOp> {
                   ConversionPatternRewriter &rewriter) const override {
     auto adaptor = getTransferOpAdapter(xferOp, operands);
 
-    if (xferOp.getVectorType().getRank() > 1 ||
-        llvm::size(xferOp.indices()) == 0)
+    if (xferOp.getVectorType().getRank() > 1 || xferOp.indices().empty())
       return failure();
     if (xferOp.permutation_map() !=
         AffineMap::getMinorIdentityMap(xferOp.permutation_map().getNumInputs(),
@@ -1160,9 +1168,8 @@ class VectorTransferConversion : public ConvertOpToLLVMPattern<ConcreteOp> {
     auto memRefType = xferOp.getShapedType().template dyn_cast<MemRefType>();
     if (!memRefType)
       return failure();
-    // Only contiguous source tensors supported atm.
-    auto strides = computeContiguousStrides(memRefType);
-    if (!strides)
+    // Last dimension must be contiguous. (Otherwise: Use VectorToSCF.)
+    if (!isLastMemrefDimUnitStride(memRefType))
       return failure();
     // Out-of-bounds dims are handled by MaterializeTransferMask.
     if (xferOp.hasOutOfBoundsDim())

diff  --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index ee439c00b9b24..bb0b5162cc0fa 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -1066,7 +1066,7 @@ static bool isLastMemrefDimUnitStride(MemRefType type) {
   int64_t offset;
   SmallVector<int64_t, 4> strides;
   auto successStrides = getStridesAndOffset(type, strides, offset);
-  return succeeded(successStrides) && strides.back() == 1;
+  return succeeded(successStrides) && (strides.empty() || strides.back() == 1);
 }
 
 /// Lower a 1D vector transfer op to SCF using scalar loads/stores. This is

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir
index 7e6ef94c872d6..7e1596ae31619 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir
@@ -37,6 +37,58 @@ func @transfer_read_1d(%A : memref<?x?xf32>, %base1 : index, %base2 : index) {
   return
 }
 
+#map0 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
+#map1 = affine_map<(d0, d1) -> (6 * d0 + 2 * d1)>
+
+// Vector load with unit stride only on last dim.
+func @transfer_read_1d_unit_stride(%A : memref<?x?xf32>) {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c2 = constant 2 : index
+  %c3 = constant 3 : index
+  %c4 = constant 4 : index
+  %c5 = constant 5 : index
+  %c6 = constant 6 : index
+  %fm42 = constant -42.0: f32
+  scf.for %arg2 = %c1 to %c5 step %c2 {
+    scf.for %arg3 = %c0 to %c6 step %c3 {
+      %0 = memref.subview %A[%arg2, %arg3] [1, 2] [1, 1]
+          : memref<?x?xf32> to memref<1x2xf32, #map0>
+      %1 = vector.transfer_read %0[%c0, %c0], %fm42 {in_bounds=[true]}
+          : memref<1x2xf32, #map0>, vector<2xf32>
+      vector.print %1 : vector<2xf32>
+    }
+  }
+  return
+}
+
+// Vector load with unit stride only on last dim. Strides are not static, so
+// codegen must go through VectorToSCF 1D lowering.
+func @transfer_read_1d_non_static_unit_stride(%A : memref<?x?xf32>) {
+  %c1 = constant 1 : index
+  %c2 = constant 2 : index
+  %c6 = constant 6 : index
+  %fm42 = constant -42.0: f32
+  %1 = memref.reinterpret_cast %A to offset: [%c6], sizes: [%c1, %c2],  strides: [%c6, %c1]
+      : memref<?x?xf32> to memref<?x?xf32, offset: ?, strides: [?, ?]>
+  %2 = vector.transfer_read %1[%c2, %c1], %fm42 {in_bounds=[true]}
+      : memref<?x?xf32, offset: ?, strides: [?, ?]>, vector<4xf32>
+  vector.print %2 : vector<4xf32>
+  return
+}
+
+// Vector load where last dim has non-unit stride.
+func @transfer_read_1d_non_unit_stride(%A : memref<?x?xf32>) {
+  %B = memref.reinterpret_cast %A to offset: [0], sizes: [4, 3], strides: [6, 2]
+      : memref<?x?xf32> to memref<4x3xf32, #map1>
+  %c1 = constant 1 : index
+  %c2 = constant 2 : index
+  %fm42 = constant -42.0: f32
+  %vec = vector.transfer_read %B[%c2, %c1], %fm42 {in_bounds=[false]} : memref<4x3xf32, #map1>, vector<3xf32>
+  vector.print %vec : vector<3xf32>
+  return
+}
+
 // Broadcast.
 func @transfer_read_1d_broadcast(
     %A : memref<?x?xf32>, %base1 : index, %base2 : index) {
@@ -117,42 +169,58 @@ func @entry() {
   call @transfer_read_1d(%A, %c1, %c2) : (memref<?x?xf32>, index, index) -> ()
   // CHECK: ( 12, 22, 32, 42, -42, -42, -42, -42, -42 )
 
-  // 2. Write to 2D memref on first dimension. Cannot be lowered to an LLVM
+  // 2.a. Read 1D vector from 2D memref with non-unit stride on first dim.
+  call @transfer_read_1d_unit_stride(%A) : (memref<?x?xf32>) -> ()
+  // CHECK: ( 10, 11 )
+  // CHECK: ( 13, 14 )
+  // CHECK: ( 30, 31 )
+  // CHECK: ( 33, 34 )
+
+  // 2.b. Read 1D vector from 2D memref with non-unit stride on first dim.
+  //      Strides are non-static.
+  call @transfer_read_1d_non_static_unit_stride(%A) : (memref<?x?xf32>) -> ()
+  // CHECK: ( 31, 32, 33, 34 )
+
+  // 3. Read 1D vector from 2D memref with non-unit stride on second dim.
+  call @transfer_read_1d_non_unit_stride(%A) : (memref<?x?xf32>) -> ()
+  // CHECK: ( 22, 24, -42 )
+
+  // 4. Write to 2D memref on first dimension. Cannot be lowered to an LLVM
   //    vector store. Instead, generates scalar stores.
   call @transfer_write_1d(%A, %c3, %c2) : (memref<?x?xf32>, index, index) -> ()
 
-  // 3. (Same as 1. To check if 2 works correctly.)
+  // 5. (Same as 1. To check if 4 works correctly.)
   call @transfer_read_1d(%A, %c0, %c2) : (memref<?x?xf32>, index, index) -> ()
   // CHECK: ( 2, 12, 22, -1, -1, -42, -42, -42, -42 )
 
-  // 4. Read a scalar from a 2D memref and broadcast the value to a 1D vector.
+  // 6. Read a scalar from a 2D memref and broadcast the value to a 1D vector.
   //    Generates a loop with vector.insertelement.
   call @transfer_read_1d_broadcast(%A, %c1, %c2)
       : (memref<?x?xf32>, index, index) -> ()
   // CHECK: ( 12, 12, 12, 12, 12, 12, 12, 12, 12 )
 
-  // 5. Read from 2D memref on first dimension. Accesses are in-bounds, so no
+  // 7. Read from 2D memref on first dimension. Accesses are in-bounds, so no
   //    if-check is generated inside the generated loop.
   call @transfer_read_1d_in_bounds(%A, %c1, %c2)
       : (memref<?x?xf32>, index, index) -> ()
   // CHECK: ( 12, 22, -1 )
 
-  // 6. Optional mask attribute is specified and, in addition, there may be
+  // 8. Optional mask attribute is specified and, in addition, there may be
   //    out-of-bounds accesses.
   call @transfer_read_1d_mask(%A, %c1, %c2)
       : (memref<?x?xf32>, index, index) -> ()
   // CHECK: ( 12, -42, -1, -42, -42, -42, -42, -42, -42 )
 
-  // 7. Same as 6, but accesses are in-bounds.
+  // 9. Same as 8, but accesses are in-bounds.
   call @transfer_read_1d_mask_in_bounds(%A, %c1, %c2)
       : (memref<?x?xf32>, index, index) -> ()
   // CHECK: ( 12, -42, -1 )
 
-  // 8. Write to 2D memref on first dimension with a mask.
+  // 10. Write to 2D memref on first dimension with a mask.
   call @transfer_write_1d_mask(%A, %c1, %c0)
       : (memref<?x?xf32>, index, index) -> ()
 
-  // 9. (Same as 1. To check if 8 works correctly.)
+  // 11. (Same as 1. To check if 10 works correctly.)
   call @transfer_read_1d(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
   // CHECK: ( 0, -2, 20, -2, 40, -42, -42, -42, -42 )
 


        


More information about the Mlir-commits mailing list