[Mlir-commits] [mlir] eacd6e1 - [mlir][GPUtoNVVM] Relax restriction on wmma op lowering

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Oct 27 21:32:39 PDT 2021


Author: thomasraoux
Date: 2021-10-27T21:31:55-07:00
New Revision: eacd6e1ebef511b216c8b2805e7e662b9902de74

URL: https://github.com/llvm/llvm-project/commit/eacd6e1ebef511b216c8b2805e7e662b9902de74
DIFF: https://github.com/llvm/llvm-project/commit/eacd6e1ebef511b216c8b2805e7e662b9902de74.diff

LOG: [mlir][GPUtoNVVM] Relax restriction on wmma op lowering

Allow lowering of wmma ops with 64bits indexes. Change the default
version of the test to use default layout.

Differential Revision: https://reviews.llvm.org/D112479

Added: 
    

Modified: 
    mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
    mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir
    mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f16.mlir
    mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f32.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
index 0296390d9c082..681725441539f 100644
--- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
@@ -75,19 +75,8 @@ struct WmmaLoadOpToNVVMLowering
     if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
       return failure();
 
-    unsigned indexTypeBitwidth =
-        this->getTypeConverter()->getIndexTypeBitwidth();
-
-    // The corresponding intrinsics expects leadDimension to be a 32-bit
-    // integer, so all the calculations of linearizing the load address
-    // must also follow this restriction.
-    if (indexTypeBitwidth != 32)
-      return rewriter.notifyMatchFailure(
-          op, "Expected indices to the memref to be 32-bit wide.");
     Location loc = op->getLoc();
 
-    auto leadDimension = subgroupMmaLoadMatrixOp.leadDimensionAttr();
-
     // MemRefDescriptor to extract alignedPtr and offset.
     MemRefDescriptor promotedSrcOp(adaptor.srcMemref());
 
@@ -95,21 +84,21 @@ struct WmmaLoadOpToNVVMLowering
     // `srcOffsetJ`. The actualOffset is (memrefOffset + (alignedPtr +
     // ((leadDimension * srcOffsetI) + srcOffsetJ)). The memrefs here are
     // assumed to be normalized and hence the simple conversion works.
+    IntegerAttr leadDimension = subgroupMmaLoadMatrixOp.leadDimensionAttr();
     SmallVector<Value> indices(adaptor.indices());
     Value srcOffsetIVal = indices[0];
     Value srcOffsetJVal = indices[1];
-    Type i32Ty = rewriter.getI32Type();
-    Value leadingDim32 =
-        rewriter.create<LLVM::ConstantOp>(loc, i32Ty, leadDimension);
+    Value leadingDim = rewriter.create<LLVM::ConstantOp>(
+        loc, srcOffsetIVal.getType(), leadDimension);
     Value numElemsLeadDim =
-        rewriter.create<LLVM::MulOp>(loc, i32Ty, leadingDim32, srcOffsetIVal);
-    Value loadOffset = rewriter.create<LLVM::AddOp>(loc, i32Ty, numElemsLeadDim,
-                                                    srcOffsetJVal);
+        rewriter.create<LLVM::MulOp>(loc, leadingDim, srcOffsetIVal);
+    Value loadOffset =
+        rewriter.create<LLVM::AddOp>(loc, numElemsLeadDim, srcOffsetJVal);
 
     Value promotedSrcOpToUse;
     promotedSrcOpToUse = promotedSrcOp.offset(rewriter, loc);
-    Value actualOffset = rewriter.create<LLVM::AddOp>(loc, i32Ty, loadOffset,
-                                                      promotedSrcOpToUse);
+    Value actualOffset =
+        rewriter.create<LLVM::AddOp>(loc, loadOffset, promotedSrcOpToUse);
     Value loadAddress = rewriter.create<LLVM::GEPOp>(
         loc, promotedSrcOp.getElementPtrType(),
         promotedSrcOp.alignedPtr(rewriter, loc), ArrayRef<Value>{actualOffset});
@@ -120,7 +109,8 @@ struct WmmaLoadOpToNVVMLowering
     Value loadAddressCasted = rewriter.create<LLVM::BitcastOp>(
         loc,
         LLVM::LLVMPointerType::get(
-            i32Ty, promotedSrcOp.getElementPtrType().getAddressSpace()),
+            rewriter.getI32Type(),
+            promotedSrcOp.getElementPtrType().getAddressSpace()),
         loadAddress);
 
     // Get the shape of the MMAMatrix type being returned. The shape will
@@ -133,6 +123,8 @@ struct WmmaLoadOpToNVVMLowering
     StringRef operandStr = retType.getOperand();
 
     // Create nvvm.mma_load op according to the operand types.
+    Value leadingDim32 = rewriter.create<LLVM::ConstantOp>(
+        loc, rewriter.getI32Type(), leadDimension);
     SmallVector<Value, 2> loadOpOperands({loadAddressCasted, leadingDim32});
     if (operandStr.equals("AOp")) {
       if (retTypeShape[0] == 16 && retTypeShape[1] == 16) {
@@ -182,40 +174,29 @@ struct WmmaStoreOpToNVVMLowering
     if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
       return failure();
 
-    unsigned indexTypeBitwidth =
-        this->getTypeConverter()->getIndexTypeBitwidth();
-    // The corresponding intrinsics expects leadDimension to be a 32-bit
-    // integer, so all the calculations of linearizing the store address
-    // must also follow this restriction.
-    if (indexTypeBitwidth != 32)
-      return rewriter.notifyMatchFailure(
-          op, "expected indices to the memref to be 32-bit wide.");
-
     Location loc = op->getLoc();
 
     // MemRefDescriptor to extract alignedPtr and offset.
     MemRefDescriptor promotedDstOp(adaptor.dstMemref());
 
-    auto leadDimension = subgroupMmaStoreMatrixOp.leadDimensionAttr();
-
     // Emit ops which compute the store offset using `dstOffsetI`,
     // `dstOffsetJ`. The actualOffset is (memrefOffset + (alignedPtr +
     // ((leadDimension * dstOffsetI) + dstOffsetJ)).
+    auto leadDimension = subgroupMmaStoreMatrixOp.leadDimensionAttr();
     SmallVector<Value> indices(adaptor.indices());
     Value dstOffsetIVal = indices[0];
     Value dstOffsetJVal = indices[1];
-    Type i32Ty = rewriter.getI32Type();
-    Value leadingDim32 =
-        rewriter.create<LLVM::ConstantOp>(loc, i32Ty, leadDimension);
+    Value leadingDim = rewriter.create<LLVM::ConstantOp>(
+        loc, dstOffsetIVal.getType(), leadDimension);
     Value numElemsLeadDim =
-        rewriter.create<LLVM::MulOp>(loc, i32Ty, leadingDim32, dstOffsetIVal);
-    Value loadOffset = rewriter.create<LLVM::AddOp>(loc, i32Ty, numElemsLeadDim,
-                                                    dstOffsetJVal);
+        rewriter.create<LLVM::MulOp>(loc, leadingDim, dstOffsetIVal);
+    Value loadOffset =
+        rewriter.create<LLVM::AddOp>(loc, numElemsLeadDim, dstOffsetJVal);
 
     Value promotedDstOpToUse;
     promotedDstOpToUse = promotedDstOp.offset(rewriter, loc);
-    Value actualOffset = rewriter.create<LLVM::AddOp>(loc, i32Ty, loadOffset,
-                                                      promotedDstOpToUse);
+    Value actualOffset =
+        rewriter.create<LLVM::AddOp>(loc, loadOffset, promotedDstOpToUse);
     Value storeAddress = rewriter.create<LLVM::GEPOp>(
         loc, promotedDstOp.getElementPtrType(),
         promotedDstOp.alignedPtr(rewriter, loc), ArrayRef<Value>{actualOffset});
@@ -226,7 +207,8 @@ struct WmmaStoreOpToNVVMLowering
     Value storeAddressCasted = rewriter.create<LLVM::BitcastOp>(
         loc,
         LLVM::LLVMPointerType::get(
-            i32Ty, promotedDstOp.getElementPtrType().getAddressSpace()),
+            rewriter.getI32Type(),
+            promotedDstOp.getElementPtrType().getAddressSpace()),
         storeAddress);
 
     SmallVector<Value, 4> storeOpOperands;
@@ -245,6 +227,8 @@ struct WmmaStoreOpToNVVMLowering
           rewriter.getI32ArrayAttr(i));
       storeOpOperands.push_back(toUse);
     }
+    Value leadingDim32 = rewriter.create<LLVM::ConstantOp>(
+        loc, rewriter.getI32Type(), leadDimension);
     storeOpOperands.push_back(leadingDim32);
     // Unpack the results from the source.
     if (srcType.getElementType().isF16()) {

diff  --git a/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir
index f22fa49324a98..9dd853a39b423 100644
--- a/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir
@@ -1,26 +1,43 @@
-// RUN: mlir-opt --convert-gpu-to-nvvm="index-bitwidth=32" --split-input-file %s | FileCheck %s
+// RUN: mlir-opt --convert-gpu-to-nvvm --split-input-file %s | FileCheck %s
+// RUN: mlir-opt --convert-gpu-to-nvvm="index-bitwidth=32" --split-input-file %s | FileCheck --check-prefix=CHECK32 %s
 
 gpu.module @test_module {
 
   // CHECK-LABEL: func @gpu_wmma_load_op() ->
   // CHECK-SAME: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> {
+  // CHECK32-LABEL: func @gpu_wmma_load_op() ->
   builtin.func @gpu_wmma_load_op() -> (!gpu.mma_matrix<16x16xf16, "AOp">) {
     %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3>
     %i = arith.constant 16 : index
     %j = arith.constant 16 : index
     %0 = gpu.subgroup_mma_load_matrix %wg[%i, %j] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "AOp">
-    // CHECK:  %[[INX:.*]] = llvm.mlir.constant(16 : index) : i32
+    // CHECK:  %[[INX:.*]] = llvm.mlir.constant(16 : index) : i64
     // CHECK: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
-    // CHECK:  %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i32
-    // CHECK:  %[[LI:.*]] = llvm.mul %[[LDM]], %[[INX]]  : i32
-    // CHECK:  %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]]  : i32
-    // CHECK:  %[[OFFSET:.*]] = llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i32, array<2 x i32>, array<2 x i32>)>
-    // CHECK:  %[[LIJO:.*]] = llvm.add %[[LIJ]], %[[OFFSET]]  : i32
-    // CHECK:  %[[BASE:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i32, array<2 x i32>, array<2 x i32>)>
-    // CHECK:  %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJO]]] : (!llvm.ptr<f16, 3>, i32) -> !llvm.ptr<f16, 3>
+    // CHECK:  %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i64
+    // CHECK:  %[[LI:.*]] = llvm.mul %[[LDM]], %[[INX]]  : i64
+    // CHECK:  %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]]  : i64
+    // CHECK:  %[[OFFSET:.*]] = llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i64, array<2 x i64>, array<2 x i64>)>
+    // CHECK:  %[[LIJO:.*]] = llvm.add %[[LIJ]], %[[OFFSET]]  : i64
+    // CHECK:  %[[BASE:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i64, array<2 x i64>, array<2 x i64>)>
+    // CHECK:  %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJO]]] : (!llvm.ptr<f16, 3>, i64) -> !llvm.ptr<f16, 3>
     // CHECK:  %[[CADDRESS:.*]] = llvm.bitcast %[[ADDRESS]] : !llvm.ptr<f16, 3> to !llvm.ptr<i32, 3>
-    // CHECK:  %[[FRAG:.*]] = nvvm.wmma.m16n16k16.load.a.f16.row.stride %[[CADDRESS]], %[[LDM]] : (!llvm.ptr<i32, 3>, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+    // CHECK:  %[[LDM32:.*]] = llvm.mlir.constant(32 : index) : i32
+    // CHECK:  %[[FRAG:.*]] = nvvm.wmma.m16n16k16.load.a.f16.row.stride %[[CADDRESS]], %[[LDM32]] : (!llvm.ptr<i32, 3>, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
     // CHECK:  llvm.return %[[FRAG]] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+
+    // CHECK32:  %[[INX:.*]] = llvm.mlir.constant(16 : index) : i32
+    // CHECK32: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
+    // CHECK32:  %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i32
+    // CHECK32:  %[[LI:.*]] = llvm.mul %[[LDM]], %[[INX]]  : i32
+    // CHECK32:  %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]]  : i32
+    // CHECK32:  %[[OFFSET:.*]] = llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i32, array<2 x i32>, array<2 x i32>)>
+    // CHECK32:  %[[LIJO:.*]] = llvm.add %[[LIJ]], %[[OFFSET]]  : i32
+    // CHECK32:  %[[BASE:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i32, array<2 x i32>, array<2 x i32>)>
+    // CHECK32:  %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJO]]] : (!llvm.ptr<f16, 3>, i32) -> !llvm.ptr<f16, 3>
+    // CHECK32:  %[[CADDRESS:.*]] = llvm.bitcast %[[ADDRESS]] : !llvm.ptr<f16, 3> to !llvm.ptr<i32, 3>
+    // CHECK32:  %[[LDM32:.*]] = llvm.mlir.constant(32 : index) : i32
+    // CHECK32:  %[[FRAG:.*]] = nvvm.wmma.m16n16k16.load.a.f16.row.stride %[[CADDRESS]], %[[LDM32]] : (!llvm.ptr<i32, 3>, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+    // CHECK32:  llvm.return %[[FRAG]] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
     return %0 : !gpu.mma_matrix<16x16xf16, "AOp">
   }
 }
@@ -31,27 +48,48 @@ gpu.module @test_module {
 
   // CHECK-LABEL: func @gpu_wmma_store_op
   // CHECK-SAME: (%[[D:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>) {
+  // CHECK32-LABEL: func @gpu_wmma_store_op
+  // CHECK32-SAME: (%[[D:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>) {
   builtin.func @gpu_wmma_store_op(%arg0 : !gpu.mma_matrix<16x16xf16, "COp">) -> () {
     %sg = memref.alloca(){alignment = 32} : memref<32x32xf16, 3>
     %i = arith.constant 16 : index
     %j = arith.constant 16 : index
     gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16, 3>
-    // CHECK:  %[[INX:.*]] = llvm.mlir.constant(16 : index) : i32
+    // CHECK:  %[[INX:.*]] = llvm.mlir.constant(16 : index) : i64
     // CHECK: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
-    // CHECK:  %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i32
-    // CHECK:  %[[LI:.*]] = llvm.mul %[[LDM]], %[[INX]]  : i32
-    // CHECK:  %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]]  : i32
-    // CHECK:  %[[OFFSET:.*]] = llvm.extractvalue %17[2] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i32, array<2 x i32>, array<2 x i32>)>
-    // CHECK:  %[[LIJO:.*]] = llvm.add %[[LIJ]], %[[OFFSET]]  : i32
-    // CHECK:  %[[BASE:.*]] = llvm.extractvalue %17[1] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i32, array<2 x i32>, array<2 x i32>)>
-    // CHECK:  %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJO]]] : (!llvm.ptr<f16, 3>, i32) -> !llvm.ptr<f16, 3>
+    // CHECK:  %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i64
+    // CHECK:  %[[LI:.*]] = llvm.mul %[[LDM]], %[[INX]]  : i64
+    // CHECK:  %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]]  : i64
+    // CHECK:  %[[OFFSET:.*]] = llvm.extractvalue %17[2] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i64, array<2 x i64>, array<2 x i64>)>
+    // CHECK:  %[[LIJO:.*]] = llvm.add %[[LIJ]], %[[OFFSET]]  : i64
+    // CHECK:  %[[BASE:.*]] = llvm.extractvalue %17[1] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i64, array<2 x i64>, array<2 x i64>)>
+    // CHECK:  %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJO]]] : (!llvm.ptr<f16, 3>, i64) -> !llvm.ptr<f16, 3>
     // CHECK:  %[[CADDRESS:.*]] = llvm.bitcast %[[ADDRESS]] : !llvm.ptr<f16, 3> to !llvm.ptr<i32, 3>
     // CHECK:  %[[EL1:.*]] = llvm.extractvalue %[[D]][0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
     // CHECK:  %[[EL2:.*]] = llvm.extractvalue %[[D]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
     // CHECK:  %[[EL3:.*]] = llvm.extractvalue %[[D]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
     // CHECK:  %[[EL4:.*]] = llvm.extractvalue %[[D]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
-    // CHECK:  nvvm.wmma.m16n16k16.store.d.f16.row.stride %[[CADDRESS]], %[[EL1]], %[[EL2]], %[[EL3]], %[[EL4]], %[[LDM]] : !llvm.ptr<i32, 3>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, i32
+    // CHECK:  %[[LDM32:.*]] = llvm.mlir.constant(32 : index) : i32
+    // CHECK:  nvvm.wmma.m16n16k16.store.d.f16.row.stride %[[CADDRESS]], %[[EL1]], %[[EL2]], %[[EL3]], %[[EL4]], %[[LDM32]] : !llvm.ptr<i32, 3>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, i32
     // CHECK:  llvm.return
+
+    // CHECK32:  %[[INX:.*]] = llvm.mlir.constant(16 : index) : i32
+    // CHECK32: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
+    // CHECK32:  %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i32
+    // CHECK32:  %[[LI:.*]] = llvm.mul %[[LDM]], %[[INX]]  : i32
+    // CHECK32:  %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]]  : i32
+    // CHECK32:  %[[OFFSET:.*]] = llvm.extractvalue %17[2] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i32, array<2 x i32>, array<2 x i32>)>
+    // CHECK32:  %[[LIJO:.*]] = llvm.add %[[LIJ]], %[[OFFSET]]  : i32
+    // CHECK32:  %[[BASE:.*]] = llvm.extractvalue %17[1] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i32, array<2 x i32>, array<2 x i32>)>
+    // CHECK32:  %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJO]]] : (!llvm.ptr<f16, 3>, i32) -> !llvm.ptr<f16, 3>
+    // CHECK32:  %[[CADDRESS:.*]] = llvm.bitcast %[[ADDRESS]] : !llvm.ptr<f16, 3> to !llvm.ptr<i32, 3>
+    // CHECK32:  %[[EL1:.*]] = llvm.extractvalue %[[D]][0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+    // CHECK32:  %[[EL2:.*]] = llvm.extractvalue %[[D]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+    // CHECK32:  %[[EL3:.*]] = llvm.extractvalue %[[D]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+    // CHECK32:  %[[EL4:.*]] = llvm.extractvalue %[[D]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+    // CHECK32:  %[[LDM32:.*]] = llvm.mlir.constant(32 : index) : i32
+    // CHECK32:  nvvm.wmma.m16n16k16.store.d.f16.row.stride %[[CADDRESS]], %[[EL1]], %[[EL2]], %[[EL3]], %[[EL4]], %[[LDM32]] : !llvm.ptr<i32, 3>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, i32
+    // CHECK32:  llvm.return
     return
   }
 }
@@ -96,9 +134,9 @@ gpu.module @test_module {
 
 // CHECK-LABEL: func @gpu_wmma_mma_loop_op
 //       CHECK:   %[[C:.+]] = nvvm.wmma.m16n16k16.load.c.f16.row.stride %{{.*}}, %{{.*}} : (!llvm.ptr<i32>, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
-//       CHECK:   llvm.br ^bb1(%{{.*}}, %[[C]] : i32, !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>)
-//       CHECK:  ^bb1(%{{.*}}: i32, %[[ACC:.+]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>):  // 2 preds: ^bb0, ^bb2
-//       CHECK:   llvm.cond_br %38, ^bb2, ^bb3
+//       CHECK:   llvm.br ^bb1(%{{.*}}, %[[C]] : i64, !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>)
+//       CHECK:  ^bb1(%{{.*}}: i64, %[[ACC:.+]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>):  // 2 preds: ^bb0, ^bb2
+//       CHECK:   llvm.cond_br %{{.*}}, ^bb2, ^bb3
 //       CHECK:  ^bb2:  // pred: ^bb1
 //       CHECK:   %[[A:.+]] = nvvm.wmma.m16n16k16.load.a.f16.row.stride %{{.*}}, %{{.*}} : (!llvm.ptr<i32>, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
 //       CHECK:   %[[B:.+]] = nvvm.wmma.m16n16k16.load.b.f16.row.stride %{{.*}}, %{{.*}} : (!llvm.ptr<i32>, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
@@ -123,13 +161,13 @@ gpu.module @test_module {
 //       CHECK:   %[[ACC2:.+]] = llvm.extractvalue %[[ACC]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
 //       CHECK:   %[[ACC3:.+]] = llvm.extractvalue %[[ACC]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
 //       CHECK:   %[[ACC_MUL:.+]] = nvvm.wmma.m16n16k16.mma.row.row.f16.f16 %[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]], %[[A7]], %[[B0]], %[[B1]], %[[B2]], %[[B3]], %[[B4]], %[[B5]], %[[B6]], %[[B7]], %[[ACC0]], %[[ACC1]], %[[ACC2]], %[[ACC3]] : vector<2xf16> -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
-//       CHECK:   llvm.br ^bb1(%{{.*}}, %[[ACC_MUL]] : i32, !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>)
+//       CHECK:   llvm.br ^bb1(%{{.*}}, %[[ACC_MUL]] : i64, !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>)
 //       CHECK:  ^bb3:  // pred: ^bb1
-//       CHECK:   %87 = llvm.extractvalue %[[ACC]][0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
-//       CHECK:   %88 = llvm.extractvalue %[[ACC]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
-//       CHECK:   %89 = llvm.extractvalue %[[ACC]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
-//       CHECK:   %90 = llvm.extractvalue %[[ACC]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
-//       CHECK:   nvvm.wmma.m16n16k16.store.d.f16.row.stride %86, %87, %88, %89, %90, %79 : !llvm.ptr<i32>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, i32
+//       CHECK:   %[[E0:.+]] = llvm.extractvalue %[[ACC]][0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK:   %[[E1:.+]] = llvm.extractvalue %[[ACC]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK:   %[[E2:.+]] = llvm.extractvalue %[[ACC]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK:   %[[E3:.+]] = llvm.extractvalue %[[ACC]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK:   nvvm.wmma.m16n16k16.store.d.f16.row.stride %{{.*}}, %[[E0]], %[[E1]], %[[E2]], %[[E3]], %{{.*}} : !llvm.ptr<i32>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, i32
 
   builtin.func @gpu_wmma_mma_loop_op(%arg0: memref<128x128xf16>, %arg1: memref<128x128xf16>, %arg2: memref<128x128xf16>) {
       %c0 = arith.constant 0 : index

diff  --git a/mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f16.mlir b/mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f16.mlir
index b977cdda92b47..52d5faf43f3e5 100644
--- a/mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f16.mlir
+++ b/mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f16.mlir
@@ -1,6 +1,6 @@
 // RUN: mlir-opt %s \
 // RUN: -gpu-kernel-outlining \
-// RUN: -pass-pipeline='gpu.module(strip-debuginfo,convert-gpu-to-nvvm{index-bitwidth=32},gpu-to-cubin{chip=sm_70})' \
+// RUN: -pass-pipeline='gpu.module(strip-debuginfo,convert-gpu-to-nvvm,gpu-to-cubin{chip=sm_70})' \
 // RUN: --convert-scf-to-std -gpu-to-llvm \
 // RUN: | mlir-cpu-runner \
 // RUN:   --shared-libs=%linalg_test_lib_dir/libmlir_cuda_runtime%shlibext \

diff  --git a/mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f32.mlir b/mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f32.mlir
index b3f90e5c1c625..a25f5d3408a39 100644
--- a/mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f32.mlir
+++ b/mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f32.mlir
@@ -1,6 +1,6 @@
 // RUN: mlir-opt %s \
 // RUN: -gpu-kernel-outlining \
-// RUN: -pass-pipeline='gpu.module(strip-debuginfo,convert-gpu-to-nvvm{index-bitwidth=32},gpu-to-cubin{chip=sm_70})' \
+// RUN: -pass-pipeline='gpu.module(strip-debuginfo,convert-gpu-to-nvvm,gpu-to-cubin{chip=sm_70})' \
 // RUN: --convert-scf-to-std -gpu-to-llvm \
 // RUN: | mlir-cpu-runner \
 // RUN:   --shared-libs=%linalg_test_lib_dir/libmlir_cuda_runtime%shlibext \


        


More information about the Mlir-commits mailing list