[Mlir-commits] [mlir] 271a48e - [mlir][VectorToGPU] Fix bug generating incorrect ldmatrix ops
Thomas Raoux
llvmlistbot at llvm.org
Thu Jun 2 21:30:36 PDT 2022
Author: Thomas Raoux
Date: 2022-06-03T04:30:22Z
New Revision: 271a48e02917859cd09ee7f230adea7b6cc7a578
URL: https://github.com/llvm/llvm-project/commit/271a48e02917859cd09ee7f230adea7b6cc7a578
DIFF: https://github.com/llvm/llvm-project/commit/271a48e02917859cd09ee7f230adea7b6cc7a578.diff
LOG: [mlir][VectorToGPU] Fix bug generating incorrect ldmatrix ops
ldmatrix transpose can only be used with types that are 16bits wide.
Differential Revision: https://reviews.llvm.org/D126846
Added:
Modified:
mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index a6e122c380315..90a37dab12ebf 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -624,7 +624,8 @@ convertTransferReadToLoads(vector::TransferReadOp op,
// at least 8 rows to read and the width to read for the transpose is 128
// bits.
if (!op.getPermutationMap().isMinorIdentity() &&
- (vecTy.getDimSize(1) < 8 || vecTy.getDimSize(0) * bitWidth < 128))
+ (bitWidth != 16 || vecTy.getDimSize(1) < 8 ||
+ vecTy.getDimSize(0) * bitWidth < 128))
isLdMatrixCompatible = false;
if (!isLdMatrixCompatible)
diff --git a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir
index be8d08be06ce6..479533d1fa411 100644
--- a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir
+++ b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir
@@ -347,3 +347,63 @@ func.func @m16n8k4_tf32_f32_row_row_row(%arg0: memref<20x20xf32, 3>, %arg1: memr
vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x8xf32>, memref<20x20xf32>
return
}
+
+// -----
+
+#map0 = affine_map<(d0, d1) -> (d1, d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-DAG: [[$rowA_map:#.+]] = affine_map<()[s0] -> (s0 mod 16 + 1)>
+// CHECK-DAG: [[$colA_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 4 + 3)>
+
+// CHECK-DAG: [[$rowB_map:#.+]] = affine_map<()[s0] -> (s0 mod 4 + 3)>
+// CHECK-DAG: [[$colB_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 3)>
+
+// CHECK-DAG: [[$rowC_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4)>
+// CHECK-DAG: [[$rowC8_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 8)>
+// CHECK-DAG: [[$colC_map:#.+]] = affine_map<()[s0] -> (s0 * 2 - (s0 floordiv 4) * 8)>
+
+// CHECK-LABEL: func @m16n8k8_tf32_f32_row_row_row
+func.func @m16n8k8_tf32_f32_row_row_row(%arg0: memref<20x20xf32, 3>, %arg1: memref<20x20xf32, 3>, %arg2: memref<20x20xf32>) {
+ %cst_0 = arith.constant dense<0.000000e+00> : vector<16x8xf32>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c3 = arith.constant 3 : index
+ %cst = arith.constant 0.000000e+00 : f32
+
+ // CHECK: [[c_frag:%.+]] = arith.constant {{.*}} : vector<2x2xf32>
+
+ // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowA_map]]
+ // CHECK-DAG: [[col:%.+]] = affine.apply [[$colA_map]]
+ // CHECK: [[a_frag:%.+]] = nvgpu.ldmatrix %arg0[[[row]], [[col]]] {numTiles = 4 : i32, transpose = false}
+
+ // b and c are not loaded by ldmatrix in this test.
+ // CHECK-NOT: nvgpu.ldmatrix
+
+ // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB_map]]
+ // CHECK-DAG: [[col:%.+]] = affine.apply [[$colB_map]]
+ // CHECK: [[b_el0:%.+]] = memref.load {{%.+}} : memref<20x20xf32, 3>
+ // CHECK: [[b_frag0:%.+]] = vector.insert [[b_el0]], {{.*}} : f32 into vector<2x1xf32>
+ // CHECK: [[b_el1:%.+]] = memref.load {{%.+}} : memref<20x20xf32, 3>
+ // CHECK: [[b_frag1:%.+]] = vector.insert [[b_el1]], {{.*}} : f32 into vector<2x1xf32>
+
+ // CHECK: [[d_frag:%.+]] = nvgpu.mma.sync([[a_frag]], [[b_frag1]], [[c_frag]])
+ // CHECK-SAME: mmaShape = [16, 8, 8]
+ // CHECK-SAME: -> vector<2x2xf32>
+ %A = vector.transfer_read %arg0[%c1, %c3], %cst {in_bounds = [true, true]} : memref<20x20xf32, 3>, vector<16x8xf32>
+ %B = vector.transfer_read %arg1[%c3, %c3], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<20x20xf32, 3>, vector<8x8xf32>
+ %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %cst_0 : vector<16x8xf32>, vector<8x8xf32> into vector<16x8xf32>
+
+ // CHECK: vector.extract [[d_frag]][0] : vector<2x2xf32>
+ // CHECK: affine.apply [[$rowC_map]]
+ // CHECK: affine.apply [[$colC_map]]
+ // CHECK: vector.store
+ // CHECK: vector.extract [[d_frag]][1] : vector<2x2xf32>
+ // CHECK: affine.apply [[$rowC8_map]]
+ // CHECK: affine.apply [[$colC_map]]
+ // CHECK: vector.store
+ vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x8xf32>, memref<20x20xf32>
+ return
+}
More information about the Mlir-commits
mailing list