[Mlir-commits] [mlir] c91d3b0 - [mlir][vector] Constrain patterns: vector.contract -> vector.outerproduct

Andrzej Warzynski llvmlistbot at llvm.org
Fri Oct 6 09:08:45 PDT 2023


Author: Andrzej Warzynski
Date: 2023-10-06T16:07:07Z
New Revision: c91d3b0b08ee133573b4c24f239306a9f1741174

URL: https://github.com/llvm/llvm-project/commit/c91d3b0b08ee133573b4c24f239306a9f1741174
DIFF: https://github.com/llvm/llvm-project/commit/c91d3b0b08ee133573b4c24f239306a9f1741174.diff

LOG: [mlir][vector] Constrain patterns: vector.contract -> vector.outerproduct

This patch constrains the patterns for converting `vector.contract` to
`vector.outerproduct` so that

  * the reduction dimension is _not unrolled_ if the corresponding
    dimension is scalable.

This is necessary as the current lowering is incorrect for scalable
dims. Indeed, the following unrolling for `vector.contract` would be
invalid if the corresponding dimension was scalable (K is the size of
the reduction dimension):

```
  // K times. This is valid if K _is not_ scalable.
  %lhs = vector.extract %LHS[0]
  %rhs = vector.extract %RHS[0]
  vector.outerproduct %lhs, %rhs

  %lhs = vector.extract %LHS[1]
  %rhs = vector.extract %RHS[1]
  vector.outerproduct %lhs, %rhs

  // ...
```

Instead, a `for` loop should be generated:
```
// This would be valid regardless of whether K is scalable or not
scf.for %k = 0 to K step 1
  %lhs = vector.extract LHS[%k]
  %rhs = vector.extract RHS[%k]
  vector.outerproduct %lhs, %rhs
```

However, the lowering of:

  * `vector.extract` of vector slices with dynamic indices

is incomplete and hence the implementation proposed above (with
`scf.for`) wouldn't work just yet, i.e. it wouldn't be possible to lower
it further. Instead, this patch disables unrolling in cases when the
reduction dimension is scalable, i.e. where the generated code would be
functionally incorrect.

In order to document unsupported cases, a dedicated test file is added:

  * "vector-contract-to-outerproduct-transforms-unsupported.mlir"

This is the first patch in a series of patches that strives to update
these patterns (and to test them) for scalable vectors.

Resolves #68400

Added: 
    mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms-unsupported.mlir

Modified: 
    mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
    mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index 04d9ddf2183f8c5..311e589547b89bc 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -424,9 +424,14 @@ struct UnrolledOuterProductGenerator
     return rewriter.create<arith::ExtSIOp>(loc, promotedType, v);
   }
 
-  FailureOr<Value> outerProd(Value lhs, Value rhs, Value res, int reductionSize,
+  FailureOr<Value> outerProd(Value lhs, Value rhs, Value res,
+                             VectorType lhsType, int reductionDim,
                              std::optional<Value> maybeMask = std::nullopt) {
-    assert(reductionSize > 0);
+    // Unrolling a scalable dimension would be incorrect - bail out.
+    if (lhsType.getScalableDims()[reductionDim])
+      return failure();
+
+    int reductionSize = lhsType.getDimSize(reductionDim);
     // Incremental support for masking.
     if (mask && !maybeMask.has_value())
       return failure();
@@ -459,33 +464,39 @@ struct UnrolledOuterProductGenerator
     Value transposedMask = t(mask, {2, 0, 1});
     // Classical row-major matmul:  Just permute the lhs.
     if (layout({{m, k}, {k, n}, {m, n}}))
-      return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1), transposedMask);
+      return outerProd(t(lhs), rhs, res, lhsType, /*reductionDim=*/1,
+                       transposedMask);
     // TODO: may be better to fail and use some vector<k> -> scalar reduction.
     if (layout({{m, k}, {n, k}, {m, n}})) {
       Value tlhs = t(lhs);
-      return outerProd(tlhs, t(rhs), res, lhsType.getDimSize(1),
+      return outerProd(tlhs, t(rhs), res, lhsType, /*reductionDim=*/1,
                        transposedMask);
     }
     // No need to permute anything.
     if (layout({{k, m}, {k, n}, {m, n}}))
-      return outerProd(lhs, rhs, res, lhsType.getDimSize(0), transposedMask);
+      return outerProd(lhs, rhs, res, lhsType, /*reductionDim=*/0,
+                       transposedMask);
     // Just permute the rhs.
     if (layout({{k, m}, {n, k}, {m, n}}))
-      return outerProd(lhs, t(rhs), res, lhsType.getDimSize(0), transposedMask);
+      return outerProd(lhs, t(rhs), res, lhsType, /*reductionDim=*/0,
+                       transposedMask);
     // Transposed output: swap RHS and LHS.
     // Classical row-major matmul: permute the lhs.
     if (layout({{m, k}, {k, n}, {n, m}}))
-      return outerProd(rhs, t(lhs), res, lhsType.getDimSize(1), transposedMask);
+      return outerProd(rhs, t(lhs), res, lhsType, /*reductionDim=*/1,
+                       transposedMask);
     // TODO: may be better to fail and use some vector<k> -> scalar reduction.
     if (layout({{m, k}, {n, k}, {n, m}})) {
       Value trhs = t(rhs);
-      return outerProd(trhs, t(lhs), res, lhsType.getDimSize(1),
+      return outerProd(trhs, t(lhs), res, lhsType, /*reductionDim=*/1,
                        transposedMask);
     }
     if (layout({{k, m}, {k, n}, {n, m}}))
-      return outerProd(rhs, lhs, res, lhsType.getDimSize(0), transposedMask);
+      return outerProd(rhs, lhs, res, lhsType, /*reductionDim=*/0,
+                       transposedMask);
     if (layout({{k, m}, {n, k}, {n, m}}))
-      return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0), transposedMask);
+      return outerProd(t(rhs), lhs, res, lhsType, /*reductionDim=*/0,
+                       transposedMask);
     return failure();
   }
 
@@ -503,16 +514,20 @@ struct UnrolledOuterProductGenerator
 
     // Case mat-vec: transpose.
     if (layout({{m, k}, {k}, {m}}))
-      return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1), transposedMask);
+      return outerProd(t(lhs), rhs, res, lhsType, /*reductionDim=*/1,
+                       transposedMask);
     // Case mat-trans-vec: ready to go.
     if (layout({{k, m}, {k}, {m}}))
-      return outerProd(lhs, rhs, res, lhsType.getDimSize(0), transposedMask);
+      return outerProd(lhs, rhs, res, lhsType, /*reductionDim=*/0,
+                       transposedMask);
     // Case vec-mat: swap and transpose.
     if (layout({{k}, {m, k}, {m}}))
-      return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0), transposedMask);
+      return outerProd(t(rhs), lhs, res, lhsType, /*reductionDim=*/0,
+                       transposedMask);
     // Case vec-mat-trans: swap and ready to go.
     if (layout({{k}, {k, m}, {m}}))
-      return outerProd(rhs, lhs, res, lhsType.getDimSize(0), transposedMask);
+      return outerProd(rhs, lhs, res, lhsType, /*reductionDim=*/0,
+                       transposedMask);
     return failure();
   }
 
@@ -528,16 +543,16 @@ struct UnrolledOuterProductGenerator
 
     // Case mat-vec: transpose.
     if (layout({{m, k}, {k}, {m}}))
-      return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1), mask);
+      return outerProd(t(lhs), rhs, res, lhsType, /*reductionDim=*/1, mask);
     // Case mat-trans-vec: ready to go.
     if (layout({{k, m}, {k}, {m}}))
-      return outerProd(lhs, rhs, res, lhsType.getDimSize(0), mask);
+      return outerProd(lhs, rhs, res, lhsType, /*reductionDim=*/0, mask);
     // Case vec-mat: swap and transpose.
     if (layout({{k}, {m, k}, {m}}))
-      return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0), mask);
+      return outerProd(t(rhs), lhs, res, lhsType, /*reductionDim=*/0, mask);
     // Case vec-mat-trans: swap and ready to go.
     if (layout({{k}, {k, m}, {m}}))
-      return outerProd(rhs, lhs, res, lhsType.getDimSize(0), mask);
+      return outerProd(rhs, lhs, res, lhsType, /*reductionDim=*/0, mask);
     return failure();
   }
 
@@ -980,9 +995,19 @@ FailureOr<Value> ContractionOpLowering::lowerParallel(PatternRewriter &rewriter,
         diag << "expected lhsIndex=" << lhsIndex << " and rhsIndex=" << rhsIndex
              << " to map to the same dimension";
       });
+    if (lhsType.getScalableDims()[lhsIndex])
+      return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
+        diag << "Unrolloing scalable dimension (lhsIndex=" << lhsIndex
+             << ") is not supported yet";
+      });
     dimSize = lhsType.getDimSize(lhsIndex);
   } else if (rhsIndex >= 0) {
     iterIndex = iMap[1].getDimPosition(rhsIndex);
+    if (rhsType.getScalableDims()[rhsIndex])
+      return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
+        diag << "Unrolloing scalable dimension (lhsIndex=" << lhsIndex
+             << ") is not supported yet";
+      });
     dimSize = rhsType.getDimSize(rhsIndex);
   }
   if (iterIndex < 0)

diff  --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms-unsupported.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms-unsupported.mlir
new file mode 100644
index 000000000000000..a955250107d73d7
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms-unsupported.mlir
@@ -0,0 +1,33 @@
+// RUN: mlir-opt %s --test-transform-dialect-interpreter -allow-unregistered-dialect --split-input-file --verify-diagnostics
+
+#matvec_accesses = [
+  affine_map<(i, j) -> (i, j)>,
+  affine_map<(i, j) -> (j)>,
+  affine_map<(i, j) -> (i)>
+]
+#matvec_trait = {
+  indexing_maps = #matvec_accesses,
+  iterator_types = ["parallel", "reduction"]
+}
+
+// Unrolling scalable reduction dim is not supported - bail out
+
+// expected-error at below {{greedy pattern application failed}}
+func.func @masked_extract_contract2_scalable_reduction_dim(%arg0: vector<[2]x[3]xf32>,
+                                    %arg1: vector<[3]xf32>,
+                                    %arg2: vector<[2]xf32>,
+                                    %m: vector<[2]x[3]xi1>) -> vector<[2]xf32> {
+  %0 = vector.mask %m { vector.contract #matvec_trait %arg0, %arg1, %arg2
+          : vector<[2]x[3]xf32>, vector<[3]xf32> into vector<[2]xf32> } : vector<[2]x[3]xi1> -> vector<[2]xf32>
+  return %0 : vector<[2]xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !transform.any_op):
+  %f = transform.structured.match ops{["func.func"]} in %module_op 
+    : (!transform.any_op) -> !transform.any_op
+
+  transform.apply_patterns to %f {
+    transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
+  } : !transform.any_op
+}

diff  --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
index deea7747f36799c..1e92fcff64dea57 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
@@ -31,19 +31,19 @@
 }
 
 // CHECK-LABEL:   func.func @masked_extract_contract2(
-// CHECK-SAME:                                      %[[VAL_0:.*]]: vector<2x3xf32>,
-// CHECK-SAME:                                      %[[VAL_1:.*]]: vector<3xf32>,
-// CHECK-SAME:                                      %[[VAL_2:.*]]: vector<2xf32>,
-// CHECK-SAME:                                      %[[IN_MASK:.*]]: vector<2x3xi1>) -> vector<2xf32>
+// CHECK-SAME:      %{{.*}}: vector<2x3xf32>,
+// CHECK-SAME:      %{{.*}}: vector<3xf32>,
+// CHECK-SAME:      %{{.*}}: vector<2xf32>,
+// CHECK-SAME:      %[[IN_MASK:.*]]: vector<2x3xi1>) -> vector<2xf32>
 // CHECK:           %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [1, 0] : vector<2x3xi1> to vector<3x2xi1>
 // CHECK:           %[[MASK0:.*]] = vector.extract %[[T_MASK]][0] : vector<2xi1> from vector<3x2xi1>
-// CHECK:           vector.mask %[[MASK0]] { vector.outerproduct
+// CHECK:           vector.mask %[[MASK0]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
 
 // CHECK:           %[[MASK1:.*]] = vector.extract %[[T_MASK]][1] : vector<2xi1> from vector<3x2xi1>
-// CHECK:           vector.mask %[[MASK1]] { vector.outerproduct
+// CHECK:           vector.mask %[[MASK1]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
 
 // CHECK:           %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<2xi1> from vector<3x2xi1>
-// CHECK:           vector.mask %[[MASK2]] { vector.outerproduct
+// CHECK:           vector.mask %[[MASK2]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
 
 func.func @masked_extract_contract2(%arg0: vector<2x3xf32>,
                                     %arg1: vector<3xf32>,
@@ -54,6 +54,30 @@ func.func @masked_extract_contract2(%arg0: vector<2x3xf32>,
   return %0 : vector<2xf32>
 }
 
+
+// CHECK-LABEL:   func.func @masked_extract_contract2_scalable_parallel_dim(
+// CHECK-SAME:      %{{.*}}: vector<[2]x3xf32>,
+// CHECK-SAME:      %{{.*}}: vector<3xf32>,
+// CHECK-SAME:      %{{.*}}: vector<[2]xf32>,
+// CHECK-SAME:      %[[IN_MASK:.*]]: vector<[2]x3xi1>) -> vector<[2]xf32>
+// CHECK:           %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [1, 0] : vector<[2]x3xi1> to vector<3x[2]xi1>
+// CHECK:           %[[MASK0:.*]] = vector.extract %[[T_MASK]][0] : vector<[2]xi1> from vector<3x[2]xi1>
+// CHECK:           vector.mask %[[MASK0]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
+
+// CHECK:           %[[MASK1:.*]] = vector.extract %[[T_MASK]][1] : vector<[2]xi1> from vector<3x[2]xi1>
+// CHECK:           vector.mask %[[MASK1]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
+
+// CHECK:           %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<[2]xi1> from vector<3x[2]xi1>
+// CHECK:           vector.mask %[[MASK2]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
+func.func @masked_extract_contract2_scalable_parallel_dim(%arg0: vector<[2]x3xf32>,
+                                    %arg1: vector<3xf32>,
+                                    %arg2: vector<[2]xf32>,
+                                    %m: vector<[2]x3xi1>) -> vector<[2]xf32> {
+  %0 = vector.mask %m { vector.contract #matvec_trait %arg0, %arg1, %arg2
+          : vector<[2]x3xf32>, vector<3xf32> into vector<[2]xf32> } : vector<[2]x3xi1> -> vector<[2]xf32>
+  return %0 : vector<[2]xf32>
+}
+
 // CHECK-LABEL: func.func @masked_extract_contract4(
 // CHECK-SAME:                                      %[[VAL_0:.*]]: vector<3x5xf32>,
 // CHECK-SAME:                                      %[[VAL_1:.*]]: vector<5x7xf32>,


        


More information about the Mlir-commits mailing list