[Mlir-commits] [mlir] 39d8876 - [mlir][NVGPU] Support 2D masks in transform.nvgpu.create_async_groups
Matthias Springer
llvmlistbot at llvm.org
Mon Aug 7 06:47:36 PDT 2023
Author: Matthias Springer
Date: 2023-08-07T15:46:54+02:00
New Revision: 39d8876da363f8d8bde0d3ed65a4b127588fbc6e
URL: https://github.com/llvm/llvm-project/commit/39d8876da363f8d8bde0d3ed65a4b127588fbc6e
DIFF: https://github.com/llvm/llvm-project/commit/39d8876da363f8d8bde0d3ed65a4b127588fbc6e.diff
LOG: [mlir][NVGPU] Support 2D masks in transform.nvgpu.create_async_groups
Support IR that is generated by the vector-to-scf lowering of 2D vector transfers with a mask. Only 2D transfers that were fully unrolled are supported at the moment.
Differential Revision: https://reviews.llvm.org/D156695
Added:
Modified:
mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp
mlir/test/Dialect/NVGPU/transform-create-async-groups.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp b/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp
index fb64e4e32b400d..3df99b1d374b10 100644
--- a/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp
+++ b/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp
@@ -46,23 +46,79 @@ static bool isContiguousRead(Operation *read) {
return isa<vector::LoadOp>(read);
}
+namespace {
+/// A vector.create_mask op and extract position.
+struct TransferMask {
+ vector::CreateMaskOp createMaskOp;
+ SmallVector<int64_t> extractPosition;
+};
+} // namespace
+
/// If the given vector load op has a mask that is defined by
/// vector.create_mask, return that op.
-static vector::CreateMaskOp getMaskOp(Operation *loadOp) {
+static FailureOr<TransferMask> getMaskOp(Operation *loadOp) {
auto transferRead = dyn_cast<vector::TransferReadOp>(loadOp);
if (!transferRead || !transferRead.getMask())
- return {};
- auto maskOp = transferRead.getMask().getDefiningOp<vector::CreateMaskOp>();
- // TODO: Support 2D masks and higher. Ops with a >1D mask are ignored at the
- // moment.
- if (maskOp.getVectorType().getRank() != 1)
- return {};
- return maskOp;
+ return TransferMask{{}, {}};
+ assert(transferRead.getMask().getType().getRank() == 1 &&
+ "expected 1-D mask");
+
+ // Case 1: Mask is the result of a vector.create_mask.
+ if (auto maskOp =
+ transferRead.getMask().getDefiningOp<vector::CreateMaskOp>())
+ return TransferMask{maskOp, {}};
+
+ // Case 2: Mask is the result of a vector.extract(vector.create_mask). Only
+ // 2D -> 1D extracts are supported at the moment.
+ if (auto extractOp =
+ transferRead.getMask().getDefiningOp<vector::ExtractOp>())
+ if (auto maskOp =
+ extractOp.getVector().getDefiningOp<vector::CreateMaskOp>())
+ if (extractOp.getPosition().size() == 1 &&
+ extractOp.getSourceVectorType().getRank() == 2)
+ return TransferMask{maskOp,
+ SmallVector<int64_t>(extractOp.getPosition())};
+
+ // All other cases: not supported.
+ return {};
+}
+
+/// Build an SSA value that represents the number of read elements.
+static Value buildNumReadElements(OpBuilder &b, Location loc,
+ Operation *readOp) {
+ FailureOr<TransferMask> transferMask = getMaskOp(readOp);
+ assert(succeeded(transferMask) && "invalid transfer mask");
+
+ // No mask => no num_read_elements.
+ if (!transferMask->createMaskOp)
+ return Value();
+
+ // No extract: return size of "ones" segment in the mask.
+ if (transferMask->extractPosition.empty()) {
+ assert(transferMask->createMaskOp.getNumOperands() == 1 &&
+ "expected single operand");
+ return transferMask->createMaskOp.getOperand(0);
+ }
+
+ // vector.extract(vector.create_mask).
+ // If extract_pos < num_ones, take number of elements from the least
+ // significant dimension.
+ assert(transferMask->createMaskOp.getVectorType().getRank() == 2 &&
+ "expected 2D mask");
+ assert(transferMask->extractPosition.size() == 1 &&
+ "expected 2D->1D extract");
+ Value cmp = b.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::slt,
+ b.create<arith::ConstantIndexOp>(loc,
+ transferMask->extractPosition.front()),
+ transferMask->createMaskOp->getOperands().front());
+ return b.create<arith::SelectOp>(
+ loc, cmp, transferMask->createMaskOp->getOperands().back(),
+ b.create<arith::ConstantIndexOp>(loc, 0));
}
/// Return "true" if the conversion to async copy is supported by "async copy".
static bool resultsInSupportedAsyncCopy(MemRefType memrefType,
- Operation::operand_range indices,
VectorType vecType) {
assert(vecType.getRank() == 1 && "expected 1-D vector");
constexpr int64_t kSupportedCpAsyncAlignmentsInBytes[3] = {4, 8, 16};
@@ -121,7 +177,7 @@ void nvgpu::createAsyncGroups(RewriterBase &rewriter, Operation *op,
if (getConstantIntValue(transferRead.getPadding()) ==
static_cast<int64_t>(0))
return;
- if (!getMaskOp(readOp))
+ if (failed(getMaskOp(readOp)))
return;
}
}
@@ -131,9 +187,9 @@ void nvgpu::createAsyncGroups(RewriterBase &rewriter, Operation *op,
VectorType vecType = cast<VectorType>(vectorVal.getType());
if (!resultsInSupportedAsyncCopy(cast<MemRefType>(loadBase.getType()),
- nvgpu::getIndices(readOp), vecType) ||
+ vecType) ||
!resultsInSupportedAsyncCopy(cast<MemRefType>(storeBase.getType()),
- nvgpu::getIndices(writeOp), vecType))
+ vecType))
return;
copyToSharedMem.insert(writeOp);
@@ -184,11 +240,8 @@ void nvgpu::createAsyncGroups(RewriterBase &rewriter, Operation *op,
Operation *readOp = vectorVal.getDefiningOp();
Value storeBase = nvgpu::getMemrefOperand(writeOp);
Value loadBase = nvgpu::getMemrefOperand(readOp);
- Value numReadElements;
- if (vector::CreateMaskOp maskOp = getMaskOp(readOp)) {
- assert(maskOp.getNumOperands() == 1 && "expected single operand");
- numReadElements = maskOp.getOperand(0);
- }
+ Value numReadElements =
+ buildNumReadElements(rewriter, writeOp->getLoc(), readOp);
auto dstMemref = cast<MemRefType>(storeBase.getType());
int64_t sizeInBytes =
(dstMemref.getElementTypeBitWidth() * numElements) / 8;
diff --git a/mlir/test/Dialect/NVGPU/transform-create-async-groups.mlir b/mlir/test/Dialect/NVGPU/transform-create-async-groups.mlir
index e378d1c578c36f..9b8864219d693b 100644
--- a/mlir/test/Dialect/NVGPU/transform-create-async-groups.mlir
+++ b/mlir/test/Dialect/NVGPU/transform-create-async-groups.mlir
@@ -151,3 +151,51 @@ builtin.module {
transform.nvgpu.create_async_groups %top_level_func {bypass_l1} : (!transform.any_op) -> (!transform.any_op)
}
}
+
+// -----
+
+// 2D vector.transfer_read with a mask.
+builtin.module {
+ // CHECK-LABEL: @read_2d_with_mask(
+ // CHECK-SAME: %[[sz0:.*]]: index, %[[sz1:.*]]: index, %[[a:.*]]: memref<1024x1024xf32>
+ func.func @read_2d_with_mask(%sz0: index, %sz1: index, %a: memref<1024x1024xf32>) {
+ // CHECK: %[[c0:.*]] = arith.constant 0 : index
+ // CHECK: %[[c1:.*]] = arith.constant 1 : index
+ // CHECK: %[[c2:.*]] = arith.constant 2 : index
+ %0 = memref.alloc() : memref<4x32x16xf32, #gpu.address_space<workgroup>>
+ %c0 = arith.constant 0 : index
+ %cst_0 = arith.constant 0.000000e+00 : f32
+ // CHECK: %[[mask:.*]] = vector.create_mask
+ // CHECK: %[[e0:.*]] = vector.extract %[[mask]][0] : vector<3x4xi1>
+ // CHECK: %[[e1:.*]] = vector.extract %[[mask]][1] : vector<3x4xi1>
+ // CHECK: %[[e2:.*]] = vector.extract %[[mask]][2] : vector<3x4xi1>
+
+ // CHECK: %[[cmpi0:.*]] = arith.cmpi slt, %[[c0]], %[[sz0]]
+ // CHECK: %[[s0:.*]] = arith.select %[[cmpi0]], %[[sz1]], %[[c0]]
+ // CHECK: nvgpu.device_async_copy %[[a]][%[[c0]], %[[c0]]], {{.*}}, 4, %[[s0]] {bypassL1}
+
+ // CHECK: %[[cmpi1:.*]] = arith.cmpi slt, %[[c1]], %[[sz0]]
+ // CHECK: %[[s1:.*]] = arith.select %[[cmpi1]], %[[sz1]], %[[c0]]
+ // CHECK: nvgpu.device_async_copy %[[a]][%[[c1]], %[[c0]]], {{.*}}, 4, %[[s1]] {bypassL1}
+
+ // CHECK: %[[cmpi2:.*]] = arith.cmpi slt, %[[c2]], %[[sz0]]
+ // CHECK: %[[s2:.*]] = arith.select %[[cmpi2]], %[[sz1]], %[[c0]]
+ // CHECK: nvgpu.device_async_copy %[[a]][%[[c2]], %[[c0]]], {{.*}}, 4, %[[s2]] {bypassL1}
+ %mask = vector.create_mask %sz0, %sz1 : vector<3x4xi1>
+ %1 = vector.transfer_read %a[%c0, %c0], %cst_0, %mask {in_bounds = [true, true]} : memref<1024x1024xf32>, vector<3x4xf32>
+ vector.transfer_write %1, %0[%c0, %c0, %c0] {in_bounds = [true, true]} : vector<3x4xf32>, memref<4x32x16xf32, #gpu.address_space<workgroup>>
+
+ return
+ }
+
+ transform.sequence failures(propagate) {
+ ^bb1(%variant_op: !transform.any_op):
+ %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %top_level_func {
+ transform.apply_patterns.vector.transfer_to_scf max_transfer_rank = 1 full_unroll = true
+ } : !transform.any_op
+ transform.nvgpu.create_async_groups %top_level_func {bypass_l1} : (!transform.any_op) -> (!transform.any_op)
+ %top_level_func_2 = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+ transform.apply_cse to %top_level_func_2 : !transform.any_op
+ }
+}
More information about the Mlir-commits
mailing list