[Mlir-commits] [mlir] 15ea230 - [mlir][NVGPU] Support N-D masks in transform.nvgpu.create_async_groups

Matthias Springer llvmlistbot at llvm.org
Tue Aug 8 05:36:02 PDT 2023


Author: Matthias Springer
Date: 2023-08-08T14:30:03+02:00
New Revision: 15ea2306a41ad52a73e03f51497d115733b8b9d3

URL: https://github.com/llvm/llvm-project/commit/15ea2306a41ad52a73e03f51497d115733b8b9d3
DIFF: https://github.com/llvm/llvm-project/commit/15ea2306a41ad52a73e03f51497d115733b8b9d3.diff

LOG: [mlir][NVGPU] Support N-D masks in transform.nvgpu.create_async_groups

Support IR that is generated by the vector-to-scf lowering of N-D vector transfers with a mask. (Until now only 1-D and 2-D transfers were supported.) Only transfers that were fully unrolled are supported.

Differential Revision: https://reviews.llvm.org/D157286

Added: 
    

Modified: 
    mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp
    mlir/test/Dialect/NVGPU/transform-create-async-groups.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp b/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp
index 3df99b1d374b10..ad2180d501148f 100644
--- a/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp
+++ b/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp
@@ -68,19 +68,16 @@ static FailureOr<TransferMask> getMaskOp(Operation *loadOp) {
           transferRead.getMask().getDefiningOp<vector::CreateMaskOp>())
     return TransferMask{maskOp, {}};
 
-  // Case 2: Mask is the result of a vector.extract(vector.create_mask). Only
-  // 2D -> 1D extracts are supported at the moment.
+  // Case 2: Mask is the result of a vector.extract(vector.create_mask).
   if (auto extractOp =
           transferRead.getMask().getDefiningOp<vector::ExtractOp>())
     if (auto maskOp =
             extractOp.getVector().getDefiningOp<vector::CreateMaskOp>())
-      if (extractOp.getPosition().size() == 1 &&
-          extractOp.getSourceVectorType().getRank() == 2)
-        return TransferMask{maskOp,
-                            SmallVector<int64_t>(extractOp.getPosition())};
+      return TransferMask{maskOp,
+                          SmallVector<int64_t>(extractOp.getPosition())};
 
   // All other cases: not supported.
-  return {};
+  return failure();
 }
 
 /// Build an SSA value that represents the number of read elements.
@@ -102,18 +99,27 @@ static Value buildNumReadElements(OpBuilder &b, Location loc,
 
   // vector.extract(vector.create_mask).
   // If extract_pos < num_ones, take number of elements from the least
-  // significant dimension.
-  assert(transferMask->createMaskOp.getVectorType().getRank() == 2 &&
-         "expected 2D mask");
-  assert(transferMask->extractPosition.size() == 1 &&
-         "expected 2D->1D extract");
-  Value cmp = b.create<arith::CmpIOp>(
-      loc, arith::CmpIPredicate::slt,
-      b.create<arith::ConstantIndexOp>(loc,
-                                       transferMask->extractPosition.front()),
-      transferMask->createMaskOp->getOperands().front());
+  // significant dimension. (Do this for all dimensions and bit-AND the
+  // conditions.)
+  assert(transferMask->createMaskOp.getVectorType().getRank() -
+                 transferMask->extractPosition.size() ==
+             1 &&
+         "expected N-D -> (N-1)-D extract");
+  Value cond;
+  // Note: There is one more `sz` than `pos`. The loop end with the last `pos`.
+  for (auto [pos, sz] : llvm::zip(transferMask->extractPosition,
+                                  transferMask->createMaskOp->getOperands())) {
+    Value cmp =
+        b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
+                                b.create<arith::ConstantIndexOp>(loc, pos), sz);
+    if (!cond) {
+      cond = cmp;
+      continue;
+    }
+    cond = b.create<arith::AndIOp>(loc, cmp, cond);
+  }
   return b.create<arith::SelectOp>(
-      loc, cmp, transferMask->createMaskOp->getOperands().back(),
+      loc, cond, transferMask->createMaskOp->getOperands().back(),
       b.create<arith::ConstantIndexOp>(loc, 0));
 }
 

diff  --git a/mlir/test/Dialect/NVGPU/transform-create-async-groups.mlir b/mlir/test/Dialect/NVGPU/transform-create-async-groups.mlir
index 9b8864219d693b..d4f46fd06c37e3 100644
--- a/mlir/test/Dialect/NVGPU/transform-create-async-groups.mlir
+++ b/mlir/test/Dialect/NVGPU/transform-create-async-groups.mlir
@@ -165,10 +165,6 @@ builtin.module {
     %0 = memref.alloc() : memref<4x32x16xf32, #gpu.address_space<workgroup>>
     %c0 = arith.constant 0 : index
     %cst_0 = arith.constant 0.000000e+00 : f32
-    // CHECK: %[[mask:.*]] = vector.create_mask
-    // CHECK: %[[e0:.*]] = vector.extract %[[mask]][0] : vector<3x4xi1>
-    // CHECK: %[[e1:.*]] = vector.extract %[[mask]][1] : vector<3x4xi1>
-    // CHECK: %[[e2:.*]] = vector.extract %[[mask]][2] : vector<3x4xi1>
 
     // CHECK: %[[cmpi0:.*]] = arith.cmpi slt, %[[c0]], %[[sz0]]
     // CHECK: %[[s0:.*]] = arith.select %[[cmpi0]], %[[sz1]], %[[c0]]
@@ -199,3 +195,64 @@ builtin.module {
     transform.apply_cse to %top_level_func_2 : !transform.any_op
   }
 }
+
+// -----
+
+// 3D vector.transfer_read with a mask.
+builtin.module {
+  // CHECK-LABEL: @read_3d_with_mask(
+  //  CHECK-SAME:     %[[sz0:.*]]: index, %[[sz1:.*]]: index, %[[sz2:.*]]: index, %[[a:.*]]: memref<1024x1024x1024xf32>
+  func.func @read_3d_with_mask(%sz0: index, %sz1: index, %sz2: index, %a: memref<1024x1024x1024xf32>) {
+    // CHECK: %[[c0:.*]] = arith.constant 0 : index
+    // CHECK: %[[c1:.*]] = arith.constant 1 : index
+    // CHECK: %[[c2:.*]] = arith.constant 2 : index
+    %0 = memref.alloc() : memref<4x32x16xf32, #gpu.address_space<workgroup>>
+    %c0 = arith.constant 0 : index
+    %cst_0 = arith.constant 0.000000e+00 : f32
+
+    // CHECK: %[[cmpi0:.*]] = arith.cmpi slt, %[[c0]], %[[sz0]]
+    // CHECK: %[[cmpi1:.*]] = arith.cmpi slt, %[[c0]], %[[sz1]]
+    // CHECK: %[[cond0:.*]] = arith.andi %[[cmpi1]], %[[cmpi0]]
+    // CHECK: %[[s0:.*]] = arith.select %[[cond0]], %[[sz2]], %[[c0]]
+    // CHECK: nvgpu.device_async_copy %[[a]][%[[c0]], %[[c0]], %[[c0]]], {{.*}}, 4, %[[s0]] {bypassL1}
+
+    // CHECK: %[[cmpi2:.*]] = arith.cmpi slt, %[[c1]], %[[sz1]]
+    // CHECK: %[[cond1:.*]] = arith.andi %[[cmpi2]], %[[cmpi0]]
+    // CHECK: %[[s1:.*]] = arith.select %[[cond1]], %[[sz2]], %[[c0]]
+    // CHECK: nvgpu.device_async_copy %[[a]][%[[c0]], %[[c1]], %[[c0]]], {{.*}}, 4, %[[s1]] {bypassL1}
+
+    // CHECK: %[[cmpi3:.*]] = arith.cmpi slt, %[[c2]], %[[sz1]]
+    // CHECK: %[[cond2:.*]] = arith.andi %[[cmpi3]], %[[cmpi0]]
+    // CHECK: %[[s2:.*]] = arith.select %[[cond2]], %[[sz2]], %[[c0]]
+    // CHECK: nvgpu.device_async_copy %[[a]][%[[c0]], %[[c2]], %[[c0]]], {{.*}}, 4, %[[s2]] {bypassL1}
+
+    // CHECK: %[[cmpi4:.*]] = arith.cmpi slt, %[[c1]], %[[sz0]]
+    // CHECK: %[[cond3:.*]] = arith.andi %[[cmpi1]], %[[cmpi4]]
+    // CHECK: %[[s3:.*]] = arith.select %[[cond3]], %[[sz2]], %[[c0]]
+    // CHECK: nvgpu.device_async_copy %[[a]][%[[c1]], %[[c0]], %[[c0]]], {{.*}}, 4, %[[s3]] {bypassL1}
+
+    // CHECK: %[[cond4:.*]] = arith.andi %[[cmpi2]], %[[cmpi4]]
+    // CHECK: %[[s4:.*]] = arith.select %[[cond4]], %[[sz2]], %[[c0]]
+    // CHECK: nvgpu.device_async_copy %[[a]][%[[c1]], %[[c1]], %[[c0]]], {{.*}}, 4, %[[s4]] {bypassL1}
+
+    // CHECK: %[[cond5:.*]] = arith.andi %[[cmpi3]], %[[cmpi4]]
+    // CHECK: %[[s5:.*]] = arith.select %[[cond5]], %[[sz2]], %[[c0]]
+    // CHECK: nvgpu.device_async_copy %[[a]][%[[c1]], %[[c2]], %[[c0]]], {{.*}}, 4, %[[s5]] {bypassL1}
+    %mask = vector.create_mask %sz0, %sz1, %sz2 : vector<2x3x4xi1>
+    %1 = vector.transfer_read %a[%c0, %c0, %c0], %cst_0, %mask {in_bounds = [true, true, true]} : memref<1024x1024x1024xf32>, vector<2x3x4xf32>
+    vector.transfer_write %1, %0[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<2x3x4xf32>, memref<4x32x16xf32, #gpu.address_space<workgroup>>
+
+    return
+  }
+
+  transform.sequence failures(propagate) {
+  ^bb1(%variant_op: !transform.any_op):
+    %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %top_level_func {
+      transform.apply_patterns.vector.transfer_to_scf max_transfer_rank = 1 full_unroll = true
+    } : !transform.any_op
+    transform.nvgpu.create_async_groups %top_level_func {bypass_l1} : (!transform.any_op) -> (!transform.any_op)
+    %top_level_func_2 = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+    transform.apply_cse to %top_level_func_2 : !transform.any_op
+  }
+}


        


More information about the Mlir-commits mailing list