[Mlir-commits] [mlir] 15ea230 - [mlir][NVGPU] Support N-D masks in transform.nvgpu.create_async_groups
Matthias Springer
llvmlistbot at llvm.org
Tue Aug 8 05:36:02 PDT 2023
Author: Matthias Springer
Date: 2023-08-08T14:30:03+02:00
New Revision: 15ea2306a41ad52a73e03f51497d115733b8b9d3
URL: https://github.com/llvm/llvm-project/commit/15ea2306a41ad52a73e03f51497d115733b8b9d3
DIFF: https://github.com/llvm/llvm-project/commit/15ea2306a41ad52a73e03f51497d115733b8b9d3.diff
LOG: [mlir][NVGPU] Support N-D masks in transform.nvgpu.create_async_groups
Support IR that is generated by the vector-to-scf lowering of N-D vector transfers with a mask. (Until now only 1-D and 2-D transfers were supported.) Only transfers that were fully unrolled are supported.
Differential Revision: https://reviews.llvm.org/D157286
Added:
Modified:
mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp
mlir/test/Dialect/NVGPU/transform-create-async-groups.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp b/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp
index 3df99b1d374b10..ad2180d501148f 100644
--- a/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp
+++ b/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp
@@ -68,19 +68,16 @@ static FailureOr<TransferMask> getMaskOp(Operation *loadOp) {
transferRead.getMask().getDefiningOp<vector::CreateMaskOp>())
return TransferMask{maskOp, {}};
- // Case 2: Mask is the result of a vector.extract(vector.create_mask). Only
- // 2D -> 1D extracts are supported at the moment.
+ // Case 2: Mask is the result of a vector.extract(vector.create_mask).
if (auto extractOp =
transferRead.getMask().getDefiningOp<vector::ExtractOp>())
if (auto maskOp =
extractOp.getVector().getDefiningOp<vector::CreateMaskOp>())
- if (extractOp.getPosition().size() == 1 &&
- extractOp.getSourceVectorType().getRank() == 2)
- return TransferMask{maskOp,
- SmallVector<int64_t>(extractOp.getPosition())};
+ return TransferMask{maskOp,
+ SmallVector<int64_t>(extractOp.getPosition())};
// All other cases: not supported.
- return {};
+ return failure();
}
/// Build an SSA value that represents the number of read elements.
@@ -102,18 +99,27 @@ static Value buildNumReadElements(OpBuilder &b, Location loc,
// vector.extract(vector.create_mask).
// If extract_pos < num_ones, take number of elements from the least
- // significant dimension.
- assert(transferMask->createMaskOp.getVectorType().getRank() == 2 &&
- "expected 2D mask");
- assert(transferMask->extractPosition.size() == 1 &&
- "expected 2D->1D extract");
- Value cmp = b.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::slt,
- b.create<arith::ConstantIndexOp>(loc,
- transferMask->extractPosition.front()),
- transferMask->createMaskOp->getOperands().front());
+ // significant dimension. (Do this for all dimensions and bit-AND the
+ // conditions.)
+ assert(transferMask->createMaskOp.getVectorType().getRank() -
+ transferMask->extractPosition.size() ==
+ 1 &&
+ "expected N-D -> (N-1)-D extract");
+ Value cond;
+ // Note: There is one more `sz` than `pos`. The loop end with the last `pos`.
+ for (auto [pos, sz] : llvm::zip(transferMask->extractPosition,
+ transferMask->createMaskOp->getOperands())) {
+ Value cmp =
+ b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
+ b.create<arith::ConstantIndexOp>(loc, pos), sz);
+ if (!cond) {
+ cond = cmp;
+ continue;
+ }
+ cond = b.create<arith::AndIOp>(loc, cmp, cond);
+ }
return b.create<arith::SelectOp>(
- loc, cmp, transferMask->createMaskOp->getOperands().back(),
+ loc, cond, transferMask->createMaskOp->getOperands().back(),
b.create<arith::ConstantIndexOp>(loc, 0));
}
diff --git a/mlir/test/Dialect/NVGPU/transform-create-async-groups.mlir b/mlir/test/Dialect/NVGPU/transform-create-async-groups.mlir
index 9b8864219d693b..d4f46fd06c37e3 100644
--- a/mlir/test/Dialect/NVGPU/transform-create-async-groups.mlir
+++ b/mlir/test/Dialect/NVGPU/transform-create-async-groups.mlir
@@ -165,10 +165,6 @@ builtin.module {
%0 = memref.alloc() : memref<4x32x16xf32, #gpu.address_space<workgroup>>
%c0 = arith.constant 0 : index
%cst_0 = arith.constant 0.000000e+00 : f32
- // CHECK: %[[mask:.*]] = vector.create_mask
- // CHECK: %[[e0:.*]] = vector.extract %[[mask]][0] : vector<3x4xi1>
- // CHECK: %[[e1:.*]] = vector.extract %[[mask]][1] : vector<3x4xi1>
- // CHECK: %[[e2:.*]] = vector.extract %[[mask]][2] : vector<3x4xi1>
// CHECK: %[[cmpi0:.*]] = arith.cmpi slt, %[[c0]], %[[sz0]]
// CHECK: %[[s0:.*]] = arith.select %[[cmpi0]], %[[sz1]], %[[c0]]
@@ -199,3 +195,64 @@ builtin.module {
transform.apply_cse to %top_level_func_2 : !transform.any_op
}
}
+
+// -----
+
+// 3D vector.transfer_read with a mask.
+builtin.module {
+ // CHECK-LABEL: @read_3d_with_mask(
+ // CHECK-SAME: %[[sz0:.*]]: index, %[[sz1:.*]]: index, %[[sz2:.*]]: index, %[[a:.*]]: memref<1024x1024x1024xf32>
+ func.func @read_3d_with_mask(%sz0: index, %sz1: index, %sz2: index, %a: memref<1024x1024x1024xf32>) {
+ // CHECK: %[[c0:.*]] = arith.constant 0 : index
+ // CHECK: %[[c1:.*]] = arith.constant 1 : index
+ // CHECK: %[[c2:.*]] = arith.constant 2 : index
+ %0 = memref.alloc() : memref<4x32x16xf32, #gpu.address_space<workgroup>>
+ %c0 = arith.constant 0 : index
+ %cst_0 = arith.constant 0.000000e+00 : f32
+
+ // CHECK: %[[cmpi0:.*]] = arith.cmpi slt, %[[c0]], %[[sz0]]
+ // CHECK: %[[cmpi1:.*]] = arith.cmpi slt, %[[c0]], %[[sz1]]
+ // CHECK: %[[cond0:.*]] = arith.andi %[[cmpi1]], %[[cmpi0]]
+ // CHECK: %[[s0:.*]] = arith.select %[[cond0]], %[[sz2]], %[[c0]]
+ // CHECK: nvgpu.device_async_copy %[[a]][%[[c0]], %[[c0]], %[[c0]]], {{.*}}, 4, %[[s0]] {bypassL1}
+
+ // CHECK: %[[cmpi2:.*]] = arith.cmpi slt, %[[c1]], %[[sz1]]
+ // CHECK: %[[cond1:.*]] = arith.andi %[[cmpi2]], %[[cmpi0]]
+ // CHECK: %[[s1:.*]] = arith.select %[[cond1]], %[[sz2]], %[[c0]]
+ // CHECK: nvgpu.device_async_copy %[[a]][%[[c0]], %[[c1]], %[[c0]]], {{.*}}, 4, %[[s1]] {bypassL1}
+
+ // CHECK: %[[cmpi3:.*]] = arith.cmpi slt, %[[c2]], %[[sz1]]
+ // CHECK: %[[cond2:.*]] = arith.andi %[[cmpi3]], %[[cmpi0]]
+ // CHECK: %[[s2:.*]] = arith.select %[[cond2]], %[[sz2]], %[[c0]]
+ // CHECK: nvgpu.device_async_copy %[[a]][%[[c0]], %[[c2]], %[[c0]]], {{.*}}, 4, %[[s2]] {bypassL1}
+
+ // CHECK: %[[cmpi4:.*]] = arith.cmpi slt, %[[c1]], %[[sz0]]
+ // CHECK: %[[cond3:.*]] = arith.andi %[[cmpi1]], %[[cmpi4]]
+ // CHECK: %[[s3:.*]] = arith.select %[[cond3]], %[[sz2]], %[[c0]]
+ // CHECK: nvgpu.device_async_copy %[[a]][%[[c1]], %[[c0]], %[[c0]]], {{.*}}, 4, %[[s3]] {bypassL1}
+
+ // CHECK: %[[cond4:.*]] = arith.andi %[[cmpi2]], %[[cmpi4]]
+ // CHECK: %[[s4:.*]] = arith.select %[[cond4]], %[[sz2]], %[[c0]]
+ // CHECK: nvgpu.device_async_copy %[[a]][%[[c1]], %[[c1]], %[[c0]]], {{.*}}, 4, %[[s4]] {bypassL1}
+
+ // CHECK: %[[cond5:.*]] = arith.andi %[[cmpi3]], %[[cmpi4]]
+ // CHECK: %[[s5:.*]] = arith.select %[[cond5]], %[[sz2]], %[[c0]]
+ // CHECK: nvgpu.device_async_copy %[[a]][%[[c1]], %[[c2]], %[[c0]]], {{.*}}, 4, %[[s5]] {bypassL1}
+ %mask = vector.create_mask %sz0, %sz1, %sz2 : vector<2x3x4xi1>
+ %1 = vector.transfer_read %a[%c0, %c0, %c0], %cst_0, %mask {in_bounds = [true, true, true]} : memref<1024x1024x1024xf32>, vector<2x3x4xf32>
+ vector.transfer_write %1, %0[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<2x3x4xf32>, memref<4x32x16xf32, #gpu.address_space<workgroup>>
+
+ return
+ }
+
+ transform.sequence failures(propagate) {
+ ^bb1(%variant_op: !transform.any_op):
+ %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %top_level_func {
+ transform.apply_patterns.vector.transfer_to_scf max_transfer_rank = 1 full_unroll = true
+ } : !transform.any_op
+ transform.nvgpu.create_async_groups %top_level_func {bypass_l1} : (!transform.any_op) -> (!transform.any_op)
+ %top_level_func_2 = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+ transform.apply_cse to %top_level_func_2 : !transform.any_op
+ }
+}
More information about the Mlir-commits
mailing list