[Mlir-commits] [mlir] 428a62f - [mlir][gpu] Add op to create MMA constant matrix
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jun 10 08:34:16 PDT 2021
Author: thomasraoux
Date: 2021-06-10T08:34:04-07:00
New Revision: 428a62f65f16f1640b1bfe033d20e6a4f545dd3e
URL: https://github.com/llvm/llvm-project/commit/428a62f65f16f1640b1bfe033d20e6a4f545dd3e
DIFF: https://github.com/llvm/llvm-project/commit/428a62f65f16f1640b1bfe033d20e6a4f545dd3e.diff
LOG: [mlir][gpu] Add op to create MMA constant matrix
This allow creating a matrix with all elements set to a given value. This is
needed to be able to implement a simple dot op.
Differential Revision: https://reviews.llvm.org/D103870
Added:
Modified:
mlir/include/mlir/Dialect/GPU/GPUOps.td
mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir
mlir/test/Dialect/GPU/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td
index 8e2520b675ae6..1e78e4af4d51a 100644
--- a/mlir/include/mlir/Dialect/GPU/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td
@@ -1022,4 +1022,49 @@ def GPU_SubgroupMmaComputeOp : GPU_Op<"subgroup_mma_compute",
let verifier = [{ return ::verify(*this); }];
}
+def GPU_SubgroupMmaConstantMatrixOp : GPU_Op<"subgroup_mma_constant_matrix",
+ [NoSideEffect,
+ TypesMatchWith<"value type matches element type of mma_matrix",
+ "res", "value",
+ "$_self.cast<gpu::MMAMatrixType>().getElementType()">]>{
+
+ let summary = "GPU warp synchronous constant matrix";
+
+ let description = [{
+ The `gpu.subgroup_mma_constant_matrix` creates a `!gpu.mma_matrix` with
+ constant elements.
+
+ The operation takes a scalar input and return a `!gpu.mma_matrix` where each
+ element of is equal to the operand constant. The destination mma_matrix type
+ must have elememt type equal to the constant type. Since the layout of
+ `!gpu.mma_matrix` is opaque this only support setting all the elements to
+ the same value.
+
+ This op is meant to be used along with `gpu.subgroup_mma_compute`.
+
+ Example:
+
+ ```mlir
+ %0 = gpu.subgroup_mma_constant_matrix %a :
+ !gpu.mma_matrix<16x16xf16, "AOp">
+ %1 = gpu.subgroup_mma_constant_matrix %b :
+ !gpu.mma_matrix<16x16xf32, "COp">
+ ```
+ }];
+
+ let arguments = (ins AnyTypeOf<[F16, F32]>:$value);
+
+ let results = (outs GPU_MMAMatrix:$res);
+
+ let extraClassDeclaration = [{
+ gpu::MMAMatrixType getType() {
+ return res().getType().cast<gpu::MMAMatrixType>();
+ }
+ }];
+
+ let assemblyFormat = [{
+ $value attr-dict `:` type($res)
+ }];
+}
+
#endif // GPU_OPS
diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
index d72c8c217f86c..d46a185dec22c 100644
--- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
@@ -348,12 +348,52 @@ struct WmmaMmaOpToNVVMLowering
}
};
+/// Convert GPU MMA ConstantMatrixOp to a chain of InsertValueOp.
+struct WmmaConstantOpToNVVMLowering
+ : public ConvertOpToLLVMPattern<gpu::SubgroupMmaConstantMatrixOp> {
+ using ConvertOpToLLVMPattern<
+ gpu::SubgroupMmaConstantMatrixOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp subgroupMmaConstantOp,
+ ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ if (failed(areAllLLVMTypes(subgroupMmaConstantOp.getOperation(), operands,
+ rewriter)))
+ return failure();
+ Location loc = subgroupMmaConstantOp.getLoc();
+ Value cst = operands[0];
+ LLVM::LLVMStructType type = convertMMAToLLVMType(
+ subgroupMmaConstantOp.getType().cast<gpu::MMAMatrixType>());
+ // If the element type is a vector create a vector from the operand.
+ if (auto vecType = type.getBody()[0].dyn_cast<VectorType>()) {
+ Value vecCst = rewriter.create<LLVM::UndefOp>(loc, vecType);
+ for (int64_t vecEl = 0; vecEl < vecType.getNumElements(); vecEl++) {
+ Value idx = rewriter.create<LLVM::ConstantOp>(
+ loc, typeConverter->convertType(rewriter.getIntegerType(32)),
+ rewriter.getI32ArrayAttr(vecEl));
+ vecCst = rewriter.create<LLVM::InsertElementOp>(loc, vecType, vecCst,
+ cst, idx);
+ }
+ cst = vecCst;
+ }
+ Value matrixStruct = rewriter.create<LLVM::UndefOp>(loc, type);
+ for (size_t i : llvm::seq(size_t(0), type.getBody().size())) {
+ matrixStruct = rewriter.create<LLVM::InsertValueOp>(
+ loc, matrixStruct, cst, rewriter.getI32ArrayAttr(i));
+ }
+ rewriter.replaceOp(subgroupMmaConstantOp, matrixStruct);
+ return success();
+ }
+};
+
} // anonymous namespace
namespace mlir {
void populateGpuWMMAToNVVMConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns) {
patterns.insert<WmmaLoadOpToNVVMLowering, WmmaMmaOpToNVVMLowering,
- WmmaStoreOpToNVVMLowering>(converter);
+ WmmaStoreOpToNVVMLowering, WmmaConstantOpToNVVMLowering>(
+ converter);
}
} // namespace mlir
diff --git a/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir
index de5d0d3fcf1c0..f692dffdfcbad 100644
--- a/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir
@@ -151,3 +151,28 @@ gpu.module @test_module {
return
}
}
+
+
+// -----
+
+gpu.module @test_module {
+
+// CHECK-LABEL: func @gpu_wmma_constant_op
+// CHECK: %[[CST:.+]] = llvm.mlir.constant(1.000000e+00 : f16) : f16
+// CHECK: %[[V0:.+]] = llvm.mlir.undef : vector<2xf16>
+// CHECK: %[[C0:.+]] = llvm.mlir.constant([0 : i32]) : i32
+// CHECK: %[[V1:.+]] = llvm.insertelement %[[CST]], %[[V0]][%[[C0]] : i32] : vector<2xf16>
+// CHECK: %[[C1:.+]] = llvm.mlir.constant([1 : i32]) : i32
+// CHECK: %[[V2:.+]] = llvm.insertelement %[[CST]], %[[V1]][%[[C1]] : i32] : vector<2xf16>
+// CHECK: %[[M0:.+]] = llvm.mlir.undef : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+// CHECK: %[[M1:.+]] = llvm.insertvalue %[[V2]], %[[M0]][0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+// CHECK: %[[M2:.+]] = llvm.insertvalue %[[V2]], %[[M1]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+// CHECK: %[[M3:.+]] = llvm.insertvalue %[[V2]], %[[M2]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+// CHECK: %[[M4:.+]] = llvm.insertvalue %[[V2]], %[[M3]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+// CHECK: llvm.return %[[M4]] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+ func @gpu_wmma_constant_op() ->(!gpu.mma_matrix<16x16xf16, "COp">) {
+ %cst = constant 1.0 : f16
+ %C = gpu.subgroup_mma_constant_matrix %cst : !gpu.mma_matrix<16x16xf16, "COp">
+ return %C : !gpu.mma_matrix<16x16xf16, "COp">
+ }
+}
diff --git a/mlir/test/Dialect/GPU/ops.mlir b/mlir/test/Dialect/GPU/ops.mlir
index a98fe1c496838..1bed13c4b21a4 100644
--- a/mlir/test/Dialect/GPU/ops.mlir
+++ b/mlir/test/Dialect/GPU/ops.mlir
@@ -201,8 +201,12 @@ module attributes {gpu.container_module} {
// CHECK: %[[wg:.*]] = memref.alloca()
%i = constant 16 : index
// CHECK: %[[i:.*]] = constant 16 : index
+ %cst = constant 1.000000e+00 : f32
+ // CHECK: %[[cst:.*]] = constant 1.000000e+00 : f32
%0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "AOp">
// CHECK: gpu.subgroup_mma_load_matrix %[[wg]][%[[i]], %[[i]]] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "AOp">
+ %1 = gpu.subgroup_mma_constant_matrix %cst : !gpu.mma_matrix<16x16xf32, "COp">
+ // CHECK: gpu.subgroup_mma_constant_matrix %[[cst]] : !gpu.mma_matrix<16x16xf32, "COp">
return
}
}
More information about the Mlir-commits
mailing list