[Mlir-commits] [mlir] [MLIR][NVGPU-Tests] Fix a failing sm90 test (PR #111731)
Durgadoss R
llvmlistbot at llvm.org
Wed Oct 9 11:08:17 PDT 2024
https://github.com/durga4github created https://github.com/llvm/llvm-project/pull/111731
For memref, output_shape is required now. This patch adds it to the assembly format which fixes the
failing test.
>From 0c20693d445f413f96f298e3beb1aee31375c756 Mon Sep 17 00:00:00 2001
From: Durgadoss R <durgadossr at nvidia.com>
Date: Wed, 9 Oct 2024 13:12:42 +0000
Subject: [PATCH] [MLIR][NVGPU-Tests] Fix a failing sm90 test
For memref, output_shape is required now. This patch
adds it to the assembly format which fixes the
failing test.
Signed-off-by: Durgadoss R <durgadossr at nvidia.com>
---
.../GPU/CUDA/sm90/tma_load_128x128_stride_noswizzle.mlir | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/tma_load_128x128_stride_noswizzle.mlir b/mlir/test/Integration/GPU/CUDA/sm90/tma_load_128x128_stride_noswizzle.mlir
index 4d415394482029..1f284f250fd839 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/tma_load_128x128_stride_noswizzle.mlir
+++ b/mlir/test/Integration/GPU/CUDA/sm90/tma_load_128x128_stride_noswizzle.mlir
@@ -57,7 +57,7 @@ module {
%s4 = gpu.memcpy async [%s3] %srcMemref, %srcMemref_host : memref<128x128xf16>, memref<128x128xf16>
%s5 = gpu.memcpy async [%s4] %dstMemref, %dstMemref_host : memref<128x128xf16>, memref<128x128xf16>
- %expand_shape = memref.expand_shape %srcMemref [[0, 1], [2, 3]] : memref<128x128xf16> into memref<2x64x2x64xf16>
+ %expand_shape = memref.expand_shape %srcMemref [[0, 1], [2, 3]] output_shape [2, 64, 2, 64] : memref<128x128xf16> into memref<2x64x2x64xf16>
%transpose = memref.transpose %expand_shape (d0, d1, d2, d3) -> (d0, d2, d1, d3) : memref<2x64x2x64xf16> to memref<2x2x64x64xf16, strided<[8192, 64, 128, 1]>>
%cast = memref.cast %transpose : memref<2x2x64x64xf16, strided<[8192, 64, 128, 1]>> to memref<*xf16>
%24 = nvgpu.tma.create.descriptor %cast box[%c2, %c2, %c64, %c64] : memref<*xf16> -> <tensor = memref<2x2x64x64xf16, 3>, swizzle = none, l2promo = none, oob = zero, interleave = none>
More information about the Mlir-commits
mailing list