[Mlir-commits] [mlir] 2a82dfd - [mlir][VectorOps] Don't drop scalable dims when lowering transfer_reads/writes (in VectorToSCF)

Benjamin Maxwell llvmlistbot at llvm.org
Fri Sep 8 02:44:59 PDT 2023


Author: Benjamin Maxwell
Date: 2023-09-08T09:43:17Z
New Revision: 2a82dfd7040276d50347a3fb4bcb6aced54d9fc5

URL: https://github.com/llvm/llvm-project/commit/2a82dfd7040276d50347a3fb4bcb6aced54d9fc5
DIFF: https://github.com/llvm/llvm-project/commit/2a82dfd7040276d50347a3fb4bcb6aced54d9fc5.diff

LOG: [mlir][VectorOps] Don't drop scalable dims when lowering transfer_reads/writes (in VectorToSCF)

This allows the lowering of > rank 1 transfer_reads/writes to equivalent
lower-rank ones when the trailing dimension is scalable. The resulting
ops still cannot be completely lowered as they depend on arrays of
scalable vectors being enabled, and a few related fixes (see D158517).

This patch also explicitly disables lowering transfer_reads/writes with
a leading scalable dimension, as more changes would be needed to handle
that correctly and it is unclear if it is required.

Examples of ops that can now be further lowered:

  %vec = vector.transfer_read %arg0[%c0, %c0], %cst, %mask
		 {in_bounds = [true, true]} : memref<3x?xf32>, vector<3x[4]xf32>

  vector.transfer_write %vec, %arg0[%c0, %c0], %mask
		 {in_bounds = [true, true]} : vector<3x[4]xf32>, memref<3x?xf32>

Reviewed By: c-rhodes, awarzynski, dcaballe

Differential Revision: https://reviews.llvm.org/D158753

Added: 
    

Modified: 
    mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
    mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 675bddca61a3e2d..1aeed4594f94505 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -314,15 +314,18 @@ static BufferAllocs allocBuffers(OpBuilder &b, OpTy xferOp) {
 /// the VectorType into the MemRefType.
 ///
 /// E.g.: memref<9xvector<5x6xf32>> --> memref<9x5xvector<6xf32>>
-static MemRefType unpackOneDim(MemRefType type) {
+static FailureOr<MemRefType> unpackOneDim(MemRefType type) {
   auto vectorType = dyn_cast<VectorType>(type.getElementType());
+  // Vectors with leading scalable dims are not supported.
+  // It may be possible to support these in future by using dynamic memref dims.
+  if (vectorType.getScalableDims().front())
+    return failure();
   auto memrefShape = type.getShape();
   SmallVector<int64_t, 8> newMemrefShape;
   newMemrefShape.append(memrefShape.begin(), memrefShape.end());
   newMemrefShape.push_back(vectorType.getDimSize(0));
   return MemRefType::get(newMemrefShape,
-                         VectorType::get(vectorType.getShape().drop_front(),
-                                         vectorType.getElementType()));
+                         VectorType::Builder(vectorType).dropDim(0));
 }
 
 /// Given a transfer op, find the memref from which the mask is loaded. This
@@ -542,6 +545,10 @@ LogicalResult checkPrepareXferOp(OpTy xferOp,
     return failure();
   if (xferOp.getVectorType().getRank() <= options.targetRank)
     return failure();
+  // Currently the unpacking of the leading dimension into the memref is not
+  // supported for scalable dimensions.
+  if (xferOp.getVectorType().getScalableDims().front())
+    return failure();
   if (isTensorOp(xferOp) && !options.lowerTensors)
     return failure();
   // Transfer ops that modify the element type are not supported atm.
@@ -866,8 +873,11 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
     auto dataBuffer = Strategy<OpTy>::getBuffer(xferOp);
     auto dataBufferType = dyn_cast<MemRefType>(dataBuffer.getType());
     auto castedDataType = unpackOneDim(dataBufferType);
+    if (failed(castedDataType))
+      return failure();
+
     auto castedDataBuffer =
-        locB.create<vector::TypeCastOp>(castedDataType, dataBuffer);
+        locB.create<vector::TypeCastOp>(*castedDataType, dataBuffer);
 
     // If the xferOp has a mask: Find and cast mask buffer.
     Value castedMaskBuffer;
@@ -882,7 +892,9 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
         //   be broadcasted.)
         castedMaskBuffer = maskBuffer;
       } else {
-        auto castedMaskType = unpackOneDim(maskBufferType);
+        // It's safe to assume the mask buffer can be unpacked if the data
+        // buffer was unpacked.
+        auto castedMaskType = *unpackOneDim(maskBufferType);
         castedMaskBuffer =
             locB.create<vector::TypeCastOp>(castedMaskType, maskBuffer);
       }
@@ -891,7 +903,7 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
     // Loop bounds and step.
     auto lb = locB.create<arith::ConstantIndexOp>(0);
     auto ub = locB.create<arith::ConstantIndexOp>(
-        castedDataType.getDimSize(castedDataType.getRank() - 1));
+        castedDataType->getDimSize(castedDataType->getRank() - 1));
     auto step = locB.create<arith::ConstantIndexOp>(1);
     // TransferWriteOps that operate on tensors return the modified tensor and
     // require a loop state.
@@ -1074,8 +1086,14 @@ struct UnrollTransferReadConversion
     auto vec = getResultVector(xferOp, rewriter);
     auto vecType = dyn_cast<VectorType>(vec.getType());
     auto xferVecType = xferOp.getVectorType();
-    auto newXferVecType = VectorType::get(xferVecType.getShape().drop_front(),
-                                          xferVecType.getElementType());
+
+    if (xferVecType.getScalableDims()[0]) {
+      // Cannot unroll a scalable dimension at compile time.
+      return failure();
+    }
+
+    VectorType newXferVecType = VectorType::Builder(xferVecType).dropDim(0);
+
     int64_t dimSize = xferVecType.getShape()[0];
 
     // Generate fully unrolled loop of transfer ops.

diff  --git a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
index 2653786106c2840..484e1fcde62d64d 100644
--- a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
+++ b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
@@ -635,3 +635,106 @@ func.func @vector_print_scalable_vector(%arg0: vector<[4]xi32>) {
 // CHECK:           vector.print
 // CHECK:           return
 // CHECK:         }
+
+// -----
+
+func.func @transfer_read_array_of_scalable(%arg0: memref<3x?xf32>) -> vector<3x[4]xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  %dim = memref.dim %arg0, %c1 : memref<3x?xf32>
+  %mask = vector.create_mask %c1, %dim : vector<3x[4]xi1>
+  %read = vector.transfer_read %arg0[%c0, %c0], %cst, %mask {in_bounds = [true, true]} : memref<3x?xf32>, vector<3x[4]xf32>
+  return %read : vector<3x[4]xf32>
+}
+// CHECK-LABEL:   func.func @transfer_read_array_of_scalable(
+// CHECK-SAME:                                               %[[ARG:.*]]: memref<3x?xf32>) -> vector<3x[4]xf32> {
+// CHECK:           %[[PADDING:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[C0:.*]] = arith.constant 0 : index
+// CHECK:           %[[C3:.*]] = arith.constant 3 : index
+// CHECK:           %[[C1:.*]] = arith.constant 1 : index
+// CHECK:           %[[ALLOCA_VEC:.*]] = memref.alloca() : memref<vector<3x[4]xf32>>
+// CHECK:           %[[ALLOCA_MASK:.*]] = memref.alloca() : memref<vector<3x[4]xi1>>
+// CHECK:           %[[DIM_SIZE:.*]] = memref.dim %[[ARG]], %[[C1]] : memref<3x?xf32>
+// CHECK:           %[[MASK:.*]] = vector.create_mask %[[C1]], %[[DIM_SIZE]] : vector<3x[4]xi1>
+// CHECK:           memref.store %[[MASK]], %[[ALLOCA_MASK]][] : memref<vector<3x[4]xi1>>
+// CHECK:           %[[UNPACK_VECTOR:.*]] = vector.type_cast %[[ALLOCA_VEC]] : memref<vector<3x[4]xf32>> to memref<3xvector<[4]xf32>>
+// CHECK:           %[[UNPACK_MASK:.*]] = vector.type_cast %[[ALLOCA_MASK]] : memref<vector<3x[4]xi1>> to memref<3xvector<[4]xi1>>
+// CHECK:           scf.for %[[VAL_11:.*]] = %[[C0]] to %[[C3]] step %[[C1]] {
+// CHECK:             %[[MASK_SLICE:.*]] = memref.load %[[UNPACK_MASK]]{{\[}}%[[VAL_11]]] : memref<3xvector<[4]xi1>>
+// CHECK:             %[[READ_SLICE:.*]] = vector.transfer_read %[[ARG]]{{\[}}%[[VAL_11]], %[[C0]]], %[[PADDING]], %[[MASK_SLICE]] {in_bounds = [true]} : memref<3x?xf32>, vector<[4]xf32>
+// CHECK:             memref.store %[[READ_SLICE]], %[[UNPACK_VECTOR]]{{\[}}%[[VAL_11]]] : memref<3xvector<[4]xf32>>
+// CHECK:           }
+// CHECK:           %[[RESULT:.*]] = memref.load %[[ALLOCA_VEC]][] : memref<vector<3x[4]xf32>>
+// CHECK:           return %[[RESULT]] : vector<3x[4]xf32>
+// CHECK:         }
+
+// -----
+
+func.func @transfer_write_array_of_scalable(%vec: vector<3x[4]xf32>, %arg0: memref<3x?xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  %dim = memref.dim %arg0, %c1 : memref<3x?xf32>
+  %mask = vector.create_mask %c1, %dim : vector<3x[4]xi1>
+  vector.transfer_write %vec, %arg0[%c0, %c0], %mask {in_bounds = [true, true]} : vector<3x[4]xf32>, memref<3x?xf32>
+  return
+}
+// CHECK-LABEL:   func.func @transfer_write_array_of_scalable(
+// CHECK-SAME:                                                %[[VEC:.*]]: vector<3x[4]xf32>,
+// CHECK-SAME:                                                %[[MEMREF:.*]]: memref<3x?xf32>) {
+// CHECK:           %[[C0:.*]] = arith.constant 0 : index
+// CHECK:           %[[C3:.*]] = arith.constant 3 : index
+// CHECK:           %[[C1:.*]] = arith.constant 1 : index
+// CHECK:           %[[ALLOCA_VEC:.*]] = memref.alloca() : memref<vector<3x[4]xf32>>
+// CHECK:           %[[ALLOCA_MASK:.*]] = memref.alloca() : memref<vector<3x[4]xi1>>
+// CHECK:           %[[DIM_SIZE:.*]] = memref.dim %[[MEMREF]], %[[C1]] : memref<3x?xf32>
+// CHECK:           %[[MASK:.*]] = vector.create_mask %[[C1]], %[[DIM_SIZE]] : vector<3x[4]xi1>
+// CHECK:           memref.store %[[MASK]], %[[ALLOCA_MASK]][] : memref<vector<3x[4]xi1>>
+// CHECK:           memref.store %[[VEC]], %[[ALLOCA_VEC]][] : memref<vector<3x[4]xf32>>
+// CHECK:           %[[UNPACK_VECTOR:.*]] = vector.type_cast %[[ALLOCA_VEC]] : memref<vector<3x[4]xf32>> to memref<3xvector<[4]xf32>>
+// CHECK:           %[[UNPACK_MASK:.*]] = vector.type_cast %[[ALLOCA_MASK]] : memref<vector<3x[4]xi1>> to memref<3xvector<[4]xi1>>
+// CHECK:           scf.for %[[VAL_11:.*]] = %[[C0]] to %[[C3]] step %[[C1]] {
+// CHECK:             %[[MASK_SLICE:.*]] = memref.load %[[UNPACK_VECTOR]]{{\[}}%[[VAL_11]]] : memref<3xvector<[4]xf32>>
+// CHECK:             %[[VECTOR_SLICE:.*]] = memref.load %[[UNPACK_MASK]]{{\[}}%[[VAL_11]]] : memref<3xvector<[4]xi1>>
+// CHECK:             vector.transfer_write %[[MASK_SLICE]], %[[MEMREF]]{{\[}}%[[VAL_11]], %[[C0]]], %[[VECTOR_SLICE]] {in_bounds = [true]} : vector<[4]xf32>, memref<3x?xf32>
+// CHECK:           }
+// CHECK:           return
+// CHECK:         }
+
+// -----
+
+/// The following two tests currently cannot be lowered via unpacking the leading dim since it is scalable.
+/// It may be possible to special case this via a dynamic dim in future.
+
+func.func @cannot_lower_transfer_write_with_leading_scalable(%vec: vector<[4]x4xf32>, %arg0: memref<?x4xf32>) {
+  %c0 = arith.constant 0 : index
+  %c4 = arith.constant 4 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  %dim = memref.dim %arg0, %c0 : memref<?x4xf32>
+  %mask = vector.create_mask %dim, %c4 : vector<[4]x4xi1>
+  vector.transfer_write %vec, %arg0[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[4]x4xf32>, memref<?x4xf32>
+  return
+}
+// CHECK-LABEL:   func.func @cannot_lower_transfer_write_with_leading_scalable(
+// CHECK-SAME:                                                                 %[[VEC:.*]]: vector<[4]x4xf32>,
+// CHECK-SAME:                                                                 %[[MEMREF:.*]]: memref<?x4xf32>)
+// CHECK: vector.transfer_write %[[VEC]], %[[MEMREF]][%{{.*}}, %{{.*}}], %{{.*}} {in_bounds = [true, true]} : vector<[4]x4xf32>, memref<?x4xf32>
+
+// -----
+
+func.func @cannot_lower_transfer_read_with_leading_scalable(%arg0: memref<?x4xf32>) -> vector<[4]x4xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  %dim = memref.dim %arg0, %c0 : memref<?x4xf32>
+  %mask = vector.create_mask %dim, %c4 : vector<[4]x4xi1>
+  %read = vector.transfer_read %arg0[%c0, %c0], %cst, %mask {in_bounds = [true, true]} : memref<?x4xf32>, vector<[4]x4xf32>
+  return %read : vector<[4]x4xf32>
+}
+// CHECK-LABEL:   func.func @cannot_lower_transfer_read_with_leading_scalable(
+// CHECK-SAME:                                                                %[[MEMREF:.*]]: memref<?x4xf32>)
+// CHECK: %{{.*}} = vector.transfer_read %[[MEMREF]][%{{.*}}, %{{.*}}], %{{.*}}, %{{.*}} {in_bounds = [true, true]} : memref<?x4xf32>, vector<[4]x4xf32>
+
+


        


More information about the Mlir-commits mailing list