[Mlir-commits] [mlir] a011943 - [mlir][gpu] Support arith.extf in subgroup MMA elementwise ops
Lei Zhang
llvmlistbot at llvm.org
Tue Aug 1 21:16:38 PDT 2023
Author: Lei Zhang
Date: 2023-08-01T21:12:37-07:00
New Revision: a01194377c8ae3178a4dd22dbc420caed9bba21d
URL: https://github.com/llvm/llvm-project/commit/a01194377c8ae3178a4dd22dbc420caed9bba21d
DIFF: https://github.com/llvm/llvm-project/commit/a01194377c8ae3178a4dd22dbc420caed9bba21d.diff
LOG: [mlir][gpu] Support arith.extf in subgroup MMA elementwise ops
This commit adds support for arith.extf in the supported list of
elementwise ops for subgroup MMA ops, and enables lowering to
SPIR-V.
Reviewed By: mravishankar
Differential Revision: https://reviews.llvm.org/D156847
Added:
Modified:
mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir
mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index e3cd604fcc30ce..3b20ba2b46e351 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -1472,6 +1472,7 @@ def GPU_ElementwiseOpDivS : I32EnumAttrCase<"DIVS", 9, "divs">;
def GPU_ElementwiseOpDivU : I32EnumAttrCase<"DIVU", 10, "divu">;
def GPU_ElementwiseOpNEGF : I32EnumAttrCase<"NEGATEF", 11, "negatef">;
def GPU_ElementwiseOpNEGS : I32EnumAttrCase<"NEGATES", 12, "negates">;
+def GPU_ElementwiseOpEXTF : I32EnumAttrCase<"EXTF", 13, "extf">;
def MMAElementWise : I32EnumAttr<"MMAElementwiseOp",
"elementwise operation to apply to mma matrix", [
@@ -1487,7 +1488,8 @@ def MMAElementWise : I32EnumAttr<"MMAElementwiseOp",
GPU_ElementwiseOpDivS,
GPU_ElementwiseOpDivU,
GPU_ElementwiseOpNEGF,
- GPU_ElementwiseOpNEGS
+ GPU_ElementwiseOpNEGS,
+ GPU_ElementwiseOpEXTF
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::gpu";
diff --git a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
index d64fa6ac4ece22..57e21530b9da76 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
@@ -61,6 +61,9 @@ static bool createElementwiseOp(ConversionPatternRewriter &builder,
case gpu::MMAElementwiseOp::NEGATES:
builder.replaceOpWithNewOp<spirv::SNegateOp>(op, coopType, operands);
return true;
+ case gpu::MMAElementwiseOp::EXTF:
+ builder.replaceOpWithNewOp<spirv::FConvertOp>(op, coopType, operands);
+ return true;
default:
break;
}
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 409e9365a9f207..2c18de1c5b662e 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -214,6 +214,8 @@ static bool integerExtendSupportsMMAMatrixType(ExtOpTy extOp) {
});
}
+static bool fpExtendSupportsMMAMatrixType(arith::ExtFOp extOp) { return true; }
+
/// Return the MMA elementwise enum associated with `op` if it is supported.
/// Return `std::nullopt` otherwise.
static std::optional<gpu::MMAElementwiseOp>
@@ -242,6 +244,8 @@ convertElementwiseOpToMMA(Operation *op) {
return gpu::MMAElementwiseOp::DIVU;
if (isa<arith::NegFOp>(op))
return gpu::MMAElementwiseOp::NEGATEF;
+ if (isa<arith::ExtFOp>(op))
+ return gpu::MMAElementwiseOp::EXTF;
return std::nullopt;
}
@@ -297,6 +301,8 @@ static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) {
return integerExtendSupportsMMAMatrixType<arith::ExtSIOp>(signedExtend);
if (auto unsignedExtend = dyn_cast<arith::ExtUIOp>(op))
return integerExtendSupportsMMAMatrixType<arith::ExtUIOp>(unsignedExtend);
+ if (auto fpExtend = dyn_cast<arith::ExtFOp>(op))
+ return fpExtendSupportsMMAMatrixType(fpExtend);
return elementwiseSupportsMMAMatrixType(op);
}
@@ -1203,8 +1209,17 @@ convertElementwiseOp(RewriterBase &rewriter, Operation *op,
return rewriter.notifyMatchFailure(op, "no mapping");
matrixOperands.push_back(it->second);
}
+ auto resultType = matrixOperands[0].getType().cast<gpu::MMAMatrixType>();
+ if (opType == gpu::MMAElementwiseOp::EXTF) {
+ // The floating point extension case has a
diff erent result type.
+ auto vectorType = op->getResultTypes()[0].cast<VectorType>();
+ resultType = gpu::MMAMatrixType::get(resultType.getShape(),
+ vectorType.getElementType(),
+ resultType.getOperand());
+ }
+
Value newOp = rewriter.create<gpu::SubgroupMmaElementwiseOp>(
- op->getLoc(), matrixOperands[0].getType(), matrixOperands, opType);
+ op->getLoc(), resultType, matrixOperands, opType);
valueMapping[op->getResult(0)] = newOp;
return success();
}
diff --git a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir
index 12b6a2eb94268c..a53eca65fc9869 100644
--- a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir
@@ -142,6 +142,8 @@ module attributes {
%D = gpu.subgroup_mma_elementwise negatef %C : (!gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
// CHECK: {{%.*}} = spirv.FDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<16x16xf16, Subgroup>
%E = gpu.subgroup_mma_elementwise divf %D, %A : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
+ // CHECK: {{%.*}} = spirv.FConvert {{%.*}} : !spirv.NV.coopmatrix<16x16xf16, Subgroup> to !spirv.NV.coopmatrix<16x16xf32, Subgroup>
+ %F = gpu.subgroup_mma_elementwise extf %E : (!gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf32, "COp">
// CHECK: spirv.Return
gpu.return
}
diff --git a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
index 08f7e12cf55d98..fa9fff2dad6649 100644
--- a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
+++ b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
@@ -437,4 +437,26 @@ func.func @matmul_mixed_signedness_int8(%arg0: memref<16x32xi8>, %arg1: memref<1
%D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %Ae, %Be, %C : vector<16x32xi32>, vector<16x32xi32> into vector<16x16xi32>
vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xi32>, memref<16x16xi32>
return
-}
\ No newline at end of file
+}
+
+// -----
+
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: func @cast_f16_to_f32_write
+// CHECK: %[[COMPUTE:.+]] = gpu.subgroup_mma_compute
+// CHECK: %[[EXT:.+]] = gpu.subgroup_mma_elementwise extf %[[COMPUTE]] : (!gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf32, "COp">
+// CHECK: gpu.subgroup_mma_store_matrix %[[EXT]]
+func.func @cast_f16_to_f32_write(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<16x16xf16>, %arg3: memref<16x16xf32>) {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.000000e+00 : f16
+ %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
+ %B = vector.transfer_read %arg1[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
+ %C = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
+ %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %C : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
+ %cast = arith.extf %D : vector<16x16xf16> to vector<16x16xf32>
+ vector.transfer_write %cast, %arg3[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf32>, memref<16x16xf32>
+ return
+}
More information about the Mlir-commits
mailing list