[Mlir-commits] [mlir] 4e4af13 - [mlir][gpu][nvvm] fixed bug with literal for inline asm for mma instruction

Aart Bik llvmlistbot at llvm.org
Fri Mar 17 09:22:31 PDT 2023


Author: Aart Bik
Date: 2023-03-17T09:22:15-07:00
New Revision: 4e4af1338da5bdbf10e113c0462d7eb7222b5d97

URL: https://github.com/llvm/llvm-project/commit/4e4af1338da5bdbf10e113c0462d7eb7222b5d97
DIFF: https://github.com/llvm/llvm-project/commit/4e4af1338da5bdbf10e113c0462d7eb7222b5d97.diff

LOG: [mlir][gpu][nvvm] fixed bug with literal for inline asm for mma instruction

The 'mma.sp.sync.aligned' family of instructions expects
the sparsity selector as a direct literal (0x0 or 0x1).
The current MLIR inline asm passed this as a value in
register, which broke the downstream assemblers

This is a small step towards supporting 2:4 sparsity on
NVidia GPUs in the sparse compiler of MLIR.

Reviewed By: ThomasRaoux, guraypp

Differential Revision: https://reviews.llvm.org/D146110

Added: 
    

Modified: 
    mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
    mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index d27af456f41b9..e88bb05d0b0b9 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -401,22 +401,22 @@ static std::string buildMmaSparseAsmConstraintString(unsigned matASize,
     ss << "=r,";
   for (unsigned i = 0; i < matASize + matBSize + matCSize; i++)
     ss << "r,";
-  // The final two operands are for the sparsity metadata and sparsity selector.
-  ss << "r,r";
+  // The final operand is for the sparsity metadata.
+  // The sparsity selector appears as direct literal.
+  ss << "r";
   ss.flush();
   return str;
 }
 
 /// Returns the string for the `mma.sp.sync` instruction that corresponds to
-/// the give parameters. Note that this function doesn't do any validation,
+/// the given parameters. Note that this function doesn't do any validation,
 /// it's expected that the provided parameters correspond to a valid
 /// instruction.
-static std::string
-buildMmaSparseAsmString(const std::array<int64_t, 3> &shape, unsigned matASize,
-                        unsigned matBSize, unsigned matCSize,
-                        NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
-                        NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
-                        std::optional<NVVM::MMAIntOverflow> overflow) {
+static std::string buildMmaSparseAsmString(
+    const std::array<int64_t, 3> &shape, unsigned matASize, unsigned matBSize,
+    unsigned matCSize, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
+    NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
+    std::optional<NVVM::MMAIntOverflow> overflow, unsigned metaDataSelector) {
   auto ptxTypeStr = [](NVVM::MMATypes ptxType) {
     return NVVM::stringifyMMATypes(ptxType);
   };
@@ -442,7 +442,8 @@ buildMmaSparseAsmString(const std::array<int64_t, 3> &shape, unsigned matASize,
     ss << "},";
   }
   ss << "$" << asmArgIdx++ << ",";
-  ss << "$" << asmArgIdx++ << ";";
+  assert(metaDataSelector <= 1);
+  ss << "0x" << metaDataSelector << ";";
   ss.flush();
   return asmStr;
 }
@@ -459,22 +460,21 @@ static FailureOr<LLVM::InlineAsmOp> emitMmaSparseSyncOpAsm(
   auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
                                                   LLVM::AsmDialect::AD_ATT);
 
-  std::string asmStr = buildMmaSparseAsmString(
-      shape, unpackedAData.size(), unpackedB.size(), unpackedC.size(), ptxTypeA,
-      ptxTypeB, ptxTypeC, ptxTypeD, overflow);
-  std::string constraintStr = buildMmaSparseAsmConstraintString(
-      unpackedAData.size(), unpackedB.size(), unpackedC.size());
+  const unsigned matASize = unpackedAData.size();
+  const unsigned matBSize = unpackedB.size();
+  const unsigned matCSize = unpackedC.size();
 
-  Value selectorVal = rewriter.create<LLVM::ConstantOp>(
-      loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(metadataSelector));
+  std::string asmStr = buildMmaSparseAsmString(
+      shape, matASize, matBSize, matCSize, ptxTypeA, ptxTypeB, ptxTypeC,
+      ptxTypeD, overflow, metadataSelector);
+  std::string constraintStr =
+      buildMmaSparseAsmConstraintString(matASize, matBSize, matCSize);
 
   SmallVector<Value> asmVals;
-  asmVals.reserve(unpackedAData.size() + unpackedB.size() + unpackedC.size() +
-                  2);
+  asmVals.reserve(matASize + matBSize + matCSize + 1);
   for (ArrayRef<Value> args : {unpackedAData, unpackedB, unpackedC})
     llvm::append_range(asmVals, args);
   asmVals.push_back(indexData);
-  asmVals.push_back(selectorVal);
 
   return rewriter.create<LLVM::InlineAsmOp>(loc,
                                             /*resultTypes=*/intrinsicResultType,

diff  --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index f4125847230de..1133e09c538fa 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -333,12 +333,11 @@ func.func @mma_sp_sync_f16_16832(%arg0: vector<4x2xf16>,
   // CHECK-NOT llvm.extractvalue
 
   // CHECK: %[[sparseMetadata:.+]] = llvm.bitcast %{{.+}} : vector<2xi16> to i32
-  // CHECK: %[[sparseSelector:.+]] = llvm.mlir.constant(0 : i32) : i32
 
   // CHECK: %[[d:.+]] = llvm.inline_asm has_side_effects asm_dialect = att
-  // CHECK-SAME: "mma.sp.sync.aligned.m16n8k32.row.col.f16.f16.f16.f16 {$0,$1},{$2,$3,$4,$5},{$6,$7,$8,$9},{$10,$11},$12,$13;"
-  // CHECK-SAME: "=r,=r,r,r,r,r,r,r,r,r,r,r,r,r"
-  // CHECK-SAME: %[[sparseMetadata]], %[[sparseSelector]] :
+  // CHECK-SAME: "mma.sp.sync.aligned.m16n8k32.row.col.f16.f16.f16.f16 {$0,$1},{$2,$3,$4,$5},{$6,$7,$8,$9},{$10,$11},$12,0x0;"
+  // CHECK-SAME: "=r,=r,r,r,r,r,r,r,r,r,r,r,r"
+  // CHECK-SAME: %[[sparseMetadata]] :
   // CHECK-SAME: -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
 
   %d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) {mmaShape = [16, 8, 32]} :
@@ -372,12 +371,11 @@ func.func @mma_sp_sync_f16_16816(%arg0: vector<2x2xf16>,
   // CHECK-NOT llvm.extractvalue
 
   // CHECK: %[[sparseMetadata:.+]] = llvm.bitcast %{{.+}} : vector<2xi16> to i32
-  // CHECK: %[[sparseSelector:.+]] = llvm.mlir.constant(0 : i32) : i32
 
   // CHECK: %[[d:.+]] = llvm.inline_asm has_side_effects asm_dialect = att
-  // CHECK-SAME: "mma.sp.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {$0,$1},{$2,$3},{$4,$5},{$6,$7},$8,$9;"
-  // CHECK-SAME: "=r,=r,r,r,r,r,r,r,r,r"
-  // CHECK-SAME: %[[sparseMetadata]], %[[sparseSelector]] :
+  // CHECK-SAME: "mma.sp.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {$0,$1},{$2,$3},{$4,$5},{$6,$7},$8,0x0;"
+  // CHECK-SAME: "=r,=r,r,r,r,r,r,r,r"
+  // CHECK-SAME: %[[sparseMetadata]] :
   // CHECK-SAME: -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
 
   %d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) {mmaShape = [16, 8, 16]} :
@@ -413,12 +411,11 @@ func.func @mma_sp_sync_i8_16864(%arg0: vector<4x4xi8>,
   // CHECK-NOT llvm.extractvalue
 
   // CHECK: %[[sparseMetadata:.+]] = llvm.bitcast %{{.+}} : vector<2xi16> to i32
-  // CHECK: %[[sparseSelector:.+]] = llvm.mlir.constant(0 : i32) : i32
 
   // CHECK: %[[d:.+]] = llvm.inline_asm has_side_effects asm_dialect = att
-  // CHECK-SAME: "mma.sp.sync.aligned.m16n8k64.row.col.satfinite.s32.s8.s8.s32
-  // CHECK-SAME: "=r,=r,=r,=r,r,r,r,r,r,r,r,r,r,r,r,r,r,r"
-  // CHECK-SAME: %[[sparseMetadata]], %[[sparseSelector]] :
+  // CHECK-SAME: "mma.sp.sync.aligned.m16n8k64.row.col.satfinite.s32.s8.s8.s32 {$0,$1,$2,$3},{$4,$5,$6,$7},{$8,$9,$10,$11},{$12,$13,$14,$15},$16,0x0;"
+  // CHECK-SAME: "=r,=r,=r,=r,r,r,r,r,r,r,r,r,r,r,r,r,r"
+  // CHECK-SAME: %[[sparseMetadata]] :
   // CHECK-SAME: -> !llvm.struct<(i32, i32, i32, i32)
 
   %d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) {mmaShape = [16, 8, 64]} :


        


More information about the Mlir-commits mailing list