[Mlir-commits] [mlir] f309939 - [mlir][nvvm] Remove special case ptr arithmetic lowering in gpu to nvvm

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Nov 10 10:00:31 PST 2021


Author: thomasraoux
Date: 2021-11-10T10:00:12-08:00
New Revision: f309939d065a67cb2379059d82cb8a76d8b74e3c

URL: https://github.com/llvm/llvm-project/commit/f309939d065a67cb2379059d82cb8a76d8b74e3c
DIFF: https://github.com/llvm/llvm-project/commit/f309939d065a67cb2379059d82cb8a76d8b74e3c.diff

LOG: [mlir][nvvm] Remove special case ptr arithmetic lowering in gpu to nvvm

Use existing helper instead of handling only a subset of indices lowering
arithmetic. Also relax the restriction on the memref rank for the GPU mma ops
as we can now support any rank.

Differential Revision: https://reviews.llvm.org/D113383

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/GPU/GPUOps.td
    mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
    mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td
index 3f1ad84278cb0..5e4d122c69eaa 100644
--- a/mlir/include/mlir/Dialect/GPU/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td
@@ -991,7 +991,7 @@ def GPU_SubgroupMmaLoadMatrixOp : GPU_Op<"subgroup_mma_load_matrix",
     ```
   }];
 
-  let arguments = (ins Arg<MemRefRankOf<[F16, F32], [2]>, "", [MemRead]>:$srcMemref,
+  let arguments = (ins Arg<MemRefOf<[F16, F32]>, "", [MemRead]>:$srcMemref,
                   Variadic<Index>:$indices,
                   IndexAttr:$leadDimension);
 
@@ -1031,7 +1031,7 @@ def GPU_SubgroupMmaStoreMatrixOp : GPU_Op<"subgroup_mma_store_matrix",
   }];
 
   let arguments = (ins Arg<MMAMatrixOf<[F16, F32]>>:$src,
-                  Arg<MemRefRankOf<[F16, F32], [2]>, "",[MemWrite]>:$dstMemref,
+                  Arg<MemRefOf<[F16, F32]>, "",[MemWrite]>:$dstMemref,
                   Variadic<Index>:$indices,
                   IndexAttr:$leadDimension);
 

diff  --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
index b0bf94b7f8066..6de739088b896 100644
--- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
@@ -77,44 +77,6 @@ struct WmmaLoadOpToNVVMLowering
     if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
       return failure();
 
-    Location loc = op->getLoc();
-
-    // MemRefDescriptor to extract alignedPtr and offset.
-    MemRefDescriptor promotedSrcOp(adaptor.srcMemref());
-
-    // Emit ops which compute the load offset using `srcOffsetI`,
-    // `srcOffsetJ`. The actualOffset is (memrefOffset + (alignedPtr +
-    // ((leadDimension * srcOffsetI) + srcOffsetJ)). The memrefs here are
-    // assumed to be normalized and hence the simple conversion works.
-    IntegerAttr leadDimension = subgroupMmaLoadMatrixOp.leadDimensionAttr();
-    SmallVector<Value> indices(adaptor.indices());
-    Value srcOffsetIVal = indices[0];
-    Value srcOffsetJVal = indices[1];
-    Value leadingDim = rewriter.create<LLVM::ConstantOp>(
-        loc, srcOffsetIVal.getType(), leadDimension);
-    Value numElemsLeadDim =
-        rewriter.create<LLVM::MulOp>(loc, leadingDim, srcOffsetIVal);
-    Value loadOffset =
-        rewriter.create<LLVM::AddOp>(loc, numElemsLeadDim, srcOffsetJVal);
-
-    Value promotedSrcOpToUse;
-    promotedSrcOpToUse = promotedSrcOp.offset(rewriter, loc);
-    Value actualOffset =
-        rewriter.create<LLVM::AddOp>(loc, loadOffset, promotedSrcOpToUse);
-    Value loadAddress = rewriter.create<LLVM::GEPOp>(
-        loc, promotedSrcOp.getElementPtrType(),
-        promotedSrcOp.alignedPtr(rewriter, loc), ArrayRef<Value>{actualOffset});
-
-    // Bitcast the base address pointer of the destination memref, So that
-    // values can be stored in chunks of 32-bits and semantics match with the
-    // intrinsic exposed by NVPTX backend.
-    Value loadAddressCasted = rewriter.create<LLVM::BitcastOp>(
-        loc,
-        LLVM::LLVMPointerType::get(
-            rewriter.getI32Type(),
-            promotedSrcOp.getElementPtrType().getAddressSpace()),
-        loadAddress);
-
     // Get the shape of the MMAMatrix type being returned. The shape will
     // choose which intrinsic this op will be lowered to.
     gpu::MMAMatrixType retType =
@@ -146,15 +108,18 @@ struct WmmaLoadOpToNVVMLowering
       return rewriter.notifyMatchFailure(op, kInvalidCaseStr);
 
     Type resType = convertMMAToLLVMType(retType);
+    Location loc = op->getLoc();
 
     // Create nvvm.mma_load op according to the operand types.
-    Value leadingDim32 = rewriter.create<LLVM::ConstantOp>(
-        loc, rewriter.getI32Type(), leadDimension);
+    Value dataPtr = getStridedElementPtr(
+        loc, subgroupMmaLoadMatrixOp.srcMemref().getType().cast<MemRefType>(),
+        adaptor.srcMemref(), adaptor.indices(), rewriter);
 
+    Value leadingDim = rewriter.create<LLVM::ConstantOp>(
+        loc, rewriter.getI32Type(),
+        subgroupMmaLoadMatrixOp.leadDimensionAttr());
     rewriter.replaceOpWithNewOp<NVVM::WMMALoadOp>(
-        op, resType, loadAddressCasted, leadingDim32, m, n, k, layout, eltype,
-        frag);
-
+        op, resType, dataPtr, leadingDim, m, n, k, layout, eltype, frag);
     return success();
   }
 };
@@ -178,41 +143,6 @@ struct WmmaStoreOpToNVVMLowering
 
     Location loc = op->getLoc();
 
-    // MemRefDescriptor to extract alignedPtr and offset.
-    MemRefDescriptor promotedDstOp(adaptor.dstMemref());
-
-    // Emit ops which compute the store offset using `dstOffsetI`,
-    // `dstOffsetJ`. The actualOffset is (memrefOffset + (alignedPtr +
-    // ((leadDimension * dstOffsetI) + dstOffsetJ)).
-    auto leadDimension = subgroupMmaStoreMatrixOp.leadDimensionAttr();
-    SmallVector<Value> indices(adaptor.indices());
-    Value dstOffsetIVal = indices[0];
-    Value dstOffsetJVal = indices[1];
-    Value leadingDim = rewriter.create<LLVM::ConstantOp>(
-        loc, dstOffsetIVal.getType(), leadDimension);
-    Value numElemsLeadDim =
-        rewriter.create<LLVM::MulOp>(loc, leadingDim, dstOffsetIVal);
-    Value loadOffset =
-        rewriter.create<LLVM::AddOp>(loc, numElemsLeadDim, dstOffsetJVal);
-
-    Value promotedDstOpToUse;
-    promotedDstOpToUse = promotedDstOp.offset(rewriter, loc);
-    Value actualOffset =
-        rewriter.create<LLVM::AddOp>(loc, loadOffset, promotedDstOpToUse);
-    Value storeAddress = rewriter.create<LLVM::GEPOp>(
-        loc, promotedDstOp.getElementPtrType(),
-        promotedDstOp.alignedPtr(rewriter, loc), ArrayRef<Value>{actualOffset});
-
-    // Bitcast the base address pointer of the destination memref, So that
-    // values can be stored in chunks of 32-bits and semantics match with the
-    // intrinsic exposed by NVPTX backend.
-    Value storeAddressCasted = rewriter.create<LLVM::BitcastOp>(
-        loc,
-        LLVM::LLVMPointerType::get(
-            rewriter.getI32Type(),
-            promotedDstOp.getElementPtrType().getAddressSpace()),
-        storeAddress);
-
     SmallVector<Value, 4> storeOpOperands;
     // Get the shape of the MMAMatrix type being stored. The shape will
     // choose which intrinsic this op will be lowered to.
@@ -234,12 +164,15 @@ struct WmmaStoreOpToNVVMLowering
           rewriter.getI32ArrayAttr(i));
       storeOpOperands.push_back(toUse);
     }
-    Value leadingDim32 = rewriter.create<LLVM::ConstantOp>(
-        loc, rewriter.getI32Type(), leadDimension);
-    rewriter.create<NVVM::WMMAStoreOp>(loc, storeAddressCasted, m, n, k, layout,
-                                       eltype, storeOpOperands, leadingDim32);
 
-    rewriter.eraseOp(op);
+    Value dataPtr = getStridedElementPtr(
+        loc, subgroupMmaStoreMatrixOp.dstMemref().getType().cast<MemRefType>(),
+        adaptor.dstMemref(), adaptor.indices(), rewriter);
+    Value leadingDim = rewriter.create<LLVM::ConstantOp>(
+        loc, rewriter.getI32Type(),
+        subgroupMmaStoreMatrixOp.leadDimensionAttr());
+    rewriter.replaceOpWithNewOp<NVVM::WMMAStoreOp>(
+        op, dataPtr, m, n, k, layout, eltype, storeOpOperands, leadingDim);
     return success();
   }
 };

diff  --git a/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir
index c0ac8a050288f..f322fcac67647 100644
--- a/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir
@@ -13,32 +13,26 @@ gpu.module @test_module {
     %0 = gpu.subgroup_mma_load_matrix %wg[%i, %j] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "AOp">
     // CHECK:  %[[INX:.*]] = llvm.mlir.constant(16 : index) : i64
     // CHECK: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
+    // CHECK:  %[[BASE:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i64, array<2 x i64>, array<2 x i64>)>
     // CHECK:  %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i64
-    // CHECK:  %[[LI:.*]] = llvm.mul %[[LDM]], %[[INX]]  : i64
+    // CHECK:  %[[LI:.*]] = llvm.mul %[[INX]], %[[LDM]]  : i64
     // CHECK:  %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]]  : i64
-    // CHECK:  %[[OFFSET:.*]] = llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i64, array<2 x i64>, array<2 x i64>)>
-    // CHECK:  %[[LIJO:.*]] = llvm.add %[[LIJ]], %[[OFFSET]]  : i64
-    // CHECK:  %[[BASE:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i64, array<2 x i64>, array<2 x i64>)>
-    // CHECK:  %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJO]]] : (!llvm.ptr<f16, 3>, i64) -> !llvm.ptr<f16, 3>
-    // CHECK:  %[[CADDRESS:.*]] = llvm.bitcast %[[ADDRESS]] : !llvm.ptr<f16, 3> to !llvm.ptr<i32, 3>
+    // CHECK:  %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJ]]] : (!llvm.ptr<f16, 3>, i64) -> !llvm.ptr<f16, 3>
     // CHECK:  %[[LDM32:.*]] = llvm.mlir.constant(32 : index) : i32
-    // CHECK:  %[[FRAG:.*]] = nvvm.wmma.load %[[CADDRESS]], %[[LDM32]]
-    // CHECK-SAME: {eltype = "f16", frag = "a", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32}  : (!llvm.ptr<i32, 3>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+    // CHECK:  %[[FRAG:.*]] = nvvm.wmma.load %[[ADDRESS]], %[[LDM32]]
+    // CHECK-SAME: {eltype = "f16", frag = "a", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32}  : (!llvm.ptr<f16, 3>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
     // CHECK:  llvm.return %[[FRAG]] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
 
     // CHECK32:  %[[INX:.*]] = llvm.mlir.constant(16 : index) : i32
     // CHECK32: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
+    // CHECK32:  %[[BASE:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i32, array<2 x i32>, array<2 x i32>)>
     // CHECK32:  %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i32
-    // CHECK32:  %[[LI:.*]] = llvm.mul %[[LDM]], %[[INX]]  : i32
+    // CHECK32:  %[[LI:.*]] = llvm.mul %[[INX]], %[[LDM]]  : i32
     // CHECK32:  %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]]  : i32
-    // CHECK32:  %[[OFFSET:.*]] = llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i32, array<2 x i32>, array<2 x i32>)>
-    // CHECK32:  %[[LIJO:.*]] = llvm.add %[[LIJ]], %[[OFFSET]]  : i32
-    // CHECK32:  %[[BASE:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i32, array<2 x i32>, array<2 x i32>)>
-    // CHECK32:  %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJO]]] : (!llvm.ptr<f16, 3>, i32) -> !llvm.ptr<f16, 3>
-    // CHECK32:  %[[CADDRESS:.*]] = llvm.bitcast %[[ADDRESS]] : !llvm.ptr<f16, 3> to !llvm.ptr<i32, 3>
+    // CHECK32:  %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJ]]] : (!llvm.ptr<f16, 3>, i32) -> !llvm.ptr<f16, 3>
     // CHECK32:  %[[LDM32:.*]] = llvm.mlir.constant(32 : index) : i32
-    // CHECK32:  %[[FRAG:.*]] = nvvm.wmma.load %[[CADDRESS]], %[[LDM32]]
-    // CHECK32-SAME: {eltype = "f16", frag = "a", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32}  : (!llvm.ptr<i32, 3>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+    // CHECK32:  %[[FRAG:.*]] = nvvm.wmma.load %[[ADDRESS]], %[[LDM32]]
+    // CHECK32-SAME: {eltype = "f16", frag = "a", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32}  : (!llvm.ptr<f16, 3>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
     // CHECK32:  llvm.return %[[FRAG]] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
     return %0 : !gpu.mma_matrix<16x16xf16, "AOp">
   }
@@ -59,40 +53,34 @@ gpu.module @test_module {
     gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16, 3>
     // CHECK:  %[[INX:.*]] = llvm.mlir.constant(16 : index) : i64
     // CHECK: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
-    // CHECK:  %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i64
-    // CHECK:  %[[LI:.*]] = llvm.mul %[[LDM]], %[[INX]]  : i64
-    // CHECK:  %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]]  : i64
-    // CHECK:  %[[OFFSET:.*]] = llvm.extractvalue %17[2] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i64, array<2 x i64>, array<2 x i64>)>
-    // CHECK:  %[[LIJO:.*]] = llvm.add %[[LIJ]], %[[OFFSET]]  : i64
-    // CHECK:  %[[BASE:.*]] = llvm.extractvalue %17[1] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i64, array<2 x i64>, array<2 x i64>)>
-    // CHECK:  %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJO]]] : (!llvm.ptr<f16, 3>, i64) -> !llvm.ptr<f16, 3>
-    // CHECK:  %[[CADDRESS:.*]] = llvm.bitcast %[[ADDRESS]] : !llvm.ptr<f16, 3> to !llvm.ptr<i32, 3>
     // CHECK:  %[[EL1:.*]] = llvm.extractvalue %[[D]][0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
     // CHECK:  %[[EL2:.*]] = llvm.extractvalue %[[D]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
     // CHECK:  %[[EL3:.*]] = llvm.extractvalue %[[D]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
     // CHECK:  %[[EL4:.*]] = llvm.extractvalue %[[D]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+    // CHECK:  %[[BASE:.*]] = llvm.extractvalue %17[1] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i64, array<2 x i64>, array<2 x i64>)>
+    // CHECK:  %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i64
+    // CHECK:  %[[LI:.*]] = llvm.mul %[[INX]], %[[LDM]]   : i64
+    // CHECK:  %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]]  : i64
+    // CHECK:  %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJ]]] : (!llvm.ptr<f16, 3>, i64) -> !llvm.ptr<f16, 3>
     // CHECK:  %[[LDM32:.*]] = llvm.mlir.constant(32 : index) : i32
-    // CHECK:  nvvm.wmma.store %[[CADDRESS]], %[[LDM32]], %[[EL1]], %[[EL2]], %[[EL3]], %[[EL4]]
-    // CHECK-SAME: {eltype = "f16", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : !llvm.ptr<i32, 3>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>
+    // CHECK:  nvvm.wmma.store %[[ADDRESS]], %[[LDM32]], %[[EL1]], %[[EL2]], %[[EL3]], %[[EL4]]
+    // CHECK-SAME: {eltype = "f16", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : !llvm.ptr<f16, 3>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>
     // CHECK:  llvm.return
 
     // CHECK32:  %[[INX:.*]] = llvm.mlir.constant(16 : index) : i32
     // CHECK32: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
-    // CHECK32:  %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i32
-    // CHECK32:  %[[LI:.*]] = llvm.mul %[[LDM]], %[[INX]]  : i32
-    // CHECK32:  %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]]  : i32
-    // CHECK32:  %[[OFFSET:.*]] = llvm.extractvalue %17[2] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i32, array<2 x i32>, array<2 x i32>)>
-    // CHECK32:  %[[LIJO:.*]] = llvm.add %[[LIJ]], %[[OFFSET]]  : i32
-    // CHECK32:  %[[BASE:.*]] = llvm.extractvalue %17[1] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i32, array<2 x i32>, array<2 x i32>)>
-    // CHECK32:  %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJO]]] : (!llvm.ptr<f16, 3>, i32) -> !llvm.ptr<f16, 3>
-    // CHECK32:  %[[CADDRESS:.*]] = llvm.bitcast %[[ADDRESS]] : !llvm.ptr<f16, 3> to !llvm.ptr<i32, 3>
     // CHECK32:  %[[EL1:.*]] = llvm.extractvalue %[[D]][0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
     // CHECK32:  %[[EL2:.*]] = llvm.extractvalue %[[D]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
     // CHECK32:  %[[EL3:.*]] = llvm.extractvalue %[[D]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
     // CHECK32:  %[[EL4:.*]] = llvm.extractvalue %[[D]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+    // CHECK32:  %[[BASE:.*]] = llvm.extractvalue %17[1] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i32, array<2 x i32>, array<2 x i32>)>
+    // CHECK32:  %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i32
+    // CHECK32:  %[[LI:.*]] = llvm.mul %[[INX]], %[[LDM]]   : i32
+    // CHECK32:  %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]]  : i32
+    // CHECK32:  %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJ]]] : (!llvm.ptr<f16, 3>, i32) -> !llvm.ptr<f16, 3>
     // CHECK32:  %[[LDM32:.*]] = llvm.mlir.constant(32 : index) : i32
-    // CHECK32:  nvvm.wmma.store %[[CADDRESS]],  %[[LDM32]], %[[EL1]], %[[EL2]], %[[EL3]], %[[EL4]]
-    // CHECK32-SAME: {eltype = "f16", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : !llvm.ptr<i32, 3>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>
+    // CHECK32:  nvvm.wmma.store %[[ADDRESS]], %[[LDM32]], %[[EL1]], %[[EL2]], %[[EL3]], %[[EL4]]
+    // CHECK32-SAME: {eltype = "f16", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : !llvm.ptr<f16, 3>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>
     // CHECK32:  llvm.return
     return
   }
@@ -139,13 +127,13 @@ gpu.module @test_module {
 gpu.module @test_module {
 
 // CHECK-LABEL: func @gpu_wmma_mma_loop_op
-//       CHECK:   %[[C:.+]] = nvvm.wmma.load %{{.*}}, %{{.*}} {eltype = "f16", frag = "c", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : (!llvm.ptr<i32>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK:   %[[C:.+]] = nvvm.wmma.load %{{.*}}, %{{.*}} {eltype = "f16", frag = "c", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : (!llvm.ptr<f16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
 //       CHECK:   llvm.br ^bb1(%{{.*}}, %[[C]] : i64, !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>)
 //       CHECK:  ^bb1(%{{.*}}: i64, %[[ACC:.+]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>):  // 2 preds: ^bb0, ^bb2
 //       CHECK:   llvm.cond_br %{{.*}}, ^bb2, ^bb3
 //       CHECK:  ^bb2:  // pred: ^bb1
-//       CHECK:   %[[A:.+]] = nvvm.wmma.load %{{.*}}, %{{.*}} {eltype = "f16", frag = "a", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : (!llvm.ptr<i32>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
-//       CHECK:   %[[B:.+]] = nvvm.wmma.load %{{.*}}, %{{.*}} {eltype = "f16", frag = "b", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : (!llvm.ptr<i32>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK:   %[[A:.+]] = nvvm.wmma.load %{{.*}}, %{{.*}} {eltype = "f16", frag = "a", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : (!llvm.ptr<f16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK:   %[[B:.+]] = nvvm.wmma.load %{{.*}}, %{{.*}} {eltype = "f16", frag = "b", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : (!llvm.ptr<f16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
 //       CHECK:   %[[A0:.+]] = llvm.extractvalue %[[A]][0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
 //       CHECK:   %[[A1:.+]] = llvm.extractvalue %[[A]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
 //       CHECK:   %[[A2:.+]] = llvm.extractvalue %[[A]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
@@ -173,7 +161,7 @@ gpu.module @test_module {
 //       CHECK:   %[[E1:.+]] = llvm.extractvalue %[[ACC]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
 //       CHECK:   %[[E2:.+]] = llvm.extractvalue %[[ACC]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
 //       CHECK:   %[[E3:.+]] = llvm.extractvalue %[[ACC]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
-//       CHECK:   nvvm.wmma.store %{{.*}}, %{{.*}}, %[[E0]], %[[E1]], %[[E2]], %[[E3]] {eltype = "f16", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : !llvm.ptr<i32>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>
+//       CHECK:   nvvm.wmma.store %{{.*}}, %{{.*}}, %[[E0]], %[[E1]], %[[E2]], %[[E3]] {eltype = "f16", k = 16 : i32, layout = "row", m = 16 : i32, n = 16 : i32} : !llvm.ptr<f16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>
 
   builtin.func @gpu_wmma_mma_loop_op(%arg0: memref<128x128xf16>, %arg1: memref<128x128xf16>, %arg2: memref<128x128xf16>) {
       %c0 = arith.constant 0 : index


        


More information about the Mlir-commits mailing list