[Mlir-commits] [mlir] a509a18 - [mlir][vector] proper masking support for contract lowering (#67145)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Sep 25 04:38:50 PDT 2023


Author: Oleksandr "Alex" Zinenko
Date: 2023-09-25T13:38:46+02:00
New Revision: a509a18731e6b77629dbacbfc369382eb684d4a9

URL: https://github.com/llvm/llvm-project/commit/a509a18731e6b77629dbacbfc369382eb684d4a9
DIFF: https://github.com/llvm/llvm-project/commit/a509a18731e6b77629dbacbfc369382eb684d4a9.diff

LOG: [mlir][vector] proper masking support for contract lowering (#67145)

Support all known permutations when lowering masked vector.contract to
vector.outerproduct, and not just the canonical permutation.

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
    mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index 7560db2332cf8d9..04d9ddf2183f8c5 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -456,61 +456,69 @@ struct UnrolledOuterProductGenerator
     // Set up the parallel/reduction structure in the right form.
     AffineExpr m, n, k;
     bindDims(rewriter.getContext(), m, n, k);
+    Value transposedMask = t(mask, {2, 0, 1});
     // Classical row-major matmul:  Just permute the lhs.
     if (layout({{m, k}, {k, n}, {m, n}}))
-      return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1),
-                       t(mask, {2, 0, 1}));
+      return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1), transposedMask);
     // TODO: may be better to fail and use some vector<k> -> scalar reduction.
     if (layout({{m, k}, {n, k}, {m, n}})) {
       Value tlhs = t(lhs);
-      return outerProd(tlhs, t(rhs), res, lhsType.getDimSize(1));
+      return outerProd(tlhs, t(rhs), res, lhsType.getDimSize(1),
+                       transposedMask);
     }
     // No need to permute anything.
     if (layout({{k, m}, {k, n}, {m, n}}))
-      return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
+      return outerProd(lhs, rhs, res, lhsType.getDimSize(0), transposedMask);
     // Just permute the rhs.
     if (layout({{k, m}, {n, k}, {m, n}}))
-      return outerProd(lhs, t(rhs), res, lhsType.getDimSize(0));
+      return outerProd(lhs, t(rhs), res, lhsType.getDimSize(0), transposedMask);
     // Transposed output: swap RHS and LHS.
     // Classical row-major matmul: permute the lhs.
     if (layout({{m, k}, {k, n}, {n, m}}))
-      return outerProd(rhs, t(lhs), res, lhsType.getDimSize(1));
+      return outerProd(rhs, t(lhs), res, lhsType.getDimSize(1), transposedMask);
     // TODO: may be better to fail and use some vector<k> -> scalar reduction.
     if (layout({{m, k}, {n, k}, {n, m}})) {
       Value trhs = t(rhs);
-      return outerProd(trhs, t(lhs), res, lhsType.getDimSize(1));
+      return outerProd(trhs, t(lhs), res, lhsType.getDimSize(1),
+                       transposedMask);
     }
     if (layout({{k, m}, {k, n}, {n, m}}))
-      return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
+      return outerProd(rhs, lhs, res, lhsType.getDimSize(0), transposedMask);
     if (layout({{k, m}, {n, k}, {n, m}}))
-      return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
+      return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0), transposedMask);
     return failure();
   }
 
-  /// One outer parallel, one inner reduction (matvec flavor)
+  //
+  // One outer parallel, one inner reduction (matvec flavor).
+  // Mask needs to be transposed everywhere to turn the reduction dimension
+  // outermost as required by outerproduct.
+  //
   FailureOr<Value> matvec() {
     if (!iters({Par(), Red()}))
       return failure();
     AffineExpr m, k;
     bindDims(rewriter.getContext(), m, k);
+    Value transposedMask = t(mask);
 
     // Case mat-vec: transpose.
     if (layout({{m, k}, {k}, {m}}))
-      return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1), t(mask));
+      return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1), transposedMask);
     // Case mat-trans-vec: ready to go.
     if (layout({{k, m}, {k}, {m}}))
-      return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
+      return outerProd(lhs, rhs, res, lhsType.getDimSize(0), transposedMask);
     // Case vec-mat: swap and transpose.
     if (layout({{k}, {m, k}, {m}}))
-      return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
+      return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0), transposedMask);
     // Case vec-mat-trans: swap and ready to go.
     if (layout({{k}, {k, m}, {m}}))
-      return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
+      return outerProd(rhs, lhs, res, lhsType.getDimSize(0), transposedMask);
     return failure();
   }
 
   //
-  // One outer reduction, one inner parallel (tmatvec flavor)
+  // One outer reduction, one inner parallel (tmatvec flavor).
+  // Mask already has the shape of the outer product.
   //
   FailureOr<Value> tmatvec() {
     if (!iters({Red(), Par()}))
@@ -520,16 +528,16 @@ struct UnrolledOuterProductGenerator
 
     // Case mat-vec: transpose.
     if (layout({{m, k}, {k}, {m}}))
-      return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1));
+      return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1), mask);
     // Case mat-trans-vec: ready to go.
     if (layout({{k, m}, {k}, {m}}))
-      return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
+      return outerProd(lhs, rhs, res, lhsType.getDimSize(0), mask);
     // Case vec-mat: swap and transpose.
     if (layout({{k}, {m, k}, {m}}))
-      return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
+      return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0), mask);
     // Case vec-mat-trans: swap and ready to go.
     if (layout({{k}, {k, m}, {m}}))
-      return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
+      return outerProd(rhs, lhs, res, lhsType.getDimSize(0), mask);
     return failure();
   }
 

diff  --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
index aee4cf6b1379bd0..982de77786745b6 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
@@ -341,6 +341,174 @@ func.func @matmul_7(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vecto
   return %0 : vector<3x2xf32>
 }
 
+// CHECK-LABEL: @masked_matvec_mk_k_m
+// CHECK-SAME:  %[[MAT:.+]]: vector<4x2xf32>
+// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[INIT:.+]]: vector<4xf32>
+// CHECK-SAME:  %[[MASK:.+]]: vector<4x2xi1>
+func.func @masked_matvec_mk_k_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<4x2xi1>) -> vector<4xf32> {
+  // CHECK:         vector.transpose %[[MASK]]
+  // CHECK:         vector.transpose %[[MAT]]
+  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
+  %res = vector.mask %mask {
+    vector.contract {
+      indexing_maps = [affine_map<(m, k) -> (m, k)>,
+                       affine_map<(m, k) -> (k)>,
+                       affine_map<(m, k) -> (m)>],
+      iterator_types = ["parallel", "reduction"],
+      kind = #vector.kind<add>
+    } %arg0, %arg1, %arg2 : vector<4x2xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
+  } : vector<4x2xi1> -> vector<4xf32>
+  return %res : vector<4xf32>
+}
+
+// CHECK-LABEL: @masked_matvec_km_k_m
+// CHECK-SAME:  %[[MAT:.+]]: vector<2x4xf32>
+// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[INIT:.+]]: vector<4xf32>
+// CHECK-SAME:  %[[MASK:.+]]: vector<4x2xi1>
+func.func @masked_matvec_km_k_m(%arg0: vector<2x4xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<4x2xi1>) -> vector<4xf32> {
+  // CHECK:         vector.transpose %[[MASK]]
+  // CHECK-NOT:     vector.transpose %[[MAT]]
+  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
+  %res = vector.mask %mask {
+    vector.contract {
+      indexing_maps = [affine_map<(m, k) -> (k, m)>,
+                       affine_map<(m, k) -> (k)>,
+                       affine_map<(m, k) -> (m)>],
+      iterator_types = ["parallel", "reduction"],
+      kind = #vector.kind<add>
+    } %arg0, %arg1, %arg2 : vector<2x4xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
+  } : vector<4x2xi1> -> vector<4xf32>
+  return %res : vector<4xf32>
+}
+
+// CHECK-LABEL: @masked_matvec_k_mk_m
+// CHECK-SAME:  %[[MAT:.+]]: vector<4x2xf32>
+// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[INIT:.+]]: vector<4xf32>
+// CHECK-SAME:  %[[MASK:.+]]: vector<4x2xi1>
+func.func @masked_matvec_k_mk_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<4x2xi1>) -> vector<4xf32> {
+  // CHECK:         vector.transpose %[[MASK]]
+  // CHECK:         vector.transpose %[[MAT]]
+  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
+  %res = vector.mask %mask {
+    vector.contract {
+      indexing_maps = [affine_map<(m, k) -> (k)>,
+                       affine_map<(m, k) -> (m, k)>,
+                       affine_map<(m, k) -> (m)>],
+      iterator_types = ["parallel", "reduction"],
+      kind = #vector.kind<add>
+    } %arg1, %arg0, %arg2 : vector<2xf32>, vector<4x2xf32>, vector<4xf32> into vector<4xf32>
+  } : vector<4x2xi1> -> vector<4xf32>
+  return %res : vector<4xf32>
+}
+
+// CHECK-LABEL: @masked_matvec_k_km_m
+// CHECK-SAME:  %[[MAT:.+]]: vector<2x4xf32>
+// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[INIT:.+]]: vector<4xf32>
+// CHECK-SAME:  %[[MASK:.+]]: vector<4x2xi1>
+func.func @masked_matvec_k_km_m(%arg0: vector<2x4xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<4x2xi1>) -> vector<4xf32> {
+  // CHECK:         vector.transpose %[[MASK]]
+  // CHECK-NOT:     vector.transpose %[[MAT]]
+  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
+  %res = vector.mask %mask {
+    vector.contract {
+      indexing_maps = [affine_map<(m, k) -> (k)>,
+                       affine_map<(m, k) -> (k, m)>,
+                       affine_map<(m, k) -> (m)>],
+      iterator_types = ["parallel", "reduction"],
+      kind = #vector.kind<add>
+    } %arg1, %arg0, %arg2 : vector<2xf32>, vector<2x4xf32>, vector<4xf32> into vector<4xf32>
+  } : vector<4x2xi1> -> vector<4xf32>
+  return %res : vector<4xf32>
+}
+
+// CHECK-LABEL: @masked_tmatvec_mk_k_m
+// CHECK-SAME:  %[[MAT:.+]]: vector<4x2xf32>
+// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[INIT:.+]]: vector<4xf32>
+// CHECK-SAME:  %[[MASK:.+]]: vector<2x4xi1>
+func.func @masked_tmatvec_mk_k_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<2x4xi1>) -> vector<4xf32> {
+  // CHECK:         vector.transpose %[[MAT]]
+  // CHECK-NOT:     vector.transpose %[[MASK]]
+  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
+  %res = vector.mask %mask {
+    vector.contract {
+      indexing_maps = [affine_map<(k, m) -> (m, k)>,
+                       affine_map<(k, m) -> (k)>,
+                       affine_map<(k, m) -> (m)>],
+      iterator_types = ["reduction", "parallel"],
+      kind = #vector.kind<add>
+    } %arg0, %arg1, %arg2 : vector<4x2xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
+  } : vector<2x4xi1> -> vector<4xf32>
+  return %res : vector<4xf32>
+}
+
+// CHECK-LABEL: @masked_tmatvec_km_k_m
+// CHECK-SAME:  %[[MAT:.+]]: vector<2x4xf32>
+// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[INIT:.+]]: vector<4xf32>
+// CHECK-SAME:  %[[MASK:.+]]: vector<2x4xi1>
+func.func @masked_tmatvec_km_k_m(%arg0: vector<2x4xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<2x4xi1>) -> vector<4xf32> {
+  // CHECK-NOT:     vector.transpose %[[MAT]]
+  // CHECK-NOT:     vector.transpose %[[MASK]]
+  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
+  %res = vector.mask %mask {
+    vector.contract {
+      indexing_maps = [affine_map<(k, m) -> (k, m)>,
+                       affine_map<(k, m) -> (k)>,
+                       affine_map<(k, m) -> (m)>],
+      iterator_types = ["reduction", "parallel"],
+      kind = #vector.kind<add>
+    } %arg0, %arg1, %arg2 : vector<2x4xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
+  } : vector<2x4xi1> -> vector<4xf32>
+  return %res : vector<4xf32>
+}
+
+// CHECK-LABEL: @masked_tmatvec_k_mk_m
+// CHECK-SAME:  %[[MAT:.+]]: vector<4x2xf32>
+// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[INIT:.+]]: vector<4xf32>
+// CHECK-SAME:  %[[MASK:.+]]: vector<2x4xi1>
+func.func @masked_tmatvec_k_mk_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<2x4xi1>) -> vector<4xf32> {
+  // CHECK:         vector.transpose %[[MAT]]
+  // CHECK-NOT:     vector.transpose %[[MASK]]
+  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
+  %res = vector.mask %mask {
+    vector.contract {
+      indexing_maps = [affine_map<(k, m) -> (k)>,
+                       affine_map<(k, m) -> (m, k)>,
+                       affine_map<(k, m) -> (m)>],
+      iterator_types = ["reduction", "parallel"],
+      kind = #vector.kind<add>
+    } %arg1, %arg0, %arg2 : vector<2xf32>, vector<4x2xf32>, vector<4xf32> into vector<4xf32>
+  } : vector<2x4xi1> -> vector<4xf32>
+  return %res : vector<4xf32>
+}
+
+// CHECK-LABEL: @masked_tmatvec_k_km_m
+// CHECK-SAME:  %[[MAT:.+]]: vector<2x4xf32>
+// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[INIT:.+]]: vector<4xf32>
+// CHECK-SAME:  %[[MASK:.+]]: vector<2x4xi1>
+func.func @masked_tmatvec_k_km_m(%arg0: vector<2x4xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<2x4xi1>) -> vector<4xf32> {
+  // CHECK-NOT:     vector.transpose %[[MAT]]
+  // CHECK-NOT:     vector.transpose %[[MASK]]
+  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
+  %res = vector.mask %mask {
+    vector.contract {
+      indexing_maps = [affine_map<(k, m) -> (k)>,
+                       affine_map<(k, m) -> (k, m)>,
+                       affine_map<(k, m) -> (m)>],
+      iterator_types = ["reduction", "parallel"],
+      kind = #vector.kind<add>
+    } %arg1, %arg0, %arg2 : vector<2xf32>, vector<2x4xf32>, vector<4xf32> into vector<4xf32>
+  } : vector<2x4xi1> -> vector<4xf32>
+  return %res : vector<4xf32>
+}
+
 
 transform.sequence failures(propagate) {
 ^bb1(%module_op: !transform.any_op):


        


More information about the Mlir-commits mailing list