[Mlir-commits] [mlir] [mlir][vector] proper masking support for contract lowering (PR #67145)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Sep 22 07:10:36 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-vector

<details>
<summary>Changes</summary>

Support all known permutations when lowering masked vector.contract to vector.outerproduct, and not just the canonical permutation.

---
Full diff: https://github.com/llvm/llvm-project/pull/67145.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp (+25-18) 
- (modified) mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir (+168) 


``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index 64ab0abda26e640..de7aba6e84b748b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -457,38 +457,44 @@ struct UnrolledOuterProductGenerator
     // Set up the parallel/reduction structure in the right form.
     AffineExpr m, n, k;
     bindDims(rewriter.getContext(), m, n, k);
+    Value transposedMask = t(mask, {2, 0, 1});
     // Classical row-major matmul:  Just permute the lhs.
     if (layout({{m, k}, {k, n}, {m, n}}))
-      return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1),
-                       t(mask, {2, 0, 1}));
+      return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1), transposedMask);
     // TODO: may be better to fail and use some vector<k> -> scalar reduction.
     if (layout({{m, k}, {n, k}, {m, n}})) {
       Value tlhs = t(lhs);
-      return outerProd(tlhs, t(rhs), res, lhsType.getDimSize(1));
+      return outerProd(tlhs, t(rhs), res, lhsType.getDimSize(1),
+                       transposedMask);
     }
     // No need to permute anything.
     if (layout({{k, m}, {k, n}, {m, n}}))
-      return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
+      return outerProd(lhs, rhs, res, lhsType.getDimSize(0), transposedMask);
     // Just permute the rhs.
     if (layout({{k, m}, {n, k}, {m, n}}))
-      return outerProd(lhs, t(rhs), res, lhsType.getDimSize(0));
+      return outerProd(lhs, t(rhs), res, lhsType.getDimSize(0), transposedMask);
     // Transposed output: swap RHS and LHS.
     // Classical row-major matmul: permute the lhs.
     if (layout({{m, k}, {k, n}, {n, m}}))
-      return outerProd(rhs, t(lhs), res, lhsType.getDimSize(1));
+      return outerProd(rhs, t(lhs), res, lhsType.getDimSize(1), transposedMask);
     // TODO: may be better to fail and use some vector<k> -> scalar reduction.
     if (layout({{m, k}, {n, k}, {n, m}})) {
       Value trhs = t(rhs);
-      return outerProd(trhs, t(lhs), res, lhsType.getDimSize(1));
+      return outerProd(trhs, t(lhs), res, lhsType.getDimSize(1),
+                       transposedMask);
     }
     if (layout({{k, m}, {k, n}, {n, m}}))
-      return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
+      return outerProd(rhs, lhs, res, lhsType.getDimSize(0), transposedMask);
     if (layout({{k, m}, {n, k}, {n, m}}))
-      return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
+      return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0), transposedMask);
     return failure();
   }
 
-  /// One outer parallel, one inner reduction (matvec flavor)
+  //
+  // One outer parallel, one inner reduction (matvec flavor).
+  // Mask needs to be transposed everywhere to turn the reduction dimension
+  // outermost as required by outerproduct.
+  //
   FailureOr<Value> matvec() {
     if (!iters({Par(), Red()}))
       return failure();
@@ -500,18 +506,19 @@ struct UnrolledOuterProductGenerator
       return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1), t(mask));
     // Case mat-trans-vec: ready to go.
     if (layout({{k, m}, {k}, {m}}))
-      return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
+      return outerProd(lhs, rhs, res, lhsType.getDimSize(0), t(mask));
     // Case vec-mat: swap and transpose.
     if (layout({{k}, {m, k}, {m}}))
-      return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
+      return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0), t(mask));
     // Case vec-mat-trans: swap and ready to go.
     if (layout({{k}, {k, m}, {m}}))
-      return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
+      return outerProd(rhs, lhs, res, lhsType.getDimSize(0), t(mask));
     return failure();
   }
 
   //
-  // One outer reduction, one inner parallel (tmatvec flavor)
+  // One outer reduction, one inner parallel (tmatvec flavor).
+  // Mask already has the shape of the outer product.
   //
   FailureOr<Value> tmatvec() {
     if (!iters({Red(), Par()}))
@@ -521,16 +528,16 @@ struct UnrolledOuterProductGenerator
 
     // Case mat-vec: transpose.
     if (layout({{m, k}, {k}, {m}}))
-      return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1));
+      return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1), mask);
     // Case mat-trans-vec: ready to go.
     if (layout({{k, m}, {k}, {m}}))
-      return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
+      return outerProd(lhs, rhs, res, lhsType.getDimSize(0), mask);
     // Case vec-mat: swap and transpose.
     if (layout({{k}, {m, k}, {m}}))
-      return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
+      return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0), mask);
     // Case vec-mat-trans: swap and ready to go.
     if (layout({{k}, {k, m}, {m}}))
-      return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
+      return outerProd(rhs, lhs, res, lhsType.getDimSize(0), mask);
     return failure();
   }
 
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
index aee4cf6b1379bd0..1b4428431a92ba3 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
@@ -341,6 +341,174 @@ func.func @matmul_7(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vecto
   return %0 : vector<3x2xf32>
 }
 
+// CHECK-LABEL: @masked_matvec_mk_k_m
+// CHECK-SAME:  %[[MAT:.+]]: vector<4x2xf32>
+// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[INIT:.+]]: vector<4xf32>
+// CHECK-SAME:  %[[MASK:.+]]: vector<4x2xi1>
+func.func @masked_matvec_mk_k_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<4x2xi1>) -> vector<4xf32> {
+  // CHECK:         vector.transpose %[[MAT]]
+  // CHECK:         vector.transpose %[[MASK]]
+  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
+  %res = vector.mask %mask {
+    vector.contract {
+      indexing_maps = [affine_map<(m, k) -> (m, k)>,
+                       affine_map<(m, k) -> (k)>,
+                       affine_map<(m, k) -> (m)>],
+      iterator_types = ["parallel", "reduction"],
+      kind = #vector.kind<add>
+    } %arg0, %arg1, %arg2 : vector<4x2xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
+  } : vector<4x2xi1> -> vector<4xf32>
+  return %res : vector<4xf32>
+}
+
+// CHECK-LABEL: @masked_matvec_km_k_m
+// CHECK-SAME:  %[[MAT:.+]]: vector<2x4xf32>
+// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[INIT:.+]]: vector<4xf32>
+// CHECK-SAME:  %[[MASK:.+]]: vector<4x2xi1>
+func.func @masked_matvec_km_k_m(%arg0: vector<2x4xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<4x2xi1>) -> vector<4xf32> {
+  // CHECK-NOT:     vector.transpose %[[MAT]]
+  // CHECK:         vector.transpose %[[MASK]]
+  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
+  %res = vector.mask %mask {
+    vector.contract {
+      indexing_maps = [affine_map<(m, k) -> (k, m)>,
+                       affine_map<(m, k) -> (k)>,
+                       affine_map<(m, k) -> (m)>],
+      iterator_types = ["parallel", "reduction"],
+      kind = #vector.kind<add>
+    } %arg0, %arg1, %arg2 : vector<2x4xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
+  } : vector<4x2xi1> -> vector<4xf32>
+  return %res : vector<4xf32>
+}
+
+// CHECK-LABEL: @masked_matvec_k_mk_m
+// CHECK-SAME:  %[[MAT:.+]]: vector<4x2xf32>
+// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[INIT:.+]]: vector<4xf32>
+// CHECK-SAME:  %[[MASK:.+]]: vector<4x2xi1>
+func.func @masked_matvec_k_mk_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<4x2xi1>) -> vector<4xf32> {
+  // CHECK:         vector.transpose %[[MAT]]
+  // CHECK:         vector.transpose %[[MASK]]
+  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
+  %res = vector.mask %mask {
+    vector.contract {
+      indexing_maps = [affine_map<(m, k) -> (k)>,
+                       affine_map<(m, k) -> (m, k)>,
+                       affine_map<(m, k) -> (m)>],
+      iterator_types = ["parallel", "reduction"],
+      kind = #vector.kind<add>
+    } %arg1, %arg0, %arg2 : vector<2xf32>, vector<4x2xf32>, vector<4xf32> into vector<4xf32>
+  } : vector<4x2xi1> -> vector<4xf32>
+  return %res : vector<4xf32>
+}
+
+// CHECK-LABEL: @masked_matvec_k_km_m
+// CHECK-SAME:  %[[MAT:.+]]: vector<2x4xf32>
+// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[INIT:.+]]: vector<4xf32>
+// CHECK-SAME:  %[[MASK:.+]]: vector<4x2xi1>
+func.func @masked_matvec_k_km_m(%arg0: vector<2x4xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<4x2xi1>) -> vector<4xf32> {
+  // CHECK-NOT:     vector.transpose %[[MAT]]
+  // CHECK:         vector.transpose %[[MASK]]
+  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
+  %res = vector.mask %mask {
+    vector.contract {
+      indexing_maps = [affine_map<(m, k) -> (k)>,
+                       affine_map<(m, k) -> (k, m)>,
+                       affine_map<(m, k) -> (m)>],
+      iterator_types = ["parallel", "reduction"],
+      kind = #vector.kind<add>
+    } %arg1, %arg0, %arg2 : vector<2xf32>, vector<2x4xf32>, vector<4xf32> into vector<4xf32>
+  } : vector<4x2xi1> -> vector<4xf32>
+  return %res : vector<4xf32>
+}
+
+// CHECK-LABEL: @masked_tmatvec_mk_k_m
+// CHECK-SAME:  %[[MAT:.+]]: vector<4x2xf32>
+// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[INIT:.+]]: vector<4xf32>
+// CHECK-SAME:  %[[MASK:.+]]: vector<2x4xi1>
+func.func @masked_tmatvec_mk_k_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<2x4xi1>) -> vector<4xf32> {
+  // CHECK:         vector.transpose %[[MAT]]
+  // CHECK-NOT:     vector.transpose %[[MASK]]
+  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
+  %res = vector.mask %mask {
+    vector.contract {
+      indexing_maps = [affine_map<(k, m) -> (m, k)>,
+                       affine_map<(k, m) -> (k)>,
+                       affine_map<(k, m) -> (m)>],
+      iterator_types = ["reduction", "parallel"],
+      kind = #vector.kind<add>
+    } %arg0, %arg1, %arg2 : vector<4x2xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
+  } : vector<2x4xi1> -> vector<4xf32>
+  return %res : vector<4xf32>
+}
+
+// CHECK-LABEL: @masked_tmatvec_km_k_m
+// CHECK-SAME:  %[[MAT:.+]]: vector<2x4xf32>
+// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[INIT:.+]]: vector<4xf32>
+// CHECK-SAME:  %[[MASK:.+]]: vector<2x4xi1>
+func.func @masked_tmatvec_km_k_m(%arg0: vector<2x4xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<2x4xi1>) -> vector<4xf32> {
+  // CHECK-NOT:     vector.transpose %[[MAT]]
+  // CHECK-NOT:     vector.transpose %[[MASK]]
+  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
+  %res = vector.mask %mask {
+    vector.contract {
+      indexing_maps = [affine_map<(k, m) -> (k, m)>,
+                       affine_map<(k, m) -> (k)>,
+                       affine_map<(k, m) -> (m)>],
+      iterator_types = ["reduction", "parallel"],
+      kind = #vector.kind<add>
+    } %arg0, %arg1, %arg2 : vector<2x4xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
+  } : vector<2x4xi1> -> vector<4xf32>
+  return %res : vector<4xf32>
+}
+
+// CHECK-LABEL: @masked_tmatvec_k_mk_m
+// CHECK-SAME:  %[[MAT:.+]]: vector<4x2xf32>
+// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[INIT:.+]]: vector<4xf32>
+// CHECK-SAME:  %[[MASK:.+]]: vector<2x4xi1>
+func.func @masked_tmatvec_k_mk_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<2x4xi1>) -> vector<4xf32> {
+  // CHECK:         vector.transpose %[[MAT]]
+  // CHECK-NOT:     vector.transpose %[[MASK]]
+  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
+  %res = vector.mask %mask {
+    vector.contract {
+      indexing_maps = [affine_map<(k, m) -> (k)>,
+                       affine_map<(k, m) -> (m, k)>,
+                       affine_map<(k, m) -> (m)>],
+      iterator_types = ["reduction", "parallel"],
+      kind = #vector.kind<add>
+    } %arg1, %arg0, %arg2 : vector<2xf32>, vector<4x2xf32>, vector<4xf32> into vector<4xf32>
+  } : vector<2x4xi1> -> vector<4xf32>
+  return %res : vector<4xf32>
+}
+
+// CHECK-LABEL: @masked_tmatvec_k_km_m
+// CHECK-SAME:  %[[MAT:.+]]: vector<2x4xf32>
+// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[INIT:.+]]: vector<4xf32>
+// CHECK-SAME:  %[[MASK:.+]]: vector<2x4xi1>
+func.func @masked_tmatvec_k_km_m(%arg0: vector<2x4xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<2x4xi1>) -> vector<4xf32> {
+  // CHECK-NOT:     vector.transpose %[[MAT]]
+  // CHECK-NOT:     vector.transpose %[[MASK]]
+  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
+  %res = vector.mask %mask {
+    vector.contract {
+      indexing_maps = [affine_map<(k, m) -> (k)>,
+                       affine_map<(k, m) -> (k, m)>,
+                       affine_map<(k, m) -> (m)>],
+      iterator_types = ["reduction", "parallel"],
+      kind = #vector.kind<add>
+    } %arg1, %arg0, %arg2 : vector<2xf32>, vector<2x4xf32>, vector<4xf32> into vector<4xf32>
+  } : vector<2x4xi1> -> vector<4xf32>
+  return %res : vector<4xf32>
+}
+
 
 transform.sequence failures(propagate) {
 ^bb1(%module_op: !transform.any_op):

``````````

</details>


https://github.com/llvm/llvm-project/pull/67145


More information about the Mlir-commits mailing list