[Mlir-commits] [mlir] 9774cd1 - [mlir][nvgpu] Fix affine maps computing indices for LdMatrixOp srcMemref

Manish Gupta llvmlistbot at llvm.org
Thu Dec 1 18:30:55 PST 2022


Author: Manish Gupta
Date: 2022-12-01T18:26:33-08:00
New Revision: 9774cd17e80fc413cef73e1e7e9bac20ef21ebae

URL: https://github.com/llvm/llvm-project/commit/9774cd17e80fc413cef73e1e7e9bac20ef21ebae
DIFF: https://github.com/llvm/llvm-project/commit/9774cd17e80fc413cef73e1e7e9bac20ef21ebae.diff

LOG: [mlir][nvgpu] Fix affine maps computing indices for LdMatrixOp srcMemref

This patch fixes and simplifies the ldmatrix affine map arithmetic by
abstracting the affine expressions in terms of pitch-linear layout
(strided and contiguous dimensions). Then it applies the maps for
strided and contiguous dimensions in row-major and col-major.

LdMatrixOp collaboratively (32 threads in a warp) load tiles
(8 row x 128b col) of data. It can load either x1, x2, x4 tiles.
Additionally, it can transpose at 16-bit granularity when moving
data from the Shared Memory to registers.

This patch fixes affine map:
(laneid -> coordinate index a thread points in a tile).

- Loading x4 tiles needs all 32 lanes T0-31 point to a contiguous
  chunk of 128b. The issue was exposed when running this case.
- Loading x2 tiles and x1 needs T0-15 threads and T0-7 threads points
  to contiguous chunk of 128b. The patch is NFC for these cases.

Differential Revision: https://reviews.llvm.org/D138978

Added: 
    

Modified: 
    mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp
    mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp b/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp
index 6de16f84668ad..6fdaaad746ec1 100644
--- a/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp
+++ b/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp
@@ -238,7 +238,6 @@ FailureOr<AffineMap>
 nvgpu::getLaneIdToLdMatrixMatrixCoord(Location loc, OpBuilder &builder,
                                       const LdMatrixParams &params) {
   // One thread per 128b row.
-  const int64_t kNumThreadsPerTile = kNumRowsPerTile;
   const int bitsPerElement = static_cast<int>(
       params.fragmentType.getElementType().getIntOrFloatBitWidth());
   const int kElementsPer128b = (128 / bitsPerElement);
@@ -249,27 +248,28 @@ nvgpu::getLaneIdToLdMatrixMatrixCoord(Location loc, OpBuilder &builder,
     return AffineMap::get(1, 0, dimExprs, builder.getContext());
   };
 
-  // This case corresponds to row-major A|C or col-major B operands.
-  if (params.contiguousDimType == vector::IteratorType::reduction) {
-    AffineExpr row = d0 % (operandShape[0]);
-    AffineExpr col = d0.floorDiv(operandShape[0]) * (kElementsPer128b);
-    return makeMap({row, col});
-  }
+  // Index `idx` in vectorType `operandShape` maps to the strided dimension of
+  // the `srcMemref` memory of the LdMatrixOp.
+  int idx =
+      (params.contiguousDimType == vector::IteratorType::reduction) ? 0 : 1;
+
+  // Affine expr in strided and contiguous dimension encodes the coordinate
+  // mapping for the element a thread points to for warp-wide LdMatrixOp.
+  AffineExpr strided = d0 % (operandShape[idx]);
+  AffineExpr contiguous = d0.floorDiv(operandShape[idx]) * (kElementsPer128b);
+
+  // This case corresponds to row-major matrixA or col-major matrixB or
+  // row-major matrixC. This is when the memory layout in `srcMemref`
+  // match mma.sync hardware vector register operand layout.
+  if (params.contiguousDimType == vector::IteratorType::reduction)
+    return makeMap({strided, contiguous});
+
+  // This case corresponds to col-major matrixA or row-major matrixB or
+  // col-major matrixC. This is when the memory layout in `srcMemref` does not
+  // match mma.sync hardware vector register operand layout.
+  if (params.contiguousDimType == vector::IteratorType::parallel)
+    return makeMap({contiguous, strided});
 
-  // This case Corresponds to col-major A|C or row-major B operands. The
-  // operandShape given is already pre-transposed (e.g. 8x16 = KxN).
-  if (params.contiguousDimType == vector::IteratorType::parallel) {
-    const int64_t num8x128bCols = (operandShape[0] * bitsPerElement) / 128;
-    // Threads are assigned in groups of 8 first across columns, then to
-    // rows. This is transpose of what `ldmatrix` expects, but when
-    // `ldmatrix` gets the `.trans` qualifier, final the effect will be to
-    // transpose just the blocks.
-    auto groupIdx = d0.floorDiv(kNumThreadsPerTile);
-    auto tileCol = (groupIdx % num8x128bCols);
-    auto tileRow = groupIdx.floorDiv(num8x128bCols);
-    return makeMap({tileCol * kElementsPer128b,
-                    tileRow * kNumRowsPerTile + (d0 % kNumRowsPerTile)});
-  }
   return failure();
 }
 

diff  --git a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir
index b2e34faf09147..d12c2a56ef5be 100644
--- a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir
+++ b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir
@@ -4,8 +4,8 @@
 // INT8 row-row-row
 //#########################################################
 
-// CHECK-DAG: [[$rowA0_map:#.+]] = affine_map<()[s0] -> (s0 mod 16 + 1)>
-// CHECK-DAG: [[$colA0_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 16 + 1)>
+// CHECK-DAG: [[$strided_map:#.+]] = affine_map<()[s0] -> (s0 mod 16)>
+// CHECK-DAG: [[$contiguous_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 16)>
 
 // CHECK-DAG: [[$rowB0_map:#.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16 + 39)>
 // CHECK-DAG: [[$colB0_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 40)>
@@ -40,14 +40,15 @@ func.func @m16n8k32_int8_row_row_row(%arg0: memref<128x128xi8, 3>, %arg1: memref
   %cst = arith.constant 0 : i8
   %cst0 = arith.constant 0 : i32
 
-  // Verify that the operand A is distributed to loads correctly.
+  // Verify that the operandA load is lowered to warp-wide ldmatrix.
 
-  // CHECK: [[row:%.+]] = affine.apply [[$rowA0_map]]()[{{%.+}}]
-  // CHECK: [[col:%.+]] = affine.apply [[$colA0_map]]()[{{%.+}}]
-  // CHECK: nvgpu.ldmatrix %arg0[[[row]], [[col]]] {numTiles = 4 : i32, transpose = false} : memref<128x128xi8, 3> -> vector<4x4xi8>
+  // CHECK: [[m_coord:%.+]] = affine.apply [[$strided_map]]()[{{%.+}}]
+  // CHECK: [[k_coord:%.+]] = affine.apply [[$contiguous_map]]()[{{%.+}}]
+  // CHECK: nvgpu.ldmatrix %arg0[[[m_coord]], [[k_coord]]] {numTiles = 4 : i32, transpose = false} : memref<128x128xi8, 3> -> vector<4x4xi8>
 
-  // Verify that the operand B is distributed to loads correctly. It's elements
-  // must be loaded in a non-vectorized manner to do the transpose.
+  // Verify that the operandB load is lowered to scalar load to be able
+  // to transpose at 8-bit granularity. ldmatrix can only transpose at 
+  // 16-bit granularity.
 
   // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB0_map]]()[{{%.+}}]
   // CHECK-DAG: [[col:%.+]] = affine.apply [[$colB0_map]]()[{{%.+}}]
@@ -84,7 +85,7 @@ func.func @m16n8k32_int8_row_row_row(%arg0: memref<128x128xi8, 3>, %arg1: memref
   // CHECK: vector.load %arg2[[[row]], [[col]]] : memref<128x128xi32>, vector<2xi32>
   // CHECK-NOT: vector.load %arg2{{.*}}
 
-  %A = vector.transfer_read %arg0[%c1, %c1], %cst {in_bounds = [true, true]} : memref<128x128xi8, 3>, vector<16x32xi8>
+  %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<128x128xi8, 3>, vector<16x32xi8>
   %B = vector.transfer_read %arg1[%c39, %c40], %cst {in_bounds = [true, true], permutation_map = #map0} : memref<128x128xi8, 3>, vector<8x32xi8>
   %C = vector.transfer_read %arg2[%c49, %c40], %cst0 {in_bounds = [true, true]} : memref<128x128xi32>, vector<16x8xi32>
   // CHECK: [[d:%.+]] = nvgpu.mma.sync({{.*}}) {mmaShape = [16, 8, 32]} : (vector<4x4xi8>, vector<2x4xi8>, vector<2x2xi32>) -> vector<2x2xi32>
@@ -173,28 +174,23 @@ func.func @m8n8k4_f64_row_row_row(%arg0: memref<128x128xf64>, %arg1: memref<128x
 #map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
 #map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
 
-// CHECK-DAG: [[$rowA_map:#.+]] = affine_map<()[s0] -> (s0 mod 16 + 1)>
-// CHECK-DAG: [[$colA_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 8 + 3)>
-
-// CHECK-DAG: [[$rowB_map:#.+]] = affine_map<()[s0] -> (s0 + 3)>
-// CHECK-DAG: [[$colB_map:#.+]] = affine_map<() -> (3)>
+// CHECK-DAG: [[$strided_map:#.+]] = affine_map<()[s0] -> (s0 mod 16)>
+// CHECK-DAG: [[$contiguous_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 8)>
 
 // CHECK-LABEL: func @m16n8k16_fp16_row_row_row
 func.func @m16n8k16_fp16_row_row_row(%arg0: memref<20x20xf16, 3>, %arg1: memref<20x20xf16, 3>, %arg2: memref<20x20xf16, 3>) {
   %cst_0 = arith.constant dense<0.000000e+00> : vector<16x8xf16>
   %c0 = arith.constant 0 : index
-  %c1 = arith.constant 1 : index
-  %c3 = arith.constant 3 : index
   %cst = arith.constant 0.000000e+00 : f16
-  // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowA_map]]
-  // CHECK-DAG: [[col:%.+]] = affine.apply [[$colA_map]]
-  // CHECK: nvgpu.ldmatrix %arg0[[[row]], [[col]]] {numTiles = 4 : i32, transpose = false}
 
-  // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB_map]]
-  // CHECK-DAG: [[col:%.+]] = affine.apply [[$colB_map]]
-  // CHECK: nvgpu.ldmatrix %arg1[[[row]], [[col]]] {numTiles = 2 : i32, transpose = true}
-  %A = vector.transfer_read %arg0[%c1, %c3], %cst {in_bounds = [true, true]} : memref<20x20xf16, 3>, vector<16x16xf16>
-  %B = vector.transfer_read %arg1[%c3, %c3], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<20x20xf16, 3>, vector<8x16xf16>
+  // CHECK-DAG: [[m_coord:%.+]] = affine.apply [[$strided_map]]
+  // CHECK-DAG: [[k_coord:%.+]] = affine.apply [[$contiguous_map]]
+  // CHECK: nvgpu.ldmatrix %arg0[[[m_coord]], [[k_coord]]] {numTiles = 4 : i32, transpose = false}
+  // CHECK-DAG: [[n_coord:%.+]] = affine.apply [[$contiguous_map]]
+  // CHECK-DAG: [[k_coord:%.+]] = affine.apply [[$strided_map]]
+  // CHECK: nvgpu.ldmatrix %arg1[[[k_coord]], [[n_coord]]] {numTiles = 2 : i32, transpose = true}
+  %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<20x20xf16, 3>, vector<16x16xf16>
+  %B = vector.transfer_read %arg1[%c0, %c0], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<20x20xf16, 3>, vector<8x16xf16>
   %C = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<20x20xf16, 3>, vector<16x8xf16>
   %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %C : vector<16x16xf16>, vector<8x16xf16> into vector<16x8xf16>
   vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x8xf16>, memref<20x20xf16, 3>
@@ -207,10 +203,8 @@ func.func @m16n8k16_fp16_row_row_row(%arg0: memref<20x20xf16, 3>, %arg1: memref<
 // FP16 row-row-row (ldmatrix x4 for matrixA and ldmatrix x4 for matrixB)
 //#########################################################################
 
-// CHECK-DAG: [[$rowA_map:#.+]] = affine_map<()[s0] -> (s0 mod 16)>
-// CHECK-DAG: [[$colA_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 8)>
-// CHECK-DAG: [[$rowB_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 8) * 8 - ((s0 floordiv 8) floordiv 2) * 16)>
-// CHECK-DAG: [[$colB_map:#.+]] = affine_map<()[s0] -> (s0 mod 8 + ((s0 floordiv 8) floordiv 2) * 8)>
+// CHECK-DAG: [[$strided_map:#.+]] = affine_map<()[s0] -> (s0 mod 16)>
+// CHECK-DAG: [[$contiguous_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 8)>
 
 #map0 = affine_map<(d0, d1) -> (d1, d0)>
 #map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
@@ -224,19 +218,19 @@ func.func @m16n16k16_mmasync16816_fp16_f16_row_row_row(%arg0: memref<42x32xf16,
   %c8 = arith.constant 8 : index
   %cst = arith.constant 0.000000e+00 : f16
 
-  // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowA_map]]
-  // CHECK-DAG: [[col:%.+]] = affine.apply [[$colA_map]]
-  // CHECK: [[fragmentA:%.+]] = nvgpu.ldmatrix %arg0[[[row]], [[col]]] {numTiles = 4 : i32, transpose = false}
+  // CHECK-DAG: [[m_coord:%.+]] = affine.apply [[$strided_map]]
+  // CHECK-DAG: [[k_coord:%.+]] = affine.apply [[$contiguous_map]]
+  // CHECK: [[fragmentA:%.+]] = nvgpu.ldmatrix %arg0[[[m_coord]], [[k_coord]]] {numTiles = 4 : i32, transpose = false}
   %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<42x32xf16, 3>, vector<16x16xf16>
 
-  // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB_map]]
-  // CHECK-DAG: [[col:%.+]] = affine.apply [[$colB_map]]
-  // CHECK-DAG: [[fragmentB:%.+]] = nvgpu.ldmatrix %arg1[[[col]], [[row]]] {numTiles = 4 : i32, transpose = true}
+  // CHECK-DAG: [[n_coord:%.+]] = affine.apply [[$contiguous_map]]
+  // CHECK-DAG: [[k_coord:%.+]] = affine.apply [[$strided_map]]
+  // CHECK-DAG: [[fragmentB:%.+]] = nvgpu.ldmatrix %arg1[[[k_coord]], [[n_coord]]] {numTiles = 4 : i32, transpose = true}
   %B = vector.transfer_read %arg1[%c0, %c0], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<32x64xf16, 3>, vector<16x16xf16>
 
-  // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowA_map]]
-  // CHECK-DAG: [[col:%.+]] = affine.apply [[$colA_map]]
-  // CHECK-DAG: [[fragmentC:%.*]] = nvgpu.ldmatrix %arg2[[[row]], [[col]]] {numTiles = 4 : i32, transpose = false}
+  // CHECK-DAG: [[m_coord:%.+]] = affine.apply [[$strided_map]]
+  // CHECK-DAG: [[n_coord:%.+]] = affine.apply [[$contiguous_map]]
+  // CHECK-DAG: [[fragmentC:%.*]] = nvgpu.ldmatrix %arg2[[[m_coord]], [[n_coord]]] {numTiles = 4 : i32, transpose = false}
   %C = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<42x64xf16, 3>, vector<16x16xf16>
 
   // CHECK-DAG: [[fragmentB0:%.+]] = vector.extract_strided_slice [[fragmentB]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf16> to vector<2x2xf16>
@@ -259,10 +253,8 @@ func.func @m16n16k16_mmasync16816_fp16_f16_row_row_row(%arg0: memref<42x32xf16,
 }
 // -----
 
-// CHECK-DAG: [[$Arow_map:#.+]] = affine_map<()[s0] -> (s0 mod 16 + 1)>
-// CHECK-DAG: [[$Acol_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 8 + 3)>
-// CHECK-DAG: [[$Bcol_map:#.+]] = affine_map<() -> (3)>
-// CHECK-DAG: [[$Brow_map:#.+]] = affine_map<()[s0] -> (s0 + 3)>
+// CHECK-DAG: [[$strided_map:#.+]] = affine_map<()[s0] -> (s0 mod 16)>
+// CHECK-DAG: [[$contiguous_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 8)>
 
 #map0 = affine_map<(d0, d1, d2) -> (d2, d1)>
 #map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
@@ -274,26 +266,24 @@ func.func @batch_m16n8k16_fp16_row_row_row(%arg0: memref<2x20x20xf16, 3>, %arg1:
   %cst_0 = arith.constant dense<0.000000e+00> : vector<20x20xf16>
   // CHECK: [[C0:%.+]] = arith.constant 0 : index
   %c0 = arith.constant 0 : index
-  %c1 = arith.constant 1 : index
-  %c3 = arith.constant 3 : index
   %cst = arith.constant 0.000000e+00 : f16
 
-  // CHECK-DAG: [[row:%.+]] = affine.apply [[$Arow_map]]
-  // CHECK-DAG: [[col:%.+]] = affine.apply [[$Acol_map]]
-  // CHECK: nvgpu.ldmatrix %arg0[[[C0]], [[row]], [[col]]] {numTiles = 4 : i32, transpose = false} : memref<2x20x20xf16, 3> -> vector<4x2xf16>
-  %A = vector.transfer_read %arg0[%c0, %c1, %c3], %cst {in_bounds = [true, true]} : memref<2x20x20xf16, 3>, vector<16x16xf16>
+  // CHECK-DAG: [[m_coord:%.+]] = affine.apply [[$strided_map]]
+  // CHECK-DAG: [[k_coord:%.+]] = affine.apply [[$contiguous_map]]
+  // CHECK: nvgpu.ldmatrix %arg0[[[C0]], [[m_coord]], [[k_coord]]] {numTiles = 4 : i32, transpose = false} : memref<2x20x20xf16, 3> -> vector<4x2xf16>
+  %A = vector.transfer_read %arg0[%c0, %c0, %c0], %cst {in_bounds = [true, true]} : memref<2x20x20xf16, 3>, vector<16x16xf16>
 
-  // CHECK-DAG: [[row:%.+]] = affine.apply [[$Brow_map]]
-  // CHECK-DAG: [[col:%.+]] = affine.apply [[$Bcol_map]]
-  // CHECK: nvgpu.ldmatrix %arg1[[[C0]], [[row]], [[col]]] {numTiles = 2 : i32, transpose = true} : memref<2x20x20xf16, 3> -> vector<2x2xf16>
-  %B = vector.transfer_read %arg1[%c0, %c3, %c3], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<2x20x20xf16, 3>, vector<8x16xf16>
+  // CHECK-DAG: [[n_coord:%.+]] = affine.apply [[$contiguous_map]]
+  // CHECK-DAG: [[k_coord:%.+]] = affine.apply [[$strided_map]]
+  // CHECK: nvgpu.ldmatrix %arg1[[[C0]], [[k_coord]], [[n_coord]]] {numTiles = 2 : i32, transpose = true} : memref<2x20x20xf16, 3> -> vector<2x2xf16>
+  %B = vector.transfer_read %arg1[%c0, %c0, %c0], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<2x20x20xf16, 3>, vector<8x16xf16>
 
-  // CHECK-DAG: [[row:%.+]] = affine.apply [[$Arow_map]]
-  // CHECK-DAG: [[col:%.+]] = affine.apply [[$Acol_map]]
-  // CHECK: nvgpu.ldmatrix %arg2[[[C0]], [[row]], [[col]]] {numTiles = 2 : i32, transpose = false} : memref<2x20x20xf16, 3> -> vector<2x2xf16>
-  %C = vector.transfer_read %arg2[%c0, %c1, %c3], %cst {in_bounds = [true, true]} : memref<2x20x20xf16, 3>, vector<16x8xf16>
+  // CHECK-DAG: [[m_coord:%.+]] = affine.apply [[$strided_map]]
+  // CHECK-DAG: [[n_coord:%.+]] = affine.apply [[$contiguous_map]]
+  // CHECK: nvgpu.ldmatrix %arg2[[[C0]], [[m_coord]], [[n_coord]]] {numTiles = 2 : i32, transpose = false} : memref<2x20x20xf16, 3> -> vector<2x2xf16>
+  %C = vector.transfer_read %arg2[%c0, %c0, %c0], %cst {in_bounds = [true, true]} : memref<2x20x20xf16, 3>, vector<16x8xf16>
   %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %C : vector<16x16xf16>, vector<8x16xf16> into vector<16x8xf16>
-  vector.transfer_write %D, %arg2[%c0, %c1, %c3] {in_bounds = [true, true]} : vector<16x8xf16>, memref<2x20x20xf16, 3>
+  vector.transfer_write %D, %arg2[%c0, %c0, %c0] {in_bounds = [true, true]} : vector<16x8xf16>, memref<2x20x20xf16, 3>
   return
 }
 
@@ -307,36 +297,36 @@ func.func @batch_m16n8k16_fp16_row_row_row(%arg0: memref<2x20x20xf16, 3>, %arg1:
 #map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
 #map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
 
-// CHECK: [[$rowA_map:#.+]] = affine_map<()[s0] -> (s0 mod 16 + 1)>
-// CHECK: [[$colA_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 8 + 3)>
+// Affine maps for ldmatrix x4 tile of `16 x 16` f16 elements in `strided x contiguous` dimensions.
+// CHECK: [[$strided_ldmatrix_x4_map:#.+]] = affine_map<()[s0] -> (s0 mod 16)>
+// CHECK: [[$contiguous_ldmatrix_x4_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 8)>
 
-// CHECK: [[$rowB_map:#.+]] = affine_map<()[s0] -> (s0 mod 8 + 1)>
-// CHECK: [[$colB_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 8) * 8 + 3)>
+// CHECK: [[$strided_ldmatrix_x2_map:#.+]] = affine_map<()[s0] -> (s0 mod 8)>
+// CHECK: [[$contiguous_ldmatrix_x2_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 8) * 8)>
 
 // CHECK-LABEL: func @m16n8k16_fp16_row_col_row
 func.func @m16n8k16_fp16_row_col_row(%arg0: memref<20x20xf16, 3>, %arg1: memref<20x20xf16, 3>, %arg2: memref<20x20xf16, 3>) {
   %cst_0 = arith.constant dense<0.000000e+00> : vector<16x8xf16>
   %c0 = arith.constant 0 : index
-  %c1 = arith.constant 1 : index
-  %c3 = arith.constant 3 : index
+
   %cst = arith.constant 0.000000e+00 : f16
-  // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowA_map]]
-  // CHECK-DAG: [[col:%.+]] = affine.apply [[$colA_map]]
-  // CHECK: nvgpu.ldmatrix %arg0[[[row]], [[col]]] {numTiles = 4 : i32
+  // CHECK-DAG: [[m_coord:%.+]] = affine.apply [[$strided_ldmatrix_x4_map]]
+  // CHECK-DAG: [[k_coord:%.+]] = affine.apply [[$contiguous_ldmatrix_x4_map]]
+  // CHECK: nvgpu.ldmatrix %arg0[[[m_coord]], [[k_coord]]] {numTiles = 4 : i32
   // CHECK-SAME: transpose = false
 
-  // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB_map]]
-  // CHECK-DAG: [[col:%.+]] = affine.apply [[$colB_map]]
-  // CHECK: nvgpu.ldmatrix %arg1[[[row]], [[col]]] {numTiles = 2 : i32
+  // CHECK-DAG: [[n_coord:%.+]] = affine.apply [[$strided_ldmatrix_x2_map]]
+  // CHECK-DAG: [[k_coord:%.+]] = affine.apply [[$contiguous_ldmatrix_x2_map]]
+  // CHECK: nvgpu.ldmatrix %arg1[[[n_coord]], [[k_coord]]] {numTiles = 2 : i32
   // CHECK-SAME: transpose = false
 
-  // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowA_map]]
-  // CHECK-DAG: [[col:%.+]] = affine.apply [[$colA_map]]
-  // CHECK: nvgpu.ldmatrix %arg2[[[row]], [[col]]] {numTiles = 2 : i32
+  // CHECK-DAG: [[m_coord:%.+]] = affine.apply [[$strided_ldmatrix_x4_map]]
+  // CHECK-DAG: [[n_coord:%.+]] = affine.apply [[$contiguous_ldmatrix_x4_map]]
+  // CHECK: nvgpu.ldmatrix %arg2[[[m_coord]], [[n_coord]]] {numTiles = 2 : i32
   // CHECK-SAME: transpose = false
-  %A = vector.transfer_read %arg0[%c1, %c3], %cst {in_bounds = [true, true]} : memref<20x20xf16, 3>, vector<16x16xf16>
-  %B = vector.transfer_read %arg1[%c1, %c3], %cst {in_bounds = [true, true]} : memref<20x20xf16, 3>, vector<8x16xf16>
-  %C = vector.transfer_read %arg2[%c1, %c3], %cst {in_bounds = [true, true]} : memref<20x20xf16, 3>, vector<16x8xf16>
+  %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<20x20xf16, 3>, vector<16x16xf16>
+  %B = vector.transfer_read %arg1[%c0, %c0], %cst {in_bounds = [true, true]} : memref<20x20xf16, 3>, vector<8x16xf16>
+  %C = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<20x20xf16, 3>, vector<16x8xf16>
   %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %C : vector<16x16xf16>, vector<8x16xf16> into vector<16x8xf16>
   vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x8xf16>, memref<20x20xf16, 3>
   return
@@ -345,7 +335,7 @@ func.func @m16n8k16_fp16_row_col_row(%arg0: memref<20x20xf16, 3>, %arg1: memref<
 // -----
 
 //#########################################################
-// TF32 (multiplicand) F32 (accumulator) row-row-row
+// TF32 row-row-row
 //#########################################################
 
 #map0 = affine_map<(d0, d1) -> (d1, d0)>
@@ -406,6 +396,9 @@ func.func @m16n8k4_tf32_f32_row_row_row(%arg0: memref<20x20xf32, 3>, %arg1: memr
 
 // -----
 
+//#########################################################
+// TF32 row-row-row
+//#########################################################
 #map0 = affine_map<(d0, d1) -> (d1, d0)>
 #map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
 #map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
@@ -467,13 +460,88 @@ func.func @m16n8k8_tf32_f32_row_row_row(%arg0: memref<20x20xf32, 3>, %arg1: memr
 // -----
 
 //#########################################################
-// INT4 row-col-row
+// TF32 col-col-row
 //#########################################################
+#map0 = affine_map<(d0, d1) -> (d1, d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-DAG: [[$rowA0_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4)>
+// CHECK-DAG: [[$colA0_map:#.+]] = affine_map<()[s0] -> (s0 mod 4)>
+// CHECK-DAG: [[$rowA8_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 8)>
+// CHECK-DAG: [[$colA4_map:#.+]] = affine_map<()[s0] -> (s0 mod 4 + 4)>
 
-// CHECK-DAG: [[$rowA0_map:#.+]] = affine_map<()[s0] -> (s0 mod 16)>
-// CHECK-DAG: [[$colA0_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 32)>
 // CHECK-DAG: [[$rowB0_map:#.+]] = affine_map<()[s0] -> (s0 mod 8)>
-// CHECK-DAG: [[$colB0_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 8) * 32)>
+// CHECK-DAG: [[$colB0_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 8) * 4)>
+
+// CHECK-DAG: [[$rowC_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 16)>
+// CHECK-DAG: [[$rowC8_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 24)>
+// CHECK-DAG: [[$colC_map:#.+]] = affine_map<()[s0] -> (s0 * 2 - (s0 floordiv 4) * 8 + 8)>
+
+// CHECK-LABEL: func @m16n8k8_tf32_f32_col_col_row
+func.func @m16n8k8_tf32_f32_col_col_row(%arg0: memref<20x20xf32, 3>, %arg1: memref<20x20xf32, 3>, %arg2: memref<20x20xf32>) {
+  %cst_0 = arith.constant dense<0.000000e+00> : vector<16x8xf32>
+  %c0 = arith.constant 0 : index
+  %c16 = arith.constant 16 : index
+  %c8 = arith.constant 8 : index
+  %cst = arith.constant 0.000000e+00 : f32
+
+  // CHECK: [[c_frag:%.+]] = arith.constant {{.*}} : vector<2x2xf32>
+
+  // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowA0_map]]
+  // CHECK-DAG: [[col:%.+]] = affine.apply [[$colA0_map]]
+  // CHECK: [[a_el0:%.+]] = memref.load {{%.+}} : memref<20x20xf32, 3>
+  // CHECK: [[a_frag0:%.+]] = vector.insert [[a_el0]], {{.*}} [0, 0] : f32 into vector<4x1xf32>
+
+  // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowA8_map]]
+  // CHECK-DAG: [[col:%.+]] = affine.apply [[$colA0_map]]
+  // CHECK: [[a_el0:%.+]] = memref.load {{%.+}} : memref<20x20xf32, 3>
+  // CHECK: [[a_frag0:%.+]] = vector.insert [[a_el0]], {{.*}} [1, 0] : f32 into vector<4x1xf32>
+
+  // CHECK: [[a_el:%.+]] = memref.load {{%.+}} : memref<20x20xf32, 3>
+  // CHECK: [[a_frag:%.+]] = vector.insert [[a_el]], {{.*}} [2, 0] : f32 into vector<4x1xf32>
+  // CHECK: [[a_el:%.+]] = memref.load {{%.+}} : memref<20x20xf32, 3>
+  // CHECK: [[a_frag:%.+]] = vector.insert [[a_el]], {{.*}} [3, 0] : f32 into vector<4x1xf32>
+
+  // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB0_map]]
+  // CHECK-DAG: [[col:%.+]] = affine.apply [[$colB0_map]]
+  // CHECK: [[b_frag:%.+]] = nvgpu.ldmatrix %arg1[[[row]], [[col]]] {numTiles = 2 : i32, transpose = false}
+
+  // CHECK: [[d_frag:%.+]] = nvgpu.mma.sync([[a_frag]], [[b_frag]], [[c_frag]])
+  // CHECK-SAME: mmaShape = [16, 8, 8]
+  // CHECK-SAME: -> vector<2x2xf32>
+  %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true], permutation_map = #map0} : memref<20x20xf32, 3>, vector<16x8xf32>
+  %B = vector.transfer_read %arg1[%c0, %c0], %cst {in_bounds = [true, true]} : memref<20x20xf32, 3>, vector<8x8xf32>
+  %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>} %A, %B, %cst_0 : vector<16x8xf32>, vector<8x8xf32> into vector<16x8xf32>
+
+  // CHECK: vector.extract [[d_frag]][0] : vector<2x2xf32>
+  // CHECK: affine.apply [[$rowC_map]]
+  // CHECK: affine.apply [[$colC_map]]
+  // CHECK: vector.store
+  // CHECK: vector.extract [[d_frag]][1] : vector<2x2xf32>
+  // CHECK: affine.apply [[$rowC8_map]]
+  // CHECK: affine.apply [[$colC_map]]
+  // CHECK: vector.store
+  vector.transfer_write %D, %arg2[%c16, %c8] {in_bounds = [true, true]} : vector<16x8xf32>, memref<20x20xf32>
+  return
+}
+
+// -----
+
+//#########################################################
+// INT4 row-col-row
+//#########################################################
+// Affine maps for loading operandA and operandB
+// maps (laneid -> coordinate pointed by the lane in the ldmatrix operand tile)
+// CHECK-DAG: [[$strided_ldmatrix_x4_map:#.+]] = affine_map<()[s0] -> (s0 mod 16)>
+// CHECK-DAG: [[$contiguous_ldmatrix_x4_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 32)>
+// CHECK-DAG: [[$strided_ldmatrix_x2_map:#.+]] = affine_map<()[s0] -> (s0 mod 8)>
+// CHECK-DAG: [[$contiguous_ldmatrix_x2_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 8) * 32)>
+
+// Affine maps for accumulator registers
+// maps (laneid -> coordinate pointed by the lane in accumulator register tile)
 // CHECK-DAG: [[$rowC0_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4)>
 // CHECK-DAG: [[$colC0_map:#.+]] = affine_map<()[s0] -> (s0 * 2 - (s0 floordiv 4) * 8
 // CHECK-DAG: [[$rowC8_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 8)>
@@ -490,14 +558,14 @@ func.func @m16n8k64_int4_row_col_row(%arg0: memref<128x128xi4, 3>, %arg1: memref
   %c0 = arith.constant 0 : index
 
   // CHECK: [[lane:%.+]] = gpu.lane_id
-  // CHECK: [[row:%.+]] = affine.apply [[$rowA0_map]]()[[[lane]]]
-  // CHECK: [[col:%.+]] = affine.apply [[$colA0_map]]()[[[lane]]]
-  // CHECK: nvgpu.ldmatrix %arg0[[[row]], [[col]]] {numTiles = 4 : i32, transpose = false} : memref<128x128xi4, 3> -> vector<4x8xi4>
+  // CHECK: [[m_coord:%.+]] = affine.apply [[$strided_ldmatrix_x4_map]]()[[[lane]]]
+  // CHECK: [[k_coord:%.+]] = affine.apply [[$contiguous_ldmatrix_x4_map]]()[[[lane]]]
+  // CHECK: nvgpu.ldmatrix %arg0[[[m_coord]], [[k_coord]]] {numTiles = 4 : i32, transpose = false} : memref<128x128xi4, 3> -> vector<4x8xi4>
 
   // CHECK: [[lane:%.+]] = gpu.lane_id
-  // CHECK: [[row:%.+]] = affine.apply [[$rowB0_map]]()[[[lane]]]
-  // CHECK: [[col:%.+]] = affine.apply [[$colB0_map]]()[[[lane]]]
-  // CHECK: nvgpu.ldmatrix %arg1[[[row]], [[col]]] {numTiles = 2 : i32, transpose = false} : memref<128x128xi4, 3> -> vector<2x8xi4>
+  // CHECK: [[n_coord:%.+]] = affine.apply [[$strided_ldmatrix_x2_map]]()[[[lane]]]
+  // CHECK: [[k_coord:%.+]] = affine.apply [[$contiguous_ldmatrix_x2_map]]()[[[lane]]]
+  // CHECK: nvgpu.ldmatrix %arg1[[[n_coord]], [[k_coord]]] {numTiles = 2 : i32, transpose = false} : memref<128x128xi4, 3> -> vector<2x8xi4>
 
   // CHECK: [[lane:%.+]] = gpu.lane_id
   // CHECK: [[row:%.+]] = affine.apply [[$rowC0_map]]()[{{%.+}}]
@@ -534,12 +602,15 @@ func.func @m16n8k64_int4_row_col_row(%arg0: memref<128x128xi4, 3>, %arg1: memref
 //#########################################################
 // INT8 row-col-row
 //#########################################################
-
-// CHECK-DAG: [[$rowA0_map:#.+]] = affine_map<()[s0] -> (s0 mod 16)>
-// CHECK-DAG: [[$colA0_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 16)>
-// CHECK-DAG: [[$rowB0_map:#.+]] = affine_map<()[s0] -> (s0 mod 8)>
-// CHECK-DAG: [[$colB0_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 8) * 16)>
-
+// Affine maps for loading operandA and operandB
+// maps (laneid -> coordinate pointed by the lane in the ldmatrix operand tile)
+// CHECK-DAG: [[$strided_ldmatrix_x4_map:#.+]] = affine_map<()[s0] -> (s0 mod 16)>
+// CHECK-DAG: [[$contiguous_ldmatrix_x4_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 16)>
+// CHECK-DAG: [[$strided_ldmatrix_x2_map:#.+]] = affine_map<()[s0] -> (s0 mod 8)>
+// CHECK-DAG: [[$contiguous_ldmatrix_x2_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 8) * 16)>
+
+// Affine maps for accumulator registers
+// maps (laneid -> coordinate pointed by the lane in accumulator register tile)
 // CHECK-DAG: [[$rowC0_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4)>
 // CHECK-DAG: [[$colC0_map:#.+]] = affine_map<()[s0] -> (s0 * 2 - (s0 floordiv 4) * 8)>
 // CHECK-DAG: [[$rowC8_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 8)>
@@ -554,27 +625,26 @@ func.func @m16n8k64_int4_row_col_row(%arg0: memref<128x128xi4, 3>, %arg1: memref
 func.func @m16n8k32_int8_row_col_row(%arg0: memref<128x128xi8, 3>, %arg1: memref<128x128xi8, 3>, %arg2: memref<128x128xi32>) {
   %cst_0 = arith.constant dense<0> : vector<32x8xi8>
   %c0 = arith.constant 0 : index
-  %c1 = arith.constant 1 : index
   %cst = arith.constant 0 : i8
   %cst0 = arith.constant 0 : i32
 
   // CHECK: [[lane:%.+]] = gpu.lane_id
-  // CHECK: [[row:%.+]] = affine.apply [[$rowA0_map]]()[[[lane]]]
-  // CHECK: [[col:%.+]] = affine.apply [[$colA0_map]]()[[[lane]]]
-  // CHECK: nvgpu.ldmatrix %arg0[[[row]], [[col]]] {numTiles = 4 : i32, transpose = false} : memref<128x128xi8, 3> -> vector<4x4xi8>
+  // CHECK: [[m_coord:%.+]] = affine.apply [[$strided_ldmatrix_x4_map]]()[[[lane]]]
+  // CHECK: [[k_coord:%.+]] = affine.apply [[$contiguous_ldmatrix_x4_map]]()[[[lane]]]
+  // CHECK: nvgpu.ldmatrix %arg0[[[m_coord]], [[k_coord]]] {numTiles = 4 : i32, transpose = false} : memref<128x128xi8, 3> -> vector<4x4xi8>
 
   // CHECK: [[lane:%.+]] = gpu.lane_id
-  // CHECK: [[row:%.+]] = affine.apply [[$rowB0_map]]()[[[lane]]]
-  // CHECK: [[col:%.+]] = affine.apply [[$colB0_map]]()[[[lane]]]
-  // CHECK: nvgpu.ldmatrix %arg1[[[row]], [[col]]] {numTiles = 2 : i32, transpose = false} : memref<128x128xi8, 3> -> vector<2x4xi8>
+  // CHECK: [[n_coord:%.+]] = affine.apply [[$strided_ldmatrix_x2_map]]()[[[lane]]]
+  // CHECK: [[k_coord:%.+]] = affine.apply [[$contiguous_ldmatrix_x2_map]]()[[[lane]]]
+  // CHECK: nvgpu.ldmatrix %arg1[[[n_coord]], [[k_coord]]] {numTiles = 2 : i32, transpose = false} : memref<128x128xi8, 3> -> vector<2x4xi8>
 
   // CHECK: [[lane:%.+]] = gpu.lane_id
-  // CHECK: [[row:%.+]] = affine.apply [[$rowC0_map]]()[[[lane]]]
-  // CHECK: [[col:%.+]] = affine.apply [[$colC0_map]]()[[[lane]]]
-  // CHECK: vector.load %arg2[[[row]], [[col]]] : memref<128x128xi32>, vector<2xi32>
-  // CHECK: [[row:%.+]] = affine.apply [[$rowC8_map]]()[[[lane]]]
-  // CHECK: [[col:%.+]] = affine.apply [[$colC0_map]]()[[[lane]]]
-  // CHECK: vector.load %arg2[[[row]], [[col]]] : memref<128x128xi32>, vector<2xi32>
+  // CHECK: [[m_coord:%.+]] = affine.apply [[$rowC0_map]]()[[[lane]]]
+  // CHECK: [[n_coord:%.+]] = affine.apply [[$colC0_map]]()[[[lane]]]
+  // CHECK: vector.load %arg2[[[m_coord]], [[n_coord]]] : memref<128x128xi32>, vector<2xi32>
+  // CHECK: [[m_coord:%.+]] = affine.apply [[$rowC8_map]]()[[[lane]]]
+  // CHECK: [[n_coord:%.+]] = affine.apply [[$colC0_map]]()[[[lane]]]
+  // CHECK: vector.load %arg2[[[m_coord]], [[n_coord]]] : memref<128x128xi32>, vector<2xi32>
   // CHECK-NOT: vector.load %arg2{{.*}}
 
   %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<128x128xi8, 3>, vector<16x32xi8>
@@ -595,73 +665,3 @@ func.func @m16n8k32_int8_row_col_row(%arg0: memref<128x128xi8, 3>, %arg1: memref
   vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x8xi32>, memref<128x128xi32>
   return
 }
-
-// -----
-
-#map0 = affine_map<(d0, d1) -> (d1, d0)>
-#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
-#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
-#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
-
-// CHECK-DAG: [[$rowA0_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4)>
-// CHECK-DAG: [[$colA0_map:#.+]] = affine_map<()[s0] -> (s0 mod 4)>
-// CHECK-DAG: [[$rowA8_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 8)>
-// CHECK-DAG: [[$colA4_map:#.+]] = affine_map<()[s0] -> (s0 mod 4 + 4)>
-
-// CHECK-DAG: [[$rowB0_map:#.+]] = affine_map<()[s0] -> (s0 mod 8)>
-// CHECK-DAG: [[$colB0_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 8) * 4)>
-
-// CHECK-DAG: [[$rowC_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 16)>
-// CHECK-DAG: [[$rowC8_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 24)>
-// CHECK-DAG: [[$colC_map:#.+]] = affine_map<()[s0] -> (s0 * 2 - (s0 floordiv 4) * 8 + 8)>
-
-// CHECK-LABEL: func @m16n8k8_tf32_f32_col_col_row
-func.func @m16n8k8_tf32_f32_col_col_row(%arg0: memref<20x20xf32, 3>, %arg1: memref<20x20xf32, 3>, %arg2: memref<20x20xf32>) {
-  %cst_0 = arith.constant dense<0.000000e+00> : vector<16x8xf32>
-  %c0 = arith.constant 0 : index
-  %c16 = arith.constant 16 : index
-  %c8 = arith.constant 8 : index
-  %c1 = arith.constant 1 : index
-  %c3 = arith.constant 3 : index
-  %cst = arith.constant 0.000000e+00 : f32
-
-  // CHECK: [[c_frag:%.+]] = arith.constant {{.*}} : vector<2x2xf32>
-
-  // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowA0_map]]
-  // CHECK-DAG: [[col:%.+]] = affine.apply [[$colA0_map]]
-  // CHECK: [[a_el0:%.+]] = memref.load {{%.+}} : memref<20x20xf32, 3>
-  // CHECK: [[a_frag0:%.+]] = vector.insert [[a_el0]], {{.*}} [0, 0] : f32 into vector<4x1xf32>
-
-  // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowA8_map]]
-  // CHECK-DAG: [[col:%.+]] = affine.apply [[$colA0_map]]
-  // CHECK: [[a_el0:%.+]] = memref.load {{%.+}} : memref<20x20xf32, 3>
-  // CHECK: [[a_frag0:%.+]] = vector.insert [[a_el0]], {{.*}} [1, 0] : f32 into vector<4x1xf32>
-
-  // CHECK: [[a_el:%.+]] = memref.load {{%.+}} : memref<20x20xf32, 3>
-  // CHECK: [[a_frag:%.+]] = vector.insert [[a_el]], {{.*}} [2, 0] : f32 into vector<4x1xf32>
-  // CHECK: [[a_el:%.+]] = memref.load {{%.+}} : memref<20x20xf32, 3>
-  // CHECK: [[a_frag:%.+]] = vector.insert [[a_el]], {{.*}} [3, 0] : f32 into vector<4x1xf32>
-
-  // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB0_map]]
-  // CHECK-DAG: [[col:%.+]] = affine.apply [[$colB0_map]]
-  // CHECK: [[b_frag:%.+]] = nvgpu.ldmatrix %arg1[[[row]], [[col]]] {numTiles = 2 : i32, transpose = false}
-
-  // CHECK: [[d_frag:%.+]] = nvgpu.mma.sync([[a_frag]], [[b_frag]], [[c_frag]])
-  // CHECK-SAME: mmaShape = [16, 8, 8]
-  // CHECK-SAME: -> vector<2x2xf32>
-  %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true], permutation_map = #map0} : memref<20x20xf32, 3>, vector<16x8xf32>
-  %B = vector.transfer_read %arg1[%c0, %c0], %cst {in_bounds = [true, true]} : memref<20x20xf32, 3>, vector<8x8xf32>
-  %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"],
-    kind = #vector.kind<add>} %A, %B, %cst_0 : vector<16x8xf32>, vector<8x8xf32> into vector<16x8xf32>
-
-  // CHECK: vector.extract [[d_frag]][0] : vector<2x2xf32>
-  // CHECK: affine.apply [[$rowC_map]]
-  // CHECK: affine.apply [[$colC_map]]
-  // CHECK: vector.store
-  // CHECK: vector.extract [[d_frag]][1] : vector<2x2xf32>
-  // CHECK: affine.apply [[$rowC8_map]]
-  // CHECK: affine.apply [[$colC_map]]
-  // CHECK: vector.store
-  vector.transfer_write %D, %arg2[%c16, %c8] {in_bounds = [true, true]} : vector<16x8xf32>, memref<20x20xf32>
-  return
-}


        


More information about the Mlir-commits mailing list