[Mlir-commits] [mlir] e796924 - [mlir][VectorToGPU] Support more cases in conversion to MMA ops
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Nov 11 13:12:17 PST 2021
Author: Thomas Raoux
Date: 2021-11-11T13:10:38-08:00
New Revision: e7969240dce5e064f11abd44d7553ba9e9f27210
URL: https://github.com/llvm/llvm-project/commit/e7969240dce5e064f11abd44d7553ba9e9f27210
DIFF: https://github.com/llvm/llvm-project/commit/e7969240dce5e064f11abd44d7553ba9e9f27210.diff
LOG: [mlir][VectorToGPU] Support more cases in conversion to MMA ops
Support load with broadcast, elementwise divf op and remove the
hardcoded restriction on the vector size. Picking the right size should
be enfored by user and will fail conversion to llvm/spirv if it is not
supported.
Differential Revision: https://reviews.llvm.org/D113618
Added:
Modified:
mlir/include/mlir/Dialect/GPU/GPUOps.td
mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td
index 5e4d122c69eaa..adefba30d9a8e 100644
--- a/mlir/include/mlir/Dialect/GPU/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td
@@ -1130,11 +1130,12 @@ def GPU_ELEMENTWISE_OP_ADD : StrEnumAttrCase<"ADDF">;
def GPU_ELEMENTWISE_OP_MUL : StrEnumAttrCase<"MULF">;
def GPU_ELEMENTWISE_OP_MAXF : StrEnumAttrCase<"MAXF">;
def GPU_ELEMENTWISE_OP_MINF : StrEnumAttrCase<"MINF">;
+def GPU_ELEMENTWISE_OP_DIVF : StrEnumAttrCase<"DIVF">;
def MMAElementWiseAttr : StrEnumAttr<"MMAElementwiseOp",
"elementwise operation to apply to mma matrix",
[GPU_ELEMENTWISE_OP_ADD, GPU_ELEMENTWISE_OP_MUL,
- GPU_ELEMENTWISE_OP_MAXF, GPU_ELEMENTWISE_OP_MINF]> {
+ GPU_ELEMENTWISE_OP_MAXF, GPU_ELEMENTWISE_OP_MINF, GPU_ELEMENTWISE_OP_DIVF]> {
let cppNamespace = "::mlir::gpu";
let storageType = "::mlir::StringAttr";
let returnType = "::mlir::gpu::MMAElementwiseOp";
diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
index 6de739088b896..55935e7f971f7 100644
--- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
@@ -304,6 +304,8 @@ static Value createScalarOp(OpBuilder &builder, Location loc,
return builder.create<LLVM::FAddOp>(loc, operands[0].getType(), operands);
case gpu::MMAElementwiseOp::MULF:
return builder.create<LLVM::FMulOp>(loc, operands[0].getType(), operands);
+ case gpu::MMAElementwiseOp::DIVF:
+ return builder.create<LLVM::FDivOp>(loc, operands[0].getType(), operands);
case gpu::MMAElementwiseOp::MAXF:
return createMinMaxF(builder, loc, operands[0], operands[1],
/*isMin=*/false);
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index a9f3c7da9c842..18f472634480d 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -50,26 +50,7 @@ static bool contractSupportsMMAMatrixType(vector::ContractionOp contract) {
if (contract.getIndexingMaps() != infer({{m, k}, {k, n}, {m, n}}))
return false;
- // Check that the size matches what is natively supported.
- VectorType lhsType = contract.lhs().getType().cast<VectorType>();
- VectorType rhsType = contract.rhs().getType().cast<VectorType>();
- VectorType accType = contract.acc().getType().cast<VectorType>();
-
- std::tuple<int, int, int> dim(lhsType.getDimSize(0), rhsType.getDimSize(1),
- lhsType.getDimSize(1));
- if (lhsType.getElementType().isInteger(8) &&
- rhsType.getElementType().isInteger(8) &&
- accType.getElementType().isInteger(32) &&
- (dim == std::make_tuple(8, 8, 32) || dim == std::make_tuple(16, 16, 32) ||
- dim == std::make_tuple(16, 8, 32)))
- return true;
-
- if (lhsType.getElementType().isF16() && rhsType.getElementType().isF16() &&
- (accType.getElementType().isF16() || accType.getElementType().isF32()) &&
- (dim == std::make_tuple(8, 8, 16) || dim == std::make_tuple(16, 16, 16) ||
- dim == std::make_tuple(16, 8, 16)))
- return true;
- return false;
+ return true;
}
// Return the stide for the dimension 0 of |type| if it is a memref and has a
@@ -95,8 +76,15 @@ static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp) {
return false;
if (!getMemrefConstantHorizontalStride(readOp.getShapedType()))
return false;
+ AffineMap map = readOp.permutation_map();
+ OpBuilder b(readOp.getContext());
+ AffineExpr innerDim = b.getAffineDimExpr(map.getNumDims() - 1);
+ AffineExpr zero = b.getAffineConstantExpr(0);
+ auto broadcastInnerDim = AffineMap::get(map.getNumDims(), 0, {zero, innerDim},
+ readOp.getContext());
// TODO: Support transpose once it is added to GPU dialect ops.
- if (!readOp.permutation_map().isMinorIdentity())
+ // For now we only support (d0, d1) -> (d0, d1) and (d0, d1) -> (0, d1).
+ if (!map.isMinorIdentity() && map != broadcastInnerDim)
return false;
return true;
}
@@ -142,6 +130,8 @@ convertElementwiseOpToMMA(Operation *op) {
return gpu::MMAElementwiseOp::MAXF;
if (isa<MinFOp>(op))
return gpu::MMAElementwiseOp::MINF;
+ if (isa<arith::DivFOp>(op))
+ return gpu::MMAElementwiseOp::DIVF;
return llvm::None;
}
@@ -166,6 +156,44 @@ static bool supportsMMaMatrixType(Operation *op) {
return elementwiseSupportsMMAMatrixType(op);
}
+/// Return an unsorted slice handling scf.for region
diff erently than
+/// `getSlice`. In scf.for we only want to include as part of the slice elements
+/// that are part of the use/def chain.
+static SetVector<Operation *> getSliceContract(Operation *op,
+ TransitiveFilter backwardFilter,
+ TransitiveFilter forwardFilter) {
+ SetVector<Operation *> slice;
+ slice.insert(op);
+ unsigned currentIndex = 0;
+ SetVector<Operation *> backwardSlice;
+ SetVector<Operation *> forwardSlice;
+ while (currentIndex != slice.size()) {
+ auto *currentOp = (slice)[currentIndex];
+ // Compute and insert the backwardSlice starting from currentOp.
+ backwardSlice.clear();
+ getBackwardSlice(currentOp, &backwardSlice, backwardFilter);
+ slice.insert(backwardSlice.begin(), backwardSlice.end());
+
+ // Compute and insert the forwardSlice starting from currentOp.
+ forwardSlice.clear();
+ // Special case for ForOp, we don't want to include the whole region but
+ // only the value using the region arguments.
+ // TODO: We should refine this to only care about the region arguments being
+ // converted to matrix type.
+ if (auto forOp = dyn_cast<scf::ForOp>(currentOp)) {
+ for (Value forOpResult : forOp.getResults())
+ getForwardSlice(forOpResult, &forwardSlice, forwardFilter);
+ for (BlockArgument &arg : forOp.getRegionIterArgs())
+ getForwardSlice(arg, &forwardSlice, forwardFilter);
+ } else {
+ getForwardSlice(currentOp, &forwardSlice, forwardFilter);
+ }
+ slice.insert(forwardSlice.begin(), forwardSlice.end());
+ ++currentIndex;
+ }
+ return slice;
+}
+
// Analyze slice of operations based on convert op to figure out if the whole
// slice can be converted to MMA operations.
static SetVector<Operation *> getOpToConvert(mlir::Operation *op) {
@@ -182,16 +210,17 @@ static SetVector<Operation *> getOpToConvert(mlir::Operation *op) {
if (opToConvert.contains(contract.getOperation()))
return;
SetVector<Operation *> dependentOps =
- getSlice(contract, hasVectorDest, hasVectorSrc);
+ getSliceContract(contract, hasVectorDest, hasVectorSrc);
// If any instruction cannot use MMA matrix type drop the whole
- // chaine. MMA matrix are stored in an opaque type so they cannot be used
+ // chain. MMA matrix are stored in an opaque type so they cannot be used
// by all operations.
if (llvm::any_of(dependentOps,
[](Operation *op) { return !supportsMMaMatrixType(op); }))
return;
opToConvert.insert(dependentOps.begin(), dependentOps.end());
});
- return opToConvert;
+ // Sort the operations so that we can convert them in topological order.
+ return topologicalSort(opToConvert);
}
namespace {
@@ -309,6 +338,12 @@ static void convertTransferReadOp(vector::TransferReadOp op,
assert(transferReadSupportsMMAMatrixType(op));
Optional<int64_t> stride =
getMemrefConstantHorizontalStride(op.getShapedType());
+ AffineMap map = op.permutation_map();
+ // Handle broadcast by setting the stride to 0.
+ if (map.getResult(0).isa<AffineConstantExpr>()) {
+ assert(map.getResult(0).cast<AffineConstantExpr>().getValue() == 0);
+ stride = 0;
+ }
assert(stride);
const char *fragType = inferFragType(op);
gpu::MMAMatrixType type =
diff --git a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
index 2ca899fa5bac4..5e8f40f52b86b 100644
--- a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
+++ b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
@@ -106,3 +106,28 @@ func @matmul_fused_elementwise(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16
vector.transfer_write %E, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<16x16xf16>
return
}
+
+// CHECK-LABEL: func @matmul_fused_broadcast
+// CHECK-DAG: %[[CST_0:.+]] = arith.constant 0.000000e+00 : f16
+// CHECK-DAG: %[[A:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp">
+// CHECK-DAG: %[[B:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "BOp">
+// CHECK-DAG: %[[C0:.+]] = gpu.subgroup_mma_constant_matrix %[[CST_0]] : !gpu.mma_matrix<16x16xf16, "COp">
+// CHECK: %[[D:.+]] = gpu.subgroup_mma_compute %[[A]], %[[B]], %[[C0]] : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
+// CHECK: %[[E:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] {leadDimension = 0 : index} : memref<16x16x16x16xf16> -> !gpu.mma_matrix<16x16xf16, "COp">
+// CHECK: %[[F:.+]] = gpu.subgroup_mma_elementwise %[[D]], %[[E]] {operation = "DIVF"} : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
+// CHECK: gpu.subgroup_mma_store_matrix %[[F]], %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16>
+func @matmul_fused_broadcast(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>,
+ %arg2: memref<16x16xf16>, %arg3: memref<16x16x16x16xf16>) {
+ %cst_0 = arith.constant dense<0.000000e+00> : vector<16x16xf16>
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.000000e+00 : f16
+ %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
+ %B = vector.transfer_read %arg1[%c0, %c0], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
+ %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %cst_0 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
+ %E = vector.transfer_read %arg3[%c0, %c0, %c0, %c0], %cst
+ {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2, d3)->(0, d3)>}
+ : memref<16x16x16x16xf16>, vector<16x16xf16>
+ %F = arith.divf %D, %E : vector<16x16xf16>
+ vector.transfer_write %F, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<16x16xf16>
+ return
+}
More information about the Mlir-commits
mailing list