[Mlir-commits] [mlir] cafb628 - [mlir][VectorToGPU] Update memref stride preconditions on `nvgpu.mma.sync` path

Christopher Bate llvmlistbot at llvm.org
Thu Sep 14 12:51:48 PDT 2023


Author: Christopher Bate
Date: 2023-09-14T13:51:42-06:00
New Revision: cafb6284d18bbdb952ae6d5e4aa97912d57dbfb8

URL: https://github.com/llvm/llvm-project/commit/cafb6284d18bbdb952ae6d5e4aa97912d57dbfb8
DIFF: https://github.com/llvm/llvm-project/commit/cafb6284d18bbdb952ae6d5e4aa97912d57dbfb8.diff

LOG: [mlir][VectorToGPU] Update memref stride preconditions on `nvgpu.mma.sync` path

This change removes the requirement that the row stride be statically known when
converting `vector.transfer_read` and `vector.transfer_write` to distributed
SIMT operations in the `nvgpu` lowering path. It also adds a check to verify
that the last dimension of the source memref is statically known to have stride
1 since this is assumed in the conversion logic.  No other change should be
required since the generated `vector.load` operations are never created across
dimensions other than the last. The routines for checking preconditions on
`vector.transfer_read/write` are moved to under nvgpu utilities.

The change is NFC with respect to the GPU dialect lowering path.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D155753

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/NVGPU/Utils/MMAUtils.h
    mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
    mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp
    mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/NVGPU/Utils/MMAUtils.h b/mlir/include/mlir/Dialect/NVGPU/Utils/MMAUtils.h
index 003a160985ee1b2..10851c617398244 100644
--- a/mlir/include/mlir/Dialect/NVGPU/Utils/MMAUtils.h
+++ b/mlir/include/mlir/Dialect/NVGPU/Utils/MMAUtils.h
@@ -93,6 +93,18 @@ FailureOr<AffineMap>
 getLaneIdToLdMatrixMatrixCoord(OpBuilder &builder, Location loc,
                                const LdMatrixParams &params);
 
+/// Returns whether the `vector.transfer_read` instruction can be interpreted
+/// as a warp-level cooperative matrix load operation. This function is meant to
+/// be used to establish whether `op` is part of a chain of such warp-level
+/// operations.
+bool canLowerToWarpMatrixOperation(vector::TransferReadOp op);
+
+/// Returns whether the `vector.transfer_write` instruction can be interpreted
+/// as a warp-level cooperative matrix store operation. This function is meant
+/// to be used to establish whether `op` is part of a chain of such warp-level
+/// operations.
+bool canLowerToWarpMatrixOperation(vector::TransferWriteOp op);
+
 } // namespace nvgpu
 } // namespace mlir
 

diff  --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 8a46357acd7bf1f..3089e917d0eed9c 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -119,10 +119,9 @@ static bool isTransposeMatrixLoadMap(AffineMap permutationMap) {
          permutationMap == AffineMap::get(nDim, 0, {innerDim, zero}, ctx);
 }
 
-// Return the stide for the dimension 0 of |type| if it is a memref and has a
-// constant stride.
-static std::optional<int64_t>
-getMemrefConstantHorizontalStride(ShapedType type) {
+// Return the stide for the second-to-last dimension of |type| if it is a memref
+// and has a constant stride.
+static std::optional<int64_t> getStaticallyKnownRowStride(ShapedType type) {
   auto memrefType = dyn_cast<MemRefType>(type);
   if (!memrefType)
     return false;
@@ -141,35 +140,27 @@ getMemrefConstantHorizontalStride(ShapedType type) {
 }
 
 // Return true if the transfer op can be converted to a MMA matrix load.
-static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp,
-                                              bool useNvGpu) {
+static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp) {
   if (readOp.getMask() || readOp.hasOutOfBoundsDim() ||
       readOp.getVectorType().getRank() != 2)
     return false;
-  if (!getMemrefConstantHorizontalStride(readOp.getShapedType()))
+  if (!getStaticallyKnownRowStride(readOp.getShapedType()))
     return false;
 
   // Only allow integer types if the signedness can be inferred.
-  if (!useNvGpu && readOp.getVectorType().getElementType().isInteger(8))
+  if (readOp.getVectorType().getElementType().isInteger(8))
     if (!readOp->hasOneUse() || (!isa<arith::ExtSIOp>(*readOp->user_begin()) &&
                                  !isa<arith::ExtUIOp>(*readOp->user_begin())))
       return false;
 
   AffineMap map = readOp.getPermutationMap();
-
   MLIRContext *ctx = readOp.getContext();
   AffineExpr innerDim = getAffineDimExpr(map.getNumDims() - 1, ctx);
   AffineExpr zero = getAffineConstantExpr(0, ctx);
   auto broadcastInnerDim =
       AffineMap::get(map.getNumDims(), 0, {zero, innerDim}, ctx);
-
-  if (!useNvGpu) {
-    bool result = map.isMinorIdentity() || map == broadcastInnerDim ||
-                  isTransposeMatrixLoadMap(map);
-    return result;
-  }
-
-  return true;
+  return map.isMinorIdentity() || map == broadcastInnerDim ||
+         isTransposeMatrixLoadMap(map);
 }
 
 // Return true if the transfer op can be converted to a MMA matrix store.
@@ -182,7 +173,7 @@ transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp) {
   if (writeOp.getMask() || writeOp.hasOutOfBoundsDim() ||
       writeOp.getVectorType().getRank() != 2)
     return false;
-  if (!getMemrefConstantHorizontalStride(writeOp.getShapedType()))
+  if (!getStaticallyKnownRowStride(writeOp.getShapedType()))
     return false;
   // TODO: Support transpose once it is added to GPU dialect ops.
   if (!writeOp.getPermutationMap().isMinorIdentity())
@@ -285,9 +276,11 @@ static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) {
   if (isa<scf::ForOp, scf::YieldOp>(op))
     return true;
   if (auto transferRead = dyn_cast<vector::TransferReadOp>(op))
-    return transferReadSupportsMMAMatrixType(transferRead, useNvGpu);
+    return useNvGpu ? nvgpu::canLowerToWarpMatrixOperation(transferRead)
+                    : transferReadSupportsMMAMatrixType(transferRead);
   if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op))
-    return transferWriteSupportsMMAMatrixType(transferWrite);
+    return useNvGpu ? nvgpu::canLowerToWarpMatrixOperation(transferWrite)
+                    : transferWriteSupportsMMAMatrixType(transferWrite);
   if (auto extractStridedSlice = dyn_cast<vector::ExtractStridedSliceOp>(op))
     return useNvGpu &&
            extractStridedSliceSupportsMMAMatrixType(extractStridedSlice);
@@ -372,9 +365,14 @@ static SetVector<Operation *> getOpToConvert(mlir::Operation *op,
     // chain. MMA matrix are stored in an opaque type so they cannot be used
     // by all operations.
     if (llvm::any_of(dependentOps, [useNvGpu](Operation *op) {
-          return !supportsMMaMatrixType(op, useNvGpu);
+          if (!supportsMMaMatrixType(op, useNvGpu)) {
+            LLVM_DEBUG(DBGS() << "cannot convert op: " << *op << "\n");
+            return true;
+          }
+          return false;
         }))
       return;
+
     opToConvert.insert(dependentOps.begin(), dependentOps.end());
   });
   // Sort the operations so that we can convert them in topological order.
@@ -537,10 +535,11 @@ convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op,
   rewriter.setInsertionPoint(op);
 
   assert(op.getTransferRank() > 0 && "unexpected 0-d transfer");
-  assert(transferReadSupportsMMAMatrixType(op, /*useNvGpu=*/false));
+  assert(transferReadSupportsMMAMatrixType(op) &&
+         "expected convertible operation");
 
   std::optional<int64_t> stride =
-      getMemrefConstantHorizontalStride(op.getShapedType());
+      getStaticallyKnownRowStride(op.getShapedType());
   if (!stride.has_value()) {
     LLVM_DEBUG(DBGS() << "no stride\n");
     return rewriter.notifyMatchFailure(op, "no stride");
@@ -591,7 +590,7 @@ convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op,
 
   assert(transferWriteSupportsMMAMatrixType(op));
   std::optional<int64_t> stride =
-      getMemrefConstantHorizontalStride(op.getShapedType());
+      getStaticallyKnownRowStride(op.getShapedType());
   if (!stride.has_value()) {
     LLVM_DEBUG(DBGS() << "no stride\n");
     return rewriter.notifyMatchFailure(op, "no stride");
@@ -1303,7 +1302,8 @@ LogicalResult mlir::convertVectorToNVVMCompatibleMMASync(RewriterBase &rewriter,
               return op->emitError() << "unhandled vector to mma type: " << *op;
             })
             .failed()) {
-      return op->emitError() << "Failed to convert op " << *op;
+      return op->emitOpError()
+             << "failed to convert op during vector-to-nvgpu conversion";
     }
   }
   return success();
@@ -1326,10 +1326,11 @@ struct ConvertVectorToGPUPass
       return signalPassFailure();
 
     IRRewriter rewriter(&getContext());
-    if (useNvGpu.getValue()) {
+    if (useNvGpu) {
       if (failed(
               convertVectorToNVVMCompatibleMMASync(rewriter, getOperation())))
         return signalPassFailure();
+      return;
     }
     (void)convertVectorToMMAOps(rewriter, getOperation());
   }

diff  --git a/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp b/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp
index 5a0018c3151767b..c500815857ca5bc 100644
--- a/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp
+++ b/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp
@@ -272,3 +272,54 @@ nvgpu::getLaneIdToLdMatrixMatrixCoord(OpBuilder &builder, Location loc,
 
   return failure();
 }
+
+bool nvgpu::canLowerToWarpMatrixOperation(vector::TransferReadOp op) {
+  if (op.getMask() || op.hasOutOfBoundsDim())
+    return false;
+  VectorType type = op.getType();
+  // The result type should be 2D. Note that it is possible to expand support so
+  // that we are robust to extra unit dimensions that failed to fold, but that
+  // would significantly increase downstream code complexity in the conversion
+  // step. For now, we rely on other patterns to ensure canonical 2D form is
+  // used when targeting the `nvgpu.mma.sync` lowering path.
+  if (!type.hasStaticShape() || type.getRank() != 2)
+    return false;
+
+  // Currently we can't support reads on tensor types because we need stride
+  // information to ensure correctness of downstream assumptions. It is possible
+  // to enable this if caller can assert that tensor will be lowered in a
+  // particular manner.
+  auto sourceType = dyn_cast<MemRefType>(op.getSource().getType());
+  if (!sourceType)
+    return false;
+
+  // Check that the last dimension of the read is contiguous. Note that it is
+  // possible to expand support for this by scalarizing all the loads during
+  // conversion.
+  auto [strides, offset] = mlir::getStridesAndOffset(sourceType);
+  return strides.back() == 1;
+}
+
+bool nvgpu::canLowerToWarpMatrixOperation(vector::TransferWriteOp op) {
+  if (op.getMask() || op.hasOutOfBoundsDim() || op.getTransferRank() == 0)
+    return false;
+  VectorType type = op.getVectorType();
+  if (!type.hasStaticShape() || type.getRank() != 2)
+    return false;
+  // TODO: Currently we rely on lowering to a `vector.store` operation. We could
+  // support the transposed write case by lowering to scalarized `memref.store`
+  // operations.
+  if (!op.getPermutationMap().isMinorIdentity())
+    return false;
+  // Currently we can't support reads on tensor types because we need stride
+  // information to ensure correctness of downstream assumptions.
+  auto sourceType = dyn_cast<MemRefType>(op.getSource().getType());
+  if (!sourceType)
+    return false;
+
+  // Check that the last dimension of the target memref is contiguous. Note that
+  // it is possible to expand support for this by scalarizing all the stores
+  // during conversion.
+  auto [strides, offset] = mlir::getStridesAndOffset(sourceType);
+  return strides.back() == 1;
+}

diff  --git a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir
index 4465819fc7fe406..81cb3e1d05fa3b2 100644
--- a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir
+++ b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir
@@ -47,7 +47,7 @@ func.func @m16n8k32_int8_row_row_row(%arg0: memref<128x128xi8, #gpu.address_spac
   // CHECK: nvgpu.ldmatrix %arg0[[[m_coord]], [[k_coord]]] {numTiles = 4 : i32, transpose = false} : memref<128x128xi8, #gpu.address_space<workgroup>> -> vector<4x4xi8>
 
   // Verify that the operandB load is lowered to scalar load to be able
-  // to transpose at 8-bit granularity. ldmatrix can only transpose at 
+  // to transpose at 8-bit granularity. ldmatrix can only transpose at
   // 16-bit granularity.
 
   // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB0_map]]()[{{%.+}}]
@@ -282,7 +282,7 @@ func.func @multi_dim_m16n8k16_fp16_row_row_row(%arg0: memref<4x32x1x32xf16, #gpu
   // CHECK-DAG: [[k_coord:%.+]] = affine.apply [[$strided_map]]
   // CHECK-DAG: [[fragmentB:%.+]] = nvgpu.ldmatrix %arg1[[[c0]], [[c0]], [[k_coord]], [[n_coord]]] {numTiles = 4 : i32, transpose = true}
   %B = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst {in_bounds = [true, true], permutation_map = #map_b} : memref<4x1x32x32xf16, #gpu.address_space<workgroup>>, vector<16x16xf16>
-  
+
   // CHECK-DAG: [[m_coord:%.+]] = affine.apply [[$strided_map]]
   // CHECK-DAG: [[n_coord:%.+]] = affine.apply [[$contiguous_map]]
   // CHECK-DAG: [[fragmentC:%.*]] = nvgpu.ldmatrix %arg2[[[c0]], [[m_coord]], [[n_coord]]] {numTiles = 4 : i32, transpose = false}
@@ -713,3 +713,125 @@ func.func @m16n8k32_int8_row_col_row(%arg0: memref<128x128xi8, #gpu.address_spac
   vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x8xi32>, memref<128x128xi32>
   return
 }
+
+// -----
+
+
+#map0 = affine_map<(d0, d1) -> (d1, d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+!smem_type = memref<20x20xf16, strided<[?, 1], offset: ?>, #gpu.address_space<workgroup>>
+
+// This test case is identical to m16n8k16 test case, but it tests that having
+// n row dimension with unknown stride is handled correctly.
+
+// CHECK-DAG: [[$strided_map:#.+]] = affine_map<()[s0] -> (s0 mod 16)>
+// CHECK-DAG: [[$contiguous_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 8)>
+// CHECK-LABEL: func @strided_memref_read_write
+func.func @strided_memref_read_write(%arg0: !smem_type,
+                                     %arg1: !smem_type,
+                                     %arg2: !smem_type) {
+  %cst_0 = arith.constant dense<0.000000e+00> : vector<16x8xf16>
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.000000e+00 : f16
+
+  // CHECK-DAG: [[m_coord:%.+]] = affine.apply [[$strided_map]]
+  // CHECK-DAG: [[k_coord:%.+]] = affine.apply [[$contiguous_map]]
+  // CHECK: nvgpu.ldmatrix %arg0[[[m_coord]], [[k_coord]]] {numTiles = 4 : i32, transpose = false}
+  // CHECK-DAG: [[n_coord:%.+]] = affine.apply [[$contiguous_map]]
+  // CHECK-DAG: [[k_coord:%.+]] = affine.apply [[$strided_map]]
+  // CHECK: nvgpu.ldmatrix %arg1[[[k_coord]], [[n_coord]]] {numTiles = 2 : i32, transpose = true}
+  %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : !smem_type, vector<16x16xf16>
+  %B = vector.transfer_read %arg1[%c0, %c0], %cst {permutation_map = #map0, in_bounds = [true, true]} : !smem_type, vector<8x16xf16>
+  %C = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : !smem_type, vector<16x8xf16>
+  %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+    %A, %B, %C : vector<16x16xf16>, vector<8x16xf16> into vector<16x8xf16>
+  vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x8xf16>, !smem_type
+  return
+}
+
+// -----
+
+
+#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d2, d0, d3)>
+#map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+!smem_type = memref<20x20x20xf16, strided<[?, ?, 1], offset: ?>, #gpu.address_space<workgroup>>
+
+// CHECK-LABEL: func @unsupported_non_2d_load_store
+func.func @unsupported_non_2d_load_store(%arg0: !smem_type,
+                           %arg1: !smem_type,
+                           %arg2: !smem_type) {
+  %cst_0 = arith.constant dense<0.000000e+00> : vector<16x8xf16>
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.000000e+00 : f16
+
+  // CHECK-NOT: nvgpu.ldmatrix
+  // CHECK-NOT: nvgpu.mma
+  %A = vector.transfer_read %arg0[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : !smem_type, vector<1x16x16xf16>
+  %B = vector.transfer_read %arg1[%c0, %c0, %c0], %cst {permutation_map = #map0, in_bounds = [true, true, true]} : !smem_type, vector<8x1x16xf16>
+  %C = vector.transfer_read %arg2[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : !smem_type, vector<1x16x8xf16>
+  %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+    %A, %B, %C : vector<1x16x16xf16>, vector<8x1x16xf16> into vector<1x16x8xf16>
+  vector.transfer_write %D, %arg2[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<1x16x8xf16>, !smem_type
+  return
+}
+
+// -----
+
+#map0 = affine_map<(d0, d1) -> (d1, d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+!smem_type = memref<20x20xf16, strided<[?, ?], offset: ?>, #gpu.address_space<workgroup>>
+
+// CHECK-LABEL: func @unsupported_fully_dynamic_strides
+func.func @unsupported_fully_dynamic_strides(%arg0: !smem_type,
+                                         %arg1: !smem_type,
+                                         %arg2: !smem_type) {
+  %cst_0 = arith.constant dense<0.000000e+00> : vector<16x8xf16>
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.000000e+00 : f16
+
+  // CHECK-NOT: nvgpu.ldmatrix
+  // CHECK-NOT: nvgpu.mma
+  %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : !smem_type, vector<16x16xf16>
+  %B = vector.transfer_read %arg1[%c0, %c0], %cst {permutation_map = #map0, in_bounds = [true, true]} : !smem_type, vector<8x16xf16>
+  %C = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : !smem_type, vector<16x8xf16>
+  %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+    %A, %B, %C : vector<16x16xf16>, vector<8x16xf16> into vector<16x8xf16>
+  vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x8xf16>, !smem_type
+  return
+}
+
+// -----
+
+#map0 = affine_map<(d0, d1) -> (d1, d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+
+!smem_type = memref<20x20xf16, strided<[?, 1], offset: ?>, #gpu.address_space<workgroup>>
+
+// CHECK-LABEL: func @unsupported_transposed_store
+func.func @unsupported_transposed_store(%arg0: !smem_type,
+                            %arg1: !smem_type,
+                            %arg2: !smem_type) {
+  %cst_0 = arith.constant dense<0.000000e+00> : vector<16x8xf16>
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.000000e+00 : f16
+
+  // CHECK-NOT: nvgpu.ldmatrix
+  // CHECK-NOT: nvgpu.mma
+  %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : !smem_type, vector<16x16xf16>
+  %B = vector.transfer_read %arg1[%c0, %c0], %cst {permutation_map = #map0, in_bounds = [true, true]} : !smem_type, vector<8x16xf16>
+  %C = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : !smem_type, vector<16x8xf16>
+  %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+    %A, %B, %C : vector<16x16xf16>, vector<8x16xf16> into vector<16x8xf16>
+  vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true], permutation_map = affine_map<(d0, d1)->(d1, d0)>} : vector<16x8xf16>, !smem_type
+  return
+}


        


More information about the Mlir-commits mailing list