[Mlir-commits] [mlir] [mlir][vector] Constrain patterns: vector.contract -> vector.outerproduct (PR #68400)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Oct 6 02:57:12 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

<details>
<summary>Changes</summary>

This patch constrains the patterns for converting `vector.contract` to
`vector.outerproduct` so that

  * the reduction dimension is _not unrolled_ if the corresponding
    dimension is scalable.

This is necessary as the current lowering is incorrect for scalable
dims. Indeed, instead of the following unrolling that is currently being
generated for `vector.contract` (K is the size  f the reduction
dimension):

```
  // K times
  %lhs = vector.extract %LHS[0]
  %rhs = vector.extract %RHS[0]
  vector.outerproduct %lhs, %rhs

  %lhs = vector.extract %LHS[1]
  %rhs = vector.extract %RHS[1]
  vector.outerproduct %lhs, %rhs

  ...
```

we should be generating a `for` loop like the following:
```
scf.for %k = 0 to K step 1
  %lhs = vector.extract LHS[%k]
  %rhs = vector.extract RHS[%k]
  vector.outerproduct %lhs, %rhs
```

However, the lowering of `vector.extract` of vector slices with dynamic
indices is incomplete and hence the implementation above wouldn't work
just yet. Instead, this patch effectively disables unrolling in case
where the generated code would be functionally incorrect (i.e. when the
reduction dimension is scalable).

In order to document unsupported cases, a dedicated test file is added:

* "vector-contract-to-outerproduct-transforms-unsupported.mlir"

This is the first patch in a series of patches that strives to update
these patterns (and to test them) for scalable vectors.


---
Full diff: https://github.com/llvm/llvm-project/pull/68400.diff


3 Files Affected:

- (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp (+39-18) 
- (added) mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms-unsupported.mlir (+33) 
- (modified) mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir (+30-23) 


``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index 04d9ddf2183f8c5..c1cc0d7c64de264 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -424,9 +424,14 @@ struct UnrolledOuterProductGenerator
     return rewriter.create<arith::ExtSIOp>(loc, promotedType, v);
   }
 
-  FailureOr<Value> outerProd(Value lhs, Value rhs, Value res, int reductionSize,
+  FailureOr<Value> outerProd(Value lhs, Value rhs, Value res, VectorType lhsType,
+                             int reductionDim,
                              std::optional<Value> maybeMask = std::nullopt) {
-    assert(reductionSize > 0);
+    // Unrolling a scalable dimension would be incorrect - bail out.
+    if (lhsType.getScalableDims()[reductionDim])
+      return failure();
+
+    int reductionSize = lhsType.getDimSize(reductionDim);
     // Incremental support for masking.
     if (mask && !maybeMask.has_value())
       return failure();
@@ -459,33 +464,39 @@ struct UnrolledOuterProductGenerator
     Value transposedMask = t(mask, {2, 0, 1});
     // Classical row-major matmul:  Just permute the lhs.
     if (layout({{m, k}, {k, n}, {m, n}}))
-      return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1), transposedMask);
+      return outerProd(t(lhs), rhs, res, lhsType, /*reductionDim=*/1,
+                       transposedMask);
     // TODO: may be better to fail and use some vector<k> -> scalar reduction.
     if (layout({{m, k}, {n, k}, {m, n}})) {
       Value tlhs = t(lhs);
-      return outerProd(tlhs, t(rhs), res, lhsType.getDimSize(1),
+      return outerProd(tlhs, t(rhs), res, lhsType, /*reductionDim=*/1,
                        transposedMask);
     }
     // No need to permute anything.
     if (layout({{k, m}, {k, n}, {m, n}}))
-      return outerProd(lhs, rhs, res, lhsType.getDimSize(0), transposedMask);
+      return outerProd(lhs, rhs, res, lhsType, /*reductionDim=*/0,
+                       transposedMask);
     // Just permute the rhs.
     if (layout({{k, m}, {n, k}, {m, n}}))
-      return outerProd(lhs, t(rhs), res, lhsType.getDimSize(0), transposedMask);
+      return outerProd(lhs, t(rhs), res, lhsType, /*reductionDim=*/0,
+                       transposedMask);
     // Transposed output: swap RHS and LHS.
     // Classical row-major matmul: permute the lhs.
     if (layout({{m, k}, {k, n}, {n, m}}))
-      return outerProd(rhs, t(lhs), res, lhsType.getDimSize(1), transposedMask);
+      return outerProd(rhs, t(lhs), res, lhsType, /*reductionDim=*/1,
+                       transposedMask);
     // TODO: may be better to fail and use some vector<k> -> scalar reduction.
     if (layout({{m, k}, {n, k}, {n, m}})) {
       Value trhs = t(rhs);
-      return outerProd(trhs, t(lhs), res, lhsType.getDimSize(1),
+      return outerProd(trhs, t(lhs), res, lhsType, /*reductionDim=*/1,
                        transposedMask);
     }
     if (layout({{k, m}, {k, n}, {n, m}}))
-      return outerProd(rhs, lhs, res, lhsType.getDimSize(0), transposedMask);
+      return outerProd(rhs, lhs, res, lhsType, /*reductionDim=*/0,
+                       transposedMask);
     if (layout({{k, m}, {n, k}, {n, m}}))
-      return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0), transposedMask);
+      return outerProd(t(rhs), lhs, res, lhsType, /*reductionDim=*/0,
+                       transposedMask);
     return failure();
   }
 
@@ -503,16 +514,20 @@ struct UnrolledOuterProductGenerator
 
     // Case mat-vec: transpose.
     if (layout({{m, k}, {k}, {m}}))
-      return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1), transposedMask);
+      return outerProd(t(lhs), rhs, res, lhsType, /*reductionDim=*/1,
+                       transposedMask);
     // Case mat-trans-vec: ready to go.
     if (layout({{k, m}, {k}, {m}}))
-      return outerProd(lhs, rhs, res, lhsType.getDimSize(0), transposedMask);
+      return outerProd(lhs, rhs, res, lhsType, /*reductionDim=*/0,
+                       transposedMask);
     // Case vec-mat: swap and transpose.
     if (layout({{k}, {m, k}, {m}}))
-      return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0), transposedMask);
+      return outerProd(t(rhs), lhs, res, lhsType, /*reductionDim=*/0,
+                       transposedMask);
     // Case vec-mat-trans: swap and ready to go.
     if (layout({{k}, {k, m}, {m}}))
-      return outerProd(rhs, lhs, res, lhsType.getDimSize(0), transposedMask);
+      return outerProd(rhs, lhs, res, lhsType, /*reductionDim=*/0,
+                       transposedMask);
     return failure();
   }
 
@@ -528,16 +543,16 @@ struct UnrolledOuterProductGenerator
 
     // Case mat-vec: transpose.
     if (layout({{m, k}, {k}, {m}}))
-      return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1), mask);
+      return outerProd(t(lhs), rhs, res, lhsType, /*reductionDim=*/1, mask);
     // Case mat-trans-vec: ready to go.
     if (layout({{k, m}, {k}, {m}}))
-      return outerProd(lhs, rhs, res, lhsType.getDimSize(0), mask);
+      return outerProd(lhs, rhs, res, lhsType, /*reductionDim=*/0, mask);
     // Case vec-mat: swap and transpose.
     if (layout({{k}, {m, k}, {m}}))
-      return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0), mask);
+      return outerProd(t(rhs), lhs, res, lhsType, /*reductionDim=*/0, mask);
     // Case vec-mat-trans: swap and ready to go.
     if (layout({{k}, {k, m}, {m}}))
-      return outerProd(rhs, lhs, res, lhsType.getDimSize(0), mask);
+      return outerProd(rhs, lhs, res, lhsType, /*reductionDim=*/0, mask);
     return failure();
   }
 
@@ -980,9 +995,15 @@ FailureOr<Value> ContractionOpLowering::lowerParallel(PatternRewriter &rewriter,
         diag << "expected lhsIndex=" << lhsIndex << " and rhsIndex=" << rhsIndex
              << " to map to the same dimension";
       });
+    // Unrolling a scalable dimension would be incorrect - bail out.
+    if (lhsType.getScalableDims()[lhsIndex])
+      return failure();
     dimSize = lhsType.getDimSize(lhsIndex);
   } else if (rhsIndex >= 0) {
     iterIndex = iMap[1].getDimPosition(rhsIndex);
+    // Unrolling a scalable dimension would be incorrect - bail out.
+    if (rhsType.getScalableDims()[rhsIndex])
+      return failure();
     dimSize = rhsType.getDimSize(rhsIndex);
   }
   if (iterIndex < 0)
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms-unsupported.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms-unsupported.mlir
new file mode 100644
index 000000000000000..a955250107d73d7
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms-unsupported.mlir
@@ -0,0 +1,33 @@
+// RUN: mlir-opt %s --test-transform-dialect-interpreter -allow-unregistered-dialect --split-input-file --verify-diagnostics
+
+#matvec_accesses = [
+  affine_map<(i, j) -> (i, j)>,
+  affine_map<(i, j) -> (j)>,
+  affine_map<(i, j) -> (i)>
+]
+#matvec_trait = {
+  indexing_maps = #matvec_accesses,
+  iterator_types = ["parallel", "reduction"]
+}
+
+// Unrolling scalable reduction dim is not supported - bail out
+
+// expected-error at below {{greedy pattern application failed}}
+func.func @masked_extract_contract2_scalable_reduction_dim(%arg0: vector<[2]x[3]xf32>,
+                                    %arg1: vector<[3]xf32>,
+                                    %arg2: vector<[2]xf32>,
+                                    %m: vector<[2]x[3]xi1>) -> vector<[2]xf32> {
+  %0 = vector.mask %m { vector.contract #matvec_trait %arg0, %arg1, %arg2
+          : vector<[2]x[3]xf32>, vector<[3]xf32> into vector<[2]xf32> } : vector<[2]x[3]xi1> -> vector<[2]xf32>
+  return %0 : vector<[2]xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !transform.any_op):
+  %f = transform.structured.match ops{["func.func"]} in %module_op 
+    : (!transform.any_op) -> !transform.any_op
+
+  transform.apply_patterns to %f {
+    transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
+  } : !transform.any_op
+}
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
index deea7747f36799c..8ee0a35717ce87a 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
@@ -31,19 +31,19 @@
 }
 
 // CHECK-LABEL:   func.func @masked_extract_contract2(
-// CHECK-SAME:                                      %[[VAL_0:.*]]: vector<2x3xf32>,
-// CHECK-SAME:                                      %[[VAL_1:.*]]: vector<3xf32>,
-// CHECK-SAME:                                      %[[VAL_2:.*]]: vector<2xf32>,
-// CHECK-SAME:                                      %[[IN_MASK:.*]]: vector<2x3xi1>) -> vector<2xf32>
+// CHECK-SAME:      %{{.*}}: vector<2x3xf32>,
+// CHECK-SAME:      %{{.*}}: vector<3xf32>,
+// CHECK-SAME:      %{{.*}}: vector<2xf32>,
+// CHECK-SAME:      %[[IN_MASK:.*]]: vector<2x3xi1>) -> vector<2xf32>
 // CHECK:           %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [1, 0] : vector<2x3xi1> to vector<3x2xi1>
 // CHECK:           %[[MASK0:.*]] = vector.extract %[[T_MASK]][0] : vector<2xi1> from vector<3x2xi1>
-// CHECK:           vector.mask %[[MASK0]] { vector.outerproduct
+// CHECK:           vector.mask %[[MASK0]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
 
 // CHECK:           %[[MASK1:.*]] = vector.extract %[[T_MASK]][1] : vector<2xi1> from vector<3x2xi1>
-// CHECK:           vector.mask %[[MASK1]] { vector.outerproduct
+// CHECK:           vector.mask %[[MASK1]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
 
 // CHECK:           %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<2xi1> from vector<3x2xi1>
-// CHECK:           vector.mask %[[MASK2]] { vector.outerproduct
+// CHECK:           vector.mask %[[MASK2]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
 
 func.func @masked_extract_contract2(%arg0: vector<2x3xf32>,
                                     %arg1: vector<3xf32>,
@@ -54,22 +54,29 @@ func.func @masked_extract_contract2(%arg0: vector<2x3xf32>,
   return %0 : vector<2xf32>
 }
 
-// CHECK-LABEL: func.func @masked_extract_contract4(
-// CHECK-SAME:                                      %[[VAL_0:.*]]: vector<3x5xf32>,
-// CHECK-SAME:                                      %[[VAL_1:.*]]: vector<5x7xf32>,
-// CHECK-SAME:                                      %[[VAL_2:.*]]: vector<3x7xf32>,
-// CHECK-SAME:                                      %[[VAL_3:.*]]: vector<3x7x5xi1>) -> vector<3x7xf32> {
-// CHECK:         %[[VAL_5:.*]] = vector.transpose %[[VAL_3]], [2, 0, 1] : vector<3x7x5xi1> to vector<5x3x7xi1>
-// CHECK:         %[[VAL_8:.*]] = vector.extract %[[VAL_5]][0] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK:         %[[VAL_9:.*]] = vector.mask %[[VAL_8]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
-// CHECK:         %[[VAL_12:.*]] = vector.extract %[[VAL_5]][1] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK:         %[[VAL_13:.*]] = vector.mask %[[VAL_12]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
-// CHECK:         %[[VAL_16:.*]] = vector.extract %[[VAL_5]][2] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK:         %[[VAL_17:.*]] = vector.mask %[[VAL_16]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
-// CHECK:         %[[VAL_20:.*]] = vector.extract %[[VAL_5]][3] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK:         %[[VAL_21:.*]] = vector.mask %[[VAL_20]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
-// CHECK:         %[[VAL_24:.*]] = vector.extract %[[VAL_5]][4] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK:         %[[VAL_25:.*]] = vector.mask %[[VAL_24]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+
+// CHECK-LABEL:   func.func @masked_extract_contract2_scalable_parallel_dim(
+// CHECK-SAME:      %{{.*}}: vector<[2]x3xf32>,
+// CHECK-SAME:      %{{.*}}: vector<3xf32>,
+// CHECK-SAME:      %{{.*}}: vector<[2]xf32>,
+// CHECK-SAME:      %[[IN_MASK:.*]]: vector<[2]x3xi1>) -> vector<[2]xf32>
+// CHECK:           %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [1, 0] : vector<[2]x3xi1> to vector<3x[2]xi1>
+// CHECK:           %[[MASK0:.*]] = vector.extract %[[T_MASK]][0] : vector<[2]xi1> from vector<3x[2]xi1>
+// CHECK:           vector.mask %[[MASK0]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
+
+// CHECK:           %[[MASK1:.*]] = vector.extract %[[T_MASK]][1] : vector<[2]xi1> from vector<3x[2]xi1>
+// CHECK:           vector.mask %[[MASK1]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
+
+// CHECK:           %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<[2]xi1> from vector<3x[2]xi1>
+// CHECK:           vector.mask %[[MASK2]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
+func.func @masked_extract_contract2_scalable_parallel_dim(%arg0: vector<[2]x3xf32>,
+                                    %arg1: vector<3xf32>,
+                                    %arg2: vector<[2]xf32>,
+                                    %m: vector<[2]x3xi1>) -> vector<[2]xf32> {
+  %0 = vector.mask %m { vector.contract #matvec_trait %arg0, %arg1, %arg2
+          : vector<[2]x3xf32>, vector<3xf32> into vector<[2]xf32> } : vector<[2]x3xi1> -> vector<[2]xf32>
+  return %0 : vector<[2]xf32>
+}
 
 func.func @masked_extract_contract4(%arg0: vector<3x5xf32>,
                                     %arg1: vector<5x7xf32>,

``````````

</details>


https://github.com/llvm/llvm-project/pull/68400


More information about the Mlir-commits mailing list