[Mlir-commits] [mlir] 670eee0 - [mlir][VectorToGPU] Fix support for i4, col-major operand support

Christopher Bate llvmlistbot at llvm.org
Thu Jun 30 09:31:41 PDT 2022


Author: Christopher Bate
Date: 2022-06-30T10:26:59-06:00
New Revision: 670eee08cecfcfe170fb0e7daa88df8c2a150dbe

URL: https://github.com/llvm/llvm-project/commit/670eee08cecfcfe170fb0e7daa88df8c2a150dbe
DIFF: https://github.com/llvm/llvm-project/commit/670eee08cecfcfe170fb0e7daa88df8c2a150dbe.diff

LOG: [mlir][VectorToGPU] Fix support for i4, col-major operand support

For the conversion to nvgpu `mma.sync` and `ldmatrix` pathways, the code
was missing support for the `i4` data type. While fixing this, another
bug was discoverd that caused the number of ldmatrix tiles calculated for
certain operand types and configurations to be incorrect. This change
fixes both issues and adds additional tests.

Differential Revision: https://reviews.llvm.org/D128074

Added: 
    

Modified: 
    mlir/lib/Conversion/VectorToGPU/NvGpuSupport.cpp
    mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
    mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/VectorToGPU/NvGpuSupport.cpp b/mlir/lib/Conversion/VectorToGPU/NvGpuSupport.cpp
index a0d16e537b5b1..af14aaa1deeec 100644
--- a/mlir/lib/Conversion/VectorToGPU/NvGpuSupport.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/NvGpuSupport.cpp
@@ -124,6 +124,14 @@ getMmaSyncRegisterType(const WarpMatrixInfo &type) {
         LLVM::getFixedVectorType(IntegerType::get(ctx, 8), 4), 4, 32,
         inferNumRegistersPerMatrixFragment(type)};
   }
+
+  // int4 operand
+  if (elType.isInteger(4)) {
+    return FragmentElementInfo{
+        LLVM::getFixedVectorType(IntegerType::get(ctx, 4), 8), 8, 32,
+        inferNumRegistersPerMatrixFragment(type)};
+  }
+
   // Integer 32bit acc operands
   if (elType.isInteger(32)) {
     return FragmentElementInfo{
@@ -212,7 +220,7 @@ FailureOr<nvgpu::LdMatrixParams> getLdMatrixParams(const WarpMatrixInfo &type,
   params.contiguousDimType =
       transpose ? IteratorType::Parallel : IteratorType::Reduction;
 
-  if (params.targetLayout == NVVM::MMALayout::row) {
+  if (params.contiguousDimType == IteratorType::Reduction) {
     params.numTiles = (shape[0] / kNumRowsPerTile) *
                       ((shape[1] * elType.getIntOrFloatBitWidth()) / 128);
   } else {

diff  --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 74e81e961455f..4e7cf6b31cd00 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -524,11 +524,6 @@ createNonLdMatrixLoads(vector::TransferReadOp op, OpBuilder &builder,
     return failure();
   }
 
-  NVVM::MMALayout targetLayout =
-      warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::B
-          ? NVVM::MMALayout::col
-          : NVVM::MMALayout::row;
-
   Value laneId = builder.create<gpu::LaneIdOp>(loc);
   SmallVector<Value, 4> elements;
 
@@ -543,8 +538,9 @@ createNonLdMatrixLoads(vector::TransferReadOp op, OpBuilder &builder,
 
   bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity();
 
-  // Vectorized loads.
-  if (!isTransposeLoad && targetLayout == NVVM::MMALayout::row) {
+  // If we are not transposing, then we can use vectorized loads. Otherwise, we
+  // must load each element individually.
+  if (!isTransposeLoad) {
     if (!loadedElType.isa<VectorType>()) {
       loadedElType = VectorType::get({1}, loadedElType);
     }
@@ -566,11 +562,10 @@ createNonLdMatrixLoads(vector::TransferReadOp op, OpBuilder &builder,
       result = builder.create<vector::InsertOp>(loc, el, result,
                                                 builder.getI64ArrayAttr(i));
     }
-  } else if (isTransposeLoad && targetLayout == NVVM::MMALayout::col) {
+  } else {
     if (auto vecType = loadedElType.dyn_cast<VectorType>()) {
       loadedElType = vecType.getElementType();
     }
-    // Load each element individually.
     for (int i = 0; i < vectorType.getShape()[0]; i++) {
       for (unsigned innerIdx = 0; innerIdx < vectorType.getShape()[1];
            innerIdx++) {
@@ -592,8 +587,6 @@ createNonLdMatrixLoads(vector::TransferReadOp op, OpBuilder &builder,
             op.getLoc(), el, result, builder.getI64ArrayAttr({i, innerIdx}));
       }
     }
-  } else {
-    return failure();
   }
 
   valueMapping[op.getResult()] = result;

diff  --git a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir
index 479533d1fa411..42dc06c937d40 100644
--- a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir
+++ b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir
@@ -407,3 +407,205 @@ func.func @m16n8k8_tf32_f32_row_row_row(%arg0: memref<20x20xf32, 3>, %arg1: memr
   vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x8xf32>, memref<20x20xf32>
   return
 }
+
+// -----
+
+//#########################################################
+// INT4 row-col-row
+//#########################################################
+
+// CHECK-DAG: [[$rowA0_map:#.+]] = affine_map<()[s0] -> (s0 mod 16)>
+// CHECK-DAG: [[$colA0_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 32)>
+// CHECK-DAG: [[$rowB0_map:#.+]] = affine_map<()[s0] -> (s0 mod 8)>
+// CHECK-DAG: [[$colB0_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 8) * 32)>
+// CHECK-DAG: [[$rowC0_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4)>
+// CHECK-DAG: [[$colC0_map:#.+]] = affine_map<()[s0] -> (s0 * 2 - (s0 floordiv 4) * 8
+// CHECK-DAG: [[$rowC8_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 8)>
+
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: func @m16n8k64_int4_row_col_row
+func.func @m16n8k64_int4_row_col_row(%arg0: memref<128x128xi4, 3>, %arg1: memref<128x128xi4, 3>, %arg2: memref<128x128xi32>) {
+  %cst  = arith.constant 0 : i4
+  %cst0  = arith.constant 0 : i32
+  %cst_0 = arith.constant dense<0> : vector<32x8xi4>
+  %c0 = arith.constant 0 : index
+
+  // CHECK: [[lane:%.+]] = gpu.lane_id
+  // CHECK: [[row:%.+]] = affine.apply [[$rowA0_map]]()[[[lane]]]
+  // CHECK: [[col:%.+]] = affine.apply [[$colA0_map]]()[[[lane]]]
+  // CHECK: nvgpu.ldmatrix %arg0[[[row]], [[col]]] {numTiles = 4 : i32, transpose = false} : memref<128x128xi4, 3> -> vector<4x8xi4>
+
+  // CHECK: [[lane:%.+]] = gpu.lane_id
+  // CHECK: [[row:%.+]] = affine.apply [[$rowB0_map]]()[[[lane]]]
+  // CHECK: [[col:%.+]] = affine.apply [[$colB0_map]]()[[[lane]]]
+  // CHECK: nvgpu.ldmatrix %arg1[[[row]], [[col]]] {numTiles = 2 : i32, transpose = false} : memref<128x128xi4, 3> -> vector<2x8xi4>
+  
+  // CHECK: [[lane:%.+]] = gpu.lane_id
+  // CHECK: [[row:%.+]] = affine.apply [[$rowC0_map]]()[{{%.+}}]
+  // CHECK: [[col:%.+]] = affine.apply [[$colC0_map]]()[{{%.+}}]
+  // CHECK: vector.load %arg2[[[row]], [[col]]] : memref<128x128xi32>, vector<2xi32>
+
+  // CHECK: [[row:%.+]] = affine.apply [[$rowC8_map]]()[{{%.+}}]
+  // CHECK: [[col:%.+]] = affine.apply [[$colC0_map]]()[{{%.+}}]
+  // CHECK: vector.load %arg2[[[row]], [[col]]] : memref<128x128xi32>, vector<2xi32>
+  // CHECK-NOT: vector.load
+
+  %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<128x128xi4, 3>, vector<16x64xi4>
+  %B = vector.transfer_read %arg1[%c0, %c0], %cst {in_bounds = [true, true]} : memref<128x128xi4, 3>, vector<8x64xi4>
+  %C = vector.transfer_read %arg2[%c0, %c0], %cst0 {in_bounds = [true, true]} : memref<128x128xi32>, vector<16x8xi32>
+  // CHECK: [[d:%.+]] = nvgpu.mma.sync({{.*}}) {mmaShape = [16, 8, 64]} : (vector<4x8xi4>, vector<2x8xi4>, vector<2x2xi32>) -> vector<2x2xi32>
+  %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %C : vector<16x64xi4>, vector<8x64xi4> into vector<16x8xi32>
+
+  // CHECK: [[lane:%.+]] = gpu.lane_id
+  // CHECK: [[v:%.+]] = vector.extract [[d]][0] : vector<2x2xi32>  
+  // CHECK: [[row:%.+]] = affine.apply [[$rowC0_map]]()[[[lane]]]
+  // CHECK: [[col:%.+]] = affine.apply [[$colC0_map]]()[[[lane]]]
+  // CHECK: vector.store [[v]], %arg2[[[row]], [[col]]] : memref<128x128xi32>, vector<2xi32>
+  
+  // CHECK: [[v:%.+]] = vector.extract [[d]][1] : vector<2x2xi32>
+  // CHECK: [[row:%.+]] = affine.apply [[$rowC8_map]]()[[[lane]]]
+  // CHECK: [[col:%.+]] = affine.apply [[$colC0_map]]()[[[lane]]]
+  // CHECK: vector.store [[v]], %arg2[[[row]], [[col]]] : memref<128x128xi32>, vector<2xi32>
+  vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x8xi32>, memref<128x128xi32>
+  return
+}
+
+// -----
+
+//#########################################################
+// INT8 row-col-row
+//#########################################################
+
+// CHECK-DAG: [[$rowA0_map:#.+]] = affine_map<()[s0] -> (s0 mod 16)>
+// CHECK-DAG: [[$colA0_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 16)>
+// CHECK-DAG: [[$rowB0_map:#.+]] = affine_map<()[s0] -> (s0 mod 8)>
+// CHECK-DAG: [[$colB0_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 8) * 16)>
+
+// CHECK-DAG: [[$rowC0_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4)>
+// CHECK-DAG: [[$colC0_map:#.+]] = affine_map<()[s0] -> (s0 * 2 - (s0 floordiv 4) * 8)>
+// CHECK-DAG: [[$rowC8_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 8)>
+
+
+#map0 = affine_map<(d0, d1) -> (d1, d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: func @m16n8k32_int8_row_col_row
+func.func @m16n8k32_int8_row_col_row(%arg0: memref<128x128xi8, 3>, %arg1: memref<128x128xi8, 3>, %arg2: memref<128x128xi32>) {
+  %cst_0 = arith.constant dense<0> : vector<32x8xi8>
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %cst = arith.constant 0 : i8
+  %cst0 = arith.constant 0 : i32
+
+  // CHECK: [[lane:%.+]] = gpu.lane_id
+  // CHECK: [[row:%.+]] = affine.apply [[$rowA0_map]]()[[[lane]]]
+  // CHECK: [[col:%.+]] = affine.apply [[$colA0_map]]()[[[lane]]]
+  // CHECK: nvgpu.ldmatrix %arg0[[[row]], [[col]]] {numTiles = 4 : i32, transpose = false} : memref<128x128xi8, 3> -> vector<4x4xi8>
+
+  // CHECK: [[lane:%.+]] = gpu.lane_id
+  // CHECK: [[row:%.+]] = affine.apply [[$rowB0_map]]()[[[lane]]]
+  // CHECK: [[col:%.+]] = affine.apply [[$colB0_map]]()[[[lane]]]
+  // CHECK: nvgpu.ldmatrix %arg1[[[row]], [[col]]] {numTiles = 2 : i32, transpose = false} : memref<128x128xi8, 3> -> vector<2x4xi8>
+  
+  // CHECK: [[lane:%.+]] = gpu.lane_id
+  // CHECK: [[row:%.+]] = affine.apply [[$rowC0_map]]()[[[lane]]]
+  // CHECK: [[col:%.+]] = affine.apply [[$colC0_map]]()[[[lane]]]
+  // CHECK: vector.load %arg2[[[row]], [[col]]] : memref<128x128xi32>, vector<2xi32>
+  // CHECK: [[row:%.+]] = affine.apply [[$rowC8_map]]()[[[lane]]]
+  // CHECK: [[col:%.+]] = affine.apply [[$colC0_map]]()[[[lane]]]
+  // CHECK: vector.load %arg2[[[row]], [[col]]] : memref<128x128xi32>, vector<2xi32>
+  // CHECK-NOT: vector.load %arg2{{.*}}
+
+  %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<128x128xi8, 3>, vector<16x32xi8>
+  %B = vector.transfer_read %arg1[%c0, %c0], %cst {in_bounds = [true, true]} : memref<128x128xi8, 3>, vector<8x32xi8>
+  %C = vector.transfer_read %arg2[%c0, %c0], %cst0 {in_bounds = [true, true]} : memref<128x128xi32>, vector<16x8xi32>
+  // CHECK: [[d:%.+]] = nvgpu.mma.sync({{.*}}) {mmaShape = [16, 8, 32]} : (vector<4x4xi8>, vector<2x4xi8>, vector<2x2xi32>) -> vector<2x2xi32>
+  %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %C : vector<16x32xi8>, vector<8x32xi8> into vector<16x8xi32>
+
+  // CHECK: [[lane:%.+]] = gpu.lane_id
+  // CHECK: [[v:%.+]] = vector.extract [[d]][0] : vector<2x2xi32>  
+  // CHECK: [[row:%.+]] = affine.apply [[$rowC0_map]]()[[[lane]]]
+  // CHECK: [[col:%.+]] = affine.apply [[$colC0_map]]()[[[lane]]]
+  // CHECK: vector.store [[v]], %arg2[[[row]], [[col]]] : memref<128x128xi32>, vector<2xi32>
+  // CHECK: [[v:%.+]] = vector.extract [[d]][1] : vector<2x2xi32>
+  // CHECK: [[row:%.+]] = affine.apply [[$rowC8_map]]()[[[lane]]]
+  // CHECK: [[col:%.+]] = affine.apply [[$colC0_map]]()[[[lane]]]
+  // CHECK: vector.store [[v]], %arg2[[[row]], [[col]]] : memref<128x128xi32>, vector<2xi32>
+  vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x8xi32>, memref<128x128xi32>
+  return
+}
+
+// -----
+
+#map0 = affine_map<(d0, d1) -> (d1, d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-DAG: [[$rowA0_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4)>
+// CHECK-DAG: [[$colA0_map:#.+]] = affine_map<()[s0] -> (s0 mod 4)>
+// CHECK-DAG: [[$rowA8_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 8)>
+// CHECK-DAG: [[$colA4_map:#.+]] = affine_map<()[s0] -> (s0 mod 4 + 4)>
+
+// CHECK-DAG: [[$rowB0_map:#.+]] = affine_map<()[s0] -> (s0 mod 8)>
+// CHECK-DAG: [[$colB0_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 8) * 4)>
+
+// CHECK-DAG: [[$rowC_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 16)>
+// CHECK-DAG: [[$rowC8_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 24)>
+// CHECK-DAG: [[$colC_map:#.+]] = affine_map<()[s0] -> (s0 * 2 - (s0 floordiv 4) * 8 + 8)>
+
+// CHECK-LABEL: func @m16n8k8_tf32_f32_col_col_row
+func.func @m16n8k8_tf32_f32_col_col_row(%arg0: memref<20x20xf32, 3>, %arg1: memref<20x20xf32, 3>, %arg2: memref<20x20xf32>) {
+  %cst_0 = arith.constant dense<0.000000e+00> : vector<16x8xf32>
+  %c0 = arith.constant 0 : index
+  %c16 = arith.constant 16 : index
+  %c8 = arith.constant 8 : index
+  %c1 = arith.constant 1 : index
+  %c3 = arith.constant 3 : index
+  %cst = arith.constant 0.000000e+00 : f32
+
+  // CHECK: [[c_frag:%.+]] = arith.constant {{.*}} : vector<2x2xf32>
+
+  // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowA0_map]]
+  // CHECK-DAG: [[col:%.+]] = affine.apply [[$colA0_map]]
+  // CHECK: [[a_el0:%.+]] = memref.load {{%.+}} : memref<20x20xf32, 3>
+  // CHECK: [[a_frag0:%.+]] = vector.insert [[a_el0]], {{.*}} [0, 0] : f32 into vector<4x1xf32>
+
+  // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowA8_map]]
+  // CHECK-DAG: [[col:%.+]] = affine.apply [[$colA0_map]]
+  // CHECK: [[a_el0:%.+]] = memref.load {{%.+}} : memref<20x20xf32, 3>
+  // CHECK: [[a_frag0:%.+]] = vector.insert [[a_el0]], {{.*}} [1, 0] : f32 into vector<4x1xf32>
+
+  // CHECK: [[a_el:%.+]] = memref.load {{%.+}} : memref<20x20xf32, 3>
+  // CHECK: [[a_frag:%.+]] = vector.insert [[a_el]], {{.*}} [2, 0] : f32 into vector<4x1xf32>
+  // CHECK: [[a_el:%.+]] = memref.load {{%.+}} : memref<20x20xf32, 3>
+  // CHECK: [[a_frag:%.+]] = vector.insert [[a_el]], {{.*}} [3, 0] : f32 into vector<4x1xf32>
+
+  // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB0_map]]
+  // CHECK-DAG: [[col:%.+]] = affine.apply [[$colB0_map]]
+  // CHECK: [[b_frag:%.+]] = nvgpu.ldmatrix %arg1[[[row]], [[col]]] {numTiles = 2 : i32, transpose = false}
+
+  // CHECK: [[d_frag:%.+]] = nvgpu.mma.sync([[a_frag]], [[b_frag]], [[c_frag]])
+  // CHECK-SAME: mmaShape = [16, 8, 8]
+  // CHECK-SAME: -> vector<2x2xf32>
+  %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true], permutation_map = #map0} : memref<20x20xf32, 3>, vector<16x8xf32>
+  %B = vector.transfer_read %arg1[%c0, %c0], %cst {in_bounds = [true, true]} : memref<20x20xf32, 3>, vector<8x8xf32>
+  %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], 
+    kind = #vector.kind<add>} %A, %B, %cst_0 : vector<16x8xf32>, vector<8x8xf32> into vector<16x8xf32>
+
+  // CHECK: vector.extract [[d_frag]][0] : vector<2x2xf32>
+  // CHECK: affine.apply [[$rowC_map]]
+  // CHECK: affine.apply [[$colC_map]]
+  // CHECK: vector.store
+  // CHECK: vector.extract [[d_frag]][1] : vector<2x2xf32>
+  // CHECK: affine.apply [[$rowC8_map]]
+  // CHECK: affine.apply [[$colC_map]]
+  // CHECK: vector.store
+  vector.transfer_write %D, %arg2[%c16, %c8] {in_bounds = [true, true]} : vector<16x8xf32>, memref<20x20xf32>
+  return
+}
\ No newline at end of file


        


More information about the Mlir-commits mailing list