[Mlir-commits] [mlir] [mlir][vector] Fix invalid IR in `ContractionOpLowering` (PR #78130)

Matthias Springer llvmlistbot at llvm.org
Tue Jan 16 00:18:36 PST 2024


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/78130

>From f24010d6ae39ad4f841b27890176ba1f19a3de6c Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Mon, 15 Jan 2024 09:19:33 +0000
Subject: [PATCH] [mlir][vector] Fix invalid IR in `ContractionOpLowering`

If a rewrite pattern returns "failure", it must not have modified the IR. This commit fixes `Dialect/Vector/vector-contract-to-outerproduct-transforms-unsupported.mlir` when running with `MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS`.

```
  * Pattern (anonymous namespace)::ContractionOpToOuterProductOpLowering : 'vector.contract -> ()' {
Trying to match "(anonymous namespace)::ContractionOpToOuterProductOpLowering"
    ** Insert  : 'vector.transpose'(0x5625b3a8cb30)
    ** Insert  : 'vector.transpose'(0x5625b3a8cbc0)
"(anonymous namespace)::ContractionOpToOuterProductOpLowering" result 0
  } -> failure : pattern failed to match
} -> failure : pattern failed to match

LLVM ERROR: pattern returned failure but IR did change
```

Note: `vector-contract-to-outerproduct-transforms-unsupported.mlir` is merged into `vector-contract-to-outerproduct-matvec-transforms.mlir`. The `greedy pattern application failed` error is not longer produced. This error indicates that the greedy pattern rewrite did not convergence; it does not mean that a pattern could not be applied.
---
 .../Vector/Transforms/LowerVectorContract.cpp | 153 ++++++++++++------
 ...act-to-outerproduct-matvec-transforms.mlir |  20 ++-
 ...o-outerproduct-transforms-unsupported.mlir |  35 ----
 3 files changed, 118 insertions(+), 90 deletions(-)
 delete mode 100644 mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms-unsupported.mlir

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index 6ff4c26763d2478..c9256b00116c07d 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -426,16 +426,8 @@ struct UnrolledOuterProductGenerator
   }
 
   FailureOr<Value> outerProd(Value lhs, Value rhs, Value res,
-                             VectorType lhsType, int reductionDim,
+                             VectorType lhsType, int reductionSize,
                              std::optional<Value> maybeMask = std::nullopt) {
-    // Unrolling a scalable dimension would be incorrect - bail out.
-    if (lhsType.getScalableDims()[reductionDim])
-      return failure();
-
-    int reductionSize = lhsType.getDimSize(reductionDim);
-    assert(reductionSize > 0 &&
-           "Reduction dim must be a known static size to allow unrolling");
-
     // Incremental support for masking.
     if (mask && !maybeMask.has_value())
       return failure();
@@ -458,6 +450,20 @@ struct UnrolledOuterProductGenerator
     return res;
   }
 
+  /// Helper function for `matmat`, `matvec`, `tmatvec`. Returns the size of
+  /// dimension `reductionDim`. If the dimension is a scalable dimension,
+  /// returns "nullopt".
+  std::optional<int64_t> getReductionSize(VectorType vecType,
+                                          int64_t reductionDim) {
+    // Cannot unroll scalable dimension.
+    if (vecType.getScalableDims()[reductionDim])
+      return std::nullopt;
+    int64_t reductionSize = vecType.getDimSize(reductionDim);
+    assert(reductionSize > 0 &&
+           "Reduction dim must be a known static size to allow unrolling");
+    return reductionSize;
+  }
+
   /// Two outer parallel, one inner reduction (matmat flavor).
   FailureOr<Value> matmat() {
     if (!iters({Par(), Par(), Red()}))
@@ -465,42 +471,70 @@ struct UnrolledOuterProductGenerator
     // Set up the parallel/reduction structure in the right form.
     AffineExpr m, n, k;
     bindDims(rewriter.getContext(), m, n, k);
-    Value transposedMask = t(mask, {2, 0, 1});
+
     // Classical row-major matmul:  Just permute the lhs.
-    if (layout({{m, k}, {k, n}, {m, n}}))
-      return outerProd(t(lhs), rhs, res, lhsType, /*reductionDim=*/1,
-                       transposedMask);
+    if (layout({{m, k}, {k, n}, {m, n}})) {
+      if (auto reductionSize = getReductionSize(lhsType, 1)) {
+        Value tLhs = t(lhs);
+        Value tMask = t(mask, {2, 0, 1});
+        return outerProd(tLhs, rhs, res, lhsType, *reductionSize, tMask);
+      }
+    }
     // TODO: may be better to fail and use some vector<k> -> scalar reduction.
     if (layout({{m, k}, {n, k}, {m, n}})) {
-      Value tlhs = t(lhs);
-      return outerProd(tlhs, t(rhs), res, lhsType, /*reductionDim=*/1,
-                       transposedMask);
+      if (auto reductionSize = getReductionSize(lhsType, 1)) {
+        Value tLhs = t(lhs);
+        Value tRhs = t(rhs);
+        Value tMask = t(mask, {2, 0, 1});
+        return outerProd(tLhs, tRhs, res, lhsType, *reductionSize, tMask);
+      }
     }
     // No need to permute anything.
-    if (layout({{k, m}, {k, n}, {m, n}}))
-      return outerProd(lhs, rhs, res, lhsType, /*reductionDim=*/0,
-                       transposedMask);
+    if (layout({{k, m}, {k, n}, {m, n}})) {
+      if (auto reductionSize = getReductionSize(lhsType, 0)) {
+        Value tMask = t(mask, {2, 0, 1});
+        return outerProd(lhs, rhs, res, lhsType, *reductionSize, tMask);
+      }
+    }
     // Just permute the rhs.
-    if (layout({{k, m}, {n, k}, {m, n}}))
-      return outerProd(lhs, t(rhs), res, lhsType, /*reductionDim=*/0,
-                       transposedMask);
+    if (layout({{k, m}, {n, k}, {m, n}})) {
+      if (auto reductionSize = getReductionSize(lhsType, 0)) {
+        Value tRhs = t(rhs);
+        Value tMask = t(mask, {2, 0, 1});
+        return outerProd(lhs, tRhs, res, lhsType, *reductionSize, tMask);
+      }
+    }
     // Transposed output: swap RHS and LHS.
     // Classical row-major matmul: permute the lhs.
-    if (layout({{m, k}, {k, n}, {n, m}}))
-      return outerProd(rhs, t(lhs), res, lhsType, /*reductionDim=*/1,
-                       transposedMask);
+    if (layout({{m, k}, {k, n}, {n, m}})) {
+      if (auto reductionSize = getReductionSize(lhsType, 1)) {
+        Value tLhs = t(lhs);
+        Value tMask = t(mask, {2, 0, 1});
+        return outerProd(rhs, tLhs, res, lhsType, *reductionSize, tMask);
+      }
+    }
     // TODO: may be better to fail and use some vector<k> -> scalar reduction.
     if (layout({{m, k}, {n, k}, {n, m}})) {
-      Value trhs = t(rhs);
-      return outerProd(trhs, t(lhs), res, lhsType, /*reductionDim=*/1,
-                       transposedMask);
+      if (auto reductionSize = getReductionSize(lhsType, 1)) {
+        Value tRhs = t(rhs);
+        Value tLhs = t(lhs);
+        Value tMask = t(mask, {2, 0, 1});
+        return outerProd(tRhs, tLhs, res, lhsType, *reductionSize, tMask);
+      }
+    }
+    if (layout({{k, m}, {k, n}, {n, m}})) {
+      if (auto reductionSize = getReductionSize(lhsType, 0)) {
+        Value tMask = t(mask, {2, 0, 1});
+        return outerProd(rhs, lhs, res, lhsType, *reductionSize, tMask);
+      }
+    }
+    if (layout({{k, m}, {n, k}, {n, m}})) {
+      if (auto reductionSize = getReductionSize(lhsType, 0)) {
+        Value tRhs = t(rhs);
+        Value tMask = t(mask, {2, 0, 1});
+        return outerProd(tRhs, lhs, res, lhsType, *reductionSize, tMask);
+      }
     }
-    if (layout({{k, m}, {k, n}, {n, m}}))
-      return outerProd(rhs, lhs, res, lhsType, /*reductionDim=*/0,
-                       transposedMask);
-    if (layout({{k, m}, {n, k}, {n, m}}))
-      return outerProd(t(rhs), lhs, res, lhsType, /*reductionDim=*/0,
-                       transposedMask);
     return failure();
   }
 
@@ -514,24 +548,37 @@ struct UnrolledOuterProductGenerator
       return failure();
     AffineExpr m, k;
     bindDims(rewriter.getContext(), m, k);
-    Value transposedMask = t(mask);
 
     // Case mat-vec: transpose.
-    if (layout({{m, k}, {k}, {m}}))
-      return outerProd(t(lhs), rhs, res, lhsType, /*reductionDim=*/1,
-                       transposedMask);
+    if (layout({{m, k}, {k}, {m}})) {
+      if (auto reductionSize = getReductionSize(lhsType, 1)) {
+        Value tLhs = t(lhs);
+        Value tMask = t(mask);
+        return outerProd(tLhs, rhs, res, lhsType, *reductionSize, tMask);
+      }
+    }
     // Case mat-trans-vec: ready to go.
-    if (layout({{k, m}, {k}, {m}}))
-      return outerProd(lhs, rhs, res, lhsType, /*reductionDim=*/0,
-                       transposedMask);
+    if (layout({{k, m}, {k}, {m}})) {
+      if (auto reductionSize = getReductionSize(lhsType, 0)) {
+        Value tMask = t(mask);
+        return outerProd(lhs, rhs, res, lhsType, *reductionSize, tMask);
+      }
+    }
     // Case vec-mat: swap and transpose.
-    if (layout({{k}, {m, k}, {m}}))
-      return outerProd(t(rhs), lhs, res, lhsType, /*reductionDim=*/0,
-                       transposedMask);
+    if (layout({{k}, {m, k}, {m}})) {
+      if (auto reductionSize = getReductionSize(lhsType, 0)) {
+        Value tRhs = t(rhs);
+        Value tMask = t(mask);
+        return outerProd(tRhs, lhs, res, lhsType, *reductionSize, tMask);
+      }
+    }
     // Case vec-mat-trans: swap and ready to go.
-    if (layout({{k}, {k, m}, {m}}))
-      return outerProd(rhs, lhs, res, lhsType, /*reductionDim=*/0,
-                       transposedMask);
+    if (layout({{k}, {k, m}, {m}})) {
+      if (auto reductionSize = getReductionSize(lhsType, 0)) {
+        Value tMask = t(mask);
+        return outerProd(rhs, lhs, res, lhsType, *reductionSize, tMask);
+      }
+    }
     return failure();
   }
 
@@ -547,16 +594,20 @@ struct UnrolledOuterProductGenerator
 
     // Case mat-vec: transpose.
     if (layout({{m, k}, {k}, {m}}))
-      return outerProd(t(lhs), rhs, res, lhsType, /*reductionDim=*/1, mask);
+      if (auto reductionSize = getReductionSize(lhsType, 1))
+        return outerProd(t(lhs), rhs, res, lhsType, *reductionSize, mask);
     // Case mat-trans-vec: ready to go.
     if (layout({{k, m}, {k}, {m}}))
-      return outerProd(lhs, rhs, res, lhsType, /*reductionDim=*/0, mask);
+      if (auto reductionSize = getReductionSize(lhsType, 0))
+        return outerProd(lhs, rhs, res, lhsType, *reductionSize, mask);
     // Case vec-mat: swap and transpose.
     if (layout({{k}, {m, k}, {m}}))
-      return outerProd(t(rhs), lhs, res, lhsType, /*reductionDim=*/0, mask);
+      if (auto reductionSize = getReductionSize(lhsType, 0))
+        return outerProd(t(rhs), lhs, res, lhsType, *reductionSize, mask);
     // Case vec-mat-trans: swap and ready to go.
     if (layout({{k}, {k, m}, {m}}))
-      return outerProd(rhs, lhs, res, lhsType, /*reductionDim=*/0, mask);
+      if (auto reductionSize = getReductionSize(lhsType, 0))
+        return outerProd(rhs, lhs, res, lhsType, *reductionSize, mask);
     return failure();
   }
 
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir
index d86c6158bcdf2fe..5c8527f77e3df0e 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir
@@ -320,8 +320,8 @@ func.func @masked_matvec_k_mk_m(%A: vector<4x2xf32>,
                                 %x: vector<2xf32>,
                                 %b: vector<4xf32>,
                                 %mask: vector<4x2xi1>) -> vector<4xf32> {
-  // CHECK:         vector.transpose %[[MASK]]
-  // CHECK:         vector.transpose %[[A]]
+  // CHECK-DAG:     vector.transpose %[[MASK]]
+  // CHECK-DAG:     vector.transpose %[[A]]
   // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
   %res = vector.mask %mask {
       vector.contract #matvec_trait_3 %x, %A, %b
@@ -339,8 +339,8 @@ func.func @masked_matvec_k_mk_m_scalable_parallel_dim(%A: vector<[4]x2xf32>,
                                                       %x: vector<2xf32>,
                                                       %b: vector<[4]xf32>,
                                                       %mask: vector<[4]x2xi1>) -> vector<[4]xf32> {
-  // CHECK:         vector.transpose %[[MASK]]
-  // CHECK:         vector.transpose %[[A]]
+  // CHECK-DAG:     vector.transpose %[[MASK]]
+  // CHECK-DAG:     vector.transpose %[[A]]
   // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
   %res = vector.mask %mask {
       vector.contract #matvec_trait_3 %x, %A, %b
@@ -641,6 +641,18 @@ func.func @masked_tmatvec_k_km_m_scalable_parallel_dim(%A: vector<2x[4]xf32>,
   return %res : vector<[4]xf32>
 }
 
+// Unrolling scalable reduction dim is not supported - bail out
+// CHECK-LABEL: @masked_extract_contract2_scalable_reduction_dim(
+// CHECK:         vector.contract {{.*}} : vector<[2]x[3]xf32>, vector<[3]xf32> into vector<[2]xf32>
+func.func @masked_extract_contract2_scalable_reduction_dim(%arg0: vector<[2]x[3]xf32>,
+                                    %arg1: vector<[3]xf32>,
+                                    %arg2: vector<[2]xf32>,
+                                    %m: vector<[2]x[3]xi1>) -> vector<[2]xf32> {
+  %0 = vector.mask %m { vector.contract #matvec_trait_1 %arg0, %arg1, %arg2
+          : vector<[2]x[3]xf32>, vector<[3]xf32> into vector<[2]xf32> } : vector<[2]x[3]xi1> -> vector<[2]xf32>
+  return %0 : vector<[2]xf32>
+}
+
 // ============================================================================
 //  TD sequence
 // ============================================================================
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms-unsupported.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms-unsupported.mlir
deleted file mode 100644
index 954aa13c3e77b37..000000000000000
--- a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms-unsupported.mlir
+++ /dev/null
@@ -1,35 +0,0 @@
-// RUN: mlir-opt %s --transform-interpreter -allow-unregistered-dialect --split-input-file --verify-diagnostics
-
-#matvec_accesses = [
-  affine_map<(i, j) -> (i, j)>,
-  affine_map<(i, j) -> (j)>,
-  affine_map<(i, j) -> (i)>
-]
-#matvec_trait = {
-  indexing_maps = #matvec_accesses,
-  iterator_types = ["parallel", "reduction"]
-}
-
-// Unrolling scalable reduction dim is not supported - bail out
-
-// expected-error at below {{greedy pattern application failed}}
-func.func @masked_extract_contract2_scalable_reduction_dim(%arg0: vector<[2]x[3]xf32>,
-                                    %arg1: vector<[3]xf32>,
-                                    %arg2: vector<[2]xf32>,
-                                    %m: vector<[2]x[3]xi1>) -> vector<[2]xf32> {
-  %0 = vector.mask %m { vector.contract #matvec_trait %arg0, %arg1, %arg2
-          : vector<[2]x[3]xf32>, vector<[3]xf32> into vector<[2]xf32> } : vector<[2]x[3]xi1> -> vector<[2]xf32>
-  return %0 : vector<[2]xf32>
-}
-
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
-    %f = transform.structured.match ops{["func.func"]} in %module_op
-      : (!transform.any_op) -> !transform.any_op
-
-    transform.apply_patterns to %f {
-      transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
-    } : !transform.any_op
-    transform.yield
-  }
-}



More information about the Mlir-commits mailing list