[Mlir-commits] [mlir] 39d8876 - [mlir][NVGPU] Support 2D masks in transform.nvgpu.create_async_groups

Matthias Springer llvmlistbot at llvm.org
Mon Aug 7 06:47:36 PDT 2023


Author: Matthias Springer
Date: 2023-08-07T15:46:54+02:00
New Revision: 39d8876da363f8d8bde0d3ed65a4b127588fbc6e

URL: https://github.com/llvm/llvm-project/commit/39d8876da363f8d8bde0d3ed65a4b127588fbc6e
DIFF: https://github.com/llvm/llvm-project/commit/39d8876da363f8d8bde0d3ed65a4b127588fbc6e.diff

LOG: [mlir][NVGPU] Support 2D masks in transform.nvgpu.create_async_groups

Support IR that is generated by the vector-to-scf lowering of 2D vector transfers with a mask. Only 2D transfers that were fully unrolled are supported at the moment.

Differential Revision: https://reviews.llvm.org/D156695

Added: 
    

Modified: 
    mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp
    mlir/test/Dialect/NVGPU/transform-create-async-groups.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp b/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp
index fb64e4e32b400d..3df99b1d374b10 100644
--- a/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp
+++ b/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp
@@ -46,23 +46,79 @@ static bool isContiguousRead(Operation *read) {
   return isa<vector::LoadOp>(read);
 }
 
+namespace {
+/// A vector.create_mask op and extract position.
+struct TransferMask {
+  vector::CreateMaskOp createMaskOp;
+  SmallVector<int64_t> extractPosition;
+};
+} // namespace
+
 /// If the given vector load op has a mask that is defined by
 /// vector.create_mask, return that op.
-static vector::CreateMaskOp getMaskOp(Operation *loadOp) {
+static FailureOr<TransferMask> getMaskOp(Operation *loadOp) {
   auto transferRead = dyn_cast<vector::TransferReadOp>(loadOp);
   if (!transferRead || !transferRead.getMask())
-    return {};
-  auto maskOp = transferRead.getMask().getDefiningOp<vector::CreateMaskOp>();
-  // TODO: Support 2D masks and higher. Ops with a >1D mask are ignored at the
-  // moment.
-  if (maskOp.getVectorType().getRank() != 1)
-    return {};
-  return maskOp;
+    return TransferMask{{}, {}};
+  assert(transferRead.getMask().getType().getRank() == 1 &&
+         "expected 1-D mask");
+
+  // Case 1: Mask is the result of a vector.create_mask.
+  if (auto maskOp =
+          transferRead.getMask().getDefiningOp<vector::CreateMaskOp>())
+    return TransferMask{maskOp, {}};
+
+  // Case 2: Mask is the result of a vector.extract(vector.create_mask). Only
+  // 2D -> 1D extracts are supported at the moment.
+  if (auto extractOp =
+          transferRead.getMask().getDefiningOp<vector::ExtractOp>())
+    if (auto maskOp =
+            extractOp.getVector().getDefiningOp<vector::CreateMaskOp>())
+      if (extractOp.getPosition().size() == 1 &&
+          extractOp.getSourceVectorType().getRank() == 2)
+        return TransferMask{maskOp,
+                            SmallVector<int64_t>(extractOp.getPosition())};
+
+  // All other cases: not supported.
+  return {};
+}
+
+/// Build an SSA value that represents the number of read elements.
+static Value buildNumReadElements(OpBuilder &b, Location loc,
+                                  Operation *readOp) {
+  FailureOr<TransferMask> transferMask = getMaskOp(readOp);
+  assert(succeeded(transferMask) && "invalid transfer mask");
+
+  // No mask => no num_read_elements.
+  if (!transferMask->createMaskOp)
+    return Value();
+
+  // No extract: return size of "ones" segment in the mask.
+  if (transferMask->extractPosition.empty()) {
+    assert(transferMask->createMaskOp.getNumOperands() == 1 &&
+           "expected single operand");
+    return transferMask->createMaskOp.getOperand(0);
+  }
+
+  // vector.extract(vector.create_mask).
+  // If extract_pos < num_ones, take number of elements from the least
+  // significant dimension.
+  assert(transferMask->createMaskOp.getVectorType().getRank() == 2 &&
+         "expected 2D mask");
+  assert(transferMask->extractPosition.size() == 1 &&
+         "expected 2D->1D extract");
+  Value cmp = b.create<arith::CmpIOp>(
+      loc, arith::CmpIPredicate::slt,
+      b.create<arith::ConstantIndexOp>(loc,
+                                       transferMask->extractPosition.front()),
+      transferMask->createMaskOp->getOperands().front());
+  return b.create<arith::SelectOp>(
+      loc, cmp, transferMask->createMaskOp->getOperands().back(),
+      b.create<arith::ConstantIndexOp>(loc, 0));
 }
 
 /// Return "true" if the conversion to async copy is supported by "async copy".
 static bool resultsInSupportedAsyncCopy(MemRefType memrefType,
-                                        Operation::operand_range indices,
                                         VectorType vecType) {
   assert(vecType.getRank() == 1 && "expected 1-D vector");
   constexpr int64_t kSupportedCpAsyncAlignmentsInBytes[3] = {4, 8, 16};
@@ -121,7 +177,7 @@ void nvgpu::createAsyncGroups(RewriterBase &rewriter, Operation *op,
         if (getConstantIntValue(transferRead.getPadding()) ==
             static_cast<int64_t>(0))
           return;
-        if (!getMaskOp(readOp))
+        if (failed(getMaskOp(readOp)))
           return;
       }
     }
@@ -131,9 +187,9 @@ void nvgpu::createAsyncGroups(RewriterBase &rewriter, Operation *op,
     VectorType vecType = cast<VectorType>(vectorVal.getType());
 
     if (!resultsInSupportedAsyncCopy(cast<MemRefType>(loadBase.getType()),
-                                     nvgpu::getIndices(readOp), vecType) ||
+                                     vecType) ||
         !resultsInSupportedAsyncCopy(cast<MemRefType>(storeBase.getType()),
-                                     nvgpu::getIndices(writeOp), vecType))
+                                     vecType))
       return;
 
     copyToSharedMem.insert(writeOp);
@@ -184,11 +240,8 @@ void nvgpu::createAsyncGroups(RewriterBase &rewriter, Operation *op,
       Operation *readOp = vectorVal.getDefiningOp();
       Value storeBase = nvgpu::getMemrefOperand(writeOp);
       Value loadBase = nvgpu::getMemrefOperand(readOp);
-      Value numReadElements;
-      if (vector::CreateMaskOp maskOp = getMaskOp(readOp)) {
-        assert(maskOp.getNumOperands() == 1 && "expected single operand");
-        numReadElements = maskOp.getOperand(0);
-      }
+      Value numReadElements =
+          buildNumReadElements(rewriter, writeOp->getLoc(), readOp);
       auto dstMemref = cast<MemRefType>(storeBase.getType());
       int64_t sizeInBytes =
           (dstMemref.getElementTypeBitWidth() * numElements) / 8;

diff  --git a/mlir/test/Dialect/NVGPU/transform-create-async-groups.mlir b/mlir/test/Dialect/NVGPU/transform-create-async-groups.mlir
index e378d1c578c36f..9b8864219d693b 100644
--- a/mlir/test/Dialect/NVGPU/transform-create-async-groups.mlir
+++ b/mlir/test/Dialect/NVGPU/transform-create-async-groups.mlir
@@ -151,3 +151,51 @@ builtin.module {
     transform.nvgpu.create_async_groups %top_level_func {bypass_l1} : (!transform.any_op) -> (!transform.any_op)
   }
 }
+
+// -----
+
+// 2D vector.transfer_read with a mask.
+builtin.module {
+  // CHECK-LABEL: @read_2d_with_mask(
+  //  CHECK-SAME:     %[[sz0:.*]]: index, %[[sz1:.*]]: index, %[[a:.*]]: memref<1024x1024xf32>
+  func.func @read_2d_with_mask(%sz0: index, %sz1: index, %a: memref<1024x1024xf32>) {
+    // CHECK: %[[c0:.*]] = arith.constant 0 : index
+    // CHECK: %[[c1:.*]] = arith.constant 1 : index
+    // CHECK: %[[c2:.*]] = arith.constant 2 : index
+    %0 = memref.alloc() : memref<4x32x16xf32, #gpu.address_space<workgroup>>
+    %c0 = arith.constant 0 : index
+    %cst_0 = arith.constant 0.000000e+00 : f32
+    // CHECK: %[[mask:.*]] = vector.create_mask
+    // CHECK: %[[e0:.*]] = vector.extract %[[mask]][0] : vector<3x4xi1>
+    // CHECK: %[[e1:.*]] = vector.extract %[[mask]][1] : vector<3x4xi1>
+    // CHECK: %[[e2:.*]] = vector.extract %[[mask]][2] : vector<3x4xi1>
+
+    // CHECK: %[[cmpi0:.*]] = arith.cmpi slt, %[[c0]], %[[sz0]]
+    // CHECK: %[[s0:.*]] = arith.select %[[cmpi0]], %[[sz1]], %[[c0]]
+    // CHECK: nvgpu.device_async_copy %[[a]][%[[c0]], %[[c0]]], {{.*}}, 4, %[[s0]] {bypassL1}
+
+    // CHECK: %[[cmpi1:.*]] = arith.cmpi slt, %[[c1]], %[[sz0]]
+    // CHECK: %[[s1:.*]] = arith.select %[[cmpi1]], %[[sz1]], %[[c0]]
+    // CHECK: nvgpu.device_async_copy %[[a]][%[[c1]], %[[c0]]], {{.*}}, 4, %[[s1]] {bypassL1}
+
+    // CHECK: %[[cmpi2:.*]] = arith.cmpi slt, %[[c2]], %[[sz0]]
+    // CHECK: %[[s2:.*]] = arith.select %[[cmpi2]], %[[sz1]], %[[c0]]
+    // CHECK: nvgpu.device_async_copy %[[a]][%[[c2]], %[[c0]]], {{.*}}, 4, %[[s2]] {bypassL1}
+    %mask = vector.create_mask %sz0, %sz1 : vector<3x4xi1>
+    %1 = vector.transfer_read %a[%c0, %c0], %cst_0, %mask {in_bounds = [true, true]} : memref<1024x1024xf32>, vector<3x4xf32>
+    vector.transfer_write %1, %0[%c0, %c0, %c0] {in_bounds = [true, true]} : vector<3x4xf32>, memref<4x32x16xf32, #gpu.address_space<workgroup>>
+
+    return
+  }
+
+  transform.sequence failures(propagate) {
+  ^bb1(%variant_op: !transform.any_op):
+    %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %top_level_func {
+      transform.apply_patterns.vector.transfer_to_scf max_transfer_rank = 1 full_unroll = true
+    } : !transform.any_op
+    transform.nvgpu.create_async_groups %top_level_func {bypass_l1} : (!transform.any_op) -> (!transform.any_op)
+    %top_level_func_2 = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+    transform.apply_cse to %top_level_func_2 : !transform.any_op
+  }
+}


        


More information about the Mlir-commits mailing list