[Mlir-commits] [mlir] f309939 - [mlir][nvvm] Remove special case ptr arithmetic lowering in gpu to nvvm
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Nov 10 10:00:31 PST 2021
Author: thomasraoux
Date: 2021-11-10T10:00:12-08:00
New Revision: f309939d065a67cb2379059d82cb8a76d8b74e3c
URL: https://github.com/llvm/llvm-project/commit/f309939d065a67cb2379059d82cb8a76d8b74e3c
DIFF: https://github.com/llvm/llvm-project/commit/f309939d065a67cb2379059d82cb8a76d8b74e3c.diff
LOG: [mlir][nvvm] Remove special case ptr arithmetic lowering in gpu to nvvm
Use existing helper instead of handling only a subset of indices lowering
arithmetic. Also relax the restriction on the memref rank for the GPU mma ops
as we can now support any rank.
Differential Revision: https://reviews.llvm.org/D113383
Added:
Modified:
mlir/include/mlir/Dialect/GPU/GPUOps.td
mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td
index 3f1ad84278cb0..5e4d122c69eaa 100644
--- a/mlir/include/mlir/Dialect/GPU/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td
@@ -991,7 +991,7 @@ def GPU_SubgroupMmaLoadMatrixOp : GPU_Op<"subgroup_mma_load_matrix",
```
}];
- let arguments = (ins Arg<MemRefRankOf<[F16, F32], [2]>, "", [MemRead]>:$srcMemref,
+ let arguments = (ins Arg<MemRefOf<[F16, F32]>, "", [MemRead]>:$srcMemref,
Variadic<Index>:$indices,
IndexAttr:$leadDimension);
@@ -1031,7 +1031,7 @@ def GPU_SubgroupMmaStoreMatrixOp : GPU_Op<"subgroup_mma_store_matrix",
}];
let arguments = (ins Arg<MMAMatrixOf<[F16, F32]>>:$src,
- Arg<MemRefRankOf<[F16, F32], [2]>, "",[MemWrite]>:$dstMemref,
+ Arg<MemRefOf<[F16, F32]>, "",[MemWrite]>:$dstMemref,
Variadic<Index>:$indices,
IndexAttr:$leadDimension);
diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
index b0bf94b7f8066..6de739088b896 100644
--- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
@@ -77,44 +77,6 @@ struct WmmaLoadOpToNVVMLowering
if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
return failure();
- Location loc = op->getLoc();
-
- // MemRefDescriptor to extract alignedPtr and offset.
- MemRefDescriptor promotedSrcOp(adaptor.srcMemref());
-
- // Emit ops which compute the load offset using `srcOffsetI`,
- // `srcOffsetJ`. The actualOffset is (memrefOffset + (alignedPtr +
- // ((leadDimension * srcOffsetI) + srcOffsetJ)). The memrefs here are
- // assumed to be normalized and hence the simple conversion works.
- IntegerAttr leadDimension = subgroupMmaLoadMatrixOp.leadDimensionAttr();
- SmallVector<Value> indices(adaptor.indices());
- Value srcOffsetIVal = indices[0];
- Value srcOffsetJVal = indices[1];
- Value leadingDim = rewriter.create<LLVM::ConstantOp>(
- loc, srcOffsetIVal.getType(), leadDimension);
- Value numElemsLeadDim =
- rewriter.create<LLVM::MulOp>(loc, leadingDim, srcOffsetIVal);
- Value loadOffset =
- rewriter.create<LLVM::AddOp>(loc, numElemsLeadDim, srcOffsetJVal);
-
- Value promotedSrcOpToUse;
- promotedSrcOpToUse = promotedSrcOp.offset(rewriter, loc);
- Value actualOffset =
- rewriter.create<LLVM::AddOp>(loc, loadOffset, promotedSrcOpToUse);
- Value loadAddress = rewriter.create<LLVM::GEPOp>(
- loc, promotedSrcOp.getElementPtrType(),
- promotedSrcOp.alignedPtr(rewriter, loc), ArrayRef<Value>{actualOffset});
-
- // Bitcast the base address pointer of the destination memref, So that
- // values can be stored in chunks of 32-bits and semantics match with the
- // intrinsic exposed by NVPTX backend.
- Value loadAddressCasted = rewriter.create<LLVM::BitcastOp>(
- loc,
- LLVM::LLVMPointerType::get(
- rewriter.getI32Type(),
- promotedSrcOp.getElementPtrType().getAddressSpace()),
- loadAddress);
-
// Get the shape of the MMAMatrix type being returned. The shape will
// choose which intrinsic this op will be lowered to.
gpu::MMAMatrixType retType =
@@ -146,15 +108,18 @@ struct WmmaLoadOpToNVVMLowering
return rewriter.notifyMatchFailure(op, kInvalidCaseStr);
Type resType = convertMMAToLLVMType(retType);
+ Location loc = op->getLoc();
// Create nvvm.mma_load op according to the operand types.
- Value leadingDim32 = rewriter.create<LLVM::ConstantOp>(
- loc, rewriter.getI32Type(), leadDimension);
+ Value dataPtr = getStridedElementPtr(
+ loc, subgroupMmaLoadMatrixOp.srcMemref().getType().cast<MemRefType>(),
+ adaptor.srcMemref(), adaptor.indices(), rewriter);
+ Value leadingDim = rewriter.create<LLVM::ConstantOp>(
+ loc, rewriter.getI32Type(),
+ subgroupMmaLoadMatrixOp.leadDimensionAttr());
rewriter.replaceOpWithNewOp<NVVM::WMMALoadOp>(
- op, resType, loadAddressCasted, leadingDim32, m, n, k, layout, eltype,
- frag);
-
+ op, resType, dataPtr, leadingDim, m, n, k, layout, eltype, frag);
return success();
}
};
@@ -178,41 +143,6 @@ struct WmmaStoreOpToNVVMLowering
Location loc = op->getLoc();
- // MemRefDescriptor to extract alignedPtr and offset.
- MemRefDescriptor promotedDstOp(adaptor.dstMemref());
-
- // Emit ops which compute the store offset using `dstOffsetI`,
- // `dstOffsetJ`. The actualOffset is (memrefOffset + (alignedPtr +
- // ((leadDimension * dstOffsetI) + dstOffsetJ)).
- auto leadDimension = subgroupMmaStoreMatrixOp.leadDimensionAttr();
- SmallVector<Value> indices(adaptor.indices());
- Value dstOffsetIVal = indices[0];
- Value dstOffsetJVal = indices[1];
- Value leadingDim = rewriter.create<LLVM::ConstantOp>(
- loc, dstOffsetIVal.getType(), leadDimension);
- Value numElemsLeadDim =
- rewriter.create<LLVM::MulOp>(loc, leadingDim, dstOffsetIVal);
- Value loadOffset =
- rewriter.create<LLVM::AddOp>(loc, numElemsLeadDim, dstOffsetJVal);
-
- Value promotedDstOpToUse;
- promotedDstOpToUse = promotedDstOp.offset(rewriter, loc);
- Value actualOffset =
- rewriter.create<LLVM::AddOp>(loc, loadOffset, promotedDstOpToUse);
- Value storeAddress = rewriter.create<LLVM::GEPOp>(
- loc, promotedDstOp.getElementPtrType(),
- promotedDstOp.alignedPtr(rewriter, loc), ArrayRef<Value>{actualOffset});
-
- // Bitcast the base address pointer of the destination memref, So that
- // values can be stored in chunks of 32-bits and semantics match with the
- // intrinsic exposed by NVPTX backend.
- Value storeAddressCasted = rewriter.create<LLVM::BitcastOp>(
- loc,
- LLVM::LLVMPointerType::get(
- rewriter.getI32Type(),
- promotedDstOp.getElementPtrType().getAddressSpace()),
- storeAddress);
-
SmallVector<Value, 4> storeOpOperands;
// Get the shape of the MMAMatrix type being stored. The shape will
// choose which intrinsic this op will be lowered to.
@@ -234,12 +164,15 @@ struct WmmaStoreOpToNVVMLowering
rewriter.getI32ArrayAttr(i));
storeOpOperands.push_back(toUse);
}
- Value leadingDim32 = rewriter.create<LLVM::ConstantOp>(
- loc, rewriter.getI32Type(), leadDimension);
- rewriter.create<NVVM::WMMAStoreOp>(loc, storeAddressCasted, m, n, k, layout,
- eltype, storeOpOperands, leadingDim32);
- rewriter.eraseOp(op);
+ Value dataPtr = getStridedElementPtr(
+ loc, subgroupMmaStoreMatrixOp.dstMemref().getType().cast<MemRefType>(),
+ adaptor.dstMemref(), adaptor.indices(), rewriter);
+ Value leadingDim = rewriter.create<LLVM::ConstantOp>(
+ loc, rewriter.getI32Type(),
+ subgroupMmaStoreMatrixOp.leadDimensionAttr());
+ rewriter.replaceOpWithNewOp<NVVM::WMMAStoreOp>(
+ op, dataPtr, m, n, k, layout, eltype, storeOpOperands, leadingDim);
return success();
}
};
diff --git a/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir
index c0ac8a050288f..f322fcac67647 100644
--- a/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir
@@ -13,32 +13,26 @@ gpu.module @test_module {
%0 = gpu.subgroup_mma_load_matrix %wg[%i, %j] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "AOp">
// CHECK: %[[INX:.*]] = llvm.mlir.constant(16 : index) : i64
// CHECK: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
+ // CHECK: %[[BASE:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i64
- // CHECK: %[[LI:.*]] = llvm.mul %[[LDM]], %[[INX]] : i64
+ // CHECK: %[[LI:.*]] = llvm.mul %[[INX]], %[[LDM]] : i64
// CHECK: %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]] : i64
- // CHECK: %[[OFFSET:.*]] = llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i64, array<2 x i64>, array<2 x i64>)>
- // CHECK: %[[LIJO:.*]] = llvm.add %[[LIJ]], %[[OFFSET]] : i64
- // CHECK: %[[BASE:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i64, array<2 x i64>, array<2 x i64>)>
- // CHECK: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJO]]] : (!llvm.ptr<f16, 3>, i64) -> !llvm.ptr<f16, 3>
- // CHECK: %[[CADDRESS:.*]] = llvm.bitcast %[[ADDRESS]] : !llvm.ptr<f16, 3> to !llvm.ptr<i32, 3>
+ // CHECK: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJ]]] : (!llvm.ptr<f16, 3>, i64) -> !llvm.ptr<f16, 3>
// CHECK: %[[LDM32:.*]] = llvm.mlir.constant(32 : index) : i32
- // CHECK: %[[FRAG:.*]] = nvvm.wmma.load %[[CADDRESS]], %[[LDM32]]
- // CHECK-SAME: {eltype = "f16", frag = "a", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : (!llvm.ptr<i32, 3>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+ // CHECK: %[[FRAG:.*]] = nvvm.wmma.load %[[ADDRESS]], %[[LDM32]]
+ // CHECK-SAME: {eltype = "f16", frag = "a", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : (!llvm.ptr<f16, 3>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: llvm.return %[[FRAG]] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK32: %[[INX:.*]] = llvm.mlir.constant(16 : index) : i32
// CHECK32: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
+ // CHECK32: %[[BASE:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i32, array<2 x i32>, array<2 x i32>)>
// CHECK32: %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i32
- // CHECK32: %[[LI:.*]] = llvm.mul %[[LDM]], %[[INX]] : i32
+ // CHECK32: %[[LI:.*]] = llvm.mul %[[INX]], %[[LDM]] : i32
// CHECK32: %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]] : i32
- // CHECK32: %[[OFFSET:.*]] = llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i32, array<2 x i32>, array<2 x i32>)>
- // CHECK32: %[[LIJO:.*]] = llvm.add %[[LIJ]], %[[OFFSET]] : i32
- // CHECK32: %[[BASE:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i32, array<2 x i32>, array<2 x i32>)>
- // CHECK32: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJO]]] : (!llvm.ptr<f16, 3>, i32) -> !llvm.ptr<f16, 3>
- // CHECK32: %[[CADDRESS:.*]] = llvm.bitcast %[[ADDRESS]] : !llvm.ptr<f16, 3> to !llvm.ptr<i32, 3>
+ // CHECK32: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJ]]] : (!llvm.ptr<f16, 3>, i32) -> !llvm.ptr<f16, 3>
// CHECK32: %[[LDM32:.*]] = llvm.mlir.constant(32 : index) : i32
- // CHECK32: %[[FRAG:.*]] = nvvm.wmma.load %[[CADDRESS]], %[[LDM32]]
- // CHECK32-SAME: {eltype = "f16", frag = "a", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : (!llvm.ptr<i32, 3>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+ // CHECK32: %[[FRAG:.*]] = nvvm.wmma.load %[[ADDRESS]], %[[LDM32]]
+ // CHECK32-SAME: {eltype = "f16", frag = "a", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : (!llvm.ptr<f16, 3>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK32: llvm.return %[[FRAG]] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
return %0 : !gpu.mma_matrix<16x16xf16, "AOp">
}
@@ -59,40 +53,34 @@ gpu.module @test_module {
gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16, 3>
// CHECK: %[[INX:.*]] = llvm.mlir.constant(16 : index) : i64
// CHECK: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
- // CHECK: %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i64
- // CHECK: %[[LI:.*]] = llvm.mul %[[LDM]], %[[INX]] : i64
- // CHECK: %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]] : i64
- // CHECK: %[[OFFSET:.*]] = llvm.extractvalue %17[2] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i64, array<2 x i64>, array<2 x i64>)>
- // CHECK: %[[LIJO:.*]] = llvm.add %[[LIJ]], %[[OFFSET]] : i64
- // CHECK: %[[BASE:.*]] = llvm.extractvalue %17[1] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i64, array<2 x i64>, array<2 x i64>)>
- // CHECK: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJO]]] : (!llvm.ptr<f16, 3>, i64) -> !llvm.ptr<f16, 3>
- // CHECK: %[[CADDRESS:.*]] = llvm.bitcast %[[ADDRESS]] : !llvm.ptr<f16, 3> to !llvm.ptr<i32, 3>
// CHECK: %[[EL1:.*]] = llvm.extractvalue %[[D]][0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %[[EL2:.*]] = llvm.extractvalue %[[D]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %[[EL3:.*]] = llvm.extractvalue %[[D]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %[[EL4:.*]] = llvm.extractvalue %[[D]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+ // CHECK: %[[BASE:.*]] = llvm.extractvalue %17[1] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i64, array<2 x i64>, array<2 x i64>)>
+ // CHECK: %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i64
+ // CHECK: %[[LI:.*]] = llvm.mul %[[INX]], %[[LDM]] : i64
+ // CHECK: %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]] : i64
+ // CHECK: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJ]]] : (!llvm.ptr<f16, 3>, i64) -> !llvm.ptr<f16, 3>
// CHECK: %[[LDM32:.*]] = llvm.mlir.constant(32 : index) : i32
- // CHECK: nvvm.wmma.store %[[CADDRESS]], %[[LDM32]], %[[EL1]], %[[EL2]], %[[EL3]], %[[EL4]]
- // CHECK-SAME: {eltype = "f16", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : !llvm.ptr<i32, 3>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>
+ // CHECK: nvvm.wmma.store %[[ADDRESS]], %[[LDM32]], %[[EL1]], %[[EL2]], %[[EL3]], %[[EL4]]
+ // CHECK-SAME: {eltype = "f16", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : !llvm.ptr<f16, 3>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>
// CHECK: llvm.return
// CHECK32: %[[INX:.*]] = llvm.mlir.constant(16 : index) : i32
// CHECK32: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
- // CHECK32: %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i32
- // CHECK32: %[[LI:.*]] = llvm.mul %[[LDM]], %[[INX]] : i32
- // CHECK32: %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]] : i32
- // CHECK32: %[[OFFSET:.*]] = llvm.extractvalue %17[2] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i32, array<2 x i32>, array<2 x i32>)>
- // CHECK32: %[[LIJO:.*]] = llvm.add %[[LIJ]], %[[OFFSET]] : i32
- // CHECK32: %[[BASE:.*]] = llvm.extractvalue %17[1] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i32, array<2 x i32>, array<2 x i32>)>
- // CHECK32: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJO]]] : (!llvm.ptr<f16, 3>, i32) -> !llvm.ptr<f16, 3>
- // CHECK32: %[[CADDRESS:.*]] = llvm.bitcast %[[ADDRESS]] : !llvm.ptr<f16, 3> to !llvm.ptr<i32, 3>
// CHECK32: %[[EL1:.*]] = llvm.extractvalue %[[D]][0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK32: %[[EL2:.*]] = llvm.extractvalue %[[D]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK32: %[[EL3:.*]] = llvm.extractvalue %[[D]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK32: %[[EL4:.*]] = llvm.extractvalue %[[D]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+ // CHECK32: %[[BASE:.*]] = llvm.extractvalue %17[1] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i32, array<2 x i32>, array<2 x i32>)>
+ // CHECK32: %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i32
+ // CHECK32: %[[LI:.*]] = llvm.mul %[[INX]], %[[LDM]] : i32
+ // CHECK32: %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]] : i32
+ // CHECK32: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJ]]] : (!llvm.ptr<f16, 3>, i32) -> !llvm.ptr<f16, 3>
// CHECK32: %[[LDM32:.*]] = llvm.mlir.constant(32 : index) : i32
- // CHECK32: nvvm.wmma.store %[[CADDRESS]], %[[LDM32]], %[[EL1]], %[[EL2]], %[[EL3]], %[[EL4]]
- // CHECK32-SAME: {eltype = "f16", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : !llvm.ptr<i32, 3>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>
+ // CHECK32: nvvm.wmma.store %[[ADDRESS]], %[[LDM32]], %[[EL1]], %[[EL2]], %[[EL3]], %[[EL4]]
+ // CHECK32-SAME: {eltype = "f16", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : !llvm.ptr<f16, 3>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>
// CHECK32: llvm.return
return
}
@@ -139,13 +127,13 @@ gpu.module @test_module {
gpu.module @test_module {
// CHECK-LABEL: func @gpu_wmma_mma_loop_op
-// CHECK: %[[C:.+]] = nvvm.wmma.load %{{.*}}, %{{.*}} {eltype = "f16", frag = "c", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : (!llvm.ptr<i32>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+// CHECK: %[[C:.+]] = nvvm.wmma.load %{{.*}}, %{{.*}} {eltype = "f16", frag = "c", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : (!llvm.ptr<f16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: llvm.br ^bb1(%{{.*}}, %[[C]] : i64, !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>)
// CHECK: ^bb1(%{{.*}}: i64, %[[ACC:.+]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>): // 2 preds: ^bb0, ^bb2
// CHECK: llvm.cond_br %{{.*}}, ^bb2, ^bb3
// CHECK: ^bb2: // pred: ^bb1
-// CHECK: %[[A:.+]] = nvvm.wmma.load %{{.*}}, %{{.*}} {eltype = "f16", frag = "a", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : (!llvm.ptr<i32>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
-// CHECK: %[[B:.+]] = nvvm.wmma.load %{{.*}}, %{{.*}} {eltype = "f16", frag = "b", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : (!llvm.ptr<i32>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+// CHECK: %[[A:.+]] = nvvm.wmma.load %{{.*}}, %{{.*}} {eltype = "f16", frag = "a", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : (!llvm.ptr<f16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+// CHECK: %[[B:.+]] = nvvm.wmma.load %{{.*}}, %{{.*}} {eltype = "f16", frag = "b", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : (!llvm.ptr<f16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %[[A0:.+]] = llvm.extractvalue %[[A]][0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %[[A1:.+]] = llvm.extractvalue %[[A]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %[[A2:.+]] = llvm.extractvalue %[[A]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
@@ -173,7 +161,7 @@ gpu.module @test_module {
// CHECK: %[[E1:.+]] = llvm.extractvalue %[[ACC]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %[[E2:.+]] = llvm.extractvalue %[[ACC]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %[[E3:.+]] = llvm.extractvalue %[[ACC]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
-// CHECK: nvvm.wmma.store %{{.*}}, %{{.*}}, %[[E0]], %[[E1]], %[[E2]], %[[E3]] {eltype = "f16", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : !llvm.ptr<i32>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>
+// CHECK: nvvm.wmma.store %{{.*}}, %{{.*}}, %[[E0]], %[[E1]], %[[E2]], %[[E3]] {eltype = "f16", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : !llvm.ptr<f16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>
builtin.func @gpu_wmma_mma_loop_op(%arg0: memref<128x128xf16>, %arg1: memref<128x128xf16>, %arg2: memref<128x128xf16>) {
%c0 = arith.constant 0 : index
More information about the Mlir-commits
mailing list