[Mlir-commits] [mlir] eacd6e1 - [mlir][GPUtoNVVM] Relax restriction on wmma op lowering
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Oct 27 21:32:39 PDT 2021
Author: thomasraoux
Date: 2021-10-27T21:31:55-07:00
New Revision: eacd6e1ebef511b216c8b2805e7e662b9902de74
URL: https://github.com/llvm/llvm-project/commit/eacd6e1ebef511b216c8b2805e7e662b9902de74
DIFF: https://github.com/llvm/llvm-project/commit/eacd6e1ebef511b216c8b2805e7e662b9902de74.diff
LOG: [mlir][GPUtoNVVM] Relax restriction on wmma op lowering
Allow lowering of wmma ops with 64bits indexes. Change the default
version of the test to use default layout.
Differential Revision: https://reviews.llvm.org/D112479
Added:
Modified:
mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir
mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f16.mlir
mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f32.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
index 0296390d9c082..681725441539f 100644
--- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
@@ -75,19 +75,8 @@ struct WmmaLoadOpToNVVMLowering
if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
return failure();
- unsigned indexTypeBitwidth =
- this->getTypeConverter()->getIndexTypeBitwidth();
-
- // The corresponding intrinsics expects leadDimension to be a 32-bit
- // integer, so all the calculations of linearizing the load address
- // must also follow this restriction.
- if (indexTypeBitwidth != 32)
- return rewriter.notifyMatchFailure(
- op, "Expected indices to the memref to be 32-bit wide.");
Location loc = op->getLoc();
- auto leadDimension = subgroupMmaLoadMatrixOp.leadDimensionAttr();
-
// MemRefDescriptor to extract alignedPtr and offset.
MemRefDescriptor promotedSrcOp(adaptor.srcMemref());
@@ -95,21 +84,21 @@ struct WmmaLoadOpToNVVMLowering
// `srcOffsetJ`. The actualOffset is (memrefOffset + (alignedPtr +
// ((leadDimension * srcOffsetI) + srcOffsetJ)). The memrefs here are
// assumed to be normalized and hence the simple conversion works.
+ IntegerAttr leadDimension = subgroupMmaLoadMatrixOp.leadDimensionAttr();
SmallVector<Value> indices(adaptor.indices());
Value srcOffsetIVal = indices[0];
Value srcOffsetJVal = indices[1];
- Type i32Ty = rewriter.getI32Type();
- Value leadingDim32 =
- rewriter.create<LLVM::ConstantOp>(loc, i32Ty, leadDimension);
+ Value leadingDim = rewriter.create<LLVM::ConstantOp>(
+ loc, srcOffsetIVal.getType(), leadDimension);
Value numElemsLeadDim =
- rewriter.create<LLVM::MulOp>(loc, i32Ty, leadingDim32, srcOffsetIVal);
- Value loadOffset = rewriter.create<LLVM::AddOp>(loc, i32Ty, numElemsLeadDim,
- srcOffsetJVal);
+ rewriter.create<LLVM::MulOp>(loc, leadingDim, srcOffsetIVal);
+ Value loadOffset =
+ rewriter.create<LLVM::AddOp>(loc, numElemsLeadDim, srcOffsetJVal);
Value promotedSrcOpToUse;
promotedSrcOpToUse = promotedSrcOp.offset(rewriter, loc);
- Value actualOffset = rewriter.create<LLVM::AddOp>(loc, i32Ty, loadOffset,
- promotedSrcOpToUse);
+ Value actualOffset =
+ rewriter.create<LLVM::AddOp>(loc, loadOffset, promotedSrcOpToUse);
Value loadAddress = rewriter.create<LLVM::GEPOp>(
loc, promotedSrcOp.getElementPtrType(),
promotedSrcOp.alignedPtr(rewriter, loc), ArrayRef<Value>{actualOffset});
@@ -120,7 +109,8 @@ struct WmmaLoadOpToNVVMLowering
Value loadAddressCasted = rewriter.create<LLVM::BitcastOp>(
loc,
LLVM::LLVMPointerType::get(
- i32Ty, promotedSrcOp.getElementPtrType().getAddressSpace()),
+ rewriter.getI32Type(),
+ promotedSrcOp.getElementPtrType().getAddressSpace()),
loadAddress);
// Get the shape of the MMAMatrix type being returned. The shape will
@@ -133,6 +123,8 @@ struct WmmaLoadOpToNVVMLowering
StringRef operandStr = retType.getOperand();
// Create nvvm.mma_load op according to the operand types.
+ Value leadingDim32 = rewriter.create<LLVM::ConstantOp>(
+ loc, rewriter.getI32Type(), leadDimension);
SmallVector<Value, 2> loadOpOperands({loadAddressCasted, leadingDim32});
if (operandStr.equals("AOp")) {
if (retTypeShape[0] == 16 && retTypeShape[1] == 16) {
@@ -182,40 +174,29 @@ struct WmmaStoreOpToNVVMLowering
if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
return failure();
- unsigned indexTypeBitwidth =
- this->getTypeConverter()->getIndexTypeBitwidth();
- // The corresponding intrinsics expects leadDimension to be a 32-bit
- // integer, so all the calculations of linearizing the store address
- // must also follow this restriction.
- if (indexTypeBitwidth != 32)
- return rewriter.notifyMatchFailure(
- op, "expected indices to the memref to be 32-bit wide.");
-
Location loc = op->getLoc();
// MemRefDescriptor to extract alignedPtr and offset.
MemRefDescriptor promotedDstOp(adaptor.dstMemref());
- auto leadDimension = subgroupMmaStoreMatrixOp.leadDimensionAttr();
-
// Emit ops which compute the store offset using `dstOffsetI`,
// `dstOffsetJ`. The actualOffset is (memrefOffset + (alignedPtr +
// ((leadDimension * dstOffsetI) + dstOffsetJ)).
+ auto leadDimension = subgroupMmaStoreMatrixOp.leadDimensionAttr();
SmallVector<Value> indices(adaptor.indices());
Value dstOffsetIVal = indices[0];
Value dstOffsetJVal = indices[1];
- Type i32Ty = rewriter.getI32Type();
- Value leadingDim32 =
- rewriter.create<LLVM::ConstantOp>(loc, i32Ty, leadDimension);
+ Value leadingDim = rewriter.create<LLVM::ConstantOp>(
+ loc, dstOffsetIVal.getType(), leadDimension);
Value numElemsLeadDim =
- rewriter.create<LLVM::MulOp>(loc, i32Ty, leadingDim32, dstOffsetIVal);
- Value loadOffset = rewriter.create<LLVM::AddOp>(loc, i32Ty, numElemsLeadDim,
- dstOffsetJVal);
+ rewriter.create<LLVM::MulOp>(loc, leadingDim, dstOffsetIVal);
+ Value loadOffset =
+ rewriter.create<LLVM::AddOp>(loc, numElemsLeadDim, dstOffsetJVal);
Value promotedDstOpToUse;
promotedDstOpToUse = promotedDstOp.offset(rewriter, loc);
- Value actualOffset = rewriter.create<LLVM::AddOp>(loc, i32Ty, loadOffset,
- promotedDstOpToUse);
+ Value actualOffset =
+ rewriter.create<LLVM::AddOp>(loc, loadOffset, promotedDstOpToUse);
Value storeAddress = rewriter.create<LLVM::GEPOp>(
loc, promotedDstOp.getElementPtrType(),
promotedDstOp.alignedPtr(rewriter, loc), ArrayRef<Value>{actualOffset});
@@ -226,7 +207,8 @@ struct WmmaStoreOpToNVVMLowering
Value storeAddressCasted = rewriter.create<LLVM::BitcastOp>(
loc,
LLVM::LLVMPointerType::get(
- i32Ty, promotedDstOp.getElementPtrType().getAddressSpace()),
+ rewriter.getI32Type(),
+ promotedDstOp.getElementPtrType().getAddressSpace()),
storeAddress);
SmallVector<Value, 4> storeOpOperands;
@@ -245,6 +227,8 @@ struct WmmaStoreOpToNVVMLowering
rewriter.getI32ArrayAttr(i));
storeOpOperands.push_back(toUse);
}
+ Value leadingDim32 = rewriter.create<LLVM::ConstantOp>(
+ loc, rewriter.getI32Type(), leadDimension);
storeOpOperands.push_back(leadingDim32);
// Unpack the results from the source.
if (srcType.getElementType().isF16()) {
diff --git a/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir
index f22fa49324a98..9dd853a39b423 100644
--- a/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir
@@ -1,26 +1,43 @@
-// RUN: mlir-opt --convert-gpu-to-nvvm="index-bitwidth=32" --split-input-file %s | FileCheck %s
+// RUN: mlir-opt --convert-gpu-to-nvvm --split-input-file %s | FileCheck %s
+// RUN: mlir-opt --convert-gpu-to-nvvm="index-bitwidth=32" --split-input-file %s | FileCheck --check-prefix=CHECK32 %s
gpu.module @test_module {
// CHECK-LABEL: func @gpu_wmma_load_op() ->
// CHECK-SAME: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> {
+ // CHECK32-LABEL: func @gpu_wmma_load_op() ->
builtin.func @gpu_wmma_load_op() -> (!gpu.mma_matrix<16x16xf16, "AOp">) {
%wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3>
%i = arith.constant 16 : index
%j = arith.constant 16 : index
%0 = gpu.subgroup_mma_load_matrix %wg[%i, %j] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "AOp">
- // CHECK: %[[INX:.*]] = llvm.mlir.constant(16 : index) : i32
+ // CHECK: %[[INX:.*]] = llvm.mlir.constant(16 : index) : i64
// CHECK: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
- // CHECK: %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i32
- // CHECK: %[[LI:.*]] = llvm.mul %[[LDM]], %[[INX]] : i32
- // CHECK: %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]] : i32
- // CHECK: %[[OFFSET:.*]] = llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i32, array<2 x i32>, array<2 x i32>)>
- // CHECK: %[[LIJO:.*]] = llvm.add %[[LIJ]], %[[OFFSET]] : i32
- // CHECK: %[[BASE:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i32, array<2 x i32>, array<2 x i32>)>
- // CHECK: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJO]]] : (!llvm.ptr<f16, 3>, i32) -> !llvm.ptr<f16, 3>
+ // CHECK: %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i64
+ // CHECK: %[[LI:.*]] = llvm.mul %[[LDM]], %[[INX]] : i64
+ // CHECK: %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]] : i64
+ // CHECK: %[[OFFSET:.*]] = llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i64, array<2 x i64>, array<2 x i64>)>
+ // CHECK: %[[LIJO:.*]] = llvm.add %[[LIJ]], %[[OFFSET]] : i64
+ // CHECK: %[[BASE:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i64, array<2 x i64>, array<2 x i64>)>
+ // CHECK: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJO]]] : (!llvm.ptr<f16, 3>, i64) -> !llvm.ptr<f16, 3>
// CHECK: %[[CADDRESS:.*]] = llvm.bitcast %[[ADDRESS]] : !llvm.ptr<f16, 3> to !llvm.ptr<i32, 3>
- // CHECK: %[[FRAG:.*]] = nvvm.wmma.m16n16k16.load.a.f16.row.stride %[[CADDRESS]], %[[LDM]] : (!llvm.ptr<i32, 3>, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+ // CHECK: %[[LDM32:.*]] = llvm.mlir.constant(32 : index) : i32
+ // CHECK: %[[FRAG:.*]] = nvvm.wmma.m16n16k16.load.a.f16.row.stride %[[CADDRESS]], %[[LDM32]] : (!llvm.ptr<i32, 3>, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: llvm.return %[[FRAG]] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+
+ // CHECK32: %[[INX:.*]] = llvm.mlir.constant(16 : index) : i32
+ // CHECK32: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
+ // CHECK32: %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i32
+ // CHECK32: %[[LI:.*]] = llvm.mul %[[LDM]], %[[INX]] : i32
+ // CHECK32: %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]] : i32
+ // CHECK32: %[[OFFSET:.*]] = llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i32, array<2 x i32>, array<2 x i32>)>
+ // CHECK32: %[[LIJO:.*]] = llvm.add %[[LIJ]], %[[OFFSET]] : i32
+ // CHECK32: %[[BASE:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i32, array<2 x i32>, array<2 x i32>)>
+ // CHECK32: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJO]]] : (!llvm.ptr<f16, 3>, i32) -> !llvm.ptr<f16, 3>
+ // CHECK32: %[[CADDRESS:.*]] = llvm.bitcast %[[ADDRESS]] : !llvm.ptr<f16, 3> to !llvm.ptr<i32, 3>
+ // CHECK32: %[[LDM32:.*]] = llvm.mlir.constant(32 : index) : i32
+ // CHECK32: %[[FRAG:.*]] = nvvm.wmma.m16n16k16.load.a.f16.row.stride %[[CADDRESS]], %[[LDM32]] : (!llvm.ptr<i32, 3>, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+ // CHECK32: llvm.return %[[FRAG]] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
return %0 : !gpu.mma_matrix<16x16xf16, "AOp">
}
}
@@ -31,27 +48,48 @@ gpu.module @test_module {
// CHECK-LABEL: func @gpu_wmma_store_op
// CHECK-SAME: (%[[D:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>) {
+ // CHECK32-LABEL: func @gpu_wmma_store_op
+ // CHECK32-SAME: (%[[D:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>) {
builtin.func @gpu_wmma_store_op(%arg0 : !gpu.mma_matrix<16x16xf16, "COp">) -> () {
%sg = memref.alloca(){alignment = 32} : memref<32x32xf16, 3>
%i = arith.constant 16 : index
%j = arith.constant 16 : index
gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16, 3>
- // CHECK: %[[INX:.*]] = llvm.mlir.constant(16 : index) : i32
+ // CHECK: %[[INX:.*]] = llvm.mlir.constant(16 : index) : i64
// CHECK: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
- // CHECK: %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i32
- // CHECK: %[[LI:.*]] = llvm.mul %[[LDM]], %[[INX]] : i32
- // CHECK: %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]] : i32
- // CHECK: %[[OFFSET:.*]] = llvm.extractvalue %17[2] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i32, array<2 x i32>, array<2 x i32>)>
- // CHECK: %[[LIJO:.*]] = llvm.add %[[LIJ]], %[[OFFSET]] : i32
- // CHECK: %[[BASE:.*]] = llvm.extractvalue %17[1] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i32, array<2 x i32>, array<2 x i32>)>
- // CHECK: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJO]]] : (!llvm.ptr<f16, 3>, i32) -> !llvm.ptr<f16, 3>
+ // CHECK: %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i64
+ // CHECK: %[[LI:.*]] = llvm.mul %[[LDM]], %[[INX]] : i64
+ // CHECK: %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]] : i64
+ // CHECK: %[[OFFSET:.*]] = llvm.extractvalue %17[2] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i64, array<2 x i64>, array<2 x i64>)>
+ // CHECK: %[[LIJO:.*]] = llvm.add %[[LIJ]], %[[OFFSET]] : i64
+ // CHECK: %[[BASE:.*]] = llvm.extractvalue %17[1] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i64, array<2 x i64>, array<2 x i64>)>
+ // CHECK: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJO]]] : (!llvm.ptr<f16, 3>, i64) -> !llvm.ptr<f16, 3>
// CHECK: %[[CADDRESS:.*]] = llvm.bitcast %[[ADDRESS]] : !llvm.ptr<f16, 3> to !llvm.ptr<i32, 3>
// CHECK: %[[EL1:.*]] = llvm.extractvalue %[[D]][0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %[[EL2:.*]] = llvm.extractvalue %[[D]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %[[EL3:.*]] = llvm.extractvalue %[[D]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %[[EL4:.*]] = llvm.extractvalue %[[D]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
- // CHECK: nvvm.wmma.m16n16k16.store.d.f16.row.stride %[[CADDRESS]], %[[EL1]], %[[EL2]], %[[EL3]], %[[EL4]], %[[LDM]] : !llvm.ptr<i32, 3>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, i32
+ // CHECK: %[[LDM32:.*]] = llvm.mlir.constant(32 : index) : i32
+ // CHECK: nvvm.wmma.m16n16k16.store.d.f16.row.stride %[[CADDRESS]], %[[EL1]], %[[EL2]], %[[EL3]], %[[EL4]], %[[LDM32]] : !llvm.ptr<i32, 3>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, i32
// CHECK: llvm.return
+
+ // CHECK32: %[[INX:.*]] = llvm.mlir.constant(16 : index) : i32
+ // CHECK32: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
+ // CHECK32: %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i32
+ // CHECK32: %[[LI:.*]] = llvm.mul %[[LDM]], %[[INX]] : i32
+ // CHECK32: %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]] : i32
+ // CHECK32: %[[OFFSET:.*]] = llvm.extractvalue %17[2] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i32, array<2 x i32>, array<2 x i32>)>
+ // CHECK32: %[[LIJO:.*]] = llvm.add %[[LIJ]], %[[OFFSET]] : i32
+ // CHECK32: %[[BASE:.*]] = llvm.extractvalue %17[1] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i32, array<2 x i32>, array<2 x i32>)>
+ // CHECK32: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJO]]] : (!llvm.ptr<f16, 3>, i32) -> !llvm.ptr<f16, 3>
+ // CHECK32: %[[CADDRESS:.*]] = llvm.bitcast %[[ADDRESS]] : !llvm.ptr<f16, 3> to !llvm.ptr<i32, 3>
+ // CHECK32: %[[EL1:.*]] = llvm.extractvalue %[[D]][0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+ // CHECK32: %[[EL2:.*]] = llvm.extractvalue %[[D]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+ // CHECK32: %[[EL3:.*]] = llvm.extractvalue %[[D]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+ // CHECK32: %[[EL4:.*]] = llvm.extractvalue %[[D]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+ // CHECK32: %[[LDM32:.*]] = llvm.mlir.constant(32 : index) : i32
+ // CHECK32: nvvm.wmma.m16n16k16.store.d.f16.row.stride %[[CADDRESS]], %[[EL1]], %[[EL2]], %[[EL3]], %[[EL4]], %[[LDM32]] : !llvm.ptr<i32, 3>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, i32
+ // CHECK32: llvm.return
return
}
}
@@ -96,9 +134,9 @@ gpu.module @test_module {
// CHECK-LABEL: func @gpu_wmma_mma_loop_op
// CHECK: %[[C:.+]] = nvvm.wmma.m16n16k16.load.c.f16.row.stride %{{.*}}, %{{.*}} : (!llvm.ptr<i32>, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
-// CHECK: llvm.br ^bb1(%{{.*}}, %[[C]] : i32, !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>)
-// CHECK: ^bb1(%{{.*}}: i32, %[[ACC:.+]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>): // 2 preds: ^bb0, ^bb2
-// CHECK: llvm.cond_br %38, ^bb2, ^bb3
+// CHECK: llvm.br ^bb1(%{{.*}}, %[[C]] : i64, !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>)
+// CHECK: ^bb1(%{{.*}}: i64, %[[ACC:.+]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>): // 2 preds: ^bb0, ^bb2
+// CHECK: llvm.cond_br %{{.*}}, ^bb2, ^bb3
// CHECK: ^bb2: // pred: ^bb1
// CHECK: %[[A:.+]] = nvvm.wmma.m16n16k16.load.a.f16.row.stride %{{.*}}, %{{.*}} : (!llvm.ptr<i32>, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %[[B:.+]] = nvvm.wmma.m16n16k16.load.b.f16.row.stride %{{.*}}, %{{.*}} : (!llvm.ptr<i32>, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
@@ -123,13 +161,13 @@ gpu.module @test_module {
// CHECK: %[[ACC2:.+]] = llvm.extractvalue %[[ACC]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %[[ACC3:.+]] = llvm.extractvalue %[[ACC]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %[[ACC_MUL:.+]] = nvvm.wmma.m16n16k16.mma.row.row.f16.f16 %[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]], %[[A7]], %[[B0]], %[[B1]], %[[B2]], %[[B3]], %[[B4]], %[[B5]], %[[B6]], %[[B7]], %[[ACC0]], %[[ACC1]], %[[ACC2]], %[[ACC3]] : vector<2xf16> -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
-// CHECK: llvm.br ^bb1(%{{.*}}, %[[ACC_MUL]] : i32, !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>)
+// CHECK: llvm.br ^bb1(%{{.*}}, %[[ACC_MUL]] : i64, !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>)
// CHECK: ^bb3: // pred: ^bb1
-// CHECK: %87 = llvm.extractvalue %[[ACC]][0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
-// CHECK: %88 = llvm.extractvalue %[[ACC]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
-// CHECK: %89 = llvm.extractvalue %[[ACC]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
-// CHECK: %90 = llvm.extractvalue %[[ACC]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
-// CHECK: nvvm.wmma.m16n16k16.store.d.f16.row.stride %86, %87, %88, %89, %90, %79 : !llvm.ptr<i32>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, i32
+// CHECK: %[[E0:.+]] = llvm.extractvalue %[[ACC]][0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+// CHECK: %[[E1:.+]] = llvm.extractvalue %[[ACC]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+// CHECK: %[[E2:.+]] = llvm.extractvalue %[[ACC]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+// CHECK: %[[E3:.+]] = llvm.extractvalue %[[ACC]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+// CHECK: nvvm.wmma.m16n16k16.store.d.f16.row.stride %{{.*}}, %[[E0]], %[[E1]], %[[E2]], %[[E3]], %{{.*}} : !llvm.ptr<i32>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, i32
builtin.func @gpu_wmma_mma_loop_op(%arg0: memref<128x128xf16>, %arg1: memref<128x128xf16>, %arg2: memref<128x128xf16>) {
%c0 = arith.constant 0 : index
diff --git a/mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f16.mlir b/mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f16.mlir
index b977cdda92b47..52d5faf43f3e5 100644
--- a/mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f16.mlir
+++ b/mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f16.mlir
@@ -1,6 +1,6 @@
// RUN: mlir-opt %s \
// RUN: -gpu-kernel-outlining \
-// RUN: -pass-pipeline='gpu.module(strip-debuginfo,convert-gpu-to-nvvm{index-bitwidth=32},gpu-to-cubin{chip=sm_70})' \
+// RUN: -pass-pipeline='gpu.module(strip-debuginfo,convert-gpu-to-nvvm,gpu-to-cubin{chip=sm_70})' \
// RUN: --convert-scf-to-std -gpu-to-llvm \
// RUN: | mlir-cpu-runner \
// RUN: --shared-libs=%linalg_test_lib_dir/libmlir_cuda_runtime%shlibext \
diff --git a/mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f32.mlir b/mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f32.mlir
index b3f90e5c1c625..a25f5d3408a39 100644
--- a/mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f32.mlir
+++ b/mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f32.mlir
@@ -1,6 +1,6 @@
// RUN: mlir-opt %s \
// RUN: -gpu-kernel-outlining \
-// RUN: -pass-pipeline='gpu.module(strip-debuginfo,convert-gpu-to-nvvm{index-bitwidth=32},gpu-to-cubin{chip=sm_70})' \
+// RUN: -pass-pipeline='gpu.module(strip-debuginfo,convert-gpu-to-nvvm,gpu-to-cubin{chip=sm_70})' \
// RUN: --convert-scf-to-std -gpu-to-llvm \
// RUN: | mlir-cpu-runner \
// RUN: --shared-libs=%linalg_test_lib_dir/libmlir_cuda_runtime%shlibext \
More information about the Mlir-commits
mailing list